#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::TermId as CoreTermId;
pub type TermId = CoreTermId;
pub type TheoryId = usize;
pub type DecisionLevel = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs.raw() <= rhs.raw() {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
pub fn flip(self) -> Self {
Self::new(self.rhs, self.lhs)
}
}
#[derive(Debug, Clone)]
pub enum EqualityExplanation {
Given,
Reflexive,
TheoryPropagation {
theory: TheoryId,
support: Vec<Equality>,
},
Transitive {
intermediate: TermId,
left: Box<EqualityExplanation>,
right: Box<EqualityExplanation>,
},
Congruence {
function: TermId,
arg_equalities: Vec<(Equality, Box<EqualityExplanation>)>,
},
}
pub type EClassId = u32;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ENode {
pub term: TermId,
pub eclass: EClassId,
}
#[derive(Debug, Clone)]
pub struct EClass {
pub id: EClassId,
pub representative: TermId,
pub members: FxHashSet<TermId>,
pub parents: FxHashSet<EClassId>,
pub size: usize,
}
impl EClass {
fn new(id: EClassId, term: TermId) -> Self {
let mut members = FxHashSet::default();
members.insert(term);
Self {
id,
representative: term,
members,
parents: FxHashSet::default(),
size: 1,
}
}
fn merge(&mut self, other: &EClass) {
for &term in &other.members {
self.members.insert(term);
}
for &parent in &other.parents {
self.parents.insert(parent);
}
self.size += other.size;
}
}
#[derive(Debug, Clone)]
pub struct EGraph {
term_to_eclass: FxHashMap<TermId, EClassId>,
eclasses: FxHashMap<EClassId, EClass>,
next_eclass_id: EClassId,
parent: FxHashMap<EClassId, EClassId>,
rank: FxHashMap<EClassId, usize>,
pending_congruences: VecDeque<(EClassId, EClassId)>,
merge_explanations: FxHashMap<(EClassId, EClassId), EqualityExplanation>,
}
impl EGraph {
pub fn new() -> Self {
Self {
term_to_eclass: FxHashMap::default(),
eclasses: FxHashMap::default(),
next_eclass_id: 0,
parent: FxHashMap::default(),
rank: FxHashMap::default(),
pending_congruences: VecDeque::new(),
merge_explanations: FxHashMap::default(),
}
}
pub fn add_term(&mut self, term: TermId) -> EClassId {
if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
return self.find(eclass_id);
}
let eclass_id = self.next_eclass_id;
self.next_eclass_id += 1;
let eclass = EClass::new(eclass_id, term);
self.eclasses.insert(eclass_id, eclass);
self.term_to_eclass.insert(term, eclass_id);
eclass_id
}
pub fn find(&mut self, mut eclass_id: EClassId) -> EClassId {
let mut path = Vec::new();
while let Some(&parent) = self.parent.get(&eclass_id) {
if parent == eclass_id {
break;
}
path.push(eclass_id);
eclass_id = parent;
}
for node in path {
self.parent.insert(node, eclass_id);
}
eclass_id
}
pub fn merge(
&mut self,
a: EClassId,
b: EClassId,
explanation: EqualityExplanation,
) -> Result<EClassId, String> {
let a_root = self.find(a);
let b_root = self.find(b);
if a_root == b_root {
return Ok(a_root);
}
let a_rank = self.rank.get(&a_root).copied().unwrap_or(0);
let b_rank = self.rank.get(&b_root).copied().unwrap_or(0);
let (child, parent_id) = if a_rank < b_rank {
(a_root, b_root)
} else if a_rank > b_rank {
(b_root, a_root)
} else {
self.rank.insert(b_root, b_rank + 1);
(a_root, b_root)
};
self.parent.insert(child, parent_id);
if let Some(child_eclass) = self.eclasses.get(&child).cloned()
&& let Some(parent_eclass) = self.eclasses.get_mut(&parent_id)
{
parent_eclass.merge(&child_eclass);
}
self.merge_explanations
.insert((child, parent_id), explanation);
self.queue_congruence_checks(child, parent_id);
Ok(parent_id)
}
fn queue_congruence_checks(&mut self, _a: EClassId, _b: EClassId) {
}
pub fn process_congruences(&mut self) -> Result<(), String> {
while let Some((a, b)) = self.pending_congruences.pop_front() {
let a_root = self.find(a);
let b_root = self.find(b);
if a_root != b_root {
self.merge(
a_root,
b_root,
EqualityExplanation::Congruence {
function: TermId::new(0), arg_equalities: Vec::new(),
},
)?;
}
}
Ok(())
}
pub fn get_representative(&mut self, term: TermId) -> Option<TermId> {
let eclass_id = *self.term_to_eclass.get(&term)?;
let root = self.find(eclass_id);
self.eclasses.get(&root).map(|ec| ec.representative)
}
pub fn are_equal(&mut self, a: TermId, b: TermId) -> bool {
if let (Some(&a_class), Some(&b_class)) =
(self.term_to_eclass.get(&a), self.term_to_eclass.get(&b))
{
self.find(a_class) == self.find(b_class)
} else {
false
}
}
pub fn get_explanation(&mut self, a: TermId, b: TermId) -> Option<EqualityExplanation> {
if !self.are_equal(a, b) {
return None;
}
if a == b {
return Some(EqualityExplanation::Reflexive);
}
let a_class = self.term_to_eclass.get(&a)?;
let b_class = self.term_to_eclass.get(&b)?;
let a_class_val = *a_class;
let b_class_val = *b_class;
let a_root = self.find(a_class_val);
let b_root = self.find(b_class_val);
if a_root == b_root {
if let Some(explanation) = self.merge_explanations.get(&(a_class_val, b_class_val)) {
return Some(explanation.clone());
}
}
None
}
pub fn get_eclass_members(&mut self, term: TermId) -> Vec<TermId> {
if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
let root = self.find(eclass_id);
if let Some(eclass) = self.eclasses.get(&root) {
return eclass.members.iter().copied().collect();
}
}
Vec::new()
}
pub fn clear(&mut self) {
self.term_to_eclass.clear();
self.eclasses.clear();
self.next_eclass_id = 0;
self.parent.clear();
self.rank.clear();
self.pending_congruences.clear();
self.merge_explanations.clear();
}
}
impl Default for EGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SharedTermInfo {
pub theories: FxHashSet<TheoryId>,
pub is_interface: bool,
pub representative: TermId,
pub class_size: usize,
pub shared_at_level: DecisionLevel,
}
impl SharedTermInfo {
fn new(theory: TheoryId, level: DecisionLevel) -> Self {
let mut theories = FxHashSet::default();
theories.insert(theory);
Self {
theories,
is_interface: false,
representative: TermId::new(0), class_size: 1,
shared_at_level: level,
}
}
}
#[derive(Debug, Clone)]
pub struct SharedTermsConfig {
pub enable_batching: bool,
pub max_batch_size: usize,
pub enable_egraph: bool,
pub minimize_interface: bool,
pub track_explanations: bool,
}
impl Default for SharedTermsConfig {
fn default() -> Self {
Self {
enable_batching: true,
max_batch_size: 1000,
enable_egraph: true,
minimize_interface: true,
track_explanations: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SharedTermsStats {
pub terms_registered: u64,
pub subscriptions: u64,
pub equalities_propagated: u64,
pub batches_sent: u64,
pub eclass_merges: u64,
pub congruences: u64,
pub interface_terms: u64,
}
#[derive(Debug)]
pub struct AdvancedSharedTermsManager {
config: SharedTermsConfig,
terms: FxHashMap<TermId, SharedTermInfo>,
egraph: EGraph,
pending_equalities: Vec<Equality>,
subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
interface_terms: FxHashSet<TermId>,
decision_levels: FxHashMap<DecisionLevel, Vec<TermId>>,
current_level: DecisionLevel,
stats: SharedTermsStats,
explanations: FxHashMap<Equality, EqualityExplanation>,
}
impl AdvancedSharedTermsManager {
pub fn new(config: SharedTermsConfig) -> Self {
Self {
config,
terms: FxHashMap::default(),
egraph: EGraph::new(),
pending_equalities: Vec::new(),
subscriptions: FxHashMap::default(),
interface_terms: FxHashSet::default(),
decision_levels: FxHashMap::default(),
current_level: 0,
stats: SharedTermsStats::default(),
explanations: FxHashMap::default(),
}
}
pub fn default_config() -> Self {
Self::new(SharedTermsConfig::default())
}
pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
let is_new = !self.terms.contains_key(&term);
let entry = self.terms.entry(term).or_insert_with(|| {
self.stats.terms_registered += 1;
SharedTermInfo::new(theory, self.current_level)
});
let was_single_theory = entry.theories.len() == 1;
entry.theories.insert(theory);
if was_single_theory && entry.theories.len() > 1 {
self.interface_terms.insert(term);
entry.is_interface = true;
self.stats.interface_terms += 1;
}
self.stats.subscriptions += 1;
self.subscriptions.entry(term).or_default().insert(theory);
if self.config.enable_egraph {
self.egraph.add_term(term);
}
if is_new {
self.decision_levels
.entry(self.current_level)
.or_default()
.push(term);
}
}
pub fn is_shared(&self, term: TermId) -> bool {
self.terms
.get(&term)
.map(|info| info.theories.len() > 1)
.unwrap_or(false)
}
pub fn is_interface_term(&self, term: TermId) -> bool {
self.interface_terms.contains(&term)
}
pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
self.terms
.get(&term)
.map(|info| info.theories.iter().copied().collect())
.unwrap_or_default()
}
pub fn assert_equality(
&mut self,
lhs: TermId,
rhs: TermId,
explanation: EqualityExplanation,
) -> Result<(), String> {
if self.config.enable_egraph && self.egraph.are_equal(lhs, rhs) {
return Ok(()); }
if self.config.enable_egraph {
let lhs_class = self.egraph.add_term(lhs);
let rhs_class = self.egraph.add_term(rhs);
self.egraph
.merge(lhs_class, rhs_class, explanation.clone())?;
self.stats.eclass_merges += 1;
}
let equality = Equality::new(lhs, rhs);
self.pending_equalities.push(equality);
self.stats.equalities_propagated += 1;
if self.config.track_explanations {
self.explanations.insert(equality, explanation);
}
if self.pending_equalities.len() >= self.config.max_batch_size {
self.flush_equalities();
}
Ok(())
}
pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
if !self.config.enable_egraph {
return lhs == rhs;
}
self.egraph.are_equal(lhs, rhs)
}
pub fn get_representative(&mut self, term: TermId) -> TermId {
if !self.config.enable_egraph {
return term;
}
self.egraph.get_representative(term).unwrap_or(term)
}
pub fn get_eclass_members(&mut self, term: TermId) -> Vec<TermId> {
if !self.config.enable_egraph {
return vec![term];
}
self.egraph.get_eclass_members(term)
}
pub fn get_equality_explanation(
&mut self,
lhs: TermId,
rhs: TermId,
) -> Option<EqualityExplanation> {
let eq = Equality::new(lhs, rhs);
if let Some(explanation) = self.explanations.get(&eq) {
return Some(explanation.clone());
}
if self.config.enable_egraph {
return self.egraph.get_explanation(lhs, rhs);
}
None
}
pub fn get_pending_equalities(&self) -> &[Equality] {
&self.pending_equalities
}
pub fn flush_equalities(&mut self) {
if !self.pending_equalities.is_empty() {
self.stats.batches_sent += 1;
self.pending_equalities.clear();
}
}
pub fn get_shared_terms(&self) -> Vec<TermId> {
self.terms
.iter()
.filter(|(_, info)| info.theories.len() > 1)
.map(|(&term, _)| term)
.collect()
}
pub fn get_interface_terms(&self) -> Vec<TermId> {
self.interface_terms.iter().copied().collect()
}
pub fn minimize_interface(&mut self) -> Vec<TermId> {
if !self.config.minimize_interface || !self.config.enable_egraph {
return self.get_interface_terms();
}
let mut minimal = FxHashSet::default();
let terms: Vec<_> = self.interface_terms.iter().copied().collect();
for term in terms {
let rep = self.get_representative(term);
minimal.insert(rep);
}
minimal.into_iter().collect()
}
pub fn push_decision_level(&mut self) {
self.current_level += 1;
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
if level > self.current_level {
return Err("Cannot backtrack to future level".to_string());
}
let levels_to_remove: Vec<_> = self
.decision_levels
.keys()
.filter(|&&l| l > level)
.copied()
.collect();
for l in levels_to_remove {
if let Some(terms) = self.decision_levels.remove(&l) {
for term in terms {
self.terms.remove(&term);
self.subscriptions.remove(&term);
self.interface_terms.remove(&term);
}
}
}
self.current_level = level;
Ok(())
}
pub fn stats(&self) -> &SharedTermsStats {
&self.stats
}
pub fn reset(&mut self) {
self.terms.clear();
self.egraph.clear();
self.pending_equalities.clear();
self.subscriptions.clear();
self.interface_terms.clear();
self.decision_levels.clear();
self.current_level = 0;
self.explanations.clear();
self.stats = SharedTermsStats::default();
}
pub fn process_congruences(&mut self) -> Result<(), String> {
if !self.config.enable_egraph {
return Ok(());
}
self.egraph.process_congruences()?;
Ok(())
}
pub fn detect_shared_terms(&mut self, _term_theories: &FxHashMap<TermId, FxHashSet<TheoryId>>) {
}
pub fn build_explanation_chain(
&self,
equalities: &[Equality],
) -> Result<EqualityExplanation, String> {
if equalities.is_empty() {
return Err("No equalities to explain".to_string());
}
if equalities.len() == 1 {
let eq = &equalities[0];
return Ok(self
.explanations
.get(eq)
.cloned()
.unwrap_or(EqualityExplanation::Given));
}
let mut current = equalities[0];
let mut explanation = self
.explanations
.get(¤t)
.cloned()
.unwrap_or(EqualityExplanation::Given);
for &eq in &equalities[1..] {
let next_explanation = self
.explanations
.get(&eq)
.cloned()
.unwrap_or(EqualityExplanation::Given);
explanation = EqualityExplanation::Transitive {
intermediate: current.rhs,
left: Box::new(explanation),
right: Box::new(next_explanation),
};
current = eq;
}
Ok(explanation)
}
}
impl Default for AdvancedSharedTermsManager {
fn default() -> Self {
Self::default_config()
}
}
pub struct InterfaceTermMinimizer {
egraph: EGraph,
candidates: FxHashSet<TermId>,
}
impl InterfaceTermMinimizer {
pub fn new() -> Self {
Self {
egraph: EGraph::new(),
candidates: FxHashSet::default(),
}
}
pub fn add_candidate(&mut self, term: TermId) {
self.candidates.insert(term);
self.egraph.add_term(term);
}
pub fn add_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
let lhs_class = self.egraph.add_term(lhs);
let rhs_class = self.egraph.add_term(rhs);
self.egraph
.merge(lhs_class, rhs_class, EqualityExplanation::Given)?;
Ok(())
}
pub fn minimize(&mut self) -> Vec<TermId> {
let mut minimal = FxHashSet::default();
for &term in &self.candidates {
let rep = self.egraph.get_representative(term).unwrap_or(term);
minimal.insert(rep);
}
minimal.into_iter().collect()
}
pub fn clear(&mut self) {
self.egraph.clear();
self.candidates.clear();
}
}
impl Default for InterfaceTermMinimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn term(id: u32) -> TermId {
TermId::new(id)
}
#[test]
fn test_equality_creation() {
let eq1 = Equality::new(term(1), term(2));
let eq2 = Equality::new(term(2), term(1));
assert_eq!(eq1, eq2);
}
#[test]
fn test_egraph_creation() {
let egraph = EGraph::new();
assert_eq!(egraph.next_eclass_id, 0);
}
#[test]
fn test_egraph_add_term() {
let mut egraph = EGraph::new();
let class1 = egraph.add_term(term(1));
let class2 = egraph.add_term(term(1));
assert_eq!(class1, class2);
}
#[test]
fn test_egraph_merge() {
let mut egraph = EGraph::new();
let c1 = egraph.add_term(term(1));
let c2 = egraph.add_term(term(2));
egraph
.merge(c1, c2, EqualityExplanation::Given)
.expect("Merge failed");
assert!(egraph.are_equal(term(1), term(2)));
}
#[test]
fn test_egraph_transitivity() {
let mut egraph = EGraph::new();
let c1 = egraph.add_term(term(1));
let c2 = egraph.add_term(term(2));
let c3 = egraph.add_term(term(3));
egraph
.merge(c1, c2, EqualityExplanation::Given)
.expect("Merge failed");
egraph
.merge(c2, c3, EqualityExplanation::Given)
.expect("Merge failed");
assert!(egraph.are_equal(term(1), term(3)));
}
#[test]
fn test_manager_creation() {
let manager = AdvancedSharedTermsManager::default_config();
assert_eq!(manager.stats().terms_registered, 0);
}
#[test]
fn test_register_term() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager.register_term(term(1), 0); manager.register_term(term(1), 1);
assert!(manager.is_shared(term(1)));
assert!(manager.is_interface_term(term(1)));
assert_eq!(manager.get_theories(term(1)).len(), 2);
}
#[test]
fn test_equality_assertion() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager
.assert_equality(term(1), term(2), EqualityExplanation::Given)
.expect("Assert failed");
assert!(manager.are_equal(term(1), term(2)));
assert_eq!(manager.get_pending_equalities().len(), 1);
}
#[test]
fn test_representative() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager
.assert_equality(term(1), term(2), EqualityExplanation::Given)
.expect("Assert failed");
let rep1 = manager.get_representative(term(1));
let rep2 = manager.get_representative(term(2));
assert_eq!(rep1, rep2);
}
#[test]
fn test_eclass_members() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager
.assert_equality(term(1), term(2), EqualityExplanation::Given)
.expect("Assert failed");
let members = manager.get_eclass_members(term(1));
assert!(members.contains(&term(1)));
assert!(members.contains(&term(2)));
}
#[test]
fn test_flush_equalities() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager
.assert_equality(term(1), term(2), EqualityExplanation::Given)
.expect("Assert failed");
assert_eq!(manager.get_pending_equalities().len(), 1);
manager.flush_equalities();
assert_eq!(manager.get_pending_equalities().len(), 0);
}
#[test]
fn test_interface_term_minimization() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager.register_term(term(1), 0);
manager.register_term(term(1), 1);
manager.register_term(term(2), 0);
manager.register_term(term(2), 1);
manager
.assert_equality(term(1), term(2), EqualityExplanation::Given)
.expect("Assert failed");
let minimal = manager.minimize_interface();
assert_eq!(minimal.len(), 1);
}
#[test]
fn test_decision_levels() {
let mut manager = AdvancedSharedTermsManager::default_config();
manager.push_decision_level();
manager.register_term(term(1), 0);
manager.push_decision_level();
manager.register_term(term(2), 0);
manager.backtrack(1).expect("Backtrack failed");
assert!(manager.terms.contains_key(&term(1)));
assert!(!manager.terms.contains_key(&term(2)));
}
#[test]
fn test_interface_minimizer() {
let mut minimizer = InterfaceTermMinimizer::new();
minimizer.add_candidate(term(1));
minimizer.add_candidate(term(2));
minimizer.add_candidate(term(3));
minimizer
.add_equality(term(1), term(2))
.expect("Equality failed");
let minimal = minimizer.minimize();
assert_eq!(minimal.len(), 2); }
}