1#[allow(unused_imports)]
7use crate::prelude::*;
8use oxiz_core::ast::{TermId, TermKind, TermManager};
9use oxiz_core::interner::Spur;
10
11use super::{QuantifiedFormula, QuantifierConfig};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum PatternStrategy {
16 StaticDepth,
18 GreedyCover,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum TermShape {
25 BoolConst,
27 IntConst,
29 RealConst,
31 Var,
33 Apply { arity: usize },
35 Eq,
37 StrictIneq,
39 NonStrictIneq,
41 Add { arity: usize },
43 Mul { arity: usize },
45 Other,
47}
48
49impl TermShape {
50 fn from_term(term: TermId, manager: &TermManager) -> Self {
51 let Some(node) = manager.get(term) else {
52 return Self::Other;
53 };
54
55 match &node.kind {
56 TermKind::True | TermKind::False => Self::BoolConst,
57 TermKind::IntConst(_) => Self::IntConst,
58 TermKind::RealConst(_) => Self::RealConst,
59 TermKind::Var(_) => Self::Var,
60 TermKind::Apply { args, .. } => Self::Apply { arity: args.len() },
61 TermKind::Eq(_, _) => Self::Eq,
62 TermKind::Lt(_, _) | TermKind::Gt(_, _) => Self::StrictIneq,
63 TermKind::Le(_, _) | TermKind::Ge(_, _) => Self::NonStrictIneq,
64 TermKind::Add(args) => Self::Add { arity: args.len() },
65 TermKind::Mul(args) => Self::Mul { arity: args.len() },
66 _ => Self::Other,
67 }
68 }
69}
70
71#[derive(Debug, Default, Clone)]
73pub struct PatternCoverScorer;
74
75impl PatternCoverScorer {
76 pub fn score_cover(
78 &self,
79 candidate_patterns: &[PatternSet],
80 egraph_ground_terms: &[TermShape],
81 ) -> Vec<(usize, f64)> {
82 if candidate_patterns.is_empty() {
83 return Vec::new();
84 }
85
86 let total_shapes = egraph_ground_terms
87 .iter()
88 .cloned()
89 .collect::<FxHashSet<_>>();
90 if total_shapes.is_empty() {
91 return candidate_patterns
92 .iter()
93 .enumerate()
94 .map(|(idx, _)| (idx, 0.0))
95 .collect();
96 }
97
98 let mut remaining = total_shapes;
99 let mut pending: Vec<usize> = (0..candidate_patterns.len()).collect();
100 let mut ranked = Vec::with_capacity(candidate_patterns.len());
101
102 while !pending.is_empty() {
103 let mut best_pos = 0usize;
104 let mut best_gain = 0usize;
105 let mut best_score = 0.0f64;
106
107 for (pos, &idx) in pending.iter().enumerate() {
108 let covered = candidate_patterns[idx]
109 .covered_shapes
110 .iter()
111 .filter(|shape| remaining.contains(*shape))
112 .count();
113 let score = covered as f64 / egraph_ground_terms.len() as f64;
114 if covered > best_gain || (covered == best_gain && score > best_score) {
115 best_pos = pos;
116 best_gain = covered;
117 best_score = score;
118 }
119 }
120
121 let chosen_idx = pending.remove(best_pos);
122 for shape in &candidate_patterns[chosen_idx].covered_shapes {
123 remaining.remove(shape);
124 }
125 ranked.push((chosen_idx, best_score));
126 }
127
128 ranked
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Pattern {
135 pub terms: Vec<TermId>,
137 pub variables: FxHashSet<Spur>,
139 pub quality: u32,
141 pub pattern_type: PatternType,
143}
144
145impl Pattern {
146 pub fn new(terms: Vec<TermId>) -> Self {
148 Self {
149 terms,
150 variables: FxHashSet::default(),
151 quality: 0,
152 pattern_type: PatternType::MultiPattern,
153 }
154 }
155
156 pub fn extract_variables(&mut self, manager: &TermManager) {
158 self.variables.clear();
159 let terms: Vec<_> = self.terms.to_vec();
161 for term in terms {
162 self.extract_vars_rec(term, manager);
163 }
164 }
165
166 fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
167 let mut visited = FxHashSet::default();
168 self.extract_vars_helper(term, manager, &mut visited);
169 }
170
171 fn extract_vars_helper(
172 &mut self,
173 term: TermId,
174 manager: &TermManager,
175 visited: &mut FxHashSet<TermId>,
176 ) {
177 if visited.contains(&term) {
178 return;
179 }
180 visited.insert(term);
181
182 let Some(t) = manager.get(term) else {
183 return;
184 };
185
186 if let TermKind::Var(name) = t.kind {
187 self.variables.insert(name);
188 return;
189 }
190
191 match &t.kind {
192 TermKind::Apply { args, .. } => {
193 for &arg in args.iter() {
194 self.extract_vars_helper(arg, manager, visited);
195 }
196 }
197 TermKind::Not(arg) | TermKind::Neg(arg) => {
198 self.extract_vars_helper(*arg, manager, visited);
199 }
200 TermKind::And(args) | TermKind::Or(args) => {
201 for &arg in args {
202 self.extract_vars_helper(arg, manager, visited);
203 }
204 }
205 _ => {}
206 }
207 }
208
209 pub fn calculate_quality(&mut self, manager: &TermManager) {
211 let num_funcs = self.count_function_symbols(manager);
217 let num_vars = self.variables.len();
218 let complexity_penalty = self.terms.len();
219
220 self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
221 }
222
223 fn count_function_symbols(&self, manager: &TermManager) -> usize {
224 let mut count = 0;
225 let mut visited = FxHashSet::default();
226
227 for &term in &self.terms {
228 count += self.count_funcs_rec(term, manager, &mut visited);
229 }
230
231 count
232 }
233
234 fn count_funcs_rec(
235 &self,
236 term: TermId,
237 manager: &TermManager,
238 visited: &mut FxHashSet<TermId>,
239 ) -> usize {
240 if visited.contains(&term) {
241 return 0;
242 }
243 visited.insert(term);
244
245 let Some(t) = manager.get(term) else {
246 return 0;
247 };
248
249 match &t.kind {
250 TermKind::Apply { args, .. } => {
251 1 + args
252 .iter()
253 .map(|&arg| self.count_funcs_rec(arg, manager, visited))
254 .sum::<usize>()
255 }
256 _ => 0,
257 }
258 }
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
263pub enum PatternType {
264 SingleTerm,
266 MultiPattern,
268 UserSpecified,
270 AutoGenerated,
272}
273
274#[derive(Debug)]
276pub struct PatternGenerator {
277 max_patterns: usize,
279 min_quality: u32,
281 stats: GeneratorStats,
283 strategy: PatternStrategy,
285}
286
287impl PatternGenerator {
288 pub fn new() -> Self {
290 let config = QuantifierConfig::default();
291 Self {
292 max_patterns: 10,
293 min_quality: 0,
294 stats: GeneratorStats::default(),
295 strategy: config.pattern_strategy,
296 }
297 }
298
299 pub fn generate(
301 &mut self,
302 quantifier: &QuantifiedFormula,
303 manager: &TermManager,
304 ) -> Vec<Pattern> {
305 self.stats.num_generations += 1;
306
307 if !quantifier.patterns.is_empty() {
309 return self.user_patterns_to_patterns(&quantifier.patterns, manager);
310 }
311
312 let mut patterns = Vec::new();
314
315 patterns.extend(self.generate_function_patterns(quantifier.body, manager));
317
318 patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
320
321 patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
323
324 patterns.retain(|p| p.quality >= self.min_quality);
326
327 match self.strategy {
328 PatternStrategy::StaticDepth => {
329 patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
330 }
331 PatternStrategy::GreedyCover => {
332 patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
333 }
334 }
335
336 patterns.truncate(self.max_patterns);
338
339 self.stats.num_patterns_generated += patterns.len();
340
341 patterns
342 }
343
344 fn user_patterns_to_patterns(
345 &self,
346 user_patterns: &[Vec<TermId>],
347 manager: &TermManager,
348 ) -> Vec<Pattern> {
349 let mut patterns = Vec::new();
350
351 for pattern_terms in user_patterns {
352 let mut pattern = Pattern::new(pattern_terms.clone());
353 pattern.extract_variables(manager);
354 pattern.calculate_quality(manager);
355 pattern.pattern_type = PatternType::UserSpecified;
356 patterns.push(pattern);
357 }
358
359 patterns
360 }
361
362 fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
363 let mut patterns = Vec::new();
364 let func_apps = self.collect_function_applications(body, manager);
365
366 for func_app in func_apps {
367 let mut pattern = Pattern::new(vec![func_app]);
368 pattern.extract_variables(manager);
369 pattern.calculate_quality(manager);
370 pattern.pattern_type = PatternType::AutoGenerated;
371 patterns.push(pattern);
372 }
373
374 patterns
375 }
376
377 fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
378 let mut patterns = Vec::new();
379 let equalities = self.collect_equalities(body, manager);
380
381 for eq_term in equalities {
382 let mut pattern = Pattern::new(vec![eq_term]);
383 pattern.extract_variables(manager);
384 pattern.calculate_quality(manager);
385 pattern.pattern_type = PatternType::AutoGenerated;
386 patterns.push(pattern);
387 }
388
389 patterns
390 }
391
392 fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
393 let mut patterns = Vec::new();
394 let arith_terms = self.collect_arithmetic_terms(body, manager);
395
396 for arith_term in arith_terms {
397 let mut pattern = Pattern::new(vec![arith_term]);
398 pattern.extract_variables(manager);
399 pattern.calculate_quality(manager);
400 pattern.pattern_type = PatternType::AutoGenerated;
401 patterns.push(pattern);
402 }
403
404 patterns
405 }
406
407 fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
408 let mut results = Vec::new();
409 let mut visited = FxHashSet::default();
410 self.collect_funcs_rec(term, &mut results, &mut visited, manager);
411 results
412 }
413
414 fn collect_funcs_rec(
415 &self,
416 term: TermId,
417 results: &mut Vec<TermId>,
418 visited: &mut FxHashSet<TermId>,
419 manager: &TermManager,
420 ) {
421 if visited.contains(&term) {
422 return;
423 }
424 visited.insert(term);
425
426 let Some(t) = manager.get(term) else {
427 return;
428 };
429
430 if let TermKind::Apply { args, .. } = &t.kind {
431 results.push(term);
432 for &arg in args.iter() {
433 self.collect_funcs_rec(arg, results, visited, manager);
434 }
435 }
436
437 match &t.kind {
439 TermKind::Not(arg) | TermKind::Neg(arg) => {
440 self.collect_funcs_rec(*arg, results, visited, manager);
441 }
442 TermKind::And(args) | TermKind::Or(args) => {
443 for &arg in args {
444 self.collect_funcs_rec(arg, results, visited, manager);
445 }
446 }
447 _ => {}
448 }
449 }
450
451 fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
452 let mut results = Vec::new();
453 let mut visited = FxHashSet::default();
454 self.collect_eqs_rec(term, &mut results, &mut visited, manager);
455 results
456 }
457
458 fn collect_eqs_rec(
459 &self,
460 term: TermId,
461 results: &mut Vec<TermId>,
462 visited: &mut FxHashSet<TermId>,
463 manager: &TermManager,
464 ) {
465 if visited.contains(&term) {
466 return;
467 }
468 visited.insert(term);
469
470 let Some(t) = manager.get(term) else {
471 return;
472 };
473
474 if matches!(t.kind, TermKind::Eq(_, _)) {
475 results.push(term);
476 }
477
478 match &t.kind {
479 TermKind::Not(arg) | TermKind::Neg(arg) => {
480 self.collect_eqs_rec(*arg, results, visited, manager);
481 }
482 TermKind::And(args) | TermKind::Or(args) => {
483 for &arg in args {
484 self.collect_eqs_rec(arg, results, visited, manager);
485 }
486 }
487 _ => {}
488 }
489 }
490
491 fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
492 let mut results = Vec::new();
493 let mut visited = FxHashSet::default();
494 self.collect_arith_rec(term, &mut results, &mut visited, manager);
495 results
496 }
497
498 fn collect_arith_rec(
499 &self,
500 term: TermId,
501 results: &mut Vec<TermId>,
502 visited: &mut FxHashSet<TermId>,
503 manager: &TermManager,
504 ) {
505 if visited.contains(&term) {
506 return;
507 }
508 visited.insert(term);
509
510 let Some(t) = manager.get(term) else {
511 return;
512 };
513
514 match &t.kind {
515 TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
516 results.push(term);
517 }
518 TermKind::Not(arg) | TermKind::Neg(arg) => {
519 self.collect_arith_rec(*arg, results, visited, manager);
520 }
521 TermKind::And(args) | TermKind::Or(args) => {
522 for &arg in args {
523 self.collect_arith_rec(arg, results, visited, manager);
524 }
525 }
526 _ => {}
527 }
528 }
529
530 pub fn stats(&self) -> &GeneratorStats {
532 &self.stats
533 }
534}
535
536impl Default for PatternGenerator {
537 fn default() -> Self {
538 Self::new()
539 }
540}
541
542#[derive(Debug, Clone, Default)]
544pub struct GeneratorStats {
545 pub num_generations: usize,
547 pub num_patterns_generated: usize,
549}
550
551#[derive(Debug)]
553pub struct MultiPatternCoordinator {
554 pattern_sets: Vec<PatternSet>,
556 match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
558}
559
560impl MultiPatternCoordinator {
561 pub fn new() -> Self {
563 Self {
564 pattern_sets: Vec::new(),
565 match_cache: FxHashMap::default(),
566 }
567 }
568
569 pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>, manager: &TermManager) {
571 self.pattern_sets
572 .push(PatternSet::from_patterns(patterns, manager));
573 }
574
575 pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
577 let mut multi_matches = Vec::new();
578
579 for pattern_set in &self.pattern_sets {
580 let mut set_matches = Vec::new();
582
583 for pattern in &pattern_set.patterns {
584 for &term in &pattern.terms {
585 if let Some(cached) = self.match_cache.get(&term) {
586 set_matches.extend(cached.clone());
587 }
588 }
589 }
590
591 if !set_matches.is_empty() {
593 multi_matches.push(MultiMatch {
594 pattern_set: pattern_set.patterns.clone(),
595 matches: set_matches,
596 });
597 }
598 }
599
600 multi_matches
601 }
602
603 pub fn clear_cache(&mut self) {
605 self.match_cache.clear();
606 }
607}
608
609impl Default for MultiPatternCoordinator {
610 fn default() -> Self {
611 Self::new()
612 }
613}
614
615#[derive(Debug, Clone)]
617pub struct PatternSet {
618 pub patterns: Vec<Pattern>,
619 pub matches: Vec<PatternMatch>,
620 pub covered_shapes: FxHashSet<TermShape>,
621}
622
623impl PatternSet {
624 pub fn from_patterns(patterns: Vec<Pattern>, manager: &TermManager) -> Self {
626 let mut covered_shapes = FxHashSet::default();
627 for pattern in &patterns {
628 for &term in &pattern.terms {
629 covered_shapes.insert(TermShape::from_term(term, manager));
630 }
631 }
632 Self {
633 patterns,
634 matches: Vec::new(),
635 covered_shapes,
636 }
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct PatternMatch {
643 pub pattern: Pattern,
645 pub matched_term: TermId,
647 pub bindings: FxHashMap<Spur, TermId>,
649}
650
651#[derive(Debug, Clone)]
653pub struct MultiMatch {
654 pub pattern_set: Vec<Pattern>,
656 pub matches: Vec<PatternMatch>,
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_pattern_creation() {
666 let pattern = Pattern::new(vec![TermId::new(1)]);
667 assert_eq!(pattern.terms.len(), 1);
668 assert_eq!(pattern.variables.len(), 0);
669 }
670
671 #[test]
672 fn test_pattern_type_equality() {
673 assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
674 assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
675 }
676
677 #[test]
678 fn test_pattern_generator_creation() {
679 let generator = PatternGenerator::new();
680 assert_eq!(generator.max_patterns, 10);
681 }
682
683 #[test]
684 fn test_multi_pattern_coordinator() {
685 let mut coord = MultiPatternCoordinator::new();
686 let manager = TermManager::new();
687 coord.add_pattern_set(vec![], &manager);
688 assert_eq!(coord.pattern_sets.len(), 1);
689 }
690
691 #[test]
692 fn test_pattern_equality() {
693 let p1 = Pattern::new(vec![TermId::new(1)]);
694 let p2 = Pattern::new(vec![TermId::new(1)]);
695 assert_eq!(p1, p2);
696 }
697}