#![allow(missing_docs)]
#[allow(unused_imports)]
use crate::prelude::*;
pub type TermId = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TheoryId {
Core,
Arithmetic,
BitVector,
Array,
Datatype,
String,
Uninterpreted,
}
pub trait TheorySolver {
fn theory_id(&self) -> TheoryId;
fn assert_formula(&mut self, formula: TermId) -> Result<(), String>;
fn check_sat(&mut self) -> Result<SatResult, String>;
fn get_model(&self) -> Option<FxHashMap<TermId, TermId>>;
fn get_conflict(&self) -> Option<Vec<TermId>>;
fn backtrack(&mut self, level: usize) -> Result<(), String>;
fn get_implied_equalities(&self) -> Vec<(TermId, TermId)>;
fn notify_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SatResult {
Sat,
Unsat,
Unknown,
}
#[derive(Debug, Clone)]
pub struct SharedTerm {
pub term: TermId,
pub theories: FxHashSet<TheoryId>,
pub representative: TermId,
}
#[derive(Debug, Clone)]
pub struct EqualityProp {
pub lhs: TermId,
pub rhs: TermId,
pub source: TheoryId,
pub explanation: Vec<TermId>,
}
#[derive(Debug, Clone, Default)]
pub struct CoordinatorStats {
pub check_sat_calls: u64,
pub theory_conflicts: u64,
pub equalities_propagated: u64,
pub shared_terms_count: usize,
pub theory_combination_rounds: u64,
}
#[derive(Debug, Clone)]
pub struct CoordinatorConfig {
pub eager_combination: bool,
pub max_combination_rounds: usize,
pub minimize_conflicts: bool,
}
impl Default for CoordinatorConfig {
fn default() -> Self {
Self {
eager_combination: false,
max_combination_rounds: 10,
minimize_conflicts: true,
}
}
}
pub struct TheoryCoordinator {
config: CoordinatorConfig,
stats: CoordinatorStats,
theories: FxHashMap<TheoryId, Box<dyn TheorySolver>>,
shared_terms: FxHashMap<TermId, SharedTerm>,
pending_equalities: VecDeque<EqualityProp>,
current_level: usize,
}
impl TheoryCoordinator {
pub fn new(config: CoordinatorConfig) -> Self {
Self {
config,
stats: CoordinatorStats::default(),
theories: FxHashMap::default(),
shared_terms: FxHashMap::default(),
pending_equalities: VecDeque::new(),
current_level: 0,
}
}
pub fn register_theory(&mut self, theory: Box<dyn TheorySolver>) {
let theory_id = theory.theory_id();
self.theories.insert(theory_id, theory);
}
pub fn assert_formula(&mut self, formula: TermId, theory: TheoryId) -> Result<(), String> {
if let Some(solver) = self.theories.get_mut(&theory) {
solver.assert_formula(formula)?;
self.identify_shared_terms(formula)?;
} else {
return Err(format!("Theory {:?} not registered", theory));
}
Ok(())
}
pub fn check_sat(&mut self) -> Result<SatResult, String> {
self.stats.check_sat_calls += 1;
for solver in self.theories.values_mut() {
let result = solver.check_sat()?;
match result {
SatResult::Unsat => {
self.stats.theory_conflicts += 1;
return Ok(SatResult::Unsat);
}
SatResult::Unknown => {
return Ok(SatResult::Unknown);
}
SatResult::Sat => {
}
}
}
if self.config.eager_combination {
self.eager_theory_combination()
} else {
self.lazy_theory_combination()
}
}
fn eager_theory_combination(&mut self) -> Result<SatResult, String> {
let mut iteration = 0;
loop {
self.stats.theory_combination_rounds += 1;
iteration += 1;
if iteration > self.config.max_combination_rounds {
return Ok(SatResult::Unknown);
}
let mut new_equalities = Vec::new();
for (theory_id, solver) in &self.theories {
let equalities = solver.get_implied_equalities();
for (lhs, rhs) in equalities {
if self.is_shared_term(lhs) || self.is_shared_term(rhs) {
new_equalities.push(EqualityProp {
lhs,
rhs,
source: *theory_id,
explanation: vec![],
});
}
}
}
if new_equalities.is_empty() {
return Ok(SatResult::Sat);
}
for eq in new_equalities {
self.propagate_equality(eq)?;
}
for solver in self.theories.values_mut() {
match solver.check_sat()? {
SatResult::Unsat => {
self.stats.theory_conflicts += 1;
return Ok(SatResult::Unsat);
}
SatResult::Unknown => {
return Ok(SatResult::Unknown);
}
SatResult::Sat => {}
}
}
}
}
fn lazy_theory_combination(&mut self) -> Result<SatResult, String> {
while let Some(eq) = self.pending_equalities.pop_front() {
self.propagate_equality(eq)?;
for solver in self.theories.values_mut() {
match solver.check_sat()? {
SatResult::Unsat => {
self.stats.theory_conflicts += 1;
return Ok(SatResult::Unsat);
}
SatResult::Unknown => {
return Ok(SatResult::Unknown);
}
SatResult::Sat => {}
}
}
}
Ok(SatResult::Sat)
}
fn propagate_equality(&mut self, eq: EqualityProp) -> Result<(), String> {
self.stats.equalities_propagated += 1;
self.merge_equivalence_classes(eq.lhs, eq.rhs)?;
let theories_to_notify = self.get_theories_for_terms(eq.lhs, eq.rhs);
for theory_id in theories_to_notify {
if theory_id != eq.source
&& let Some(solver) = self.theories.get_mut(&theory_id)
{
solver.notify_equality(eq.lhs, eq.rhs)?;
}
}
Ok(())
}
fn identify_shared_terms(&mut self, _formula: TermId) -> Result<(), String> {
self.stats.shared_terms_count = self.shared_terms.len();
Ok(())
}
fn is_shared_term(&self, term: TermId) -> bool {
self.shared_terms
.get(&term)
.is_some_and(|st| st.theories.len() > 1)
}
fn get_theories_for_terms(&self, lhs: TermId, rhs: TermId) -> FxHashSet<TheoryId> {
let mut theories = FxHashSet::default();
if let Some(st) = self.shared_terms.get(&lhs) {
theories.extend(&st.theories);
}
if let Some(st) = self.shared_terms.get(&rhs) {
theories.extend(&st.theories);
}
theories
}
fn merge_equivalence_classes(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
let lhs_rep = self.find_representative(lhs);
let rhs_rep = self.find_representative(rhs);
if lhs_rep == rhs_rep {
return Ok(());
}
if let Some(st) = self.shared_terms.get_mut(&lhs_rep) {
st.representative = rhs_rep;
}
Ok(())
}
fn find_representative(&self, term: TermId) -> TermId {
if let Some(st) = self.shared_terms.get(&term)
&& st.representative != term
{
return self.find_representative(st.representative);
}
term
}
pub fn add_shared_term(&mut self, term: TermId, theory: TheoryId) {
self.shared_terms
.entry(term)
.or_insert_with(|| SharedTerm {
term,
theories: FxHashSet::default(),
representative: term,
})
.theories
.insert(theory);
self.stats.shared_terms_count = self.shared_terms.len();
}
pub fn enqueue_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
self.pending_equalities.push_back(EqualityProp {
lhs,
rhs,
source,
explanation: vec![],
});
}
pub fn backtrack(&mut self, level: usize) -> Result<(), String> {
self.current_level = level;
for solver in self.theories.values_mut() {
solver.backtrack(level)?;
}
self.pending_equalities.clear();
Ok(())
}
pub fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
let mut combined_model = FxHashMap::default();
for solver in self.theories.values() {
if let Some(model) = solver.get_model() {
combined_model.extend(model);
} else {
return None;
}
}
Some(combined_model)
}
pub fn get_conflict(&self) -> Option<Vec<TermId>> {
let mut combined_conflict = Vec::new();
for solver in self.theories.values() {
if let Some(conflict) = solver.get_conflict() {
combined_conflict.extend(conflict);
}
}
if combined_conflict.is_empty() {
None
} else {
if self.config.minimize_conflicts {
Some(self.minimize_conflict(combined_conflict))
} else {
Some(combined_conflict)
}
}
}
fn minimize_conflict(&self, mut conflict: Vec<TermId>) -> Vec<TermId> {
conflict.sort();
conflict.dedup();
conflict
}
pub fn stats(&self) -> &CoordinatorStats {
&self.stats
}
pub fn current_level(&self) -> usize {
self.current_level
}
pub fn increment_level(&mut self) {
self.current_level += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockTheory {
id: TheoryId,
sat_result: SatResult,
}
impl TheorySolver for MockTheory {
fn theory_id(&self) -> TheoryId {
self.id
}
fn assert_formula(&mut self, _formula: TermId) -> Result<(), String> {
Ok(())
}
fn check_sat(&mut self) -> Result<SatResult, String> {
Ok(self.sat_result)
}
fn get_model(&self) -> Option<FxHashMap<TermId, TermId>> {
Some(FxHashMap::default())
}
fn get_conflict(&self) -> Option<Vec<TermId>> {
None
}
fn backtrack(&mut self, _level: usize) -> Result<(), String> {
Ok(())
}
fn get_implied_equalities(&self) -> Vec<(TermId, TermId)> {
vec![]
}
fn notify_equality(&mut self, _lhs: TermId, _rhs: TermId) -> Result<(), String> {
Ok(())
}
}
#[test]
fn test_coordinator_creation() {
let config = CoordinatorConfig::default();
let coordinator = TheoryCoordinator::new(config);
assert_eq!(coordinator.stats.check_sat_calls, 0);
}
#[test]
fn test_register_theory() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
let mock_theory = MockTheory {
id: TheoryId::Arithmetic,
sat_result: SatResult::Sat,
};
coordinator.register_theory(Box::new(mock_theory));
assert!(coordinator.theories.contains_key(&TheoryId::Arithmetic));
}
#[test]
fn test_check_sat_single_theory() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
let mock_theory = MockTheory {
id: TheoryId::Arithmetic,
sat_result: SatResult::Sat,
};
coordinator.register_theory(Box::new(mock_theory));
let result = coordinator.check_sat();
assert!(result.is_ok());
assert_eq!(
result.expect("test operation should succeed"),
SatResult::Sat
);
assert_eq!(coordinator.stats.check_sat_calls, 1);
}
#[test]
fn test_shared_term_management() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
coordinator.add_shared_term(1, TheoryId::Arithmetic);
coordinator.add_shared_term(1, TheoryId::BitVector);
assert!(coordinator.is_shared_term(1));
assert_eq!(coordinator.stats.shared_terms_count, 1);
}
#[test]
fn test_equivalence_classes() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
coordinator.add_shared_term(1, TheoryId::Arithmetic);
coordinator.add_shared_term(2, TheoryId::Arithmetic);
coordinator
.merge_equivalence_classes(1, 2)
.expect("test operation should succeed");
let rep1 = coordinator.find_representative(1);
let rep2 = coordinator.find_representative(2);
assert_eq!(rep1, rep2);
}
#[test]
fn test_equality_propagation() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
coordinator.enqueue_equality(1, 2, TheoryId::Arithmetic);
assert_eq!(coordinator.pending_equalities.len(), 1);
}
#[test]
fn test_backtrack() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
let mock_theory = MockTheory {
id: TheoryId::Arithmetic,
sat_result: SatResult::Sat,
};
coordinator.register_theory(Box::new(mock_theory));
coordinator.increment_level();
coordinator.increment_level();
assert_eq!(coordinator.current_level(), 2);
coordinator
.backtrack(0)
.expect("test operation should succeed");
assert_eq!(coordinator.current_level(), 0);
}
#[test]
fn test_get_model() {
let config = CoordinatorConfig::default();
let mut coordinator = TheoryCoordinator::new(config);
let mock_theory = MockTheory {
id: TheoryId::Arithmetic,
sat_result: SatResult::Sat,
};
coordinator.register_theory(Box::new(mock_theory));
let model = coordinator.get_model();
assert!(model.is_some());
}
#[test]
fn test_conflict_minimization() {
let coordinator = TheoryCoordinator::new(CoordinatorConfig {
minimize_conflicts: true,
..Default::default()
});
let conflict = vec![1, 2, 2, 3, 1, 4];
let minimized = coordinator.minimize_conflict(conflict);
assert_eq!(minimized, vec![1, 2, 3, 4]);
}
}