mathhook_core/pattern/matching/engine/
core.rs

1//! Core pattern matching implementation
2//!
3//! Provides the Matchable trait and recursive matching algorithms.
4
5use super::{apply_replacement, match_commutative, PatternMatches};
6use crate::core::Expression;
7use crate::pattern::matching::patterns::Pattern;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Trait for types that support pattern matching
12pub trait Matchable {
13    /// Match this expression against a pattern
14    ///
15    /// Returns bindings for wildcard names if the match succeeds,
16    /// or None if the pattern doesn't match.
17    ///
18    /// # Arguments
19    ///
20    /// * `pattern` - The pattern to match against
21    ///
22    /// # Examples
23    ///
24    /// ```
25    /// use mathhook_core::prelude::*;
26    /// use mathhook_core::pattern::{Pattern, Matchable};
27    ///
28    /// let x = symbol!(x);
29    /// let expr = Expression::add(vec![
30    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
31    ///     Expression::integer(1)
32    /// ]);
33    ///
34    /// // Pattern: a*x + b
35    /// let pattern = Pattern::Add(vec![
36    ///     Pattern::Mul(vec![
37    ///         Pattern::wildcard("a"),
38    ///         Pattern::Exact(Expression::symbol(x.clone()))
39    ///     ]),
40    ///     Pattern::wildcard("b")
41    /// ]);
42    ///
43    /// let matches = expr.matches(&pattern);
44    /// assert!(matches.is_some());
45    ///
46    /// if let Some(bindings) = matches {
47    ///     assert_eq!(bindings.get("a"), Some(&Expression::integer(2)));
48    ///     assert_eq!(bindings.get("b"), Some(&Expression::integer(1)));
49    /// }
50    /// ```
51    fn matches(&self, pattern: &Pattern) -> Option<PatternMatches>;
52
53    /// Replace all occurrences of a pattern with a replacement expression
54    ///
55    /// Uses pattern matching to find matches and applies the replacement,
56    /// substituting wildcards with their matched values.
57    ///
58    /// # Arguments
59    ///
60    /// * `pattern` - The pattern to match
61    /// * `replacement` - The replacement pattern (can contain wildcards from match)
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use mathhook_core::prelude::*;
67    /// use mathhook_core::pattern::{Pattern, Matchable};
68    ///
69    /// let x = symbol!(x);
70    /// // sin(x)^2 + cos(x)^2
71    /// let expr = Expression::add(vec![
72    ///     Expression::pow(
73    ///         Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]),
74    ///         Expression::integer(2)
75    ///     ),
76    ///     Expression::pow(
77    ///         Expression::function("cos".to_string(), vec![Expression::symbol(x.clone())]),
78    ///         Expression::integer(2)
79    ///     )
80    /// ]);
81    ///
82    /// // Pattern: sin(a)^2 + cos(a)^2
83    /// let pattern = Pattern::Add(vec![
84    ///     Pattern::Pow(
85    ///         Box::new(Pattern::Function {
86    ///             name: "sin".to_string(),
87    ///             args: vec![Pattern::wildcard("a")]
88    ///         }),
89    ///         Box::new(Pattern::Exact(Expression::integer(2)))
90    ///     ),
91    ///     Pattern::Pow(
92    ///         Box::new(Pattern::Function {
93    ///             name: "cos".to_string(),
94    ///             args: vec![Pattern::wildcard("a")]
95    ///         }),
96    ///         Box::new(Pattern::Exact(Expression::integer(2)))
97    ///     )
98    /// ]);
99    ///
100    /// // Replacement: 1
101    /// let replacement = Pattern::Exact(Expression::integer(1));
102    ///
103    /// let result = expr.replace(&pattern, &replacement);
104    /// assert_eq!(result, Expression::integer(1));
105    /// ```
106    fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression;
107}
108
109impl Matchable for Expression {
110    fn matches(&self, pattern: &Pattern) -> Option<PatternMatches> {
111        let mut bindings = HashMap::new();
112        if match_recursive(self, pattern, &mut bindings) {
113            Some(bindings)
114        } else {
115            None
116        }
117    }
118
119    fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression {
120        if let Some(bindings) = self.matches(pattern) {
121            apply_replacement(replacement, &bindings)
122        } else {
123            match self {
124                Expression::Add(terms) => {
125                    let new_terms: Vec<Expression> = terms
126                        .iter()
127                        .map(|t| t.replace(pattern, replacement))
128                        .collect();
129                    Expression::Add(Arc::new(new_terms))
130                }
131
132                Expression::Mul(factors) => {
133                    let new_factors: Vec<Expression> = factors
134                        .iter()
135                        .map(|f| f.replace(pattern, replacement))
136                        .collect();
137                    Expression::Mul(Arc::new(new_factors))
138                }
139
140                Expression::Pow(base, exp) => {
141                    let new_base = base.replace(pattern, replacement);
142                    let new_exp = exp.replace(pattern, replacement);
143                    Expression::Pow(Arc::new(new_base), Arc::new(new_exp))
144                }
145
146                Expression::Function { name, args } => {
147                    let new_args: Vec<Expression> = args
148                        .iter()
149                        .map(|a| a.replace(pattern, replacement))
150                        .collect();
151                    Expression::Function {
152                        name: name.clone(),
153                        args: Arc::new(new_args),
154                    }
155                }
156
157                _ => self.clone(),
158            }
159        }
160    }
161}
162
163/// Recursive helper for pattern matching
164///
165/// Attempts to match an expression against a pattern, accumulating
166/// wildcard bindings in the provided HashMap.
167pub(super) fn match_recursive(
168    expr: &Expression,
169    pattern: &Pattern,
170    bindings: &mut PatternMatches,
171) -> bool {
172    match pattern {
173        Pattern::Wildcard { name, constraints } => {
174            if let Some(constraints) = constraints {
175                if !constraints.is_satisfied_by(expr) {
176                    return false;
177                }
178            }
179
180            if let Some(existing) = bindings.get(name) {
181                expr == existing
182            } else {
183                bindings.insert(name.clone(), expr.clone());
184                true
185            }
186        }
187
188        Pattern::Exact(pattern_expr) => expr == pattern_expr,
189
190        Pattern::Add(pattern_terms) => {
191            if let Expression::Add(expr_terms) = expr {
192                match_commutative(expr_terms, pattern_terms, bindings)
193            } else {
194                false
195            }
196        }
197
198        Pattern::Mul(pattern_factors) => {
199            if let Expression::Mul(expr_factors) = expr {
200                match_commutative(expr_factors, pattern_factors, bindings)
201            } else {
202                false
203            }
204        }
205
206        Pattern::Pow(pattern_base, pattern_exp) => {
207            if let Expression::Pow(expr_base, expr_exp) = expr {
208                match_recursive(expr_base, pattern_base, bindings)
209                    && match_recursive(expr_exp, pattern_exp, bindings)
210            } else {
211                false
212            }
213        }
214
215        Pattern::Function { name, args } => {
216            if let Expression::Function {
217                name: expr_name,
218                args: expr_args,
219            } = expr
220            {
221                if expr_name.as_ref() != name.as_str() {
222                    return false;
223                }
224
225                if expr_args.len() != args.len() {
226                    return false;
227                }
228
229                for (expr_arg, pattern_arg) in expr_args.iter().zip(args.iter()) {
230                    if !match_recursive(expr_arg, pattern_arg, bindings) {
231                        return false;
232                    }
233                }
234
235                true
236            } else {
237                false
238            }
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::pattern::matching::patterns::Pattern;
247    use crate::prelude::*;
248
249    #[test]
250    fn test_wildcard_pattern_matches() {
251        let expr = Expression::integer(42);
252        let pattern = Pattern::wildcard("x");
253
254        let matches = expr.matches(&pattern);
255        assert!(matches.is_some());
256
257        if let Some(bindings) = matches {
258            assert_eq!(bindings.get("x"), Some(&Expression::integer(42)));
259        }
260    }
261
262    #[test]
263    fn test_exact_pattern_matches() {
264        let expr = Expression::integer(42);
265        let pattern = Pattern::Exact(Expression::integer(42));
266
267        assert!(expr.matches(&pattern).is_some());
268    }
269
270    #[test]
271    fn test_exact_pattern_no_match() {
272        let expr = Expression::integer(42);
273        let pattern = Pattern::Exact(Expression::integer(43));
274
275        assert!(expr.matches(&pattern).is_none());
276    }
277
278    #[test]
279    fn test_addition_pattern() {
280        let x = symbol!(x);
281        let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
282
283        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
284
285        let matches = expr.matches(&pattern);
286        assert!(matches.is_some());
287
288        if let Some(bindings) = matches {
289            let a_val = bindings.get("a").unwrap();
290            let b_val = bindings.get("b").unwrap();
291
292            assert!(
293                (a_val == &Expression::symbol(x.clone()) && b_val == &Expression::integer(1))
294                    || (a_val == &Expression::integer(1)
295                        && b_val == &Expression::symbol(x.clone()))
296            );
297        }
298    }
299
300    #[test]
301    fn test_multiplication_pattern() {
302        let x = symbol!(x);
303        let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
304
305        let pattern = Pattern::Mul(vec![
306            Pattern::Exact(Expression::integer(2)),
307            Pattern::wildcard("x"),
308        ]);
309
310        let matches = expr.matches(&pattern);
311        assert!(matches.is_some());
312
313        if let Some(bindings) = matches {
314            assert_eq!(bindings.get("x"), Some(&Expression::symbol(x.clone())));
315        }
316    }
317
318    #[test]
319    fn test_power_pattern() {
320        let x = symbol!(x);
321        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
322
323        let pattern = Pattern::Pow(
324            Box::new(Pattern::wildcard("base")),
325            Box::new(Pattern::Exact(Expression::integer(2))),
326        );
327
328        let matches = expr.matches(&pattern);
329        assert!(matches.is_some());
330
331        if let Some(bindings) = matches {
332            assert_eq!(bindings.get("base"), Some(&Expression::symbol(x.clone())));
333        }
334    }
335
336    #[test]
337    fn test_function_pattern() {
338        let x = symbol!(x);
339        let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
340
341        let pattern = Pattern::Function {
342            name: "sin".to_string(),
343            args: vec![Pattern::wildcard("arg")],
344        };
345
346        let matches = expr.matches(&pattern);
347        assert!(matches.is_some());
348
349        if let Some(bindings) = matches {
350            assert_eq!(bindings.get("arg"), Some(&Expression::symbol(x.clone())));
351        }
352    }
353
354    #[test]
355    fn test_wildcard_consistency() {
356        let x = symbol!(x);
357        let expr = Expression::Add(Arc::new(vec![
358            Expression::symbol(x.clone()),
359            Expression::symbol(x.clone()),
360        ]));
361
362        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
363
364        let matches = expr.matches(&pattern);
365        assert!(matches.is_some());
366
367        if let Some(bindings) = matches {
368            assert_eq!(bindings.get("a"), Some(&Expression::symbol(x.clone())));
369        }
370    }
371
372    #[test]
373    fn test_wildcard_inconsistency() {
374        let x = symbol!(x);
375        let y = symbol!(y);
376        let expr = Expression::add(vec![
377            Expression::symbol(x.clone()),
378            Expression::symbol(y.clone()),
379        ]);
380
381        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
382
383        assert!(expr.matches(&pattern).is_none());
384    }
385
386    #[test]
387    fn test_wildcard_with_exclude() {
388        let x = symbol!(x);
389        let y = symbol!(y);
390
391        let pattern = Pattern::wildcard_excluding("a", vec![Expression::symbol(x.clone())]);
392
393        assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
394
395        assert!(Expression::symbol(y.clone()).matches(&pattern).is_some());
396
397        let expr_with_x =
398            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
399        assert!(expr_with_x.matches(&pattern).is_none());
400    }
401
402    #[test]
403    fn test_wildcard_with_property() {
404        fn is_integer(expr: &Expression) -> bool {
405            matches!(expr, Expression::Number(_))
406        }
407
408        let pattern = Pattern::wildcard_with_properties("n", vec![is_integer]);
409
410        assert!(Expression::integer(42).matches(&pattern).is_some());
411
412        let x = symbol!(x);
413        assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
414    }
415}