mathhook_core/pattern/matching/engine/
core.rs

1//! Core pattern matching implementation
2//!
3//! Provides the Matchable trait and recursive matching algorithms.
4
5use super::{apply_replacement, match_commutative, PatternMatches};
6use crate::core::Expression;
7use crate::pattern::matching::patterns::Pattern;
8use std::collections::HashMap;
9
10/// Trait for types that support pattern matching
11pub trait Matchable {
12    /// Match this expression against a pattern
13    ///
14    /// Returns bindings for wildcard names if the match succeeds,
15    /// or None if the pattern doesn't match.
16    ///
17    /// # Arguments
18    ///
19    /// * `pattern` - The pattern to match against
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// use mathhook_core::prelude::*;
25    /// use mathhook_core::pattern::{Pattern, Matchable};
26    ///
27    /// let x = symbol!(x);
28    /// let expr = Expression::add(vec![
29    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
30    ///     Expression::integer(1)
31    /// ]);
32    ///
33    /// // Pattern: a*x + b
34    /// let pattern = Pattern::Add(vec![
35    ///     Pattern::Mul(vec![
36    ///         Pattern::wildcard("a"),
37    ///         Pattern::Exact(Expression::symbol(x.clone()))
38    ///     ]),
39    ///     Pattern::wildcard("b")
40    /// ]);
41    ///
42    /// let matches = expr.matches(&pattern);
43    /// assert!(matches.is_some());
44    ///
45    /// if let Some(bindings) = matches {
46    ///     assert_eq!(bindings.get("a"), Some(&Expression::integer(2)));
47    ///     assert_eq!(bindings.get("b"), Some(&Expression::integer(1)));
48    /// }
49    /// ```
50    fn matches(&self, pattern: &Pattern) -> Option<PatternMatches>;
51
52    /// Replace all occurrences of a pattern with a replacement expression
53    ///
54    /// Uses pattern matching to find matches and applies the replacement,
55    /// substituting wildcards with their matched values.
56    ///
57    /// # Arguments
58    ///
59    /// * `pattern` - The pattern to match
60    /// * `replacement` - The replacement pattern (can contain wildcards from match)
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use mathhook_core::prelude::*;
66    /// use mathhook_core::pattern::{Pattern, Matchable};
67    ///
68    /// let x = symbol!(x);
69    /// // sin(x)^2 + cos(x)^2
70    /// let expr = Expression::add(vec![
71    ///     Expression::pow(
72    ///         Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]),
73    ///         Expression::integer(2)
74    ///     ),
75    ///     Expression::pow(
76    ///         Expression::function("cos".to_string(), vec![Expression::symbol(x.clone())]),
77    ///         Expression::integer(2)
78    ///     )
79    /// ]);
80    ///
81    /// // Pattern: sin(a)^2 + cos(a)^2
82    /// let pattern = Pattern::Add(vec![
83    ///     Pattern::Pow(
84    ///         Box::new(Pattern::Function {
85    ///             name: "sin".to_string(),
86    ///             args: vec![Pattern::wildcard("a")]
87    ///         }),
88    ///         Box::new(Pattern::Exact(Expression::integer(2)))
89    ///     ),
90    ///     Pattern::Pow(
91    ///         Box::new(Pattern::Function {
92    ///             name: "cos".to_string(),
93    ///             args: vec![Pattern::wildcard("a")]
94    ///         }),
95    ///         Box::new(Pattern::Exact(Expression::integer(2)))
96    ///     )
97    /// ]);
98    ///
99    /// // Replacement: 1
100    /// let replacement = Pattern::Exact(Expression::integer(1));
101    ///
102    /// let result = expr.replace(&pattern, &replacement);
103    /// assert_eq!(result, Expression::integer(1));
104    /// ```
105    fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression;
106}
107
108impl Matchable for Expression {
109    fn matches(&self, pattern: &Pattern) -> Option<PatternMatches> {
110        let mut bindings = HashMap::new();
111        if match_recursive(self, pattern, &mut bindings) {
112            Some(bindings)
113        } else {
114            None
115        }
116    }
117
118    fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression {
119        if let Some(bindings) = self.matches(pattern) {
120            apply_replacement(replacement, &bindings)
121        } else {
122            match self {
123                Expression::Add(terms) => {
124                    let new_terms: Vec<Expression> = terms
125                        .iter()
126                        .map(|t| t.replace(pattern, replacement))
127                        .collect();
128                    Expression::Add(Box::new(new_terms))
129                }
130
131                Expression::Mul(factors) => {
132                    let new_factors: Vec<Expression> = factors
133                        .iter()
134                        .map(|f| f.replace(pattern, replacement))
135                        .collect();
136                    Expression::Mul(Box::new(new_factors))
137                }
138
139                Expression::Pow(base, exp) => {
140                    let new_base = base.replace(pattern, replacement);
141                    let new_exp = exp.replace(pattern, replacement);
142                    Expression::Pow(Box::new(new_base), Box::new(new_exp))
143                }
144
145                Expression::Function { name, args } => {
146                    let new_args: Vec<Expression> = args
147                        .iter()
148                        .map(|a| a.replace(pattern, replacement))
149                        .collect();
150                    Expression::Function {
151                        name: name.clone(),
152                        args: Box::new(new_args),
153                    }
154                }
155
156                _ => self.clone(),
157            }
158        }
159    }
160}
161
162/// Recursive helper for pattern matching
163///
164/// Attempts to match an expression against a pattern, accumulating
165/// wildcard bindings in the provided HashMap.
166pub(super) fn match_recursive(
167    expr: &Expression,
168    pattern: &Pattern,
169    bindings: &mut PatternMatches,
170) -> bool {
171    match pattern {
172        Pattern::Wildcard { name, constraints } => {
173            if let Some(constraints) = constraints {
174                if !constraints.is_satisfied_by(expr) {
175                    return false;
176                }
177            }
178
179            if let Some(existing) = bindings.get(name) {
180                expr == existing
181            } else {
182                bindings.insert(name.clone(), expr.clone());
183                true
184            }
185        }
186
187        Pattern::Exact(pattern_expr) => expr == pattern_expr,
188
189        Pattern::Add(pattern_terms) => {
190            if let Expression::Add(expr_terms) = expr {
191                match_commutative(expr_terms, pattern_terms, bindings)
192            } else {
193                false
194            }
195        }
196
197        Pattern::Mul(pattern_factors) => {
198            if let Expression::Mul(expr_factors) = expr {
199                match_commutative(expr_factors, pattern_factors, bindings)
200            } else {
201                false
202            }
203        }
204
205        Pattern::Pow(pattern_base, pattern_exp) => {
206            if let Expression::Pow(expr_base, expr_exp) = expr {
207                match_recursive(expr_base, pattern_base, bindings)
208                    && match_recursive(expr_exp, pattern_exp, bindings)
209            } else {
210                false
211            }
212        }
213
214        Pattern::Function { name, args } => {
215            if let Expression::Function {
216                name: expr_name,
217                args: expr_args,
218            } = expr
219            {
220                if expr_name != name {
221                    return false;
222                }
223
224                if expr_args.len() != args.len() {
225                    return false;
226                }
227
228                for (expr_arg, pattern_arg) in expr_args.iter().zip(args.iter()) {
229                    if !match_recursive(expr_arg, pattern_arg, bindings) {
230                        return false;
231                    }
232                }
233
234                true
235            } else {
236                false
237            }
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::pattern::matching::patterns::Pattern;
246    use crate::prelude::*;
247
248    #[test]
249    fn test_wildcard_pattern_matches() {
250        let expr = Expression::integer(42);
251        let pattern = Pattern::wildcard("x");
252
253        let matches = expr.matches(&pattern);
254        assert!(matches.is_some());
255
256        if let Some(bindings) = matches {
257            assert_eq!(bindings.get("x"), Some(&Expression::integer(42)));
258        }
259    }
260
261    #[test]
262    fn test_exact_pattern_matches() {
263        let expr = Expression::integer(42);
264        let pattern = Pattern::Exact(Expression::integer(42));
265
266        assert!(expr.matches(&pattern).is_some());
267    }
268
269    #[test]
270    fn test_exact_pattern_no_match() {
271        let expr = Expression::integer(42);
272        let pattern = Pattern::Exact(Expression::integer(43));
273
274        assert!(expr.matches(&pattern).is_none());
275    }
276
277    #[test]
278    fn test_addition_pattern() {
279        let x = symbol!(x);
280        let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
281
282        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
283
284        let matches = expr.matches(&pattern);
285        assert!(matches.is_some());
286
287        if let Some(bindings) = matches {
288            let a_val = bindings.get("a").unwrap();
289            let b_val = bindings.get("b").unwrap();
290
291            assert!(
292                (a_val == &Expression::symbol(x.clone()) && b_val == &Expression::integer(1))
293                    || (a_val == &Expression::integer(1)
294                        && b_val == &Expression::symbol(x.clone()))
295            );
296        }
297    }
298
299    #[test]
300    fn test_multiplication_pattern() {
301        let x = symbol!(x);
302        let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
303
304        let pattern = Pattern::Mul(vec![
305            Pattern::Exact(Expression::integer(2)),
306            Pattern::wildcard("x"),
307        ]);
308
309        let matches = expr.matches(&pattern);
310        assert!(matches.is_some());
311
312        if let Some(bindings) = matches {
313            assert_eq!(bindings.get("x"), Some(&Expression::symbol(x.clone())));
314        }
315    }
316
317    #[test]
318    fn test_power_pattern() {
319        let x = symbol!(x);
320        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
321
322        let pattern = Pattern::Pow(
323            Box::new(Pattern::wildcard("base")),
324            Box::new(Pattern::Exact(Expression::integer(2))),
325        );
326
327        let matches = expr.matches(&pattern);
328        assert!(matches.is_some());
329
330        if let Some(bindings) = matches {
331            assert_eq!(bindings.get("base"), Some(&Expression::symbol(x.clone())));
332        }
333    }
334
335    #[test]
336    fn test_function_pattern() {
337        let x = symbol!(x);
338        let expr = Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]);
339
340        let pattern = Pattern::Function {
341            name: "sin".to_string(),
342            args: vec![Pattern::wildcard("arg")],
343        };
344
345        let matches = expr.matches(&pattern);
346        assert!(matches.is_some());
347
348        if let Some(bindings) = matches {
349            assert_eq!(bindings.get("arg"), Some(&Expression::symbol(x.clone())));
350        }
351    }
352
353    #[test]
354    fn test_wildcard_consistency() {
355        let x = symbol!(x);
356        let expr = Expression::Add(Box::new(vec![
357            Expression::symbol(x.clone()),
358            Expression::symbol(x.clone()),
359        ]));
360
361        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
362
363        let matches = expr.matches(&pattern);
364        assert!(matches.is_some());
365
366        if let Some(bindings) = matches {
367            assert_eq!(bindings.get("a"), Some(&Expression::symbol(x.clone())));
368        }
369    }
370
371    #[test]
372    fn test_wildcard_inconsistency() {
373        let x = symbol!(x);
374        let y = symbol!(y);
375        let expr = Expression::add(vec![
376            Expression::symbol(x.clone()),
377            Expression::symbol(y.clone()),
378        ]);
379
380        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
381
382        assert!(expr.matches(&pattern).is_none());
383    }
384
385    #[test]
386    fn test_wildcard_with_exclude() {
387        let x = symbol!(x);
388        let y = symbol!(y);
389
390        let pattern = Pattern::wildcard_excluding("a", vec![Expression::symbol(x.clone())]);
391
392        assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
393
394        assert!(Expression::symbol(y.clone()).matches(&pattern).is_some());
395
396        let expr_with_x =
397            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
398        assert!(expr_with_x.matches(&pattern).is_none());
399    }
400
401    #[test]
402    fn test_wildcard_with_property() {
403        fn is_integer(expr: &Expression) -> bool {
404            matches!(expr, Expression::Number(_))
405        }
406
407        let pattern = Pattern::wildcard_with_properties("n", vec![is_integer]);
408
409        assert!(Expression::integer(42).matches(&pattern).is_some());
410
411        let x = symbol!(x);
412        assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
413    }
414}