1#![allow(missing_docs)] use rustc_hash::{FxHashMap, FxHashSet};
12use std::collections::VecDeque;
13
14pub type TermId = usize;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum TheoryId {
20 Core,
21 Arithmetic,
22 BitVector,
23 Array,
24 Datatype,
25 String,
26 Uninterpreted,
27}
28
29pub trait TheorySolver {
31 fn theory_id(&self) -> TheoryId;
33
34 fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
36
37 fn check_sat(&mut self) -> Result<SatResult, String>;
39
40 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
42
43 fn get_conflict(&self) -> Option<Vec<TermId>>;
45
46 fn backtrack(&mut self, level: usize) -> Result<(), String>;
48
49 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
51
52 fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SatResult {
59 Sat,
60 Unsat,
61 Unknown,
62}
63
64#[derive(Debug, Clone)]
66pub struct SharedTerm {
67 pub term: TermId,
69 pub theories: FxHashSet<TheoryId>,
71 pub representative: TermId,
73}
74
75#[derive(Debug, Clone)]
77pub struct EqualityProp {
78 pub lhs: TermId,
80 pub rhs: TermId,
82 pub source: TheoryId,
84 pub explanation: Vec<TermId>,
86}
87
88#[derive(Debug, Clone, Default)]
90pub struct CoordinatorStats {
91 pub check_sat_calls: u64,
92 pub theory_conflicts: u64,
93 pub equalities_propagated: u64,
94 pub shared_terms_count: usize,
95 pub theory_combination_rounds: u64,
96}
97
98#[derive(Debug, Clone)]
100pub struct CoordinatorConfig {
101 pub eager_combination: bool,
103 pub max_combination_rounds: usize,
105 pub minimize_conflicts: bool,
107}
108
109impl Default for CoordinatorConfig {
110 fn default() -> Self {
111 Self {
112 eager_combination: false,
113 max_combination_rounds: 10,
114 minimize_conflicts: true,
115 }
116 }
117}
118
119pub struct TheoryCoordinator {
121 config: CoordinatorConfig,
122 stats: CoordinatorStats,
123 theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
125 shared_terms: FxHashMap<TermId, SharedTerm>,
127 pending_equalities: VecDeque<EqualityProp>,
129 current_level: usize,
131}
132
133impl TheoryCoordinator {
134 pub fn new(config: CoordinatorConfig) -> Self {
136 Self {
137 config,
138 stats: CoordinatorStats::default(),
139 theories: FxHashMap::default(),
140 shared_terms: FxHashMap::default(),
141 pending_equalities: VecDeque::new(),
142 current_level: 0,
143 }
144 }
145
146 pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
148 let theory_id = theory.theory_id();
149 self.theories.insert(theory_id, theory);
150 }
151
152 pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
154 if let Some(solver) = self.theories.get_mut(&theory) {
155 solver.assert_formula(formula)?;
156
157 self.identify_shared_terms(formula)?;
159 } else {
160 return Err(format!("Theory {:?} not registered", theory));
161 }
162
163 Ok(())
164 }
165
166 pub fn check_sat(&mut self) -> Result<SatResult, String> {
168 self.stats.check_sat_calls += 1;
169
170 for solver in self.theories.values_mut() {
172 let result = solver.check_sat()?;
173
174 match result {
175 SatResult::Unsat => {
176 self.stats.theory_conflicts += 1;
177 return Ok(SatResult::Unsat);
178 }
179 SatResult::Unknown => {
180 return Ok(SatResult::Unknown);
181 }
182 SatResult::Sat => {
183 }
185 }
186 }
187
188 if self.config.eager_combination {
190 self.eager_theory_combination()
191 } else {
192 self.lazy_theory_combination()
193 }
194 }
195
196 fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
198 let mut iteration = 0;
199
200 loop {
201 self.stats.theory_combination_rounds += 1;
202 iteration += 1;
203
204 if iteration > self.config.max_combination_rounds {
205 return Ok(SatResult::Unknown);
206 }
207
208 let mut new_equalities = Vec::new();
210
211 for (theory_id, solver) in &self.theories {
212 let equalities = solver.get_implied_equalities();
213
214 for (lhs, rhs) in equalities {
215 if self.is_shared_term(lhs) || self.is_shared_term(rhs) {
217 new_equalities.push(EqualityProp {
218 lhs,
219 rhs,
220 source: *theory_id,
221 explanation: vec![],
222 });
223 }
224 }
225 }
226
227 if new_equalities.is_empty() {
229 return Ok(SatResult::Sat);
230 }
231
232 for eq in new_equalities {
234 self.propagate_equality(eq)?;
235 }
236
237 for solver in self.theories.values_mut() {
239 match solver.check_sat()? {
240 SatResult::Unsat => {
241 self.stats.theory_conflicts += 1;
242 return Ok(SatResult::Unsat);
243 }
244 SatResult::Unknown => {
245 return Ok(SatResult::Unknown);
246 }
247 SatResult::Sat => {}
248 }
249 }
250 }
251 }
252
253 fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
255 while let Some(eq) = self.pending_equalities.pop_front() {
257 self.propagate_equality(eq)?;
258
259 for solver in self.theories.values_mut() {
261 match solver.check_sat()? {
262 SatResult::Unsat => {
263 self.stats.theory_conflicts += 1;
264 return Ok(SatResult::Unsat);
265 }
266 SatResult::Unknown => {
267 return Ok(SatResult::Unknown);
268 }
269 SatResult::Sat => {}
270 }
271 }
272 }
273
274 Ok(SatResult::Sat)
275 }
276
277 fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
279 self.stats.equalities_propagated += 1;
280
281 self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
283
284 let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
286
287 for theory_id in theories_to_notify {
288 if theory_id != eq.source
289 && let Some(solver) = self.theories.get_mut(&theory_id)
290 {
291 solver.notify_equality(eq.lhs, eq.rhs)?;
292 }
293 }
294
295 Ok(())
296 }
297
298 fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
300 self.stats.shared_terms_count = self.shared_terms.len();
303 Ok(())
304 }
305
306 fn is_shared_term(&self, term: TermId) -> bool {
308 self.shared_terms
309 .get(&term)
310 .is_some_and(|st| st.theories.len() > 1)
311 }
312
313 fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
315 let mut theories = FxHashSet::default();
316
317 if let Some(st) = self.shared_terms.get(&lhs) {
318 theories.extend(&st.theories);
319 }
320
321 if let Some(st) = self.shared_terms.get(&rhs) {
322 theories.extend(&st.theories);
323 }
324
325 theories
326 }
327
328 fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
330 let lhs_rep = self.find_representative(lhs);
332 let rhs_rep = self.find_representative(rhs);
333
334 if lhs_rep == rhs_rep {
335 return Ok(());
336 }
337
338 if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
340 st.representative = rhs_rep;
341 }
342
343 Ok(())
344 }
345
346 fn find_representative(&self, term: TermId) -> TermId {
348 if let Some(st) = self.shared_terms.get(&term)
349 && st.representative != term
350 {
351 return self.find_representative(st.representative);
353 }
354 term
355 }
356
357 pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
359 self.shared_terms
360 .entry(term)
361 .or_insert_with(|| SharedTerm {
362 term,
363 theories: FxHashSet::default(),
364 representative: term,
365 })
366 .theories
367 .insert(theory);
368
369 self.stats.shared_terms_count = self.shared_terms.len();
370 }
371
372 pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
374 self.pending_equalities.push_back(EqualityProp {
375 lhs,
376 rhs,
377 source,
378 explanation: vec![],
379 });
380 }
381
382 pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
384 self.current_level = level;
385
386 for solver in self.theories.values_mut() {
387 solver.backtrack(level)?;
388 }
389
390 self.pending_equalities.clear();
392
393 Ok(())
394 }
395
396 pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
398 let mut combined_model = FxHashMap::default();
399
400 for solver in self.theories.values() {
401 if let Some(model) = solver.get_model() {
402 combined_model.extend(model);
403 } else {
404 return None;
405 }
406 }
407
408 Some(combined_model)
409 }
410
411 pub fn get_conflict(&self) -> Option<Vec<TermId>> {
413 let mut combined_conflict = Vec::new();
415
416 for solver in self.theories.values() {
417 if let Some(conflict) = solver.get_conflict() {
418 combined_conflict.extend(conflict);
419 }
420 }
421
422 if combined_conflict.is_empty() {
423 None
424 } else {
425 if self.config.minimize_conflicts {
427 Some(self.minimize_conflict(combined_conflict))
428 } else {
429 Some(combined_conflict)
430 }
431 }
432 }
433
434 fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
436 conflict.sort();
439 conflict.dedup();
440 conflict
441 }
442
443 pub fn stats(&self) -> &CoordinatorStats {
445 &self.stats
446 }
447
448 pub fn current_level(&self) -> usize {
450 self.current_level
451 }
452
453 pub fn increment_level(&mut self) {
455 self.current_level += 1;
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 struct MockTheory {
465 id: TheoryId,
466 sat_result: SatResult,
467 }
468
469 impl TheorySolver for MockTheory {
470 fn theory_id(&self) -> TheoryId {
471 self.id
472 }
473
474 fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
475 Ok(())
476 }
477
478 fn check_sat(&mut self) -> Result<SatResult, String> {
479 Ok(self.sat_result)
480 }
481
482 fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
483 Some(FxHashMap::default())
484 }
485
486 fn get_conflict(&self) -> Option<Vec<TermId>> {
487 None
488 }
489
490 fn backtrack(&mut self, _level: usize) -> Result<(), String> {
491 Ok(())
492 }
493
494 fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
495 vec![]
496 }
497
498 fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
499 Ok(())
500 }
501 }
502
503 #[test]
504 fn test_coordinator_creation() {
505 let config = CoordinatorConfig::default();
506 let coordinator = TheoryCoordinator::new(config);
507 assert_eq!(coordinator.stats.check_sat_calls, 0);
508 }
509
510 #[test]
511 fn test_register_theory() {
512 let config = CoordinatorConfig::default();
513 let mut coordinator = TheoryCoordinator::new(config);
514
515 let mock_theory = MockTheory {
516 id: TheoryId::Arithmetic,
517 sat_result: SatResult::Sat,
518 };
519
520 coordinator.register_theory(Box::new(mock_theory));
521 assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
522 }
523
524 #[test]
525 fn test_check_sat_single_theory() {
526 let config = CoordinatorConfig::default();
527 let mut coordinator = TheoryCoordinator::new(config);
528
529 let mock_theory = MockTheory {
530 id: TheoryId::Arithmetic,
531 sat_result: SatResult::Sat,
532 };
533
534 coordinator.register_theory(Box::new(mock_theory));
535
536 let result = coordinator.check_sat();
537 assert!(result.is_ok());
538 assert_eq!(result.unwrap(), SatResult::Sat);
539 assert_eq!(coordinator.stats.check_sat_calls, 1);
540 }
541
542 #[test]
543 fn test_shared_term_management() {
544 let config = CoordinatorConfig::default();
545 let mut coordinator = TheoryCoordinator::new(config);
546
547 coordinator.add_shared_term(1, TheoryId::Arithmetic);
548 coordinator.add_shared_term(1, TheoryId::BitVector);
549
550 assert!(coordinator.is_shared_term(1));
551 assert_eq!(coordinator.stats.shared_terms_count, 1);
552 }
553
554 #[test]
555 fn test_equivalence_classes() {
556 let config = CoordinatorConfig::default();
557 let mut coordinator = TheoryCoordinator::new(config);
558
559 coordinator.add_shared_term(1, TheoryId::Arithmetic);
560 coordinator.add_shared_term(2, TheoryId::Arithmetic);
561
562 coordinator.merge_equivalence_classes(1, 2).unwrap();
563
564 let rep1 = coordinator.find_representative(1);
565 let rep2 = coordinator.find_representative(2);
566 assert_eq!(rep1, rep2);
567 }
568
569 #[test]
570 fn test_equality_propagation() {
571 let config = CoordinatorConfig::default();
572 let mut coordinator = TheoryCoordinator::new(config);
573
574 coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
575 assert_eq!(coordinator.pending_equalities.len(), 1);
576 }
577
578 #[test]
579 fn test_backtrack() {
580 let config = CoordinatorConfig::default();
581 let mut coordinator = TheoryCoordinator::new(config);
582
583 let mock_theory = MockTheory {
584 id: TheoryId::Arithmetic,
585 sat_result: SatResult::Sat,
586 };
587
588 coordinator.register_theory(Box::new(mock_theory));
589 coordinator.increment_level();
590 coordinator.increment_level();
591
592 assert_eq!(coordinator.current_level(), 2);
593
594 coordinator.backtrack(0).unwrap();
595 assert_eq!(coordinator.current_level(), 0);
596 }
597
598 #[test]
599 fn test_get_model() {
600 let config = CoordinatorConfig::default();
601 let mut coordinator = TheoryCoordinator::new(config);
602
603 let mock_theory = MockTheory {
604 id: TheoryId::Arithmetic,
605 sat_result: SatResult::Sat,
606 };
607
608 coordinator.register_theory(Box::new(mock_theory));
609
610 let model = coordinator.get_model();
611 assert!(model.is_some());
612 }
613
614 #[test]
615 fn test_conflict_minimization() {
616 let coordinator = TheoryCoordinator::new(CoordinatorConfig {
617 minimize_conflicts: true,
618 ..Default::default()
619 });
620
621 let conflict = vec![1, 2, 2, 3, 1, 4];
622 let minimized = coordinator.minimize_conflict(conflict);
623
624 assert_eq!(minimized, vec![1, 2, 3, 4]);
625 }
626}