1#![allow(missing_docs)]
21
22#[allow(unused_imports)]
23use crate::prelude::*;
24use core::cmp::Ordering;
25use core::fmt;
26use num_bigint::BigInt;
27use num_rational::Rational64;
28use num_traits::ToPrimitive;
29use oxiz_core::ast::{TermId, TermKind, TermManager};
30use oxiz_core::interner::Spur;
31use oxiz_core::sort::SortId;
32use smallvec::SmallVec;
33
34type FuncEntry = (SmallVec<[SortId; 4]>, SortId, Vec<TermId>, TermId);
36
37use super::QuantifiedFormula;
38
39const MAX_UNIVERSE_SIZE: usize = 1000;
41
42#[derive(Debug, Clone)]
44pub struct CompletedModel {
45 pub assignments: FxHashMap<TermId, TermId>,
47 pub function_interps: FxHashMap<Spur, FunctionInterpretation>,
49 pub universes: FxHashMap<SortId, Vec<TermId>>,
51 pub defaults: FxHashMap<SortId, TermId>,
53 pub generation: u32,
55}
56
57impl CompletedModel {
58 pub fn new() -> Self {
60 Self {
61 assignments: FxHashMap::default(),
62 function_interps: FxHashMap::default(),
63 universes: FxHashMap::default(),
64 defaults: FxHashMap::default(),
65 generation: 0,
66 }
67 }
68
69 pub fn eval(&self, term: TermId) -> Option<TermId> {
71 self.assignments.get(&term).copied()
72 }
73
74 pub fn set(&mut self, term: TermId, value: TermId) {
76 self.assignments.insert(term, value);
77 }
78
79 pub fn universe(&self, sort: SortId) -> Option<&[TermId]> {
81 self.universes.get(&sort).map(|v| v.as_slice())
82 }
83
84 pub fn add_to_universe(&mut self, sort: SortId, value: TermId) {
86 self.universes.entry(sort).or_default().push(value);
87 }
88
89 pub fn default_value(&self, sort: SortId) -> Option<TermId> {
91 self.defaults.get(&sort).copied()
92 }
93
94 pub fn set_default(&mut self, sort: SortId, value: TermId) {
96 self.defaults.insert(sort, value);
97 }
98
99 pub fn has_uninterpreted_sort(&self, sort: SortId) -> bool {
101 self.universes.contains_key(&sort)
102 }
103
104 pub fn eval_apply(&self, func: Spur, evaluated_args: &[TermId]) -> Option<TermId> {
109 if let Some(interp) = self.function_interps.get(&func) {
110 if let Some(result) = interp.lookup(evaluated_args) {
112 return Some(result);
113 }
114 if let Some(else_val) = interp.else_value {
116 return Some(else_val);
117 }
118 if let Some(default) = self.defaults.get(&interp.range) {
120 return Some(*default);
121 }
122 }
123 None
124 }
125
126 pub fn collect_universes_from_model(
133 &mut self,
134 quantifiers: &[QuantifiedFormula],
135 manager: &TermManager,
136 ) {
137 let mut needed_sorts: FxHashSet<SortId> = FxHashSet::default();
139 for quant in quantifiers {
140 for &(_name, sort) in &quant.bound_vars {
141 needed_sorts.insert(sort);
142 }
143 }
144
145 for sort in needed_sorts {
147 if self.universes.contains_key(&sort) {
149 continue;
150 }
151
152 let mut universe_values: Vec<TermId> = Vec::new();
153 let mut seen: FxHashSet<TermId> = FxHashSet::default();
154
155 for (&term, &value) in &self.assignments {
157 if let Some(t) = manager.get(term) {
159 if t.sort == sort && seen.insert(value) {
160 universe_values.push(value);
161 }
162 }
163 if let Some(v) = manager.get(value) {
165 if v.sort == sort && seen.insert(value) {
166 universe_values.push(value);
167 }
168 }
169 }
170
171 for interp in self.function_interps.values() {
173 for entry in &interp.entries {
174 for (i, &arg) in entry.args.iter().enumerate() {
176 if i < interp.domain.len() && interp.domain[i] == sort && seen.insert(arg) {
177 universe_values.push(arg);
178 }
179 }
180 if interp.range == sort && seen.insert(entry.result) {
182 universe_values.push(entry.result);
183 }
184 }
185 }
186
187 universe_values.truncate(MAX_UNIVERSE_SIZE);
189
190 if !universe_values.is_empty() {
191 self.universes.insert(sort, universe_values);
192 }
193 }
194 }
195
196 pub fn complete_function_interpretations(&mut self) {
211 let updates: Vec<(Spur, TermId)> = self
213 .function_interps
214 .iter()
215 .filter_map(|(&name, interp)| {
216 if interp.else_value.is_some() {
217 return None;
218 }
219 if !interp.entries.is_empty() {
223 return None;
224 }
225 if let Some(&default) = self.defaults.get(&interp.range) {
227 return Some((name, default));
228 }
229 None
230 })
231 .collect();
232
233 for (name, else_val) in updates {
234 if let Some(interp) = self.function_interps.get_mut(&name) {
235 interp.else_value = Some(else_val);
236 }
237 }
238 }
239}
240
241impl Default for CompletedModel {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct FunctionInterpretation {
250 pub name: Spur,
252 pub arity: usize,
254 pub domain: SmallVec<[SortId; 4]>,
256 pub range: SortId,
258 pub entries: Vec<FunctionEntry>,
260 pub else_value: Option<TermId>,
262 pub projections: Vec<Option<ProjectionFunctionDef>>,
264}
265
266impl FunctionInterpretation {
267 pub fn new(name: Spur, domain: SmallVec<[SortId; 4]>, range: SortId) -> Self {
269 let arity = domain.len();
270 Self {
271 name,
272 arity,
273 domain,
274 range,
275 entries: Vec::new(),
276 else_value: None,
277 projections: vec![None; arity],
278 }
279 }
280
281 pub fn add_entry(&mut self, args: Vec<TermId>, result: TermId) {
283 if args.len() == self.arity {
284 self.entries.push(FunctionEntry { args, result });
285 }
286 }
287
288 pub fn lookup(&self, args: &[TermId]) -> Option<TermId> {
290 for entry in &self.entries {
291 if entry.args == args {
292 return Some(entry.result);
293 }
294 }
295 self.else_value
296 }
297
298 pub fn is_constant(&self) -> bool {
300 self.arity == 0
301 }
302
303 pub fn is_partial(&self) -> bool {
305 self.else_value.is_none() && !self.entries.is_empty()
306 }
307
308 pub fn max_occurrence_result(&self) -> Option<TermId> {
310 if self.entries.is_empty() {
311 return None;
312 }
313
314 let mut counts: FxHashMap<TermId, usize> = FxHashMap::default();
315 for entry in &self.entries {
316 *counts.entry(entry.result).or_insert(0) += 1;
317 }
318
319 counts
320 .into_iter()
321 .max_by_key(|(_, count)| *count)
322 .map(|(term, _)| term)
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct FunctionEntry {
329 pub args: Vec<TermId>,
331 pub result: TermId,
333}
334
335#[derive(Debug, Clone)]
337pub struct ProjectionFunctionDef {
338 pub arg_index: usize,
340 pub sort: SortId,
342 pub values: Vec<TermId>,
344 pub value_to_term: FxHashMap<TermId, TermId>,
346 pub term_to_value: FxHashMap<TermId, TermId>,
348}
349
350impl ProjectionFunctionDef {
351 pub fn new(arg_index: usize, sort: SortId) -> Self {
353 Self {
354 arg_index,
355 sort,
356 values: Vec::new(),
357 value_to_term: FxHashMap::default(),
358 term_to_value: FxHashMap::default(),
359 }
360 }
361
362 pub fn add_value(&mut self, value: TermId, term: TermId) {
364 if !self.values.contains(&value) {
365 self.values.push(value);
366 }
367 self.value_to_term.insert(value, term);
368 self.term_to_value.insert(term, value);
369 }
370
371 pub fn project(&self, value: TermId) -> Option<TermId> {
373 self.value_to_term.get(&value).copied()
374 }
375}
376
377#[derive(Debug)]
379pub struct ModelCompleter {
380 macro_solver: MacroSolver,
382 model_fixer: ModelFixer,
384 uninterp_handler: UninterpretedSortHandler,
386 cache: FxHashMap<u64, CompletedModel>,
388 stats: CompletionStats,
390}
391
392impl ModelCompleter {
393 pub fn new() -> Self {
395 Self {
396 macro_solver: MacroSolver::new(),
397 model_fixer: ModelFixer::new(),
398 uninterp_handler: UninterpretedSortHandler::new(),
399 cache: FxHashMap::default(),
400 stats: CompletionStats::default(),
401 }
402 }
403
404 pub fn complete(
416 &mut self,
417 partial_model: &FxHashMap<TermId, TermId>,
418 quantifiers: &[QuantifiedFormula],
419 manager: &mut TermManager,
420 ) -> Result<CompletedModel, CompletionError> {
421 self.stats.num_completions += 1;
422
423 let mut completed = CompletedModel::new();
425 completed.assignments = partial_model.clone();
426
427 self.extract_function_interpretations(&mut completed, manager);
429
430 let macro_results = self.macro_solver.solve_macros(quantifiers, manager)?;
435 for (func_name, macro_interp) in macro_results {
436 completed
438 .function_interps
439 .entry(func_name)
440 .or_insert(macro_interp);
441 }
444
445 self.model_fixer
447 .fix_model(&mut completed, quantifiers, manager)?;
448
449 self.uninterp_handler
451 .complete_universes(&mut completed, manager)?;
452
453 self.set_default_values(&mut completed, manager)?;
455
456 completed.collect_universes_from_model(quantifiers, manager);
460
461 self.set_default_values(&mut completed, manager)?;
463
464 completed.complete_function_interpretations();
466
467 Ok(completed)
468 }
469
470 fn eval_to_const(term: TermId, manager: &mut TermManager) -> Option<TermId> {
477 let t = manager.get(term)?.clone();
478 match &t.kind {
479 TermKind::IntConst(_) | TermKind::RealConst(_) => Some(term),
480 TermKind::Neg(arg) => {
481 let inner = Self::eval_to_const(*arg, manager)?;
482 let inner_t = manager.get(inner)?.clone();
483 match &inner_t.kind {
484 TermKind::IntConst(n) => {
485 let neg_n = -n.clone();
486 Some(manager.mk_int(neg_n))
487 }
488 TermKind::RealConst(r) => {
489 let neg_r = -*r;
490 Some(manager.mk_real(neg_r))
491 }
492 _ => None,
493 }
494 }
495 TermKind::Add(args) => {
496 let args_cloned: SmallVec<[TermId; 4]> = args.clone();
497 let mut sum_r = Rational64::from_integer(0);
499 let mut all_real = true;
500 let mut all_int = true;
501 let mut sum_i = num_bigint::BigInt::from(0i64);
502 for &arg in &args_cloned {
503 if let Some(c) = Self::eval_to_const(arg, manager) {
504 let ct = manager.get(c)?.clone();
505 match &ct.kind {
506 TermKind::RealConst(r) => {
507 sum_r += r;
508 all_int = false;
509 }
510 TermKind::IntConst(n) => {
511 sum_r += Rational64::from_integer(n.to_i64().unwrap_or(0));
512 sum_i += n.clone();
513 all_real = false;
514 }
515 _ => {
516 all_real = false;
517 all_int = false;
518 }
519 }
520 } else {
521 all_real = false;
522 all_int = false;
523 }
524 }
525 if all_int && !args_cloned.is_empty() {
526 Some(manager.mk_int(sum_i))
527 } else if all_real && !args_cloned.is_empty() {
528 Some(manager.mk_real(sum_r))
529 } else {
530 None
531 }
532 }
533 TermKind::Sub(lhs, rhs) => {
534 let (lhs_v, rhs_v) = (*lhs, *rhs);
535 let lc = Self::eval_to_const(lhs_v, manager)?;
536 let rc = Self::eval_to_const(rhs_v, manager)?;
537 let lct = manager.get(lc)?.clone();
538 let rct = manager.get(rc)?.clone();
539 match (&lct.kind, &rct.kind) {
540 (TermKind::IntConst(a), TermKind::IntConst(b)) => Some(manager.mk_int(a - b)),
541 (TermKind::RealConst(a), TermKind::RealConst(b)) => {
542 Some(manager.mk_real(a - b))
543 }
544 _ => None,
545 }
546 }
547 _ => None,
548 }
549 }
550
551 fn eval_arg(term: TermId, model: &CompletedModel, manager: &mut TermManager) -> TermId {
554 if let Some(val) = model.eval(term) {
556 return val;
557 }
558 if let Some(const_val) = Self::eval_to_const(term, manager) {
560 return const_val;
561 }
562 term
563 }
564
565 fn extract_function_interpretations(
573 &self,
574 model: &mut CompletedModel,
575 manager: &mut TermManager,
576 ) {
577 let mut func_entries: FxHashMap<Spur, Vec<FuncEntry>> = FxHashMap::default();
579
580 let apply_entries: Vec<(TermId, TermId)> = model
582 .assignments
583 .iter()
584 .filter_map(|(&term, &value)| {
585 if manager
586 .get(term)
587 .is_some_and(|t| matches!(t.kind, TermKind::Apply { .. }))
588 {
589 Some((term, value))
590 } else {
591 None
592 }
593 })
594 .collect();
595
596 for (term, value) in apply_entries {
597 let Some(t) = manager.get(term).cloned() else {
598 continue;
599 };
600 if let TermKind::Apply { func, args } = &t.kind {
601 let args_cloned: SmallVec<[TermId; 4]> = args.clone();
606 let evaluated_args: Vec<TermId> = args_cloned
607 .iter()
608 .map(|&arg| Self::eval_arg(arg, model, manager))
609 .collect();
610
611 let domain: SmallVec<[SortId; 4]> = args
612 .iter()
613 .map(|&arg| manager.get(arg).map_or(manager.sorts.int_sort, |a| a.sort))
614 .collect();
615
616 func_entries.entry(*func).or_default().push((
617 domain,
618 t.sort,
619 evaluated_args,
620 value,
621 ));
622 }
623 }
624
625 for (func_name, entries) in func_entries {
627 match model.function_interps.entry(func_name) {
628 std::collections::hash_map::Entry::Occupied(mut occupied) => {
629 let interp = occupied.get_mut();
631 for (_domain, _range, args, result) in entries {
632 let already_exists = interp.entries.iter().any(|e| e.args == args);
633 if !already_exists {
634 interp.add_entry(args, result);
635 }
636 }
637 }
638 std::collections::hash_map::Entry::Vacant(vacant) => {
639 if let Some((domain, range, first_args, first_result)) = entries.first() {
640 let mut interp =
642 FunctionInterpretation::new(func_name, domain.clone(), *range);
643 interp.add_entry(first_args.clone(), *first_result);
644 for (_, _, args, result) in entries.iter().skip(1) {
645 let already_exists = interp.entries.iter().any(|e| &e.args == args);
646 if !already_exists {
647 interp.add_entry(args.clone(), *result);
648 }
649 }
650 vacant.insert(interp);
651 }
652 }
653 }
654 }
655 }
656
657 fn set_default_values(
659 &mut self,
660 model: &mut CompletedModel,
661 manager: &mut TermManager,
662 ) -> Result<(), CompletionError> {
663 if !model.defaults.contains_key(&manager.sorts.bool_sort) {
665 model.set_default(manager.sorts.bool_sort, manager.mk_false());
666 }
667
668 if !model.defaults.contains_key(&manager.sorts.int_sort) {
670 model.set_default(manager.sorts.int_sort, manager.mk_int(BigInt::from(0)));
671 }
672
673 if !model.defaults.contains_key(&manager.sorts.real_sort) {
675 model.set_default(
676 manager.sorts.real_sort,
677 manager.mk_real(Rational64::from_integer(0)),
678 );
679 }
680
681 let defaults_to_set: Vec<(SortId, TermId)> = model
684 .universes
685 .iter()
686 .filter_map(|(sort, universe)| {
687 if !model.defaults.contains_key(sort) {
688 universe.first().map(|&first| (*sort, first))
689 } else {
690 None
691 }
692 })
693 .collect();
694
695 for (sort, value) in defaults_to_set {
696 model.set_default(sort, value);
697 }
698
699 Ok(())
700 }
701
702 pub fn stats(&self) -> &CompletionStats {
704 &self.stats
705 }
706}
707
708impl Default for ModelCompleter {
709 fn default() -> Self {
710 Self::new()
711 }
712}
713
714#[derive(Debug)]
720pub struct MacroSolver {
721 macros: FxHashMap<Spur, MacroDefinition>,
723 stats: MacroStats,
725}
726
727impl MacroSolver {
728 pub fn new() -> Self {
730 Self {
731 macros: FxHashMap::default(),
732 stats: MacroStats::default(),
733 }
734 }
735
736 pub fn solve_macros(
738 &mut self,
739 quantifiers: &[QuantifiedFormula],
740 manager: &mut TermManager,
741 ) -> Result<FxHashMap<Spur, FunctionInterpretation>, CompletionError> {
742 let mut results = FxHashMap::default();
743
744 for quant in quantifiers {
745 if let Some(macro_def) = self.try_extract_macro(quant, manager)? {
746 self.stats.num_macros_found += 1;
747 let interp = self.macro_to_interpretation(¯o_def, manager)?;
748 results.insert(macro_def.func_name, interp);
749 self.macros.insert(macro_def.func_name, macro_def);
750 }
751 }
752
753 Ok(results)
754 }
755
756 fn try_extract_macro(
758 &self,
759 quant: &QuantifiedFormula,
760 manager: &TermManager,
761 ) -> Result<Option<MacroDefinition>, CompletionError> {
762 let Some(body_term) = manager.get(quant.body) else {
764 return Ok(None);
765 };
766
767 if let TermKind::Eq(lhs, rhs) = &body_term.kind {
769 if let Some(macro_def) = self.try_extract_macro_from_eq(*lhs, *rhs, quant, manager)? {
771 return Ok(Some(macro_def));
772 }
773 if let Some(macro_def) = self.try_extract_macro_from_eq(*rhs, *lhs, quant, manager)? {
774 return Ok(Some(macro_def));
775 }
776 }
777
778 Ok(None)
779 }
780
781 fn try_extract_macro_from_eq(
783 &self,
784 lhs: TermId,
785 rhs: TermId,
786 quant: &QuantifiedFormula,
787 manager: &TermManager,
788 ) -> Result<Option<MacroDefinition>, CompletionError> {
789 let Some(lhs_term) = manager.get(lhs) else {
790 return Ok(None);
791 };
792
793 if let TermKind::Apply { func, args } = &lhs_term.kind {
795 let mut is_macro = true;
797 for &arg in args.iter() {
798 if let Some(arg_term) = manager.get(arg)
799 && !matches!(arg_term.kind, TermKind::Var(_))
800 {
801 is_macro = false;
802 break;
803 }
804 }
805
806 if is_macro {
807 if !self.contains_function(rhs, *func, manager) {
809 return Ok(Some(MacroDefinition {
810 quantifier: quant.term,
811 func_name: *func,
812 bound_vars: quant.bound_vars.clone(),
813 body: rhs,
814 }));
815 }
816 }
817 }
818
819 Ok(None)
820 }
821
822 fn contains_function(&self, term: TermId, func: Spur, manager: &TermManager) -> bool {
824 let mut visited = FxHashSet::default();
825 self.contains_function_rec(term, func, manager, &mut visited)
826 }
827
828 fn contains_function_rec(
829 &self,
830 term: TermId,
831 func: Spur,
832 manager: &TermManager,
833 visited: &mut FxHashSet<TermId>,
834 ) -> bool {
835 if visited.contains(&term) {
836 return false;
837 }
838 visited.insert(term);
839
840 let Some(t) = manager.get(term) else {
841 return false;
842 };
843
844 match &t.kind {
845 TermKind::Apply { func: f, args } => {
846 if *f == func {
847 return true;
848 }
849 for &arg in args.iter() {
850 if self.contains_function_rec(arg, func, manager, visited) {
851 return true;
852 }
853 }
854 false
855 }
856 _ => {
857 let children = self.get_children(term, manager);
859 for child in children {
860 if self.contains_function_rec(child, func, manager, visited) {
861 return true;
862 }
863 }
864 false
865 }
866 }
867 }
868
869 fn get_children(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
871 let Some(t) = manager.get(term) else {
872 return vec![];
873 };
874
875 match &t.kind {
876 TermKind::Not(arg) | TermKind::Neg(arg) => vec![*arg],
877 TermKind::And(args)
878 | TermKind::Or(args)
879 | TermKind::Add(args)
880 | TermKind::Mul(args) => args.to_vec(),
881 TermKind::Sub(lhs, rhs)
882 | TermKind::Div(lhs, rhs)
883 | TermKind::Mod(lhs, rhs)
884 | TermKind::Eq(lhs, rhs)
885 | TermKind::Lt(lhs, rhs)
886 | TermKind::Le(lhs, rhs)
887 | TermKind::Gt(lhs, rhs)
888 | TermKind::Ge(lhs, rhs)
889 | TermKind::Implies(lhs, rhs) => vec![*lhs, *rhs],
890 TermKind::Ite(cond, then_br, else_br) => vec![*cond, *then_br, *else_br],
891 TermKind::Apply { args, .. } => args.to_vec(),
892 _ => vec![],
893 }
894 }
895
896 fn macro_to_interpretation(
904 &self,
905 macro_def: &MacroDefinition,
906 manager: &mut TermManager,
907 ) -> Result<FunctionInterpretation, CompletionError> {
908 let func_name = macro_def.func_name;
909
910 let domain: SmallVec<[SortId; 4]> =
912 macro_def.bound_vars.iter().map(|&(_, sort)| sort).collect();
913
914 let range = manager
916 .get(macro_def.body)
917 .map_or(manager.sorts.bool_sort, |t| t.sort);
918
919 let interp = FunctionInterpretation::new(func_name, domain, range);
920 Ok(interp)
921 }
922
923 pub fn stats(&self) -> &MacroStats {
925 &self.stats
926 }
927}
928
929impl Default for MacroSolver {
930 fn default() -> Self {
931 Self::new()
932 }
933}
934
935#[derive(Debug, Clone)]
937pub struct MacroDefinition {
938 pub quantifier: TermId,
940 pub func_name: Spur,
942 pub bound_vars: SmallVec<[(Spur, SortId); 4]>,
944 pub body: TermId,
946}
947
948#[derive(Debug)]
950pub struct ModelFixer {
951 projections: FxHashMap<SortId, Box<dyn ProjectionFunction>>,
953 stats: FixerStats,
955}
956
957impl ModelFixer {
958 pub fn new() -> Self {
960 Self {
961 projections: FxHashMap::default(),
962 stats: FixerStats::default(),
963 }
964 }
965
966 pub fn fix_model(
968 &mut self,
969 model: &mut CompletedModel,
970 quantifiers: &[QuantifiedFormula],
971 manager: &mut TermManager,
972 ) -> Result<(), CompletionError> {
973 self.stats.num_fixes += 1;
974
975 let partial_functions = self.collect_partial_functions(quantifiers, manager);
977
978 for func_name in partial_functions.iter() {
981 let has_interp = model.function_interps.contains_key(func_name);
983 if has_interp {
984 if let Some(interp) = model.function_interps.get_mut(func_name) {
986 for arg_idx in 0..interp.arity {
989 let sort = interp.domain[arg_idx];
990 if self.needs_projection(sort, manager) {
991 interp.projections[arg_idx] = None;
993 }
994 }
995 }
996 }
997 }
998
999 Ok(())
1018 }
1019
1020 fn collect_partial_functions(
1022 &self,
1023 quantifiers: &[QuantifiedFormula],
1024 manager: &TermManager,
1025 ) -> FxHashSet<Spur> {
1026 let mut functions = FxHashSet::default();
1027
1028 for quant in quantifiers {
1029 self.collect_partial_functions_rec(quant.body, &mut functions, manager);
1030 }
1031
1032 functions
1033 }
1034
1035 fn collect_partial_functions_rec(
1036 &self,
1037 term: TermId,
1038 functions: &mut FxHashSet<Spur>,
1039 manager: &TermManager,
1040 ) {
1041 let Some(t) = manager.get(term) else {
1042 return;
1043 };
1044
1045 if let TermKind::Apply { func, args } = &t.kind {
1046 let has_vars = args.iter().any(|&arg| {
1048 if let Some(arg_t) = manager.get(arg) {
1049 matches!(arg_t.kind, TermKind::Var(_))
1050 } else {
1051 false
1052 }
1053 });
1054
1055 if has_vars {
1056 functions.insert(*func);
1057 }
1058
1059 for &arg in args.iter() {
1061 self.collect_partial_functions_rec(arg, functions, manager);
1062 }
1063 }
1064
1065 match &t.kind {
1067 TermKind::Not(arg) | TermKind::Neg(arg) => {
1068 self.collect_partial_functions_rec(*arg, functions, manager);
1069 }
1070 TermKind::And(args) | TermKind::Or(args) => {
1071 for &arg in args.iter() {
1072 self.collect_partial_functions_rec(arg, functions, manager);
1073 }
1074 }
1075 TermKind::Eq(lhs, rhs) | TermKind::Lt(lhs, rhs) | TermKind::Le(lhs, rhs) => {
1076 self.collect_partial_functions_rec(*lhs, functions, manager);
1077 self.collect_partial_functions_rec(*rhs, functions, manager);
1078 }
1079 _ => {}
1080 }
1081 }
1082
1083 fn add_projection_functions(
1085 &mut self,
1086 interp: &mut FunctionInterpretation,
1087 model: &CompletedModel,
1088 manager: &mut TermManager,
1089 ) -> Result<(), CompletionError> {
1090 for arg_idx in 0..interp.arity {
1092 let sort = interp.domain[arg_idx];
1093
1094 if self.needs_projection(sort, manager) {
1096 let proj_def = self.create_projection(interp, arg_idx, model, manager)?;
1097 interp.projections[arg_idx] = Some(proj_def);
1098 }
1099 }
1100
1101 Ok(())
1102 }
1103
1104 fn needs_projection(&self, sort: SortId, manager: &TermManager) -> bool {
1106 sort == manager.sorts.int_sort || sort == manager.sorts.real_sort
1108 }
1109
1110 fn create_projection(
1112 &mut self,
1113 interp: &FunctionInterpretation,
1114 arg_idx: usize,
1115 model: &CompletedModel,
1116 manager: &mut TermManager,
1117 ) -> Result<ProjectionFunctionDef, CompletionError> {
1118 let sort = interp.domain[arg_idx];
1119 let mut proj_def = ProjectionFunctionDef::new(arg_idx, sort);
1120
1121 for entry in &interp.entries {
1123 if let Some(&arg_term) = entry.args.get(arg_idx) {
1124 let value = model.eval(arg_term).unwrap_or(arg_term);
1126 proj_def.add_value(value, arg_term);
1127 }
1128 }
1129
1130 proj_def
1132 .values
1133 .sort_by(|a, b| self.compare_values(*a, *b, sort, manager));
1134
1135 Ok(proj_def)
1136 }
1137
1138 fn compare_values(
1140 &self,
1141 a: TermId,
1142 b: TermId,
1143 _sort: SortId,
1144 manager: &TermManager,
1145 ) -> Ordering {
1146 let a_term = manager.get(a);
1147 let b_term = manager.get(b);
1148
1149 if let (Some(at), Some(bt)) = (a_term, b_term) {
1150 if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
1152 return av.cmp(bv);
1153 }
1154
1155 if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
1157 return av.cmp(bv);
1158 }
1159
1160 match (&at.kind, &bt.kind) {
1162 (TermKind::False, TermKind::True) => return Ordering::Less,
1163 (TermKind::True, TermKind::False) => return Ordering::Greater,
1164 (TermKind::False, TermKind::False) | (TermKind::True, TermKind::True) => {
1165 return Ordering::Equal;
1166 }
1167 _ => {}
1168 }
1169 }
1170
1171 a.0.cmp(&b.0)
1173 }
1174
1175 pub fn stats(&self) -> &FixerStats {
1177 &self.stats
1178 }
1179}
1180
1181impl Default for ModelFixer {
1182 fn default() -> Self {
1183 Self::new()
1184 }
1185}
1186
1187pub trait ProjectionFunction: fmt::Debug + Send + Sync {
1189 fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool;
1191
1192 fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId;
1194}
1195
1196#[derive(Debug)]
1198pub struct ArithmeticProjection {
1199 is_int: bool,
1201}
1202
1203impl ArithmeticProjection {
1204 pub fn new(is_int: bool) -> Self {
1205 Self { is_int }
1206 }
1207}
1208
1209impl ProjectionFunction for ArithmeticProjection {
1210 fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool {
1211 let a_term = manager.get(a);
1212 let b_term = manager.get(b);
1213
1214 if let (Some(at), Some(bt)) = (a_term, b_term) {
1215 if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
1216 return av < bv;
1217 }
1218 if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
1219 return av < bv;
1220 }
1221 }
1222
1223 a.0 < b.0
1224 }
1225
1226 fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId {
1227 manager.mk_lt(x, y)
1228 }
1229}
1230
1231#[derive(Debug)]
1233pub struct UninterpretedSortHandler {
1234 max_universe_size: usize,
1236 stats: UninterpStats,
1238}
1239
1240impl UninterpretedSortHandler {
1241 pub fn new() -> Self {
1243 Self {
1244 max_universe_size: 8,
1245 stats: UninterpStats::default(),
1246 }
1247 }
1248
1249 pub fn with_max_size(max_size: usize) -> Self {
1251 let mut handler = Self::new();
1252 handler.max_universe_size = max_size;
1253 handler
1254 }
1255
1256 pub fn complete_universes(
1258 &mut self,
1259 model: &mut CompletedModel,
1260 manager: &mut TermManager,
1261 ) -> Result<(), CompletionError> {
1262 let uninterp_sorts = self.identify_uninterpreted_sorts(model, manager);
1264
1265 for sort in uninterp_sorts {
1266 if let crate::prelude::hash_map::Entry::Vacant(e) = model.universes.entry(sort) {
1267 let universe = self.create_finite_universe(sort, manager)?;
1269 e.insert(universe);
1270 self.stats.num_universes_created += 1;
1271 }
1272 }
1273
1274 Ok(())
1275 }
1276
1277 fn identify_uninterpreted_sorts(
1279 &self,
1280 model: &CompletedModel,
1281 manager: &TermManager,
1282 ) -> Vec<SortId> {
1283 let mut sorts = Vec::new();
1284
1285 for interp in model.function_interps.values() {
1287 for &sort in &interp.domain {
1288 if self.is_uninterpreted(sort, manager) && !sorts.contains(&sort) {
1289 sorts.push(sort);
1290 }
1291 }
1292 if self.is_uninterpreted(interp.range, manager) && !sorts.contains(&interp.range) {
1293 sorts.push(interp.range);
1294 }
1295 }
1296
1297 sorts
1298 }
1299
1300 fn is_uninterpreted(&self, sort: SortId, manager: &TermManager) -> bool {
1302 sort != manager.sorts.bool_sort
1304 && sort != manager.sorts.int_sort
1305 && sort != manager.sorts.real_sort
1306 }
1307
1308 fn create_finite_universe(
1310 &self,
1311 sort: SortId,
1312 manager: &mut TermManager,
1313 ) -> Result<Vec<TermId>, CompletionError> {
1314 let mut universe = Vec::new();
1315
1316 for i in 0..self.max_universe_size {
1318 let name = format!("u!{}", i);
1319 let const_id = manager.mk_var(&name, sort);
1320 universe.push(const_id);
1321 }
1322
1323 Ok(universe)
1324 }
1325
1326 pub fn stats(&self) -> &UninterpStats {
1328 &self.stats
1329 }
1330}
1331
1332impl Default for UninterpretedSortHandler {
1333 fn default() -> Self {
1334 Self::new()
1335 }
1336}
1337
1338#[derive(Debug, Clone)]
1340pub enum CompletionError {
1341 CompletionFailed(String),
1343 ResourceLimit,
1345 InvalidModel(String),
1347}
1348
1349impl fmt::Display for CompletionError {
1350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1351 match self {
1352 Self::CompletionFailed(msg) => write!(f, "Model completion failed: {}", msg),
1353 Self::ResourceLimit => write!(f, "Resource limit exceeded during completion"),
1354 Self::InvalidModel(msg) => write!(f, "Invalid model: {}", msg),
1355 }
1356 }
1357}
1358
1359impl core::error::Error for CompletionError {}
1360
1361#[derive(Debug, Clone, Default)]
1363pub struct CompletionStats {
1364 pub num_completions: usize,
1365 pub num_failures: usize,
1366}
1367
1368#[derive(Debug, Clone, Default)]
1370pub struct MacroStats {
1371 pub num_macros_found: usize,
1372 pub num_macros_applied: usize,
1373}
1374
1375#[derive(Debug, Clone, Default)]
1377pub struct FixerStats {
1378 pub num_fixes: usize,
1379 pub num_projections_created: usize,
1380}
1381
1382#[derive(Debug, Clone, Default)]
1384pub struct UninterpStats {
1385 pub num_universes_created: usize,
1386 pub total_universe_size: usize,
1387}
1388
1389#[cfg(test)]
1390mod tests {
1391 use super::*;
1392 use oxiz_core::interner::Key;
1393
1394 #[test]
1395 fn test_completed_model_creation() {
1396 let model = CompletedModel::new();
1397 assert_eq!(model.assignments.len(), 0);
1398 assert_eq!(model.function_interps.len(), 0);
1399 }
1400
1401 #[test]
1402 fn test_completed_model_eval() {
1403 let mut model = CompletedModel::new();
1404 let term = TermId::new(1);
1405 let value = TermId::new(2);
1406
1407 model.set(term, value);
1408 assert_eq!(model.eval(term), Some(value));
1409 assert_eq!(model.eval(TermId::new(99)), None);
1410 }
1411
1412 #[test]
1413 fn test_function_interpretation_lookup() {
1414 let mut domain = SmallVec::new();
1416 domain.push(SortId::new(1));
1417 domain.push(SortId::new(1));
1418
1419 let mut interp = FunctionInterpretation::new(
1420 Spur::try_from_usize(1).expect("valid spur"),
1421 domain,
1422 SortId::new(1),
1423 );
1424
1425 let args = vec![TermId::new(1), TermId::new(2)];
1426 let result = TermId::new(10);
1427 interp.add_entry(args.clone(), result);
1428
1429 assert_eq!(interp.lookup(&args), Some(result));
1430 assert_eq!(interp.lookup(&[TermId::new(99)]), None);
1431 }
1432
1433 #[test]
1434 fn test_function_interpretation_else_value() {
1435 let mut interp = FunctionInterpretation::new(
1436 Spur::try_from_usize(1).expect("valid spur"),
1437 SmallVec::new(),
1438 SortId::new(1),
1439 );
1440
1441 let else_val = TermId::new(42);
1442 interp.else_value = Some(else_val);
1443
1444 assert_eq!(interp.lookup(&[TermId::new(99)]), Some(else_val));
1445 }
1446
1447 #[test]
1448 fn test_function_interpretation_max_occurrence() {
1449 let mut domain = SmallVec::new();
1451 domain.push(SortId::new(1));
1452
1453 let mut interp = FunctionInterpretation::new(
1454 Spur::try_from_usize(1).expect("valid spur"),
1455 domain,
1456 SortId::new(1),
1457 );
1458
1459 let result1 = TermId::new(10);
1460 let result2 = TermId::new(20);
1461
1462 interp.add_entry(vec![TermId::new(1)], result1);
1463 interp.add_entry(vec![TermId::new(2)], result1);
1464 interp.add_entry(vec![TermId::new(3)], result2);
1465
1466 assert_eq!(interp.max_occurrence_result(), Some(result1));
1467 }
1468
1469 #[test]
1470 fn test_projection_function_def() {
1471 let mut proj = ProjectionFunctionDef::new(0, SortId::new(1));
1472
1473 let value1 = TermId::new(1);
1474 let term1 = TermId::new(10);
1475 proj.add_value(value1, term1);
1476
1477 assert_eq!(proj.project(value1), Some(term1));
1478 assert_eq!(proj.values.len(), 1);
1479 }
1480
1481 #[test]
1482 fn test_model_completer_creation() {
1483 let completer = ModelCompleter::new();
1484 assert_eq!(completer.stats.num_completions, 0);
1485 }
1486
1487 #[test]
1488 fn test_macro_solver_creation() {
1489 let solver = MacroSolver::new();
1490 assert_eq!(solver.stats.num_macros_found, 0);
1491 }
1492
1493 #[test]
1494 fn test_model_fixer_creation() {
1495 let fixer = ModelFixer::new();
1496 assert_eq!(fixer.stats.num_fixes, 0);
1497 }
1498
1499 #[test]
1500 fn test_uninterpreted_sort_handler_creation() {
1501 let handler = UninterpretedSortHandler::new();
1502 assert_eq!(handler.max_universe_size, 8);
1503 }
1504
1505 #[test]
1506 fn test_uninterpreted_sort_handler_custom_size() {
1507 let handler = UninterpretedSortHandler::with_max_size(16);
1508 assert_eq!(handler.max_universe_size, 16);
1509 }
1510
1511 #[test]
1512 fn test_arithmetic_projection() {
1513 let proj = ArithmeticProjection::new(true);
1514 assert!(proj.is_int);
1515 }
1516
1517 #[test]
1518 fn test_completion_error_display() {
1519 let err = CompletionError::CompletionFailed("test".to_string());
1520 assert!(format!("{}", err).contains("test"));
1521 }
1522}