mathhook_core/pattern/matching/engine/
commutative.rs

1//! Commutative matching algorithms
2//!
3//! Handles pattern matching for commutative operations (Add, Mul) with
4//! permutation-based and greedy heuristic matching strategies.
5
6use super::core::match_recursive;
7use super::PatternMatches;
8use crate::core::Expression;
9use crate::pattern::matching::patterns::Pattern;
10
11/// Match commutative operations (Add, Mul)
12///
13/// Tries to match expression terms/factors against pattern terms/factors.
14/// For commutative expressions (scalars), considers all possible orderings.
15/// For noncommutative expressions (matrices, operators), requires exact order.
16pub(super) fn match_commutative(
17    expr_items: &[Expression],
18    pattern_items: &[Pattern],
19    bindings: &mut PatternMatches,
20) -> bool {
21    if pattern_items.is_empty() {
22        return expr_items.is_empty();
23    }
24
25    if pattern_items.len() == 1 {
26        if expr_items.len() == 1 {
27            return match_recursive(&expr_items[0], &pattern_items[0], bindings);
28        } else {
29            return false;
30        }
31    }
32
33    if expr_items.len() != pattern_items.len() {
34        return false;
35    }
36
37    let is_commutative = check_commutativity(expr_items);
38
39    let backup_bindings = bindings.clone();
40    let mut ordered_match = true;
41
42    for (expr_item, pattern_item) in expr_items.iter().zip(pattern_items.iter()) {
43        if !match_recursive(expr_item, pattern_item, bindings) {
44            ordered_match = false;
45            break;
46        }
47    }
48
49    if ordered_match {
50        return true;
51    }
52
53    *bindings = backup_bindings;
54
55    if !is_commutative {
56        return false;
57    }
58
59    if pattern_items.len() <= 6 {
60        try_permutation_match(expr_items, pattern_items, bindings)
61    } else {
62        try_greedy_match(expr_items, pattern_items, bindings)
63    }
64}
65
66/// Check if all expressions in the collection are commutative
67pub fn check_commutativity(items: &[Expression]) -> bool {
68    use crate::core::commutativity::Commutativity;
69
70    for item in items {
71        if item.commutativity() == Commutativity::Noncommutative {
72            return false;
73        }
74    }
75    true
76}
77
78/// Try all permutations of pattern items to find a match
79pub fn try_permutation_match(
80    expr_items: &[Expression],
81    pattern_items: &[Pattern],
82    bindings: &mut PatternMatches,
83) -> bool {
84    if expr_items.len() != pattern_items.len() {
85        return false;
86    }
87
88    let indices: Vec<usize> = (0..pattern_items.len()).collect();
89    try_permutations(&indices, 0, expr_items, pattern_items, bindings)
90}
91
92/// Recursive permutation generator and matcher
93pub fn try_permutations(
94    indices: &[usize],
95    start: usize,
96    expr_items: &[Expression],
97    pattern_items: &[Pattern],
98    bindings: &mut PatternMatches,
99) -> bool {
100    if start == indices.len() {
101        let backup_bindings = bindings.clone();
102        for (expr_idx, &pattern_idx) in indices.iter().enumerate() {
103            if !match_recursive(&expr_items[expr_idx], &pattern_items[pattern_idx], bindings) {
104                *bindings = backup_bindings;
105                return false;
106            }
107        }
108        return true;
109    }
110
111    for i in start..indices.len() {
112        let mut perm = indices.to_vec();
113        perm.swap(start, i);
114        if try_permutations(&perm, start + 1, expr_items, pattern_items, bindings) {
115            return true;
116        }
117    }
118
119    false
120}
121
122/// Greedy heuristic matching for large commutative patterns
123pub fn try_greedy_match(
124    expr_items: &[Expression],
125    pattern_items: &[Pattern],
126    bindings: &mut PatternMatches,
127) -> bool {
128    if expr_items.len() != pattern_items.len() {
129        return false;
130    }
131
132    let mut used_expr: Vec<bool> = vec![false; expr_items.len()];
133    let backup_bindings = bindings.clone();
134
135    for pattern_item in pattern_items {
136        let mut matched = false;
137        for (expr_idx, expr_item) in expr_items.iter().enumerate() {
138            if !used_expr[expr_idx] {
139                let mut temp_bindings = bindings.clone();
140                if match_recursive(expr_item, pattern_item, &mut temp_bindings) {
141                    *bindings = temp_bindings;
142                    used_expr[expr_idx] = true;
143                    matched = true;
144                    break;
145                }
146            }
147        }
148
149        if !matched {
150            *bindings = backup_bindings;
151            return false;
152        }
153    }
154
155    true
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::pattern::matching::engine::Matchable;
162    use crate::pattern::matching::patterns::Pattern;
163    use crate::prelude::*;
164
165    #[test]
166    fn test_commutative_addition_matching() {
167        let x = symbol!(x);
168        let y = symbol!(y);
169        let expr = Expression::add(vec![
170            Expression::symbol(y.clone()),
171            Expression::symbol(x.clone()),
172        ]);
173
174        let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
175
176        let matches = expr.matches(&pattern);
177        assert!(matches.is_some());
178
179        if let Some(bindings) = matches {
180            let a_val = bindings.get("a").unwrap();
181            let b_val = bindings.get("b").unwrap();
182
183            assert!(
184                (a_val == &Expression::symbol(y.clone())
185                    && b_val == &Expression::symbol(x.clone()))
186                    || (a_val == &Expression::symbol(x.clone())
187                        && b_val == &Expression::symbol(y.clone()))
188            );
189        }
190    }
191
192    #[test]
193    fn test_commutative_multiplication_matching() {
194        let x = symbol!(x);
195        let expr = Expression::mul(vec![Expression::symbol(x.clone()), Expression::integer(3)]);
196
197        let pattern = Pattern::Mul(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
198
199        let matches = expr.matches(&pattern);
200        assert!(matches.is_some());
201    }
202
203    #[test]
204    fn test_three_term_commutative_match() {
205        let x = symbol!(x);
206        let y = symbol!(y);
207        let z = symbol!(z);
208
209        let expr = Expression::add(vec![
210            Expression::symbol(z.clone()),
211            Expression::symbol(y.clone()),
212            Expression::symbol(x.clone()),
213        ]);
214
215        let pattern = Pattern::Add(vec![
216            Pattern::wildcard("a"),
217            Pattern::wildcard("b"),
218            Pattern::wildcard("c"),
219        ]);
220
221        let matches = expr.matches(&pattern);
222        assert!(matches.is_some());
223    }
224
225    #[test]
226    fn test_matrix_multiplication_no_match_reversed() {
227        let a = symbol!(A; matrix);
228        let b = symbol!(B; matrix);
229
230        let expr = Expression::mul(vec![
231            Expression::symbol(b.clone()),
232            Expression::symbol(a.clone()),
233        ])
234        .simplify();
235
236        let pattern = Pattern::Mul(vec![
237            Pattern::Exact(Expression::symbol(a.clone())),
238            Pattern::Exact(Expression::symbol(b.clone())),
239        ]);
240
241        let matches = expr.matches(&pattern);
242        assert!(
243            matches.is_none(),
244            "AB pattern should NOT match BA expression for noncommutative matrices"
245        );
246    }
247
248    #[test]
249    fn test_matrix_multiplication_matches_same_order() {
250        let a = symbol!(A; matrix);
251        let b = symbol!(B; matrix);
252
253        let expr = Expression::mul(vec![
254            Expression::symbol(a.clone()),
255            Expression::symbol(b.clone()),
256        ])
257        .simplify();
258
259        let pattern = Pattern::Mul(vec![
260            Pattern::Exact(Expression::symbol(a.clone())),
261            Pattern::Exact(Expression::symbol(b.clone())),
262        ]);
263
264        let matches = expr.matches(&pattern);
265        assert!(
266            matches.is_some(),
267            "AB pattern should match AB expression for matrices"
268        );
269    }
270
271    #[test]
272    fn test_scalar_multiplication_matches_reversed() {
273        let x = symbol!(x);
274        let y = symbol!(y);
275
276        let expr = Expression::mul(vec![
277            Expression::symbol(y.clone()),
278            Expression::symbol(x.clone()),
279        ])
280        .simplify();
281
282        let pattern = Pattern::Mul(vec![
283            Pattern::Exact(Expression::symbol(x.clone())),
284            Pattern::Exact(Expression::symbol(y.clone())),
285        ]);
286
287        let matches = expr.matches(&pattern);
288        assert!(
289            matches.is_some(),
290            "xy pattern should match yx expression for commutative scalars"
291        );
292    }
293
294    #[test]
295    fn test_operator_multiplication_no_match_reversed() {
296        let p = symbol!(p; operator);
297        let x = symbol!(x; operator);
298
299        let expr = Expression::mul(vec![
300            Expression::symbol(x.clone()),
301            Expression::symbol(p.clone()),
302        ])
303        .simplify();
304
305        let pattern = Pattern::Mul(vec![
306            Pattern::Exact(Expression::symbol(p.clone())),
307            Pattern::Exact(Expression::symbol(x.clone())),
308        ]);
309
310        let matches = expr.matches(&pattern);
311        assert!(
312            matches.is_none(),
313            "px pattern should NOT match xp expression for noncommutative operators"
314        );
315    }
316
317    #[test]
318    fn test_quaternion_multiplication_no_match_reversed() {
319        let i = symbol!(i; quaternion);
320        let j = symbol!(j; quaternion);
321
322        let expr = Expression::mul(vec![
323            Expression::symbol(j.clone()),
324            Expression::symbol(i.clone()),
325        ])
326        .simplify();
327
328        let pattern = Pattern::Mul(vec![
329            Pattern::Exact(Expression::symbol(i.clone())),
330            Pattern::Exact(Expression::symbol(j.clone())),
331        ]);
332
333        let matches = expr.matches(&pattern);
334        assert!(
335            matches.is_none(),
336            "ij pattern should NOT match ji expression for noncommutative quaternions"
337        );
338    }
339
340    #[test]
341    fn test_matrix_wildcard_pattern_preserves_order() {
342        let a = symbol!(A; matrix);
343        let b = symbol!(B; matrix);
344
345        let expr = Expression::mul(vec![
346            Expression::symbol(a.clone()),
347            Expression::symbol(b.clone()),
348        ])
349        .simplify();
350
351        let pattern = Pattern::Mul(vec![Pattern::wildcard("x"), Pattern::wildcard("y")]);
352
353        let matches = expr.matches(&pattern);
354        assert!(matches.is_some());
355
356        if let Some(bindings) = matches {
357            assert_eq!(bindings.get("x"), Some(&Expression::symbol(a.clone())));
358            assert_eq!(bindings.get("y"), Some(&Expression::symbol(b.clone())));
359        }
360    }
361
362    #[test]
363    fn test_mixed_commutative_noncommutative_respects_order() {
364        let a = symbol!(A; matrix);
365        let b = symbol!(B; matrix);
366        let c = symbol!(c);
367
368        let expr = Expression::mul(vec![
369            Expression::symbol(a.clone()),
370            Expression::symbol(c.clone()),
371            Expression::symbol(b.clone()),
372        ])
373        .simplify();
374
375        let pattern_wrong_order = Pattern::Mul(vec![
376            Pattern::Exact(Expression::symbol(a.clone())),
377            Pattern::Exact(Expression::symbol(b.clone())),
378            Pattern::Exact(Expression::symbol(c.clone())),
379        ]);
380
381        assert!(
382            expr.matches(&pattern_wrong_order).is_none(),
383            "AcB should NOT match ABc pattern when matrices are involved"
384        );
385
386        let pattern_correct_order = Pattern::Mul(vec![
387            Pattern::Exact(Expression::symbol(a.clone())),
388            Pattern::Exact(Expression::symbol(c.clone())),
389            Pattern::Exact(Expression::symbol(b.clone())),
390        ]);
391
392        assert!(
393            expr.matches(&pattern_correct_order).is_some(),
394            "AcB should match AcB pattern"
395        );
396    }
397}