Skip to main content

oxiz_theories/
quantifier_code_tree.rs

1//! E-matching Code Tree for Multi-Pattern Matching.
2//!
3//! This module implements a code tree data structure for efficient E-matching,
4//! which is the core of pattern-based quantifier instantiation in SMT solvers.
5//!
6//! ## Code Tree Structure
7//!
8//! A code tree is a compiled representation of multiple patterns that enables
9//! simultaneous matching against a database of ground terms. It combines:
10//!
11//! 1. **Finite Automaton**: Patterns are compiled into a tree-structured automaton
12//! 2. **Backtracking**: Handles multiple possible matches efficiently
13//! 3. **Substitution Building**: Incrementally builds variable substitutions
14//! 4. **Indexing**: Fast lookup of matching candidates by function symbols
15//!
16//! ## Advantages over Naive Matching
17//!
18//! - **Shared Prefixes**: Common pattern prefixes are matched once
19//! - **Indexing**: Only relevant terms are considered (by root symbol)
20//! - **Early Pruning**: Failed matches are detected early
21//! - **Incremental**: Supports incremental addition of ground terms
22//!
23//! ## Example
24//!
25//! For patterns:
26//! - `f(x, g(x))` from quantifier `∀x. P(f(x, g(x)))`
27//! - `f(y, g(z))` from quantifier `∀y,z. Q(f(y, g(z)))`
28//!
29//! The code tree shares the matching of `f(?, g(?))` and then checks variable constraints.
30//!
31//! ## References
32//!
33//! - de Moura & Bjørner: "Efficient E-Matching for SMT Solvers" (2007)
34//! - Z3's `muz/rel/dl_mk_filter_rules.cpp` and `ast/pattern/pattern_inference.cpp`
35//! - Simplify's E-matching implementation
36
37use lasso::Spur;
38use oxiz_core::ast::{TermId, TermKind, TermManager};
39use oxiz_core::sort::SortId;
40use rustc_hash::{FxHashMap, FxHashSet};
41
42/// Variable identifier in patterns.
43pub type PatternVar = u32;
44
45/// Instruction in the code tree (compiled pattern matching bytecode).
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum CodeTreeInstr {
48    /// Check if current term has a specific function symbol.
49    /// Args: (symbol_id, child_count, failure_pc)
50    CheckSymbol {
51        /// Function symbol
52        symbol: Spur,
53        /// Number of arguments
54        arity: usize,
55        /// PC on failure
56        failure_pc: usize,
57    },
58
59    /// Check if current term is a variable (for patterns with nested variables).
60    /// Args: (failure_pc)
61    CheckVar {
62        /// PC on failure
63        failure_pc: usize,
64    },
65
66    /// Check if current term is a constant (specific value).
67    /// Args: (value_id, failure_pc)
68    CheckConstant {
69        /// Expected value
70        value: TermId,
71        /// PC on failure
72        failure_pc: usize,
73    },
74
75    /// Bind current term to a pattern variable.
76    /// Args: (variable_id)
77    Bind {
78        /// Pattern variable ID
79        var: PatternVar,
80    },
81
82    /// Check if current term matches a previously bound variable (occurs check).
83    /// Args: (variable_id, failure_pc)
84    CheckEq {
85        /// Pattern variable ID
86        var: PatternVar,
87        /// PC on failure
88        failure_pc: usize,
89    },
90
91    /// Move to the i-th child of current term.
92    /// Args: (child_index)
93    MoveToChild {
94        /// Child index
95        index: usize,
96    },
97
98    /// Move to parent term (backtrack).
99    MoveToParent,
100
101    /// Yield a match with the current substitution.
102    /// Args: (quantifier_id, pattern_index)
103    Yield {
104        /// Quantifier term ID
105        quantifier: TermId,
106        /// Pattern index
107        pattern_idx: usize,
108    },
109
110    /// Halt execution (end of program).
111    Halt,
112}
113
114/// A compiled pattern for a quantified formula.
115#[derive(Debug, Clone)]
116pub struct CompiledPattern {
117    /// Original pattern terms.
118    pub pattern: Vec<TermId>,
119    /// Compiled instructions.
120    pub instructions: Vec<CodeTreeInstr>,
121    /// Pattern variables and their sorts.
122    pub variables: FxHashMap<PatternVar, SortId>,
123    /// Quantifier this pattern belongs to.
124    pub quantifier: TermId,
125    /// Pattern index within quantifier.
126    pub pattern_index: usize,
127}
128
129/// Match context during execution.
130#[derive(Debug, Clone)]
131struct MatchContext {
132    /// Current term being examined.
133    current_term: TermId,
134    /// Current instruction pointer.
135    pc: usize,
136    /// Variable substitutions.
137    substitution: FxHashMap<PatternVar, TermId>,
138    /// Term stack (for MoveToChild/MoveToParent).
139    term_stack: Vec<TermId>,
140}
141
142impl MatchContext {
143    fn new(root: TermId) -> Self {
144        Self {
145            current_term: root,
146            pc: 0,
147            substitution: FxHashMap::default(),
148            term_stack: vec![root],
149        }
150    }
151}
152
153/// Match result from code tree execution.
154#[derive(Debug, Clone)]
155pub struct Match {
156    /// Quantifier being instantiated.
157    pub quantifier: TermId,
158    /// Pattern index.
159    pub pattern_index: usize,
160    /// Variable substitution (pattern_var -> ground_term).
161    pub substitution: FxHashMap<PatternVar, TermId>,
162}
163
164/// Statistics for code tree matching.
165#[derive(Debug, Clone, Default)]
166pub struct CodeTreeStats {
167    /// Number of patterns compiled.
168    pub patterns_compiled: usize,
169    /// Number of instructions in the code tree.
170    pub total_instructions: usize,
171    /// Number of ground terms indexed.
172    pub ground_terms_indexed: usize,
173    /// Number of matches found.
174    pub matches_found: usize,
175    /// Number of match attempts.
176    pub match_attempts: usize,
177    /// Number of failed matches (early pruning).
178    pub failed_matches: usize,
179    /// Time spent in matching (microseconds).
180    pub matching_time_us: u64,
181}
182
183/// E-matching code tree for multi-pattern matching.
184pub struct CodeTree {
185    /// Compiled patterns, indexed by root symbol.
186    /// Key: function symbol, Value: list of compiled patterns starting with that symbol.
187    symbol_index: FxHashMap<Spur, Vec<CompiledPattern>>,
188
189    /// Patterns starting with variables (need to check all ground terms).
190    variable_patterns: Vec<CompiledPattern>,
191
192    /// Database of ground terms, indexed by root symbol.
193    /// Key: function symbol, Value: set of term IDs with that root.
194    ground_terms: FxHashMap<Spur, FxHashSet<TermId>>,
195
196    /// All ground terms (for variable patterns).
197    all_ground_terms: FxHashSet<TermId>,
198
199    /// Statistics.
200    stats: CodeTreeStats,
201}
202
203impl CodeTree {
204    /// Create a new code tree.
205    pub fn new() -> Self {
206        Self {
207            symbol_index: FxHashMap::default(),
208            variable_patterns: Vec::new(),
209            ground_terms: FxHashMap::default(),
210            all_ground_terms: FxHashSet::default(),
211            stats: CodeTreeStats::default(),
212        }
213    }
214
215    /// Get statistics.
216    pub fn stats(&self) -> &CodeTreeStats {
217        &self.stats
218    }
219
220    /// Reset statistics.
221    pub fn reset_stats(&mut self) {
222        self.stats = CodeTreeStats::default();
223    }
224
225    /// Compile and add patterns from a quantified formula.
226    ///
227    /// # Arguments
228    /// * `quantifier` - The quantified formula term ID
229    /// * `patterns` - List of pattern lists (multi-patterns)
230    /// * `var_mapping` - Maps bound variable names to pattern variable IDs
231    /// * `tm` - Term manager for accessing term structure
232    pub fn add_patterns(
233        &mut self,
234        quantifier: TermId,
235        patterns: &[Vec<TermId>],
236        var_mapping: &FxHashMap<Spur, PatternVar>,
237        tm: &TermManager,
238    ) {
239        for (pattern_idx, pattern) in patterns.iter().enumerate() {
240            if pattern.is_empty() {
241                continue;
242            }
243
244            // Compile each pattern in the multi-pattern
245            for pattern_term in pattern {
246                let compiled =
247                    self.compile_pattern(*pattern_term, quantifier, pattern_idx, var_mapping, tm);
248
249                if !compiled.instructions.is_empty() {
250                    // Index by root symbol
251                    if let Some(root_sym) = self.get_root_symbol(*pattern_term, tm) {
252                        self.symbol_index
253                            .entry(root_sym)
254                            .or_default()
255                            .push(compiled);
256                    } else {
257                        // Pattern starts with variable
258                        self.variable_patterns.push(compiled);
259                    }
260
261                    self.stats.patterns_compiled += 1;
262                }
263            }
264        }
265    }
266
267    /// Compile a single pattern term into instructions.
268    fn compile_pattern(
269        &mut self,
270        pattern: TermId,
271        quantifier: TermId,
272        pattern_idx: usize,
273        var_mapping: &FxHashMap<Spur, PatternVar>,
274        tm: &TermManager,
275    ) -> CompiledPattern {
276        let mut instructions = Vec::new();
277        let mut variables = FxHashMap::default();
278        let mut bound_vars = FxHashMap::default();
279
280        self.compile_term(
281            pattern,
282            var_mapping,
283            &mut bound_vars,
284            &mut variables,
285            &mut instructions,
286            tm,
287        );
288
289        // Add yield and halt
290        instructions.push(CodeTreeInstr::Yield {
291            quantifier,
292            pattern_idx,
293        });
294        instructions.push(CodeTreeInstr::Halt);
295
296        self.stats.total_instructions += instructions.len();
297
298        CompiledPattern {
299            pattern: vec![pattern],
300            instructions,
301            variables,
302            quantifier,
303            pattern_index: pattern_idx,
304        }
305    }
306
307    /// Recursively compile a pattern term.
308    fn compile_term(
309        &self,
310        term: TermId,
311        var_mapping: &FxHashMap<Spur, PatternVar>,
312        bound_vars: &mut FxHashMap<PatternVar, usize>,
313        variables: &mut FxHashMap<PatternVar, SortId>,
314        instructions: &mut Vec<CodeTreeInstr>,
315        tm: &TermManager,
316    ) {
317        let term_data = tm.get(term).expect("term should exist in manager");
318        match &term_data.kind {
319            TermKind::Var(name) => {
320                let sort = term_data.sort;
321
322                // Pattern variable
323                if let Some(&var_id) = var_mapping.get(name) {
324                    if let Some(&_first_occurrence) = bound_vars.get(&var_id) {
325                        // Variable already seen, check equality
326                        let failure_pc = instructions.len() + 1;
327                        instructions.push(CodeTreeInstr::CheckEq {
328                            var: var_id,
329                            failure_pc,
330                        });
331                    } else {
332                        // First occurrence, bind it
333                        bound_vars.insert(var_id, instructions.len());
334                        variables.insert(var_id, sort);
335                        instructions.push(CodeTreeInstr::Bind { var: var_id });
336                    }
337                } else {
338                    // Free variable (shouldn't happen in well-formed patterns)
339                    let failure_pc = instructions.len() + 1;
340                    instructions.push(CodeTreeInstr::CheckVar { failure_pc });
341                }
342            }
343
344            TermKind::IntConst(_) | TermKind::RealConst(_) | TermKind::BitVecConst { .. } => {
345                // Constant value - check exact match
346                let failure_pc = instructions.len() + 1;
347                instructions.push(CodeTreeInstr::CheckConstant {
348                    value: term,
349                    failure_pc,
350                });
351            }
352
353            TermKind::Apply { func, args } => {
354                // Function application
355                let symbol = *func;
356                let arity = args.len();
357                let failure_pc = instructions.len() + arity + 2;
358
359                instructions.push(CodeTreeInstr::CheckSymbol {
360                    symbol,
361                    arity,
362                    failure_pc,
363                });
364
365                // Recursively compile each argument
366                for (i, &arg) in args.iter().enumerate() {
367                    instructions.push(CodeTreeInstr::MoveToChild { index: i });
368                    self.compile_term(arg, var_mapping, bound_vars, variables, instructions, tm);
369                    instructions.push(CodeTreeInstr::MoveToParent);
370                }
371            }
372
373            _ => {
374                // Other term kinds (e.g., Lambda) - treat as opaque
375                let failure_pc = instructions.len() + 1;
376                instructions.push(CodeTreeInstr::CheckConstant {
377                    value: term,
378                    failure_pc,
379                });
380            }
381        }
382    }
383
384    /// Add a ground term to the database.
385    ///
386    /// This term will be considered for matching against patterns.
387    pub fn add_ground_term(&mut self, term: TermId, tm: &TermManager) {
388        if self.all_ground_terms.contains(&term) {
389            return; // Already indexed
390        }
391
392        self.all_ground_terms.insert(term);
393
394        // Index by root symbol
395        if let Some(root_sym) = self.get_root_symbol(term, tm) {
396            self.ground_terms.entry(root_sym).or_default().insert(term);
397        }
398
399        self.stats.ground_terms_indexed += 1;
400    }
401
402    /// Find all matches for the indexed patterns against the ground terms.
403    ///
404    /// Returns a list of matches (quantifier, pattern_index, substitution).
405    pub fn find_matches(&mut self, tm: &TermManager) -> Vec<Match> {
406        let start = std::time::Instant::now();
407        let mut matches = Vec::new();
408
409        // Match patterns indexed by symbol
410        for (symbol, patterns) in &self.symbol_index {
411            if let Some(ground_terms) = self.ground_terms.get(symbol) {
412                for term in ground_terms {
413                    for pattern in patterns {
414                        self.stats.match_attempts += 1;
415                        if let Some(m) = self.execute_pattern(pattern, *term, tm) {
416                            matches.push(m);
417                            self.stats.matches_found += 1;
418                        } else {
419                            self.stats.failed_matches += 1;
420                        }
421                    }
422                }
423            }
424        }
425
426        // Match variable patterns (against all ground terms)
427        for pattern in &self.variable_patterns {
428            for term in &self.all_ground_terms {
429                self.stats.match_attempts += 1;
430                if let Some(m) = self.execute_pattern(pattern, *term, tm) {
431                    matches.push(m);
432                    self.stats.matches_found += 1;
433                } else {
434                    self.stats.failed_matches += 1;
435                }
436            }
437        }
438
439        self.stats.matching_time_us += start.elapsed().as_micros() as u64;
440        matches
441    }
442
443    /// Execute a compiled pattern against a ground term.
444    fn execute_pattern(
445        &self,
446        pattern: &CompiledPattern,
447        ground_term: TermId,
448        tm: &TermManager,
449    ) -> Option<Match> {
450        let mut context = MatchContext::new(ground_term);
451
452        while context.pc < pattern.instructions.len() {
453            match &pattern.instructions[context.pc] {
454                CodeTreeInstr::CheckSymbol {
455                    symbol,
456                    arity,
457                    failure_pc,
458                } => {
459                    if let Some(current) = tm.get(context.current_term)
460                        && let TermKind::Apply { func, args } = &current.kind
461                        && func == symbol
462                        && args.len() == *arity
463                    {
464                        context.pc += 1;
465                        continue;
466                    }
467                    // Failed match
468                    context.pc = *failure_pc;
469                    if context.pc >= pattern.instructions.len() {
470                        return None;
471                    }
472                }
473
474                CodeTreeInstr::CheckVar { failure_pc } => {
475                    if let Some(current) = tm.get(context.current_term)
476                        && matches!(current.kind, TermKind::Var(_))
477                    {
478                        context.pc += 1;
479                        continue;
480                    }
481                    context.pc = *failure_pc;
482                    if context.pc >= pattern.instructions.len() {
483                        return None;
484                    }
485                }
486
487                CodeTreeInstr::CheckConstant { value, failure_pc } => {
488                    if context.current_term == *value {
489                        context.pc += 1;
490                    } else {
491                        context.pc = *failure_pc;
492                        if context.pc >= pattern.instructions.len() {
493                            return None;
494                        }
495                    }
496                }
497
498                CodeTreeInstr::Bind { var } => {
499                    context.substitution.insert(*var, context.current_term);
500                    context.pc += 1;
501                }
502
503                CodeTreeInstr::CheckEq { var, failure_pc } => {
504                    if let Some(&bound_term) = context.substitution.get(var)
505                        && bound_term == context.current_term
506                    {
507                        context.pc += 1;
508                        continue;
509                    }
510                    context.pc = *failure_pc;
511                    if context.pc >= pattern.instructions.len() {
512                        return None;
513                    }
514                }
515
516                CodeTreeInstr::MoveToChild { index } => {
517                    if let Some(current) = tm.get(context.current_term)
518                        && let TermKind::Apply { args, .. } = &current.kind
519                        && *index < args.len()
520                    {
521                        context.term_stack.push(context.current_term);
522                        context.current_term = args[*index];
523                        context.pc += 1;
524                        continue;
525                    }
526                    return None; // Cannot move to child
527                }
528
529                CodeTreeInstr::MoveToParent => {
530                    if let Some(parent) = context.term_stack.pop() {
531                        context.current_term = parent;
532                        context.pc += 1;
533                    } else {
534                        return None; // No parent
535                    }
536                }
537
538                CodeTreeInstr::Yield {
539                    quantifier,
540                    pattern_idx,
541                } => {
542                    return Some(Match {
543                        quantifier: *quantifier,
544                        pattern_index: *pattern_idx,
545                        substitution: context.substitution.clone(),
546                    });
547                }
548
549                CodeTreeInstr::Halt => {
550                    return None;
551                }
552            }
553        }
554
555        None
556    }
557
558    /// Get the root symbol of a term (function name at root).
559    fn get_root_symbol(&self, term: TermId, tm: &TermManager) -> Option<Spur> {
560        if let Some(term_data) = tm.get(term)
561            && let TermKind::Apply { func, .. } = &term_data.kind
562        {
563            return Some(*func);
564        }
565        None
566    }
567
568    /// Clear all ground terms and reset matching state.
569    pub fn clear_ground_terms(&mut self) {
570        self.ground_terms.clear();
571        self.all_ground_terms.clear();
572        self.stats.ground_terms_indexed = 0;
573    }
574
575    /// Remove patterns associated with a quantifier.
576    pub fn remove_quantifier(&mut self, quantifier: TermId) {
577        // Remove from symbol index
578        for patterns in self.symbol_index.values_mut() {
579            patterns.retain(|p| p.quantifier != quantifier);
580        }
581
582        // Remove from variable patterns
583        self.variable_patterns
584            .retain(|p| p.quantifier != quantifier);
585    }
586}
587
588impl Default for CodeTree {
589    fn default() -> Self {
590        Self::new()
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use lasso::{Key, Rodeo};
598
599    fn setup_term_manager() -> (TermManager, Rodeo) {
600        (TermManager::new(), Rodeo::default())
601    }
602
603    #[test]
604    fn test_code_tree_creation() {
605        let tree = CodeTree::new();
606        assert_eq!(tree.stats().patterns_compiled, 0);
607        assert_eq!(tree.stats().ground_terms_indexed, 0);
608    }
609
610    #[test]
611    fn test_code_tree_stats() {
612        let mut tree = CodeTree::new();
613        assert_eq!(tree.stats().matches_found, 0);
614
615        tree.stats.matches_found = 10;
616        assert_eq!(tree.stats().matches_found, 10);
617
618        tree.reset_stats();
619        assert_eq!(tree.stats().matches_found, 0);
620    }
621
622    #[test]
623    fn test_clear_ground_terms() {
624        let mut tree = CodeTree::new();
625        let (_tm, _) = setup_term_manager();
626
627        // This is a simplified test - in real usage, we'd create actual terms
628        tree.all_ground_terms.insert(TermId::new(1));
629        tree.stats.ground_terms_indexed = 1;
630
631        tree.clear_ground_terms();
632        assert!(tree.all_ground_terms.is_empty());
633        assert_eq!(tree.stats.ground_terms_indexed, 0);
634    }
635
636    #[test]
637    fn test_match_context_creation() {
638        let ctx = MatchContext::new(TermId::new(1));
639        assert_eq!(ctx.current_term, TermId::new(1));
640        assert_eq!(ctx.pc, 0);
641        assert!(ctx.substitution.is_empty());
642        assert_eq!(ctx.term_stack.len(), 1);
643    }
644
645    #[test]
646    fn test_code_tree_instruction_check_symbol() {
647        let instr = CodeTreeInstr::CheckSymbol {
648            symbol: Spur::try_from_usize(0).unwrap(),
649            arity: 2,
650            failure_pc: 10,
651        };
652
653        match instr {
654            CodeTreeInstr::CheckSymbol {
655                symbol: _,
656                arity,
657                failure_pc,
658            } => {
659                assert_eq!(arity, 2);
660                assert_eq!(failure_pc, 10);
661            }
662            _ => panic!("Wrong instruction type"),
663        }
664    }
665
666    #[test]
667    fn test_code_tree_instruction_bind() {
668        let instr = CodeTreeInstr::Bind { var: 42 };
669
670        match instr {
671            CodeTreeInstr::Bind { var } => {
672                assert_eq!(var, 42);
673            }
674            _ => panic!("Wrong instruction type"),
675        }
676    }
677
678    #[test]
679    fn test_compiled_pattern() {
680        let pattern = CompiledPattern {
681            pattern: vec![TermId::new(1)],
682            instructions: vec![
683                CodeTreeInstr::Bind { var: 0 },
684                CodeTreeInstr::Yield {
685                    quantifier: TermId::new(2),
686                    pattern_idx: 0,
687                },
688                CodeTreeInstr::Halt,
689            ],
690            variables: FxHashMap::default(),
691            quantifier: TermId::new(2),
692            pattern_index: 0,
693        };
694
695        assert_eq!(pattern.instructions.len(), 3);
696        assert_eq!(pattern.pattern_index, 0);
697    }
698
699    #[test]
700    fn test_match_result() {
701        let mut subst = FxHashMap::default();
702        subst.insert(0, TermId::new(100));
703
704        let m = Match {
705            quantifier: TermId::new(1),
706            pattern_index: 0,
707            substitution: subst,
708        };
709
710        assert_eq!(m.quantifier, TermId::new(1));
711        assert_eq!(m.pattern_index, 0);
712        assert_eq!(m.substitution.len(), 1);
713    }
714}