Skip to main content

oxiz_solver/mbqi/
patterns.rs

1//! Pattern Matching and Trigger Generation for MBQI
2//!
3//! This module implements sophisticated pattern matching and trigger generation
4//! algorithms for E-matching style quantifier instantiation.
5
6use lasso::Spur;
7use oxiz_core::ast::{TermId, TermKind, TermManager};
8use rustc_hash::{FxHashMap, FxHashSet};
9
10use super::QuantifiedFormula;
11
12/// A pattern for E-matching
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct Pattern {
15    /// The pattern terms
16    pub terms: Vec<TermId>,
17    /// Variables in the pattern
18    pub variables: FxHashSet<Spur>,
19    /// Pattern quality score
20    pub quality: u32,
21    /// Pattern type
22    pub pattern_type: PatternType,
23}
24
25impl Pattern {
26    /// Create a new pattern
27    pub fn new(terms: Vec<TermId>) -> Self {
28        Self {
29            terms,
30            variables: FxHashSet::default(),
31            quality: 0,
32            pattern_type: PatternType::MultiPattern,
33        }
34    }
35
36    /// Extract variables from the pattern
37    pub fn extract_variables(&mut self, manager: &TermManager) {
38        self.variables.clear();
39        // Collect terms first to avoid borrow checker issues
40        let terms: Vec<_> = self.terms.to_vec();
41        for term in terms {
42            self.extract_vars_rec(term, manager);
43        }
44    }
45
46    fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
47        let mut visited = FxHashSet::default();
48        self.extract_vars_helper(term, manager, &mut visited);
49    }
50
51    fn extract_vars_helper(
52        &mut self,
53        term: TermId,
54        manager: &TermManager,
55        visited: &mut FxHashSet<TermId>,
56    ) {
57        if visited.contains(&term) {
58            return;
59        }
60        visited.insert(term);
61
62        let Some(t) = manager.get(term) else {
63            return;
64        };
65
66        if let TermKind::Var(name) = t.kind {
67            self.variables.insert(name);
68            return;
69        }
70
71        match &t.kind {
72            TermKind::Apply { args, .. } => {
73                for &arg in args.iter() {
74                    self.extract_vars_helper(arg, manager, visited);
75                }
76            }
77            TermKind::Not(arg) | TermKind::Neg(arg) => {
78                self.extract_vars_helper(*arg, manager, visited);
79            }
80            TermKind::And(args) | TermKind::Or(args) => {
81                for &arg in args {
82                    self.extract_vars_helper(arg, manager, visited);
83                }
84            }
85            _ => {}
86        }
87    }
88
89    /// Calculate pattern quality
90    pub fn calculate_quality(&mut self, manager: &TermManager) {
91        // Quality factors:
92        // 1. Number of function symbols (more = better)
93        // 2. Number of variables covered
94        // 3. Pattern complexity
95
96        let num_funcs = self.count_function_symbols(manager);
97        let num_vars = self.variables.len();
98        let complexity_penalty = self.terms.len();
99
100        self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
101    }
102
103    fn count_function_symbols(&self, manager: &TermManager) -> usize {
104        let mut count = 0;
105        let mut visited = FxHashSet::default();
106
107        for &term in &self.terms {
108            count += self.count_funcs_rec(term, manager, &mut visited);
109        }
110
111        count
112    }
113
114    fn count_funcs_rec(
115        &self,
116        term: TermId,
117        manager: &TermManager,
118        visited: &mut FxHashSet<TermId>,
119    ) -> usize {
120        if visited.contains(&term) {
121            return 0;
122        }
123        visited.insert(term);
124
125        let Some(t) = manager.get(term) else {
126            return 0;
127        };
128
129        match &t.kind {
130            TermKind::Apply { args, .. } => {
131                1 + args
132                    .iter()
133                    .map(|&arg| self.count_funcs_rec(arg, manager, visited))
134                    .sum::<usize>()
135            }
136            _ => 0,
137        }
138    }
139}
140
141/// Type of pattern
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub enum PatternType {
144    /// Single term pattern
145    SingleTerm,
146    /// Multi-pattern (multiple terms)
147    MultiPattern,
148    /// User-specified pattern
149    UserSpecified,
150    /// Auto-generated pattern
151    AutoGenerated,
152}
153
154/// Pattern generator
155#[derive(Debug)]
156pub struct PatternGenerator {
157    /// Maximum patterns to generate
158    max_patterns: usize,
159    /// Minimum pattern quality
160    min_quality: u32,
161    /// Statistics
162    stats: GeneratorStats,
163}
164
165impl PatternGenerator {
166    /// Create a new pattern generator
167    pub fn new() -> Self {
168        Self {
169            max_patterns: 10,
170            min_quality: 0,
171            stats: GeneratorStats::default(),
172        }
173    }
174
175    /// Generate patterns for a quantifier
176    pub fn generate(
177        &mut self,
178        quantifier: &QuantifiedFormula,
179        manager: &TermManager,
180    ) -> Vec<Pattern> {
181        self.stats.num_generations += 1;
182
183        // If user specified patterns, use those
184        if !quantifier.patterns.is_empty() {
185            return self.user_patterns_to_patterns(&quantifier.patterns, manager);
186        }
187
188        // Auto-generate patterns
189        let mut patterns = Vec::new();
190
191        // Strategy 1: Function application patterns
192        patterns.extend(self.generate_function_patterns(quantifier.body, manager));
193
194        // Strategy 2: Equality patterns
195        patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
196
197        // Strategy 3: Arithmetic patterns
198        patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
199
200        // Filter by quality
201        patterns.retain(|p| p.quality >= self.min_quality);
202
203        // Sort by quality (best first)
204        patterns.sort_by(|a, b| b.quality.cmp(&a.quality));
205
206        // Limit number of patterns
207        patterns.truncate(self.max_patterns);
208
209        self.stats.num_patterns_generated += patterns.len();
210
211        patterns
212    }
213
214    fn user_patterns_to_patterns(
215        &self,
216        user_patterns: &[Vec<TermId>],
217        manager: &TermManager,
218    ) -> Vec<Pattern> {
219        let mut patterns = Vec::new();
220
221        for pattern_terms in user_patterns {
222            let mut pattern = Pattern::new(pattern_terms.clone());
223            pattern.extract_variables(manager);
224            pattern.calculate_quality(manager);
225            pattern.pattern_type = PatternType::UserSpecified;
226            patterns.push(pattern);
227        }
228
229        patterns
230    }
231
232    fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
233        let mut patterns = Vec::new();
234        let func_apps = self.collect_function_applications(body, manager);
235
236        for func_app in func_apps {
237            let mut pattern = Pattern::new(vec![func_app]);
238            pattern.extract_variables(manager);
239            pattern.calculate_quality(manager);
240            pattern.pattern_type = PatternType::AutoGenerated;
241            patterns.push(pattern);
242        }
243
244        patterns
245    }
246
247    fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
248        let mut patterns = Vec::new();
249        let equalities = self.collect_equalities(body, manager);
250
251        for eq_term in equalities {
252            let mut pattern = Pattern::new(vec![eq_term]);
253            pattern.extract_variables(manager);
254            pattern.calculate_quality(manager);
255            pattern.pattern_type = PatternType::AutoGenerated;
256            patterns.push(pattern);
257        }
258
259        patterns
260    }
261
262    fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
263        let mut patterns = Vec::new();
264        let arith_terms = self.collect_arithmetic_terms(body, manager);
265
266        for arith_term in arith_terms {
267            let mut pattern = Pattern::new(vec![arith_term]);
268            pattern.extract_variables(manager);
269            pattern.calculate_quality(manager);
270            pattern.pattern_type = PatternType::AutoGenerated;
271            patterns.push(pattern);
272        }
273
274        patterns
275    }
276
277    fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
278        let mut results = Vec::new();
279        let mut visited = FxHashSet::default();
280        self.collect_funcs_rec(term, &mut results, &mut visited, manager);
281        results
282    }
283
284    fn collect_funcs_rec(
285        &self,
286        term: TermId,
287        results: &mut Vec<TermId>,
288        visited: &mut FxHashSet<TermId>,
289        manager: &TermManager,
290    ) {
291        if visited.contains(&term) {
292            return;
293        }
294        visited.insert(term);
295
296        let Some(t) = manager.get(term) else {
297            return;
298        };
299
300        if let TermKind::Apply { args, .. } = &t.kind {
301            results.push(term);
302            for &arg in args.iter() {
303                self.collect_funcs_rec(arg, results, visited, manager);
304            }
305        }
306
307        // Recurse into other term types
308        match &t.kind {
309            TermKind::Not(arg) | TermKind::Neg(arg) => {
310                self.collect_funcs_rec(*arg, results, visited, manager);
311            }
312            TermKind::And(args) | TermKind::Or(args) => {
313                for &arg in args {
314                    self.collect_funcs_rec(arg, results, visited, manager);
315                }
316            }
317            _ => {}
318        }
319    }
320
321    fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
322        let mut results = Vec::new();
323        let mut visited = FxHashSet::default();
324        self.collect_eqs_rec(term, &mut results, &mut visited, manager);
325        results
326    }
327
328    fn collect_eqs_rec(
329        &self,
330        term: TermId,
331        results: &mut Vec<TermId>,
332        visited: &mut FxHashSet<TermId>,
333        manager: &TermManager,
334    ) {
335        if visited.contains(&term) {
336            return;
337        }
338        visited.insert(term);
339
340        let Some(t) = manager.get(term) else {
341            return;
342        };
343
344        if matches!(t.kind, TermKind::Eq(_, _)) {
345            results.push(term);
346        }
347
348        match &t.kind {
349            TermKind::Not(arg) | TermKind::Neg(arg) => {
350                self.collect_eqs_rec(*arg, results, visited, manager);
351            }
352            TermKind::And(args) | TermKind::Or(args) => {
353                for &arg in args {
354                    self.collect_eqs_rec(arg, results, visited, manager);
355                }
356            }
357            _ => {}
358        }
359    }
360
361    fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
362        let mut results = Vec::new();
363        let mut visited = FxHashSet::default();
364        self.collect_arith_rec(term, &mut results, &mut visited, manager);
365        results
366    }
367
368    fn collect_arith_rec(
369        &self,
370        term: TermId,
371        results: &mut Vec<TermId>,
372        visited: &mut FxHashSet<TermId>,
373        manager: &TermManager,
374    ) {
375        if visited.contains(&term) {
376            return;
377        }
378        visited.insert(term);
379
380        let Some(t) = manager.get(term) else {
381            return;
382        };
383
384        match &t.kind {
385            TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
386                results.push(term);
387            }
388            TermKind::Not(arg) | TermKind::Neg(arg) => {
389                self.collect_arith_rec(*arg, results, visited, manager);
390            }
391            TermKind::And(args) | TermKind::Or(args) => {
392                for &arg in args {
393                    self.collect_arith_rec(arg, results, visited, manager);
394                }
395            }
396            _ => {}
397        }
398    }
399
400    /// Get statistics
401    pub fn stats(&self) -> &GeneratorStats {
402        &self.stats
403    }
404}
405
406impl Default for PatternGenerator {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412/// Statistics for pattern generation
413#[derive(Debug, Clone, Default)]
414pub struct GeneratorStats {
415    /// Number of generation calls
416    pub num_generations: usize,
417    /// Total patterns generated
418    pub num_patterns_generated: usize,
419}
420
421/// Multi-pattern coordinator
422#[derive(Debug)]
423pub struct MultiPatternCoordinator {
424    /// Pattern sets
425    pattern_sets: Vec<PatternSet>,
426    /// Matching cache
427    match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
428}
429
430impl MultiPatternCoordinator {
431    /// Create a new coordinator
432    pub fn new() -> Self {
433        Self {
434            pattern_sets: Vec::new(),
435            match_cache: FxHashMap::default(),
436        }
437    }
438
439    /// Add a pattern set
440    pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>) {
441        self.pattern_sets.push(PatternSet {
442            patterns,
443            matches: Vec::new(),
444        });
445    }
446
447    /// Find matches for all pattern sets
448    pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
449        let mut multi_matches = Vec::new();
450
451        for pattern_set in &self.pattern_sets {
452            // Find matches for each pattern in the set
453            let mut set_matches = Vec::new();
454
455            for pattern in &pattern_set.patterns {
456                for &term in &pattern.terms {
457                    if let Some(cached) = self.match_cache.get(&term) {
458                        set_matches.extend(cached.clone());
459                    }
460                }
461            }
462
463            // Combine matches
464            if !set_matches.is_empty() {
465                multi_matches.push(MultiMatch {
466                    pattern_set: pattern_set.patterns.clone(),
467                    matches: set_matches,
468                });
469            }
470        }
471
472        multi_matches
473    }
474
475    /// Clear cache
476    pub fn clear_cache(&mut self) {
477        self.match_cache.clear();
478    }
479}
480
481impl Default for MultiPatternCoordinator {
482    fn default() -> Self {
483        Self::new()
484    }
485}
486
487/// A set of patterns that must be matched together
488#[derive(Debug, Clone)]
489struct PatternSet {
490    patterns: Vec<Pattern>,
491    matches: Vec<PatternMatch>,
492}
493
494/// A match for a pattern
495#[derive(Debug, Clone)]
496pub struct PatternMatch {
497    /// The pattern that matched
498    pub pattern: Pattern,
499    /// The matched term
500    pub matched_term: TermId,
501    /// Variable bindings
502    pub bindings: FxHashMap<Spur, TermId>,
503}
504
505/// A multi-pattern match
506#[derive(Debug, Clone)]
507pub struct MultiMatch {
508    /// The pattern set
509    pub pattern_set: Vec<Pattern>,
510    /// Individual matches
511    pub matches: Vec<PatternMatch>,
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_pattern_creation() {
520        let pattern = Pattern::new(vec![TermId::new(1)]);
521        assert_eq!(pattern.terms.len(), 1);
522        assert_eq!(pattern.variables.len(), 0);
523    }
524
525    #[test]
526    fn test_pattern_type_equality() {
527        assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
528        assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
529    }
530
531    #[test]
532    fn test_pattern_generator_creation() {
533        let generator = PatternGenerator::new();
534        assert_eq!(generator.max_patterns, 10);
535    }
536
537    #[test]
538    fn test_multi_pattern_coordinator() {
539        let mut coord = MultiPatternCoordinator::new();
540        coord.add_pattern_set(vec![]);
541        assert_eq!(coord.pattern_sets.len(), 1);
542    }
543
544    #[test]
545    fn test_pattern_equality() {
546        let p1 = Pattern::new(vec![TermId::new(1)]);
547        let p2 = Pattern::new(vec![TermId::new(1)]);
548        assert_eq!(p1, p2);
549    }
550}