1#![allow(missing_docs)] #[allow(unused_imports)]
13use crate::prelude::*;
14
15pub type TermId = usize;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum TheoryId {
21 Core,
22 Arithmetic,
23 BitVector,
24 Array,
25 Datatype,
26 String,
27 Uninterpreted,
28}
29
30pub trait TheorySolver {
32 fn theory_id(&self) -> TheoryId;
34
35 fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
37
38 fn check_sat(&mut self) -> Result<SatResult, String>;
40
41 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
43
44 fn get_conflict(&self) -> Option<Vec<TermId>>;
46
47 fn backtrack(&mut self, level: usize) -> Result<(), String>;
49
50 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
52
53 fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum SatResult {
60 Sat,
61 Unsat,
62 Unknown,
63}
64
65#[derive(Debug, Clone)]
67pub struct SharedTerm {
68 pub term: TermId,
70 pub theories: FxHashSet<TheoryId>,
72 pub representative: TermId,
74}
75
76#[derive(Debug, Clone)]
78pub struct EqualityProp {
79 pub lhs: TermId,
81 pub rhs: TermId,
83 pub source: TheoryId,
85 pub explanation: Vec<TermId>,
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct CoordinatorStats {
92 pub check_sat_calls: u64,
93 pub theory_conflicts: u64,
94 pub equalities_propagated: u64,
95 pub shared_terms_count: usize,
96 pub theory_combination_rounds: u64,
97}
98
99#[derive(Debug, Clone)]
101pub struct CoordinatorConfig {
102 pub eager_combination: bool,
104 pub max_combination_rounds: usize,
106 pub minimize_conflicts: bool,
108}
109
110impl Default for CoordinatorConfig {
111 fn default() -> Self {
112 Self {
113 eager_combination: false,
114 max_combination_rounds: 10,
115 minimize_conflicts: true,
116 }
117 }
118}
119
120pub struct TheoryCoordinator {
122 config: CoordinatorConfig,
123 stats: CoordinatorStats,
124 theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
126 shared_terms: FxHashMap<TermId, SharedTerm>,
128 pending_equalities: VecDeque<EqualityProp>,
130 current_level: usize,
132}
133
134impl TheoryCoordinator {
135 pub fn new(config: CoordinatorConfig) -> Self {
137 Self {
138 config,
139 stats: CoordinatorStats::default(),
140 theories: FxHashMap::default(),
141 shared_terms: FxHashMap::default(),
142 pending_equalities: VecDeque::new(),
143 current_level: 0,
144 }
145 }
146
147 pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
149 let theory_id = theory.theory_id();
150 self.theories.insert(theory_id, theory);
151 }
152
153 pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
155 if let Some(solver) = self.theories.get_mut(&theory) {
156 solver.assert_formula(formula)?;
157
158 self.identify_shared_terms(formula)?;
160 } else {
161 return Err(format!("Theory {:?} not registered", theory));
162 }
163
164 Ok(())
165 }
166
167 pub fn check_sat(&mut self) -> Result<SatResult, String> {
169 self.stats.check_sat_calls += 1;
170
171 for solver in self.theories.values_mut() {
173 let result = solver.check_sat()?;
174
175 match result {
176 SatResult::Unsat => {
177 self.stats.theory_conflicts += 1;
178 return Ok(SatResult::Unsat);
179 }
180 SatResult::Unknown => {
181 return Ok(SatResult::Unknown);
182 }
183 SatResult::Sat => {
184 }
186 }
187 }
188
189 if self.config.eager_combination {
191 self.eager_theory_combination()
192 } else {
193 self.lazy_theory_combination()
194 }
195 }
196
197 fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
199 let mut iteration = 0;
200
201 loop {
202 self.stats.theory_combination_rounds += 1;
203 iteration += 1;
204
205 if iteration > self.config.max_combination_rounds {
206 return Ok(SatResult::Unknown);
207 }
208
209 let mut new_equalities = Vec::new();
211
212 for (theory_id, solver) in &self.theories {
213 let equalities = solver.get_implied_equalities();
214
215 for (lhs, rhs) in equalities {
216 if self.is_shared_term(lhs) || self.is_shared_term(rhs) {
218 new_equalities.push(EqualityProp {
219 lhs,
220 rhs,
221 source: *theory_id,
222 explanation: vec![],
223 });
224 }
225 }
226 }
227
228 if new_equalities.is_empty() {
230 return Ok(SatResult::Sat);
231 }
232
233 for eq in new_equalities {
235 self.propagate_equality(eq)?;
236 }
237
238 for solver in self.theories.values_mut() {
240 match solver.check_sat()? {
241 SatResult::Unsat => {
242 self.stats.theory_conflicts += 1;
243 return Ok(SatResult::Unsat);
244 }
245 SatResult::Unknown => {
246 return Ok(SatResult::Unknown);
247 }
248 SatResult::Sat => {}
249 }
250 }
251 }
252 }
253
254 fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
256 while let Some(eq) = self.pending_equalities.pop_front() {
258 self.propagate_equality(eq)?;
259
260 for solver in self.theories.values_mut() {
262 match solver.check_sat()? {
263 SatResult::Unsat => {
264 self.stats.theory_conflicts += 1;
265 return Ok(SatResult::Unsat);
266 }
267 SatResult::Unknown => {
268 return Ok(SatResult::Unknown);
269 }
270 SatResult::Sat => {}
271 }
272 }
273 }
274
275 Ok(SatResult::Sat)
276 }
277
278 fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
280 self.stats.equalities_propagated += 1;
281
282 self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
284
285 let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
287
288 for theory_id in theories_to_notify {
289 if theory_id != eq.source
290 && let Some(solver) = self.theories.get_mut(&theory_id)
291 {
292 solver.notify_equality(eq.lhs, eq.rhs)?;
293 }
294 }
295
296 Ok(())
297 }
298
299 fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
301 self.stats.shared_terms_count = self.shared_terms.len();
304 Ok(())
305 }
306
307 fn is_shared_term(&self, term: TermId) -> bool {
309 self.shared_terms
310 .get(&term)
311 .is_some_and(|st| st.theories.len() > 1)
312 }
313
314 fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
316 let mut theories = FxHashSet::default();
317
318 if let Some(st) = self.shared_terms.get(&lhs) {
319 theories.extend(&st.theories);
320 }
321
322 if let Some(st) = self.shared_terms.get(&rhs) {
323 theories.extend(&st.theories);
324 }
325
326 theories
327 }
328
329 fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
331 let lhs_rep = self.find_representative(lhs);
333 let rhs_rep = self.find_representative(rhs);
334
335 if lhs_rep == rhs_rep {
336 return Ok(());
337 }
338
339 if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
341 st.representative = rhs_rep;
342 }
343
344 Ok(())
345 }
346
347 fn find_representative(&self, term: TermId) -> TermId {
349 if let Some(st) = self.shared_terms.get(&term)
350 && st.representative != term
351 {
352 return self.find_representative(st.representative);
354 }
355 term
356 }
357
358 pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
360 self.shared_terms
361 .entry(term)
362 .or_insert_with(|| SharedTerm {
363 term,
364 theories: FxHashSet::default(),
365 representative: term,
366 })
367 .theories
368 .insert(theory);
369
370 self.stats.shared_terms_count = self.shared_terms.len();
371 }
372
373 pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
375 self.pending_equalities.push_back(EqualityProp {
376 lhs,
377 rhs,
378 source,
379 explanation: vec![],
380 });
381 }
382
383 pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
385 self.current_level = level;
386
387 for solver in self.theories.values_mut() {
388 solver.backtrack(level)?;
389 }
390
391 self.pending_equalities.clear();
393
394 Ok(())
395 }
396
397 pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
399 let mut combined_model = FxHashMap::default();
400
401 for solver in self.theories.values() {
402 if let Some(model) = solver.get_model() {
403 combined_model.extend(model);
404 } else {
405 return None;
406 }
407 }
408
409 Some(combined_model)
410 }
411
412 pub fn get_conflict(&self) -> Option<Vec<TermId>> {
414 let mut combined_conflict = Vec::new();
416
417 for solver in self.theories.values() {
418 if let Some(conflict) = solver.get_conflict() {
419 combined_conflict.extend(conflict);
420 }
421 }
422
423 if combined_conflict.is_empty() {
424 None
425 } else {
426 if self.config.minimize_conflicts {
428 Some(self.minimize_conflict(combined_conflict))
429 } else {
430 Some(combined_conflict)
431 }
432 }
433 }
434
435 fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
437 conflict.sort();
440 conflict.dedup();
441 conflict
442 }
443
444 pub fn stats(&self) -> &CoordinatorStats {
446 &self.stats
447 }
448
449 pub fn current_level(&self) -> usize {
451 self.current_level
452 }
453
454 pub fn increment_level(&mut self) {
456 self.current_level += 1;
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 struct MockTheory {
466 id: TheoryId,
467 sat_result: SatResult,
468 }
469
470 impl TheorySolver for MockTheory {
471 fn theory_id(&self) -> TheoryId {
472 self.id
473 }
474
475 fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
476 Ok(())
477 }
478
479 fn check_sat(&mut self) -> Result<SatResult, String> {
480 Ok(self.sat_result)
481 }
482
483 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
484 Some(FxHashMap::default())
485 }
486
487 fn get_conflict(&self) -> Option<Vec<TermId>> {
488 None
489 }
490
491 fn backtrack(&mut self, _level: usize) -> Result<(), String> {
492 Ok(())
493 }
494
495 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
496 vec![]
497 }
498
499 fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
500 Ok(())
501 }
502 }
503
504 #[test]
505 fn test_coordinator_creation() {
506 let config = CoordinatorConfig::default();
507 let coordinator = TheoryCoordinator::new(config);
508 assert_eq!(coordinator.stats.check_sat_calls, 0);
509 }
510
511 #[test]
512 fn test_register_theory() {
513 let config = CoordinatorConfig::default();
514 let mut coordinator = TheoryCoordinator::new(config);
515
516 let mock_theory = MockTheory {
517 id: TheoryId::Arithmetic,
518 sat_result: SatResult::Sat,
519 };
520
521 coordinator.register_theory(Box::new(mock_theory));
522 assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
523 }
524
525 #[test]
526 fn test_check_sat_single_theory() {
527 let config = CoordinatorConfig::default();
528 let mut coordinator = TheoryCoordinator::new(config);
529
530 let mock_theory = MockTheory {
531 id: TheoryId::Arithmetic,
532 sat_result: SatResult::Sat,
533 };
534
535 coordinator.register_theory(Box::new(mock_theory));
536
537 let result = coordinator.check_sat();
538 assert!(result.is_ok());
539 assert_eq!(
540 result.expect("test operation should succeed"),
541 SatResult::Sat
542 );
543 assert_eq!(coordinator.stats.check_sat_calls, 1);
544 }
545
546 #[test]
547 fn test_shared_term_management() {
548 let config = CoordinatorConfig::default();
549 let mut coordinator = TheoryCoordinator::new(config);
550
551 coordinator.add_shared_term(1, TheoryId::Arithmetic);
552 coordinator.add_shared_term(1, TheoryId::BitVector);
553
554 assert!(coordinator.is_shared_term(1));
555 assert_eq!(coordinator.stats.shared_terms_count, 1);
556 }
557
558 #[test]
559 fn test_equivalence_classes() {
560 let config = CoordinatorConfig::default();
561 let mut coordinator = TheoryCoordinator::new(config);
562
563 coordinator.add_shared_term(1, TheoryId::Arithmetic);
564 coordinator.add_shared_term(2, TheoryId::Arithmetic);
565
566 coordinator
567 .merge_equivalence_classes(1, 2)
568 .expect("test operation should succeed");
569
570 let rep1 = coordinator.find_representative(1);
571 let rep2 = coordinator.find_representative(2);
572 assert_eq!(rep1, rep2);
573 }
574
575 #[test]
576 fn test_equality_propagation() {
577 let config = CoordinatorConfig::default();
578 let mut coordinator = TheoryCoordinator::new(config);
579
580 coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
581 assert_eq!(coordinator.pending_equalities.len(), 1);
582 }
583
584 #[test]
585 fn test_backtrack() {
586 let config = CoordinatorConfig::default();
587 let mut coordinator = TheoryCoordinator::new(config);
588
589 let mock_theory = MockTheory {
590 id: TheoryId::Arithmetic,
591 sat_result: SatResult::Sat,
592 };
593
594 coordinator.register_theory(Box::new(mock_theory));
595 coordinator.increment_level();
596 coordinator.increment_level();
597
598 assert_eq!(coordinator.current_level(), 2);
599
600 coordinator
601 .backtrack(0)
602 .expect("test operation should succeed");
603 assert_eq!(coordinator.current_level(), 0);
604 }
605
606 #[test]
607 fn test_get_model() {
608 let config = CoordinatorConfig::default();
609 let mut coordinator = TheoryCoordinator::new(config);
610
611 let mock_theory = MockTheory {
612 id: TheoryId::Arithmetic,
613 sat_result: SatResult::Sat,
614 };
615
616 coordinator.register_theory(Box::new(mock_theory));
617
618 let model = coordinator.get_model();
619 assert!(model.is_some());
620 }
621
622 #[test]
623 fn test_conflict_minimization() {
624 let coordinator = TheoryCoordinator::new(CoordinatorConfig {
625 minimize_conflicts: true,
626 ..Default::default()
627 });
628
629 let conflict = vec![1, 2, 2, 3, 1, 4];
630 let minimized = coordinator.minimize_conflict(conflict);
631
632 assert_eq!(minimized, vec![1, 2, 3, 4]);
633 }
634}