1#![allow(missing_docs)]
21#![allow(dead_code)]
22
23#[allow(unused_imports)]
24use crate::prelude::*;
25use core::cmp::Ordering;
26use core::fmt;
27use num_bigint::BigInt;
28use num_rational::Rational64;
29use num_traits::ToPrimitive;
30use oxiz_core::ast::{TermId, TermKind, TermManager};
31use oxiz_core::interner::Spur;
32use oxiz_core::sort::SortId;
33use smallvec::SmallVec;
34
35type FuncEntry = (SmallVec<[SortId; 4]>, SortId, Vec<TermId>, TermId);
37
38use super::QuantifiedFormula;
39
40const MAX_UNIVERSE_SIZE: usize = 1000;
42
43#[derive(Debug, Clone)]
45pub struct CompletedModel {
46 pub assignments: FxHashMap<TermId, TermId>,
48 pub function_interps: FxHashMap<Spur, FunctionInterpretation>,
50 pub universes: FxHashMap<SortId, Vec<TermId>>,
52 pub defaults: FxHashMap<SortId, TermId>,
54 pub generation: u32,
56}
57
58impl CompletedModel {
59 pub fn new() -> Self {
61 Self {
62 assignments: FxHashMap::default(),
63 function_interps: FxHashMap::default(),
64 universes: FxHashMap::default(),
65 defaults: FxHashMap::default(),
66 generation: 0,
67 }
68 }
69
70 pub fn eval(&self, term: TermId) -> Option<TermId> {
72 self.assignments.get(&term).copied()
73 }
74
75 pub fn set(&mut self, term: TermId, value: TermId) {
77 self.assignments.insert(term, value);
78 }
79
80 pub fn universe(&self, sort: SortId) -> Option<&[TermId]> {
82 self.universes.get(&sort).map(|v| v.as_slice())
83 }
84
85 pub fn add_to_universe(&mut self, sort: SortId, value: TermId) {
87 self.universes.entry(sort).or_default().push(value);
88 }
89
90 pub fn default_value(&self, sort: SortId) -> Option<TermId> {
92 self.defaults.get(&sort).copied()
93 }
94
95 pub fn set_default(&mut self, sort: SortId, value: TermId) {
97 self.defaults.insert(sort, value);
98 }
99
100 pub fn has_uninterpreted_sort(&self, sort: SortId) -> bool {
102 self.universes.contains_key(&sort)
103 }
104
105 pub fn eval_apply(&self, func: Spur, evaluated_args: &[TermId]) -> Option<TermId> {
110 if let Some(interp) = self.function_interps.get(&func) {
111 if let Some(result) = interp.lookup(evaluated_args) {
113 return Some(result);
114 }
115 if let Some(else_val) = interp.else_value {
117 return Some(else_val);
118 }
119 if let Some(default) = self.defaults.get(&interp.range) {
121 return Some(*default);
122 }
123 }
124 None
125 }
126
127 pub fn collect_universes_from_model(
134 &mut self,
135 quantifiers: &[QuantifiedFormula],
136 manager: &TermManager,
137 ) {
138 let mut needed_sorts: FxHashSet<SortId> = FxHashSet::default();
140 for quant in quantifiers {
141 for &(_name, sort) in &quant.bound_vars {
142 needed_sorts.insert(sort);
143 }
144 }
145
146 for sort in needed_sorts {
148 if self.universes.contains_key(&sort) {
150 continue;
151 }
152
153 let mut universe_values: Vec<TermId> = Vec::new();
154 let mut seen: FxHashSet<TermId> = FxHashSet::default();
155
156 for (&term, &value) in &self.assignments {
158 if let Some(t) = manager.get(term) {
160 if t.sort == sort && seen.insert(value) {
161 universe_values.push(value);
162 }
163 }
164 if let Some(v) = manager.get(value) {
166 if v.sort == sort && seen.insert(value) {
167 universe_values.push(value);
168 }
169 }
170 }
171
172 for interp in self.function_interps.values() {
174 for entry in &interp.entries {
175 for (i, &arg) in entry.args.iter().enumerate() {
177 if i < interp.domain.len() && interp.domain[i] == sort && seen.insert(arg) {
178 universe_values.push(arg);
179 }
180 }
181 if interp.range == sort && seen.insert(entry.result) {
183 universe_values.push(entry.result);
184 }
185 }
186 }
187
188 universe_values.truncate(MAX_UNIVERSE_SIZE);
190
191 if !universe_values.is_empty() {
192 self.universes.insert(sort, universe_values);
193 }
194 }
195 }
196
197 pub fn complete_function_interpretations(&mut self) {
212 let updates: Vec<(Spur, TermId)> = self
214 .function_interps
215 .iter()
216 .filter_map(|(&name, interp)| {
217 if interp.else_value.is_some() {
218 return None;
219 }
220 if !interp.entries.is_empty() {
224 return None;
225 }
226 if let Some(&default) = self.defaults.get(&interp.range) {
228 return Some((name, default));
229 }
230 None
231 })
232 .collect();
233
234 for (name, else_val) in updates {
235 if let Some(interp) = self.function_interps.get_mut(&name) {
236 interp.else_value = Some(else_val);
237 }
238 }
239 }
240}
241
242impl Default for CompletedModel {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[derive(Debug, Clone)]
250pub struct FunctionInterpretation {
251 pub name: Spur,
253 pub arity: usize,
255 pub domain: SmallVec<[SortId; 4]>,
257 pub range: SortId,
259 pub entries: Vec<FunctionEntry>,
261 pub else_value: Option<TermId>,
263 pub projections: Vec<Option<ProjectionFunctionDef>>,
265}
266
267impl FunctionInterpretation {
268 pub fn new(name: Spur, domain: SmallVec<[SortId; 4]>, range: SortId) -> Self {
270 let arity = domain.len();
271 Self {
272 name,
273 arity,
274 domain,
275 range,
276 entries: Vec::new(),
277 else_value: None,
278 projections: vec![None; arity],
279 }
280 }
281
282 pub fn add_entry(&mut self, args: Vec<TermId>, result: TermId) {
284 if args.len() == self.arity {
285 self.entries.push(FunctionEntry { args, result });
286 }
287 }
288
289 pub fn lookup(&self, args: &[TermId]) -> Option<TermId> {
291 for entry in &self.entries {
292 if entry.args == args {
293 return Some(entry.result);
294 }
295 }
296 self.else_value
297 }
298
299 pub fn is_constant(&self) -> bool {
301 self.arity == 0
302 }
303
304 pub fn is_partial(&self) -> bool {
306 self.else_value.is_none() && !self.entries.is_empty()
307 }
308
309 pub fn max_occurrence_result(&self) -> Option<TermId> {
311 if self.entries.is_empty() {
312 return None;
313 }
314
315 let mut counts: FxHashMap<TermId, usize> = FxHashMap::default();
316 for entry in &self.entries {
317 *counts.entry(entry.result).or_insert(0) += 1;
318 }
319
320 counts
321 .into_iter()
322 .max_by_key(|(_, count)| *count)
323 .map(|(term, _)| term)
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct FunctionEntry {
330 pub args: Vec<TermId>,
332 pub result: TermId,
334}
335
336#[derive(Debug, Clone)]
338pub struct ProjectionFunctionDef {
339 pub arg_index: usize,
341 pub sort: SortId,
343 pub values: Vec<TermId>,
345 pub value_to_term: FxHashMap<TermId, TermId>,
347 pub term_to_value: FxHashMap<TermId, TermId>,
349}
350
351impl ProjectionFunctionDef {
352 pub fn new(arg_index: usize, sort: SortId) -> Self {
354 Self {
355 arg_index,
356 sort,
357 values: Vec::new(),
358 value_to_term: FxHashMap::default(),
359 term_to_value: FxHashMap::default(),
360 }
361 }
362
363 pub fn add_value(&mut self, value: TermId, term: TermId) {
365 if !self.values.contains(&value) {
366 self.values.push(value);
367 }
368 self.value_to_term.insert(value, term);
369 self.term_to_value.insert(term, value);
370 }
371
372 pub fn project(&self, value: TermId) -> Option<TermId> {
374 self.value_to_term.get(&value).copied()
375 }
376}
377
378#[derive(Debug)]
380pub struct ModelCompleter {
381 macro_solver: MacroSolver,
383 model_fixer: ModelFixer,
385 uninterp_handler: UninterpretedSortHandler,
387 cache: FxHashMap<u64, CompletedModel>,
389 stats: CompletionStats,
391}
392
393impl ModelCompleter {
394 pub fn new() -> Self {
396 Self {
397 macro_solver: MacroSolver::new(),
398 model_fixer: ModelFixer::new(),
399 uninterp_handler: UninterpretedSortHandler::new(),
400 cache: FxHashMap::default(),
401 stats: CompletionStats::default(),
402 }
403 }
404
405 pub fn complete(
417 &mut self,
418 partial_model: &FxHashMap<TermId, TermId>,
419 quantifiers: &[QuantifiedFormula],
420 manager: &mut TermManager,
421 ) -> Result<CompletedModel, CompletionError> {
422 self.stats.num_completions += 1;
423
424 let mut completed = CompletedModel::new();
426 completed.assignments = partial_model.clone();
427
428 self.extract_function_interpretations(&mut completed, manager);
430
431 let macro_results = self.macro_solver.solve_macros(quantifiers, manager)?;
436 for (func_name, macro_interp) in macro_results {
437 completed
439 .function_interps
440 .entry(func_name)
441 .or_insert(macro_interp);
442 }
445
446 self.model_fixer
448 .fix_model(&mut completed, quantifiers, manager)?;
449
450 self.uninterp_handler
452 .complete_universes(&mut completed, manager)?;
453
454 self.set_default_values(&mut completed, manager)?;
456
457 completed.collect_universes_from_model(quantifiers, manager);
461
462 self.set_default_values(&mut completed, manager)?;
464
465 completed.complete_function_interpretations();
467
468 Ok(completed)
469 }
470
471 fn eval_to_const(term: TermId, manager: &mut TermManager) -> Option<TermId> {
478 let t = manager.get(term)?.clone();
479 match &t.kind {
480 TermKind::IntConst(_) | TermKind::RealConst(_) => Some(term),
481 TermKind::Neg(arg) => {
482 let inner = Self::eval_to_const(*arg, manager)?;
483 let inner_t = manager.get(inner)?.clone();
484 match &inner_t.kind {
485 TermKind::IntConst(n) => {
486 let neg_n = -n.clone();
487 Some(manager.mk_int(neg_n))
488 }
489 TermKind::RealConst(r) => {
490 let neg_r = -*r;
491 Some(manager.mk_real(neg_r))
492 }
493 _ => None,
494 }
495 }
496 TermKind::Add(args) => {
497 let args_cloned: SmallVec<[TermId; 4]> = args.clone();
498 let mut sum_r = Rational64::from_integer(0);
500 let mut all_real = true;
501 let mut all_int = true;
502 let mut sum_i = num_bigint::BigInt::from(0i64);
503 for &arg in &args_cloned {
504 if let Some(c) = Self::eval_to_const(arg, manager) {
505 let ct = manager.get(c)?.clone();
506 match &ct.kind {
507 TermKind::RealConst(r) => {
508 sum_r += r;
509 all_int = false;
510 }
511 TermKind::IntConst(n) => {
512 sum_r += Rational64::from_integer(n.to_i64().unwrap_or(0));
513 sum_i += n.clone();
514 all_real = false;
515 }
516 _ => {
517 all_real = false;
518 all_int = false;
519 }
520 }
521 } else {
522 all_real = false;
523 all_int = false;
524 }
525 }
526 if all_int && !args_cloned.is_empty() {
527 Some(manager.mk_int(sum_i))
528 } else if all_real && !args_cloned.is_empty() {
529 Some(manager.mk_real(sum_r))
530 } else {
531 None
532 }
533 }
534 TermKind::Sub(lhs, rhs) => {
535 let (lhs_v, rhs_v) = (*lhs, *rhs);
536 let lc = Self::eval_to_const(lhs_v, manager)?;
537 let rc = Self::eval_to_const(rhs_v, manager)?;
538 let lct = manager.get(lc)?.clone();
539 let rct = manager.get(rc)?.clone();
540 match (&lct.kind, &rct.kind) {
541 (TermKind::IntConst(a), TermKind::IntConst(b)) => Some(manager.mk_int(a - b)),
542 (TermKind::RealConst(a), TermKind::RealConst(b)) => {
543 Some(manager.mk_real(a - b))
544 }
545 _ => None,
546 }
547 }
548 _ => None,
549 }
550 }
551
552 fn eval_arg(term: TermId, model: &CompletedModel, manager: &mut TermManager) -> TermId {
555 if let Some(val) = model.eval(term) {
557 return val;
558 }
559 if let Some(const_val) = Self::eval_to_const(term, manager) {
561 return const_val;
562 }
563 term
564 }
565
566 fn extract_function_interpretations(
574 &self,
575 model: &mut CompletedModel,
576 manager: &mut TermManager,
577 ) {
578 let mut func_entries: FxHashMap<Spur, Vec<FuncEntry>> = FxHashMap::default();
580
581 let apply_entries: Vec<(TermId, TermId)> = model
583 .assignments
584 .iter()
585 .filter_map(|(&term, &value)| {
586 if manager
587 .get(term)
588 .is_some_and(|t| matches!(t.kind, TermKind::Apply { .. }))
589 {
590 Some((term, value))
591 } else {
592 None
593 }
594 })
595 .collect();
596
597 for (term, value) in apply_entries {
598 let Some(t) = manager.get(term).cloned() else {
599 continue;
600 };
601 if let TermKind::Apply { func, args } = &t.kind {
602 let args_cloned: SmallVec<[TermId; 4]> = args.clone();
607 let evaluated_args: Vec<TermId> = args_cloned
608 .iter()
609 .map(|&arg| Self::eval_arg(arg, model, manager))
610 .collect();
611
612 let domain: SmallVec<[SortId; 4]> = args
613 .iter()
614 .map(|&arg| manager.get(arg).map_or(manager.sorts.int_sort, |a| a.sort))
615 .collect();
616
617 func_entries.entry(*func).or_default().push((
618 domain,
619 t.sort,
620 evaluated_args,
621 value,
622 ));
623 }
624 }
625
626 for (func_name, entries) in func_entries {
628 match model.function_interps.entry(func_name) {
629 std::collections::hash_map::Entry::Occupied(mut occupied) => {
630 let interp = occupied.get_mut();
632 for (_domain, _range, args, result) in entries {
633 let already_exists = interp.entries.iter().any(|e| e.args == args);
634 if !already_exists {
635 interp.add_entry(args, result);
636 }
637 }
638 }
639 std::collections::hash_map::Entry::Vacant(vacant) => {
640 if let Some((domain, range, first_args, first_result)) = entries.first() {
641 let mut interp =
643 FunctionInterpretation::new(func_name, domain.clone(), *range);
644 interp.add_entry(first_args.clone(), *first_result);
645 for (_, _, args, result) in entries.iter().skip(1) {
646 let already_exists = interp.entries.iter().any(|e| &e.args == args);
647 if !already_exists {
648 interp.add_entry(args.clone(), *result);
649 }
650 }
651 vacant.insert(interp);
652 }
653 }
654 }
655 }
656 }
657
658 fn set_default_values(
660 &mut self,
661 model: &mut CompletedModel,
662 manager: &mut TermManager,
663 ) -> Result<(), CompletionError> {
664 if !model.defaults.contains_key(&manager.sorts.bool_sort) {
666 model.set_default(manager.sorts.bool_sort, manager.mk_false());
667 }
668
669 if !model.defaults.contains_key(&manager.sorts.int_sort) {
671 model.set_default(manager.sorts.int_sort, manager.mk_int(BigInt::from(0)));
672 }
673
674 if !model.defaults.contains_key(&manager.sorts.real_sort) {
676 model.set_default(
677 manager.sorts.real_sort,
678 manager.mk_real(Rational64::from_integer(0)),
679 );
680 }
681
682 let defaults_to_set: Vec<(SortId, TermId)> = model
685 .universes
686 .iter()
687 .filter_map(|(sort, universe)| {
688 if !model.defaults.contains_key(sort) {
689 universe.first().map(|&first| (*sort, first))
690 } else {
691 None
692 }
693 })
694 .collect();
695
696 for (sort, value) in defaults_to_set {
697 model.set_default(sort, value);
698 }
699
700 Ok(())
701 }
702
703 pub fn stats(&self) -> &CompletionStats {
705 &self.stats
706 }
707}
708
709impl Default for ModelCompleter {
710 fn default() -> Self {
711 Self::new()
712 }
713}
714
715#[derive(Debug)]
721pub struct MacroSolver {
722 macros: FxHashMap<Spur, MacroDefinition>,
724 stats: MacroStats,
726}
727
728impl MacroSolver {
729 pub fn new() -> Self {
731 Self {
732 macros: FxHashMap::default(),
733 stats: MacroStats::default(),
734 }
735 }
736
737 pub fn solve_macros(
739 &mut self,
740 quantifiers: &[QuantifiedFormula],
741 manager: &mut TermManager,
742 ) -> Result<FxHashMap<Spur, FunctionInterpretation>, CompletionError> {
743 let mut results = FxHashMap::default();
744
745 for quant in quantifiers {
746 if let Some(macro_def) = self.try_extract_macro(quant, manager)? {
747 self.stats.num_macros_found += 1;
748 let interp = self.macro_to_interpretation(¯o_def, manager)?;
749 results.insert(macro_def.func_name, interp);
750 self.macros.insert(macro_def.func_name, macro_def);
751 }
752 }
753
754 Ok(results)
755 }
756
757 fn try_extract_macro(
759 &self,
760 quant: &QuantifiedFormula,
761 manager: &TermManager,
762 ) -> Result<Option<MacroDefinition>, CompletionError> {
763 let Some(body_term) = manager.get(quant.body) else {
765 return Ok(None);
766 };
767
768 if let TermKind::Eq(lhs, rhs) = &body_term.kind {
770 if let Some(macro_def) = self.try_extract_macro_from_eq(*lhs, *rhs, quant, manager)? {
772 return Ok(Some(macro_def));
773 }
774 if let Some(macro_def) = self.try_extract_macro_from_eq(*rhs, *lhs, quant, manager)? {
775 return Ok(Some(macro_def));
776 }
777 }
778
779 Ok(None)
780 }
781
782 fn try_extract_macro_from_eq(
784 &self,
785 lhs: TermId,
786 rhs: TermId,
787 quant: &QuantifiedFormula,
788 manager: &TermManager,
789 ) -> Result<Option<MacroDefinition>, CompletionError> {
790 let Some(lhs_term) = manager.get(lhs) else {
791 return Ok(None);
792 };
793
794 if let TermKind::Apply { func, args } = &lhs_term.kind {
796 let mut is_macro = true;
798 for &arg in args.iter() {
799 if let Some(arg_term) = manager.get(arg)
800 && !matches!(arg_term.kind, TermKind::Var(_))
801 {
802 is_macro = false;
803 break;
804 }
805 }
806
807 if is_macro {
808 if !self.contains_function(rhs, *func, manager) {
810 return Ok(Some(MacroDefinition {
811 quantifier: quant.term,
812 func_name: *func,
813 bound_vars: quant.bound_vars.clone(),
814 body: rhs,
815 }));
816 }
817 }
818 }
819
820 Ok(None)
821 }
822
823 fn contains_function(&self, term: TermId, func: Spur, manager: &TermManager) -> bool {
825 let mut visited = FxHashSet::default();
826 self.contains_function_rec(term, func, manager, &mut visited)
827 }
828
829 fn contains_function_rec(
830 &self,
831 term: TermId,
832 func: Spur,
833 manager: &TermManager,
834 visited: &mut FxHashSet<TermId>,
835 ) -> bool {
836 if visited.contains(&term) {
837 return false;
838 }
839 visited.insert(term);
840
841 let Some(t) = manager.get(term) else {
842 return false;
843 };
844
845 match &t.kind {
846 TermKind::Apply { func: f, args } => {
847 if *f == func {
848 return true;
849 }
850 for &arg in args.iter() {
851 if self.contains_function_rec(arg, func, manager, visited) {
852 return true;
853 }
854 }
855 false
856 }
857 _ => {
858 let children = self.get_children(term, manager);
860 for child in children {
861 if self.contains_function_rec(child, func, manager, visited) {
862 return true;
863 }
864 }
865 false
866 }
867 }
868 }
869
870 fn get_children(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
872 let Some(t) = manager.get(term) else {
873 return vec![];
874 };
875
876 match &t.kind {
877 TermKind::Not(arg) | TermKind::Neg(arg) => vec![*arg],
878 TermKind::And(args)
879 | TermKind::Or(args)
880 | TermKind::Add(args)
881 | TermKind::Mul(args) => args.to_vec(),
882 TermKind::Sub(lhs, rhs)
883 | TermKind::Div(lhs, rhs)
884 | TermKind::Mod(lhs, rhs)
885 | TermKind::Eq(lhs, rhs)
886 | TermKind::Lt(lhs, rhs)
887 | TermKind::Le(lhs, rhs)
888 | TermKind::Gt(lhs, rhs)
889 | TermKind::Ge(lhs, rhs)
890 | TermKind::Implies(lhs, rhs) => vec![*lhs, *rhs],
891 TermKind::Ite(cond, then_br, else_br) => vec![*cond, *then_br, *else_br],
892 TermKind::Apply { args, .. } => args.to_vec(),
893 _ => vec![],
894 }
895 }
896
897 fn macro_to_interpretation(
905 &self,
906 macro_def: &MacroDefinition,
907 manager: &mut TermManager,
908 ) -> Result<FunctionInterpretation, CompletionError> {
909 let func_name = macro_def.func_name;
910
911 let domain: SmallVec<[SortId; 4]> =
913 macro_def.bound_vars.iter().map(|&(_, sort)| sort).collect();
914
915 let range = manager
917 .get(macro_def.body)
918 .map_or(manager.sorts.bool_sort, |t| t.sort);
919
920 let interp = FunctionInterpretation::new(func_name, domain, range);
921 Ok(interp)
922 }
923
924 pub fn stats(&self) -> &MacroStats {
926 &self.stats
927 }
928}
929
930impl Default for MacroSolver {
931 fn default() -> Self {
932 Self::new()
933 }
934}
935
936#[derive(Debug, Clone)]
938pub struct MacroDefinition {
939 pub quantifier: TermId,
941 pub func_name: Spur,
943 pub bound_vars: SmallVec<[(Spur, SortId); 4]>,
945 pub body: TermId,
947}
948
949#[derive(Debug)]
951pub struct ModelFixer {
952 projections: FxHashMap<SortId, Box<dyn ProjectionFunction>>,
954 stats: FixerStats,
956}
957
958impl ModelFixer {
959 pub fn new() -> Self {
961 Self {
962 projections: FxHashMap::default(),
963 stats: FixerStats::default(),
964 }
965 }
966
967 pub fn fix_model(
969 &mut self,
970 model: &mut CompletedModel,
971 quantifiers: &[QuantifiedFormula],
972 manager: &mut TermManager,
973 ) -> Result<(), CompletionError> {
974 self.stats.num_fixes += 1;
975
976 let partial_functions = self.collect_partial_functions(quantifiers, manager);
978
979 for func_name in partial_functions.iter() {
982 let has_interp = model.function_interps.contains_key(func_name);
984 if has_interp {
985 if let Some(interp) = model.function_interps.get_mut(func_name) {
987 for arg_idx in 0..interp.arity {
990 let sort = interp.domain[arg_idx];
991 if self.needs_projection(sort, manager) {
992 interp.projections[arg_idx] = None;
994 }
995 }
996 }
997 }
998 }
999
1000 Ok(())
1019 }
1020
1021 fn collect_partial_functions(
1023 &self,
1024 quantifiers: &[QuantifiedFormula],
1025 manager: &TermManager,
1026 ) -> FxHashSet<Spur> {
1027 let mut functions = FxHashSet::default();
1028
1029 for quant in quantifiers {
1030 self.collect_partial_functions_rec(quant.body, &mut functions, manager);
1031 }
1032
1033 functions
1034 }
1035
1036 fn collect_partial_functions_rec(
1037 &self,
1038 term: TermId,
1039 functions: &mut FxHashSet<Spur>,
1040 manager: &TermManager,
1041 ) {
1042 let Some(t) = manager.get(term) else {
1043 return;
1044 };
1045
1046 if let TermKind::Apply { func, args } = &t.kind {
1047 let has_vars = args.iter().any(|&arg| {
1049 if let Some(arg_t) = manager.get(arg) {
1050 matches!(arg_t.kind, TermKind::Var(_))
1051 } else {
1052 false
1053 }
1054 });
1055
1056 if has_vars {
1057 functions.insert(*func);
1058 }
1059
1060 for &arg in args.iter() {
1062 self.collect_partial_functions_rec(arg, functions, manager);
1063 }
1064 }
1065
1066 match &t.kind {
1068 TermKind::Not(arg) | TermKind::Neg(arg) => {
1069 self.collect_partial_functions_rec(*arg, functions, manager);
1070 }
1071 TermKind::And(args) | TermKind::Or(args) => {
1072 for &arg in args.iter() {
1073 self.collect_partial_functions_rec(arg, functions, manager);
1074 }
1075 }
1076 TermKind::Eq(lhs, rhs) | TermKind::Lt(lhs, rhs) | TermKind::Le(lhs, rhs) => {
1077 self.collect_partial_functions_rec(*lhs, functions, manager);
1078 self.collect_partial_functions_rec(*rhs, functions, manager);
1079 }
1080 _ => {}
1081 }
1082 }
1083
1084 fn add_projection_functions(
1086 &mut self,
1087 interp: &mut FunctionInterpretation,
1088 model: &CompletedModel,
1089 manager: &mut TermManager,
1090 ) -> Result<(), CompletionError> {
1091 for arg_idx in 0..interp.arity {
1093 let sort = interp.domain[arg_idx];
1094
1095 if self.needs_projection(sort, manager) {
1097 let proj_def = self.create_projection(interp, arg_idx, model, manager)?;
1098 interp.projections[arg_idx] = Some(proj_def);
1099 }
1100 }
1101
1102 Ok(())
1103 }
1104
1105 fn needs_projection(&self, sort: SortId, manager: &TermManager) -> bool {
1107 sort == manager.sorts.int_sort || sort == manager.sorts.real_sort
1109 }
1110
1111 fn create_projection(
1113 &mut self,
1114 interp: &FunctionInterpretation,
1115 arg_idx: usize,
1116 model: &CompletedModel,
1117 manager: &mut TermManager,
1118 ) -> Result<ProjectionFunctionDef, CompletionError> {
1119 let sort = interp.domain[arg_idx];
1120 let mut proj_def = ProjectionFunctionDef::new(arg_idx, sort);
1121
1122 for entry in &interp.entries {
1124 if let Some(&arg_term) = entry.args.get(arg_idx) {
1125 let value = model.eval(arg_term).unwrap_or(arg_term);
1127 proj_def.add_value(value, arg_term);
1128 }
1129 }
1130
1131 proj_def
1133 .values
1134 .sort_by(|a, b| self.compare_values(*a, *b, sort, manager));
1135
1136 Ok(proj_def)
1137 }
1138
1139 fn compare_values(
1141 &self,
1142 a: TermId,
1143 b: TermId,
1144 _sort: SortId,
1145 manager: &TermManager,
1146 ) -> Ordering {
1147 let a_term = manager.get(a);
1148 let b_term = manager.get(b);
1149
1150 if let (Some(at), Some(bt)) = (a_term, b_term) {
1151 if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
1153 return av.cmp(bv);
1154 }
1155
1156 if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
1158 return av.cmp(bv);
1159 }
1160
1161 match (&at.kind, &bt.kind) {
1163 (TermKind::False, TermKind::True) => return Ordering::Less,
1164 (TermKind::True, TermKind::False) => return Ordering::Greater,
1165 (TermKind::False, TermKind::False) | (TermKind::True, TermKind::True) => {
1166 return Ordering::Equal;
1167 }
1168 _ => {}
1169 }
1170 }
1171
1172 a.0.cmp(&b.0)
1174 }
1175
1176 pub fn stats(&self) -> &FixerStats {
1178 &self.stats
1179 }
1180}
1181
1182impl Default for ModelFixer {
1183 fn default() -> Self {
1184 Self::new()
1185 }
1186}
1187
1188pub trait ProjectionFunction: fmt::Debug + Send + Sync {
1190 fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool;
1192
1193 fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId;
1195}
1196
1197#[derive(Debug)]
1199pub struct ArithmeticProjection {
1200 is_int: bool,
1202}
1203
1204impl ArithmeticProjection {
1205 pub fn new(is_int: bool) -> Self {
1206 Self { is_int }
1207 }
1208}
1209
1210impl ProjectionFunction for ArithmeticProjection {
1211 fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool {
1212 let a_term = manager.get(a);
1213 let b_term = manager.get(b);
1214
1215 if let (Some(at), Some(bt)) = (a_term, b_term) {
1216 if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
1217 return av < bv;
1218 }
1219 if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
1220 return av < bv;
1221 }
1222 }
1223
1224 a.0 < b.0
1225 }
1226
1227 fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId {
1228 manager.mk_lt(x, y)
1229 }
1230}
1231
1232#[derive(Debug)]
1234pub struct UninterpretedSortHandler {
1235 max_universe_size: usize,
1237 stats: UninterpStats,
1239}
1240
1241impl UninterpretedSortHandler {
1242 pub fn new() -> Self {
1244 Self {
1245 max_universe_size: 8,
1246 stats: UninterpStats::default(),
1247 }
1248 }
1249
1250 pub fn with_max_size(max_size: usize) -> Self {
1252 let mut handler = Self::new();
1253 handler.max_universe_size = max_size;
1254 handler
1255 }
1256
1257 pub fn complete_universes(
1259 &mut self,
1260 model: &mut CompletedModel,
1261 manager: &mut TermManager,
1262 ) -> Result<(), CompletionError> {
1263 let uninterp_sorts = self.identify_uninterpreted_sorts(model, manager);
1265
1266 for sort in uninterp_sorts {
1267 if let crate::prelude::hash_map::Entry::Vacant(e) = model.universes.entry(sort) {
1268 let universe = self.create_finite_universe(sort, manager)?;
1270 e.insert(universe);
1271 self.stats.num_universes_created += 1;
1272 }
1273 }
1274
1275 Ok(())
1276 }
1277
1278 fn identify_uninterpreted_sorts(
1280 &self,
1281 model: &CompletedModel,
1282 manager: &TermManager,
1283 ) -> Vec<SortId> {
1284 let mut sorts = Vec::new();
1285
1286 for interp in model.function_interps.values() {
1288 for &sort in &interp.domain {
1289 if self.is_uninterpreted(sort, manager) && !sorts.contains(&sort) {
1290 sorts.push(sort);
1291 }
1292 }
1293 if self.is_uninterpreted(interp.range, manager) && !sorts.contains(&interp.range) {
1294 sorts.push(interp.range);
1295 }
1296 }
1297
1298 sorts
1299 }
1300
1301 fn is_uninterpreted(&self, sort: SortId, manager: &TermManager) -> bool {
1303 sort != manager.sorts.bool_sort
1305 && sort != manager.sorts.int_sort
1306 && sort != manager.sorts.real_sort
1307 }
1308
1309 fn create_finite_universe(
1311 &self,
1312 sort: SortId,
1313 manager: &mut TermManager,
1314 ) -> Result<Vec<TermId>, CompletionError> {
1315 let mut universe = Vec::new();
1316
1317 for i in 0..self.max_universe_size {
1319 let name = format!("u!{}", i);
1320 let const_id = manager.mk_var(&name, sort);
1321 universe.push(const_id);
1322 }
1323
1324 Ok(universe)
1325 }
1326
1327 pub fn stats(&self) -> &UninterpStats {
1329 &self.stats
1330 }
1331}
1332
1333impl Default for UninterpretedSortHandler {
1334 fn default() -> Self {
1335 Self::new()
1336 }
1337}
1338
1339#[derive(Debug, Clone)]
1341pub enum CompletionError {
1342 CompletionFailed(String),
1344 ResourceLimit,
1346 InvalidModel(String),
1348}
1349
1350impl fmt::Display for CompletionError {
1351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1352 match self {
1353 Self::CompletionFailed(msg) => write!(f, "Model completion failed: {}", msg),
1354 Self::ResourceLimit => write!(f, "Resource limit exceeded during completion"),
1355 Self::InvalidModel(msg) => write!(f, "Invalid model: {}", msg),
1356 }
1357 }
1358}
1359
1360impl core::error::Error for CompletionError {}
1361
1362#[derive(Debug, Clone, Default)]
1364pub struct CompletionStats {
1365 pub num_completions: usize,
1366 pub num_failures: usize,
1367}
1368
1369#[derive(Debug, Clone, Default)]
1371pub struct MacroStats {
1372 pub num_macros_found: usize,
1373 pub num_macros_applied: usize,
1374}
1375
1376#[derive(Debug, Clone, Default)]
1378pub struct FixerStats {
1379 pub num_fixes: usize,
1380 pub num_projections_created: usize,
1381}
1382
1383#[derive(Debug, Clone, Default)]
1385pub struct UninterpStats {
1386 pub num_universes_created: usize,
1387 pub total_universe_size: usize,
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392 use super::*;
1393 use oxiz_core::interner::Key;
1394
1395 #[test]
1396 fn test_completed_model_creation() {
1397 let model = CompletedModel::new();
1398 assert_eq!(model.assignments.len(), 0);
1399 assert_eq!(model.function_interps.len(), 0);
1400 }
1401
1402 #[test]
1403 fn test_completed_model_eval() {
1404 let mut model = CompletedModel::new();
1405 let term = TermId::new(1);
1406 let value = TermId::new(2);
1407
1408 model.set(term, value);
1409 assert_eq!(model.eval(term), Some(value));
1410 assert_eq!(model.eval(TermId::new(99)), None);
1411 }
1412
1413 #[test]
1414 fn test_function_interpretation_lookup() {
1415 let mut domain = SmallVec::new();
1417 domain.push(SortId::new(1));
1418 domain.push(SortId::new(1));
1419
1420 let mut interp = FunctionInterpretation::new(
1421 Spur::try_from_usize(1).expect("valid spur"),
1422 domain,
1423 SortId::new(1),
1424 );
1425
1426 let args = vec![TermId::new(1), TermId::new(2)];
1427 let result = TermId::new(10);
1428 interp.add_entry(args.clone(), result);
1429
1430 assert_eq!(interp.lookup(&args), Some(result));
1431 assert_eq!(interp.lookup(&[TermId::new(99)]), None);
1432 }
1433
1434 #[test]
1435 fn test_function_interpretation_else_value() {
1436 let mut interp = FunctionInterpretation::new(
1437 Spur::try_from_usize(1).expect("valid spur"),
1438 SmallVec::new(),
1439 SortId::new(1),
1440 );
1441
1442 let else_val = TermId::new(42);
1443 interp.else_value = Some(else_val);
1444
1445 assert_eq!(interp.lookup(&[TermId::new(99)]), Some(else_val));
1446 }
1447
1448 #[test]
1449 fn test_function_interpretation_max_occurrence() {
1450 let mut domain = SmallVec::new();
1452 domain.push(SortId::new(1));
1453
1454 let mut interp = FunctionInterpretation::new(
1455 Spur::try_from_usize(1).expect("valid spur"),
1456 domain,
1457 SortId::new(1),
1458 );
1459
1460 let result1 = TermId::new(10);
1461 let result2 = TermId::new(20);
1462
1463 interp.add_entry(vec![TermId::new(1)], result1);
1464 interp.add_entry(vec![TermId::new(2)], result1);
1465 interp.add_entry(vec![TermId::new(3)], result2);
1466
1467 assert_eq!(interp.max_occurrence_result(), Some(result1));
1468 }
1469
1470 #[test]
1471 fn test_projection_function_def() {
1472 let mut proj = ProjectionFunctionDef::new(0, SortId::new(1));
1473
1474 let value1 = TermId::new(1);
1475 let term1 = TermId::new(10);
1476 proj.add_value(value1, term1);
1477
1478 assert_eq!(proj.project(value1), Some(term1));
1479 assert_eq!(proj.values.len(), 1);
1480 }
1481
1482 #[test]
1483 fn test_model_completer_creation() {
1484 let completer = ModelCompleter::new();
1485 assert_eq!(completer.stats.num_completions, 0);
1486 }
1487
1488 #[test]
1489 fn test_macro_solver_creation() {
1490 let solver = MacroSolver::new();
1491 assert_eq!(solver.stats.num_macros_found, 0);
1492 }
1493
1494 #[test]
1495 fn test_model_fixer_creation() {
1496 let fixer = ModelFixer::new();
1497 assert_eq!(fixer.stats.num_fixes, 0);
1498 }
1499
1500 #[test]
1501 fn test_uninterpreted_sort_handler_creation() {
1502 let handler = UninterpretedSortHandler::new();
1503 assert_eq!(handler.max_universe_size, 8);
1504 }
1505
1506 #[test]
1507 fn test_uninterpreted_sort_handler_custom_size() {
1508 let handler = UninterpretedSortHandler::with_max_size(16);
1509 assert_eq!(handler.max_universe_size, 16);
1510 }
1511
1512 #[test]
1513 fn test_arithmetic_projection() {
1514 let proj = ArithmeticProjection::new(true);
1515 assert!(proj.is_int);
1516 }
1517
1518 #[test]
1519 fn test_completion_error_display() {
1520 let err = CompletionError::CompletionFailed("test".to_string());
1521 assert!(format!("{}", err).contains("test"));
1522 }
1523}