#![allow(dead_code)]
#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::TermId;
pub type TheoryId = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs.raw() <= rhs.raw() {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
}
#[derive(Debug, Clone)]
struct SharedTermInfo {
theories: FxHashSet<TheoryId>,
representative: TermId,
}
#[derive(Debug, Clone)]
pub struct SharedTermsConfig {
pub enable_batching: bool,
pub max_batch_size: usize,
}
impl Default for SharedTermsConfig {
fn default() -> Self {
Self {
enable_batching: true,
max_batch_size: 1000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SharedTermsStats {
pub terms_registered: u64,
pub subscriptions: u64,
pub equalities_propagated: u64,
pub batches_sent: u64,
}
#[derive(Debug)]
pub struct SharedTermsManager {
config: SharedTermsConfig,
terms: FxHashMap<TermId, SharedTermInfo>,
parent: FxHashMap<TermId, TermId>,
pending_equalities: Vec<Equality>,
subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
stats: SharedTermsStats,
}
impl SharedTermsManager {
pub fn new(config: SharedTermsConfig) -> Self {
Self {
config,
terms: FxHashMap::default(),
parent: FxHashMap::default(),
pending_equalities: Vec::new(),
subscriptions: FxHashMap::default(),
stats: SharedTermsStats::default(),
}
}
pub fn default_config() -> Self {
Self::new(SharedTermsConfig::default())
}
pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
let entry = self.terms.entry(term).or_insert_with(|| {
self.stats.terms_registered += 1;
SharedTermInfo {
theories: FxHashSet::default(),
representative: term,
}
});
entry.theories.insert(theory);
self.stats.subscriptions += 1;
self.subscriptions.entry(term).or_default().insert(theory);
}
pub fn is_shared(&self, term: TermId) -> bool {
self.terms
.get(&term)
.map(|info| info.theories.len() > 1)
.unwrap_or(false)
}
pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
self.terms
.get(&term)
.map(|info| info.theories.iter().copied().collect())
.unwrap_or_default()
}
pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) {
let lhs_rep = self.find(lhs);
let rhs_rep = self.find(rhs);
if lhs_rep == rhs_rep {
return; }
self.parent.insert(lhs_rep, rhs_rep);
let equality = Equality::new(lhs, rhs);
self.pending_equalities.push(equality);
self.stats.equalities_propagated += 1;
if self.pending_equalities.len() >= self.config.max_batch_size {
self.flush_equalities();
}
}
fn find(&mut self, term: TermId) -> TermId {
if let Some(&parent) = self.parent.get(&term)
&& parent != term
{
let root = self.find(parent);
self.parent.insert(term, root); return root;
}
term
}
pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
self.find(lhs) == self.find(rhs)
}
pub fn get_pending_equalities(&self) -> &[Equality] {
&self.pending_equalities
}
pub fn flush_equalities(&mut self) {
if !self.pending_equalities.is_empty() {
self.stats.batches_sent += 1;
self.pending_equalities.clear();
}
}
pub fn get_shared_terms(&self) -> Vec<TermId> {
self.terms
.iter()
.filter(|(_, info)| info.theories.len() > 1)
.map(|(&term, _)| term)
.collect()
}
pub fn stats(&self) -> &SharedTermsStats {
&self.stats
}
pub fn reset(&mut self) {
self.terms.clear();
self.parent.clear();
self.pending_equalities.clear();
self.subscriptions.clear();
self.stats = SharedTermsStats::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn term(id: u32) -> TermId {
TermId::new(id)
}
#[test]
fn test_manager_creation() {
let manager = SharedTermsManager::default_config();
assert_eq!(manager.stats().terms_registered, 0);
}
#[test]
fn test_register_term() {
let mut manager = SharedTermsManager::default_config();
manager.register_term(term(1), 0); manager.register_term(term(1), 1);
assert!(manager.is_shared(term(1)));
assert_eq!(manager.get_theories(term(1)).len(), 2);
}
#[test]
fn test_equality() {
let mut manager = SharedTermsManager::default_config();
manager.assert_equality(term(1), term(2));
assert!(manager.are_equal(term(1), term(2)));
assert_eq!(manager.get_pending_equalities().len(), 1);
}
#[test]
fn test_equality_transitivity() {
let mut manager = SharedTermsManager::default_config();
manager.assert_equality(term(1), term(2));
manager.assert_equality(term(2), term(3));
assert!(manager.are_equal(term(1), term(3)));
}
#[test]
fn test_flush_equalities() {
let mut manager = SharedTermsManager::default_config();
manager.assert_equality(term(1), term(2));
assert_eq!(manager.get_pending_equalities().len(), 1);
manager.flush_equalities();
assert_eq!(manager.get_pending_equalities().len(), 0);
}
}