1use crate::kernel::{ExprData, ExprId, ExprPool};
26use std::collections::HashMap;
27
28const MAX_AC_DEPTH: usize = 6;
30
31#[derive(Clone, Debug)]
47pub struct Pattern {
48 pub root: ExprId,
49}
50
51impl Pattern {
52 pub fn from_expr(root: ExprId) -> Self {
55 Pattern { root }
56 }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
61pub struct Substitution {
62 pub bindings: HashMap<String, ExprId>,
63}
64
65impl Substitution {
66 fn new() -> Self {
67 Substitution {
68 bindings: HashMap::new(),
69 }
70 }
71
72 fn bind(&mut self, name: &str, id: ExprId) -> bool {
75 match self.bindings.get(name) {
76 Some(&existing) if existing != id => false,
77 _ => {
78 self.bindings.insert(name.to_string(), id);
79 true
80 }
81 }
82 }
83
84 pub fn apply(&self, pattern: ExprId, pool: &ExprPool) -> ExprId {
88 apply_subst(pattern, self, pool)
89 }
90}
91
92fn apply_subst(pat: ExprId, subst: &Substitution, pool: &ExprPool) -> ExprId {
93 enum Node {
94 Wildcard(String),
95 Literal,
96 Add(Vec<ExprId>),
97 Mul(Vec<ExprId>),
98 Pow(ExprId, ExprId),
99 Func(String, Vec<ExprId>),
100 }
101
102 let node = pool.with(pat, |data| match data {
103 ExprData::Symbol { name, .. } if is_wildcard(name) => Node::Wildcard(name.clone()),
104 ExprData::Add(args) => Node::Add(args.clone()),
105 ExprData::Mul(args) => Node::Mul(args.clone()),
106 ExprData::Pow { base, exp } => Node::Pow(*base, *exp),
107 ExprData::Func { name, args } => Node::Func(name.clone(), args.clone()),
108 _ => Node::Literal,
109 });
110
111 match node {
112 Node::Wildcard(name) => subst.bindings.get(&name).copied().unwrap_or(pat),
113 Node::Literal => pat,
114 Node::Add(args) => {
115 let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
116 pool.add(new_args)
117 }
118 Node::Mul(args) => {
119 let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
120 pool.mul(new_args)
121 }
122 Node::Pow(base, exp) => pool.pow(
123 apply_subst(base, subst, pool),
124 apply_subst(exp, subst, pool),
125 ),
126 Node::Func(name, args) => {
127 let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
128 pool.func(name, new_args)
129 }
130 }
131}
132
133fn is_wildcard(name: &str) -> bool {
140 name.starts_with(|c: char| c.is_lowercase())
141}
142
143fn match_one(
150 pat: ExprId,
151 expr: ExprId,
152 subst: Substitution,
153 pool: &ExprPool,
154 ac_depth: usize,
155) -> Option<Substitution> {
156 enum PatNode {
157 Wildcard(String),
158 Integer(i64),
159 Symbol(String),
160 Add(Vec<ExprId>),
161 Mul(Vec<ExprId>),
162 Pow(ExprId, ExprId),
163 Func(String, Vec<ExprId>),
164 Literal,
165 }
166
167 enum ExprNode {
168 Integer(i64),
169 Symbol(String),
170 Add(Vec<ExprId>),
171 Mul(Vec<ExprId>),
172 Pow(ExprId, ExprId),
173 Func(String, Vec<ExprId>),
174 Other,
175 }
176
177 let pat_node = pool.with(pat, |data| match data {
178 ExprData::Symbol { name, .. } if is_wildcard(name) => PatNode::Wildcard(name.clone()),
179 ExprData::Symbol { name, .. } => PatNode::Symbol(name.clone()),
180 ExprData::Integer(n) => PatNode::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
181 ExprData::Add(args) => PatNode::Add(args.clone()),
182 ExprData::Mul(args) => PatNode::Mul(args.clone()),
183 ExprData::Pow { base, exp } => PatNode::Pow(*base, *exp),
184 ExprData::Func { name, args } => PatNode::Func(name.clone(), args.clone()),
185 ExprData::Rational(_) | ExprData::Float(_) => PatNode::Literal,
186 ExprData::Piecewise { .. } | ExprData::Predicate { .. } => PatNode::Literal,
187 ExprData::Forall { .. } | ExprData::Exists { .. } | ExprData::BigO(_) => PatNode::Literal,
188 });
189
190 let expr_node = pool.with(expr, |data| match data {
191 ExprData::Symbol { name, .. } => ExprNode::Symbol(name.clone()),
192 ExprData::Integer(n) => ExprNode::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
193 ExprData::Add(args) => ExprNode::Add(args.clone()),
194 ExprData::Mul(args) => ExprNode::Mul(args.clone()),
195 ExprData::Pow { base, exp } => ExprNode::Pow(*base, *exp),
196 ExprData::Func { name, args } => ExprNode::Func(name.clone(), args.clone()),
197 _ => ExprNode::Other,
198 });
199
200 match pat_node {
201 PatNode::Wildcard(name) => {
203 let mut s = subst;
204 if s.bind(&name, expr) {
205 Some(s)
206 } else {
207 None
208 }
209 }
210
211 PatNode::Integer(pn) => {
213 if matches!(expr_node, ExprNode::Integer(en) if en == pn) {
214 Some(subst)
215 } else {
216 None
217 }
218 }
219
220 PatNode::Symbol(pname) => {
222 if matches!(expr_node, ExprNode::Symbol(ref ename) if *ename == pname) {
223 Some(subst)
224 } else {
225 None
226 }
227 }
228
229 PatNode::Add(pat_args) => {
231 let ExprNode::Add(expr_args) = expr_node else {
232 return None;
233 };
234 if ac_depth >= MAX_AC_DEPTH {
235 return match_args_exact(&pat_args, &expr_args, subst, pool, ac_depth + 1);
237 }
238 match_ac_args(&pat_args, &expr_args, subst, pool, ac_depth, true)
239 }
240
241 PatNode::Mul(pat_args) => {
243 let ExprNode::Mul(expr_args) = expr_node else {
244 return None;
245 };
246 if ac_depth >= MAX_AC_DEPTH {
247 return match_args_exact(&pat_args, &expr_args, subst, pool, ac_depth + 1);
248 }
249 match_ac_args(&pat_args, &expr_args, subst, pool, ac_depth, true)
250 }
251
252 PatNode::Pow(pb, pe) => {
254 let ExprNode::Pow(eb, ee) = expr_node else {
255 return None;
256 };
257 let s = match_one(pb, eb, subst, pool, ac_depth + 1)?;
258 match_one(pe, ee, s, pool, ac_depth + 1)
259 }
260
261 PatNode::Func(pname, pargs) => {
263 let ExprNode::Func(ename, eargs) = expr_node else {
264 return None;
265 };
266 if pname != ename {
267 return None;
268 }
269 match_args_exact(&pargs, &eargs, subst, pool, ac_depth + 1)
270 }
271
272 PatNode::Literal => {
274 if pat == expr {
275 Some(subst)
276 } else {
277 None
278 }
279 }
280 }
281}
282
283fn match_args_exact(
285 pat_args: &[ExprId],
286 expr_args: &[ExprId],
287 subst: Substitution,
288 pool: &ExprPool,
289 ac_depth: usize,
290) -> Option<Substitution> {
291 if pat_args.len() != expr_args.len() {
292 return None;
293 }
294 let mut s = subst;
295 for (&p, &e) in pat_args.iter().zip(expr_args.iter()) {
296 s = match_one(p, e, s, pool, ac_depth)?;
297 }
298 Some(s)
299}
300
301fn match_ac_args(
317 pat_args: &[ExprId],
318 expr_args: &[ExprId],
319 subst: Substitution,
320 pool: &ExprPool,
321 ac_depth: usize,
322 is_add: bool,
323) -> Option<Substitution> {
324 if pat_args.is_empty() && expr_args.is_empty() {
325 return Some(subst);
326 }
327 if pat_args.is_empty() || expr_args.is_empty() {
328 return None;
329 }
330
331 if pat_args.len() == expr_args.len() {
333 return try_permutations(pat_args, expr_args, subst, pool, ac_depth);
334 }
335
336 if pat_args.len() < expr_args.len() {
340 let last_pat = *pat_args.last().unwrap();
341 let is_last_wildcard = pool.with(
342 last_pat,
343 |data| matches!(data, ExprData::Symbol { name, .. } if is_wildcard(name)),
344 );
345
346 if !is_last_wildcard {
347 return None;
349 }
350
351 let prefix_len = pat_args.len() - 1;
352 let indices: Vec<usize> = (0..expr_args.len()).collect();
354 return try_subsets(
355 pat_args, expr_args, &indices, prefix_len, subst, pool, ac_depth, is_add,
356 );
357 }
358
359 None
361}
362
363fn try_permutations(
365 pat_args: &[ExprId],
366 expr_args: &[ExprId],
367 subst: Substitution,
368 pool: &ExprPool,
369 ac_depth: usize,
370) -> Option<Substitution> {
371 let mut perm: Vec<usize> = (0..expr_args.len()).collect();
373 loop {
374 let mut s = subst.clone();
376 let mut ok = true;
377 for (i, &pat_id) in pat_args.iter().enumerate() {
378 match match_one(pat_id, expr_args[perm[i]], s.clone(), pool, ac_depth + 1) {
379 Some(new_s) => s = new_s,
380 None => {
381 ok = false;
382 break;
383 }
384 }
385 }
386 if ok {
387 return Some(s);
388 }
389
390 if !next_permutation(&mut perm) {
392 break;
393 }
394 }
395 None
396}
397
398fn next_permutation(perm: &mut [usize]) -> bool {
401 let n = perm.len();
402 if n <= 1 {
403 return false;
404 }
405 let mut i = n - 1;
406 while i > 0 && perm[i - 1] >= perm[i] {
407 i -= 1;
408 }
409 if i == 0 {
410 return false;
411 }
412 let j = (i..n).rfind(|&j| perm[j] > perm[i - 1]).unwrap();
413 perm.swap(i - 1, j);
414 perm[i..].reverse();
415 true
416}
417
418#[allow(clippy::too_many_arguments)]
421fn try_subsets(
422 pat_args: &[ExprId],
423 expr_args: &[ExprId],
424 indices: &[usize],
425 prefix_len: usize,
426 subst: Substitution,
427 pool: &ExprPool,
428 ac_depth: usize,
429 is_add: bool,
430) -> Option<Substitution> {
431 if prefix_len == 0 {
432 let last_pat = *pat_args.last().unwrap();
434 let residual: Vec<ExprId> = indices.iter().map(|&i| expr_args[i]).collect();
435 let residual_expr = match residual.len() {
436 0 => return None,
437 1 => residual[0],
438 _ => {
439 if is_add {
440 pool.add(residual)
441 } else {
442 pool.mul(residual)
443 }
444 }
445 };
446 let mut s = subst;
447 s.bind(
448 &pool.with(last_pat, |data| {
449 if let ExprData::Symbol { name, .. } = data {
450 name.clone()
451 } else {
452 String::new()
453 }
454 }),
455 residual_expr,
456 );
457 return if s.bindings.values().next().is_some() {
458 Some(s)
459 } else {
460 None
461 };
462 }
463
464 for chosen_pos in 0..indices.len() {
466 let chosen = indices[chosen_pos];
467 let remaining: Vec<usize> = indices
468 .iter()
469 .enumerate()
470 .filter(|&(j, _)| j != chosen_pos)
471 .map(|(_, &i)| i)
472 .collect();
473 let pat_idx = pat_args.len() - 1 - prefix_len; if let Some(s) = match_one(
475 pat_args[pat_idx],
476 expr_args[chosen],
477 subst.clone(),
478 pool,
479 ac_depth + 1,
480 ) {
481 if let Some(final_s) = try_subsets(
482 pat_args,
483 expr_args,
484 &remaining,
485 prefix_len - 1,
486 s,
487 pool,
488 ac_depth,
489 is_add,
490 ) {
491 return Some(final_s);
492 }
493 }
494 }
495 None
496}
497
498pub fn match_pattern(pattern: &Pattern, expr: ExprId, pool: &ExprPool) -> Vec<Substitution> {
523 let mut results = Vec::new();
524 collect_matches(pattern.root, expr, pool, &mut results);
525 results
526}
527
528fn collect_matches(pat: ExprId, expr: ExprId, pool: &ExprPool, results: &mut Vec<Substitution>) {
530 if let Some(s) = match_one(pat, expr, Substitution::new(), pool, 0) {
532 results.push(s);
533 }
534
535 let children: Vec<ExprId> = pool.with(expr, |data| match data {
537 ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
538 ExprData::Pow { base, exp } => vec![*base, *exp],
539 ExprData::Func { args, .. } => args.clone(),
540 _ => vec![],
541 });
542
543 for child in children {
544 collect_matches(pat, child, pool, results);
545 }
546}
547
548#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::kernel::{Domain, ExprPool};
556
557 fn pool() -> ExprPool {
558 ExprPool::new()
559 }
560
561 #[test]
562 fn wildcard_matches_anything() {
563 let p = pool();
564 let a = p.symbol("a", Domain::Real); let x = p.symbol("x", Domain::Real);
566 let pat = Pattern::from_expr(a);
567 let matches = match_pattern(&pat, x, &p);
568 assert_eq!(matches.len(), 1);
569 assert_eq!(matches[0].bindings["a"], x);
570 }
571
572 #[test]
573 fn literal_symbol_exact_match() {
574 let p = pool();
575 let x = p.symbol("x", Domain::Real); let xpat = p.symbol("X", Domain::Real); let pat = Pattern::from_expr(xpat);
579 let y = p.symbol("Y", Domain::Real);
581 assert!(match_pattern(&pat, y, &p).is_empty());
582 assert!(!match_pattern(&pat, xpat, &p).is_empty());
584 let _ = x; }
586
587 #[test]
588 fn add_pattern_ac_match() {
589 let p = pool();
590 let a = p.symbol("a", Domain::Real);
591 let b = p.symbol("b", Domain::Real);
592 let x = p.symbol("x", Domain::Real);
593 let y = p.symbol("y", Domain::Real);
594 let pat = Pattern::from_expr(p.add(vec![a, b]));
596 let expr = p.add(vec![x, y]);
597 let matches = match_pattern(&pat, expr, &p);
598 assert!(!matches.is_empty(), "a+b should match x+y");
600 }
601
602 #[test]
603 fn add_pattern_two_splits_for_three_terms() {
604 let p = pool();
606 let a = p.symbol("a", Domain::Real);
607 let b = p.symbol("b", Domain::Real);
608 let x = p.symbol("x", Domain::Real);
609 let y = p.symbol("y", Domain::Real);
610 let z = p.symbol("z", Domain::Real);
611 let pat = Pattern::from_expr(p.add(vec![a, b]));
612 let expr = p.add(vec![x, y, z]);
613 let matches = match_pattern(&pat, expr, &p);
614 assert!(!matches.is_empty(), "a+b should match subsets of x+y+z");
616 }
617
618 #[test]
619 fn substitution_apply() {
620 let p = pool();
621 let a = p.symbol("a", Domain::Real);
622 let x = p.symbol("x", Domain::Real);
623 let one = p.integer(1_i32);
624 let pat = p.add(vec![a, one]); let mut subst = Substitution::new();
626 subst.bind("a", x);
627 let result = subst.apply(pat, &p);
628 let expected = p.add(vec![x, one]);
630 assert_eq!(result, expected);
631 }
632
633 #[test]
634 fn match_inside_function() {
635 let p = pool();
637 let a = p.symbol("a", Domain::Real);
638 let b = p.symbol("b", Domain::Real);
639 let x = p.symbol("x", Domain::Real);
640 let y = p.symbol("y", Domain::Real);
641 let inner = p.add(vec![x, y]);
642 let f = p.func("f", vec![inner]);
643 let pat = Pattern::from_expr(p.add(vec![a, b]));
644 let matches = match_pattern(&pat, f, &p);
645 assert!(!matches.is_empty(), "should find a+b inside f(x+y)");
646 }
647
648 #[test]
649 fn no_spurious_matches() {
650 let p = pool();
652 let a = p.symbol("a", Domain::Real);
653 let b = p.symbol("b", Domain::Real);
654 let x = p.symbol("x", Domain::Real);
655 let y = p.symbol("y", Domain::Real);
656 let pat = Pattern::from_expr(p.mul(vec![a, b]));
657 let expr = p.add(vec![x, y]);
658 assert!(
659 match_pattern(&pat, expr, &p).is_empty(),
660 "mul pattern should not match add"
661 );
662 }
663
664 #[test]
665 fn consistent_wildcard_bindings() {
666 let p = pool();
668 let a = p.symbol("a", Domain::Real);
669 let x = p.symbol("x", Domain::Real);
670 let y = p.symbol("y", Domain::Real);
671 let pat = Pattern::from_expr(p.add(vec![a, a]));
672 assert!(!match_pattern(&pat, p.add(vec![x, x]), &p).is_empty());
674 assert!(match_pattern(&pat, p.add(vec![x, y]), &p).is_empty());
676 }
677}