Skip to main content

oxiz_solver/mbqi/
patterns.rs

1//! Pattern Matching and Trigger Generation for MBQI
2//!
3//! This module implements sophisticated pattern matching and trigger generation
4//! algorithms for E-matching style quantifier instantiation.
5
6#[allow(unused_imports)]
7use crate::prelude::*;
8use oxiz_core::ast::{TermId, TermKind, TermManager};
9use oxiz_core::interner::Spur;
10
11use super::QuantifiedFormula;
12
13/// A pattern for E-matching
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Pattern {
16    /// The pattern terms
17    pub terms: Vec<TermId>,
18    /// Variables in the pattern
19    pub variables: FxHashSet<Spur>,
20    /// Pattern quality score
21    pub quality: u32,
22    /// Pattern type
23    pub pattern_type: PatternType,
24}
25
26impl Pattern {
27    /// Create a new pattern
28    pub fn new(terms: Vec<TermId>) -> Self {
29        Self {
30            terms,
31            variables: FxHashSet::default(),
32            quality: 0,
33            pattern_type: PatternType::MultiPattern,
34        }
35    }
36
37    /// Extract variables from the pattern
38    pub fn extract_variables(&mut self, manager: &TermManager) {
39        self.variables.clear();
40        // Collect terms first to avoid borrow checker issues
41        let terms: Vec<_> = self.terms.to_vec();
42        for term in terms {
43            self.extract_vars_rec(term, manager);
44        }
45    }
46
47    fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
48        let mut visited = FxHashSet::default();
49        self.extract_vars_helper(term, manager, &mut visited);
50    }
51
52    fn extract_vars_helper(
53        &mut self,
54        term: TermId,
55        manager: &TermManager,
56        visited: &mut FxHashSet<TermId>,
57    ) {
58        if visited.contains(&term) {
59            return;
60        }
61        visited.insert(term);
62
63        let Some(t) = manager.get(term) else {
64            return;
65        };
66
67        if let TermKind::Var(name) = t.kind {
68            self.variables.insert(name);
69            return;
70        }
71
72        match &t.kind {
73            TermKind::Apply { args, .. } => {
74                for &arg in args.iter() {
75                    self.extract_vars_helper(arg, manager, visited);
76                }
77            }
78            TermKind::Not(arg) | TermKind::Neg(arg) => {
79                self.extract_vars_helper(*arg, manager, visited);
80            }
81            TermKind::And(args) | TermKind::Or(args) => {
82                for &arg in args {
83                    self.extract_vars_helper(arg, manager, visited);
84                }
85            }
86            _ => {}
87        }
88    }
89
90    /// Calculate pattern quality
91    pub fn calculate_quality(&mut self, manager: &TermManager) {
92        // Quality factors:
93        // 1. Number of function symbols (more = better)
94        // 2. Number of variables covered
95        // 3. Pattern complexity
96
97        let num_funcs = self.count_function_symbols(manager);
98        let num_vars = self.variables.len();
99        let complexity_penalty = self.terms.len();
100
101        self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
102    }
103
104    fn count_function_symbols(&self, manager: &TermManager) -> usize {
105        let mut count = 0;
106        let mut visited = FxHashSet::default();
107
108        for &term in &self.terms {
109            count += self.count_funcs_rec(term, manager, &mut visited);
110        }
111
112        count
113    }
114
115    fn count_funcs_rec(
116        &self,
117        term: TermId,
118        manager: &TermManager,
119        visited: &mut FxHashSet<TermId>,
120    ) -> usize {
121        if visited.contains(&term) {
122            return 0;
123        }
124        visited.insert(term);
125
126        let Some(t) = manager.get(term) else {
127            return 0;
128        };
129
130        match &t.kind {
131            TermKind::Apply { args, .. } => {
132                1 + args
133                    .iter()
134                    .map(|&arg| self.count_funcs_rec(arg, manager, visited))
135                    .sum::<usize>()
136            }
137            _ => 0,
138        }
139    }
140}
141
142/// Type of pattern
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144pub enum PatternType {
145    /// Single term pattern
146    SingleTerm,
147    /// Multi-pattern (multiple terms)
148    MultiPattern,
149    /// User-specified pattern
150    UserSpecified,
151    /// Auto-generated pattern
152    AutoGenerated,
153}
154
155/// Pattern generator
156#[derive(Debug)]
157pub struct PatternGenerator {
158    /// Maximum patterns to generate
159    max_patterns: usize,
160    /// Minimum pattern quality
161    min_quality: u32,
162    /// Statistics
163    stats: GeneratorStats,
164}
165
166impl PatternGenerator {
167    /// Create a new pattern generator
168    pub fn new() -> Self {
169        Self {
170            max_patterns: 10,
171            min_quality: 0,
172            stats: GeneratorStats::default(),
173        }
174    }
175
176    /// Generate patterns for a quantifier
177    pub fn generate(
178        &mut self,
179        quantifier: &QuantifiedFormula,
180        manager: &TermManager,
181    ) -> Vec<Pattern> {
182        self.stats.num_generations += 1;
183
184        // If user specified patterns, use those
185        if !quantifier.patterns.is_empty() {
186            return self.user_patterns_to_patterns(&quantifier.patterns, manager);
187        }
188
189        // Auto-generate patterns
190        let mut patterns = Vec::new();
191
192        // Strategy 1: Function application patterns
193        patterns.extend(self.generate_function_patterns(quantifier.body, manager));
194
195        // Strategy 2: Equality patterns
196        patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
197
198        // Strategy 3: Arithmetic patterns
199        patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
200
201        // Filter by quality
202        patterns.retain(|p| p.quality >= self.min_quality);
203
204        // Sort by quality (best first)
205        patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
206
207        // Limit number of patterns
208        patterns.truncate(self.max_patterns);
209
210        self.stats.num_patterns_generated += patterns.len();
211
212        patterns
213    }
214
215    fn user_patterns_to_patterns(
216        &self,
217        user_patterns: &[Vec<TermId>],
218        manager: &TermManager,
219    ) -> Vec<Pattern> {
220        let mut patterns = Vec::new();
221
222        for pattern_terms in user_patterns {
223            let mut pattern = Pattern::new(pattern_terms.clone());
224            pattern.extract_variables(manager);
225            pattern.calculate_quality(manager);
226            pattern.pattern_type = PatternType::UserSpecified;
227            patterns.push(pattern);
228        }
229
230        patterns
231    }
232
233    fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
234        let mut patterns = Vec::new();
235        let func_apps = self.collect_function_applications(body, manager);
236
237        for func_app in func_apps {
238            let mut pattern = Pattern::new(vec![func_app]);
239            pattern.extract_variables(manager);
240            pattern.calculate_quality(manager);
241            pattern.pattern_type = PatternType::AutoGenerated;
242            patterns.push(pattern);
243        }
244
245        patterns
246    }
247
248    fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
249        let mut patterns = Vec::new();
250        let equalities = self.collect_equalities(body, manager);
251
252        for eq_term in equalities {
253            let mut pattern = Pattern::new(vec![eq_term]);
254            pattern.extract_variables(manager);
255            pattern.calculate_quality(manager);
256            pattern.pattern_type = PatternType::AutoGenerated;
257            patterns.push(pattern);
258        }
259
260        patterns
261    }
262
263    fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
264        let mut patterns = Vec::new();
265        let arith_terms = self.collect_arithmetic_terms(body, manager);
266
267        for arith_term in arith_terms {
268            let mut pattern = Pattern::new(vec![arith_term]);
269            pattern.extract_variables(manager);
270            pattern.calculate_quality(manager);
271            pattern.pattern_type = PatternType::AutoGenerated;
272            patterns.push(pattern);
273        }
274
275        patterns
276    }
277
278    fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
279        let mut results = Vec::new();
280        let mut visited = FxHashSet::default();
281        self.collect_funcs_rec(term, &mut results, &mut visited, manager);
282        results
283    }
284
285    fn collect_funcs_rec(
286        &self,
287        term: TermId,
288        results: &mut Vec<TermId>,
289        visited: &mut FxHashSet<TermId>,
290        manager: &TermManager,
291    ) {
292        if visited.contains(&term) {
293            return;
294        }
295        visited.insert(term);
296
297        let Some(t) = manager.get(term) else {
298            return;
299        };
300
301        if let TermKind::Apply { args, .. } = &t.kind {
302            results.push(term);
303            for &arg in args.iter() {
304                self.collect_funcs_rec(arg, results, visited, manager);
305            }
306        }
307
308        // Recurse into other term types
309        match &t.kind {
310            TermKind::Not(arg) | TermKind::Neg(arg) => {
311                self.collect_funcs_rec(*arg, results, visited, manager);
312            }
313            TermKind::And(args) | TermKind::Or(args) => {
314                for &arg in args {
315                    self.collect_funcs_rec(arg, results, visited, manager);
316                }
317            }
318            _ => {}
319        }
320    }
321
322    fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
323        let mut results = Vec::new();
324        let mut visited = FxHashSet::default();
325        self.collect_eqs_rec(term, &mut results, &mut visited, manager);
326        results
327    }
328
329    fn collect_eqs_rec(
330        &self,
331        term: TermId,
332        results: &mut Vec<TermId>,
333        visited: &mut FxHashSet<TermId>,
334        manager: &TermManager,
335    ) {
336        if visited.contains(&term) {
337            return;
338        }
339        visited.insert(term);
340
341        let Some(t) = manager.get(term) else {
342            return;
343        };
344
345        if matches!(t.kind, TermKind::Eq(_, _)) {
346            results.push(term);
347        }
348
349        match &t.kind {
350            TermKind::Not(arg) | TermKind::Neg(arg) => {
351                self.collect_eqs_rec(*arg, results, visited, manager);
352            }
353            TermKind::And(args) | TermKind::Or(args) => {
354                for &arg in args {
355                    self.collect_eqs_rec(arg, results, visited, manager);
356                }
357            }
358            _ => {}
359        }
360    }
361
362    fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
363        let mut results = Vec::new();
364        let mut visited = FxHashSet::default();
365        self.collect_arith_rec(term, &mut results, &mut visited, manager);
366        results
367    }
368
369    fn collect_arith_rec(
370        &self,
371        term: TermId,
372        results: &mut Vec<TermId>,
373        visited: &mut FxHashSet<TermId>,
374        manager: &TermManager,
375    ) {
376        if visited.contains(&term) {
377            return;
378        }
379        visited.insert(term);
380
381        let Some(t) = manager.get(term) else {
382            return;
383        };
384
385        match &t.kind {
386            TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
387                results.push(term);
388            }
389            TermKind::Not(arg) | TermKind::Neg(arg) => {
390                self.collect_arith_rec(*arg, results, visited, manager);
391            }
392            TermKind::And(args) | TermKind::Or(args) => {
393                for &arg in args {
394                    self.collect_arith_rec(arg, results, visited, manager);
395                }
396            }
397            _ => {}
398        }
399    }
400
401    /// Get statistics
402    pub fn stats(&self) -> &GeneratorStats {
403        &self.stats
404    }
405}
406
407impl Default for PatternGenerator {
408    fn default() -> Self {
409        Self::new()
410    }
411}
412
413/// Statistics for pattern generation
414#[derive(Debug, Clone, Default)]
415pub struct GeneratorStats {
416    /// Number of generation calls
417    pub num_generations: usize,
418    /// Total patterns generated
419    pub num_patterns_generated: usize,
420}
421
422/// Multi-pattern coordinator
423#[derive(Debug)]
424pub struct MultiPatternCoordinator {
425    /// Pattern sets
426    pattern_sets: Vec<PatternSet>,
427    /// Matching cache
428    match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
429}
430
431impl MultiPatternCoordinator {
432    /// Create a new coordinator
433    pub fn new() -> Self {
434        Self {
435            pattern_sets: Vec::new(),
436            match_cache: FxHashMap::default(),
437        }
438    }
439
440    /// Add a pattern set
441    pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>) {
442        self.pattern_sets.push(PatternSet {
443            patterns,
444            matches: Vec::new(),
445        });
446    }
447
448    /// Find matches for all pattern sets
449    pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
450        let mut multi_matches = Vec::new();
451
452        for pattern_set in &self.pattern_sets {
453            // Find matches for each pattern in the set
454            let mut set_matches = Vec::new();
455
456            for pattern in &pattern_set.patterns {
457                for &term in &pattern.terms {
458                    if let Some(cached) = self.match_cache.get(&term) {
459                        set_matches.extend(cached.clone());
460                    }
461                }
462            }
463
464            // Combine matches
465            if !set_matches.is_empty() {
466                multi_matches.push(MultiMatch {
467                    pattern_set: pattern_set.patterns.clone(),
468                    matches: set_matches,
469                });
470            }
471        }
472
473        multi_matches
474    }
475
476    /// Clear cache
477    pub fn clear_cache(&mut self) {
478        self.match_cache.clear();
479    }
480}
481
482impl Default for MultiPatternCoordinator {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488/// A set of patterns that must be matched together
489#[derive(Debug, Clone)]
490struct PatternSet {
491    patterns: Vec<Pattern>,
492    matches: Vec<PatternMatch>,
493}
494
495/// A match for a pattern
496#[derive(Debug, Clone)]
497pub struct PatternMatch {
498    /// The pattern that matched
499    pub pattern: Pattern,
500    /// The matched term
501    pub matched_term: TermId,
502    /// Variable bindings
503    pub bindings: FxHashMap<Spur, TermId>,
504}
505
506/// A multi-pattern match
507#[derive(Debug, Clone)]
508pub struct MultiMatch {
509    /// The pattern set
510    pub pattern_set: Vec<Pattern>,
511    /// Individual matches
512    pub matches: Vec<PatternMatch>,
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_pattern_creation() {
521        let pattern = Pattern::new(vec![TermId::new(1)]);
522        assert_eq!(pattern.terms.len(), 1);
523        assert_eq!(pattern.variables.len(), 0);
524    }
525
526    #[test]
527    fn test_pattern_type_equality() {
528        assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
529        assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
530    }
531
532    #[test]
533    fn test_pattern_generator_creation() {
534        let generator = PatternGenerator::new();
535        assert_eq!(generator.max_patterns, 10);
536    }
537
538    #[test]
539    fn test_multi_pattern_coordinator() {
540        let mut coord = MultiPatternCoordinator::new();
541        coord.add_pattern_set(vec![]);
542        assert_eq!(coord.pattern_sets.len(), 1);
543    }
544
545    #[test]
546    fn test_pattern_equality() {
547        let p1 = Pattern::new(vec![TermId::new(1)]);
548        let p2 = Pattern::new(vec![TermId::new(1)]);
549        assert_eq!(p1, p2);
550    }
551}