Skip to main content

opencv_binding_generator/writer/rust_native/
smart_ptr.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::LazyLock;
4
5use super::class::ClassExt;
6use super::element::RustElement;
7use super::type_ref::TypeRefExt;
8use super::RustNativeGeneratedElement;
9use crate::class::ClassDesc;
10use crate::field::{Field, FieldDesc};
11use crate::func::{FuncCppBody, FuncDesc, FuncKind, ReturnKind};
12use crate::smart_ptr::SmartPtrDesc;
13use crate::type_ref::{Constness, CppNameStyle, FishStyle, NameStyle, TypeRef, TypeRefKind};
14use crate::writer::rust_native::class::rust_generate_debug_fields;
15use crate::{Class, CompiledInterpolation, Element, Func, IteratorExt, SmartPtr, StrExt, StringExt, SupportedModule};
16
17impl RustElement for SmartPtr<'_, '_> {
18	fn rust_module(&self) -> SupportedModule {
19		self.pointee().rust_module()
20	}
21
22	// fixme, we shouldn't override the rust_module_reference and rely on rust_module to provide the correct module
23	fn rust_module_reference(&self) -> Cow<'_, str> {
24		"core".into()
25	}
26
27	fn rust_name(&self, style: NameStyle) -> Cow<'_, str> {
28		let decl_name = self.rust_leafname(style.turbo_fish_style());
29		match style {
30			NameStyle::Declaration => decl_name,
31			NameStyle::Reference(_) => {
32				let mut out = self.rust_module_reference().into_owned();
33				out.extend_sep("::", &decl_name);
34				out.into()
35			}
36		}
37	}
38
39	fn rust_leafname(&self, fish_style: FishStyle) -> Cow<'_, str> {
40		format!(
41			"Ptr{fish}<{typ}>",
42			fish = fish_style.rust_qual(),
43			typ = self.pointee().rust_name(NameStyle::ref_()),
44		)
45		.into()
46	}
47}
48
49impl RustNativeGeneratedElement for SmartPtr<'_, '_> {
50	fn element_safe_id(&self) -> String {
51		format!("{}-{}", self.rust_module().opencv_name(), self.rust_localalias())
52	}
53
54	fn gen_rust(&self, _opencv_version: &str) -> String {
55		static TPL: LazyLock<CompiledInterpolation> =
56			LazyLock::new(|| include_str!("tpl/smart_ptr/rust.tpl.rs").compile_interpolation());
57
58		static TRAIT_RAW_TPL: LazyLock<CompiledInterpolation> =
59			LazyLock::new(|| include_str!("tpl/smart_ptr/trait_raw.tpl.rs").compile_interpolation());
60
61		static BASE_CAST_TPL: LazyLock<CompiledInterpolation> =
62			LazyLock::new(|| include_str!("tpl/smart_ptr/base_cast.tpl.rs").compile_interpolation());
63
64		static IMPL_DEBUG_TPL: LazyLock<CompiledInterpolation> =
65			LazyLock::new(|| include_str!("tpl/smart_ptr/impl_debug.rs").compile_interpolation());
66
67		static CTOR_TPL: LazyLock<CompiledInterpolation> =
68			LazyLock::new(|| include_str!("tpl/smart_ptr/ctor.tpl.rs").compile_interpolation());
69
70		let rust_localalias = self.rust_localalias();
71		let rust_full = self.rust_name(NameStyle::ref_());
72		let pointee_type = self.pointee();
73		let pointee_kind = pointee_type.kind();
74		let inner_rust_full = pointee_type.rust_name(NameStyle::ref_());
75		let type_ref = self.type_ref();
76		let smartptr_class = smartptr_class(&type_ref);
77
78		let extern_get_inner_ptr = method_get_inner_ptr(
79			smartptr_class.clone(),
80			pointee_type.as_ref().clone().with_inherent_constness(Constness::Const),
81		)
82		.identifier();
83		let extern_get_inner_ptr_mut = method_get_inner_ptr(
84			smartptr_class.clone(),
85			pointee_type.as_ref().clone().with_inherent_constness(Constness::Mut),
86		)
87		.identifier();
88
89		let mut impls = String::with_capacity(1024);
90		if let Some(cls) = pointee_kind.as_class().filter(|cls| cls.kind().is_trait()) {
91			TRAIT_RAW_TPL.interpolate_into(
92				&mut impls,
93				&HashMap::from([
94					("rust_full", rust_full.as_ref()),
95					("base_rust_as_raw_const", &cls.rust_as_raw_name(Constness::Const)),
96					("base_rust_as_raw_mut", &cls.rust_as_raw_name(Constness::Mut)),
97					("base_rust_full_mut", &cls.rust_trait_name(NameStyle::ref_(), Constness::Mut)),
98					(
99						"base_rust_full_const",
100						&cls.rust_trait_name(NameStyle::ref_(), Constness::Const),
101					),
102				]),
103			);
104			let mut debug_fields = rust_generate_debug_fields(
105				cls.field_methods(&cls.fields(|f| f.exclude_kind().is_included()), Some(Constness::Const)),
106			);
107			for base in all_bases(&cls) {
108				let base_rust_local = base.rust_name(NameStyle::decl());
109				TRAIT_RAW_TPL.interpolate_into(
110					&mut impls,
111					&HashMap::from([
112						("rust_full", rust_full.as_ref()),
113						("base_rust_as_raw_const", &base.rust_as_raw_name(Constness::Const)),
114						("base_rust_as_raw_mut", &base.rust_as_raw_name(Constness::Mut)),
115						("base_rust_full_mut", &base.rust_trait_name(NameStyle::ref_(), Constness::Mut)),
116						(
117							"base_rust_full_const",
118							&base.rust_trait_name(NameStyle::ref_(), Constness::Const),
119						),
120					]),
121				);
122
123				let extern_cast_to_base = method_cast_to_base(smartptr_class.clone(), base.type_ref(), &base_rust_local).identifier();
124				BASE_CAST_TPL.interpolate_into(
125					&mut impls,
126					&HashMap::from([
127						("rust_full", rust_full.as_ref()),
128						("base_rust_full_ref", &base.rust_name(NameStyle::ref_())),
129						("extern_cast_to_base", &extern_cast_to_base),
130					]),
131				);
132				let base_fields = base.fields(|f| f.exclude_kind().is_included());
133				let base_field_const_methods = base.field_methods(&base_fields, Some(Constness::Const));
134				debug_fields.push_str(&rust_generate_debug_fields(base_field_const_methods));
135			}
136			IMPL_DEBUG_TPL.interpolate_into(
137				&mut impls,
138				&HashMap::from([
139					("rust_full", rust_full.as_ref()),
140					("rust_localalias", &rust_localalias),
141					("debug_fields", &debug_fields),
142				]),
143			);
144		};
145
146		let rust_as_raw_const = type_ref.rust_as_raw_name(Constness::Const);
147		let rust_as_raw_mut = type_ref.rust_as_raw_name(Constness::Mut);
148
149		let ctor = if gen_ctor(&pointee_kind) {
150			let extern_new = method_new(smartptr_class.clone(), type_ref.clone(), pointee_type.as_ref().clone()).identifier();
151			CTOR_TPL.interpolate(&HashMap::from([
152				("inner_rust_full", inner_rust_full.as_ref()),
153				("extern_new", &extern_new),
154			]))
155		} else {
156			"".to_string()
157		};
158
159		let extern_new_null = method_new_null(smartptr_class.clone(), type_ref).identifier();
160		let extern_delete = FuncDesc::method_delete(smartptr_class).identifier();
161		TPL.interpolate(&HashMap::from([
162			("rust_localalias", rust_localalias.as_ref()),
163			("rust_as_raw_const", &rust_as_raw_const),
164			("rust_as_raw_mut", &rust_as_raw_mut),
165			("rust_full", &rust_full),
166			("inner_rust_full", &inner_rust_full),
167			("extern_new_null", &extern_new_null),
168			("extern_delete", &extern_delete),
169			("extern_get_inner_ptr", &extern_get_inner_ptr),
170			("extern_get_inner_ptr_mut", &extern_get_inner_ptr_mut),
171			("ctor", &ctor),
172			("impls", &impls),
173		]))
174	}
175
176	fn gen_rust_externs(&self) -> String {
177		extern_functions(self).iter().map(Func::gen_rust_externs).join("")
178	}
179
180	fn gen_cpp(&self) -> String {
181		static TPL: LazyLock<CompiledInterpolation> =
182			LazyLock::new(|| include_str!("tpl/smart_ptr/cpp.tpl.cpp").compile_interpolation());
183
184		TPL.interpolate(&[("methods", extern_functions(self).iter().map(Func::gen_cpp).join(""))].into())
185	}
186}
187
188fn extern_functions<'tu, 'ge>(ptr: &SmartPtr<'tu, 'ge>) -> Vec<Func<'tu, 'ge>> {
189	let type_ref = ptr.type_ref();
190	let pointee_type = ptr.pointee();
191	let pointee_kind = pointee_type.kind();
192	let smartptr_class = smartptr_class(&type_ref);
193
194	let mut out = Vec::with_capacity(6);
195	out.push(method_get_inner_ptr(
196		smartptr_class.clone(),
197		pointee_type.as_ref().clone().with_inherent_constness(Constness::Const),
198	));
199	out.push(method_get_inner_ptr(
200		smartptr_class.clone(),
201		pointee_type.as_ref().clone().with_inherent_constness(Constness::Mut),
202	));
203	out.push(method_new_null(smartptr_class.clone(), type_ref.clone()));
204	out.push(FuncDesc::method_delete(smartptr_class.clone()));
205	if let Some(cls) = pointee_kind.as_class().filter(|cls| cls.kind().is_trait()) {
206		for base in all_bases(&cls) {
207			out.push(method_cast_to_base(
208				smartptr_class.clone(),
209				base.type_ref(),
210				&base.rust_name(NameStyle::decl()),
211			));
212		}
213	}
214	if gen_ctor(&pointee_kind) {
215		out.push(method_new(smartptr_class, type_ref, pointee_type.into_owned()));
216	}
217	out
218}
219
220fn gen_ctor(pointee_kind: &TypeRefKind) -> bool {
221	match pointee_kind.canonical().as_ref() {
222		TypeRefKind::Primitive(_, _) => true,
223		TypeRefKind::Class(cls) => !cls.is_abstract(),
224		_ => false,
225	}
226}
227
228fn all_bases<'tu, 'ge>(cls: &Class<'tu, 'ge>) -> Vec<Class<'tu, 'ge>> {
229	let mut out = cls
230		.all_bases()
231		.into_iter()
232		.filter(|b| b.exclude_kind().is_included())
233		.collect::<Vec<_>>();
234	out.sort_unstable_by(|left, right| {
235		left
236			.cpp_name(CppNameStyle::Reference)
237			.cmp(&right.cpp_name(CppNameStyle::Reference))
238	});
239	out
240}
241
242pub trait SmartPtrExt {
243	fn rust_localalias(&self) -> Cow<'_, str>;
244}
245
246impl SmartPtrExt for SmartPtr<'_, '_> {
247	fn rust_localalias(&self) -> Cow<'_, str> {
248		/*
249		let pointee = self.pointee();
250		let pointee_alias = pointee.rust_safe_id(true);
251		let pointee_alias = if let Some(rem) = pointee_alias.strip_prefix("const_") {
252			format!("Const{rem}").into()
253		} else {
254			pointee_alias
255		};
256		format!("PtrOf{pointee_alias}").into()
257		*/
258		// fixme: Not adding const here in rust_safe_id() leads to some smart pointers losing the const qualifier on the internal
259		// type (e.g. cv::Ptr<const cv::optflow::PCAPrior>). If we add it (see commented code above) it leads to problems with casting
260		// because casting doesn't take constness into account. This might not be a problem per se (e.g. if we own Ptr<PCAPrior> there is
261		// no problem to pass it as Ptr<const PCAPrior>), otherwise fix it so that it works with add_const = true in rust_safe_id().
262		format!("PtrOf{typ}", typ = self.pointee().rust_safe_id(false)).into()
263	}
264}
265
266fn smartptr_class<'tu, 'ge>(smart_ptr_type_ref: &TypeRef<'tu, 'ge>) -> Class<'tu, 'ge> {
267	Class::new_desc(ClassDesc::boxed(
268		smart_ptr_type_ref.cpp_name(CppNameStyle::Reference),
269		SupportedModule::Core,
270	))
271}
272
273fn method_new<'tu, 'ge>(
274	smartptr_class: Class<'tu, 'ge>,
275	smartptr_type_ref: TypeRef<'tu, 'ge>,
276	pointee_type: TypeRef<'tu, 'ge>,
277) -> Func<'tu, 'ge> {
278	let pointee_kind = pointee_type.kind();
279	let val = if pointee_kind.is_copy(pointee_type.type_hint()) {
280		if pointee_kind.as_class().is_some_and(|cls| cls.kind().is_simple()) {
281			panic!("Ptr with simple class is not supported");
282		} else {
283			format!("new {typ}(val)", typ = pointee_type.cpp_name(CppNameStyle::Reference)).into()
284		}
285	} else {
286		Cow::Borrowed("val")
287	};
288	Func::new_desc(
289		FuncDesc::new(
290			FuncKind::Constructor(smartptr_class),
291			Constness::Const,
292			ReturnKind::InfallibleNaked,
293			"new",
294			SupportedModule::Core,
295			[Field::new_desc(FieldDesc::new("val", pointee_type))],
296			smartptr_type_ref,
297		)
298		.cpp_body(FuncCppBody::ManualCallReturn(
299			format!("return new {{{{ret_type}}}}({val});").into(),
300		)),
301	)
302}
303
304fn method_new_null<'tu, 'ge>(smartptr_class: Class<'tu, 'ge>, smartptr_type_ref: TypeRef<'tu, 'ge>) -> Func<'tu, 'ge> {
305	Func::new_desc(
306		FuncDesc::new(
307			FuncKind::Constructor(smartptr_class),
308			Constness::Const,
309			ReturnKind::InfallibleNaked,
310			"new_null",
311			SupportedModule::Core,
312			[],
313			smartptr_type_ref,
314		)
315		.cpp_body(FuncCppBody::ManualCallReturn("return new {{ret_type}}();".into())),
316	)
317}
318
319fn method_cast_to_base<'tu, 'ge>(
320	smartptr_class: Class<'tu, 'ge>,
321	base_type_ref: TypeRef<'tu, 'ge>,
322	base_rust_local: &str,
323) -> Func<'tu, 'ge> {
324	let cpp_body = FuncCppBody::ManualCallReturn(
325		format!(
326			"return new {{{{ret_type}}}}(instance->dynamicCast<{base_type}>());",
327			base_type = base_type_ref.cpp_name(CppNameStyle::Reference)
328		)
329		.into(),
330	);
331	Func::new_desc(
332		FuncDesc::new(
333			FuncKind::InstanceMethod(smartptr_class),
334			Constness::Mut,
335			ReturnKind::InfallibleNaked,
336			format!("to_PtrOf{base_rust_local}"),
337			SupportedModule::Core,
338			[],
339			TypeRef::new_smartptr(SmartPtr::new_desc(SmartPtrDesc::new(base_type_ref))),
340		)
341		.cpp_body(cpp_body),
342	)
343}
344
345fn method_get_inner_ptr<'tu, 'ge>(smartptr_class: Class<'tu, 'ge>, pointee_type: TypeRef<'tu, 'ge>) -> Func<'tu, 'ge> {
346	// if the pointee type is actually const make sure to also generate a const function, needed for
347	// Ptr<const cv::optflow::PCAPrior*>
348	let constness = pointee_type.constness();
349	let return_type_ref = if pointee_type.kind().extern_pass_kind().is_by_ptr() {
350		pointee_type
351	} else {
352		TypeRef::new_pointer(pointee_type)
353	};
354	Func::new_desc(
355		FuncDesc::new(
356			FuncKind::InstanceMethod(smartptr_class),
357			constness,
358			ReturnKind::InfallibleNaked,
359			format!("getInnerPtr{}", constness.rust_name_qual()),
360			SupportedModule::Core,
361			[],
362			return_type_ref,
363		)
364		.cpp_body(FuncCppBody::ManualCallReturn("return instance->get();".into())),
365	)
366}