#![allow(dead_code)]
#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::ast::{TermId, TermKind, TermManager};
pub struct EqualityPropagator {
union_find: UnionFind,
congruence: CongruenceData,
pending: VecDeque<(TermId, TermId, Explanation)>,
watched: FxHashMap<TermId, Vec<EqualityWatch>>,
egraph: EGraph,
stats: EqualityPropStats,
}
#[derive(Debug, Clone)]
pub struct UnionFind {
parent: FxHashMap<TermId, TermId>,
rank: FxHashMap<TermId, usize>,
size: FxHashMap<TermId, usize>,
}
#[derive(Debug, Clone)]
pub struct CongruenceData {
use_list: FxHashMap<TermId, Vec<TermId>>,
lookup: FxHashMap<CongruenceKey, TermId>,
pending_congruences: VecDeque<(TermId, TermId)>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CongruenceKey {
pub function: TermKind,
pub args: Vec<TermId>,
}
#[derive(Debug, Clone)]
pub struct EGraph {
eclass: FxHashMap<TermId, EClassId>,
nodes: FxHashMap<EClassId, Vec<TermId>>,
data: FxHashMap<EClassId, EClassData>,
next_id: EClassId,
}
pub type EClassId = usize;
#[derive(Debug, Clone)]
pub struct EClassData {
pub representative: TermId,
pub size: usize,
pub parents: Vec<EClassId>,
}
#[derive(Debug, Clone)]
pub enum Explanation {
Given,
Reflexivity,
Transitivity(TermId, Box<Explanation>, Box<Explanation>),
Congruence(Vec<(TermId, TermId, Box<Explanation>)>),
TheoryPropagation(TheoryExplanation),
}
#[derive(Debug, Clone)]
pub struct TheoryExplanation {
pub theory_id: usize,
pub antecedents: Vec<(TermId, TermId)>,
}
#[derive(Debug, Clone)]
pub struct EqualityWatch {
pub lhs: TermId,
pub rhs: TermId,
pub callback: usize,
}
#[derive(Debug, Clone, Default)]
pub struct EqualityPropStats {
pub equalities_propagated: usize,
pub congruences_found: usize,
pub egraph_merges: usize,
pub explanations_generated: usize,
pub watch_triggers: usize,
}
impl UnionFind {
pub fn new() -> Self {
Self {
parent: FxHashMap::default(),
rank: FxHashMap::default(),
size: FxHashMap::default(),
}
}
pub fn find(&mut self, x: TermId) -> TermId {
if let crate::prelude::hash_map::Entry::Vacant(e) = self.parent.entry(x) {
e.insert(x);
self.rank.insert(x, 0);
self.size.insert(x, 1);
return x;
}
let parent = self.parent[&x];
if parent != x {
let root = self.find(parent);
self.parent.insert(x, root);
root
} else {
x
}
}
pub fn union(&mut self, x: TermId, y: TermId) -> bool {
let root_x = self.find(x);
let root_y = self.find(y);
if root_x == root_y {
return false; }
let rank_x = self.rank.get(&root_x).copied().unwrap_or(0);
let rank_y = self.rank.get(&root_y).copied().unwrap_or(0);
if rank_x < rank_y {
self.parent.insert(root_x, root_y);
let size_x = self.size.get(&root_x).copied().unwrap_or(1);
*self.size.entry(root_y).or_insert(1) += size_x;
} else if rank_x > rank_y {
self.parent.insert(root_y, root_x);
let size_y = self.size.get(&root_y).copied().unwrap_or(1);
*self.size.entry(root_x).or_insert(1) += size_y;
} else {
self.parent.insert(root_y, root_x);
*self.rank.entry(root_x).or_insert(0) += 1;
let size_y = self.size.get(&root_y).copied().unwrap_or(1);
*self.size.entry(root_x).or_insert(1) += size_y;
}
true
}
pub fn connected(&mut self, x: TermId, y: TermId) -> bool {
self.find(x) == self.find(y)
}
pub fn set_size(&mut self, x: TermId) -> usize {
let root = self.find(x);
self.size[&root]
}
}
impl EqualityPropagator {
pub fn new() -> Self {
Self {
union_find: UnionFind::new(),
congruence: CongruenceData::new(),
pending: VecDeque::new(),
watched: FxHashMap::default(),
egraph: EGraph::new(),
stats: EqualityPropStats::default(),
}
}
pub fn assert_equality(
&mut self,
lhs: TermId,
rhs: TermId,
explanation: Explanation,
tm: &TermManager,
) -> Result<(), String> {
if self.union_find.connected(lhs, rhs) {
return Ok(());
}
self.pending.push_back((lhs, rhs, explanation));
self.propagate(tm)?;
Ok(())
}
fn propagate(&mut self, tm: &TermManager) -> Result<(), String> {
while let Some((lhs, rhs, explanation)) = self.pending.pop_front() {
self.propagate_equality(lhs, rhs, explanation, tm)?;
}
self.check_congruences(tm)?;
Ok(())
}
fn propagate_equality(
&mut self,
lhs: TermId,
rhs: TermId,
_explanation: Explanation,
_tm: &TermManager,
) -> Result<(), String> {
if !self.union_find.union(lhs, rhs) {
return Ok(()); }
self.stats.equalities_propagated += 1;
self.egraph.merge(lhs, rhs);
self.stats.egraph_merges += 1;
self.congruence.merge_use_lists(lhs, rhs);
self.trigger_watches(lhs, rhs)?;
let lhs_parents = self.congruence.get_parents(lhs);
let rhs_parents = self.congruence.get_parents(rhs);
for lhs_parent in lhs_parents {
for &rhs_parent in &rhs_parents {
self.congruence
.pending_congruences
.push_back((lhs_parent, rhs_parent));
}
}
Ok(())
}
fn check_congruences(&mut self, tm: &TermManager) -> Result<(), String> {
while let Some((t1, t2)) = self.congruence.pending_congruences.pop_front() {
if self.are_congruent(t1, t2, tm)? {
self.stats.congruences_found += 1;
let explanation = self.generate_congruence_explanation(t1, t2, tm)?;
self.pending.push_back((t1, t2, explanation));
}
}
Ok(())
}
fn are_congruent(&mut self, t1: TermId, t2: TermId, tm: &TermManager) -> Result<bool, String> {
let term1 = tm.get(t1).ok_or("term not found")?;
let term2 = tm.get(t2).ok_or("term not found")?;
if core::mem::discriminant(&term1.kind) != core::mem::discriminant(&term2.kind) {
return Ok(false);
}
let args1 = self.get_args(&term1.kind);
let args2 = self.get_args(&term2.kind);
if args1.len() != args2.len() {
return Ok(false);
}
for (arg1, arg2) in args1.iter().zip(args2.iter()) {
if !self.union_find.connected(*arg1, *arg2) {
return Ok(false);
}
}
Ok(true)
}
fn generate_congruence_explanation(
&mut self,
t1: TermId,
t2: TermId,
tm: &TermManager,
) -> Result<Explanation, String> {
let term1 = tm.get(t1).ok_or("term not found")?;
let term2 = tm.get(t2).ok_or("term not found")?;
let args1 = self.get_args(&term1.kind);
let args2 = self.get_args(&term2.kind);
let mut arg_explanations = Vec::new();
for (arg1, arg2) in args1.iter().zip(args2.iter()) {
let expl = self.explain_equality(*arg1, *arg2)?;
arg_explanations.push((*arg1, *arg2, Box::new(expl)));
}
self.stats.explanations_generated += 1;
Ok(Explanation::Congruence(arg_explanations))
}
pub fn explain_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<Explanation, String> {
if lhs == rhs {
return Ok(Explanation::Reflexivity);
}
if !self.union_find.connected(lhs, rhs) {
return Err("Terms are not equal".to_string());
}
Ok(Explanation::Given)
}
pub fn watch_equality(&mut self, lhs: TermId, rhs: TermId, callback: usize) {
let watch = EqualityWatch { lhs, rhs, callback };
self.watched.entry(lhs).or_default().push(watch.clone());
self.watched.entry(rhs).or_default().push(watch);
}
fn trigger_watches(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
let mut triggered = Vec::new();
if let Some(watches) = self.watched.get(&lhs) {
for watch in watches {
if self.union_find.connected(watch.lhs, watch.rhs) {
triggered.push(watch.callback);
}
}
}
if let Some(watches) = self.watched.get(&rhs) {
for watch in watches {
if self.union_find.connected(watch.lhs, watch.rhs) {
triggered.push(watch.callback);
}
}
}
self.stats.watch_triggers += triggered.len();
Ok(())
}
fn get_args(&self, kind: &TermKind) -> Vec<TermId> {
match kind {
TermKind::And(args) | TermKind::Or(args) => args.to_vec(),
TermKind::Not(arg) => vec![*arg],
TermKind::Eq(l, r) | TermKind::Le(l, r) | TermKind::Lt(l, r) => vec![*l, *r],
TermKind::Add(args) | TermKind::Mul(args) => args.to_vec(),
_ => vec![],
}
}
pub fn stats(&self) -> &EqualityPropStats {
&self.stats
}
}
impl CongruenceData {
pub fn new() -> Self {
Self {
use_list: FxHashMap::default(),
lookup: FxHashMap::default(),
pending_congruences: VecDeque::new(),
}
}
pub fn merge_use_lists(&mut self, t1: TermId, t2: TermId) {
let t1_uses = self.use_list.get(&t1).cloned().unwrap_or_default();
let t2_uses = self.use_list.get(&t2).cloned().unwrap_or_default();
let mut merged = t1_uses;
merged.extend(t2_uses);
self.use_list.insert(t1, merged.clone());
self.use_list.insert(t2, merged);
}
pub fn get_parents(&self, t: TermId) -> Vec<TermId> {
self.use_list.get(&t).cloned().unwrap_or_default()
}
}
impl EGraph {
pub fn new() -> Self {
Self {
eclass: FxHashMap::default(),
nodes: FxHashMap::default(),
data: FxHashMap::default(),
next_id: 0,
}
}
pub fn get_eclass(&mut self, term: TermId) -> EClassId {
if let Some(&id) = self.eclass.get(&term) {
id
} else {
let id = self.next_id;
self.next_id += 1;
self.eclass.insert(term, id);
self.nodes.insert(id, vec![term]);
self.data.insert(
id,
EClassData {
representative: term,
size: 1,
parents: Vec::new(),
},
);
id
}
}
pub fn merge(&mut self, t1: TermId, t2: TermId) {
let id1 = self.get_eclass(t1);
let id2 = self.get_eclass(t2);
if id1 == id2 {
return;
}
let size1 = self.data[&id1].size;
let size2 = self.data[&id2].size;
let (smaller, larger) = if size1 < size2 {
(id1, id2)
} else {
(id2, id1)
};
let smaller_nodes = self.nodes[&smaller].clone();
for &node in &smaller_nodes {
self.eclass.insert(node, larger);
}
if let Some(larger_nodes) = self.nodes.get_mut(&larger) {
larger_nodes.extend(smaller_nodes);
}
self.nodes.remove(&smaller);
let smaller_size = self.data.get(&smaller).map(|d| d.size).unwrap_or(0);
if let Some(larger_data) = self.data.get_mut(&larger) {
larger_data.size += smaller_size;
}
self.data.remove(&smaller);
}
}
impl Default for EqualityPropagator {
fn default() -> Self {
Self::new()
}
}
impl Default for UnionFind {
fn default() -> Self {
Self::new()
}
}
impl Default for CongruenceData {
fn default() -> Self {
Self::new()
}
}
impl Default for EGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_union_find() {
let mut uf = UnionFind::new();
let t1 = TermId::from(1);
let t2 = TermId::from(2);
let t3 = TermId::from(3);
assert!(!uf.connected(t1, t2));
uf.union(t1, t2);
assert!(uf.connected(t1, t2));
uf.union(t2, t3);
assert!(uf.connected(t1, t3));
}
#[test]
fn test_equality_propagator() {
let prop = EqualityPropagator::new();
assert_eq!(prop.stats.equalities_propagated, 0);
}
#[test]
fn test_egraph() {
let mut eg = EGraph::new();
let t1 = TermId::from(1);
let t2 = TermId::from(2);
let id1 = eg.get_eclass(t1);
let id2 = eg.get_eclass(t2);
assert_ne!(id1, id2);
eg.merge(t1, t2);
let id1_after = eg.get_eclass(t1);
let id2_after = eg.get_eclass(t2);
assert_eq!(id1_after, id2_after);
}
}