Skip to main content

oxiz_solver/combination/
coordinator.rs

1//! Theory Combination Coordinator
2//!
3//! This module coordinates multiple theory solvers using the Nelson-Oppen method
4//! with optimizations:
5//! - Lazy vs eager theory combination
6//! - Shared term management
7//! - Equality sharing between theories
8//! - Conflict minimization across theories
9
10#![allow(missing_docs)] // Under development
11
12#[allow(unused_imports)]
13use crate::prelude::*;
14
15/// Placeholder term identifier
16pub type TermId = usize;
17
18/// Theory identifier
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum TheoryId {
21    Core,
22    Arithmetic,
23    BitVector,
24    Array,
25    Datatype,
26    String,
27    Uninterpreted,
28}
29
30/// Theory interface
31pub trait TheorySolver {
32    /// Get theory ID
33    fn theory_id(&self) -> TheoryId;
34
35    /// Assert a formula
36    fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
37
38    /// Check satisfiability
39    fn check_sat(&mut self) -> Result<SatResult, String>;
40
41    /// Get model (if SAT)
42    fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
43
44    /// Get conflict explanation (if UNSAT)
45    fn get_conflict(&self) -> Option<Vec<TermId>>;
46
47    /// Backtrack to a level
48    fn backtrack(&mut self, level: usize) -> Result<(), String>;
49
50    /// Get implied equalities
51    fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
52
53    /// Notify of external equality
54    fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
55}
56
57/// Satisfiability result
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum SatResult {
60    Sat,
61    Unsat,
62    Unknown,
63}
64
65/// Shared term between theories
66#[derive(Debug, Clone)]
67pub struct SharedTerm {
68    /// The term
69    pub term: TermId,
70    /// Theories that use this term
71    pub theories: FxHashSet<TheoryId>,
72    /// Current equivalence class representative
73    pub representative: TermId,
74}
75
76/// Equality propagation item
77#[derive(Debug, Clone)]
78pub struct EqualityProp {
79    /// Left-hand side
80    pub lhs: TermId,
81    /// Right-hand side
82    pub rhs: TermId,
83    /// Source theory
84    pub source: TheoryId,
85    /// Explanation (justification)
86    pub explanation: Vec<TermId>,
87}
88
89/// Statistics for theory combination
90#[derive(Debug, Clone, Default)]
91pub struct CoordinatorStats {
92    pub check_sat_calls: u64,
93    pub theory_conflicts: u64,
94    pub equalities_propagated: u64,
95    pub shared_terms_count: usize,
96    pub theory_combination_rounds: u64,
97}
98
99/// Configuration for theory combination
100#[derive(Debug, Clone)]
101pub struct CoordinatorConfig {
102    /// Use eager theory combination (propagate all equalities immediately)
103    pub eager_combination: bool,
104    /// Maximum theory combination rounds
105    pub max_combination_rounds: usize,
106    /// Enable conflict minimization across theories
107    pub minimize_conflicts: bool,
108}
109
110impl Default for CoordinatorConfig {
111    fn default() -> Self {
112        Self {
113            eager_combination: false,
114            max_combination_rounds: 10,
115            minimize_conflicts: true,
116        }
117    }
118}
119
120/// Theory combination coordinator
121pub struct TheoryCoordinator {
122    config: CoordinatorConfig,
123    stats: CoordinatorStats,
124    /// Registered theory solvers
125    theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
126    /// Shared terms between theories
127    shared_terms: FxHashMap<TermId, SharedTerm>,
128    /// Pending equality propagations
129    pending_equalities: VecDeque<EqualityProp>,
130    /// Current decision level
131    current_level: usize,
132}
133
134impl TheoryCoordinator {
135    /// Create a new theory coordinator
136    pub fn new(config: CoordinatorConfig) -> Self {
137        Self {
138            config,
139            stats: CoordinatorStats::default(),
140            theories: FxHashMap::default(),
141            shared_terms: FxHashMap::default(),
142            pending_equalities: VecDeque::new(),
143            current_level: 0,
144        }
145    }
146
147    /// Register a theory solver
148    pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
149        let theory_id = theory.theory_id();
150        self.theories.insert(theory_id, theory);
151    }
152
153    /// Assert a formula to the appropriate theory
154    pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
155        if let Some(solver) = self.theories.get_mut(&theory) {
156            solver.assert_formula(formula)?;
157
158            // Identify shared terms
159            self.identify_shared_terms(formula)?;
160        } else {
161            return Err(format!("Theory {:?} not registered", theory));
162        }
163
164        Ok(())
165    }
166
167    /// Check satisfiability with theory combination
168    pub fn check_sat(&mut self) -> Result<SatResult, String> {
169        self.stats.check_sat_calls += 1;
170
171        // Phase 1: Check individual theories
172        for solver in self.theories.values_mut() {
173            let result = solver.check_sat()?;
174
175            match result {
176                SatResult::Unsat => {
177                    self.stats.theory_conflicts += 1;
178                    return Ok(SatResult::Unsat);
179                }
180                SatResult::Unknown => {
181                    return Ok(SatResult::Unknown);
182                }
183                SatResult::Sat => {
184                    // Continue to next theory
185                }
186            }
187        }
188
189        // Phase 2: Theory combination via equality sharing
190        if self.config.eager_combination {
191            self.eager_theory_combination()
192        } else {
193            self.lazy_theory_combination()
194        }
195    }
196
197    /// Eager theory combination: propagate all equalities immediately
198    fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
199        let mut iteration = 0;
200
201        loop {
202            self.stats.theory_combination_rounds += 1;
203            iteration += 1;
204
205            if iteration > self.config.max_combination_rounds {
206                return Ok(SatResult::Unknown);
207            }
208
209            // Collect implied equalities from all theories
210            let mut new_equalities = Vec::new();
211
212            for (theory_id, solver) in &self.theories {
213                let equalities = solver.get_implied_equalities();
214
215                for (lhs, rhs) in equalities {
216                    // Only propagate equalities between shared terms
217                    if self.is_shared_term(lhs) || self.is_shared_term(rhs) {
218                        new_equalities.push(EqualityProp {
219                            lhs,
220                            rhs,
221                            source: *theory_id,
222                            explanation: vec![],
223                        });
224                    }
225                }
226            }
227
228            // No new equalities: fixed point reached
229            if new_equalities.is_empty() {
230                return Ok(SatResult::Sat);
231            }
232
233            // Propagate equalities to all theories
234            for eq in new_equalities {
235                self.propagate_equality(eq)?;
236            }
237
238            // Re-check theories for conflicts
239            for solver in self.theories.values_mut() {
240                match solver.check_sat()? {
241                    SatResult::Unsat => {
242                        self.stats.theory_conflicts += 1;
243                        return Ok(SatResult::Unsat);
244                    }
245                    SatResult::Unknown => {
246                        return Ok(SatResult::Unknown);
247                    }
248                    SatResult::Sat => {}
249                }
250            }
251        }
252    }
253
254    /// Lazy theory combination: propagate equalities on-demand
255    fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
256        // Process pending equalities
257        while let Some(eq) = self.pending_equalities.pop_front() {
258            self.propagate_equality(eq)?;
259
260            // Check for conflicts after each propagation
261            for solver in self.theories.values_mut() {
262                match solver.check_sat()? {
263                    SatResult::Unsat => {
264                        self.stats.theory_conflicts += 1;
265                        return Ok(SatResult::Unsat);
266                    }
267                    SatResult::Unknown => {
268                        return Ok(SatResult::Unknown);
269                    }
270                    SatResult::Sat => {}
271                }
272            }
273        }
274
275        Ok(SatResult::Sat)
276    }
277
278    /// Propagate an equality to all relevant theories
279    fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
280        self.stats.equalities_propagated += 1;
281
282        // Update equivalence classes
283        self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
284
285        // Notify all theories that use these terms
286        let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
287
288        for theory_id in theories_to_notify {
289            if theory_id != eq.source
290                && let Some(solver) = self.theories.get_mut(&theory_id)
291            {
292                solver.notify_equality(eq.lhs, eq.rhs)?;
293            }
294        }
295
296        Ok(())
297    }
298
299    /// Identify shared terms in a formula
300    fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
301        // Placeholder: would traverse formula AST and identify terms used by multiple theories
302        // For now, just update stats
303        self.stats.shared_terms_count = self.shared_terms.len();
304        Ok(())
305    }
306
307    /// Check if a term is shared between theories
308    fn is_shared_term(&self, term: TermId) -> bool {
309        self.shared_terms
310            .get(&term)
311            .is_some_and(|st| st.theories.len() > 1)
312    }
313
314    /// Get theories that use given terms
315    fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
316        let mut theories = FxHashSet::default();
317
318        if let Some(st) = self.shared_terms.get(&lhs) {
319            theories.extend(&st.theories);
320        }
321
322        if let Some(st) = self.shared_terms.get(&rhs) {
323            theories.extend(&st.theories);
324        }
325
326        theories
327    }
328
329    /// Merge equivalence classes for two terms
330    fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
331        // Get representatives
332        let lhs_rep = self.find_representative(lhs);
333        let rhs_rep = self.find_representative(rhs);
334
335        if lhs_rep == rhs_rep {
336            return Ok(());
337        }
338
339        // Union: make lhs_rep point to rhs_rep
340        if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
341            st.representative = rhs_rep;
342        }
343
344        Ok(())
345    }
346
347    /// Find equivalence class representative
348    fn find_representative(&self, term: TermId) -> TermId {
349        if let Some(st) = self.shared_terms.get(&term)
350            && st.representative != term
351        {
352            // Path compression would be applied here
353            return self.find_representative(st.representative);
354        }
355        term
356    }
357
358    /// Add a shared term
359    pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
360        self.shared_terms
361            .entry(term)
362            .or_insert_with(|| SharedTerm {
363                term,
364                theories: FxHashSet::default(),
365                representative: term,
366            })
367            .theories
368            .insert(theory);
369
370        self.stats.shared_terms_count = self.shared_terms.len();
371    }
372
373    /// Enqueue an equality for propagation
374    pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
375        self.pending_equalities.push_back(EqualityProp {
376            lhs,
377            rhs,
378            source,
379            explanation: vec![],
380        });
381    }
382
383    /// Backtrack all theories to a level
384    pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
385        self.current_level = level;
386
387        for solver in self.theories.values_mut() {
388            solver.backtrack(level)?;
389        }
390
391        // Clear pending equalities
392        self.pending_equalities.clear();
393
394        Ok(())
395    }
396
397    /// Get combined model from all theories
398    pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
399        let mut combined_model = FxHashMap::default();
400
401        for solver in self.theories.values() {
402            if let Some(model) = solver.get_model() {
403                combined_model.extend(model);
404            } else {
405                return None;
406            }
407        }
408
409        Some(combined_model)
410    }
411
412    /// Get combined conflict explanation
413    pub fn get_conflict(&self) -> Option<Vec<TermId>> {
414        // Collect conflicts from all theories
415        let mut combined_conflict = Vec::new();
416
417        for solver in self.theories.values() {
418            if let Some(conflict) = solver.get_conflict() {
419                combined_conflict.extend(conflict);
420            }
421        }
422
423        if combined_conflict.is_empty() {
424            None
425        } else {
426            // Minimize if enabled
427            if self.config.minimize_conflicts {
428                Some(self.minimize_conflict(combined_conflict))
429            } else {
430                Some(combined_conflict)
431            }
432        }
433    }
434
435    /// Minimize a conflict explanation
436    fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
437        // Placeholder: would use resolution to minimize
438        // For now, just remove duplicates
439        conflict.sort();
440        conflict.dedup();
441        conflict
442    }
443
444    /// Get statistics
445    pub fn stats(&self) -> &CoordinatorStats {
446        &self.stats
447    }
448
449    /// Get current decision level
450    pub fn current_level(&self) -> usize {
451        self.current_level
452    }
453
454    /// Increment decision level
455    pub fn increment_level(&mut self) {
456        self.current_level += 1;
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    // Mock theory solver for testing
465    struct MockTheory {
466        id: TheoryId,
467        sat_result: SatResult,
468    }
469
470    impl TheorySolver for MockTheory {
471        fn theory_id(&self) -> TheoryId {
472            self.id
473        }
474
475        fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
476            Ok(())
477        }
478
479        fn check_sat(&mut self) -> Result<SatResult, String> {
480            Ok(self.sat_result)
481        }
482
483        fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
484            Some(FxHashMap::default())
485        }
486
487        fn get_conflict(&self) -> Option<Vec<TermId>> {
488            None
489        }
490
491        fn backtrack(&mut self, _level: usize) -> Result<(), String> {
492            Ok(())
493        }
494
495        fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
496            vec![]
497        }
498
499        fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
500            Ok(())
501        }
502    }
503
504    #[test]
505    fn test_coordinator_creation() {
506        let config = CoordinatorConfig::default();
507        let coordinator = TheoryCoordinator::new(config);
508        assert_eq!(coordinator.stats.check_sat_calls, 0);
509    }
510
511    #[test]
512    fn test_register_theory() {
513        let config = CoordinatorConfig::default();
514        let mut coordinator = TheoryCoordinator::new(config);
515
516        let mock_theory = MockTheory {
517            id: TheoryId::Arithmetic,
518            sat_result: SatResult::Sat,
519        };
520
521        coordinator.register_theory(Box::new(mock_theory));
522        assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
523    }
524
525    #[test]
526    fn test_check_sat_single_theory() {
527        let config = CoordinatorConfig::default();
528        let mut coordinator = TheoryCoordinator::new(config);
529
530        let mock_theory = MockTheory {
531            id: TheoryId::Arithmetic,
532            sat_result: SatResult::Sat,
533        };
534
535        coordinator.register_theory(Box::new(mock_theory));
536
537        let result = coordinator.check_sat();
538        assert!(result.is_ok());
539        assert_eq!(
540            result.expect("test operation should succeed"),
541            SatResult::Sat
542        );
543        assert_eq!(coordinator.stats.check_sat_calls, 1);
544    }
545
546    #[test]
547    fn test_shared_term_management() {
548        let config = CoordinatorConfig::default();
549        let mut coordinator = TheoryCoordinator::new(config);
550
551        coordinator.add_shared_term(1, TheoryId::Arithmetic);
552        coordinator.add_shared_term(1, TheoryId::BitVector);
553
554        assert!(coordinator.is_shared_term(1));
555        assert_eq!(coordinator.stats.shared_terms_count, 1);
556    }
557
558    #[test]
559    fn test_equivalence_classes() {
560        let config = CoordinatorConfig::default();
561        let mut coordinator = TheoryCoordinator::new(config);
562
563        coordinator.add_shared_term(1, TheoryId::Arithmetic);
564        coordinator.add_shared_term(2, TheoryId::Arithmetic);
565
566        coordinator
567            .merge_equivalence_classes(1, 2)
568            .expect("test operation should succeed");
569
570        let rep1 = coordinator.find_representative(1);
571        let rep2 = coordinator.find_representative(2);
572        assert_eq!(rep1, rep2);
573    }
574
575    #[test]
576    fn test_equality_propagation() {
577        let config = CoordinatorConfig::default();
578        let mut coordinator = TheoryCoordinator::new(config);
579
580        coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
581        assert_eq!(coordinator.pending_equalities.len(), 1);
582    }
583
584    #[test]
585    fn test_backtrack() {
586        let config = CoordinatorConfig::default();
587        let mut coordinator = TheoryCoordinator::new(config);
588
589        let mock_theory = MockTheory {
590            id: TheoryId::Arithmetic,
591            sat_result: SatResult::Sat,
592        };
593
594        coordinator.register_theory(Box::new(mock_theory));
595        coordinator.increment_level();
596        coordinator.increment_level();
597
598        assert_eq!(coordinator.current_level(), 2);
599
600        coordinator
601            .backtrack(0)
602            .expect("test operation should succeed");
603        assert_eq!(coordinator.current_level(), 0);
604    }
605
606    #[test]
607    fn test_get_model() {
608        let config = CoordinatorConfig::default();
609        let mut coordinator = TheoryCoordinator::new(config);
610
611        let mock_theory = MockTheory {
612            id: TheoryId::Arithmetic,
613            sat_result: SatResult::Sat,
614        };
615
616        coordinator.register_theory(Box::new(mock_theory));
617
618        let model = coordinator.get_model();
619        assert!(model.is_some());
620    }
621
622    #[test]
623    fn test_conflict_minimization() {
624        let coordinator = TheoryCoordinator::new(CoordinatorConfig {
625            minimize_conflicts: true,
626            ..Default::default()
627        });
628
629        let conflict = vec![1, 2, 2, 3, 1, 4];
630        let minimized = coordinator.minimize_conflict(conflict);
631
632        assert_eq!(minimized, vec![1, 2, 3, 4]);
633    }
634}