mathhook_core/pattern/matching/engine/
core.rs1use super::{apply_replacement, match_commutative, PatternMatches};
6use crate::core::Expression;
7use crate::pattern::matching::patterns::Pattern;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub trait Matchable {
13 fn matches(&self, pattern: &Pattern) -> Option<PatternMatches>;
52
53 fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression;
107}
108
109impl Matchable for Expression {
110 fn matches(&self, pattern: &Pattern) -> Option<PatternMatches> {
111 let mut bindings = HashMap::new();
112 if match_recursive(self, pattern, &mut bindings) {
113 Some(bindings)
114 } else {
115 None
116 }
117 }
118
119 fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression {
120 if let Some(bindings) = self.matches(pattern) {
121 apply_replacement(replacement, &bindings)
122 } else {
123 match self {
124 Expression::Add(terms) => {
125 let new_terms: Vec<Expression> = terms
126 .iter()
127 .map(|t| t.replace(pattern, replacement))
128 .collect();
129 Expression::Add(Arc::new(new_terms))
130 }
131
132 Expression::Mul(factors) => {
133 let new_factors: Vec<Expression> = factors
134 .iter()
135 .map(|f| f.replace(pattern, replacement))
136 .collect();
137 Expression::Mul(Arc::new(new_factors))
138 }
139
140 Expression::Pow(base, exp) => {
141 let new_base = base.replace(pattern, replacement);
142 let new_exp = exp.replace(pattern, replacement);
143 Expression::Pow(Arc::new(new_base), Arc::new(new_exp))
144 }
145
146 Expression::Function { name, args } => {
147 let new_args: Vec<Expression> = args
148 .iter()
149 .map(|a| a.replace(pattern, replacement))
150 .collect();
151 Expression::Function {
152 name: name.clone(),
153 args: Arc::new(new_args),
154 }
155 }
156
157 _ => self.clone(),
158 }
159 }
160 }
161}
162
163pub(super) fn match_recursive(
168 expr: &Expression,
169 pattern: &Pattern,
170 bindings: &mut PatternMatches,
171) -> bool {
172 match pattern {
173 Pattern::Wildcard { name, constraints } => {
174 if let Some(constraints) = constraints {
175 if !constraints.is_satisfied_by(expr) {
176 return false;
177 }
178 }
179
180 if let Some(existing) = bindings.get(name) {
181 expr == existing
182 } else {
183 bindings.insert(name.clone(), expr.clone());
184 true
185 }
186 }
187
188 Pattern::Exact(pattern_expr) => expr == pattern_expr,
189
190 Pattern::Add(pattern_terms) => {
191 if let Expression::Add(expr_terms) = expr {
192 match_commutative(expr_terms, pattern_terms, bindings)
193 } else {
194 false
195 }
196 }
197
198 Pattern::Mul(pattern_factors) => {
199 if let Expression::Mul(expr_factors) = expr {
200 match_commutative(expr_factors, pattern_factors, bindings)
201 } else {
202 false
203 }
204 }
205
206 Pattern::Pow(pattern_base, pattern_exp) => {
207 if let Expression::Pow(expr_base, expr_exp) = expr {
208 match_recursive(expr_base, pattern_base, bindings)
209 && match_recursive(expr_exp, pattern_exp, bindings)
210 } else {
211 false
212 }
213 }
214
215 Pattern::Function { name, args } => {
216 if let Expression::Function {
217 name: expr_name,
218 args: expr_args,
219 } = expr
220 {
221 if expr_name.as_ref() != name.as_str() {
222 return false;
223 }
224
225 if expr_args.len() != args.len() {
226 return false;
227 }
228
229 for (expr_arg, pattern_arg) in expr_args.iter().zip(args.iter()) {
230 if !match_recursive(expr_arg, pattern_arg, bindings) {
231 return false;
232 }
233 }
234
235 true
236 } else {
237 false
238 }
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::pattern::matching::patterns::Pattern;
247 use crate::prelude::*;
248
249 #[test]
250 fn test_wildcard_pattern_matches() {
251 let expr = Expression::integer(42);
252 let pattern = Pattern::wildcard("x");
253
254 let matches = expr.matches(&pattern);
255 assert!(matches.is_some());
256
257 if let Some(bindings) = matches {
258 assert_eq!(bindings.get("x"), Some(&Expression::integer(42)));
259 }
260 }
261
262 #[test]
263 fn test_exact_pattern_matches() {
264 let expr = Expression::integer(42);
265 let pattern = Pattern::Exact(Expression::integer(42));
266
267 assert!(expr.matches(&pattern).is_some());
268 }
269
270 #[test]
271 fn test_exact_pattern_no_match() {
272 let expr = Expression::integer(42);
273 let pattern = Pattern::Exact(Expression::integer(43));
274
275 assert!(expr.matches(&pattern).is_none());
276 }
277
278 #[test]
279 fn test_addition_pattern() {
280 let x = symbol!(x);
281 let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
282
283 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
284
285 let matches = expr.matches(&pattern);
286 assert!(matches.is_some());
287
288 if let Some(bindings) = matches {
289 let a_val = bindings.get("a").unwrap();
290 let b_val = bindings.get("b").unwrap();
291
292 assert!(
293 (a_val == &Expression::symbol(x.clone()) && b_val == &Expression::integer(1))
294 || (a_val == &Expression::integer(1)
295 && b_val == &Expression::symbol(x.clone()))
296 );
297 }
298 }
299
300 #[test]
301 fn test_multiplication_pattern() {
302 let x = symbol!(x);
303 let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
304
305 let pattern = Pattern::Mul(vec![
306 Pattern::Exact(Expression::integer(2)),
307 Pattern::wildcard("x"),
308 ]);
309
310 let matches = expr.matches(&pattern);
311 assert!(matches.is_some());
312
313 if let Some(bindings) = matches {
314 assert_eq!(bindings.get("x"), Some(&Expression::symbol(x.clone())));
315 }
316 }
317
318 #[test]
319 fn test_power_pattern() {
320 let x = symbol!(x);
321 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
322
323 let pattern = Pattern::Pow(
324 Box::new(Pattern::wildcard("base")),
325 Box::new(Pattern::Exact(Expression::integer(2))),
326 );
327
328 let matches = expr.matches(&pattern);
329 assert!(matches.is_some());
330
331 if let Some(bindings) = matches {
332 assert_eq!(bindings.get("base"), Some(&Expression::symbol(x.clone())));
333 }
334 }
335
336 #[test]
337 fn test_function_pattern() {
338 let x = symbol!(x);
339 let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
340
341 let pattern = Pattern::Function {
342 name: "sin".to_string(),
343 args: vec![Pattern::wildcard("arg")],
344 };
345
346 let matches = expr.matches(&pattern);
347 assert!(matches.is_some());
348
349 if let Some(bindings) = matches {
350 assert_eq!(bindings.get("arg"), Some(&Expression::symbol(x.clone())));
351 }
352 }
353
354 #[test]
355 fn test_wildcard_consistency() {
356 let x = symbol!(x);
357 let expr = Expression::Add(Arc::new(vec![
358 Expression::symbol(x.clone()),
359 Expression::symbol(x.clone()),
360 ]));
361
362 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
363
364 let matches = expr.matches(&pattern);
365 assert!(matches.is_some());
366
367 if let Some(bindings) = matches {
368 assert_eq!(bindings.get("a"), Some(&Expression::symbol(x.clone())));
369 }
370 }
371
372 #[test]
373 fn test_wildcard_inconsistency() {
374 let x = symbol!(x);
375 let y = symbol!(y);
376 let expr = Expression::add(vec![
377 Expression::symbol(x.clone()),
378 Expression::symbol(y.clone()),
379 ]);
380
381 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
382
383 assert!(expr.matches(&pattern).is_none());
384 }
385
386 #[test]
387 fn test_wildcard_with_exclude() {
388 let x = symbol!(x);
389 let y = symbol!(y);
390
391 let pattern = Pattern::wildcard_excluding("a", vec![Expression::symbol(x.clone())]);
392
393 assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
394
395 assert!(Expression::symbol(y.clone()).matches(&pattern).is_some());
396
397 let expr_with_x =
398 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
399 assert!(expr_with_x.matches(&pattern).is_none());
400 }
401
402 #[test]
403 fn test_wildcard_with_property() {
404 fn is_integer(expr: &Expression) -> bool {
405 matches!(expr, Expression::Number(_))
406 }
407
408 let pattern = Pattern::wildcard_with_properties("n", vec![is_integer]);
409
410 assert!(Expression::integer(42).matches(&pattern).is_some());
411
412 let x = symbol!(x);
413 assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
414 }
415}