Skip to main content

oxiz_solver/mbqi/
patterns.rs

1//! Pattern Matching and Trigger Generation for MBQI
2//!
3//! This module implements sophisticated pattern matching and trigger generation
4//! algorithms for E-matching style quantifier instantiation.
5
6#[allow(unused_imports)]
7use crate::prelude::*;
8use oxiz_core::ast::{TermId, TermKind, TermManager};
9use oxiz_core::interner::Spur;
10
11use super::{QuantifiedFormula, QuantifierConfig};
12
13/// Strategy for ranking pattern candidates.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum PatternStrategy {
16    /// Prefer shallower, cheaper patterns.
17    StaticDepth,
18    /// Prefer patterns that greedily cover more e-graph ground-term shapes.
19    GreedyCover,
20}
21
22/// Coarse structural shape of a term for coverage scoring.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum TermShape {
25    /// Boolean constant.
26    BoolConst,
27    /// Integer constant.
28    IntConst,
29    /// Real constant.
30    RealConst,
31    /// Variable occurrence.
32    Var,
33    /// Uninterpreted application with arity.
34    Apply { arity: usize },
35    /// Equality-like comparison.
36    Eq,
37    /// Strict inequality.
38    StrictIneq,
39    /// Non-strict inequality.
40    NonStrictIneq,
41    /// Arithmetic sum.
42    Add { arity: usize },
43    /// Arithmetic product.
44    Mul { arity: usize },
45    /// Catch-all for other shapes.
46    Other,
47}
48
49impl TermShape {
50    fn from_term(term: TermId, manager: &TermManager) -> Self {
51        let Some(node) = manager.get(term) else {
52            return Self::Other;
53        };
54
55        match &node.kind {
56            TermKind::True | TermKind::False => Self::BoolConst,
57            TermKind::IntConst(_) => Self::IntConst,
58            TermKind::RealConst(_) => Self::RealConst,
59            TermKind::Var(_) => Self::Var,
60            TermKind::Apply { args, .. } => Self::Apply { arity: args.len() },
61            TermKind::Eq(_, _) => Self::Eq,
62            TermKind::Lt(_, _) | TermKind::Gt(_, _) => Self::StrictIneq,
63            TermKind::Le(_, _) | TermKind::Ge(_, _) => Self::NonStrictIneq,
64            TermKind::Add(args) => Self::Add { arity: args.len() },
65            TermKind::Mul(args) => Self::Mul { arity: args.len() },
66            _ => Self::Other,
67        }
68    }
69}
70
71/// Scores candidate pattern sets by greedy coverage over observed term shapes.
72#[derive(Debug, Default, Clone)]
73pub struct PatternCoverScorer;
74
75impl PatternCoverScorer {
76    /// Score pattern sets by greedy set cover over e-graph ground-term shapes.
77    pub fn score_cover(
78        &self,
79        candidate_patterns: &[PatternSet],
80        egraph_ground_terms: &[TermShape],
81    ) -> Vec<(usize, f64)> {
82        if candidate_patterns.is_empty() {
83            return Vec::new();
84        }
85
86        let total_shapes = egraph_ground_terms
87            .iter()
88            .cloned()
89            .collect::<FxHashSet<_>>();
90        if total_shapes.is_empty() {
91            return candidate_patterns
92                .iter()
93                .enumerate()
94                .map(|(idx, _)| (idx, 0.0))
95                .collect();
96        }
97
98        let mut remaining = total_shapes;
99        let mut pending: Vec<usize> = (0..candidate_patterns.len()).collect();
100        let mut ranked = Vec::with_capacity(candidate_patterns.len());
101
102        while !pending.is_empty() {
103            let mut best_pos = 0usize;
104            let mut best_gain = 0usize;
105            let mut best_score = 0.0f64;
106
107            for (pos, &idx) in pending.iter().enumerate() {
108                let covered = candidate_patterns[idx]
109                    .covered_shapes
110                    .iter()
111                    .filter(|shape| remaining.contains(*shape))
112                    .count();
113                let score = covered as f64 / egraph_ground_terms.len() as f64;
114                if covered > best_gain || (covered == best_gain && score > best_score) {
115                    best_pos = pos;
116                    best_gain = covered;
117                    best_score = score;
118                }
119            }
120
121            let chosen_idx = pending.remove(best_pos);
122            for shape in &candidate_patterns[chosen_idx].covered_shapes {
123                remaining.remove(shape);
124            }
125            ranked.push((chosen_idx, best_score));
126        }
127
128        ranked
129    }
130}
131
132/// A pattern for E-matching
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Pattern {
135    /// The pattern terms
136    pub terms: Vec<TermId>,
137    /// Variables in the pattern
138    pub variables: FxHashSet<Spur>,
139    /// Pattern quality score
140    pub quality: u32,
141    /// Pattern type
142    pub pattern_type: PatternType,
143}
144
145impl Pattern {
146    /// Create a new pattern
147    pub fn new(terms: Vec<TermId>) -> Self {
148        Self {
149            terms,
150            variables: FxHashSet::default(),
151            quality: 0,
152            pattern_type: PatternType::MultiPattern,
153        }
154    }
155
156    /// Extract variables from the pattern
157    pub fn extract_variables(&mut self, manager: &TermManager) {
158        self.variables.clear();
159        // Collect terms first to avoid borrow checker issues
160        let terms: Vec<_> = self.terms.to_vec();
161        for term in terms {
162            self.extract_vars_rec(term, manager);
163        }
164    }
165
166    fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
167        let mut visited = FxHashSet::default();
168        self.extract_vars_helper(term, manager, &mut visited);
169    }
170
171    fn extract_vars_helper(
172        &mut self,
173        term: TermId,
174        manager: &TermManager,
175        visited: &mut FxHashSet<TermId>,
176    ) {
177        if visited.contains(&term) {
178            return;
179        }
180        visited.insert(term);
181
182        let Some(t) = manager.get(term) else {
183            return;
184        };
185
186        if let TermKind::Var(name) = t.kind {
187            self.variables.insert(name);
188            return;
189        }
190
191        match &t.kind {
192            TermKind::Apply { args, .. } => {
193                for &arg in args.iter() {
194                    self.extract_vars_helper(arg, manager, visited);
195                }
196            }
197            TermKind::Not(arg) | TermKind::Neg(arg) => {
198                self.extract_vars_helper(*arg, manager, visited);
199            }
200            TermKind::And(args) | TermKind::Or(args) => {
201                for &arg in args {
202                    self.extract_vars_helper(arg, manager, visited);
203                }
204            }
205            _ => {}
206        }
207    }
208
209    /// Calculate pattern quality
210    pub fn calculate_quality(&mut self, manager: &TermManager) {
211        // Quality factors:
212        // 1. Number of function symbols (more = better)
213        // 2. Number of variables covered
214        // 3. Pattern complexity
215
216        let num_funcs = self.count_function_symbols(manager);
217        let num_vars = self.variables.len();
218        let complexity_penalty = self.terms.len();
219
220        self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
221    }
222
223    fn count_function_symbols(&self, manager: &TermManager) -> usize {
224        let mut count = 0;
225        let mut visited = FxHashSet::default();
226
227        for &term in &self.terms {
228            count += self.count_funcs_rec(term, manager, &mut visited);
229        }
230
231        count
232    }
233
234    fn count_funcs_rec(
235        &self,
236        term: TermId,
237        manager: &TermManager,
238        visited: &mut FxHashSet<TermId>,
239    ) -> usize {
240        if visited.contains(&term) {
241            return 0;
242        }
243        visited.insert(term);
244
245        let Some(t) = manager.get(term) else {
246            return 0;
247        };
248
249        match &t.kind {
250            TermKind::Apply { args, .. } => {
251                1 + args
252                    .iter()
253                    .map(|&arg| self.count_funcs_rec(arg, manager, visited))
254                    .sum::<usize>()
255            }
256            _ => 0,
257        }
258    }
259}
260
261/// Type of pattern
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
263pub enum PatternType {
264    /// Single term pattern
265    SingleTerm,
266    /// Multi-pattern (multiple terms)
267    MultiPattern,
268    /// User-specified pattern
269    UserSpecified,
270    /// Auto-generated pattern
271    AutoGenerated,
272}
273
274/// Pattern generator
275#[derive(Debug)]
276pub struct PatternGenerator {
277    /// Maximum patterns to generate
278    max_patterns: usize,
279    /// Minimum pattern quality
280    min_quality: u32,
281    /// Statistics
282    stats: GeneratorStats,
283    /// Pattern ranking strategy
284    strategy: PatternStrategy,
285}
286
287impl PatternGenerator {
288    /// Create a new pattern generator
289    pub fn new() -> Self {
290        let config = QuantifierConfig::default();
291        Self {
292            max_patterns: 10,
293            min_quality: 0,
294            stats: GeneratorStats::default(),
295            strategy: config.pattern_strategy,
296        }
297    }
298
299    /// Generate patterns for a quantifier
300    pub fn generate(
301        &mut self,
302        quantifier: &QuantifiedFormula,
303        manager: &TermManager,
304    ) -> Vec<Pattern> {
305        self.stats.num_generations += 1;
306
307        // If user specified patterns, use those
308        if !quantifier.patterns.is_empty() {
309            return self.user_patterns_to_patterns(&quantifier.patterns, manager);
310        }
311
312        // Auto-generate patterns
313        let mut patterns = Vec::new();
314
315        // Strategy 1: Function application patterns
316        patterns.extend(self.generate_function_patterns(quantifier.body, manager));
317
318        // Strategy 2: Equality patterns
319        patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
320
321        // Strategy 3: Arithmetic patterns
322        patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
323
324        // Filter by quality
325        patterns.retain(|p| p.quality >= self.min_quality);
326
327        match self.strategy {
328            PatternStrategy::StaticDepth => {
329                patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
330            }
331            PatternStrategy::GreedyCover => {
332                patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
333            }
334        }
335
336        // Limit number of patterns
337        patterns.truncate(self.max_patterns);
338
339        self.stats.num_patterns_generated += patterns.len();
340
341        patterns
342    }
343
344    fn user_patterns_to_patterns(
345        &self,
346        user_patterns: &[Vec<TermId>],
347        manager: &TermManager,
348    ) -> Vec<Pattern> {
349        let mut patterns = Vec::new();
350
351        for pattern_terms in user_patterns {
352            let mut pattern = Pattern::new(pattern_terms.clone());
353            pattern.extract_variables(manager);
354            pattern.calculate_quality(manager);
355            pattern.pattern_type = PatternType::UserSpecified;
356            patterns.push(pattern);
357        }
358
359        patterns
360    }
361
362    fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
363        let mut patterns = Vec::new();
364        let func_apps = self.collect_function_applications(body, manager);
365
366        for func_app in func_apps {
367            let mut pattern = Pattern::new(vec![func_app]);
368            pattern.extract_variables(manager);
369            pattern.calculate_quality(manager);
370            pattern.pattern_type = PatternType::AutoGenerated;
371            patterns.push(pattern);
372        }
373
374        patterns
375    }
376
377    fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
378        let mut patterns = Vec::new();
379        let equalities = self.collect_equalities(body, manager);
380
381        for eq_term in equalities {
382            let mut pattern = Pattern::new(vec![eq_term]);
383            pattern.extract_variables(manager);
384            pattern.calculate_quality(manager);
385            pattern.pattern_type = PatternType::AutoGenerated;
386            patterns.push(pattern);
387        }
388
389        patterns
390    }
391
392    fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
393        let mut patterns = Vec::new();
394        let arith_terms = self.collect_arithmetic_terms(body, manager);
395
396        for arith_term in arith_terms {
397            let mut pattern = Pattern::new(vec![arith_term]);
398            pattern.extract_variables(manager);
399            pattern.calculate_quality(manager);
400            pattern.pattern_type = PatternType::AutoGenerated;
401            patterns.push(pattern);
402        }
403
404        patterns
405    }
406
407    fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
408        let mut results = Vec::new();
409        let mut visited = FxHashSet::default();
410        self.collect_funcs_rec(term, &mut results, &mut visited, manager);
411        results
412    }
413
414    fn collect_funcs_rec(
415        &self,
416        term: TermId,
417        results: &mut Vec<TermId>,
418        visited: &mut FxHashSet<TermId>,
419        manager: &TermManager,
420    ) {
421        if visited.contains(&term) {
422            return;
423        }
424        visited.insert(term);
425
426        let Some(t) = manager.get(term) else {
427            return;
428        };
429
430        if let TermKind::Apply { args, .. } = &t.kind {
431            results.push(term);
432            for &arg in args.iter() {
433                self.collect_funcs_rec(arg, results, visited, manager);
434            }
435        }
436
437        // Recurse into other term types
438        match &t.kind {
439            TermKind::Not(arg) | TermKind::Neg(arg) => {
440                self.collect_funcs_rec(*arg, results, visited, manager);
441            }
442            TermKind::And(args) | TermKind::Or(args) => {
443                for &arg in args {
444                    self.collect_funcs_rec(arg, results, visited, manager);
445                }
446            }
447            _ => {}
448        }
449    }
450
451    fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
452        let mut results = Vec::new();
453        let mut visited = FxHashSet::default();
454        self.collect_eqs_rec(term, &mut results, &mut visited, manager);
455        results
456    }
457
458    fn collect_eqs_rec(
459        &self,
460        term: TermId,
461        results: &mut Vec<TermId>,
462        visited: &mut FxHashSet<TermId>,
463        manager: &TermManager,
464    ) {
465        if visited.contains(&term) {
466            return;
467        }
468        visited.insert(term);
469
470        let Some(t) = manager.get(term) else {
471            return;
472        };
473
474        if matches!(t.kind, TermKind::Eq(_, _)) {
475            results.push(term);
476        }
477
478        match &t.kind {
479            TermKind::Not(arg) | TermKind::Neg(arg) => {
480                self.collect_eqs_rec(*arg, results, visited, manager);
481            }
482            TermKind::And(args) | TermKind::Or(args) => {
483                for &arg in args {
484                    self.collect_eqs_rec(arg, results, visited, manager);
485                }
486            }
487            _ => {}
488        }
489    }
490
491    fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
492        let mut results = Vec::new();
493        let mut visited = FxHashSet::default();
494        self.collect_arith_rec(term, &mut results, &mut visited, manager);
495        results
496    }
497
498    fn collect_arith_rec(
499        &self,
500        term: TermId,
501        results: &mut Vec<TermId>,
502        visited: &mut FxHashSet<TermId>,
503        manager: &TermManager,
504    ) {
505        if visited.contains(&term) {
506            return;
507        }
508        visited.insert(term);
509
510        let Some(t) = manager.get(term) else {
511            return;
512        };
513
514        match &t.kind {
515            TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
516                results.push(term);
517            }
518            TermKind::Not(arg) | TermKind::Neg(arg) => {
519                self.collect_arith_rec(*arg, results, visited, manager);
520            }
521            TermKind::And(args) | TermKind::Or(args) => {
522                for &arg in args {
523                    self.collect_arith_rec(arg, results, visited, manager);
524                }
525            }
526            _ => {}
527        }
528    }
529
530    /// Get statistics
531    pub fn stats(&self) -> &GeneratorStats {
532        &self.stats
533    }
534}
535
536impl Default for PatternGenerator {
537    fn default() -> Self {
538        Self::new()
539    }
540}
541
542/// Statistics for pattern generation
543#[derive(Debug, Clone, Default)]
544pub struct GeneratorStats {
545    /// Number of generation calls
546    pub num_generations: usize,
547    /// Total patterns generated
548    pub num_patterns_generated: usize,
549}
550
551/// Multi-pattern coordinator
552#[derive(Debug)]
553pub struct MultiPatternCoordinator {
554    /// Pattern sets
555    pattern_sets: Vec<PatternSet>,
556    /// Matching cache
557    match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
558}
559
560impl MultiPatternCoordinator {
561    /// Create a new coordinator
562    pub fn new() -> Self {
563        Self {
564            pattern_sets: Vec::new(),
565            match_cache: FxHashMap::default(),
566        }
567    }
568
569    /// Add a pattern set
570    pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>, manager: &TermManager) {
571        self.pattern_sets
572            .push(PatternSet::from_patterns(patterns, manager));
573    }
574
575    /// Find matches for all pattern sets
576    pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
577        let mut multi_matches = Vec::new();
578
579        for pattern_set in &self.pattern_sets {
580            // Find matches for each pattern in the set
581            let mut set_matches = Vec::new();
582
583            for pattern in &pattern_set.patterns {
584                for &term in &pattern.terms {
585                    if let Some(cached) = self.match_cache.get(&term) {
586                        set_matches.extend(cached.clone());
587                    }
588                }
589            }
590
591            // Combine matches
592            if !set_matches.is_empty() {
593                multi_matches.push(MultiMatch {
594                    pattern_set: pattern_set.patterns.clone(),
595                    matches: set_matches,
596                });
597            }
598        }
599
600        multi_matches
601    }
602
603    /// Clear cache
604    pub fn clear_cache(&mut self) {
605        self.match_cache.clear();
606    }
607}
608
609impl Default for MultiPatternCoordinator {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615/// A set of patterns that must be matched together
616#[derive(Debug, Clone)]
617pub struct PatternSet {
618    pub patterns: Vec<Pattern>,
619    pub matches: Vec<PatternMatch>,
620    pub covered_shapes: FxHashSet<TermShape>,
621}
622
623impl PatternSet {
624    /// Build a pattern set and precompute the term shapes it can match.
625    pub fn from_patterns(patterns: Vec<Pattern>, manager: &TermManager) -> Self {
626        let mut covered_shapes = FxHashSet::default();
627        for pattern in &patterns {
628            for &term in &pattern.terms {
629                covered_shapes.insert(TermShape::from_term(term, manager));
630            }
631        }
632        Self {
633            patterns,
634            matches: Vec::new(),
635            covered_shapes,
636        }
637    }
638}
639
640/// A match for a pattern
641#[derive(Debug, Clone)]
642pub struct PatternMatch {
643    /// The pattern that matched
644    pub pattern: Pattern,
645    /// The matched term
646    pub matched_term: TermId,
647    /// Variable bindings
648    pub bindings: FxHashMap<Spur, TermId>,
649}
650
651/// A multi-pattern match
652#[derive(Debug, Clone)]
653pub struct MultiMatch {
654    /// The pattern set
655    pub pattern_set: Vec<Pattern>,
656    /// Individual matches
657    pub matches: Vec<PatternMatch>,
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_pattern_creation() {
666        let pattern = Pattern::new(vec![TermId::new(1)]);
667        assert_eq!(pattern.terms.len(), 1);
668        assert_eq!(pattern.variables.len(), 0);
669    }
670
671    #[test]
672    fn test_pattern_type_equality() {
673        assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
674        assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
675    }
676
677    #[test]
678    fn test_pattern_generator_creation() {
679        let generator = PatternGenerator::new();
680        assert_eq!(generator.max_patterns, 10);
681    }
682
683    #[test]
684    fn test_multi_pattern_coordinator() {
685        let mut coord = MultiPatternCoordinator::new();
686        let manager = TermManager::new();
687        coord.add_pattern_set(vec![], &manager);
688        assert_eq!(coord.pattern_sets.len(), 1);
689    }
690
691    #[test]
692    fn test_pattern_equality() {
693        let p1 = Pattern::new(vec![TermId::new(1)]);
694        let p2 = Pattern::new(vec![TermId::new(1)]);
695        assert_eq!(p1, p2);
696    }
697}