use rustc_hash::FxHashMap;
use std::sync::Arc;
use super::eqsat::{EClassId, EGraph, ENodeLang};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct SnapshotRow {
pub eclass_id: u32,
pub language_op_id: u32,
pub children_offset: u32,
pub children_len: u32,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct Equivalence {
pub left: u32,
pub right: u32,
}
#[derive(Clone, Debug, Default)]
pub struct GpuEGraphSnapshot {
pub rows: Vec<SnapshotRow>,
pub children: Vec<u32>,
pub op_ids: OpIdRegistry,
}
#[derive(Clone, Debug, Default)]
pub struct OpIdRegistry {
by_name: FxHashMap<Arc<str>, u32>,
names: Vec<Arc<str>>,
}
impl OpIdRegistry {
pub fn intern(&mut self, name: &str) -> u32 {
if let Some(&id) = self.by_name.get(name) {
return id;
}
let id = self.names.len() as u32;
let name: Arc<str> = Arc::from(name);
self.names.push(Arc::clone(&name));
self.by_name.insert(name, id);
id
}
#[must_use]
pub fn name_of(&self, id: u32) -> Option<&str> {
self.names.get(id as usize).map(AsRef::as_ref)
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
}
impl GpuEGraphSnapshot {
#[must_use]
pub fn build<'a, I>(rows: I) -> Self
where
I: IntoIterator<Item = (u32, &'a str, &'a [u32])>,
{
let mut snapshot = Self::default();
let rows = rows.into_iter();
let (lower_bound, _) = rows.size_hint();
snapshot.rows.reserve(lower_bound);
for (eclass_id, op_name, kids) in rows {
let language_op_id = snapshot.op_ids.intern(op_name);
let children_offset = snapshot.children.len() as u32;
let children_len = kids.len() as u32;
snapshot.children.extend_from_slice(kids);
snapshot.rows.push(SnapshotRow {
eclass_id,
language_op_id,
children_offset,
children_len,
});
}
snapshot
}
#[must_use]
pub fn from_egraph_with<L, F, S>(egraph: &EGraph<L>, mut op_name: F) -> Self
where
L: ENodeLang,
F: FnMut(&L) -> S,
S: AsRef<str>,
{
let mut snapshot = Self::default();
snapshot.rows.reserve(egraph.class_count());
for (eclass_id, node) in egraph.iter_nodes() {
let language_op_id = snapshot.op_ids.intern(op_name(node).as_ref());
let children = node.children();
let children_offset = snapshot.children.len() as u32;
let children_len = children.len() as u32;
snapshot
.children
.extend(children.iter().map(|child| egraph.find_immut(*child).0));
snapshot.rows.push(SnapshotRow {
eclass_id: egraph.find_immut(eclass_id).0,
language_op_id,
children_offset,
children_len,
});
}
snapshot
}
#[must_use]
pub fn node_count(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn child_count(&self) -> usize {
self.children.len()
}
#[must_use]
pub fn children_of(&self, row_idx: usize) -> Option<&[u32]> {
let row = self.rows.get(row_idx)?;
let start = row.children_offset as usize;
let end = start.checked_add(row.children_len as usize)?;
self.children.get(start..end)
}
#[must_use]
pub fn rows_by_eclass(&self) -> FxHashMap<u32, Vec<usize>> {
let mut out: FxHashMap<u32, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(self.rows.len(), Default::default());
for (i, row) in self.rows.iter().enumerate() {
out.entry(row.eclass_id).or_default().push(i);
}
out
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct ApplyEquivalencesReport {
pub requested: usize,
pub valid: usize,
pub merged: usize,
pub rebuild_unions: usize,
}
pub fn apply_equivalences<F>(equivalences: &[Equivalence], mut merger: F) -> usize
where
F: FnMut(u32, u32) -> bool,
{
let mut applied = 0usize;
for eq in equivalences {
if merger(eq.left, eq.right) {
applied += 1;
}
}
applied
}
pub fn apply_equivalences_to_egraph<L>(
egraph: &mut EGraph<L>,
equivalences: &[Equivalence],
) -> ApplyEquivalencesReport
where
L: ENodeLang,
{
let mut report = ApplyEquivalencesReport {
requested: equivalences.len(),
..ApplyEquivalencesReport::default()
};
let class_count = egraph.class_count() as u32;
for eq in equivalences {
if eq.left >= class_count || eq.right >= class_count {
continue;
}
report.valid += 1;
let left = EClassId(eq.left);
let right = EClassId(eq.right);
if egraph.find(left) != egraph.find(right) {
egraph.union(left, right);
report.merged += 1;
}
}
report.rebuild_unions = egraph.rebuild();
report
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::Hash;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
enum TinyLang {
Lit(u32),
Add(EClassId, EClassId),
}
impl ENodeLang for TinyLang {
fn children(&self) -> super::super::eqsat::EChildren {
match self {
Self::Lit(_) => super::super::eqsat::EChildren::new(),
Self::Add(left, right) => [*left, *right].into_iter().collect(),
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Self::Lit(value) => Self::Lit(*value),
Self::Add(_, _) => Self::Add(children[0], children[1]),
}
}
}
#[test]
fn empty_snapshot() {
let snap = GpuEGraphSnapshot::default();
assert!(snap.is_empty());
assert_eq!(snap.node_count(), 0);
assert_eq!(snap.child_count(), 0);
assert!(snap.op_ids.is_empty());
}
#[test]
fn build_three_node_snapshot() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(1u32, "lit_u32", &[][..]),
(2u32, "binop_add", &[0u32, 1u32][..]),
]);
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
let empty: &[u32] = &[];
assert_eq!(snap.children_of(0), Some(empty));
assert_eq!(snap.children_of(1), Some(empty));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
assert_eq!(snap.children_of(99), None);
}
#[test]
fn op_id_intern_dedups() {
let mut reg = OpIdRegistry::default();
let a = reg.intern("foo");
let b = reg.intern("bar");
let c = reg.intern("foo");
assert_eq!(a, c);
assert_ne!(a, b);
assert_eq!(reg.len(), 2);
assert_eq!(reg.name_of(a), Some("foo"));
assert_eq!(reg.name_of(b), Some("bar"));
assert_eq!(reg.name_of(99), None);
}
#[test]
fn rows_by_eclass_groups_correctly() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(0u32, "var", &[][..]),
(1u32, "binop_add", &[0u32][..]),
]);
let groups = snap.rows_by_eclass();
assert_eq!(groups.len(), 2);
assert_eq!(groups.get(&0).unwrap().len(), 2);
assert_eq!(groups.get(&1).unwrap().len(), 1);
}
#[test]
fn snapshot_from_egraph_uses_canonical_children() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let add = egraph.add(TinyLang::Add(a, b));
assert_eq!(add.0, 2);
let snap = GpuEGraphSnapshot::from_egraph_with(&egraph, |node| match node {
TinyLang::Lit(_) => "lit",
TinyLang::Add(_, _) => "add",
});
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
assert_eq!(snap.op_ids.name_of(0), Some("lit"));
assert_eq!(snap.op_ids.name_of(1), Some("add"));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
}
#[test]
fn apply_equivalences_counts_state_changes() {
let equivalences = vec![
Equivalence { left: 0, right: 1 },
Equivalence { left: 1, right: 0 }, Equivalence { left: 2, right: 3 },
];
let mut canonical: FxHashMap<u32, u32> = FxHashMap::default();
let applied = apply_equivalences(&equivalences, |a, b| {
let canon_a = *canonical.get(&a).unwrap_or(&a);
let canon_b = *canonical.get(&b).unwrap_or(&b);
if canon_a == canon_b {
false
} else {
let (lo, hi) = if canon_a < canon_b {
(canon_a, canon_b)
} else {
(canon_b, canon_a)
};
canonical.insert(hi, lo);
canonical.insert(a, lo);
canonical.insert(b, lo);
true
}
});
assert_eq!(applied, 2);
}
#[test]
fn apply_equivalences_empty_batch() {
let applied = apply_equivalences(&[], |_, _| true);
assert_eq!(applied, 0);
}
#[test]
fn apply_equivalences_to_egraph_merges_valid_ids() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let c = egraph.add(TinyLang::Lit(3));
let report = apply_equivalences_to_egraph(
&mut egraph,
&[
Equivalence {
left: a.0,
right: b.0,
},
Equivalence {
left: c.0,
right: 99,
},
],
);
assert_eq!(
report,
ApplyEquivalencesReport {
requested: 2,
valid: 1,
merged: 1,
rebuild_unions: 0,
}
);
assert_eq!(egraph.find(a), egraph.find(b));
assert_ne!(egraph.find(a), egraph.find(c));
}
}