use std::error::Error as StdError;
use std::fmt;
use std::hash::{Hash, Hasher};
use rustc_hash::FxHashMap;
use rustc_hash::FxHasher;
use smallvec::SmallVec;
pub type EChildren = SmallVec<[EClassId; 4]>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct EClassId(pub u32);
pub trait ENodeLang: Clone + Eq + Hash {
fn children(&self) -> EChildren;
#[must_use]
fn with_children(&self, children: &[EClassId]) -> Self;
}
#[derive(Debug, Clone)]
pub struct EClass<L: ENodeLang> {
pub nodes: Vec<L>,
pub parents: Vec<EClassId>,
}
#[derive(Debug, Clone)]
pub struct EGraph<L: ENodeLang> {
classes: Vec<EClass<L>>,
hashcons: FxHashMap<L, EClassId>,
parent: Vec<EClassId>,
pending: Vec<EClassId>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EGraphError {
Capacity {
context: &'static str,
requested: usize,
source: String,
},
ClassIdOverflow {
index: usize,
},
ClassIdOutOfBounds {
context: &'static str,
id: EClassId,
len: usize,
},
}
impl fmt::Display for EGraphError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Capacity {
context,
requested,
source,
} => write!(
f,
"{context} could not reserve {requested} additional slots: {source}. Fix: lower the saturation batch size or split the optimizer workload."
),
Self::ClassIdOverflow { index } => write!(
f,
"egraph class index {index} exceeds the u32 EClassId space. Fix: split the egraph or extract before adding more classes."
),
Self::ClassIdOutOfBounds { context, id, len } => write!(
f,
"{context} referenced eclass id {} but only {len} class slots exist. Fix: pass ids returned by this EGraph instance.",
id.0
),
}
}
}
impl StdError for EGraphError {}
fn log_egraph_compat_error(context: &'static str, error: &EGraphError) {
tracing::error!(
context,
error = %error,
"legacy infallible egraph API failed; use the matching try_* API to handle this condition explicitly"
);
}
impl<L: ENodeLang> Default for EGraph<L> {
fn default() -> Self {
Self::new()
}
}
impl<L: ENodeLang> EGraph<L> {
#[must_use]
pub fn new() -> Self {
Self::empty_unreserved()
}
#[must_use]
pub fn with_capacity(class_capacity: usize) -> Self {
match Self::try_with_capacity(class_capacity) {
Ok(egraph) => egraph,
Err(error) => {
log_egraph_compat_error("egraph with_capacity", &error);
Self::empty_unreserved()
}
}
}
pub fn try_with_capacity(class_capacity: usize) -> Result<Self, EGraphError> {
let mut classes = Vec::new();
reserve_vec_exact(&mut classes, class_capacity, "egraph class storage")?;
let mut hashcons = FxHashMap::default();
reserve_hashcons(&mut hashcons, class_capacity, "egraph hashcons storage")?;
let mut parent = Vec::new();
reserve_vec_exact(
&mut parent,
class_capacity,
"egraph union-find parent storage",
)?;
let mut pending = Vec::new();
reserve_vec_exact(&mut pending, class_capacity, "egraph rebuild queue storage")?;
Ok(Self {
classes,
hashcons,
parent,
pending,
})
}
fn empty_unreserved() -> Self {
Self {
classes: Vec::new(),
hashcons: FxHashMap::default(),
parent: Vec::new(),
pending: Vec::new(),
}
}
#[must_use]
pub fn class_count(&self) -> usize {
self.classes.len()
}
pub fn find(&mut self, id: EClassId) -> EClassId {
match self.try_find(id) {
Ok(found) => found,
Err(error) => {
log_egraph_compat_error("egraph find", &error);
id
}
}
}
pub fn try_find(&mut self, id: EClassId) -> Result<EClassId, EGraphError> {
let mut cur = id;
loop {
let cur_idx = eclass_index(cur, self.parent.len(), "egraph find")?;
let parent = self.parent[cur_idx];
if parent == cur {
break;
}
cur = parent;
}
let mut walk = id;
loop {
let walk_idx = eclass_index(walk, self.parent.len(), "egraph path compression")?;
let next = self.parent[walk_idx];
if next == cur {
break;
}
self.parent[walk_idx] = cur;
walk = next;
}
Ok(cur)
}
#[must_use]
pub fn find_immut(&self, id: EClassId) -> EClassId {
match self.try_find_immut(id) {
Ok(found) => found,
Err(error) => {
log_egraph_compat_error("egraph immutable find", &error);
id
}
}
}
pub fn try_find_immut(&self, id: EClassId) -> Result<EClassId, EGraphError> {
let mut cur = id;
loop {
let cur_idx = eclass_index(cur, self.parent.len(), "egraph immutable find")?;
let parent = self.parent[cur_idx];
if parent == cur {
break;
}
cur = parent;
}
Ok(cur)
}
fn canonicalize(&self, node: &L) -> L {
match self.try_canonicalize(node) {
Ok(canonical) => canonical,
Err(error) => {
log_egraph_compat_error("egraph canonicalize", &error);
node.clone()
}
}
}
fn try_canonicalize(&self, node: &L) -> Result<L, EGraphError> {
let canon_children: EChildren = node
.children()
.into_iter()
.map(|c| self.try_find_immut(c))
.collect::<Result<_, _>>()?;
Ok(node.with_children(&canon_children))
}
pub fn add(&mut self, node: L) -> EClassId {
match self.try_add(node) {
Ok(id) => id,
Err(error) => {
log_egraph_compat_error("egraph add", &error);
EClassId(0)
}
}
}
#[expect(
clippy::needless_pass_by_value,
reason = "public insertion API consumes language nodes; canonicalized misses store an owned node"
)]
pub fn try_add(&mut self, node: L) -> Result<EClassId, EGraphError> {
let canon = self.try_canonicalize(&node)?;
if let Some(&existing) = self.hashcons.get(&canon) {
return self.try_find(existing);
}
let new_id = try_eclass_id_from_index(self.classes.len())?;
let canon_children = canon.children();
reserve_vec_exact(&mut self.parent, 1, "egraph parent insertion")?;
reserve_vec_exact(&mut self.classes, 1, "egraph class insertion")?;
reserve_hashcons(&mut self.hashcons, 1, "egraph hashcons insertion")?;
let mut child_indices: SmallVec<[(usize, EClassId); 4]> = SmallVec::new();
for child in &canon_children {
let child_canon = self.try_find(*child)?;
let child_idx = eclass_index(
child_canon,
self.classes.len(),
"egraph child parent registration",
)?;
child_indices.push((child_idx, child_canon));
}
for (position, (child_idx, _)) in child_indices.iter().enumerate() {
if child_indices[..position]
.iter()
.any(|(seen_idx, _)| seen_idx == child_idx)
{
continue;
}
let occurrences = child_indices
.iter()
.filter(|(seen_idx, _)| seen_idx == child_idx)
.count();
reserve_vec_exact(
&mut self.classes[*child_idx].parents,
occurrences,
"egraph child parent registration",
)?;
}
let mut nodes = Vec::new();
reserve_vec_exact(&mut nodes, 1, "egraph singleton enode storage")?;
nodes.push(canon.clone());
self.parent.push(new_id);
for (child_idx, _) in child_indices {
self.classes[child_idx].parents.push(new_id);
}
self.classes.push(EClass {
nodes,
parents: Vec::new(),
});
self.hashcons.insert(canon, new_id);
Ok(new_id)
}
pub fn union(&mut self, a: EClassId, b: EClassId) -> EClassId {
match self.try_union(a, b) {
Ok(id) => id,
Err(error) => {
log_egraph_compat_error("egraph union", &error);
a
}
}
}
pub fn try_union(&mut self, a: EClassId, b: EClassId) -> Result<EClassId, EGraphError> {
let a_root = self.try_find(a)?;
let b_root = self.try_find(b)?;
if a_root == b_root {
return Ok(a_root);
}
let (winner, loser) = if a_root.0 < b_root.0 {
(a_root, b_root)
} else {
(b_root, a_root)
};
let winner_idx = eclass_index(winner, self.classes.len(), "egraph union winner")?;
let loser_idx = eclass_index(loser, self.classes.len(), "egraph union loser")?;
let loser_nodes_len = self.classes[loser_idx].nodes.len();
let loser_parents_len = self.classes[loser_idx].parents.len();
reserve_vec_exact(
&mut self.classes[winner_idx].nodes,
loser_nodes_len,
"egraph union node merge",
)?;
reserve_vec_exact(
&mut self.classes[winner_idx].parents,
loser_parents_len,
"egraph union parent merge",
)?;
reserve_vec_exact(&mut self.pending, 1, "egraph rebuild queue push")?;
self.parent[loser_idx] = winner;
let loser_class = std::mem::replace(
&mut self.classes[loser_idx],
EClass {
nodes: Vec::new(),
parents: Vec::new(),
},
);
self.classes[winner_idx].nodes.extend(loser_class.nodes);
self.classes[winner_idx].parents.extend(loser_class.parents);
self.pending.push(winner);
Ok(winner)
}
pub fn rebuild(&mut self) -> usize {
match self.try_rebuild() {
Ok(count) => count,
Err(error) => {
log_egraph_compat_error("egraph rebuild", &error);
0
}
}
}
pub fn try_rebuild(&mut self) -> Result<usize, EGraphError> {
let mut new_unions = 0;
while let Some(class_id) = self.pending.pop() {
let canonical = self.try_find(class_id)?;
let canonical_idx = eclass_index(canonical, self.classes.len(), "egraph rebuild")?;
let nodes_len = self.classes[canonical_idx].nodes.len();
let mut canon_nodes = Vec::new();
reserve_vec_exact(
&mut canon_nodes,
nodes_len,
"egraph rebuild canonical node staging",
)?;
reserve_hashcons(
&mut self.hashcons,
nodes_len,
"egraph rebuild hashcons staging",
)?;
let nodes = std::mem::take(&mut self.classes[canonical_idx].nodes);
for node in nodes {
let new_canon = self.try_canonicalize(&node)?;
if let Some(&existing) = self.hashcons.get(&new_canon) {
let existing_canon = self.try_find(existing)?;
if existing_canon != canonical {
let unified = self.try_union(existing_canon, canonical)?;
new_unions += 1;
if unified != canonical {
reserve_vec_exact(
&mut self.pending,
1,
"egraph rebuild winner reschedule",
)?;
self.pending.push(unified);
}
}
}
self.hashcons.insert(new_canon.clone(), canonical);
canon_nodes.push(new_canon);
}
try_dedup_enodes_by_hash(&mut canon_nodes)?;
self.classes[canonical_idx].nodes = canon_nodes;
}
Ok(new_unions)
}
pub fn iter_nodes(&self) -> impl Iterator<Item = (EClassId, &L)> {
self.classes
.iter()
.enumerate()
.filter_map(|(idx, class)| {
let class_id = match try_eclass_id_from_index(idx) {
Ok(class_id) => class_id,
Err(error) => {
log_egraph_compat_error("egraph iter_nodes class id", &error);
return None;
}
};
(self.parent[idx] == class_id).then_some((class_id, class))
})
.flat_map(|(class_id, class)| class.nodes.iter().map(move |n| (class_id, n)))
}
#[must_use]
pub fn class(&self, id: EClassId) -> Option<&EClass<L>> {
match self.try_class(id) {
Ok(class) => class,
Err(error) => {
log_egraph_compat_error("egraph class lookup", &error);
None
}
}
}
pub fn try_class(&self, id: EClassId) -> Result<Option<&EClass<L>>, EGraphError> {
let canon = self.try_find_immut(id)?;
let idx = eclass_index(canon, self.classes.len(), "egraph class lookup")?;
Ok(self.classes.get(idx))
}
}
fn eclass_id_from_index(index: usize) -> EClassId {
match try_eclass_id_from_index(index) {
Ok(id) => id,
Err(error) => {
log_egraph_compat_error("egraph class id conversion", &error);
EClassId(0)
}
}
}
fn try_eclass_id_from_index(index: usize) -> Result<EClassId, EGraphError> {
u32::try_from(index)
.map(EClassId)
.map_err(|_| EGraphError::ClassIdOverflow { index })
}
fn eclass_index(id: EClassId, len: usize, context: &'static str) -> Result<usize, EGraphError> {
let index =
usize::try_from(id.0).map_err(|_| EGraphError::ClassIdOutOfBounds { context, id, len })?;
if index < len {
Ok(index)
} else {
Err(EGraphError::ClassIdOutOfBounds { context, id, len })
}
}
fn reserve_vec_exact<T>(
vec: &mut Vec<T>,
additional: usize,
context: &'static str,
) -> Result<(), EGraphError> {
vec.try_reserve_exact(additional)
.map_err(|source| EGraphError::Capacity {
context,
requested: additional,
source: source.to_string(),
})
}
fn reserve_hashcons<L: Eq + Hash>(
hashcons: &mut FxHashMap<L, EClassId>,
additional: usize,
context: &'static str,
) -> Result<(), EGraphError> {
hashcons
.try_reserve(additional)
.map_err(|source| EGraphError::Capacity {
context,
requested: additional,
source: source.to_string(),
})
}
fn dedup_enodes_by_hash<L: ENodeLang>(nodes: &mut Vec<L>) {
if let Err(error) = try_dedup_enodes_by_hash(nodes) {
log_egraph_compat_error("egraph dedup", &error);
}
}
fn try_dedup_enodes_by_hash<L: ENodeLang>(nodes: &mut Vec<L>) -> Result<(), EGraphError> {
if nodes.len() <= 1 {
return Ok(());
}
let mut keyed = Vec::new();
reserve_vec_exact(&mut keyed, nodes.len(), "egraph dedup hash staging")?;
keyed.extend(nodes.drain(..).map(|node| (stable_enode_hash(&node), node)));
keyed.sort_unstable_by_key(|(hash, _)| *hash);
let mut deduped: Vec<(u64, L)> = Vec::new();
reserve_vec_exact(&mut deduped, keyed.len(), "egraph dedup output staging")?;
for (hash, node) in keyed {
let duplicate_in_hash_bucket = deduped
.iter()
.rev()
.take_while(|(existing_hash, _)| *existing_hash == hash)
.any(|(_, existing)| existing == &node);
if !duplicate_in_hash_bucket {
deduped.push((hash, node));
}
}
reserve_vec_exact(nodes, deduped.len(), "egraph dedup node restoration")?;
nodes.extend(deduped.into_iter().map(|(_, node)| node));
Ok(())
}
fn stable_enode_hash<L: ENodeLang>(node: &L) -> u64 {
let mut hasher = FxHasher::default();
node.hash(&mut hasher);
hasher.finish()
}
pub trait Rule<L: ENodeLang> {
fn name(&self) -> &'static str;
fn matches(&self, egraph: &EGraph<L>) -> Vec<(EClassId, EClassId)>;
}
pub trait Family<L: ENodeLang> {
fn name(&self) -> &'static str;
fn rules(&self) -> Vec<Box<dyn Rule<L>>>;
}
pub fn saturate<L: ENodeLang>(
egraph: &mut EGraph<L>,
rules: &[Box<dyn Rule<L>>],
max_iters: usize,
) -> usize {
match try_saturate(egraph, rules, max_iters) {
Ok(iters) => iters,
Err(error) => {
log_egraph_compat_error("egraph saturate", &error);
0
}
}
}
pub fn try_saturate<L: ENodeLang>(
egraph: &mut EGraph<L>,
rules: &[Box<dyn Rule<L>>],
max_iters: usize,
) -> Result<usize, EGraphError> {
let mut equivalences = Vec::new();
reserve_vec_exact(
&mut equivalences,
egraph.class_count(),
"egraph saturation equivalence staging",
)?;
for iter in 0..max_iters {
equivalences.clear();
for rule in rules {
let matches = rule.matches(egraph);
reserve_vec_exact(
&mut equivalences,
matches.len(),
"egraph saturation rule-match staging",
)?;
equivalences.extend(matches);
}
if equivalences.is_empty() {
return Ok(iter);
}
for (a, b) in equivalences.drain(..) {
egraph.try_union(a, b)?;
}
let extra = egraph.try_rebuild()?;
if extra == 0 && egraph.pending.is_empty() {
}
}
Ok(max_iters)
}
pub struct DeviceAwareRule<L: ENodeLang, F: Fn() -> bool> {
inner: Box<dyn Rule<L>>,
predicate: F,
}
impl<L: ENodeLang, F: Fn() -> bool> DeviceAwareRule<L, F> {
pub fn new(inner: Box<dyn Rule<L>>, predicate: F) -> Self {
Self { inner, predicate }
}
}
impl<L: ENodeLang, F: Fn() -> bool> Rule<L> for DeviceAwareRule<L, F> {
fn name(&self) -> &'static str {
self.inner.name()
}
fn matches(&self, egraph: &EGraph<L>) -> Vec<(EClassId, EClassId)> {
if (self.predicate)() {
self.inner.matches(egraph)
} else {
Vec::new()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FamilySaturationReport {
pub family: &'static str,
pub iters_used: usize,
pub budget: usize,
}
pub fn saturate_per_family<L: ENodeLang>(
egraph: &mut EGraph<L>,
families: &[&dyn Family<L>],
budget_for: impl Fn(&str) -> usize,
) -> Vec<FamilySaturationReport> {
match try_saturate_per_family(egraph, families, budget_for) {
Ok(report) => report,
Err(error) => {
log_egraph_compat_error("egraph saturate_per_family", &error);
Vec::new()
}
}
}
pub fn try_saturate_per_family<L: ENodeLang>(
egraph: &mut EGraph<L>,
families: &[&dyn Family<L>],
budget_for: impl Fn(&str) -> usize,
) -> Result<Vec<FamilySaturationReport>, EGraphError> {
let mut out = Vec::new();
reserve_vec_exact(
&mut out,
families.len(),
"egraph family saturation report staging",
)?;
for family in families {
let name = family.name();
let budget = budget_for(name);
if budget == 0 {
out.push(FamilySaturationReport {
family: name,
iters_used: 0,
budget: 0,
});
continue;
}
let rules = family.rules();
let iters_used = try_saturate(egraph, &rules, budget)?;
out.push(FamilySaturationReport {
family: name,
iters_used,
budget,
});
}
Ok(out)
}
pub fn extract_best<L: ENodeLang>(
egraph: &EGraph<L>,
class_id: EClassId,
cost_fn: impl Fn(&L) -> u64,
) -> Option<(L, u64)> {
match try_extract_best(egraph, class_id, cost_fn) {
Ok(best) => best,
Err(error) => {
log_egraph_compat_error("egraph extract_best", &error);
None
}
}
}
pub fn try_extract_best<L: ENodeLang>(
egraph: &EGraph<L>,
class_id: EClassId,
cost_fn: impl Fn(&L) -> u64,
) -> Result<Option<(L, u64)>, EGraphError> {
let class_count = egraph.class_count();
let mut costs: Vec<Option<(L, u64)>> = Vec::new();
reserve_vec_exact(&mut costs, class_count, "egraph extraction cost table")?;
costs.resize_with(class_count, || None);
let mut changed = true;
let mut iters = 0;
while changed && iters < 1024 {
changed = false;
iters += 1;
for (cid, node) in egraph.iter_nodes() {
let canon_cid_idx = eclass_index(cid, class_count, "egraph extraction class")?;
let mut node_cost = cost_fn(node);
let mut child_overflow = false;
for child in node.children() {
let canon_child = egraph.try_find_immut(child)?;
let canon_child_idx =
eclass_index(canon_child, class_count, "egraph extraction child class")?;
if let Some((_, c)) = costs.get(canon_child_idx).and_then(Option::as_ref) {
node_cost = node_cost.saturating_add(*c);
} else {
child_overflow = true;
break;
}
}
if child_overflow {
continue;
}
let Some(slot) = costs.get_mut(canon_cid_idx) else {
continue;
};
match slot {
Some((_, existing_cost)) if *existing_cost <= node_cost => {}
_ => {
*slot = Some((node.clone(), node_cost));
changed = true;
}
}
}
}
let canon = egraph.try_find_immut(class_id)?;
let canon_idx = eclass_index(canon, class_count, "egraph extraction root class")?;
Ok(costs.get(canon_idx).and_then(Clone::clone))
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHashSet;
use smallvec::smallvec;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Arith {
Const(u32),
Add(EClassId, EClassId),
Mul(EClassId, EClassId),
}
impl ENodeLang for Arith {
fn children(&self) -> EChildren {
match self {
Self::Const(_) => EChildren::new(),
Self::Add(a, b) | Self::Mul(a, b) => smallvec![*a, *b],
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Self::Const(n) => Self::Const(*n),
Self::Add(_, _) => Self::Add(children[0], children[1]),
Self::Mul(_, _) => Self::Mul(children[0], children[1]),
}
}
}
fn arith_cost(node: &Arith) -> u64 {
match node {
Arith::Const(_) => 1,
Arith::Add(_, _) => 2,
Arith::Mul(_, _) => 3,
}
}
#[test]
fn empty_egraph_has_zero_classes() {
let egraph: EGraph<Arith> = EGraph::new();
assert_eq!(egraph.class_count(), 0);
}
#[test]
fn add_const_creates_one_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
assert_eq!(egraph.class_count(), 1);
}
#[test]
fn add_same_const_twice_returns_same_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let b = egraph.add(Arith::Const(7));
assert_eq!(a, b);
assert_eq!(egraph.class_count(), 1);
}
#[test]
fn add_distinct_consts_creates_distinct_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let b = egraph.add(Arith::Const(8));
assert_ne!(a, b);
assert_eq!(egraph.class_count(), 2);
}
#[test]
fn add_compound_node_creates_proper_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let sum = egraph.add(Arith::Add(a, b));
assert_eq!(egraph.class_count(), 3);
assert_ne!(sum, a);
assert_ne!(sum, b);
}
#[test]
fn union_merges_two_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let unified = egraph.union(a, b);
assert_eq!(egraph.find(a), unified);
assert_eq!(egraph.find(b), unified);
}
#[test]
fn union_is_idempotent() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
let first = egraph.union(a, b);
let second = egraph.union(a, b);
assert_eq!(first, second);
}
#[test]
fn rebuild_canonicalizes_compound_nodes_after_union() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let _add_12 = egraph.add(Arith::Add(one, two));
let _add_22 = egraph.add(Arith::Add(two, two));
egraph.union(one, two);
let _ = egraph.rebuild();
let post_one = egraph.find(one);
let post_two = egraph.find(two);
assert_eq!(post_one, post_two, "1 and 2 must be in the same class");
}
#[test]
fn extract_best_picks_cheapest_equivalent() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let three = egraph.add(Arith::Const(3));
let add_12 = egraph.add(Arith::Add(one, two));
egraph.union(add_12, three);
let _ = egraph.rebuild();
let (best, cost) = extract_best(&egraph, add_12, arith_cost).expect("Fix: must extract");
assert_eq!(best, Arith::Const(3));
assert_eq!(cost, 1);
}
#[test]
fn extract_best_returns_only_node_when_no_alternatives() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(42));
let (best, cost) = extract_best(&egraph, a, arith_cost).expect("Fix: must extract");
assert_eq!(best, Arith::Const(42));
assert_eq!(cost, 1);
}
struct UnionEqualConstsRule;
impl Rule<Arith> for UnionEqualConstsRule {
fn name(&self) -> &'static str {
"union_equal_consts"
}
fn matches(&self, egraph: &EGraph<Arith>) -> Vec<(EClassId, EClassId)> {
let mut by_value: FxHashMap<u32, Vec<EClassId>> = FxHashMap::default();
for (cid, node) in egraph.iter_nodes() {
if let Arith::Const(v) = node {
by_value.entry(*v).or_default().push(cid);
}
}
let mut out = Vec::new();
for ids in by_value.values() {
for window in ids.windows(2) {
out.push((window[0], window[1]));
}
}
out
}
}
#[test]
fn saturate_runs_to_fixed_point() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _a = egraph.add(Arith::Const(7));
let _b = egraph.add(Arith::Const(8));
let rules: Vec<Box<dyn Rule<Arith>>> = vec![Box::new(UnionEqualConstsRule)];
let iters = saturate(&mut egraph, &rules, 10);
assert!(iters <= 10);
assert!(iters <= 1);
}
#[test]
fn find_immut_returns_canonical_after_union() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
egraph.union(a, b);
let canon_a = egraph.find_immut(a);
let canon_b = egraph.find_immut(b);
assert_eq!(canon_a, canon_b);
}
#[test]
fn class_lookup_returns_canonical_class() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(7));
let class = egraph.class(a).expect("Fix: class must exist");
assert!(matches!(class.nodes[0], Arith::Const(7)));
}
#[test]
fn rebuild_propagates_through_parents() {
let mut egraph: EGraph<Arith> = EGraph::new();
let one = egraph.add(Arith::Const(1));
let two = egraph.add(Arith::Const(2));
let add_12 = egraph.add(Arith::Add(one, two));
egraph.union(one, two);
let _ = egraph.rebuild();
let class = egraph.class(add_12).expect("Fix: class must still exist");
match &class.nodes[0] {
Arith::Add(a, b) => {
let canon_a = egraph.find_immut(*a);
let canon_b = egraph.find_immut(*b);
assert_eq!(
canon_a, canon_b,
"Add(1,2)'s children must canonicalize to the same class after union"
);
}
other => panic!("expected Add; got {other:?}"),
}
}
struct PairConstSelfRule;
impl Rule<Arith> for PairConstSelfRule {
fn name(&self) -> &'static str {
"pair_const_self"
}
fn matches(&self, egraph: &EGraph<Arith>) -> Vec<(EClassId, EClassId)> {
let mut out = Vec::new();
for (cid, node) in egraph.iter_nodes() {
if let Arith::Const(_) = node {
out.push((cid, cid));
}
}
out
}
}
#[test]
fn device_aware_rule_predicate_true_forwards_matches() {
let egraph: EGraph<Arith> = EGraph::new();
let inner: Box<dyn Rule<Arith>> = Box::new(PairConstSelfRule);
let rule = DeviceAwareRule::new(inner, || true);
assert!(
rule.matches(&egraph).is_empty(),
"empty egraph must yield empty matches even with predicate true"
);
let mut egraph: EGraph<Arith> = EGraph::new();
let _a = egraph.add(Arith::Const(7));
let inner: Box<dyn Rule<Arith>> = Box::new(PairConstSelfRule);
let rule = DeviceAwareRule::new(inner, || true);
assert!(
!rule.matches(&egraph).is_empty(),
"predicate true must forward the inner rule's matches"
);
}
#[test]
fn device_aware_rule_predicate_false_returns_empty() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
let _ = egraph.add(Arith::Const(7)); let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || false);
let matches = rule.matches(&egraph);
assert!(
matches.is_empty(),
"predicate false must short-circuit to empty"
);
}
#[test]
fn device_aware_rule_forwards_inner_name() {
let inner: Box<dyn Rule<Arith>> = Box::new(UnionEqualConstsRule);
let rule = DeviceAwareRule::new(inner, || true);
assert_eq!(rule.name(), "union_equal_consts");
}
struct ConstUnionFamily {
name: &'static str,
}
impl Family<Arith> for ConstUnionFamily {
fn name(&self) -> &'static str {
self.name
}
fn rules(&self) -> Vec<Box<dyn Rule<Arith>>> {
vec![Box::new(UnionEqualConstsRule)]
}
}
#[test]
fn saturate_per_family_skips_zero_budget() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(7));
let fam = ConstUnionFamily { name: "f0" };
let report = saturate_per_family(&mut egraph, &[&fam], |_| 0);
assert_eq!(report.len(), 1);
assert_eq!(report[0].family, "f0");
assert_eq!(report[0].iters_used, 0);
assert_eq!(report[0].budget, 0);
}
#[test]
fn saturate_per_family_runs_each_family_independently() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(1));
let _ = egraph.add(Arith::Const(2));
let fam_a = ConstUnionFamily { name: "alpha" };
let fam_b = ConstUnionFamily { name: "beta" };
let report = saturate_per_family(&mut egraph, &[&fam_a, &fam_b], |name| match name {
"alpha" => 3,
"beta" => 5,
_ => 0,
});
assert_eq!(report.len(), 2);
assert_eq!(report[0].family, "alpha");
assert_eq!(report[0].budget, 3);
assert!(report[0].iters_used <= 3);
assert_eq!(report[1].family, "beta");
assert_eq!(report[1].budget, 5);
assert!(report[1].iters_used <= 5);
}
#[test]
fn saturate_per_family_empty_input_returns_empty() {
let mut egraph: EGraph<Arith> = EGraph::new();
let report = saturate_per_family(&mut egraph, &[], |_| 10);
assert!(report.is_empty());
}
#[test]
fn saturate_per_family_reports_iters_used_le_budget() {
let mut egraph: EGraph<Arith> = EGraph::new();
let _ = egraph.add(Arith::Const(1));
let _ = egraph.add(Arith::Const(2));
let fam = ConstUnionFamily { name: "single" };
let report = saturate_per_family(&mut egraph, &[&fam], |_| 100);
assert_eq!(report.len(), 1);
assert!(
report[0].iters_used <= report[0].budget,
"iters_used ({}) must not exceed budget ({})",
report[0].iters_used,
report[0].budget
);
}
#[test]
fn iter_nodes_visits_only_canonical_classes() {
let mut egraph: EGraph<Arith> = EGraph::new();
let a = egraph.add(Arith::Const(1));
let b = egraph.add(Arith::Const(2));
egraph.union(a, b);
let _ = egraph.rebuild();
let unique_classes: FxHashSet<EClassId> = egraph.iter_nodes().map(|(cid, _)| cid).collect();
assert_eq!(
unique_classes.len(),
1,
"post-union iter must visit exactly one canonical class id"
);
let total_nodes = egraph.iter_nodes().count();
assert_eq!(
total_nodes, 2,
"the merged class still holds both original nodes (Const(1) + Const(2))"
);
}
#[test]
fn fallible_find_reports_foreign_class_id() {
let mut egraph: EGraph<Arith> = EGraph::new();
let err = egraph
.try_find(EClassId(0))
.expect_err("empty graph must reject foreign class id 0");
assert!(
matches!(
err,
EGraphError::ClassIdOutOfBounds {
context: "egraph find",
id: EClassId(0),
len: 0
}
),
"unexpected error: {err}"
);
}
#[test]
fn fallible_add_rejects_foreign_child_ids() {
let mut egraph: EGraph<Arith> = EGraph::new();
let err = egraph
.try_add(Arith::Add(EClassId(0), EClassId(0)))
.expect_err("foreign children must be rejected before insertion");
assert!(
matches!(
err,
EGraphError::ClassIdOutOfBounds {
context: "egraph immutable find",
id: EClassId(0),
len: 0
}
),
"unexpected error: {err}"
);
assert_eq!(
egraph.class_count(),
0,
"failed fallible insertion must not allocate a partial class"
);
}
#[test]
fn fallible_add_handles_duplicate_children_without_late_allocation_path() {
let mut egraph: EGraph<Arith> =
EGraph::try_with_capacity(2).expect("Fix: unit-test oracle precondition - small egraph reservation must succeed");
let one = egraph
.try_add(Arith::Const(1))
.expect("Fix: unit-test oracle precondition - const insert must succeed");
let add = egraph
.try_add(Arith::Add(one, one))
.expect("Fix: unit-test oracle precondition - duplicate child registration must be pre-reserved");
let class = egraph
.try_class(add)
.expect("Fix: unit-test oracle precondition - class lookup must be valid")
.expect("Fix: unit-test oracle precondition - class must exist");
assert!(matches!(class.nodes[0], Arith::Add(_, _)));
}
#[test]
fn try_class_id_from_index_rejects_overflow() {
if usize::BITS <= u32::BITS {
return;
}
let overflow_index = (u32::MAX as usize) + 1;
let err = try_eclass_id_from_index(overflow_index)
.expect_err("overflowing class index must be rejected");
assert_eq!(
err,
EGraphError::ClassIdOverflow {
index: overflow_index
}
);
}
#[test]
fn fallible_saturate_and_extract_match_infallible_contracts() {
let mut egraph: EGraph<Arith> =
EGraph::try_with_capacity(4).expect("Fix: unit-test oracle precondition - small egraph reservation must succeed");
let one = egraph.try_add(Arith::Const(1)).expect("Fix: unit-test oracle precondition - insert one");
let two = egraph.try_add(Arith::Const(2)).expect("Fix: unit-test oracle precondition - insert two");
let three = egraph.try_add(Arith::Const(3)).expect("Fix: unit-test oracle precondition - insert three");
let add_12 = egraph.try_add(Arith::Add(one, two)).expect("Fix: unit-test oracle precondition - insert add");
egraph
.try_union(add_12, three)
.expect("Fix: unit-test oracle precondition - union equivalent nodes");
egraph.try_rebuild().expect("Fix: unit-test oracle precondition - rebuild equivalent nodes");
let rules: Vec<Box<dyn Rule<Arith>>> = vec![Box::new(UnionEqualConstsRule)];
let iters = try_saturate(&mut egraph, &rules, 10).expect("Fix: unit-test oracle precondition - fallible saturation");
assert!(iters <= 10);
let (best, cost) = try_extract_best(&egraph, add_12, arith_cost)
.expect("Fix: unit-test oracle precondition - fallible extraction")
.expect("Fix: unit-test oracle precondition - best node must exist");
assert_eq!(best, Arith::Const(3));
assert_eq!(cost, 1);
}
#[test]
fn eqsat_production_uses_fallible_staging_and_checked_ids() {
let src = std::fs::read_to_string(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/optimizer/eqsat.rs"
))
.expect("Fix: eqsat source must be readable");
let production = src
.split("#[cfg(test)]")
.next()
.expect("Fix: unit-test oracle precondition - production section must exist");
assert!(
!production.contains("unwrap_or(u32::MAX)"),
"class ids must never saturate to a poisoned sentinel"
);
assert!(
!production.contains("Vec::with_capacity("),
"optimizer staging must use fallible reservations"
);
assert!(
!production.contains("with_capacity_and_hasher("),
"hashcons staging must use fallible reservations"
);
assert!(
!production.contains(" as usize"),
"EClassId indexing must go through checked conversion helpers"
);
assert!(production.contains("try_with_capacity"));
assert!(production.contains("try_saturate"));
assert!(production.contains("try_extract_best"));
assert!(
!production.contains(".expect("),
"legacy egraph compatibility APIs must not panic in production; callers that need hard errors use try_*"
);
}
}