1use lasso::Spur;
7use oxiz_core::ast::{TermId, TermKind, TermManager};
8use rustc_hash::{FxHashMap, FxHashSet};
9
10use super::QuantifiedFormula;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct Pattern {
15 pub terms: Vec<TermId>,
17 pub variables: FxHashSet<Spur>,
19 pub quality: u32,
21 pub pattern_type: PatternType,
23}
24
25impl Pattern {
26 pub fn new(terms: Vec<TermId>) -> Self {
28 Self {
29 terms,
30 variables: FxHashSet::default(),
31 quality: 0,
32 pattern_type: PatternType::MultiPattern,
33 }
34 }
35
36 pub fn extract_variables(&mut self, manager: &TermManager) {
38 self.variables.clear();
39 let terms: Vec<_> = self.terms.to_vec();
41 for term in terms {
42 self.extract_vars_rec(term, manager);
43 }
44 }
45
46 fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
47 let mut visited = FxHashSet::default();
48 self.extract_vars_helper(term, manager, &mut visited);
49 }
50
51 fn extract_vars_helper(
52 &mut self,
53 term: TermId,
54 manager: &TermManager,
55 visited: &mut FxHashSet<TermId>,
56 ) {
57 if visited.contains(&term) {
58 return;
59 }
60 visited.insert(term);
61
62 let Some(t) = manager.get(term) else {
63 return;
64 };
65
66 if let TermKind::Var(name) = t.kind {
67 self.variables.insert(name);
68 return;
69 }
70
71 match &t.kind {
72 TermKind::Apply { args, .. } => {
73 for &arg in args.iter() {
74 self.extract_vars_helper(arg, manager, visited);
75 }
76 }
77 TermKind::Not(arg) | TermKind::Neg(arg) => {
78 self.extract_vars_helper(*arg, manager, visited);
79 }
80 TermKind::And(args) | TermKind::Or(args) => {
81 for &arg in args {
82 self.extract_vars_helper(arg, manager, visited);
83 }
84 }
85 _ => {}
86 }
87 }
88
89 pub fn calculate_quality(&mut self, manager: &TermManager) {
91 let num_funcs = self.count_function_symbols(manager);
97 let num_vars = self.variables.len();
98 let complexity_penalty = self.terms.len();
99
100 self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
101 }
102
103 fn count_function_symbols(&self, manager: &TermManager) -> usize {
104 let mut count = 0;
105 let mut visited = FxHashSet::default();
106
107 for &term in &self.terms {
108 count += self.count_funcs_rec(term, manager, &mut visited);
109 }
110
111 count
112 }
113
114 fn count_funcs_rec(
115 &self,
116 term: TermId,
117 manager: &TermManager,
118 visited: &mut FxHashSet<TermId>,
119 ) -> usize {
120 if visited.contains(&term) {
121 return 0;
122 }
123 visited.insert(term);
124
125 let Some(t) = manager.get(term) else {
126 return 0;
127 };
128
129 match &t.kind {
130 TermKind::Apply { args, .. } => {
131 1 + args
132 .iter()
133 .map(|&arg| self.count_funcs_rec(arg, manager, visited))
134 .sum::<usize>()
135 }
136 _ => 0,
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub enum PatternType {
144 SingleTerm,
146 MultiPattern,
148 UserSpecified,
150 AutoGenerated,
152}
153
154#[derive(Debug)]
156pub struct PatternGenerator {
157 max_patterns: usize,
159 min_quality: u32,
161 stats: GeneratorStats,
163}
164
165impl PatternGenerator {
166 pub fn new() -> Self {
168 Self {
169 max_patterns: 10,
170 min_quality: 0,
171 stats: GeneratorStats::default(),
172 }
173 }
174
175 pub fn generate(
177 &mut self,
178 quantifier: &QuantifiedFormula,
179 manager: &TermManager,
180 ) -> Vec<Pattern> {
181 self.stats.num_generations += 1;
182
183 if !quantifier.patterns.is_empty() {
185 return self.user_patterns_to_patterns(&quantifier.patterns, manager);
186 }
187
188 let mut patterns = Vec::new();
190
191 patterns.extend(self.generate_function_patterns(quantifier.body, manager));
193
194 patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
196
197 patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
199
200 patterns.retain(|p| p.quality >= self.min_quality);
202
203 patterns.sort_by(|a, b| b.quality.cmp(&a.quality));
205
206 patterns.truncate(self.max_patterns);
208
209 self.stats.num_patterns_generated += patterns.len();
210
211 patterns
212 }
213
214 fn user_patterns_to_patterns(
215 &self,
216 user_patterns: &[Vec<TermId>],
217 manager: &TermManager,
218 ) -> Vec<Pattern> {
219 let mut patterns = Vec::new();
220
221 for pattern_terms in user_patterns {
222 let mut pattern = Pattern::new(pattern_terms.clone());
223 pattern.extract_variables(manager);
224 pattern.calculate_quality(manager);
225 pattern.pattern_type = PatternType::UserSpecified;
226 patterns.push(pattern);
227 }
228
229 patterns
230 }
231
232 fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
233 let mut patterns = Vec::new();
234 let func_apps = self.collect_function_applications(body, manager);
235
236 for func_app in func_apps {
237 let mut pattern = Pattern::new(vec![func_app]);
238 pattern.extract_variables(manager);
239 pattern.calculate_quality(manager);
240 pattern.pattern_type = PatternType::AutoGenerated;
241 patterns.push(pattern);
242 }
243
244 patterns
245 }
246
247 fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
248 let mut patterns = Vec::new();
249 let equalities = self.collect_equalities(body, manager);
250
251 for eq_term in equalities {
252 let mut pattern = Pattern::new(vec![eq_term]);
253 pattern.extract_variables(manager);
254 pattern.calculate_quality(manager);
255 pattern.pattern_type = PatternType::AutoGenerated;
256 patterns.push(pattern);
257 }
258
259 patterns
260 }
261
262 fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
263 let mut patterns = Vec::new();
264 let arith_terms = self.collect_arithmetic_terms(body, manager);
265
266 for arith_term in arith_terms {
267 let mut pattern = Pattern::new(vec![arith_term]);
268 pattern.extract_variables(manager);
269 pattern.calculate_quality(manager);
270 pattern.pattern_type = PatternType::AutoGenerated;
271 patterns.push(pattern);
272 }
273
274 patterns
275 }
276
277 fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
278 let mut results = Vec::new();
279 let mut visited = FxHashSet::default();
280 self.collect_funcs_rec(term, &mut results, &mut visited, manager);
281 results
282 }
283
284 fn collect_funcs_rec(
285 &self,
286 term: TermId,
287 results: &mut Vec<TermId>,
288 visited: &mut FxHashSet<TermId>,
289 manager: &TermManager,
290 ) {
291 if visited.contains(&term) {
292 return;
293 }
294 visited.insert(term);
295
296 let Some(t) = manager.get(term) else {
297 return;
298 };
299
300 if let TermKind::Apply { args, .. } = &t.kind {
301 results.push(term);
302 for &arg in args.iter() {
303 self.collect_funcs_rec(arg, results, visited, manager);
304 }
305 }
306
307 match &t.kind {
309 TermKind::Not(arg) | TermKind::Neg(arg) => {
310 self.collect_funcs_rec(*arg, results, visited, manager);
311 }
312 TermKind::And(args) | TermKind::Or(args) => {
313 for &arg in args {
314 self.collect_funcs_rec(arg, results, visited, manager);
315 }
316 }
317 _ => {}
318 }
319 }
320
321 fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
322 let mut results = Vec::new();
323 let mut visited = FxHashSet::default();
324 self.collect_eqs_rec(term, &mut results, &mut visited, manager);
325 results
326 }
327
328 fn collect_eqs_rec(
329 &self,
330 term: TermId,
331 results: &mut Vec<TermId>,
332 visited: &mut FxHashSet<TermId>,
333 manager: &TermManager,
334 ) {
335 if visited.contains(&term) {
336 return;
337 }
338 visited.insert(term);
339
340 let Some(t) = manager.get(term) else {
341 return;
342 };
343
344 if matches!(t.kind, TermKind::Eq(_, _)) {
345 results.push(term);
346 }
347
348 match &t.kind {
349 TermKind::Not(arg) | TermKind::Neg(arg) => {
350 self.collect_eqs_rec(*arg, results, visited, manager);
351 }
352 TermKind::And(args) | TermKind::Or(args) => {
353 for &arg in args {
354 self.collect_eqs_rec(arg, results, visited, manager);
355 }
356 }
357 _ => {}
358 }
359 }
360
361 fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
362 let mut results = Vec::new();
363 let mut visited = FxHashSet::default();
364 self.collect_arith_rec(term, &mut results, &mut visited, manager);
365 results
366 }
367
368 fn collect_arith_rec(
369 &self,
370 term: TermId,
371 results: &mut Vec<TermId>,
372 visited: &mut FxHashSet<TermId>,
373 manager: &TermManager,
374 ) {
375 if visited.contains(&term) {
376 return;
377 }
378 visited.insert(term);
379
380 let Some(t) = manager.get(term) else {
381 return;
382 };
383
384 match &t.kind {
385 TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
386 results.push(term);
387 }
388 TermKind::Not(arg) | TermKind::Neg(arg) => {
389 self.collect_arith_rec(*arg, results, visited, manager);
390 }
391 TermKind::And(args) | TermKind::Or(args) => {
392 for &arg in args {
393 self.collect_arith_rec(arg, results, visited, manager);
394 }
395 }
396 _ => {}
397 }
398 }
399
400 pub fn stats(&self) -> &GeneratorStats {
402 &self.stats
403 }
404}
405
406impl Default for PatternGenerator {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[derive(Debug, Clone, Default)]
414pub struct GeneratorStats {
415 pub num_generations: usize,
417 pub num_patterns_generated: usize,
419}
420
421#[derive(Debug)]
423pub struct MultiPatternCoordinator {
424 pattern_sets: Vec<PatternSet>,
426 match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
428}
429
430impl MultiPatternCoordinator {
431 pub fn new() -> Self {
433 Self {
434 pattern_sets: Vec::new(),
435 match_cache: FxHashMap::default(),
436 }
437 }
438
439 pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>) {
441 self.pattern_sets.push(PatternSet {
442 patterns,
443 matches: Vec::new(),
444 });
445 }
446
447 pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
449 let mut multi_matches = Vec::new();
450
451 for pattern_set in &self.pattern_sets {
452 let mut set_matches = Vec::new();
454
455 for pattern in &pattern_set.patterns {
456 for &term in &pattern.terms {
457 if let Some(cached) = self.match_cache.get(&term) {
458 set_matches.extend(cached.clone());
459 }
460 }
461 }
462
463 if !set_matches.is_empty() {
465 multi_matches.push(MultiMatch {
466 pattern_set: pattern_set.patterns.clone(),
467 matches: set_matches,
468 });
469 }
470 }
471
472 multi_matches
473 }
474
475 pub fn clear_cache(&mut self) {
477 self.match_cache.clear();
478 }
479}
480
481impl Default for MultiPatternCoordinator {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487#[derive(Debug, Clone)]
489struct PatternSet {
490 patterns: Vec<Pattern>,
491 matches: Vec<PatternMatch>,
492}
493
494#[derive(Debug, Clone)]
496pub struct PatternMatch {
497 pub pattern: Pattern,
499 pub matched_term: TermId,
501 pub bindings: FxHashMap<Spur, TermId>,
503}
504
505#[derive(Debug, Clone)]
507pub struct MultiMatch {
508 pub pattern_set: Vec<Pattern>,
510 pub matches: Vec<PatternMatch>,
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_pattern_creation() {
520 let pattern = Pattern::new(vec![TermId::new(1)]);
521 assert_eq!(pattern.terms.len(), 1);
522 assert_eq!(pattern.variables.len(), 0);
523 }
524
525 #[test]
526 fn test_pattern_type_equality() {
527 assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
528 assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
529 }
530
531 #[test]
532 fn test_pattern_generator_creation() {
533 let generator = PatternGenerator::new();
534 assert_eq!(generator.max_patterns, 10);
535 }
536
537 #[test]
538 fn test_multi_pattern_coordinator() {
539 let mut coord = MultiPatternCoordinator::new();
540 coord.add_pattern_set(vec![]);
541 assert_eq!(coord.pattern_sets.len(), 1);
542 }
543
544 #[test]
545 fn test_pattern_equality() {
546 let p1 = Pattern::new(vec![TermId::new(1)]);
547 let p2 = Pattern::new(vec![TermId::new(1)]);
548 assert_eq!(p1, p2);
549 }
550}