use std::hash::{Hash, Hasher};
use rustc_hash::FxHashMap;
use rustc_hash::FxHasher;
use smallvec::SmallVec;
pub type EChildren = SmallVec<[EClassId; 4]>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct EClassId(pub u32);
pub trait ENodeLang: Clone + Eq + Hash {
fn children(&self) -> EChildren;
fn with_children(&self, children: &[EClassId]) -> Self;
}
#[derive(Debug, Clone)]
pub struct EClass<L: ENodeLang> {
pub nodes: Vec<L>,
pub parents: Vec<EClassId>,
}
#[derive(Debug, Clone)]
pub struct EGraph<L: ENodeLang> {
classes: Vec<EClass<L>>,
hashcons: FxHashMap<L, EClassId>,
parent: Vec<EClassId>,
pending: Vec<EClassId>,
}
impl<L: ENodeLang> Default for EGraph<L> {
fn default() -> Self {
Self::new()
}
}
impl<L: ENodeLang> EGraph<L> {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(0)
}
#[must_use]
pub fn with_capacity(class_capacity: usize) -> Self {
Self {
classes: Vec::with_capacity(class_capacity),
hashcons: FxHashMap::with_capacity_and_hasher(class_capacity, Default::default()),
parent: Vec::with_capacity(class_capacity),
pending: Vec::with_capacity(class_capacity),
}
}
#[must_use]
pub fn class_count(&self) -> usize {
self.classes.len()
}
pub fn find(&mut self, id: EClassId) -> EClassId {
let mut cur = id;
while self.parent[cur.0 as usize] != cur {
cur = self.parent[cur.0 as usize];
}
let mut walk = id;
while self.parent[walk.0 as usize] != cur {
let next = self.parent[walk.0 as usize];
self.parent[walk.0 as usize] = cur;
walk = next;
}
cur
}
#[must_use]
pub fn find_immut(&self, id: EClassId) -> EClassId {
let mut cur = id;
while self.parent[cur.0 as usize] != cur {
cur = self.parent[cur.0 as usize];
}
cur
}
fn canonicalize(&self, node: &L) -> L {
let canon_children: EChildren = node
.children()
.into_iter()
.map(|c| self.find_immut(c))
.collect();
node.with_children(&canon_children)
}
pub fn add(&mut self, node: L) -> EClassId {
let canon = self.canonicalize(&node);
if let Some(&existing) = self.hashcons.get(&canon) {
return self.find(existing);
}
let new_id = EClassId(self.classes.len() as u32);
self.parent.push(new_id);
for child in canon.children() {
let child_canon = self.find(child);
if let Some(class) = self.classes.get_mut(child_canon.0 as usize) {
class.parents.push(new_id);
}
}
let nodes = vec![canon.clone()];
self.classes.push(EClass {
nodes,
parents: Vec::new(),
});
self.hashcons.insert(canon, new_id);
new_id
}
pub fn union(&mut self, a: EClassId, b: EClassId) -> EClassId {
let a_root = self.find(a);
let b_root = self.find(b);
if a_root == b_root {
return a_root;
}
let (winner, loser) = if a_root.0 < b_root.0 {
(a_root, b_root)
} else {
(b_root, a_root)
};
self.parent[loser.0 as usize] = winner;
let loser_class = std::mem::replace(
&mut self.classes[loser.0 as usize],
EClass {
nodes: Vec::new(),
parents: Vec::new(),
},
);
self.classes[winner.0 as usize]
.nodes
.extend(loser_class.nodes);
self.classes[winner.0 as usize]
.parents
.extend(loser_class.parents);
self.pending.push(winner);
winner
}
pub fn rebuild(&mut self) -> usize {
let mut new_unions = 0;
while let Some(class_id) = self.pending.pop() {
let canonical = self.find(class_id);
let nodes = std::mem::take(&mut self.classes[canonical.0 as usize].nodes);
let mut canon_nodes = Vec::with_capacity(nodes.len());
for node in nodes {
let new_canon = self.canonicalize(&node);
if let Some(&existing) = self.hashcons.get(&new_canon) {
let existing_canon = self.find(existing);
if existing_canon != canonical {
let unified = self.union(existing_canon, canonical);
new_unions += 1;
if unified != canonical {
self.pending.push(unified);
}
}
}
self.hashcons.insert(new_canon.clone(), canonical);
canon_nodes.push(new_canon);
}
dedup_enodes_by_hash(&mut canon_nodes);
self.classes[canonical.0 as usize].nodes = canon_nodes;
}
new_unions
}
pub fn iter_nodes(&self) -> impl Iterator<Item = (EClassId, &L)> {
self.classes
.iter()
.enumerate()
.filter(|(idx, _)| self.parent[*idx] == EClassId(*idx as u32))
.flat_map(|(idx, class)| class.nodes.iter().map(move |n| (EClassId(idx as u32), n)))
}
#[must_use]
pub fn class(&self, id: EClassId) -> Option<&EClass<L>> {
let canon = self.find_immut(id);
self.classes.get(canon.0 as usize)
}
}
fn dedup_enodes_by_hash<L: ENodeLang>(nodes: &mut Vec<L>) {
if nodes.len() <= 1 {
return;
}
let mut keyed = Vec::with_capacity(nodes.len());
keyed.extend(nodes.drain(..).map(|node| (stable_enode_hash(&node), node)));
keyed.sort_unstable_by_key(|(hash, _)| *hash);
let mut deduped: Vec<(u64, L)> = Vec::with_capacity(keyed.len());
for (hash, node) in keyed {
let duplicate_in_hash_bucket = deduped
.iter()
.rev()
.take_while(|(existing_hash, _)| *existing_hash == hash)
.any(|(_, existing)| existing == &node);
if !duplicate_in_hash_bucket {
deduped.push((hash, node));
}
}
nodes.extend(deduped.into_iter().map(|(_, node)| node));
}
fn stable_enode_hash<L: ENodeLang>(node: &L) -> u64 {
let mut hasher = FxHasher::default();
node.hash(&mut hasher);
hasher.finish()
}
pub trait Rule<L: ENodeLang> {
fn name(&self) -> &'static str;
fn matches(&self, egraph: &EGraph<L>) -> Vec<(EClassId, EClassId)>;
}
pub trait Family<L: ENodeLang> {
fn name(&self) -> &'static str;
fn rules(&self) -> Vec<Box<dyn Rule<L>>>;
}
pub fn saturate<L: ENodeLang>(
egraph: &mut EGraph<L>,
rules: &[Box<dyn Rule<L>>],
max_iters: usize,
) -> usize {
let mut equivalences = Vec::with_capacity(egraph.class_count());
for iter in 0..max_iters {
equivalences.clear();
for rule in rules {
equivalences.extend(rule.matches(egraph));
}
if equivalences.is_empty() {
return iter;
}
for (a, b) in equivalences.drain(..) {
egraph.union(a, b);
}
let extra = egraph.rebuild();
if extra == 0 && egraph.pending.is_empty() {
}
}
max_iters
}
pub struct DeviceAwareRule<L: ENodeLang, F: Fn() -> bool> {
inner: Box<dyn Rule<L>>,
predicate: F,
}
impl<L: ENodeLang, F: Fn() -> bool> DeviceAwareRule<L, F> {
pub fn new(inner: Box<dyn Rule<L>>, predicate: F) -> Self {
Self { inner, predicate }
}
}
impl<L: ENodeLang, F: Fn() -> bool> Rule<L> for DeviceAwareRule<L, F> {
fn name(&self) -> &'static str {
self.inner.name()
}
fn matches(&self, egraph: &EGraph<L>) -> Vec<(EClassId, EClassId)> {
if (self.predicate)() {
self.inner.matches(egraph)
} else {
Vec::new()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FamilySaturationReport {
pub family: &'static str,
pub iters_used: usize,
pub budget: usize,
}
pub fn saturate_per_family<L: ENodeLang>(
egraph: &mut EGraph<L>,
families: &[&dyn Family<L>],
budget_for: impl Fn(&str) -> usize,
) -> Vec<FamilySaturationReport> {
let mut out = Vec::with_capacity(families.len());
for family in families {
let name = family.name();
let budget = budget_for(name);
if budget == 0 {
out.push(FamilySaturationReport {
family: name,
iters_used: 0,
budget: 0,
});
continue;
}
let rules = family.rules();
let iters_used = saturate(egraph, &rules, budget);
out.push(FamilySaturationReport {
family: name,
iters_used,
budget,
});
}
out
}
pub fn extract_best<L: ENodeLang>(
egraph: &EGraph<L>,
class_id: EClassId,
cost_fn: impl Fn(&L) -> u64,
) -> Option<(L, u64)> {
let mut costs = FxHashMap::with_capacity_and_hasher(egraph.class_count(), Default::default());
let mut changed = true;
let mut iters = 0;
while changed && iters < 1024 {
changed = false;
iters += 1;
for (cid, node) in egraph.iter_nodes() {
let canon_cid = egraph.find_immut(cid);
let mut node_cost = cost_fn(node);
let mut child_overflow = false;
for child in node.children() {
let canon_child = egraph.find_immut(child);
if let Some((_, c)) = costs.get(&canon_child) {
node_cost = node_cost.saturating_add(*c);
} else {
child_overflow = true;
break;
}
}
if child_overflow {
continue;
}
match costs.get(&canon_cid) {
Some((_, existing_cost)) if *existing_cost <= node_cost => {}
_ => {
costs.insert(canon_cid, (node.clone(), node_cost));
changed = true;
}
}
}
}
let canon = egraph.find_immut(class_id);
costs.get(&canon).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHashSet;
use smallvec::smallvec;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Arith {
Const(u32),
Add(EClassId, EClassId),
Mul(EClassId, EClassId),
}
impl ENodeLang for Arith {
fn children(&self) -> EChildren {
match self {
Self::Const(_) => EChildren::new(),
Self::Add(a, b) | Self::Mul(a, b) => smallvec![*a, *b],
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Self::Const(n) => Self::Const(*n),
Self::Add(_, _) => Self::Add(children[0], children[1]),
Self::Mul(_, _) => Self::Mul(children[0], children[1]),
}
}
}
fn arith_cost(node: &Arith) -> u64 {
match node {
Arith::Const(_) => 1,
Arith::Add(_, _) => 2,
Arith::Mul(_, _) => 3,
}
}
#[test]
fn empty_egraph_has_zero_classes() {
let egraph: EGraph<Arith> = EGraph::new();
assert_eq!(egraph.class_count(), 0);
}
#[test]
fn add_const_creates_one_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
assert_eq!(egraph.class_count(), 1);
}
#[test]
fn add_same_const_twice_returns_same_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let b = egraph.add(Arith::Const(7));
assert_eq!(a, b);
assert_eq!(egraph.class_count(), 1);
}
#[test]
fn add_distinct_consts_creates_distinct_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let b = egraph.add(Arith::Const(8));
assert_ne!(a, b);
assert_eq!(egraph.class_count(), 2);
}
#[test]
fn add_compound_node_creates_proper_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let sum = egraph.add(Arith::Add(a, b));
assert_eq!(egraph.class_count(), 3);
assert_ne!(sum, a);
assert_ne!(sum, b);
}
#[test]
fn union_merges_two_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let unified = egraph.union(a, b);
assert_eq!(egraph.find(a), unified);
assert_eq!(egraph.find(b), unified);
}
#[test]
fn union_is_idempotent() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let first = egraph.union(a, b);
let second = egraph.union(a, b);
assert_eq!(first, second);
}
#[test]
fn rebuild_canonicalizes_compound_nodes_after_union() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let _add_12 = egraph.add(Arith::Add(one, two));
let _add_22 = egraph.add(Arith::Add(two, two));
egraph.union(one, two);
let _ = egraph.rebuild();
let post_one = egraph.find(one);
let post_two = egraph.find(two);
assert_eq!(post_one, post_two, "1 and 2 must be in the same class");
}
#[test]
fn extract_best_picks_cheapest_equivalent() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let three = egraph.add(Arith::Const(3));
let add_12 = egraph.add(Arith::Add(one, two));
egraph.union(add_12, three);
let _ = egraph.rebuild();
let (best, cost) = extract_best(&egraph, add_12, arith_cost).expect("must extract");
assert_eq!(best, Arith::Const(3));
assert_eq!(cost, 1);
}
#[test]
fn extract_best_returns_only_node_when_no_alternatives() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(42));
let (best, cost) = extract_best(&egraph, a, arith_cost).expect("must extract");
assert_eq!(best, Arith::Const(42));
assert_eq!(cost, 1);
}
struct UnionEqualConstsRule;
impl Rule<Arith> for UnionEqualConstsRule {
fn name(&self) -> &'static str {
"union_equal_consts"
}
fn matches(&self, egraph: &EGraph<Arith>) -> Vec<(EClassId, EClassId)> {
let mut by_value: FxHashMap<u32, Vec<EClassId>> = FxHashMap::default();
for (cid, node) in egraph.iter_nodes() {
if let Arith::Const(v) = node {
by_value.entry(*v).or_default().push(cid);
}
}
let mut out = Vec::new();
for ids in by_value.values() {
for window in ids.windows(2) {
out.push((window[0], window[1]));
}
}
out
}
}
#[test]
fn saturate_runs_to_fixed_point() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _a = egraph.add(Arith::Const(7));
let _b = egraph.add(Arith::Const(8));
let rules: Vec<Box<dyn Rule<Arith>>> = vec![Box::new(UnionEqualConstsRule)];
let iters = saturate(&mut egraph, &rules, 10);
assert!(iters <= 10);
assert!(iters <= 1);
}
#[test]
fn find_immut_returns_canonical_after_union() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
egraph.union(a, b);
let canon_a = egraph.find_immut(a);
let canon_b = egraph.find_immut(b);
assert_eq!(canon_a, canon_b);
}
#[test]
fn class_lookup_returns_canonical_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let class = egraph.class(a).expect("class must exist");
assert!(matches!(class.nodes[0], Arith::Const(7)));
}
#[test]
fn rebuild_propagates_through_parents() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let add_12 = egraph.add(Arith::Add(one, two));
egraph.union(one, two);
let _ = egraph.rebuild();
let class = egraph.class(add_12).expect("class must still exist");
match &class.nodes[0] {
Arith::Add(a, b) => {
let canon_a = egraph.find_immut(*a);
let canon_b = egraph.find_immut(*b);
assert_eq!(
canon_a, canon_b,
"Add(1,2)'s children must canonicalize to the same class after union"
);
}
other => panic!("expected Add; got {other:?}"),
}
}
#[test]
fn device_aware_rule_predicate_true_forwards_matches() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(1));
let _ = egraph.add(Arith::Const(2));
let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || true);
let matches = rule.matches(&egraph);
assert!(
matches.is_empty(),
"no duplicate values should produce no matches"
);
let mut egraph: EGraph<Arith> = EGraph::new();
let _a = egraph.add(Arith::Const(7));
let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || true);
let _ = rule.matches(&egraph); }
#[test]
fn device_aware_rule_predicate_false_returns_empty() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
let _ = egraph.add(Arith::Const(7)); let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || false);
let matches = rule.matches(&egraph);
assert!(
matches.is_empty(),
"predicate false must short-circuit to empty"
);
}
#[test]
fn device_aware_rule_forwards_inner_name() {
let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || true);
assert_eq!(rule.name(), "union_equal_consts");
}
struct ConstUnionFamily {
name: &'static str,
}
impl Family<Arith> for ConstUnionFamily {
fn name(&self) -> &'static str {
self.name
}
fn rules(&self) -> Vec<Box<dyn Rule<Arith>>> {
vec![Box::new(UnionEqualConstsRule)]
}
}
#[test]
fn saturate_per_family_skips_zero_budget() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
let fam = ConstUnionFamily { name: "f0" };
let report = saturate_per_family(&mut egraph, &[&fam], |_| 0);
assert_eq!(report.len(), 1);
assert_eq!(report[0].family, "f0");
assert_eq!(report[0].iters_used, 0);
assert_eq!(report[0].budget, 0);
}
#[test]
fn saturate_per_family_runs_each_family_independently() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(1));
let _ = egraph.add(Arith::Const(2));
let fam_a = ConstUnionFamily { name: "alpha" };
let fam_b = ConstUnionFamily { name: "beta" };
let report = saturate_per_family(&mut egraph, &[&fam_a, &fam_b], |name| match name {
"alpha" => 3,
"beta" => 5,
_ => 0,
});
assert_eq!(report.len(), 2);
assert_eq!(report[0].family, "alpha");
assert_eq!(report[0].budget, 3);
assert!(report[0].iters_used <= 3);
assert_eq!(report[1].family, "beta");
assert_eq!(report[1].budget, 5);
assert!(report[1].iters_used <= 5);
}
#[test]
fn saturate_per_family_empty_input_returns_empty() {
let mut egraph: EGraph<Arith> = EGraph::new();
let report = saturate_per_family(&mut egraph, &[], |_| 10);
assert!(report.is_empty());
}
#[test]
fn saturate_per_family_reports_iters_used_le_budget() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(1));
let _ = egraph.add(Arith::Const(2));
let fam = ConstUnionFamily { name: "single" };
let report = saturate_per_family(&mut egraph, &[&fam], |_| 100);
assert_eq!(report.len(), 1);
assert!(
report[0].iters_used <= report[0].budget,
"iters_used ({}) must not exceed budget ({})",
report[0].iters_used,
report[0].budget
);
}
#[test]
fn iter_nodes_visits_only_canonical_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
egraph.union(a, b);
let _ = egraph.rebuild();
let unique_classes: FxHashSet<EClassId> = egraph.iter_nodes().map(|(cid, _)| cid).collect();
assert_eq!(
unique_classes.len(),
1,
"post-union iter must visit exactly one canonical class id"
);
let total_nodes = egraph.iter_nodes().count();
assert_eq!(
total_nodes, 2,
"the merged class still holds both original nodes (Const(1) + Const(2))"
);
}
}