Skip to main content

oxiz_solver/combination/
coordinator.rs

1//! Theory Combination Coordinator
2#![allow(missing_docs)] // Under development
3//!
4//! This module coordinates multiple theory solvers using the Nelson-Oppen method
5//! with optimizations:
6//! - Lazy vs eager theory combination
7//! - Shared term management
8//! - Equality sharing between theories
9//! - Conflict minimization across theories
10
11use rustc_hash::{FxHashMap, FxHashSet};
12use std::collections::VecDeque;
13
14/// Placeholder term identifier
15pub type TermId = usize;
16
17/// Theory identifier
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum TheoryId {
20    Core,
21    Arithmetic,
22    BitVector,
23    Array,
24    Datatype,
25    String,
26    Uninterpreted,
27}
28
29/// Theory interface
30pub trait TheorySolver {
31    /// Get theory ID
32    fn theory_id(&self) -> TheoryId;
33
34    /// Assert a formula
35    fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
36
37    /// Check satisfiability
38    fn check_sat(&mut self) -> Result<SatResult, String>;
39
40    /// Get model (if SAT)
41    fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
42
43    /// Get conflict explanation (if UNSAT)
44    fn get_conflict(&self) -> Option<Vec<TermId>>;
45
46    /// Backtrack to a level
47    fn backtrack(&mut self, level: usize) -> Result<(), String>;
48
49    /// Get implied equalities
50    fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
51
52    /// Notify of external equality
53    fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
54}
55
56/// Satisfiability result
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SatResult {
59    Sat,
60    Unsat,
61    Unknown,
62}
63
64/// Shared term between theories
65#[derive(Debug, Clone)]
66pub struct SharedTerm {
67    /// The term
68    pub term: TermId,
69    /// Theories that use this term
70    pub theories: FxHashSet<TheoryId>,
71    /// Current equivalence class representative
72    pub representative: TermId,
73}
74
75/// Equality propagation item
76#[derive(Debug, Clone)]
77pub struct EqualityProp {
78    /// Left-hand side
79    pub lhs: TermId,
80    /// Right-hand side
81    pub rhs: TermId,
82    /// Source theory
83    pub source: TheoryId,
84    /// Explanation (justification)
85    pub explanation: Vec<TermId>,
86}
87
88/// Statistics for theory combination
89#[derive(Debug, Clone, Default)]
90pub struct CoordinatorStats {
91    pub check_sat_calls: u64,
92    pub theory_conflicts: u64,
93    pub equalities_propagated: u64,
94    pub shared_terms_count: usize,
95    pub theory_combination_rounds: u64,
96}
97
98/// Configuration for theory combination
99#[derive(Debug, Clone)]
100pub struct CoordinatorConfig {
101    /// Use eager theory combination (propagate all equalities immediately)
102    pub eager_combination: bool,
103    /// Maximum theory combination rounds
104    pub max_combination_rounds: usize,
105    /// Enable conflict minimization across theories
106    pub minimize_conflicts: bool,
107}
108
109impl Default for CoordinatorConfig {
110    fn default() -> Self {
111        Self {
112            eager_combination: false,
113            max_combination_rounds: 10,
114            minimize_conflicts: true,
115        }
116    }
117}
118
119/// Theory combination coordinator
120pub struct TheoryCoordinator {
121    config: CoordinatorConfig,
122    stats: CoordinatorStats,
123    /// Registered theory solvers
124    theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
125    /// Shared terms between theories
126    shared_terms: FxHashMap<TermId, SharedTerm>,
127    /// Pending equality propagations
128    pending_equalities: VecDeque<EqualityProp>,
129    /// Current decision level
130    current_level: usize,
131}
132
133impl TheoryCoordinator {
134    /// Create a new theory coordinator
135    pub fn new(config: CoordinatorConfig) -> Self {
136        Self {
137            config,
138            stats: CoordinatorStats::default(),
139            theories: FxHashMap::default(),
140            shared_terms: FxHashMap::default(),
141            pending_equalities: VecDeque::new(),
142            current_level: 0,
143        }
144    }
145
146    /// Register a theory solver
147    pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
148        let theory_id = theory.theory_id();
149        self.theories.insert(theory_id, theory);
150    }
151
152    /// Assert a formula to the appropriate theory
153    pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
154        if let Some(solver) = self.theories.get_mut(&theory) {
155            solver.assert_formula(formula)?;
156
157            // Identify shared terms
158            self.identify_shared_terms(formula)?;
159        } else {
160            return Err(format!("Theory {:?} not registered", theory));
161        }
162
163        Ok(())
164    }
165
166    /// Check satisfiability with theory combination
167    pub fn check_sat(&mut self) -> Result<SatResult, String> {
168        self.stats.check_sat_calls += 1;
169
170        // Phase 1: Check individual theories
171        for solver in self.theories.values_mut() {
172            let result = solver.check_sat()?;
173
174            match result {
175                SatResult::Unsat => {
176                    self.stats.theory_conflicts += 1;
177                    return Ok(SatResult::Unsat);
178                }
179                SatResult::Unknown => {
180                    return Ok(SatResult::Unknown);
181                }
182                SatResult::Sat => {
183                    // Continue to next theory
184                }
185            }
186        }
187
188        // Phase 2: Theory combination via equality sharing
189        if self.config.eager_combination {
190            self.eager_theory_combination()
191        } else {
192            self.lazy_theory_combination()
193        }
194    }
195
196    /// Eager theory combination: propagate all equalities immediately
197    fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
198        let mut iteration = 0;
199
200        loop {
201            self.stats.theory_combination_rounds += 1;
202            iteration += 1;
203
204            if iteration > self.config.max_combination_rounds {
205                return Ok(SatResult::Unknown);
206            }
207
208            // Collect implied equalities from all theories
209            let mut new_equalities = Vec::new();
210
211            for (theory_id, solver) in &self.theories {
212                let equalities = solver.get_implied_equalities();
213
214                for (lhs, rhs) in equalities {
215                    // Only propagate equalities between shared terms
216                    if self.is_shared_term(lhs) || self.is_shared_term(rhs) {
217                        new_equalities.push(EqualityProp {
218                            lhs,
219                            rhs,
220                            source: *theory_id,
221                            explanation: vec![],
222                        });
223                    }
224                }
225            }
226
227            // No new equalities: fixed point reached
228            if new_equalities.is_empty() {
229                return Ok(SatResult::Sat);
230            }
231
232            // Propagate equalities to all theories
233            for eq in new_equalities {
234                self.propagate_equality(eq)?;
235            }
236
237            // Re-check theories for conflicts
238            for solver in self.theories.values_mut() {
239                match solver.check_sat()? {
240                    SatResult::Unsat => {
241                        self.stats.theory_conflicts += 1;
242                        return Ok(SatResult::Unsat);
243                    }
244                    SatResult::Unknown => {
245                        return Ok(SatResult::Unknown);
246                    }
247                    SatResult::Sat => {}
248                }
249            }
250        }
251    }
252
253    /// Lazy theory combination: propagate equalities on-demand
254    fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
255        // Process pending equalities
256        while let Some(eq) = self.pending_equalities.pop_front() {
257            self.propagate_equality(eq)?;
258
259            // Check for conflicts after each propagation
260            for solver in self.theories.values_mut() {
261                match solver.check_sat()? {
262                    SatResult::Unsat => {
263                        self.stats.theory_conflicts += 1;
264                        return Ok(SatResult::Unsat);
265                    }
266                    SatResult::Unknown => {
267                        return Ok(SatResult::Unknown);
268                    }
269                    SatResult::Sat => {}
270                }
271            }
272        }
273
274        Ok(SatResult::Sat)
275    }
276
277    /// Propagate an equality to all relevant theories
278    fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
279        self.stats.equalities_propagated += 1;
280
281        // Update equivalence classes
282        self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
283
284        // Notify all theories that use these terms
285        let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
286
287        for theory_id in theories_to_notify {
288            if theory_id != eq.source
289                && let Some(solver) = self.theories.get_mut(&theory_id)
290            {
291                solver.notify_equality(eq.lhs, eq.rhs)?;
292            }
293        }
294
295        Ok(())
296    }
297
298    /// Identify shared terms in a formula
299    fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
300        // Placeholder: would traverse formula AST and identify terms used by multiple theories
301        // For now, just update stats
302        self.stats.shared_terms_count = self.shared_terms.len();
303        Ok(())
304    }
305
306    /// Check if a term is shared between theories
307    fn is_shared_term(&self, term: TermId) -> bool {
308        self.shared_terms
309            .get(&term)
310            .is_some_and(|st| st.theories.len() > 1)
311    }
312
313    /// Get theories that use given terms
314    fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
315        let mut theories = FxHashSet::default();
316
317        if let Some(st) = self.shared_terms.get(&lhs) {
318            theories.extend(&st.theories);
319        }
320
321        if let Some(st) = self.shared_terms.get(&rhs) {
322            theories.extend(&st.theories);
323        }
324
325        theories
326    }
327
328    /// Merge equivalence classes for two terms
329    fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
330        // Get representatives
331        let lhs_rep = self.find_representative(lhs);
332        let rhs_rep = self.find_representative(rhs);
333
334        if lhs_rep == rhs_rep {
335            return Ok(());
336        }
337
338        // Union: make lhs_rep point to rhs_rep
339        if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
340            st.representative = rhs_rep;
341        }
342
343        Ok(())
344    }
345
346    /// Find equivalence class representative
347    fn find_representative(&self, term: TermId) -> TermId {
348        if let Some(st) = self.shared_terms.get(&term)
349            && st.representative != term
350        {
351            // Path compression would be applied here
352            return self.find_representative(st.representative);
353        }
354        term
355    }
356
357    /// Add a shared term
358    pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
359        self.shared_terms
360            .entry(term)
361            .or_insert_with(|| SharedTerm {
362                term,
363                theories: FxHashSet::default(),
364                representative: term,
365            })
366            .theories
367            .insert(theory);
368
369        self.stats.shared_terms_count = self.shared_terms.len();
370    }
371
372    /// Enqueue an equality for propagation
373    pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
374        self.pending_equalities.push_back(EqualityProp {
375            lhs,
376            rhs,
377            source,
378            explanation: vec![],
379        });
380    }
381
382    /// Backtrack all theories to a level
383    pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
384        self.current_level = level;
385
386        for solver in self.theories.values_mut() {
387            solver.backtrack(level)?;
388        }
389
390        // Clear pending equalities
391        self.pending_equalities.clear();
392
393        Ok(())
394    }
395
396    /// Get combined model from all theories
397    pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
398        let mut combined_model = FxHashMap::default();
399
400        for solver in self.theories.values() {
401            if let Some(model) = solver.get_model() {
402                combined_model.extend(model);
403            } else {
404                return None;
405            }
406        }
407
408        Some(combined_model)
409    }
410
411    /// Get combined conflict explanation
412    pub fn get_conflict(&self) -> Option<Vec<TermId>> {
413        // Collect conflicts from all theories
414        let mut combined_conflict = Vec::new();
415
416        for solver in self.theories.values() {
417            if let Some(conflict) = solver.get_conflict() {
418                combined_conflict.extend(conflict);
419            }
420        }
421
422        if combined_conflict.is_empty() {
423            None
424        } else {
425            // Minimize if enabled
426            if self.config.minimize_conflicts {
427                Some(self.minimize_conflict(combined_conflict))
428            } else {
429                Some(combined_conflict)
430            }
431        }
432    }
433
434    /// Minimize a conflict explanation
435    fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
436        // Placeholder: would use resolution to minimize
437        // For now, just remove duplicates
438        conflict.sort();
439        conflict.dedup();
440        conflict
441    }
442
443    /// Get statistics
444    pub fn stats(&self) -> &CoordinatorStats {
445        &self.stats
446    }
447
448    /// Get current decision level
449    pub fn current_level(&self) -> usize {
450        self.current_level
451    }
452
453    /// Increment decision level
454    pub fn increment_level(&mut self) {
455        self.current_level += 1;
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    // Mock theory solver for testing
464    struct MockTheory {
465        id: TheoryId,
466        sat_result: SatResult,
467    }
468
469    impl TheorySolver for MockTheory {
470        fn theory_id(&self) -> TheoryId {
471            self.id
472        }
473
474        fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
475            Ok(())
476        }
477
478        fn check_sat(&mut self) -> Result<SatResult, String> {
479            Ok(self.sat_result)
480        }
481
482        fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
483            Some(FxHashMap::default())
484        }
485
486        fn get_conflict(&self) -> Option<Vec<TermId>> {
487            None
488        }
489
490        fn backtrack(&mut self, _level: usize) -> Result<(), String> {
491            Ok(())
492        }
493
494        fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
495            vec![]
496        }
497
498        fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
499            Ok(())
500        }
501    }
502
503    #[test]
504    fn test_coordinator_creation() {
505        let config = CoordinatorConfig::default();
506        let coordinator = TheoryCoordinator::new(config);
507        assert_eq!(coordinator.stats.check_sat_calls, 0);
508    }
509
510    #[test]
511    fn test_register_theory() {
512        let config = CoordinatorConfig::default();
513        let mut coordinator = TheoryCoordinator::new(config);
514
515        let mock_theory = MockTheory {
516            id: TheoryId::Arithmetic,
517            sat_result: SatResult::Sat,
518        };
519
520        coordinator.register_theory(Box::new(mock_theory));
521        assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
522    }
523
524    #[test]
525    fn test_check_sat_single_theory() {
526        let config = CoordinatorConfig::default();
527        let mut coordinator = TheoryCoordinator::new(config);
528
529        let mock_theory = MockTheory {
530            id: TheoryId::Arithmetic,
531            sat_result: SatResult::Sat,
532        };
533
534        coordinator.register_theory(Box::new(mock_theory));
535
536        let result = coordinator.check_sat();
537        assert!(result.is_ok());
538        assert_eq!(result.unwrap(), SatResult::Sat);
539        assert_eq!(coordinator.stats.check_sat_calls, 1);
540    }
541
542    #[test]
543    fn test_shared_term_management() {
544        let config = CoordinatorConfig::default();
545        let mut coordinator = TheoryCoordinator::new(config);
546
547        coordinator.add_shared_term(1, TheoryId::Arithmetic);
548        coordinator.add_shared_term(1, TheoryId::BitVector);
549
550        assert!(coordinator.is_shared_term(1));
551        assert_eq!(coordinator.stats.shared_terms_count, 1);
552    }
553
554    #[test]
555    fn test_equivalence_classes() {
556        let config = CoordinatorConfig::default();
557        let mut coordinator = TheoryCoordinator::new(config);
558
559        coordinator.add_shared_term(1, TheoryId::Arithmetic);
560        coordinator.add_shared_term(2, TheoryId::Arithmetic);
561
562        coordinator.merge_equivalence_classes(1, 2).unwrap();
563
564        let rep1 = coordinator.find_representative(1);
565        let rep2 = coordinator.find_representative(2);
566        assert_eq!(rep1, rep2);
567    }
568
569    #[test]
570    fn test_equality_propagation() {
571        let config = CoordinatorConfig::default();
572        let mut coordinator = TheoryCoordinator::new(config);
573
574        coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
575        assert_eq!(coordinator.pending_equalities.len(), 1);
576    }
577
578    #[test]
579    fn test_backtrack() {
580        let config = CoordinatorConfig::default();
581        let mut coordinator = TheoryCoordinator::new(config);
582
583        let mock_theory = MockTheory {
584            id: TheoryId::Arithmetic,
585            sat_result: SatResult::Sat,
586        };
587
588        coordinator.register_theory(Box::new(mock_theory));
589        coordinator.increment_level();
590        coordinator.increment_level();
591
592        assert_eq!(coordinator.current_level(), 2);
593
594        coordinator.backtrack(0).unwrap();
595        assert_eq!(coordinator.current_level(), 0);
596    }
597
598    #[test]
599    fn test_get_model() {
600        let config = CoordinatorConfig::default();
601        let mut coordinator = TheoryCoordinator::new(config);
602
603        let mock_theory = MockTheory {
604            id: TheoryId::Arithmetic,
605            sat_result: SatResult::Sat,
606        };
607
608        coordinator.register_theory(Box::new(mock_theory));
609
610        let model = coordinator.get_model();
611        assert!(model.is_some());
612    }
613
614    #[test]
615    fn test_conflict_minimization() {
616        let coordinator = TheoryCoordinator::new(CoordinatorConfig {
617            minimize_conflicts: true,
618            ..Default::default()
619        });
620
621        let conflict = vec![1, 2, 2, 3, 1, 4];
622        let minimized = coordinator.minimize_conflict(conflict);
623
624        assert_eq!(minimized, vec![1, 2, 3, 4]);
625    }
626}