1#![allow(missing_docs)] #[allow(unused_imports)]
13use crate::prelude::*;
14#[cfg(feature = "std")]
15use oxiz_core::TermId as ProofTermId;
16#[cfg(feature = "profiling")]
17use oxiz_core::profiling::{ProfilingCategory, ScopedTimer};
18#[cfg(feature = "std")]
19use oxiz_proof::{CombinationStep, CombinationTheoryId, NelsonOppenCertificate, ProofNodeId};
20
21pub type TermId = usize;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum TheoryId {
27 Core,
28 Arithmetic,
29 BitVector,
30 Array,
31 Datatype,
32 String,
33 Uninterpreted,
34}
35
36pub trait TheorySolver {
38 fn theory_id(&self) -> TheoryId;
40
41 fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
43
44 fn check_sat(&mut self) -> Result<SatResult, String>;
46
47 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
49
50 fn get_conflict(&self) -> Option<Vec<TermId>>;
52
53 fn backtrack(&mut self, level: usize) -> Result<(), String>;
55
56 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
58
59 fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum SatResult {
66 Sat,
67 Unsat,
68 Unknown,
69}
70
71#[derive(Debug, Clone)]
73pub struct SharedTerm {
74 pub term: TermId,
76 pub theories: FxHashSet<TheoryId>,
78 pub representative: TermId,
80}
81
82#[derive(Debug, Clone)]
84pub struct EqualityProp {
85 pub lhs: TermId,
87 pub rhs: TermId,
89 pub source: TheoryId,
91 pub explanation: Vec<TermId>,
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct CoordinatorStats {
98 pub check_sat_calls: u64,
99 pub theory_conflicts: u64,
100 pub equalities_propagated: u64,
101 pub shared_terms_count: usize,
102 pub theory_combination_rounds: u64,
103}
104
105#[derive(Debug, Clone)]
107pub struct CoordinatorConfig {
108 pub eager_combination: bool,
110 pub max_combination_rounds: usize,
112 pub minimize_conflicts: bool,
114 pub proof_mode: bool,
116}
117
118impl Default for CoordinatorConfig {
119 fn default() -> Self {
120 Self {
121 eager_combination: false,
122 max_combination_rounds: 10,
123 minimize_conflicts: true,
124 proof_mode: false,
125 }
126 }
127}
128
129pub struct TheoryCoordinator {
131 config: CoordinatorConfig,
132 stats: CoordinatorStats,
133 theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
135 shared_terms: FxHashMap<TermId, SharedTerm>,
137 pending_equalities: VecDeque<EqualityProp>,
139 theory_propagation_cache: FxHashMap<(TheoryId, u32), Vec<EqualityProp>>,
141 propagated_equalities_log: Vec<EqualityProp>,
143 #[cfg(feature = "std")]
145 last_certificate: Option<NelsonOppenCertificate>,
146 current_level: usize,
148}
149
150impl TheoryCoordinator {
151 pub fn new(config: CoordinatorConfig) -> Self {
153 Self {
154 config,
155 stats: CoordinatorStats::default(),
156 theories: FxHashMap::default(),
157 shared_terms: FxHashMap::default(),
158 pending_equalities: VecDeque::new(),
159 theory_propagation_cache: FxHashMap::default(),
160 propagated_equalities_log: Vec::new(),
161 #[cfg(feature = "std")]
162 last_certificate: None,
163 current_level: 0,
164 }
165 }
166
167 pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
169 let theory_id = theory.theory_id();
170 self.theories.insert(theory_id, theory);
171 }
172
173 pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
175 if let Some(solver) = self.theories.get_mut(&theory) {
176 solver.assert_formula(formula)?;
177 self.clear_from_level(self.current_level as u32);
178
179 self.identify_shared_terms(formula)?;
181 } else {
182 return Err(format!("Theory {:?} not registered", theory));
183 }
184
185 Ok(())
186 }
187
188 pub fn check_sat(&mut self) -> Result<SatResult, String> {
190 #[cfg(feature = "profiling")]
191 let _timer = ScopedTimer::new(ProfilingCategory::TheoryCheck);
192 self.stats.check_sat_calls += 1;
193
194 for solver in self.theories.values_mut() {
196 let result = solver.check_sat()?;
197
198 match result {
199 SatResult::Unsat => {
200 self.stats.theory_conflicts += 1;
201 self.maybe_record_certificate_from_log();
202 return Ok(SatResult::Unsat);
203 }
204 SatResult::Unknown => {
205 return Ok(SatResult::Unknown);
206 }
207 SatResult::Sat => {
208 }
210 }
211 }
212
213 if self.config.eager_combination {
215 self.eager_theory_combination()
216 } else {
217 self.lazy_theory_combination()
218 }
219 }
220
221 fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
223 let mut iteration = 0;
224
225 loop {
226 self.stats.theory_combination_rounds += 1;
227 iteration += 1;
228
229 if iteration > self.config.max_combination_rounds {
230 return Ok(SatResult::Unknown);
231 }
232
233 let mut new_equalities = Vec::new();
235
236 for theory_id in self.theories.keys().copied().collect::<Vec<_>>() {
237 let equalities = self.cached_theory_propagation(theory_id)?;
238
239 for eq in equalities {
240 if self.is_shared_term(eq.lhs) || self.is_shared_term(eq.rhs) {
242 new_equalities.push(eq);
243 }
244 }
245 }
246
247 if new_equalities.is_empty() {
249 return Ok(SatResult::Sat);
250 }
251
252 for eq in new_equalities {
254 self.propagate_equality(eq)?;
255 }
256
257 for solver in self.theories.values_mut() {
259 match solver.check_sat()? {
260 SatResult::Unsat => {
261 self.stats.theory_conflicts += 1;
262 self.maybe_record_certificate_from_log();
263 return Ok(SatResult::Unsat);
264 }
265 SatResult::Unknown => {
266 return Ok(SatResult::Unknown);
267 }
268 SatResult::Sat => {}
269 }
270 }
271 }
272 }
273
274 fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
276 while let Some(eq) = self.pending_equalities.pop_front() {
278 self.propagate_equality(eq)?;
279
280 for solver in self.theories.values_mut() {
282 match solver.check_sat()? {
283 SatResult::Unsat => {
284 self.stats.theory_conflicts += 1;
285 self.maybe_record_certificate_from_log();
286 return Ok(SatResult::Unsat);
287 }
288 SatResult::Unknown => {
289 return Ok(SatResult::Unknown);
290 }
291 SatResult::Sat => {}
292 }
293 }
294 }
295
296 Ok(SatResult::Sat)
297 }
298
299 fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
301 self.stats.equalities_propagated += 1;
302 let logged_eq = eq.clone();
303
304 self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
306
307 let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
309
310 for theory_id in theories_to_notify {
311 if theory_id != eq.source
312 && let Some(solver) = self.theories.get_mut(&theory_id)
313 {
314 solver.notify_equality(eq.lhs, eq.rhs)?;
315 }
316 }
317
318 self.clear_from_level(self.current_level as u32);
319 self.propagated_equalities_log.push(logged_eq);
320
321 Ok(())
322 }
323
324 fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
326 self.stats.shared_terms_count = self.shared_terms.len();
329 Ok(())
330 }
331
332 fn is_shared_term(&self, term: TermId) -> bool {
334 self.shared_terms
335 .get(&term)
336 .is_some_and(|st| st.theories.len() > 1)
337 }
338
339 fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
341 let mut theories = FxHashSet::default();
342
343 if let Some(st) = self.shared_terms.get(&lhs) {
344 theories.extend(&st.theories);
345 }
346
347 if let Some(st) = self.shared_terms.get(&rhs) {
348 theories.extend(&st.theories);
349 }
350
351 theories
352 }
353
354 fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
356 let lhs_rep = self.find_representative(lhs);
358 let rhs_rep = self.find_representative(rhs);
359
360 if lhs_rep == rhs_rep {
361 return Ok(());
362 }
363
364 if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
366 st.representative = rhs_rep;
367 }
368
369 Ok(())
370 }
371
372 fn find_representative(&self, term: TermId) -> TermId {
374 if let Some(st) = self.shared_terms.get(&term)
375 && st.representative != term
376 {
377 return self.find_representative(st.representative);
379 }
380 term
381 }
382
383 pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
385 self.shared_terms
386 .entry(term)
387 .or_insert_with(|| SharedTerm {
388 term,
389 theories: FxHashSet::default(),
390 representative: term,
391 })
392 .theories
393 .insert(theory);
394
395 self.stats.shared_terms_count = self.shared_terms.len();
396 }
397
398 pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
400 self.pending_equalities.push_back(EqualityProp {
401 lhs,
402 rhs,
403 source,
404 explanation: vec![],
405 });
406 }
407
408 pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
410 self.current_level = level;
411
412 for solver in self.theories.values_mut() {
413 solver.backtrack(level)?;
414 }
415
416 self.pending_equalities.clear();
418 self.clear_above_level(level as u32);
419 self.propagated_equalities_log.clear();
420 #[cfg(feature = "std")]
421 {
422 self.last_certificate = None;
423 }
424
425 Ok(())
426 }
427
428 pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
430 let mut combined_model = FxHashMap::default();
431
432 for solver in self.theories.values() {
433 if let Some(model) = solver.get_model() {
434 combined_model.extend(model);
435 } else {
436 return None;
437 }
438 }
439
440 Some(combined_model)
441 }
442
443 pub fn get_conflict(&self) -> Option<Vec<TermId>> {
445 let mut combined_conflict = Vec::new();
447
448 for solver in self.theories.values() {
449 if let Some(conflict) = solver.get_conflict() {
450 combined_conflict.extend(conflict);
451 }
452 }
453
454 if combined_conflict.is_empty() {
455 None
456 } else {
457 if self.config.minimize_conflicts {
459 Some(self.minimize_conflict(combined_conflict))
460 } else {
461 Some(combined_conflict)
462 }
463 }
464 }
465
466 fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
468 conflict.sort();
471 conflict.dedup();
472 conflict
473 }
474
475 pub fn stats(&self) -> &CoordinatorStats {
477 &self.stats
478 }
479
480 pub fn current_level(&self) -> usize {
482 self.current_level
483 }
484
485 #[cfg(feature = "std")]
487 pub fn proof_certificate(&self) -> Option<&NelsonOppenCertificate> {
488 self.last_certificate.as_ref()
489 }
490
491 pub fn increment_level(&mut self) {
493 self.current_level += 1;
494 }
495
496 fn maybe_record_certificate_from_log(&mut self) {
497 #[cfg(feature = "std")]
498 {
499 if !self.config.proof_mode {
500 return;
501 }
502
503 self.last_certificate = self.build_certificate_from_log();
504 }
505 }
506
507 fn cached_theory_propagation(
508 &mut self,
509 theory_id: TheoryId,
510 ) -> Result<Vec<EqualityProp>, String> {
511 let level = self.current_level as u32;
512 let key = (theory_id, level);
513
514 if let Some(cached) = self.theory_propagation_cache.get(&key) {
515 return Ok(cached.clone());
516 }
517
518 let solver = self
519 .theories
520 .get(&theory_id)
521 .ok_or_else(|| format!("Theory {:?} not registered", theory_id))?;
522
523 let propagated: Vec<EqualityProp> = solver
524 .get_implied_equalities()
525 .into_iter()
526 .map(|(lhs, rhs)| EqualityProp {
527 lhs,
528 rhs,
529 source: theory_id,
530 explanation: vec![],
531 })
532 .collect();
533
534 self.theory_propagation_cache
535 .insert(key, propagated.clone());
536
537 Ok(propagated)
538 }
539
540 fn clear_above_level(&mut self, level: u32) {
541 self.theory_propagation_cache
542 .retain(|(_, cached_level), _| *cached_level <= level);
543 }
544
545 fn clear_from_level(&mut self, level: u32) {
546 self.theory_propagation_cache
547 .retain(|(_, cached_level), _| *cached_level < level);
548 }
549
550 #[cfg(feature = "std")]
551 fn build_certificate_from_log(&self) -> Option<NelsonOppenCertificate> {
552 let last_eq = self.propagated_equalities_log.last()?;
553 let mut certificate =
554 NelsonOppenCertificate::new(self.to_proof_theory_id(last_eq.source), ProofNodeId(0));
555
556 for eq in &self.propagated_equalities_log {
557 let lhs = Self::to_proof_term_id(eq.lhs)?;
558 let rhs = Self::to_proof_term_id(eq.rhs)?;
559 certificate.add_step(CombinationStep {
560 theory: self.to_proof_theory_id(eq.source),
561 propagated_equalities: vec![(lhs, rhs)],
562 justification: Vec::new(),
563 });
564 }
565
566 Some(certificate)
567 }
568
569 #[cfg(feature = "std")]
570 fn to_proof_term_id(term: TermId) -> Option<ProofTermId> {
571 let raw = u32::try_from(term).ok()?;
572 Some(ProofTermId::new(raw))
573 }
574
575 #[cfg(feature = "std")]
576 const fn to_proof_theory_id(&self, theory: TheoryId) -> CombinationTheoryId {
577 let raw = match theory {
578 TheoryId::Core => 0,
579 TheoryId::Arithmetic => 1,
580 TheoryId::BitVector => 2,
581 TheoryId::Array => 3,
582 TheoryId::Datatype => 4,
583 TheoryId::String => 5,
584 TheoryId::Uninterpreted => 6,
585 };
586 CombinationTheoryId(raw)
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 struct MockTheory {
596 id: TheoryId,
597 sat_result: SatResult,
598 implied_equalities: Vec<(TermId, TermId)>,
599 }
600
601 impl TheorySolver for MockTheory {
602 fn theory_id(&self) -> TheoryId {
603 self.id
604 }
605
606 fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
607 Ok(())
608 }
609
610 fn check_sat(&mut self) -> Result<SatResult, String> {
611 Ok(self.sat_result)
612 }
613
614 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
615 Some(FxHashMap::default())
616 }
617
618 fn get_conflict(&self) -> Option<Vec<TermId>> {
619 None
620 }
621
622 fn backtrack(&mut self, _level: usize) -> Result<(), String> {
623 Ok(())
624 }
625
626 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
627 self.implied_equalities.clone()
628 }
629
630 fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
631 Ok(())
632 }
633 }
634
635 #[test]
636 fn test_coordinator_creation() {
637 let config = CoordinatorConfig::default();
638 let coordinator = TheoryCoordinator::new(config);
639 assert_eq!(coordinator.stats.check_sat_calls, 0);
640 }
641
642 #[test]
643 fn test_register_theory() {
644 let config = CoordinatorConfig::default();
645 let mut coordinator = TheoryCoordinator::new(config);
646
647 let mock_theory = MockTheory {
648 id: TheoryId::Arithmetic,
649 sat_result: SatResult::Sat,
650 implied_equalities: Vec::new(),
651 };
652
653 coordinator.register_theory(Box::new(mock_theory));
654 assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
655 }
656
657 #[test]
658 fn test_check_sat_single_theory() {
659 let config = CoordinatorConfig::default();
660 let mut coordinator = TheoryCoordinator::new(config);
661
662 let mock_theory = MockTheory {
663 id: TheoryId::Arithmetic,
664 sat_result: SatResult::Sat,
665 implied_equalities: Vec::new(),
666 };
667
668 coordinator.register_theory(Box::new(mock_theory));
669
670 let result = coordinator.check_sat();
671 assert!(result.is_ok());
672 assert_eq!(
673 result.expect("test operation should succeed"),
674 SatResult::Sat
675 );
676 assert_eq!(coordinator.stats.check_sat_calls, 1);
677 }
678
679 #[test]
680 fn test_shared_term_management() {
681 let config = CoordinatorConfig::default();
682 let mut coordinator = TheoryCoordinator::new(config);
683
684 coordinator.add_shared_term(1, TheoryId::Arithmetic);
685 coordinator.add_shared_term(1, TheoryId::BitVector);
686
687 assert!(coordinator.is_shared_term(1));
688 assert_eq!(coordinator.stats.shared_terms_count, 1);
689 }
690
691 #[test]
692 fn test_equivalence_classes() {
693 let config = CoordinatorConfig::default();
694 let mut coordinator = TheoryCoordinator::new(config);
695
696 coordinator.add_shared_term(1, TheoryId::Arithmetic);
697 coordinator.add_shared_term(2, TheoryId::Arithmetic);
698
699 coordinator
700 .merge_equivalence_classes(1, 2)
701 .expect("test operation should succeed");
702
703 let rep1 = coordinator.find_representative(1);
704 let rep2 = coordinator.find_representative(2);
705 assert_eq!(rep1, rep2);
706 }
707
708 #[test]
709 fn test_equality_propagation() {
710 let config = CoordinatorConfig::default();
711 let mut coordinator = TheoryCoordinator::new(config);
712
713 coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
714 assert_eq!(coordinator.pending_equalities.len(), 1);
715 }
716
717 #[test]
718 fn test_backtrack() {
719 let config = CoordinatorConfig::default();
720 let mut coordinator = TheoryCoordinator::new(config);
721
722 let mock_theory = MockTheory {
723 id: TheoryId::Arithmetic,
724 sat_result: SatResult::Sat,
725 implied_equalities: Vec::new(),
726 };
727
728 coordinator.register_theory(Box::new(mock_theory));
729 coordinator.increment_level();
730 coordinator.increment_level();
731
732 assert_eq!(coordinator.current_level(), 2);
733
734 coordinator
735 .backtrack(0)
736 .expect("test operation should succeed");
737 assert_eq!(coordinator.current_level(), 0);
738 }
739
740 #[test]
741 fn test_get_model() {
742 let config = CoordinatorConfig::default();
743 let mut coordinator = TheoryCoordinator::new(config);
744
745 let mock_theory = MockTheory {
746 id: TheoryId::Arithmetic,
747 sat_result: SatResult::Sat,
748 implied_equalities: Vec::new(),
749 };
750
751 coordinator.register_theory(Box::new(mock_theory));
752
753 let model = coordinator.get_model();
754 assert!(model.is_some());
755 }
756
757 #[test]
758 fn test_conflict_minimization() {
759 let coordinator = TheoryCoordinator::new(CoordinatorConfig {
760 minimize_conflicts: true,
761 ..Default::default()
762 });
763
764 let conflict = vec![1, 2, 2, 3, 1, 4];
765 let minimized = coordinator.minimize_conflict(conflict);
766
767 assert_eq!(minimized, vec![1, 2, 3, 4]);
768 }
769
770 #[test]
771 fn test_theory_propagation_cache_clears_on_backtrack() {
772 let mut coordinator = TheoryCoordinator::new(CoordinatorConfig::default());
773 coordinator.register_theory(Box::new(MockTheory {
774 id: TheoryId::Arithmetic,
775 sat_result: SatResult::Sat,
776 implied_equalities: vec![(1, 2)],
777 }));
778
779 assert_eq!(
780 coordinator
781 .cached_theory_propagation(TheoryId::Arithmetic)
782 .expect("initial cache fill should succeed")
783 .len(),
784 1
785 );
786 assert_eq!(coordinator.theory_propagation_cache.len(), 1);
787
788 coordinator.increment_level();
789 assert_eq!(
790 coordinator
791 .cached_theory_propagation(TheoryId::Arithmetic)
792 .expect("level-one cache fill should succeed")
793 .len(),
794 1
795 );
796 assert_eq!(coordinator.theory_propagation_cache.len(), 2);
797
798 coordinator
799 .backtrack(0)
800 .expect("backtrack should clear higher-level cache entries");
801 assert_eq!(coordinator.theory_propagation_cache.len(), 1);
802 }
803}