1#[allow(unused_imports)]
7use crate::prelude::*;
8use oxiz_core::ast::{TermId, TermKind, TermManager};
9use oxiz_core::interner::Spur;
10
11use super::QuantifiedFormula;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Pattern {
16 pub terms: Vec<TermId>,
18 pub variables: FxHashSet<Spur>,
20 pub quality: u32,
22 pub pattern_type: PatternType,
24}
25
26impl Pattern {
27 pub fn new(terms: Vec<TermId>) -> Self {
29 Self {
30 terms,
31 variables: FxHashSet::default(),
32 quality: 0,
33 pattern_type: PatternType::MultiPattern,
34 }
35 }
36
37 pub fn extract_variables(&mut self, manager: &TermManager) {
39 self.variables.clear();
40 let terms: Vec<_> = self.terms.to_vec();
42 for term in terms {
43 self.extract_vars_rec(term, manager);
44 }
45 }
46
47 fn extract_vars_rec(&mut self, term: TermId, manager: &TermManager) {
48 let mut visited = FxHashSet::default();
49 self.extract_vars_helper(term, manager, &mut visited);
50 }
51
52 fn extract_vars_helper(
53 &mut self,
54 term: TermId,
55 manager: &TermManager,
56 visited: &mut FxHashSet<TermId>,
57 ) {
58 if visited.contains(&term) {
59 return;
60 }
61 visited.insert(term);
62
63 let Some(t) = manager.get(term) else {
64 return;
65 };
66
67 if let TermKind::Var(name) = t.kind {
68 self.variables.insert(name);
69 return;
70 }
71
72 match &t.kind {
73 TermKind::Apply { args, .. } => {
74 for &arg in args.iter() {
75 self.extract_vars_helper(arg, manager, visited);
76 }
77 }
78 TermKind::Not(arg) | TermKind::Neg(arg) => {
79 self.extract_vars_helper(*arg, manager, visited);
80 }
81 TermKind::And(args) | TermKind::Or(args) => {
82 for &arg in args {
83 self.extract_vars_helper(arg, manager, visited);
84 }
85 }
86 _ => {}
87 }
88 }
89
90 pub fn calculate_quality(&mut self, manager: &TermManager) {
92 let num_funcs = self.count_function_symbols(manager);
98 let num_vars = self.variables.len();
99 let complexity_penalty = self.terms.len();
100
101 self.quality = (num_funcs * 100 + num_vars * 50) as u32 - complexity_penalty as u32;
102 }
103
104 fn count_function_symbols(&self, manager: &TermManager) -> usize {
105 let mut count = 0;
106 let mut visited = FxHashSet::default();
107
108 for &term in &self.terms {
109 count += self.count_funcs_rec(term, manager, &mut visited);
110 }
111
112 count
113 }
114
115 fn count_funcs_rec(
116 &self,
117 term: TermId,
118 manager: &TermManager,
119 visited: &mut FxHashSet<TermId>,
120 ) -> usize {
121 if visited.contains(&term) {
122 return 0;
123 }
124 visited.insert(term);
125
126 let Some(t) = manager.get(term) else {
127 return 0;
128 };
129
130 match &t.kind {
131 TermKind::Apply { args, .. } => {
132 1 + args
133 .iter()
134 .map(|&arg| self.count_funcs_rec(arg, manager, visited))
135 .sum::<usize>()
136 }
137 _ => 0,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144pub enum PatternType {
145 SingleTerm,
147 MultiPattern,
149 UserSpecified,
151 AutoGenerated,
153}
154
155#[derive(Debug)]
157pub struct PatternGenerator {
158 max_patterns: usize,
160 min_quality: u32,
162 stats: GeneratorStats,
164}
165
166impl PatternGenerator {
167 pub fn new() -> Self {
169 Self {
170 max_patterns: 10,
171 min_quality: 0,
172 stats: GeneratorStats::default(),
173 }
174 }
175
176 pub fn generate(
178 &mut self,
179 quantifier: &QuantifiedFormula,
180 manager: &TermManager,
181 ) -> Vec<Pattern> {
182 self.stats.num_generations += 1;
183
184 if !quantifier.patterns.is_empty() {
186 return self.user_patterns_to_patterns(&quantifier.patterns, manager);
187 }
188
189 let mut patterns = Vec::new();
191
192 patterns.extend(self.generate_function_patterns(quantifier.body, manager));
194
195 patterns.extend(self.generate_equality_patterns(quantifier.body, manager));
197
198 patterns.extend(self.generate_arithmetic_patterns(quantifier.body, manager));
200
201 patterns.retain(|p| p.quality >= self.min_quality);
203
204 patterns.sort_by_key(|p| std::cmp::Reverse(p.quality));
206
207 patterns.truncate(self.max_patterns);
209
210 self.stats.num_patterns_generated += patterns.len();
211
212 patterns
213 }
214
215 fn user_patterns_to_patterns(
216 &self,
217 user_patterns: &[Vec<TermId>],
218 manager: &TermManager,
219 ) -> Vec<Pattern> {
220 let mut patterns = Vec::new();
221
222 for pattern_terms in user_patterns {
223 let mut pattern = Pattern::new(pattern_terms.clone());
224 pattern.extract_variables(manager);
225 pattern.calculate_quality(manager);
226 pattern.pattern_type = PatternType::UserSpecified;
227 patterns.push(pattern);
228 }
229
230 patterns
231 }
232
233 fn generate_function_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
234 let mut patterns = Vec::new();
235 let func_apps = self.collect_function_applications(body, manager);
236
237 for func_app in func_apps {
238 let mut pattern = Pattern::new(vec![func_app]);
239 pattern.extract_variables(manager);
240 pattern.calculate_quality(manager);
241 pattern.pattern_type = PatternType::AutoGenerated;
242 patterns.push(pattern);
243 }
244
245 patterns
246 }
247
248 fn generate_equality_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
249 let mut patterns = Vec::new();
250 let equalities = self.collect_equalities(body, manager);
251
252 for eq_term in equalities {
253 let mut pattern = Pattern::new(vec![eq_term]);
254 pattern.extract_variables(manager);
255 pattern.calculate_quality(manager);
256 pattern.pattern_type = PatternType::AutoGenerated;
257 patterns.push(pattern);
258 }
259
260 patterns
261 }
262
263 fn generate_arithmetic_patterns(&self, body: TermId, manager: &TermManager) -> Vec<Pattern> {
264 let mut patterns = Vec::new();
265 let arith_terms = self.collect_arithmetic_terms(body, manager);
266
267 for arith_term in arith_terms {
268 let mut pattern = Pattern::new(vec![arith_term]);
269 pattern.extract_variables(manager);
270 pattern.calculate_quality(manager);
271 pattern.pattern_type = PatternType::AutoGenerated;
272 patterns.push(pattern);
273 }
274
275 patterns
276 }
277
278 fn collect_function_applications(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
279 let mut results = Vec::new();
280 let mut visited = FxHashSet::default();
281 self.collect_funcs_rec(term, &mut results, &mut visited, manager);
282 results
283 }
284
285 fn collect_funcs_rec(
286 &self,
287 term: TermId,
288 results: &mut Vec<TermId>,
289 visited: &mut FxHashSet<TermId>,
290 manager: &TermManager,
291 ) {
292 if visited.contains(&term) {
293 return;
294 }
295 visited.insert(term);
296
297 let Some(t) = manager.get(term) else {
298 return;
299 };
300
301 if let TermKind::Apply { args, .. } = &t.kind {
302 results.push(term);
303 for &arg in args.iter() {
304 self.collect_funcs_rec(arg, results, visited, manager);
305 }
306 }
307
308 match &t.kind {
310 TermKind::Not(arg) | TermKind::Neg(arg) => {
311 self.collect_funcs_rec(*arg, results, visited, manager);
312 }
313 TermKind::And(args) | TermKind::Or(args) => {
314 for &arg in args {
315 self.collect_funcs_rec(arg, results, visited, manager);
316 }
317 }
318 _ => {}
319 }
320 }
321
322 fn collect_equalities(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
323 let mut results = Vec::new();
324 let mut visited = FxHashSet::default();
325 self.collect_eqs_rec(term, &mut results, &mut visited, manager);
326 results
327 }
328
329 fn collect_eqs_rec(
330 &self,
331 term: TermId,
332 results: &mut Vec<TermId>,
333 visited: &mut FxHashSet<TermId>,
334 manager: &TermManager,
335 ) {
336 if visited.contains(&term) {
337 return;
338 }
339 visited.insert(term);
340
341 let Some(t) = manager.get(term) else {
342 return;
343 };
344
345 if matches!(t.kind, TermKind::Eq(_, _)) {
346 results.push(term);
347 }
348
349 match &t.kind {
350 TermKind::Not(arg) | TermKind::Neg(arg) => {
351 self.collect_eqs_rec(*arg, results, visited, manager);
352 }
353 TermKind::And(args) | TermKind::Or(args) => {
354 for &arg in args {
355 self.collect_eqs_rec(arg, results, visited, manager);
356 }
357 }
358 _ => {}
359 }
360 }
361
362 fn collect_arithmetic_terms(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
363 let mut results = Vec::new();
364 let mut visited = FxHashSet::default();
365 self.collect_arith_rec(term, &mut results, &mut visited, manager);
366 results
367 }
368
369 fn collect_arith_rec(
370 &self,
371 term: TermId,
372 results: &mut Vec<TermId>,
373 visited: &mut FxHashSet<TermId>,
374 manager: &TermManager,
375 ) {
376 if visited.contains(&term) {
377 return;
378 }
379 visited.insert(term);
380
381 let Some(t) = manager.get(term) else {
382 return;
383 };
384
385 match &t.kind {
386 TermKind::Lt(_, _) | TermKind::Le(_, _) | TermKind::Gt(_, _) | TermKind::Ge(_, _) => {
387 results.push(term);
388 }
389 TermKind::Not(arg) | TermKind::Neg(arg) => {
390 self.collect_arith_rec(*arg, results, visited, manager);
391 }
392 TermKind::And(args) | TermKind::Or(args) => {
393 for &arg in args {
394 self.collect_arith_rec(arg, results, visited, manager);
395 }
396 }
397 _ => {}
398 }
399 }
400
401 pub fn stats(&self) -> &GeneratorStats {
403 &self.stats
404 }
405}
406
407impl Default for PatternGenerator {
408 fn default() -> Self {
409 Self::new()
410 }
411}
412
413#[derive(Debug, Clone, Default)]
415pub struct GeneratorStats {
416 pub num_generations: usize,
418 pub num_patterns_generated: usize,
420}
421
422#[derive(Debug)]
424pub struct MultiPatternCoordinator {
425 pattern_sets: Vec<PatternSet>,
427 match_cache: FxHashMap<TermId, Vec<PatternMatch>>,
429}
430
431impl MultiPatternCoordinator {
432 pub fn new() -> Self {
434 Self {
435 pattern_sets: Vec::new(),
436 match_cache: FxHashMap::default(),
437 }
438 }
439
440 pub fn add_pattern_set(&mut self, patterns: Vec<Pattern>) {
442 self.pattern_sets.push(PatternSet {
443 patterns,
444 matches: Vec::new(),
445 });
446 }
447
448 pub fn find_matches(&mut self, _manager: &TermManager) -> Vec<MultiMatch> {
450 let mut multi_matches = Vec::new();
451
452 for pattern_set in &self.pattern_sets {
453 let mut set_matches = Vec::new();
455
456 for pattern in &pattern_set.patterns {
457 for &term in &pattern.terms {
458 if let Some(cached) = self.match_cache.get(&term) {
459 set_matches.extend(cached.clone());
460 }
461 }
462 }
463
464 if !set_matches.is_empty() {
466 multi_matches.push(MultiMatch {
467 pattern_set: pattern_set.patterns.clone(),
468 matches: set_matches,
469 });
470 }
471 }
472
473 multi_matches
474 }
475
476 pub fn clear_cache(&mut self) {
478 self.match_cache.clear();
479 }
480}
481
482impl Default for MultiPatternCoordinator {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[derive(Debug, Clone)]
490struct PatternSet {
491 patterns: Vec<Pattern>,
492 matches: Vec<PatternMatch>,
493}
494
495#[derive(Debug, Clone)]
497pub struct PatternMatch {
498 pub pattern: Pattern,
500 pub matched_term: TermId,
502 pub bindings: FxHashMap<Spur, TermId>,
504}
505
506#[derive(Debug, Clone)]
508pub struct MultiMatch {
509 pub pattern_set: Vec<Pattern>,
511 pub matches: Vec<PatternMatch>,
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_pattern_creation() {
521 let pattern = Pattern::new(vec![TermId::new(1)]);
522 assert_eq!(pattern.terms.len(), 1);
523 assert_eq!(pattern.variables.len(), 0);
524 }
525
526 #[test]
527 fn test_pattern_type_equality() {
528 assert_eq!(PatternType::SingleTerm, PatternType::SingleTerm);
529 assert_ne!(PatternType::SingleTerm, PatternType::MultiPattern);
530 }
531
532 #[test]
533 fn test_pattern_generator_creation() {
534 let generator = PatternGenerator::new();
535 assert_eq!(generator.max_patterns, 10);
536 }
537
538 #[test]
539 fn test_multi_pattern_coordinator() {
540 let mut coord = MultiPatternCoordinator::new();
541 coord.add_pattern_set(vec![]);
542 assert_eq!(coord.pattern_sets.len(), 1);
543 }
544
545 #[test]
546 fn test_pattern_equality() {
547 let p1 = Pattern::new(vec![TermId::new(1)]);
548 let p2 = Pattern::new(vec![TermId::new(1)]);
549 assert_eq!(p1, p2);
550 }
551}