#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::TermId;
pub type TheoryId = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SharedTerm {
pub term: TermId,
pub theories: u64, }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DeferredEquality {
pub lhs: TermId,
pub rhs: TermId,
pub source_theory: TheoryId,
}
#[derive(Debug, Clone)]
pub struct DelayedCombinationConfig {
pub lazy_sharing: bool,
pub conflict_driven: bool,
pub max_deferred: usize,
}
impl Default for DelayedCombinationConfig {
fn default() -> Self {
Self {
lazy_sharing: true,
conflict_driven: true,
max_deferred: 1000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DelayedCombinationStats {
pub equalities_deferred: u64,
pub equalities_propagated: u64,
pub forced_propagations: u64,
pub conflicts_detected: u64,
}
#[derive(Debug)]
pub struct DelayedCombination {
shared_terms: FxHashMap<TermId, SharedTerm>,
deferred: Vec<DeferredEquality>,
active_theories: u64,
config: DelayedCombinationConfig,
stats: DelayedCombinationStats,
}
impl DelayedCombination {
pub fn new(config: DelayedCombinationConfig) -> Self {
Self {
shared_terms: FxHashMap::default(),
deferred: Vec::new(),
active_theories: 0,
config,
stats: DelayedCombinationStats::default(),
}
}
pub fn default_config() -> Self {
Self::new(DelayedCombinationConfig::default())
}
pub fn register_shared_term(&mut self, term: TermId, theory: TheoryId) {
let entry = self
.shared_terms
.entry(term)
.or_insert(SharedTerm { term, theories: 0 });
entry.theories |= 1 << theory;
}
pub fn is_shared(&self, term: TermId) -> bool {
self.shared_terms
.get(&term)
.map(|st| st.theories.count_ones() > 1)
.unwrap_or(false)
}
pub fn defer_equality(&mut self, lhs: TermId, rhs: TermId, source: TheoryId) {
if !self.config.lazy_sharing {
self.propagate_equality(lhs, rhs, source);
return;
}
self.deferred.push(DeferredEquality {
lhs,
rhs,
source_theory: source,
});
self.stats.equalities_deferred += 1;
if self.deferred.len() >= self.config.max_deferred {
self.force_propagation();
}
}
fn propagate_equality(&mut self, _lhs: TermId, _rhs: TermId, _source: TheoryId) {
self.stats.equalities_propagated += 1;
}
pub fn force_propagation(&mut self) {
if self.deferred.is_empty() {
return;
}
self.stats.forced_propagations += 1;
let equalities: Vec<_> = self.deferred.drain(..).collect();
for eq in equalities {
self.propagate_equality(eq.lhs, eq.rhs, eq.source_theory);
}
}
pub fn handle_conflict(&mut self) {
if !self.config.conflict_driven {
return;
}
self.stats.conflicts_detected += 1;
self.force_propagation();
}
pub fn activate_theory(&mut self, theory: TheoryId) {
self.active_theories |= 1 << theory;
}
pub fn is_theory_active(&self, theory: TheoryId) -> bool {
(self.active_theories & (1 << theory)) != 0
}
pub fn get_sharing_theories(&self, term: TermId) -> Vec<TheoryId> {
if let Some(shared) = self.shared_terms.get(&term) {
let mut theories = Vec::new();
for i in 0..64 {
if (shared.theories & (1 << i)) != 0 {
theories.push(i);
}
}
theories
} else {
Vec::new()
}
}
pub fn stats(&self) -> &DelayedCombinationStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = DelayedCombinationStats::default();
}
}
impl Default for DelayedCombination {
fn default() -> Self {
Self::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_delayed_combination_creation() {
let dc = DelayedCombination::default_config();
assert_eq!(dc.stats().equalities_deferred, 0);
}
#[test]
fn test_register_shared_term() {
let mut dc = DelayedCombination::default_config();
let term = TermId::new(0);
dc.register_shared_term(term, 0);
dc.register_shared_term(term, 1);
assert!(dc.is_shared(term));
}
#[test]
fn test_defer_equality() {
let mut dc = DelayedCombination::default_config();
let lhs = TermId::new(0);
let rhs = TermId::new(1);
dc.defer_equality(lhs, rhs, 0);
assert_eq!(dc.stats().equalities_deferred, 1);
assert_eq!(dc.deferred.len(), 1);
}
#[test]
fn test_force_propagation() {
let mut dc = DelayedCombination::default_config();
let lhs = TermId::new(0);
let rhs = TermId::new(1);
dc.defer_equality(lhs, rhs, 0);
dc.force_propagation();
assert_eq!(dc.deferred.len(), 0);
assert_eq!(dc.stats().equalities_propagated, 1);
assert_eq!(dc.stats().forced_propagations, 1);
}
#[test]
fn test_activate_theory() {
let mut dc = DelayedCombination::default_config();
dc.activate_theory(2);
assert!(dc.is_theory_active(2));
assert!(!dc.is_theory_active(3));
}
#[test]
fn test_get_sharing_theories() {
let mut dc = DelayedCombination::default_config();
let term = TermId::new(0);
dc.register_shared_term(term, 1);
dc.register_shared_term(term, 3);
let theories = dc.get_sharing_theories(term);
assert_eq!(theories.len(), 2);
assert!(theories.contains(&1));
assert!(theories.contains(&3));
}
#[test]
fn test_handle_conflict() {
let mut dc = DelayedCombination::default_config();
let lhs = TermId::new(0);
let rhs = TermId::new(1);
dc.defer_equality(lhs, rhs, 0);
dc.handle_conflict();
assert_eq!(dc.deferred.len(), 0);
assert_eq!(dc.stats().conflicts_detected, 1);
}
}