Skip to main content

opencv_binding_generator/func/
func_matcher.rs

1use std::borrow::Cow;
2use std::borrow::Cow::Owned;
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::field::Field;
7use crate::type_ref::Constness;
8use crate::{CowMapBorrowedExt, CppNameStyle, Element, Func};
9
10#[derive(Debug)]
11pub struct FuncMatchProperties<'f> {
12	func: &'f Func<'f, 'f>,
13	name: Cow<'f, str>,
14	constness: Option<Constness>,
15	ret: Option<Cow<'f, str>>,
16	arg_names: Option<Vec<Cow<'f, str>>>,
17	arg_types: Option<Vec<Cow<'f, str>>>,
18}
19
20impl<'f> FuncMatchProperties<'f> {
21	pub fn new(func: &'f Func<'f, 'f>, name: Cow<'f, str>) -> Self {
22		Self {
23			func,
24			name,
25			constness: None,
26			ret: None,
27			arg_names: None,
28			arg_types: None,
29		}
30	}
31
32	pub fn name(&mut self) -> &str {
33		self.name.as_ref()
34	}
35
36	pub fn constness(&mut self) -> Constness {
37		*self.constness.get_or_insert_with(|| self.func.constness())
38	}
39
40	pub fn return_type(&mut self) -> &str {
41		self.ret.get_or_insert_with(|| {
42			self
43				.func
44				.return_type_ref()
45				.map_borrowed(|ret| ret.cpp_name(CppNameStyle::Reference))
46		})
47	}
48
49	pub fn arg_names(&mut self) -> &[Cow<'_, str>] {
50		self
51			.arg_names
52			.get_or_insert_with(|| {
53				match &self.func {
54					Func::Clang { entity, .. } => {
55						self
56							.func
57							.clang_arguments(*entity)
58							.iter()
59							.map(|arg| Owned(arg.cpp_name(CppNameStyle::Declaration).into_owned())) // todo: find a way to store borrowed
60							.collect()
61					}
62					Func::Desc(desc) => desc
63						.arguments
64						.iter()
65						.map(|arg| arg.cpp_name(CppNameStyle::Declaration))
66						.collect(),
67				}
68			})
69			.as_slice()
70	}
71
72	pub fn arg_types(&mut self) -> &[Cow<'_, str>] {
73		self
74			.arg_types
75			.get_or_insert_with(|| match &self.func {
76				Func::Clang { entity, gen_env, .. } => self
77					.func
78					.clang_arguments(*entity)
79					.iter()
80					.map(|arg| {
81						Owned(
82							Field::new(*arg, gen_env)
83								.type_ref()
84								.cpp_name(CppNameStyle::Reference)
85								.into_owned(),
86						)
87					})
88					.collect(),
89				Func::Desc(desc) => desc
90					.arguments
91					.iter()
92					.map(|arg| {
93						arg.type_ref()
94							.map_borrowed(|type_ref| type_ref.cpp_name(CppNameStyle::Reference))
95					})
96					.collect(),
97			})
98			.as_slice()
99	}
100
101	pub fn dump(&mut self) -> String {
102		let constness = self.constness();
103		let name = self.name().to_string();
104		let arg_names = self.arg_names().iter().map(|s| s.to_string()).collect::<Vec<_>>();
105		let arg_types = self.arg_types().iter().map(|s| s.to_string()).collect::<Vec<_>>();
106		format!(
107			"(\"{name}\", vec![(pred!({cnst}, {arg_names:?}, {arg_types:?}), _)]),",
108			cnst = constness.rust_qual_ptr().trim()
109		)
110	}
111}
112
113pub type FuncMatcherInner<'l, RES> = HashMap<&'l str, Vec<(&'l [Pred<'l>], RES)>>;
114
115pub type UsageTracker<'l> = (&'l str, &'l [Pred<'l>]);
116
117#[derive(Debug)]
118pub struct FuncMatcher<'l, RES> {
119	inner: FuncMatcherInner<'l, RES>,
120	usage_tracking: Option<RwLock<HashSet<UsageTracker<'l>>>>,
121}
122
123impl<'l, RES> FuncMatcher<'l, RES> {
124	pub fn empty() -> Self {
125		Self {
126			inner: HashMap::new(),
127			usage_tracking: None,
128		}
129	}
130
131	pub fn create(inner: FuncMatcherInner<'l, RES>) -> Self {
132		Self {
133			inner,
134			usage_tracking: None,
135		}
136	}
137
138	pub fn get(&self, f: &mut FuncMatchProperties) -> Option<&RES> {
139		let mtch = self.inner.get(f.name()).and_then(|matchers| {
140			matchers
141				.iter()
142				.find_map(|(preds, res)| preds.iter().all(|m| m.matches(f)).then_some((preds, res)))
143		});
144		if let Some((preds, res)) = mtch {
145			if let Some(usage_tracking) = self.usage_tracking.as_ref() {
146				let needs_removal = if let Ok(usage_tracking) = usage_tracking.read() {
147					usage_tracking.contains(&(f.name(), *preds))
148				} else {
149					false
150				};
151				if needs_removal && let Ok(mut usage_tracking) = usage_tracking.write() {
152					usage_tracking.retain(|x| x != &(f.name(), *preds));
153				}
154			}
155			Some(res)
156		} else {
157			None
158		}
159	}
160
161	pub fn start_usage_tracking(&mut self) {
162		if !self.inner.is_empty() {
163			let mut usage_tracking = HashSet::new();
164			for (name, matchers) in &self.inner {
165				for (predicates, _) in matchers {
166					usage_tracking.insert((*name, *predicates));
167				}
168			}
169			self.usage_tracking = Some(RwLock::new(usage_tracking));
170		}
171	}
172
173	pub fn finish_usage_tracking(&mut self) -> HashSet<UsageTracker<'_>> {
174		if let Some(out) = self.usage_tracking.take()
175			&& let Ok(usage_tracking) = out.into_inner()
176			&& !usage_tracking.is_empty()
177		{
178			return usage_tracking;
179		}
180		HashSet::new()
181	}
182}
183
184#[derive(Debug, PartialEq, Eq, Hash)]
185pub enum Pred<'l> {
186	Constness(Constness),
187	Return(&'l str),
188	ArgNames(&'l [&'l str]),
189	ArgTypes(&'l [&'l str]),
190}
191
192impl Pred<'_> {
193	pub fn matches(&self, f: &mut FuncMatchProperties) -> bool {
194		match self {
195			Self::Constness(cnst) => f.constness() == *cnst,
196			Self::Return(ret_type) => f.return_type() == *ret_type,
197			Self::ArgNames(arg_names) => f.arg_names() == *arg_names,
198			Self::ArgTypes(arg_types) => f.arg_types() == *arg_types,
199		}
200	}
201}
202
203#[cfg(test)]
204mod test {
205	use std::collections::{HashMap, HashSet};
206
207	use crate::class::ClassDesc;
208	use crate::field::{Field, FieldDesc};
209	use crate::func::func_matcher::{FuncMatcher, Pred};
210	use crate::func::{FuncDesc, FuncKind, ReturnKind};
211	use crate::type_ref::{Constness, TypeRef, TypeRefDesc, TypeRefTypeHint};
212	use crate::writer::rust_native::type_ref::Lifetime;
213	use crate::{Func, SupportedModule};
214
215	#[test]
216	fn test_func_matcher() {
217		let f = Func::new_desc(
218			FuncDesc::new(
219				FuncKind::Constructor(ClassDesc::cv_input_array()),
220				Constness::Mut,
221				ReturnKind::Fallible,
222				"_InputArray",
223				SupportedModule::Core,
224				[
225					Field::new_desc(FieldDesc::new(
226						"vec",
227						TypeRef::new_array(TypeRefDesc::uchar().with_inherent_constness(Constness::Const), None),
228					)),
229					Field::new_desc(FieldDesc::new(
230						"n",
231						TypeRefDesc::int().with_type_hint(TypeRefTypeHint::LenForSlice(["vec".to_string()].as_slice().into(), 1)),
232					)),
233				],
234				TypeRefDesc::cv_input_array()
235					.with_inherent_constness(Constness::Const)
236					.with_type_hint(TypeRefTypeHint::BoxedAsRef(Constness::Const, &["vec"], Lifetime::Elided)),
237			)
238			.rust_custom_leafname("from_byte_slice"),
239		);
240
241		// match with all predicates
242		{
243			let matcher = FuncMatcher::create(HashMap::from([(
244				"cv::_InputArray::_InputArray",
245				vec![(
246					[
247						Pred::Return("const cv::_InputArray"),
248						Pred::ArgNames(&["vec", "n"]),
249						Pred::ArgTypes(&["const unsigned char*", "int"]),
250						Pred::Constness(Constness::Mut),
251					]
252					.as_slice(),
253					"MATCH",
254				)],
255			)]));
256
257			let mut f_matcher = f.matcher();
258			let res = matcher.get(&mut f_matcher);
259			assert_eq!(Some(&"MATCH"), res);
260		}
261
262		// match with limited predicates
263		{
264			let matcher = FuncMatcher::create(HashMap::from([(
265				"cv::_InputArray::_InputArray",
266				vec![(
267					[Pred::ArgNames(&["vec", "n"]), Pred::Constness(Constness::Mut)].as_slice(),
268					"MATCH",
269				)],
270			)]));
271
272			let mut f_matcher = f.matcher();
273			let res = matcher.get(&mut f_matcher);
274			assert_eq!(Some(&"MATCH"), res);
275		}
276
277		// no match with limited predicates
278		{
279			let matcher = FuncMatcher::create(HashMap::from([(
280				"cv::_InputArray::_InputArray",
281				vec![(
282					[Pred::ArgNames(&["vec", "notN"]), Pred::Constness(Constness::Mut)].as_slice(),
283					"MATCH",
284				)],
285			)]));
286
287			let mut f_matcher = f.matcher();
288			let res = matcher.get(&mut f_matcher);
289			assert_eq!(None, res);
290		}
291	}
292
293	#[test]
294	fn test_func_matcher_usage_tracking() {
295		let f = Func::new_desc(
296			FuncDesc::new(
297				FuncKind::Constructor(ClassDesc::cv_input_array()),
298				Constness::Mut,
299				ReturnKind::Fallible,
300				"_InputArray",
301				SupportedModule::Core,
302				[
303					Field::new_desc(FieldDesc::new(
304						"vec",
305						TypeRef::new_array(TypeRefDesc::uchar().with_inherent_constness(Constness::Const), None),
306					)),
307					Field::new_desc(FieldDesc::new(
308						"n",
309						TypeRefDesc::int().with_type_hint(TypeRefTypeHint::LenForSlice(["vec".to_string()].as_slice().into(), 1)),
310					)),
311				],
312				TypeRefDesc::cv_input_array()
313					.with_inherent_constness(Constness::Const)
314					.with_type_hint(TypeRefTypeHint::BoxedAsRef(Constness::Const, &["vec"], Lifetime::Elided)),
315			)
316			.rust_custom_leafname("from_byte_slice"),
317		);
318
319		// with match
320		{
321			let mut matcher = FuncMatcher::create(HashMap::from([(
322				"cv::_InputArray::_InputArray",
323				vec![(
324					[
325						Pred::Return("const cv::_InputArray"),
326						Pred::ArgNames(&["vec", "n"]),
327						Pred::ArgTypes(&["const unsigned char*", "int"]),
328						Pred::Constness(Constness::Mut),
329					]
330					.as_slice(),
331					"MATCH",
332				)],
333			)]));
334			matcher.start_usage_tracking();
335			let mut f_matcher = f.matcher();
336			matcher.get(&mut f_matcher);
337			let usage_tracking = matcher.finish_usage_tracking();
338			assert!(usage_tracking.is_empty());
339		}
340
341		// no match
342		{
343			let mut matcher = FuncMatcher::create(HashMap::from([(
344				"cv::_InputArray::_InputArray",
345				vec![([Pred::ArgNames(&["vec", "notN"])].as_slice(), "MATCH")],
346			)]));
347			matcher.start_usage_tracking();
348			let mut f_matcher = f.matcher();
349			matcher.get(&mut f_matcher);
350			let usage_tracking = matcher.finish_usage_tracking();
351			assert_eq!(
352				HashSet::from([("cv::_InputArray::_InputArray", [Pred::ArgNames(&["vec", "notN"])].as_slice())]),
353				usage_tracking
354			);
355		}
356	}
357}