mathhook_core/pattern/matching/engine/
commutative.rs1use super::core::match_recursive;
7use super::PatternMatches;
8use crate::core::Expression;
9use crate::pattern::matching::patterns::Pattern;
10
11pub(super) fn match_commutative(
17 expr_items: &[Expression],
18 pattern_items: &[Pattern],
19 bindings: &mut PatternMatches,
20) -> bool {
21 if pattern_items.is_empty() {
22 return expr_items.is_empty();
23 }
24
25 if pattern_items.len() == 1 {
26 if expr_items.len() == 1 {
27 return match_recursive(&expr_items[0], &pattern_items[0], bindings);
28 } else {
29 return false;
30 }
31 }
32
33 if expr_items.len() != pattern_items.len() {
34 return false;
35 }
36
37 let is_commutative = check_commutativity(expr_items);
38
39 let backup_bindings = bindings.clone();
40 let mut ordered_match = true;
41
42 for (expr_item, pattern_item) in expr_items.iter().zip(pattern_items.iter()) {
43 if !match_recursive(expr_item, pattern_item, bindings) {
44 ordered_match = false;
45 break;
46 }
47 }
48
49 if ordered_match {
50 return true;
51 }
52
53 *bindings = backup_bindings;
54
55 if !is_commutative {
56 return false;
57 }
58
59 if pattern_items.len() <= 6 {
60 try_permutation_match(expr_items, pattern_items, bindings)
61 } else {
62 try_greedy_match(expr_items, pattern_items, bindings)
63 }
64}
65
66pub fn check_commutativity(items: &[Expression]) -> bool {
68 use crate::core::commutativity::Commutativity;
69
70 for item in items {
71 if item.commutativity() == Commutativity::Noncommutative {
72 return false;
73 }
74 }
75 true
76}
77
78pub fn try_permutation_match(
80 expr_items: &[Expression],
81 pattern_items: &[Pattern],
82 bindings: &mut PatternMatches,
83) -> bool {
84 if expr_items.len() != pattern_items.len() {
85 return false;
86 }
87
88 let indices: Vec<usize> = (0..pattern_items.len()).collect();
89 try_permutations(&indices, 0, expr_items, pattern_items, bindings)
90}
91
92pub fn try_permutations(
94 indices: &[usize],
95 start: usize,
96 expr_items: &[Expression],
97 pattern_items: &[Pattern],
98 bindings: &mut PatternMatches,
99) -> bool {
100 if start == indices.len() {
101 let backup_bindings = bindings.clone();
102 for (expr_idx, &pattern_idx) in indices.iter().enumerate() {
103 if !match_recursive(&expr_items[expr_idx], &pattern_items[pattern_idx], bindings) {
104 *bindings = backup_bindings;
105 return false;
106 }
107 }
108 return true;
109 }
110
111 for i in start..indices.len() {
112 let mut perm = indices.to_vec();
113 perm.swap(start, i);
114 if try_permutations(&perm, start + 1, expr_items, pattern_items, bindings) {
115 return true;
116 }
117 }
118
119 false
120}
121
122pub fn try_greedy_match(
124 expr_items: &[Expression],
125 pattern_items: &[Pattern],
126 bindings: &mut PatternMatches,
127) -> bool {
128 if expr_items.len() != pattern_items.len() {
129 return false;
130 }
131
132 let mut used_expr: Vec<bool> = vec![false; expr_items.len()];
133 let backup_bindings = bindings.clone();
134
135 for pattern_item in pattern_items {
136 let mut matched = false;
137 for (expr_idx, expr_item) in expr_items.iter().enumerate() {
138 if !used_expr[expr_idx] {
139 let mut temp_bindings = bindings.clone();
140 if match_recursive(expr_item, pattern_item, &mut temp_bindings) {
141 *bindings = temp_bindings;
142 used_expr[expr_idx] = true;
143 matched = true;
144 break;
145 }
146 }
147 }
148
149 if !matched {
150 *bindings = backup_bindings;
151 return false;
152 }
153 }
154
155 true
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::pattern::matching::engine::Matchable;
162 use crate::pattern::matching::patterns::Pattern;
163 use crate::prelude::*;
164
165 #[test]
166 fn test_commutative_addition_matching() {
167 let x = symbol!(x);
168 let y = symbol!(y);
169 let expr = Expression::add(vec![
170 Expression::symbol(y.clone()),
171 Expression::symbol(x.clone()),
172 ]);
173
174 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
175
176 let matches = expr.matches(&pattern);
177 assert!(matches.is_some());
178
179 if let Some(bindings) = matches {
180 let a_val = bindings.get("a").unwrap();
181 let b_val = bindings.get("b").unwrap();
182
183 assert!(
184 (a_val == &Expression::symbol(y.clone())
185 && b_val == &Expression::symbol(x.clone()))
186 || (a_val == &Expression::symbol(x.clone())
187 && b_val == &Expression::symbol(y.clone()))
188 );
189 }
190 }
191
192 #[test]
193 fn test_commutative_multiplication_matching() {
194 let x = symbol!(x);
195 let expr = Expression::mul(vec![Expression::symbol(x.clone()), Expression::integer(3)]);
196
197 let pattern = Pattern::Mul(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
198
199 let matches = expr.matches(&pattern);
200 assert!(matches.is_some());
201 }
202
203 #[test]
204 fn test_three_term_commutative_match() {
205 let x = symbol!(x);
206 let y = symbol!(y);
207 let z = symbol!(z);
208
209 let expr = Expression::add(vec![
210 Expression::symbol(z.clone()),
211 Expression::symbol(y.clone()),
212 Expression::symbol(x.clone()),
213 ]);
214
215 let pattern = Pattern::Add(vec![
216 Pattern::wildcard("a"),
217 Pattern::wildcard("b"),
218 Pattern::wildcard("c"),
219 ]);
220
221 let matches = expr.matches(&pattern);
222 assert!(matches.is_some());
223 }
224
225 #[test]
226 fn test_matrix_multiplication_no_match_reversed() {
227 let a = symbol!(A; matrix);
228 let b = symbol!(B; matrix);
229
230 let expr = Expression::mul(vec![
231 Expression::symbol(b.clone()),
232 Expression::symbol(a.clone()),
233 ])
234 .simplify();
235
236 let pattern = Pattern::Mul(vec![
237 Pattern::Exact(Expression::symbol(a.clone())),
238 Pattern::Exact(Expression::symbol(b.clone())),
239 ]);
240
241 let matches = expr.matches(&pattern);
242 assert!(
243 matches.is_none(),
244 "AB pattern should NOT match BA expression for noncommutative matrices"
245 );
246 }
247
248 #[test]
249 fn test_matrix_multiplication_matches_same_order() {
250 let a = symbol!(A; matrix);
251 let b = symbol!(B; matrix);
252
253 let expr = Expression::mul(vec![
254 Expression::symbol(a.clone()),
255 Expression::symbol(b.clone()),
256 ])
257 .simplify();
258
259 let pattern = Pattern::Mul(vec![
260 Pattern::Exact(Expression::symbol(a.clone())),
261 Pattern::Exact(Expression::symbol(b.clone())),
262 ]);
263
264 let matches = expr.matches(&pattern);
265 assert!(
266 matches.is_some(),
267 "AB pattern should match AB expression for matrices"
268 );
269 }
270
271 #[test]
272 fn test_scalar_multiplication_matches_reversed() {
273 let x = symbol!(x);
274 let y = symbol!(y);
275
276 let expr = Expression::mul(vec![
277 Expression::symbol(y.clone()),
278 Expression::symbol(x.clone()),
279 ])
280 .simplify();
281
282 let pattern = Pattern::Mul(vec![
283 Pattern::Exact(Expression::symbol(x.clone())),
284 Pattern::Exact(Expression::symbol(y.clone())),
285 ]);
286
287 let matches = expr.matches(&pattern);
288 assert!(
289 matches.is_some(),
290 "xy pattern should match yx expression for commutative scalars"
291 );
292 }
293
294 #[test]
295 fn test_operator_multiplication_no_match_reversed() {
296 let p = symbol!(p; operator);
297 let x = symbol!(x; operator);
298
299 let expr = Expression::mul(vec![
300 Expression::symbol(x.clone()),
301 Expression::symbol(p.clone()),
302 ])
303 .simplify();
304
305 let pattern = Pattern::Mul(vec![
306 Pattern::Exact(Expression::symbol(p.clone())),
307 Pattern::Exact(Expression::symbol(x.clone())),
308 ]);
309
310 let matches = expr.matches(&pattern);
311 assert!(
312 matches.is_none(),
313 "px pattern should NOT match xp expression for noncommutative operators"
314 );
315 }
316
317 #[test]
318 fn test_quaternion_multiplication_no_match_reversed() {
319 let i = symbol!(i; quaternion);
320 let j = symbol!(j; quaternion);
321
322 let expr = Expression::mul(vec![
323 Expression::symbol(j.clone()),
324 Expression::symbol(i.clone()),
325 ])
326 .simplify();
327
328 let pattern = Pattern::Mul(vec![
329 Pattern::Exact(Expression::symbol(i.clone())),
330 Pattern::Exact(Expression::symbol(j.clone())),
331 ]);
332
333 let matches = expr.matches(&pattern);
334 assert!(
335 matches.is_none(),
336 "ij pattern should NOT match ji expression for noncommutative quaternions"
337 );
338 }
339
340 #[test]
341 fn test_matrix_wildcard_pattern_preserves_order() {
342 let a = symbol!(A; matrix);
343 let b = symbol!(B; matrix);
344
345 let expr = Expression::mul(vec![
346 Expression::symbol(a.clone()),
347 Expression::symbol(b.clone()),
348 ])
349 .simplify();
350
351 let pattern = Pattern::Mul(vec![Pattern::wildcard("x"), Pattern::wildcard("y")]);
352
353 let matches = expr.matches(&pattern);
354 assert!(matches.is_some());
355
356 if let Some(bindings) = matches {
357 assert_eq!(bindings.get("x"), Some(&Expression::symbol(a.clone())));
358 assert_eq!(bindings.get("y"), Some(&Expression::symbol(b.clone())));
359 }
360 }
361
362 #[test]
363 fn test_mixed_commutative_noncommutative_respects_order() {
364 let a = symbol!(A; matrix);
365 let b = symbol!(B; matrix);
366 let c = symbol!(c);
367
368 let expr = Expression::mul(vec![
369 Expression::symbol(a.clone()),
370 Expression::symbol(c.clone()),
371 Expression::symbol(b.clone()),
372 ])
373 .simplify();
374
375 let pattern_wrong_order = Pattern::Mul(vec![
376 Pattern::Exact(Expression::symbol(a.clone())),
377 Pattern::Exact(Expression::symbol(b.clone())),
378 Pattern::Exact(Expression::symbol(c.clone())),
379 ]);
380
381 assert!(
382 expr.matches(&pattern_wrong_order).is_none(),
383 "AcB should NOT match ABc pattern when matrices are involved"
384 );
385
386 let pattern_correct_order = Pattern::Mul(vec![
387 Pattern::Exact(Expression::symbol(a.clone())),
388 Pattern::Exact(Expression::symbol(c.clone())),
389 Pattern::Exact(Expression::symbol(b.clone())),
390 ]);
391
392 assert!(
393 expr.matches(&pattern_correct_order).is_some(),
394 "AcB should match AcB pattern"
395 );
396 }
397}