1use lasso::Spur;
18use oxiz_core::ast::{TermId, TermKind, TermManager};
19use oxiz_core::sort::SortId;
20use rustc_hash::{FxHashMap, FxHashSet};
21use smallvec::SmallVec;
22
23#[derive(Debug, Clone)]
25pub struct QuantifiedFormula {
26 pub term: TermId,
28 pub bound_vars: SmallVec<[(Spur, SortId); 2]>,
30 pub body: TermId,
32 pub universal: bool,
34 pub instantiation_count: usize,
36 pub max_instantiations: usize,
38}
39
40impl QuantifiedFormula {
41 pub fn new(
43 term: TermId,
44 bound_vars: SmallVec<[(Spur, SortId); 2]>,
45 body: TermId,
46 universal: bool,
47 ) -> Self {
48 Self {
49 term,
50 bound_vars,
51 body,
52 universal,
53 instantiation_count: 0,
54 max_instantiations: 100,
55 }
56 }
57
58 pub fn can_instantiate(&self) -> bool {
60 self.instantiation_count < self.max_instantiations
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct Instantiation {
67 pub quantifier: TermId,
69 pub substitution: FxHashMap<Spur, TermId>,
71 pub result: TermId,
73}
74
75#[derive(Debug, Clone)]
77pub enum MBQIResult {
78 NoQuantifiers,
80 Satisfied,
82 NewInstantiations(Vec<Instantiation>),
84 Conflict(Vec<TermId>),
86 InstantiationLimit,
88}
89
90#[derive(Debug)]
92pub struct MBQISolver {
93 quantifiers: Vec<QuantifiedFormula>,
95 generated_instantiations: FxHashSet<(TermId, Vec<(Spur, TermId)>)>,
97 candidates_by_sort: FxHashMap<SortId, Vec<TermId>>,
99 max_total_instantiations: usize,
101 total_instantiation_count: usize,
103 enabled: bool,
105}
106
107impl Default for MBQISolver {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl MBQISolver {
114 pub fn new() -> Self {
116 Self {
117 quantifiers: Vec::new(),
118 generated_instantiations: FxHashSet::default(),
119 candidates_by_sort: FxHashMap::default(),
120 max_total_instantiations: 10000,
121 total_instantiation_count: 0,
122 enabled: true,
123 }
124 }
125
126 pub fn with_limit(max_total: usize) -> Self {
128 let mut solver = Self::new();
129 solver.max_total_instantiations = max_total;
130 solver
131 }
132
133 pub fn set_enabled(&mut self, enabled: bool) {
135 self.enabled = enabled;
136 }
137
138 pub fn is_enabled(&self) -> bool {
140 self.enabled
141 }
142
143 pub fn clear(&mut self) {
145 self.quantifiers.clear();
146 self.generated_instantiations.clear();
147 self.candidates_by_sort.clear();
148 self.total_instantiation_count = 0;
149 }
150
151 pub fn add_quantifier(&mut self, term: TermId, manager: &TermManager) {
153 let Some(t) = manager.get(term) else {
154 return;
155 };
156
157 match &t.kind {
158 TermKind::Forall { vars, body, .. } => {
159 self.quantifiers
160 .push(QuantifiedFormula::new(term, vars.clone(), *body, true));
161 }
162 TermKind::Exists { vars, body, .. } => {
163 self.quantifiers
167 .push(QuantifiedFormula::new(term, vars.clone(), *body, false));
168 }
169 _ => {}
170 }
171 }
172
173 pub fn add_candidate(&mut self, term: TermId, sort: SortId) {
175 self.candidates_by_sort.entry(sort).or_default().push(term);
176 }
177
178 pub fn collect_ground_terms(&mut self, term: TermId, manager: &TermManager) {
180 self.collect_ground_terms_rec(term, manager, &mut FxHashSet::default());
181 }
182
183 fn collect_ground_terms_rec(
184 &mut self,
185 term: TermId,
186 manager: &TermManager,
187 visited: &mut FxHashSet<TermId>,
188 ) {
189 if visited.contains(&term) {
190 return;
191 }
192 visited.insert(term);
193
194 let Some(t) = manager.get(term) else {
195 return;
196 };
197
198 match &t.kind {
201 TermKind::Var(_) => {
202 self.add_candidate(term, t.sort);
204 }
205 TermKind::IntConst(_) | TermKind::RealConst(_) | TermKind::BitVecConst { .. } => {
206 self.add_candidate(term, t.sort);
207 }
208 TermKind::Apply { args, .. } => {
209 self.add_candidate(term, t.sort);
211 for &arg in args {
212 self.collect_ground_terms_rec(arg, manager, visited);
213 }
214 }
215 TermKind::Add(args)
216 | TermKind::Mul(args)
217 | TermKind::And(args)
218 | TermKind::Or(args) => {
219 for &arg in args {
220 self.collect_ground_terms_rec(arg, manager, visited);
221 }
222 }
223 TermKind::Sub(lhs, rhs)
224 | TermKind::Div(lhs, rhs)
225 | TermKind::Mod(lhs, rhs)
226 | TermKind::Eq(lhs, rhs)
227 | TermKind::Lt(lhs, rhs)
228 | TermKind::Le(lhs, rhs)
229 | TermKind::Gt(lhs, rhs)
230 | TermKind::Ge(lhs, rhs) => {
231 self.collect_ground_terms_rec(*lhs, manager, visited);
232 self.collect_ground_terms_rec(*rhs, manager, visited);
233 }
234 TermKind::Not(arg) | TermKind::Neg(arg) => {
235 self.collect_ground_terms_rec(*arg, manager, visited);
236 }
237 TermKind::Ite(cond, then_br, else_br) => {
238 self.collect_ground_terms_rec(*cond, manager, visited);
239 self.collect_ground_terms_rec(*then_br, manager, visited);
240 self.collect_ground_terms_rec(*else_br, manager, visited);
241 }
242 TermKind::Forall { body, .. } | TermKind::Exists { body, .. } => {
243 let _ = body;
246 }
247 _ => {}
248 }
249 }
250
251 pub fn get_candidates(&self, sort: SortId) -> &[TermId] {
253 self.candidates_by_sort
254 .get(&sort)
255 .map_or(&[], |v| v.as_slice())
256 }
257
258 pub fn at_limit(&self) -> bool {
260 self.total_instantiation_count >= self.max_total_instantiations
261 }
262
263 pub fn check_with_model(
269 &mut self,
270 model: &FxHashMap<TermId, TermId>,
271 manager: &mut TermManager,
272 ) -> MBQIResult {
273 if !self.enabled {
274 return MBQIResult::NoQuantifiers;
275 }
276
277 if self.quantifiers.is_empty() {
278 return MBQIResult::NoQuantifiers;
279 }
280
281 if self.at_limit() {
282 return MBQIResult::InstantiationLimit;
283 }
284
285 let mut new_instantiations = Vec::new();
286
287 for i in 0..self.quantifiers.len() {
289 if !self.quantifiers[i].can_instantiate() {
290 continue;
291 }
292
293 if !self.quantifiers[i].universal {
294 continue;
297 }
298
299 let instantiations = self.find_counterexamples(i, model, manager);
301
302 for inst in instantiations {
303 if self.at_limit() {
304 break;
305 }
306
307 let mut key_vec: Vec<_> = inst.substitution.iter().map(|(&k, &v)| (k, v)).collect();
309 key_vec.sort_by_key(|(k, _)| *k);
310 let key = (inst.quantifier, key_vec);
311
312 if self.generated_instantiations.contains(&key) {
313 continue;
314 }
315
316 self.generated_instantiations.insert(key);
317 self.quantifiers[i].instantiation_count += 1;
318 self.total_instantiation_count += 1;
319 new_instantiations.push(inst);
320 }
321 }
322
323 if new_instantiations.is_empty() {
324 MBQIResult::Satisfied
325 } else {
326 MBQIResult::NewInstantiations(new_instantiations)
327 }
328 }
329
330 fn find_counterexamples(
332 &self,
333 quantifier_idx: usize,
334 model: &FxHashMap<TermId, TermId>,
335 manager: &mut TermManager,
336 ) -> Vec<Instantiation> {
337 let quant = &self.quantifiers[quantifier_idx];
338 let mut results = Vec::new();
339
340 let candidates = self.build_candidate_lists(&quant.bound_vars, manager);
342
343 let combinations = self.enumerate_combinations(&candidates, 10); for combo in combinations {
347 let mut subst: FxHashMap<Spur, TermId> = FxHashMap::default();
349 for (i, &candidate) in combo.iter().enumerate() {
350 if i < quant.bound_vars.len() {
351 subst.insert(quant.bound_vars[i].0, candidate);
352 }
353 }
354
355 let ground_body = self.apply_substitution(quant.body, &subst, manager);
357
358 let evaluated = self.evaluate_under_model(ground_body, model, manager);
360
361 if let Some(t) = manager.get(evaluated) {
363 if matches!(t.kind, TermKind::False) {
364 results.push(Instantiation {
365 quantifier: quant.term,
366 substitution: subst,
367 result: ground_body,
368 });
369
370 if results.len() >= 5 {
372 break;
373 }
374 }
375 }
376 }
377
378 if results.is_empty() {
380 let model_instantiation = self.instantiate_from_model(quantifier_idx, model, manager);
382 if let Some(inst) = model_instantiation {
383 results.push(inst);
384 }
385 }
386
387 results
388 }
389
390 fn build_candidate_lists(
392 &self,
393 bound_vars: &[(Spur, SortId)],
394 manager: &mut TermManager,
395 ) -> Vec<Vec<TermId>> {
396 let mut result = Vec::new();
397
398 for &(_name, sort) in bound_vars {
399 let mut candidates = Vec::new();
400
401 if let Some(pool) = self.candidates_by_sort.get(&sort) {
403 candidates.extend(pool.iter().copied().take(10)); }
405
406 if sort == manager.sorts.int_sort {
408 for i in 0..3 {
410 let int_id = manager.mk_int(i);
411 if !candidates.contains(&int_id) {
412 candidates.push(int_id);
413 }
414 }
415 } else if sort == manager.sorts.bool_sort {
416 let true_id = manager.mk_true();
417 let false_id = manager.mk_false();
418 if !candidates.contains(&true_id) {
419 candidates.push(true_id);
420 }
421 if !candidates.contains(&false_id) {
422 candidates.push(false_id);
423 }
424 }
425
426 result.push(candidates);
427 }
428
429 result
430 }
431
432 fn enumerate_combinations(
434 &self,
435 candidates: &[Vec<TermId>],
436 max_per_dim: usize,
437 ) -> Vec<Vec<TermId>> {
438 if candidates.is_empty() {
439 return vec![vec![]];
440 }
441
442 let mut results = Vec::new();
443 let mut indices = vec![0usize; candidates.len()];
444
445 loop {
446 let combo: Vec<TermId> = indices
448 .iter()
449 .enumerate()
450 .filter_map(|(i, &idx)| candidates.get(i).and_then(|c| c.get(idx).copied()))
451 .collect();
452
453 if combo.len() == candidates.len() {
454 results.push(combo);
455 }
456
457 if results.len() >= 100 {
459 break;
460 }
461
462 let mut carry = true;
464 for (i, idx) in indices.iter_mut().enumerate() {
465 if carry {
466 *idx += 1;
467 let limit = candidates.get(i).map_or(1, |c| c.len().min(max_per_dim));
468 if *idx >= limit {
469 *idx = 0;
470 } else {
471 carry = false;
472 }
473 }
474 }
475
476 if carry {
477 break;
479 }
480 }
481
482 results
483 }
484
485 fn apply_substitution(
487 &self,
488 term: TermId,
489 subst: &FxHashMap<Spur, TermId>,
490 manager: &mut TermManager,
491 ) -> TermId {
492 let Some(t) = manager.get(term).cloned() else {
495 return term;
496 };
497
498 match &t.kind {
499 TermKind::Var(name) => {
500 if let Some(&replacement) = subst.get(name) {
502 replacement
503 } else {
504 term
505 }
506 }
507 TermKind::Not(arg) => {
508 let new_arg = self.apply_substitution(*arg, subst, manager);
509 manager.mk_not(new_arg)
510 }
511 TermKind::And(args) => {
512 let new_args: Vec<_> = args
513 .iter()
514 .map(|&a| self.apply_substitution(a, subst, manager))
515 .collect();
516 manager.mk_and(new_args)
517 }
518 TermKind::Or(args) => {
519 let new_args: Vec<_> = args
520 .iter()
521 .map(|&a| self.apply_substitution(a, subst, manager))
522 .collect();
523 manager.mk_or(new_args)
524 }
525 TermKind::Eq(lhs, rhs) => {
526 let new_lhs = self.apply_substitution(*lhs, subst, manager);
527 let new_rhs = self.apply_substitution(*rhs, subst, manager);
528 manager.mk_eq(new_lhs, new_rhs)
529 }
530 TermKind::Lt(lhs, rhs) => {
531 let new_lhs = self.apply_substitution(*lhs, subst, manager);
532 let new_rhs = self.apply_substitution(*rhs, subst, manager);
533 manager.mk_lt(new_lhs, new_rhs)
534 }
535 TermKind::Le(lhs, rhs) => {
536 let new_lhs = self.apply_substitution(*lhs, subst, manager);
537 let new_rhs = self.apply_substitution(*rhs, subst, manager);
538 manager.mk_le(new_lhs, new_rhs)
539 }
540 TermKind::Gt(lhs, rhs) => {
541 let new_lhs = self.apply_substitution(*lhs, subst, manager);
542 let new_rhs = self.apply_substitution(*rhs, subst, manager);
543 manager.mk_gt(new_lhs, new_rhs)
544 }
545 TermKind::Ge(lhs, rhs) => {
546 let new_lhs = self.apply_substitution(*lhs, subst, manager);
547 let new_rhs = self.apply_substitution(*rhs, subst, manager);
548 manager.mk_ge(new_lhs, new_rhs)
549 }
550 TermKind::Add(args) => {
551 let new_args: SmallVec<[TermId; 4]> = args
552 .iter()
553 .map(|&a| self.apply_substitution(a, subst, manager))
554 .collect();
555 manager.mk_add(new_args)
556 }
557 TermKind::Sub(lhs, rhs) => {
558 let new_lhs = self.apply_substitution(*lhs, subst, manager);
559 let new_rhs = self.apply_substitution(*rhs, subst, manager);
560 manager.mk_sub(new_lhs, new_rhs)
561 }
562 TermKind::Mul(args) => {
563 let new_args: SmallVec<[TermId; 4]> = args
564 .iter()
565 .map(|&a| self.apply_substitution(a, subst, manager))
566 .collect();
567 manager.mk_mul(new_args)
568 }
569 TermKind::Neg(arg) => {
570 let new_arg = self.apply_substitution(*arg, subst, manager);
571 manager.mk_neg(new_arg)
572 }
573 TermKind::Implies(lhs, rhs) => {
574 let new_lhs = self.apply_substitution(*lhs, subst, manager);
575 let new_rhs = self.apply_substitution(*rhs, subst, manager);
576 manager.mk_implies(new_lhs, new_rhs)
577 }
578 TermKind::Ite(cond, then_br, else_br) => {
579 let new_cond = self.apply_substitution(*cond, subst, manager);
580 let new_then = self.apply_substitution(*then_br, subst, manager);
581 let new_else = self.apply_substitution(*else_br, subst, manager);
582 manager.mk_ite(new_cond, new_then, new_else)
583 }
584 TermKind::Apply { func, args } => {
585 let func_name = manager.resolve_str(*func).to_string();
586 let new_args: SmallVec<[TermId; 4]> = args
587 .iter()
588 .map(|&a| self.apply_substitution(a, subst, manager))
589 .collect();
590 manager.mk_apply(&func_name, new_args, t.sort)
591 }
592 _ => term,
594 }
595 }
596
597 fn evaluate_under_model(
599 &self,
600 term: TermId,
601 model: &FxHashMap<TermId, TermId>,
602 manager: &mut TermManager,
603 ) -> TermId {
604 if let Some(&val) = model.get(&term) {
606 return val;
607 }
608
609 let Some(t) = manager.get(term).cloned() else {
610 return term;
611 };
612
613 match &t.kind {
614 TermKind::True | TermKind::False | TermKind::IntConst(_) | TermKind::RealConst(_) => {
615 term
617 }
618 TermKind::Var(_) => {
619 model.get(&term).copied().unwrap_or(term)
621 }
622 TermKind::Not(arg) => {
623 let eval_arg = self.evaluate_under_model(*arg, model, manager);
624 if let Some(arg_t) = manager.get(eval_arg) {
625 match arg_t.kind {
626 TermKind::True => return manager.mk_false(),
627 TermKind::False => return manager.mk_true(),
628 _ => {}
629 }
630 }
631 manager.mk_not(eval_arg)
632 }
633 TermKind::And(args) => {
634 let mut all_true = true;
635 for &arg in args {
636 let eval_arg = self.evaluate_under_model(arg, model, manager);
637 if let Some(arg_t) = manager.get(eval_arg) {
638 match arg_t.kind {
639 TermKind::False => return manager.mk_false(),
640 TermKind::True => continue,
641 _ => all_true = false,
642 }
643 } else {
644 all_true = false;
645 }
646 }
647 if all_true { manager.mk_true() } else { term }
648 }
649 TermKind::Or(args) => {
650 let mut all_false = true;
651 for &arg in args {
652 let eval_arg = self.evaluate_under_model(arg, model, manager);
653 if let Some(arg_t) = manager.get(eval_arg) {
654 match arg_t.kind {
655 TermKind::True => return manager.mk_true(),
656 TermKind::False => continue,
657 _ => all_false = false,
658 }
659 } else {
660 all_false = false;
661 }
662 }
663 if all_false { manager.mk_false() } else { term }
664 }
665 TermKind::Eq(lhs, rhs) => {
666 let eval_lhs = self.evaluate_under_model(*lhs, model, manager);
667 let eval_rhs = self.evaluate_under_model(*rhs, model, manager);
668
669 if eval_lhs == eval_rhs {
671 return manager.mk_true();
672 }
673
674 let lhs_t = manager.get(eval_lhs).cloned();
676 let rhs_t = manager.get(eval_rhs).cloned();
677
678 if let (Some(l), Some(r)) = (lhs_t, rhs_t) {
679 match (&l.kind, &r.kind) {
680 (TermKind::IntConst(a), TermKind::IntConst(b)) => {
681 if a == b {
682 return manager.mk_true();
683 } else {
684 return manager.mk_false();
685 }
686 }
687 (TermKind::True, TermKind::True) | (TermKind::False, TermKind::False) => {
688 return manager.mk_true();
689 }
690 (TermKind::True, TermKind::False) | (TermKind::False, TermKind::True) => {
691 return manager.mk_false();
692 }
693 _ => {}
694 }
695 }
696
697 term
698 }
699 TermKind::Lt(lhs, rhs) => {
700 let eval_lhs = self.evaluate_under_model(*lhs, model, manager);
701 let eval_rhs = self.evaluate_under_model(*rhs, model, manager);
702
703 let lhs_t = manager.get(eval_lhs).cloned();
704 let rhs_t = manager.get(eval_rhs).cloned();
705
706 if let (Some(l), Some(r)) = (lhs_t, rhs_t) {
707 if let (TermKind::IntConst(a), TermKind::IntConst(b)) = (&l.kind, &r.kind) {
708 if a < b {
709 return manager.mk_true();
710 } else {
711 return manager.mk_false();
712 }
713 }
714 }
715
716 term
717 }
718 _ => {
720 manager.simplify(term)
722 }
723 }
724 }
725
726 fn instantiate_from_model(
728 &self,
729 quantifier_idx: usize,
730 model: &FxHashMap<TermId, TermId>,
731 manager: &mut TermManager,
732 ) -> Option<Instantiation> {
733 let quant = &self.quantifiers[quantifier_idx];
734 let mut subst: FxHashMap<Spur, TermId> = FxHashMap::default();
735
736 for &(name, sort) in &quant.bound_vars {
738 let mut found = None;
740 for (&term, &_value) in model {
741 if let Some(t) = manager.get(term) {
742 if t.sort == sort {
743 found = Some(term);
744 break;
745 }
746 }
747 }
748
749 let candidate = match found {
751 Some(t) => t,
752 None => {
753 if sort == manager.sorts.int_sort {
754 manager.mk_int(0)
755 } else if sort == manager.sorts.bool_sort {
756 manager.mk_true()
757 } else {
758 manager.mk_true()
760 }
761 }
762 };
763
764 subst.insert(name, candidate);
765 }
766
767 let ground_body = self.apply_substitution(quant.body, &subst, manager);
769
770 Some(Instantiation {
771 quantifier: quant.term,
772 substitution: subst,
773 result: ground_body,
774 })
775 }
776
777 pub fn stats(&self) -> MBQIStats {
779 MBQIStats {
780 num_quantifiers: self.quantifiers.len(),
781 total_instantiations: self.total_instantiation_count,
782 max_instantiations: self.max_total_instantiations,
783 unique_instantiations: self.generated_instantiations.len(),
784 }
785 }
786}
787
788#[derive(Debug, Clone)]
790pub struct MBQIStats {
791 pub num_quantifiers: usize,
793 pub total_instantiations: usize,
795 pub max_instantiations: usize,
797 pub unique_instantiations: usize,
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_mbqi_new() {
807 let mbqi = MBQISolver::new();
808 assert!(mbqi.is_enabled());
809 assert_eq!(mbqi.quantifiers.len(), 0);
810 }
811
812 #[test]
813 fn test_mbqi_disable() {
814 let mut mbqi = MBQISolver::new();
815 mbqi.set_enabled(false);
816 assert!(!mbqi.is_enabled());
817
818 let model = FxHashMap::default();
819 let mut manager = TermManager::new();
820 let result = mbqi.check_with_model(&model, &mut manager);
821 assert!(matches!(result, MBQIResult::NoQuantifiers));
822 }
823
824 #[test]
825 fn test_mbqi_no_quantifiers() {
826 let mut mbqi = MBQISolver::new();
827 let model = FxHashMap::default();
828 let mut manager = TermManager::new();
829
830 let result = mbqi.check_with_model(&model, &mut manager);
831 assert!(matches!(result, MBQIResult::NoQuantifiers));
832 }
833
834 #[test]
835 fn test_mbqi_add_candidate() {
836 let mut mbqi = MBQISolver::new();
837 let manager = TermManager::new();
838
839 let sort = manager.sorts.int_sort;
840 mbqi.add_candidate(TermId::new(1), sort);
841 mbqi.add_candidate(TermId::new(2), sort);
842
843 let candidates = mbqi.get_candidates(sort);
844 assert_eq!(candidates.len(), 2);
845 }
846
847 #[test]
848 fn test_mbqi_stats() {
849 let mbqi = MBQISolver::new();
850 let stats = mbqi.stats();
851
852 assert_eq!(stats.num_quantifiers, 0);
853 assert_eq!(stats.total_instantiations, 0);
854 }
855
856 #[test]
857 fn test_mbqi_clear() {
858 let mut mbqi = MBQISolver::new();
859 let manager = TermManager::new();
860
861 mbqi.add_candidate(TermId::new(1), manager.sorts.int_sort);
862 mbqi.total_instantiation_count = 5;
863
864 mbqi.clear();
865
866 assert_eq!(mbqi.quantifiers.len(), 0);
867 assert_eq!(mbqi.total_instantiation_count, 0);
868 }
869
870 #[test]
871 fn test_mbqi_with_limit() {
872 let mbqi = MBQISolver::with_limit(100);
873 assert_eq!(mbqi.max_total_instantiations, 100);
874 }
875
876 #[test]
877 fn test_enumerate_combinations_empty() {
878 let mbqi = MBQISolver::new();
879 let candidates: Vec<Vec<TermId>> = vec![];
880 let combos = mbqi.enumerate_combinations(&candidates, 10);
881 assert_eq!(combos.len(), 1);
882 assert!(combos[0].is_empty());
883 }
884
885 #[test]
886 fn test_enumerate_combinations_single() {
887 let mbqi = MBQISolver::new();
888 let candidates = vec![vec![TermId::new(1), TermId::new(2)]];
889 let combos = mbqi.enumerate_combinations(&candidates, 10);
890 assert_eq!(combos.len(), 2);
891 }
892
893 #[test]
894 fn test_enumerate_combinations_multiple() {
895 let mbqi = MBQISolver::new();
896 let candidates = vec![
897 vec![TermId::new(1), TermId::new(2)],
898 vec![TermId::new(3), TermId::new(4)],
899 ];
900 let combos = mbqi.enumerate_combinations(&candidates, 10);
901 assert_eq!(combos.len(), 4);
903 }
904}