Skip to main content

alkahest_cas/pattern/
matcher.rs

1/// AC-aware pattern matching for symbolic expressions.
2///
3/// A `Pattern` is a template that may contain named wildcards.  A wildcard
4/// matches any sub-expression and binds it to a name, so that all occurrences
5/// of the same wildcard name must match the *same* (structurally equal)
6/// expression.
7///
8/// # AC semantics
9///
10/// `Add` and `Mul` are treated as *associative and commutative* (AC)
11/// operators.  A pattern like `a + b` therefore matches *any pair* of
12/// sub-expressions drawn from an n-ary sum, not just the literal first and
13/// second children.
14///
15/// # Search depth
16///
17/// To prevent combinatorial explosion the AC search is bounded:
18/// - At most `MAX_AC_DEPTH` nested AC operators are explored.
19/// - The number of candidate splits for an n-ary term is bounded by the
20///   number of size-k subsets of n terms (k = arity of pattern AC node).
21///
22/// These bounds are conservative for normal CAS expressions.  Callers who
23/// need exhaustive matching on large sums/products should pass a custom
24/// config in future (extension point).
25use crate::kernel::{ExprData, ExprId, ExprPool};
26use std::collections::HashMap;
27
28// Maximum depth of AC nodes traversed during recursive matching.
29const MAX_AC_DEPTH: usize = 6;
30
31// ---------------------------------------------------------------------------
32// Public types
33// ---------------------------------------------------------------------------
34
35/// A pattern for matching against symbolic expressions.
36///
37/// Patterns share the same expression representation as regular expressions
38/// but may include `Symbol` nodes that act as wildcards.  A symbol whose
39/// name starts with a lower-case letter (e.g. `a`, `f`, `lhs`) is treated
40/// as a *wildcard variable* that binds to any sub-expression.  Upper-case
41/// or multi-character names that don't match the wildcard convention are
42/// treated as literal symbols.
43///
44/// Use `Pattern::parse` to build patterns from strings, or `Pattern::from_expr`
45/// to use any `ExprId` directly as a pattern (all symbols become wildcards).
46#[derive(Clone, Debug)]
47pub struct Pattern {
48    pub root: ExprId,
49}
50
51impl Pattern {
52    /// Create a pattern from an existing expression.  All `Symbol` nodes in
53    /// the expression become wildcards.
54    pub fn from_expr(root: ExprId) -> Self {
55        Pattern { root }
56    }
57}
58
59/// A binding from wildcard names to matched expression ids.
60#[derive(Clone, Debug, PartialEq, Eq)]
61pub struct Substitution {
62    pub bindings: HashMap<String, ExprId>,
63}
64
65impl Substitution {
66    fn new() -> Self {
67        Substitution {
68            bindings: HashMap::new(),
69        }
70    }
71
72    /// Attempt to bind `name` to `id`.  Returns `false` if `name` is already
73    /// bound to a different expression.
74    fn bind(&mut self, name: &str, id: ExprId) -> bool {
75        match self.bindings.get(name) {
76            Some(&existing) if existing != id => false,
77            _ => {
78                self.bindings.insert(name.to_string(), id);
79                true
80            }
81        }
82    }
83
84    /// Apply the substitution to a pattern expression, returning the
85    /// concrete expression.  Wildcards are replaced by their bindings;
86    /// unbound wildcards are left as-is.
87    pub fn apply(&self, pattern: ExprId, pool: &ExprPool) -> ExprId {
88        apply_subst(pattern, self, pool)
89    }
90}
91
92fn apply_subst(pat: ExprId, subst: &Substitution, pool: &ExprPool) -> ExprId {
93    enum Node {
94        Wildcard(String),
95        Literal,
96        Add(Vec<ExprId>),
97        Mul(Vec<ExprId>),
98        Pow(ExprId, ExprId),
99        Func(String, Vec<ExprId>),
100    }
101
102    let node = pool.with(pat, |data| match data {
103        ExprData::Symbol { name, .. } if is_wildcard(name) => Node::Wildcard(name.clone()),
104        ExprData::Add(args) => Node::Add(args.clone()),
105        ExprData::Mul(args) => Node::Mul(args.clone()),
106        ExprData::Pow { base, exp } => Node::Pow(*base, *exp),
107        ExprData::Func { name, args } => Node::Func(name.clone(), args.clone()),
108        _ => Node::Literal,
109    });
110
111    match node {
112        Node::Wildcard(name) => subst.bindings.get(&name).copied().unwrap_or(pat),
113        Node::Literal => pat,
114        Node::Add(args) => {
115            let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
116            pool.add(new_args)
117        }
118        Node::Mul(args) => {
119            let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
120            pool.mul(new_args)
121        }
122        Node::Pow(base, exp) => pool.pow(
123            apply_subst(base, subst, pool),
124            apply_subst(exp, subst, pool),
125        ),
126        Node::Func(name, args) => {
127            let new_args: Vec<_> = args.iter().map(|&a| apply_subst(a, subst, pool)).collect();
128            pool.func(name, new_args)
129        }
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Helpers
135// ---------------------------------------------------------------------------
136
137/// A symbol is a wildcard if its name is a single lower-case letter or
138/// starts with a lower-case letter followed by alphanumeric/underscore.
139fn is_wildcard(name: &str) -> bool {
140    name.starts_with(|c: char| c.is_lowercase())
141}
142
143// ---------------------------------------------------------------------------
144// Core matching — non-AC
145// ---------------------------------------------------------------------------
146
147/// Try to match `pat` against `expr` given an existing partial `subst`.
148/// Returns the extended substitution on success, `None` on failure.
149fn match_one(
150    pat: ExprId,
151    expr: ExprId,
152    subst: Substitution,
153    pool: &ExprPool,
154    ac_depth: usize,
155) -> Option<Substitution> {
156    enum PatNode {
157        Wildcard(String),
158        Integer(i64),
159        Symbol(String),
160        Add(Vec<ExprId>),
161        Mul(Vec<ExprId>),
162        Pow(ExprId, ExprId),
163        Func(String, Vec<ExprId>),
164        Literal,
165    }
166
167    enum ExprNode {
168        Integer(i64),
169        Symbol(String),
170        Add(Vec<ExprId>),
171        Mul(Vec<ExprId>),
172        Pow(ExprId, ExprId),
173        Func(String, Vec<ExprId>),
174        Other,
175    }
176
177    let pat_node = pool.with(pat, |data| match data {
178        ExprData::Symbol { name, .. } if is_wildcard(name) => PatNode::Wildcard(name.clone()),
179        ExprData::Symbol { name, .. } => PatNode::Symbol(name.clone()),
180        ExprData::Integer(n) => PatNode::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
181        ExprData::Add(args) => PatNode::Add(args.clone()),
182        ExprData::Mul(args) => PatNode::Mul(args.clone()),
183        ExprData::Pow { base, exp } => PatNode::Pow(*base, *exp),
184        ExprData::Func { name, args } => PatNode::Func(name.clone(), args.clone()),
185        ExprData::Rational(_) | ExprData::Float(_) => PatNode::Literal,
186        ExprData::Piecewise { .. } | ExprData::Predicate { .. } => PatNode::Literal,
187        ExprData::Forall { .. } | ExprData::Exists { .. } | ExprData::BigO(_) => PatNode::Literal,
188    });
189
190    let expr_node = pool.with(expr, |data| match data {
191        ExprData::Symbol { name, .. } => ExprNode::Symbol(name.clone()),
192        ExprData::Integer(n) => ExprNode::Integer(n.0.to_i64().unwrap_or(i64::MIN)),
193        ExprData::Add(args) => ExprNode::Add(args.clone()),
194        ExprData::Mul(args) => ExprNode::Mul(args.clone()),
195        ExprData::Pow { base, exp } => ExprNode::Pow(*base, *exp),
196        ExprData::Func { name, args } => ExprNode::Func(name.clone(), args.clone()),
197        _ => ExprNode::Other,
198    });
199
200    match pat_node {
201        // Wildcard: bind to the whole expression
202        PatNode::Wildcard(name) => {
203            let mut s = subst;
204            if s.bind(&name, expr) {
205                Some(s)
206            } else {
207                None
208            }
209        }
210
211        // Literal integer — must match exactly
212        PatNode::Integer(pn) => {
213            if matches!(expr_node, ExprNode::Integer(en) if en == pn) {
214                Some(subst)
215            } else {
216                None
217            }
218        }
219
220        // Literal symbol — must match the same symbol name (not a wildcard)
221        PatNode::Symbol(pname) => {
222            if matches!(expr_node, ExprNode::Symbol(ref ename) if *ename == pname) {
223                Some(subst)
224            } else {
225                None
226            }
227        }
228
229        // AC-aware Add matching
230        PatNode::Add(pat_args) => {
231            let ExprNode::Add(expr_args) = expr_node else {
232                return None;
233            };
234            if ac_depth >= MAX_AC_DEPTH {
235                // Fall back to exact positional matching to bound depth
236                return match_args_exact(&pat_args, &expr_args, subst, pool, ac_depth + 1);
237            }
238            match_ac_args(&pat_args, &expr_args, subst, pool, ac_depth, true)
239        }
240
241        // AC-aware Mul matching
242        PatNode::Mul(pat_args) => {
243            let ExprNode::Mul(expr_args) = expr_node else {
244                return None;
245            };
246            if ac_depth >= MAX_AC_DEPTH {
247                return match_args_exact(&pat_args, &expr_args, subst, pool, ac_depth + 1);
248            }
249            match_ac_args(&pat_args, &expr_args, subst, pool, ac_depth, true)
250        }
251
252        // Pow — exact structural match
253        PatNode::Pow(pb, pe) => {
254            let ExprNode::Pow(eb, ee) = expr_node else {
255                return None;
256            };
257            let s = match_one(pb, eb, subst, pool, ac_depth + 1)?;
258            match_one(pe, ee, s, pool, ac_depth + 1)
259        }
260
261        // Named function — name must match, args AC-matched if Add/Mul
262        PatNode::Func(pname, pargs) => {
263            let ExprNode::Func(ename, eargs) = expr_node else {
264                return None;
265            };
266            if pname != ename {
267                return None;
268            }
269            match_args_exact(&pargs, &eargs, subst, pool, ac_depth + 1)
270        }
271
272        // Rational/Float literal in pattern — match only if same id (structural equality)
273        PatNode::Literal => {
274            if pat == expr {
275                Some(subst)
276            } else {
277                None
278            }
279        }
280    }
281}
282
283/// Match pattern args against expr args positionally (no AC permutations).
284fn match_args_exact(
285    pat_args: &[ExprId],
286    expr_args: &[ExprId],
287    subst: Substitution,
288    pool: &ExprPool,
289    ac_depth: usize,
290) -> Option<Substitution> {
291    if pat_args.len() != expr_args.len() {
292        return None;
293    }
294    let mut s = subst;
295    for (&p, &e) in pat_args.iter().zip(expr_args.iter()) {
296        s = match_one(p, e, s, pool, ac_depth)?;
297    }
298    Some(s)
299}
300
301// ---------------------------------------------------------------------------
302// AC matching
303// ---------------------------------------------------------------------------
304
305/// AC-aware matching for n-ary Add or Mul.
306///
307/// If `pat_args.len() == expr_args.len()` we try all permutations of
308/// expr_args against pat_args (bounded by MAX_AC_DEPTH checks above).
309///
310/// If `pat_args.len() < expr_args.len()` we additionally try all size-k
311/// subsets of expr_args for the first k-1 pat_args, bundling the remainder
312/// into a single Add/Mul node bound to the last wildcard if it is one.
313///
314/// This approach is *sound* (every returned substitution is valid) and
315/// *complete for ground patterns* — every valid ground match is returned.
316fn match_ac_args(
317    pat_args: &[ExprId],
318    expr_args: &[ExprId],
319    subst: Substitution,
320    pool: &ExprPool,
321    ac_depth: usize,
322    is_add: bool,
323) -> Option<Substitution> {
324    if pat_args.is_empty() && expr_args.is_empty() {
325        return Some(subst);
326    }
327    if pat_args.is_empty() || expr_args.is_empty() {
328        return None;
329    }
330
331    // Exact-length case: try all permutations
332    if pat_args.len() == expr_args.len() {
333        return try_permutations(pat_args, expr_args, subst, pool, ac_depth);
334    }
335
336    // Pattern is shorter: try matching a subset of expr_args to the first
337    // pat_args, leaving the rest as a residual bound to the last pattern arg
338    // (only if it's a wildcard).
339    if pat_args.len() < expr_args.len() {
340        let last_pat = *pat_args.last().unwrap();
341        let is_last_wildcard = pool.with(
342            last_pat,
343            |data| matches!(data, ExprData::Symbol { name, .. } if is_wildcard(name)),
344        );
345
346        if !is_last_wildcard {
347            // Can't absorb remainder — no match
348            return None;
349        }
350
351        let prefix_len = pat_args.len() - 1;
352        // Try all size-(prefix_len) subsets of expr_args for the prefix pattern args
353        let indices: Vec<usize> = (0..expr_args.len()).collect();
354        return try_subsets(
355            pat_args, expr_args, &indices, prefix_len, subst, pool, ac_depth, is_add,
356        );
357    }
358
359    // Pattern is longer than expr: no match possible
360    None
361}
362
363/// Try matching pat_args against all permutations of a chosen expr_args subset.
364fn try_permutations(
365    pat_args: &[ExprId],
366    expr_args: &[ExprId],
367    subst: Substitution,
368    pool: &ExprPool,
369    ac_depth: usize,
370) -> Option<Substitution> {
371    // Generate permutations via Heap's algorithm
372    let mut perm: Vec<usize> = (0..expr_args.len()).collect();
373    loop {
374        // Try current permutation
375        let mut s = subst.clone();
376        let mut ok = true;
377        for (i, &pat_id) in pat_args.iter().enumerate() {
378            match match_one(pat_id, expr_args[perm[i]], s.clone(), pool, ac_depth + 1) {
379                Some(new_s) => s = new_s,
380                None => {
381                    ok = false;
382                    break;
383                }
384            }
385        }
386        if ok {
387            return Some(s);
388        }
389
390        // Advance to next permutation (Heap's algorithm)
391        if !next_permutation(&mut perm) {
392            break;
393        }
394    }
395    None
396}
397
398/// Advance `perm` to the next lexicographic permutation.  Returns `false`
399/// when already at the last permutation.
400fn next_permutation(perm: &mut [usize]) -> bool {
401    let n = perm.len();
402    if n <= 1 {
403        return false;
404    }
405    let mut i = n - 1;
406    while i > 0 && perm[i - 1] >= perm[i] {
407        i -= 1;
408    }
409    if i == 0 {
410        return false;
411    }
412    let j = (i..n).rfind(|&j| perm[j] > perm[i - 1]).unwrap();
413    perm.swap(i - 1, j);
414    perm[i..].reverse();
415    true
416}
417
418/// Try matching prefix pattern args against all size-`prefix_len` subsets of
419/// expr_args, binding the remainder to the last wildcard.
420#[allow(clippy::too_many_arguments)]
421fn try_subsets(
422    pat_args: &[ExprId],
423    expr_args: &[ExprId],
424    indices: &[usize],
425    prefix_len: usize,
426    subst: Substitution,
427    pool: &ExprPool,
428    ac_depth: usize,
429    is_add: bool,
430) -> Option<Substitution> {
431    if prefix_len == 0 {
432        // All expr_args go to the last wildcard
433        let last_pat = *pat_args.last().unwrap();
434        let residual: Vec<ExprId> = indices.iter().map(|&i| expr_args[i]).collect();
435        let residual_expr = match residual.len() {
436            0 => return None,
437            1 => residual[0],
438            _ => {
439                if is_add {
440                    pool.add(residual)
441                } else {
442                    pool.mul(residual)
443                }
444            }
445        };
446        let mut s = subst;
447        s.bind(
448            &pool.with(last_pat, |data| {
449                if let ExprData::Symbol { name, .. } = data {
450                    name.clone()
451                } else {
452                    String::new()
453                }
454            }),
455            residual_expr,
456        );
457        return if s.bindings.values().next().is_some() {
458            Some(s)
459        } else {
460            None
461        };
462    }
463
464    // Pick one element for the next prefix slot and recurse
465    for chosen_pos in 0..indices.len() {
466        let chosen = indices[chosen_pos];
467        let remaining: Vec<usize> = indices
468            .iter()
469            .enumerate()
470            .filter(|&(j, _)| j != chosen_pos)
471            .map(|(_, &i)| i)
472            .collect();
473        let pat_idx = pat_args.len() - 1 - prefix_len; // next prefix pattern index
474        if let Some(s) = match_one(
475            pat_args[pat_idx],
476            expr_args[chosen],
477            subst.clone(),
478            pool,
479            ac_depth + 1,
480        ) {
481            if let Some(final_s) = try_subsets(
482                pat_args,
483                expr_args,
484                &remaining,
485                prefix_len - 1,
486                s,
487                pool,
488                ac_depth,
489                is_add,
490            ) {
491                return Some(final_s);
492            }
493        }
494    }
495    None
496}
497
498// ---------------------------------------------------------------------------
499// Public API
500// ---------------------------------------------------------------------------
501
502/// Find all AC-aware matches of `pattern` anywhere in `expr`.
503///
504/// Returns a list of substitutions, one per distinct match site.  The
505/// search recurses into sub-expressions so that `match_pattern(a + b, f(x + y))`
506/// can match `x + y` inside `f(...)`.
507///
508/// # Example
509/// ```
510/// # use alkahest_cas::kernel::{ExprPool, Domain};
511/// # use alkahest_cas::pattern::{Pattern, match_pattern};
512/// let pool = ExprPool::new();
513/// let x = pool.symbol("x", Domain::Real);
514/// let y = pool.symbol("y", Domain::Real);
515/// let a = pool.symbol("a", Domain::Real);  // wildcard
516/// let b = pool.symbol("b", Domain::Real);  // wildcard
517/// let pat = Pattern::from_expr(pool.add(vec![a, b]));
518/// let expr = pool.add(vec![x, y]);
519/// let matches = match_pattern(&pat, expr, &pool);
520/// assert!(!matches.is_empty());
521/// ```
522pub fn match_pattern(pattern: &Pattern, expr: ExprId, pool: &ExprPool) -> Vec<Substitution> {
523    let mut results = Vec::new();
524    collect_matches(pattern.root, expr, pool, &mut results);
525    results
526}
527
528/// Recursively search `expr` and its sub-expressions for matches of `pat`.
529fn collect_matches(pat: ExprId, expr: ExprId, pool: &ExprPool, results: &mut Vec<Substitution>) {
530    // Try matching at this node
531    if let Some(s) = match_one(pat, expr, Substitution::new(), pool, 0) {
532        results.push(s);
533    }
534
535    // Recurse into children
536    let children: Vec<ExprId> = pool.with(expr, |data| match data {
537        ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
538        ExprData::Pow { base, exp } => vec![*base, *exp],
539        ExprData::Func { args, .. } => args.clone(),
540        _ => vec![],
541    });
542
543    for child in children {
544        collect_matches(pat, child, pool, results);
545    }
546}
547
548// ---------------------------------------------------------------------------
549// Tests
550// ---------------------------------------------------------------------------
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::kernel::{Domain, ExprPool};
556
557    fn pool() -> ExprPool {
558        ExprPool::new()
559    }
560
561    #[test]
562    fn wildcard_matches_anything() {
563        let p = pool();
564        let a = p.symbol("a", Domain::Real); // wildcard
565        let x = p.symbol("x", Domain::Real);
566        let pat = Pattern::from_expr(a);
567        let matches = match_pattern(&pat, x, &p);
568        assert_eq!(matches.len(), 1);
569        assert_eq!(matches[0].bindings["a"], x);
570    }
571
572    #[test]
573    fn literal_symbol_exact_match() {
574        let p = pool();
575        let x = p.symbol("x", Domain::Real); // non-wildcard (only if name starts upper or is multi-char, but here it starts lower)
576                                             // Use "X" to force a literal pattern
577        let xpat = p.symbol("X", Domain::Real); // non-wildcard
578        let pat = Pattern::from_expr(xpat);
579        // Should not match y
580        let y = p.symbol("Y", Domain::Real);
581        assert!(match_pattern(&pat, y, &p).is_empty());
582        // Should match X
583        assert!(!match_pattern(&pat, xpat, &p).is_empty());
584        let _ = x; // suppress unused warning
585    }
586
587    #[test]
588    fn add_pattern_ac_match() {
589        let p = pool();
590        let a = p.symbol("a", Domain::Real);
591        let b = p.symbol("b", Domain::Real);
592        let x = p.symbol("x", Domain::Real);
593        let y = p.symbol("y", Domain::Real);
594        // Pattern: a + b  should match x + y in either order
595        let pat = Pattern::from_expr(p.add(vec![a, b]));
596        let expr = p.add(vec![x, y]);
597        let matches = match_pattern(&pat, expr, &p);
598        // At least one match where {a→x, b→y} or {a→y, b→x}
599        assert!(!matches.is_empty(), "a+b should match x+y");
600    }
601
602    #[test]
603    fn add_pattern_two_splits_for_three_terms() {
604        // Pattern a + b on x + y + z should find a match for each pair
605        let p = pool();
606        let a = p.symbol("a", Domain::Real);
607        let b = p.symbol("b", Domain::Real);
608        let x = p.symbol("x", Domain::Real);
609        let y = p.symbol("y", Domain::Real);
610        let z = p.symbol("z", Domain::Real);
611        let pat = Pattern::from_expr(p.add(vec![a, b]));
612        let expr = p.add(vec![x, y, z]);
613        let matches = match_pattern(&pat, expr, &p);
614        // At least one match (b absorbs the remaining two-element sum)
615        assert!(!matches.is_empty(), "a+b should match subsets of x+y+z");
616    }
617
618    #[test]
619    fn substitution_apply() {
620        let p = pool();
621        let a = p.symbol("a", Domain::Real);
622        let x = p.symbol("x", Domain::Real);
623        let one = p.integer(1_i32);
624        let pat = p.add(vec![a, one]); // a + 1
625        let mut subst = Substitution::new();
626        subst.bind("a", x);
627        let result = subst.apply(pat, &p);
628        // x + 1
629        let expected = p.add(vec![x, one]);
630        assert_eq!(result, expected);
631    }
632
633    #[test]
634    fn match_inside_function() {
635        // Pattern a + b should match inside f(x + y)
636        let p = pool();
637        let a = p.symbol("a", Domain::Real);
638        let b = p.symbol("b", Domain::Real);
639        let x = p.symbol("x", Domain::Real);
640        let y = p.symbol("y", Domain::Real);
641        let inner = p.add(vec![x, y]);
642        let f = p.func("f", vec![inner]);
643        let pat = Pattern::from_expr(p.add(vec![a, b]));
644        let matches = match_pattern(&pat, f, &p);
645        assert!(!matches.is_empty(), "should find a+b inside f(x+y)");
646    }
647
648    #[test]
649    fn no_spurious_matches() {
650        // Pattern a * b should NOT match x + y
651        let p = pool();
652        let a = p.symbol("a", Domain::Real);
653        let b = p.symbol("b", Domain::Real);
654        let x = p.symbol("x", Domain::Real);
655        let y = p.symbol("y", Domain::Real);
656        let pat = Pattern::from_expr(p.mul(vec![a, b]));
657        let expr = p.add(vec![x, y]);
658        assert!(
659            match_pattern(&pat, expr, &p).is_empty(),
660            "mul pattern should not match add"
661        );
662    }
663
664    #[test]
665    fn consistent_wildcard_bindings() {
666        // Pattern a + a: both copies of `a` must bind to the same thing
667        let p = pool();
668        let a = p.symbol("a", Domain::Real);
669        let x = p.symbol("x", Domain::Real);
670        let y = p.symbol("y", Domain::Real);
671        let pat = Pattern::from_expr(p.add(vec![a, a]));
672        // x + x should match
673        assert!(!match_pattern(&pat, p.add(vec![x, x]), &p).is_empty());
674        // x + y should NOT match
675        assert!(match_pattern(&pat, p.add(vec![x, y]), &p).is_empty());
676    }
677}