1use lasso::Spur;
38use oxiz_core::ast::{TermId, TermKind, TermManager};
39use oxiz_core::sort::SortId;
40use rustc_hash::{FxHashMap, FxHashSet};
41
42pub type PatternVar = u32;
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum CodeTreeInstr {
48 CheckSymbol {
51 symbol: Spur,
53 arity: usize,
55 failure_pc: usize,
57 },
58
59 CheckVar {
62 failure_pc: usize,
64 },
65
66 CheckConstant {
69 value: TermId,
71 failure_pc: usize,
73 },
74
75 Bind {
78 var: PatternVar,
80 },
81
82 CheckEq {
85 var: PatternVar,
87 failure_pc: usize,
89 },
90
91 MoveToChild {
94 index: usize,
96 },
97
98 MoveToParent,
100
101 Yield {
104 quantifier: TermId,
106 pattern_idx: usize,
108 },
109
110 Halt,
112}
113
114#[derive(Debug, Clone)]
116pub struct CompiledPattern {
117 pub pattern: Vec<TermId>,
119 pub instructions: Vec<CodeTreeInstr>,
121 pub variables: FxHashMap<PatternVar, SortId>,
123 pub quantifier: TermId,
125 pub pattern_index: usize,
127}
128
129#[derive(Debug, Clone)]
131struct MatchContext {
132 current_term: TermId,
134 pc: usize,
136 substitution: FxHashMap<PatternVar, TermId>,
138 term_stack: Vec<TermId>,
140}
141
142impl MatchContext {
143 fn new(root: TermId) -> Self {
144 Self {
145 current_term: root,
146 pc: 0,
147 substitution: FxHashMap::default(),
148 term_stack: vec![root],
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct Match {
156 pub quantifier: TermId,
158 pub pattern_index: usize,
160 pub substitution: FxHashMap<PatternVar, TermId>,
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct CodeTreeStats {
167 pub patterns_compiled: usize,
169 pub total_instructions: usize,
171 pub ground_terms_indexed: usize,
173 pub matches_found: usize,
175 pub match_attempts: usize,
177 pub failed_matches: usize,
179 pub matching_time_us: u64,
181}
182
183pub struct CodeTree {
185 symbol_index: FxHashMap<Spur, Vec<CompiledPattern>>,
188
189 variable_patterns: Vec<CompiledPattern>,
191
192 ground_terms: FxHashMap<Spur, FxHashSet<TermId>>,
195
196 all_ground_terms: FxHashSet<TermId>,
198
199 stats: CodeTreeStats,
201}
202
203impl CodeTree {
204 pub fn new() -> Self {
206 Self {
207 symbol_index: FxHashMap::default(),
208 variable_patterns: Vec::new(),
209 ground_terms: FxHashMap::default(),
210 all_ground_terms: FxHashSet::default(),
211 stats: CodeTreeStats::default(),
212 }
213 }
214
215 pub fn stats(&self) -> &CodeTreeStats {
217 &self.stats
218 }
219
220 pub fn reset_stats(&mut self) {
222 self.stats = CodeTreeStats::default();
223 }
224
225 pub fn add_patterns(
233 &mut self,
234 quantifier: TermId,
235 patterns: &[Vec<TermId>],
236 var_mapping: &FxHashMap<Spur, PatternVar>,
237 tm: &TermManager,
238 ) {
239 for (pattern_idx, pattern) in patterns.iter().enumerate() {
240 if pattern.is_empty() {
241 continue;
242 }
243
244 for pattern_term in pattern {
246 let compiled =
247 self.compile_pattern(*pattern_term, quantifier, pattern_idx, var_mapping, tm);
248
249 if !compiled.instructions.is_empty() {
250 if let Some(root_sym) = self.get_root_symbol(*pattern_term, tm) {
252 self.symbol_index
253 .entry(root_sym)
254 .or_default()
255 .push(compiled);
256 } else {
257 self.variable_patterns.push(compiled);
259 }
260
261 self.stats.patterns_compiled += 1;
262 }
263 }
264 }
265 }
266
267 fn compile_pattern(
269 &mut self,
270 pattern: TermId,
271 quantifier: TermId,
272 pattern_idx: usize,
273 var_mapping: &FxHashMap<Spur, PatternVar>,
274 tm: &TermManager,
275 ) -> CompiledPattern {
276 let mut instructions = Vec::new();
277 let mut variables = FxHashMap::default();
278 let mut bound_vars = FxHashMap::default();
279
280 self.compile_term(
281 pattern,
282 var_mapping,
283 &mut bound_vars,
284 &mut variables,
285 &mut instructions,
286 tm,
287 );
288
289 instructions.push(CodeTreeInstr::Yield {
291 quantifier,
292 pattern_idx,
293 });
294 instructions.push(CodeTreeInstr::Halt);
295
296 self.stats.total_instructions += instructions.len();
297
298 CompiledPattern {
299 pattern: vec![pattern],
300 instructions,
301 variables,
302 quantifier,
303 pattern_index: pattern_idx,
304 }
305 }
306
307 fn compile_term(
309 &self,
310 term: TermId,
311 var_mapping: &FxHashMap<Spur, PatternVar>,
312 bound_vars: &mut FxHashMap<PatternVar, usize>,
313 variables: &mut FxHashMap<PatternVar, SortId>,
314 instructions: &mut Vec<CodeTreeInstr>,
315 tm: &TermManager,
316 ) {
317 let term_data = tm.get(term).expect("term should exist in manager");
318 match &term_data.kind {
319 TermKind::Var(name) => {
320 let sort = term_data.sort;
321
322 if let Some(&var_id) = var_mapping.get(name) {
324 if let Some(&_first_occurrence) = bound_vars.get(&var_id) {
325 let failure_pc = instructions.len() + 1;
327 instructions.push(CodeTreeInstr::CheckEq {
328 var: var_id,
329 failure_pc,
330 });
331 } else {
332 bound_vars.insert(var_id, instructions.len());
334 variables.insert(var_id, sort);
335 instructions.push(CodeTreeInstr::Bind { var: var_id });
336 }
337 } else {
338 let failure_pc = instructions.len() + 1;
340 instructions.push(CodeTreeInstr::CheckVar { failure_pc });
341 }
342 }
343
344 TermKind::IntConst(_) | TermKind::RealConst(_) | TermKind::BitVecConst { .. } => {
345 let failure_pc = instructions.len() + 1;
347 instructions.push(CodeTreeInstr::CheckConstant {
348 value: term,
349 failure_pc,
350 });
351 }
352
353 TermKind::Apply { func, args } => {
354 let symbol = *func;
356 let arity = args.len();
357 let failure_pc = instructions.len() + arity + 2;
358
359 instructions.push(CodeTreeInstr::CheckSymbol {
360 symbol,
361 arity,
362 failure_pc,
363 });
364
365 for (i, &arg) in args.iter().enumerate() {
367 instructions.push(CodeTreeInstr::MoveToChild { index: i });
368 self.compile_term(arg, var_mapping, bound_vars, variables, instructions, tm);
369 instructions.push(CodeTreeInstr::MoveToParent);
370 }
371 }
372
373 _ => {
374 let failure_pc = instructions.len() + 1;
376 instructions.push(CodeTreeInstr::CheckConstant {
377 value: term,
378 failure_pc,
379 });
380 }
381 }
382 }
383
384 pub fn add_ground_term(&mut self, term: TermId, tm: &TermManager) {
388 if self.all_ground_terms.contains(&term) {
389 return; }
391
392 self.all_ground_terms.insert(term);
393
394 if let Some(root_sym) = self.get_root_symbol(term, tm) {
396 self.ground_terms.entry(root_sym).or_default().insert(term);
397 }
398
399 self.stats.ground_terms_indexed += 1;
400 }
401
402 pub fn find_matches(&mut self, tm: &TermManager) -> Vec<Match> {
406 let start = std::time::Instant::now();
407 let mut matches = Vec::new();
408
409 for (symbol, patterns) in &self.symbol_index {
411 if let Some(ground_terms) = self.ground_terms.get(symbol) {
412 for term in ground_terms {
413 for pattern in patterns {
414 self.stats.match_attempts += 1;
415 if let Some(m) = self.execute_pattern(pattern, *term, tm) {
416 matches.push(m);
417 self.stats.matches_found += 1;
418 } else {
419 self.stats.failed_matches += 1;
420 }
421 }
422 }
423 }
424 }
425
426 for pattern in &self.variable_patterns {
428 for term in &self.all_ground_terms {
429 self.stats.match_attempts += 1;
430 if let Some(m) = self.execute_pattern(pattern, *term, tm) {
431 matches.push(m);
432 self.stats.matches_found += 1;
433 } else {
434 self.stats.failed_matches += 1;
435 }
436 }
437 }
438
439 self.stats.matching_time_us += start.elapsed().as_micros() as u64;
440 matches
441 }
442
443 fn execute_pattern(
445 &self,
446 pattern: &CompiledPattern,
447 ground_term: TermId,
448 tm: &TermManager,
449 ) -> Option<Match> {
450 let mut context = MatchContext::new(ground_term);
451
452 while context.pc < pattern.instructions.len() {
453 match &pattern.instructions[context.pc] {
454 CodeTreeInstr::CheckSymbol {
455 symbol,
456 arity,
457 failure_pc,
458 } => {
459 if let Some(current) = tm.get(context.current_term)
460 && let TermKind::Apply { func, args } = ¤t.kind
461 && func == symbol
462 && args.len() == *arity
463 {
464 context.pc += 1;
465 continue;
466 }
467 context.pc = *failure_pc;
469 if context.pc >= pattern.instructions.len() {
470 return None;
471 }
472 }
473
474 CodeTreeInstr::CheckVar { failure_pc } => {
475 if let Some(current) = tm.get(context.current_term)
476 && matches!(current.kind, TermKind::Var(_))
477 {
478 context.pc += 1;
479 continue;
480 }
481 context.pc = *failure_pc;
482 if context.pc >= pattern.instructions.len() {
483 return None;
484 }
485 }
486
487 CodeTreeInstr::CheckConstant { value, failure_pc } => {
488 if context.current_term == *value {
489 context.pc += 1;
490 } else {
491 context.pc = *failure_pc;
492 if context.pc >= pattern.instructions.len() {
493 return None;
494 }
495 }
496 }
497
498 CodeTreeInstr::Bind { var } => {
499 context.substitution.insert(*var, context.current_term);
500 context.pc += 1;
501 }
502
503 CodeTreeInstr::CheckEq { var, failure_pc } => {
504 if let Some(&bound_term) = context.substitution.get(var)
505 && bound_term == context.current_term
506 {
507 context.pc += 1;
508 continue;
509 }
510 context.pc = *failure_pc;
511 if context.pc >= pattern.instructions.len() {
512 return None;
513 }
514 }
515
516 CodeTreeInstr::MoveToChild { index } => {
517 if let Some(current) = tm.get(context.current_term)
518 && let TermKind::Apply { args, .. } = ¤t.kind
519 && *index < args.len()
520 {
521 context.term_stack.push(context.current_term);
522 context.current_term = args[*index];
523 context.pc += 1;
524 continue;
525 }
526 return None; }
528
529 CodeTreeInstr::MoveToParent => {
530 if let Some(parent) = context.term_stack.pop() {
531 context.current_term = parent;
532 context.pc += 1;
533 } else {
534 return None; }
536 }
537
538 CodeTreeInstr::Yield {
539 quantifier,
540 pattern_idx,
541 } => {
542 return Some(Match {
543 quantifier: *quantifier,
544 pattern_index: *pattern_idx,
545 substitution: context.substitution.clone(),
546 });
547 }
548
549 CodeTreeInstr::Halt => {
550 return None;
551 }
552 }
553 }
554
555 None
556 }
557
558 fn get_root_symbol(&self, term: TermId, tm: &TermManager) -> Option<Spur> {
560 if let Some(term_data) = tm.get(term)
561 && let TermKind::Apply { func, .. } = &term_data.kind
562 {
563 return Some(*func);
564 }
565 None
566 }
567
568 pub fn clear_ground_terms(&mut self) {
570 self.ground_terms.clear();
571 self.all_ground_terms.clear();
572 self.stats.ground_terms_indexed = 0;
573 }
574
575 pub fn remove_quantifier(&mut self, quantifier: TermId) {
577 for patterns in self.symbol_index.values_mut() {
579 patterns.retain(|p| p.quantifier != quantifier);
580 }
581
582 self.variable_patterns
584 .retain(|p| p.quantifier != quantifier);
585 }
586}
587
588impl Default for CodeTree {
589 fn default() -> Self {
590 Self::new()
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use lasso::{Key, Rodeo};
598
599 fn setup_term_manager() -> (TermManager, Rodeo) {
600 (TermManager::new(), Rodeo::default())
601 }
602
603 #[test]
604 fn test_code_tree_creation() {
605 let tree = CodeTree::new();
606 assert_eq!(tree.stats().patterns_compiled, 0);
607 assert_eq!(tree.stats().ground_terms_indexed, 0);
608 }
609
610 #[test]
611 fn test_code_tree_stats() {
612 let mut tree = CodeTree::new();
613 assert_eq!(tree.stats().matches_found, 0);
614
615 tree.stats.matches_found = 10;
616 assert_eq!(tree.stats().matches_found, 10);
617
618 tree.reset_stats();
619 assert_eq!(tree.stats().matches_found, 0);
620 }
621
622 #[test]
623 fn test_clear_ground_terms() {
624 let mut tree = CodeTree::new();
625 let (_tm, _) = setup_term_manager();
626
627 tree.all_ground_terms.insert(TermId::new(1));
629 tree.stats.ground_terms_indexed = 1;
630
631 tree.clear_ground_terms();
632 assert!(tree.all_ground_terms.is_empty());
633 assert_eq!(tree.stats.ground_terms_indexed, 0);
634 }
635
636 #[test]
637 fn test_match_context_creation() {
638 let ctx = MatchContext::new(TermId::new(1));
639 assert_eq!(ctx.current_term, TermId::new(1));
640 assert_eq!(ctx.pc, 0);
641 assert!(ctx.substitution.is_empty());
642 assert_eq!(ctx.term_stack.len(), 1);
643 }
644
645 #[test]
646 fn test_code_tree_instruction_check_symbol() {
647 let instr = CodeTreeInstr::CheckSymbol {
648 symbol: Spur::try_from_usize(0).unwrap(),
649 arity: 2,
650 failure_pc: 10,
651 };
652
653 match instr {
654 CodeTreeInstr::CheckSymbol {
655 symbol: _,
656 arity,
657 failure_pc,
658 } => {
659 assert_eq!(arity, 2);
660 assert_eq!(failure_pc, 10);
661 }
662 _ => panic!("Wrong instruction type"),
663 }
664 }
665
666 #[test]
667 fn test_code_tree_instruction_bind() {
668 let instr = CodeTreeInstr::Bind { var: 42 };
669
670 match instr {
671 CodeTreeInstr::Bind { var } => {
672 assert_eq!(var, 42);
673 }
674 _ => panic!("Wrong instruction type"),
675 }
676 }
677
678 #[test]
679 fn test_compiled_pattern() {
680 let pattern = CompiledPattern {
681 pattern: vec![TermId::new(1)],
682 instructions: vec![
683 CodeTreeInstr::Bind { var: 0 },
684 CodeTreeInstr::Yield {
685 quantifier: TermId::new(2),
686 pattern_idx: 0,
687 },
688 CodeTreeInstr::Halt,
689 ],
690 variables: FxHashMap::default(),
691 quantifier: TermId::new(2),
692 pattern_index: 0,
693 };
694
695 assert_eq!(pattern.instructions.len(), 3);
696 assert_eq!(pattern.pattern_index, 0);
697 }
698
699 #[test]
700 fn test_match_result() {
701 let mut subst = FxHashMap::default();
702 subst.insert(0, TermId::new(100));
703
704 let m = Match {
705 quantifier: TermId::new(1),
706 pattern_index: 0,
707 substitution: subst,
708 };
709
710 assert_eq!(m.quantifier, TermId::new(1));
711 assert_eq!(m.pattern_index, 0);
712 assert_eq!(m.substitution.len(), 1);
713 }
714}