opencv_binding_generator/func/
func_matcher.rs

1use std::borrow::Cow;
2use std::borrow::Cow::Owned;
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::field::Field;
7use crate::type_ref::Constness;
8use crate::{CowMapBorrowedExt, CppNameStyle, Element, Func};
9
10#[derive(Debug)]
11pub struct FuncMatchProperties<'f> {
12	func: &'f Func<'f, 'f>,
13	name: Cow<'f, str>,
14	constness: Option<Constness>,
15	ret: Option<Cow<'f, str>>,
16	arg_names: Option<Vec<Cow<'f, str>>>,
17	arg_types: Option<Vec<Cow<'f, str>>>,
18}
19
20impl<'f> FuncMatchProperties<'f> {
21	pub fn new(func: &'f Func<'f, 'f>, name: Cow<'f, str>) -> Self {
22		Self {
23			func,
24			name,
25			constness: None,
26			ret: None,
27			arg_names: None,
28			arg_types: None,
29		}
30	}
31
32	pub fn name(&mut self) -> &str {
33		self.name.as_ref()
34	}
35
36	pub fn constness(&mut self) -> Constness {
37		*self.constness.get_or_insert_with(|| self.func.constness())
38	}
39
40	pub fn return_type(&mut self) -> &str {
41		self.ret.get_or_insert_with(|| {
42			self
43				.func
44				.return_type_ref()
45				.map_borrowed(|ret| ret.cpp_name(CppNameStyle::Reference))
46		})
47	}
48
49	pub fn arg_names(&mut self) -> &[Cow<str>] {
50		self
51			.arg_names
52			.get_or_insert_with(|| {
53				match &self.func {
54					Func::Clang { entity, .. } => {
55						self
56							.func
57							.clang_arguments(*entity)
58							.iter()
59							.map(|arg| Owned(arg.cpp_name(CppNameStyle::Declaration).into_owned())) // todo: find a way to store borrowed
60							.collect()
61					}
62					Func::Desc(desc) => desc
63						.arguments
64						.iter()
65						.map(|arg| arg.cpp_name(CppNameStyle::Declaration))
66						.collect(),
67				}
68			})
69			.as_slice()
70	}
71
72	pub fn arg_types(&mut self) -> &[Cow<str>] {
73		self
74			.arg_types
75			.get_or_insert_with(|| match &self.func {
76				Func::Clang { entity, gen_env, .. } => self
77					.func
78					.clang_arguments(*entity)
79					.iter()
80					.map(|arg| {
81						Owned(
82							Field::new(*arg, gen_env)
83								.type_ref()
84								.cpp_name(CppNameStyle::Reference)
85								.into_owned(),
86						)
87					})
88					.collect(),
89				Func::Desc(desc) => desc
90					.arguments
91					.iter()
92					.map(|arg| {
93						arg.type_ref()
94							.map_borrowed(|type_ref| type_ref.cpp_name(CppNameStyle::Reference))
95					})
96					.collect(),
97			})
98			.as_slice()
99	}
100
101	pub fn dump(&mut self) -> String {
102		let constness = self.constness();
103		let name = self.name().to_string();
104		let arg_names = self.arg_names().iter().map(|s| s.to_string()).collect::<Vec<_>>();
105		let arg_types = self.arg_types().iter().map(|s| s.to_string()).collect::<Vec<_>>();
106		format!(
107			"(\"{name}\", vec![(pred!({cnst}, {arg_names:?}, {arg_types:?}), _)]),",
108			cnst = constness.rust_qual_ptr().trim()
109		)
110	}
111}
112
113pub type FuncMatcherInner<'l, RES> = HashMap<&'l str, Vec<(&'l [Pred<'l>], RES)>>;
114
115pub type UsageTracker<'l> = (&'l str, &'l [Pred<'l>]);
116
117#[derive(Debug)]
118pub struct FuncMatcher<'l, RES> {
119	inner: FuncMatcherInner<'l, RES>,
120	usage_tracking: Option<RwLock<HashSet<UsageTracker<'l>>>>,
121}
122
123impl<'l, RES> FuncMatcher<'l, RES> {
124	pub fn empty() -> Self {
125		Self {
126			inner: HashMap::new(),
127			usage_tracking: None,
128		}
129	}
130
131	pub fn create(inner: FuncMatcherInner<'l, RES>) -> Self {
132		Self {
133			inner,
134			usage_tracking: None,
135		}
136	}
137
138	pub fn get(&self, f: &mut FuncMatchProperties) -> Option<&RES> {
139		let mtch = self.inner.get(f.name()).and_then(|matchers| {
140			matchers
141				.iter()
142				.find_map(|(preds, res)| preds.iter().all(|m| m.matches(f)).then_some((preds, res)))
143		});
144		if let Some((preds, res)) = mtch {
145			if let Some(usage_tracking) = self.usage_tracking.as_ref() {
146				let needs_removal = if let Ok(usage_tracking) = usage_tracking.read() {
147					usage_tracking.contains(&(f.name(), *preds))
148				} else {
149					false
150				};
151				if needs_removal {
152					if let Ok(mut usage_tracking) = usage_tracking.write() {
153						usage_tracking.retain(|x| x != &(f.name(), *preds));
154					}
155				}
156			}
157			Some(res)
158		} else {
159			None
160		}
161	}
162
163	pub fn start_usage_tracking(&mut self) {
164		if !self.inner.is_empty() {
165			let mut usage_tracking = HashSet::new();
166			for (name, matchers) in &self.inner {
167				for (predicates, _) in matchers {
168					usage_tracking.insert((*name, *predicates));
169				}
170			}
171			self.usage_tracking = Some(RwLock::new(usage_tracking));
172		}
173	}
174
175	pub fn finish_usage_tracking(&mut self) -> HashSet<UsageTracker> {
176		if let Some(out) = self.usage_tracking.take() {
177			if let Ok(usage_tracking) = out.into_inner() {
178				if !usage_tracking.is_empty() {
179					return usage_tracking;
180				}
181			}
182		}
183		HashSet::new()
184	}
185}
186
187#[derive(Debug, PartialEq, Eq, Hash)]
188pub enum Pred<'l> {
189	Constness(Constness),
190	Return(&'l str),
191	ArgNames(&'l [&'l str]),
192	ArgTypes(&'l [&'l str]),
193}
194
195impl Pred<'_> {
196	pub fn matches(&self, f: &mut FuncMatchProperties) -> bool {
197		match self {
198			Self::Constness(cnst) => f.constness() == *cnst,
199			Self::Return(ret_type) => f.return_type() == *ret_type,
200			Self::ArgNames(arg_names) => f.arg_names() == *arg_names,
201			Self::ArgTypes(arg_types) => f.arg_types() == *arg_types,
202		}
203	}
204}
205
206#[cfg(test)]
207mod test {
208	use std::collections::{HashMap, HashSet};
209
210	use crate::class::ClassDesc;
211	use crate::field::{Field, FieldDesc};
212	use crate::func::func_matcher::{FuncMatcher, Pred};
213	use crate::func::{FuncDesc, FuncKind, ReturnKind};
214	use crate::type_ref::{Constness, TypeRef, TypeRefDesc, TypeRefTypeHint};
215	use crate::writer::rust_native::type_ref::Lifetime;
216	use crate::{Func, SupportedModule};
217
218	#[test]
219	fn test_func_matcher() {
220		let f = Func::new_desc(
221			FuncDesc::new(
222				FuncKind::Constructor(ClassDesc::cv_input_array()),
223				Constness::Mut,
224				ReturnKind::Fallible,
225				"_InputArray",
226				SupportedModule::Core,
227				[
228					Field::new_desc(FieldDesc::new(
229						"vec",
230						TypeRef::new_array(TypeRefDesc::uchar().with_inherent_constness(Constness::Const), None),
231					)),
232					Field::new_desc(FieldDesc::new(
233						"n",
234						TypeRefDesc::int().with_type_hint(TypeRefTypeHint::LenForSlice(["vec".to_string()].as_slice().into(), 1)),
235					)),
236				],
237				TypeRefDesc::cv_input_array()
238					.with_inherent_constness(Constness::Const)
239					.with_type_hint(TypeRefTypeHint::BoxedAsRef(Constness::Const, &["vec"], Lifetime::Elided)),
240			)
241			.rust_custom_leafname("from_byte_slice"),
242		);
243
244		// match with all predicates
245		{
246			let matcher = FuncMatcher::create(HashMap::from([(
247				"cv::_InputArray::_InputArray",
248				vec![(
249					[
250						Pred::Return("const cv::_InputArray"),
251						Pred::ArgNames(&["vec", "n"]),
252						Pred::ArgTypes(&["const unsigned char*", "int"]),
253						Pred::Constness(Constness::Mut),
254					]
255					.as_slice(),
256					"MATCH",
257				)],
258			)]));
259
260			let mut f_matcher = f.matcher();
261			let res = matcher.get(&mut f_matcher);
262			assert_eq!(Some(&"MATCH"), res);
263		}
264
265		// match with limited predicates
266		{
267			let matcher = FuncMatcher::create(HashMap::from([(
268				"cv::_InputArray::_InputArray",
269				vec![(
270					[Pred::ArgNames(&["vec", "n"]), Pred::Constness(Constness::Mut)].as_slice(),
271					"MATCH",
272				)],
273			)]));
274
275			let mut f_matcher = f.matcher();
276			let res = matcher.get(&mut f_matcher);
277			assert_eq!(Some(&"MATCH"), res);
278		}
279
280		// no match with limited predicates
281		{
282			let matcher = FuncMatcher::create(HashMap::from([(
283				"cv::_InputArray::_InputArray",
284				vec![(
285					[Pred::ArgNames(&["vec", "notN"]), Pred::Constness(Constness::Mut)].as_slice(),
286					"MATCH",
287				)],
288			)]));
289
290			let mut f_matcher = f.matcher();
291			let res = matcher.get(&mut f_matcher);
292			assert_eq!(None, res);
293		}
294	}
295
296	#[test]
297	fn test_func_matcher_usage_tracking() {
298		let f = Func::new_desc(
299			FuncDesc::new(
300				FuncKind::Constructor(ClassDesc::cv_input_array()),
301				Constness::Mut,
302				ReturnKind::Fallible,
303				"_InputArray",
304				SupportedModule::Core,
305				[
306					Field::new_desc(FieldDesc::new(
307						"vec",
308						TypeRef::new_array(TypeRefDesc::uchar().with_inherent_constness(Constness::Const), None),
309					)),
310					Field::new_desc(FieldDesc::new(
311						"n",
312						TypeRefDesc::int().with_type_hint(TypeRefTypeHint::LenForSlice(["vec".to_string()].as_slice().into(), 1)),
313					)),
314				],
315				TypeRefDesc::cv_input_array()
316					.with_inherent_constness(Constness::Const)
317					.with_type_hint(TypeRefTypeHint::BoxedAsRef(Constness::Const, &["vec"], Lifetime::Elided)),
318			)
319			.rust_custom_leafname("from_byte_slice"),
320		);
321
322		// with match
323		{
324			let mut matcher = FuncMatcher::create(HashMap::from([(
325				"cv::_InputArray::_InputArray",
326				vec![(
327					[
328						Pred::Return("const cv::_InputArray"),
329						Pred::ArgNames(&["vec", "n"]),
330						Pred::ArgTypes(&["const unsigned char*", "int"]),
331						Pred::Constness(Constness::Mut),
332					]
333					.as_slice(),
334					"MATCH",
335				)],
336			)]));
337			matcher.start_usage_tracking();
338			let mut f_matcher = f.matcher();
339			matcher.get(&mut f_matcher);
340			let usage_tracking = matcher.finish_usage_tracking();
341			assert!(usage_tracking.is_empty());
342		}
343
344		// no match
345		{
346			let mut matcher = FuncMatcher::create(HashMap::from([(
347				"cv::_InputArray::_InputArray",
348				vec![([Pred::ArgNames(&["vec", "notN"])].as_slice(), "MATCH")],
349			)]));
350			matcher.start_usage_tracking();
351			let mut f_matcher = f.matcher();
352			matcher.get(&mut f_matcher);
353			let usage_tracking = matcher.finish_usage_tracking();
354			assert_eq!(
355				HashSet::from([("cv::_InputArray::_InputArray", [Pred::ArgNames(&["vec", "notN"])].as_slice())]),
356				usage_tracking
357			);
358		}
359	}
360}