use rustc_hash::{FxHashMap, FxHashSet};
use std::fmt;
use std::sync::Arc;
use super::eqsat::{EClassId, EGraph, ENodeLang};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct SnapshotRow {
pub eclass_id: u32,
pub language_op_id: u32,
pub children_offset: u32,
pub children_len: u32,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct Equivalence {
pub left: u32,
pub right: u32,
}
#[derive(Clone, Debug, Default)]
pub struct GpuEGraphSnapshot {
pub rows: Vec<SnapshotRow>,
pub children: Vec<u32>,
pub op_ids: OpIdRegistry,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuEGraphSnapshotError {
context: &'static str,
value: usize,
}
impl GpuEGraphSnapshotError {
fn new(context: &'static str, value: usize) -> Self {
Self { context, value }
}
#[must_use]
pub const fn context(&self) -> &'static str {
self.context
}
#[must_use]
pub const fn value(&self) -> usize {
self.value
}
}
impl fmt::Display for GpuEGraphSnapshotError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GPU e-graph snapshot {} value {} exceeds the u32 column ABI. Fix: shard the e-graph snapshot or widen the GPU snapshot ABI before upload.",
self.context, self.value
)
}
}
impl std::error::Error for GpuEGraphSnapshotError {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuEGraphSnapshotIntegrityError {
context: &'static str,
row: usize,
value: u32,
}
impl GpuEGraphSnapshotIntegrityError {
fn new(context: &'static str, row: usize, value: u32) -> Self {
Self {
context,
row,
value,
}
}
#[must_use]
pub const fn context(&self) -> &'static str {
self.context
}
#[must_use]
pub const fn row(&self) -> usize {
self.row
}
#[must_use]
pub const fn value(&self) -> u32 {
self.value
}
}
impl fmt::Display for GpuEGraphSnapshotIntegrityError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GPU e-graph snapshot integrity error at row {}: {} value {} is invalid. Fix: rebuild the snapshot from canonical e-graph rows before upload.",
self.row, self.context, self.value
)
}
}
impl std::error::Error for GpuEGraphSnapshotIntegrityError {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum GpuEGraphDeviceImageError {
Integrity(GpuEGraphSnapshotIntegrityError),
Layout(GpuEGraphSnapshotError),
}
impl fmt::Display for GpuEGraphDeviceImageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Integrity(error) => error.fmt(f),
Self::Layout(error) => error.fmt(f),
}
}
}
impl std::error::Error for GpuEGraphDeviceImageError {}
impl From<GpuEGraphSnapshotIntegrityError> for GpuEGraphDeviceImageError {
fn from(error: GpuEGraphSnapshotIntegrityError) -> Self {
Self::Integrity(error)
}
}
impl From<GpuEGraphSnapshotError> for GpuEGraphDeviceImageError {
fn from(error: GpuEGraphSnapshotError) -> Self {
Self::Layout(error)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceSpan {
offset: usize,
len: usize,
}
impl GpuEGraphDeviceSpan {
const fn new(offset: usize, len: usize) -> Self {
Self { offset, len }
}
#[must_use]
pub const fn offset(&self) -> usize {
self.offset
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
fn slice<'a>(&self, words: &'a [u32]) -> &'a [u32] {
&words[self.offset..self.offset + self.len]
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceLayout {
row_count: usize,
child_count: usize,
eclass_group_count: usize,
row_eclass_ids: GpuEGraphDeviceSpan,
row_language_op_ids: GpuEGraphDeviceSpan,
row_children_offsets: GpuEGraphDeviceSpan,
row_children_lens: GpuEGraphDeviceSpan,
row_signatures: GpuEGraphDeviceSpan,
children: GpuEGraphDeviceSpan,
group_eclass_ids: GpuEGraphDeviceSpan,
group_offsets: GpuEGraphDeviceSpan,
group_rows: GpuEGraphDeviceSpan,
}
impl GpuEGraphDeviceLayout {
#[must_use]
pub const fn row_count(&self) -> usize {
self.row_count
}
#[must_use]
pub const fn child_count(&self) -> usize {
self.child_count
}
#[must_use]
pub const fn eclass_group_count(&self) -> usize {
self.eclass_group_count
}
#[must_use]
pub const fn row_eclass_ids(&self) -> GpuEGraphDeviceSpan {
self.row_eclass_ids
}
#[must_use]
pub const fn row_language_op_ids(&self) -> GpuEGraphDeviceSpan {
self.row_language_op_ids
}
#[must_use]
pub const fn row_children_offsets(&self) -> GpuEGraphDeviceSpan {
self.row_children_offsets
}
#[must_use]
pub const fn row_children_lens(&self) -> GpuEGraphDeviceSpan {
self.row_children_lens
}
#[must_use]
pub const fn row_signatures(&self) -> GpuEGraphDeviceSpan {
self.row_signatures
}
#[must_use]
pub const fn children(&self) -> GpuEGraphDeviceSpan {
self.children
}
#[must_use]
pub const fn group_eclass_ids(&self) -> GpuEGraphDeviceSpan {
self.group_eclass_ids
}
#[must_use]
pub const fn group_offsets(&self) -> GpuEGraphDeviceSpan {
self.group_offsets
}
#[must_use]
pub const fn group_rows(&self) -> GpuEGraphDeviceSpan {
self.group_rows
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceImage {
words: Vec<u32>,
layout: GpuEGraphDeviceLayout,
}
impl GpuEGraphDeviceImage {
#[must_use]
pub fn words(&self) -> &[u32] {
&self.words
}
#[must_use]
pub const fn layout(&self) -> GpuEGraphDeviceLayout {
self.layout
}
#[must_use]
pub fn row_eclass_ids(&self) -> &[u32] {
self.layout.row_eclass_ids.slice(&self.words)
}
#[must_use]
pub fn row_language_op_ids(&self) -> &[u32] {
self.layout.row_language_op_ids.slice(&self.words)
}
#[must_use]
pub fn row_children_offsets(&self) -> &[u32] {
self.layout.row_children_offsets.slice(&self.words)
}
#[must_use]
pub fn row_children_lens(&self) -> &[u32] {
self.layout.row_children_lens.slice(&self.words)
}
#[must_use]
pub fn row_signatures(&self) -> &[u32] {
self.layout.row_signatures.slice(&self.words)
}
#[must_use]
pub fn children(&self) -> &[u32] {
self.layout.children.slice(&self.words)
}
#[must_use]
pub fn group_eclass_ids(&self) -> &[u32] {
self.layout.group_eclass_ids.slice(&self.words)
}
#[must_use]
pub fn group_offsets(&self) -> &[u32] {
self.layout.group_offsets.slice(&self.words)
}
#[must_use]
pub fn group_rows(&self) -> &[u32] {
self.layout.group_rows.slice(&self.words)
}
}
#[derive(Clone, Debug, Default)]
pub struct OpIdRegistry {
by_name: FxHashMap<Arc<str>, u32>,
names: Vec<Arc<str>>,
}
impl OpIdRegistry {
pub fn intern(&mut self, name: &str) -> u32 {
self.try_intern(name)
.unwrap_or_else(|error| panic!("{error}"))
}
pub fn try_intern(&mut self, name: &str) -> Result<u32, GpuEGraphSnapshotError> {
if let Some(&id) = self.by_name.get(name) {
return Ok(id);
}
let id = u32_len(self.names.len(), "op-id registry")?;
let name: Arc<str> = Arc::from(name);
self.names.push(Arc::clone(&name));
self.by_name.insert(name, id);
Ok(id)
}
#[must_use]
pub fn name_of(&self, id: u32) -> Option<&str> {
self.names.get(id as usize).map(AsRef::as_ref)
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
}
impl GpuEGraphSnapshot {
#[must_use]
pub fn build<'a, I>(rows: I) -> Self
where
I: IntoIterator<Item = (u32, &'a str, &'a [u32])>,
{
Self::try_build(rows).unwrap_or_else(|error| panic!("{error}"))
}
pub fn try_build<'a, I>(rows: I) -> Result<Self, GpuEGraphSnapshotError>
where
I: IntoIterator<Item = (u32, &'a str, &'a [u32])>,
{
let mut snapshot = Self::default();
let rows = rows.into_iter();
let (lower_bound, _) = rows.size_hint();
snapshot.rows.reserve(lower_bound);
for (eclass_id, op_name, kids) in rows {
let language_op_id = snapshot.op_ids.try_intern(op_name)?;
let children_offset = u32_len(snapshot.children.len(), "GPU egraph children offset")?;
let children_len = u32_len(kids.len(), "GPU egraph row child count")?;
snapshot.children.extend_from_slice(kids);
snapshot.rows.push(SnapshotRow {
eclass_id,
language_op_id,
children_offset,
children_len,
});
}
Ok(snapshot)
}
#[must_use]
pub fn from_egraph_with<L, F, S>(egraph: &EGraph<L>, mut op_name: F) -> Self
where
L: ENodeLang,
F: FnMut(&L) -> S,
S: AsRef<str>,
{
Self::try_from_egraph_with(egraph, &mut op_name).unwrap_or_else(|error| panic!("{error}"))
}
pub fn try_from_egraph_with<L, F, S>(
egraph: &EGraph<L>,
mut op_name: F,
) -> Result<Self, GpuEGraphSnapshotError>
where
L: ENodeLang,
F: FnMut(&L) -> S,
S: AsRef<str>,
{
let mut snapshot = Self::default();
snapshot.rows.reserve(egraph.class_count());
for (eclass_id, node) in egraph.iter_nodes() {
let language_op_id = snapshot.op_ids.try_intern(op_name(node).as_ref())?;
let children = node.children();
let children_offset = u32_len(snapshot.children.len(), "GPU egraph children offset")?;
let children_len = u32_len(children.len(), "GPU egraph row child count")?;
snapshot
.children
.extend(children.iter().map(|child| egraph.find_immut(*child).0));
snapshot.rows.push(SnapshotRow {
eclass_id: egraph.find_immut(eclass_id).0,
language_op_id,
children_offset,
children_len,
});
}
Ok(snapshot)
}
#[must_use]
pub fn node_count(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn child_count(&self) -> usize {
self.children.len()
}
#[must_use]
pub fn children_of(&self, row_idx: usize) -> Option<&[u32]> {
let row = self.rows.get(row_idx)?;
let start = row.children_offset as usize;
let end = start.checked_add(row.children_len as usize)?;
self.children.get(start..end)
}
#[must_use]
pub fn rows_by_eclass(&self) -> FxHashMap<u32, Vec<usize>> {
let mut out: FxHashMap<u32, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(self.rows.len(), Default::default());
for (i, row) in self.rows.iter().enumerate() {
out.entry(row.eclass_id).or_default().push(i);
}
out
}
pub fn validate_integrity(&self) -> Result<(), GpuEGraphSnapshotIntegrityError> {
let mut eclasses: FxHashSet<u32> =
FxHashSet::with_capacity_and_hasher(self.rows.len(), Default::default());
for row in &self.rows {
eclasses.insert(row.eclass_id);
}
for (row_idx, row) in self.rows.iter().enumerate() {
if self.op_ids.name_of(row.language_op_id).is_none() {
return Err(GpuEGraphSnapshotIntegrityError::new(
"unknown language_op_id",
row_idx,
row.language_op_id,
));
}
let start = row.children_offset as usize;
let end = start
.checked_add(row.children_len as usize)
.ok_or_else(|| {
GpuEGraphSnapshotIntegrityError::new(
"children range overflow",
row_idx,
row.children_len,
)
})?;
if end > self.children.len() {
return Err(GpuEGraphSnapshotIntegrityError::new(
"children range end",
row_idx,
row.children_len,
));
}
for &child in &self.children[start..end] {
if !eclasses.contains(&child) {
return Err(GpuEGraphSnapshotIntegrityError::new(
"dangling child eclass",
row_idx,
child,
));
}
}
}
Ok(())
}
pub fn try_pack_device_image(&self) -> Result<GpuEGraphDeviceImage, GpuEGraphDeviceImageError> {
self.validate_integrity()?;
let mut groups: FxHashMap<u32, Vec<u32>> =
FxHashMap::with_capacity_and_hasher(self.rows.len(), Default::default());
for (row_idx, row) in self.rows.iter().enumerate() {
groups
.entry(row.eclass_id)
.or_default()
.push(u32_len(row_idx, "GPU egraph grouped row index")?);
}
let mut group_eclass_ids = groups.keys().copied().collect::<Vec<_>>();
group_eclass_ids.sort_unstable();
let mut group_offsets = Vec::with_capacity(group_eclass_ids.len() + 1);
let mut group_rows = Vec::with_capacity(self.rows.len());
for eclass_id in &group_eclass_ids {
group_offsets.push(u32_len(group_rows.len(), "GPU egraph group row offset")?);
let Some(rows) = groups.get(eclass_id) else {
return Err(GpuEGraphSnapshotIntegrityError::new(
"missing grouped eclass key",
0,
*eclass_id,
)
.into());
};
group_rows.extend_from_slice(rows);
}
group_offsets.push(u32_len(
group_rows.len(),
"GPU egraph group row terminal offset",
)?);
let row_signatures = self
.rows
.iter()
.map(|row| {
let start = row.children_offset as usize;
let end = start + row.children_len as usize;
egraph_row_signature(row, &self.children[start..end])
})
.collect::<Vec<_>>();
let mut words = Vec::with_capacity(
self.rows.len() * 5
+ self.children.len()
+ group_eclass_ids.len()
+ group_offsets.len()
+ group_rows.len(),
);
let row_eclass_ids = append_words(&mut words, self.rows.iter().map(|row| row.eclass_id));
let row_language_op_ids =
append_words(&mut words, self.rows.iter().map(|row| row.language_op_id));
let row_children_offsets =
append_words(&mut words, self.rows.iter().map(|row| row.children_offset));
let row_children_lens =
append_words(&mut words, self.rows.iter().map(|row| row.children_len));
let row_signatures = append_words(&mut words, row_signatures);
let children = append_words(&mut words, self.children.iter().copied());
let group_eclass_ids_span = append_words(&mut words, group_eclass_ids);
let group_offsets = append_words(&mut words, group_offsets);
let group_rows = append_words(&mut words, group_rows);
Ok(GpuEGraphDeviceImage {
words,
layout: GpuEGraphDeviceLayout {
row_count: self.rows.len(),
child_count: self.children.len(),
eclass_group_count: groups.len(),
row_eclass_ids,
row_language_op_ids,
row_children_offsets,
row_children_lens,
row_signatures,
children,
group_eclass_ids: group_eclass_ids_span,
group_offsets,
group_rows,
},
})
}
#[must_use]
pub fn pack_device_image(&self) -> GpuEGraphDeviceImage {
self.try_pack_device_image()
.unwrap_or_else(|error| panic!("{error}"))
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct ApplyEquivalencesReport {
pub requested: usize,
pub valid: usize,
pub merged: usize,
pub rebuild_unions: usize,
}
pub fn apply_equivalences<F>(equivalences: &[Equivalence], mut merger: F) -> usize
where
F: FnMut(u32, u32) -> bool,
{
let mut applied = 0usize;
for eq in equivalences {
if merger(eq.left, eq.right) {
applied += 1;
}
}
applied
}
pub fn apply_equivalences_to_egraph<L>(
egraph: &mut EGraph<L>,
equivalences: &[Equivalence],
) -> ApplyEquivalencesReport
where
L: ENodeLang,
{
let mut report = ApplyEquivalencesReport {
requested: equivalences.len(),
..ApplyEquivalencesReport::default()
};
let Ok(class_count) = u32_len(egraph.class_count(), "CPU egraph class count") else {
return report;
};
for eq in equivalences {
if eq.left >= class_count || eq.right >= class_count {
continue;
}
report.valid += 1;
let left = EClassId(eq.left);
let right = EClassId(eq.right);
if egraph.find(left) != egraph.find(right) {
egraph.union(left, right);
report.merged += 1;
}
}
report.rebuild_unions = egraph.rebuild();
report
}
#[inline]
fn u32_len(value: usize, context: &'static str) -> Result<u32, GpuEGraphSnapshotError> {
u32::try_from(value).map_err(|_| GpuEGraphSnapshotError::new(context, value))
}
fn append_words<I>(words: &mut Vec<u32>, values: I) -> GpuEGraphDeviceSpan
where
I: IntoIterator<Item = u32>,
{
let offset = words.len();
words.extend(values);
GpuEGraphDeviceSpan::new(offset, words.len() - offset)
}
#[must_use]
pub fn gpu_egraph_row_signature(language_op_id: u32, children_len: u32, children: &[u32]) -> u32 {
let mut hash = mix_egraph_signature(0xA24B_AED4, language_op_id);
hash = mix_egraph_signature(hash, children_len);
for &child in children {
hash = mix_egraph_signature(hash, child);
}
hash
}
fn egraph_row_signature(row: &SnapshotRow, children: &[u32]) -> u32 {
gpu_egraph_row_signature(row.language_op_id, row.children_len, children)
}
fn mix_egraph_signature(hash: u32, value: u32) -> u32 {
let mixed = hash
^ value
.wrapping_add(0x9E37_79B9)
.wrapping_add(hash << 6)
.wrapping_add(hash >> 2);
mixed.rotate_left(13).wrapping_mul(0x85EB_CA6B)
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::Hash;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
enum TinyLang {
Lit(u32),
Add(EClassId, EClassId),
}
impl ENodeLang for TinyLang {
fn children(&self) -> super::super::eqsat::EChildren {
match self {
Self::Lit(_) => super::super::eqsat::EChildren::new(),
Self::Add(left, right) => [*left, *right].into_iter().collect(),
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Self::Lit(value) => Self::Lit(*value),
Self::Add(_, _) => Self::Add(children[0], children[1]),
}
}
}
#[test]
fn empty_snapshot() {
let snap = GpuEGraphSnapshot::default();
assert!(snap.is_empty());
assert_eq!(snap.node_count(), 0);
assert_eq!(snap.child_count(), 0);
assert!(snap.op_ids.is_empty());
}
#[test]
fn build_three_node_snapshot() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(1u32, "lit_u32", &[][..]),
(2u32, "binop_add", &[0u32, 1u32][..]),
]);
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
let empty: &[u32] = &[];
assert_eq!(snap.children_of(0), Some(empty));
assert_eq!(snap.children_of(1), Some(empty));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
assert_eq!(snap.children_of(99), None);
}
#[test]
fn op_id_intern_dedups() {
let mut reg = OpIdRegistry::default();
let a = reg.intern("foo");
let b = reg.intern("bar");
let c = reg.intern("foo");
assert_eq!(a, c);
assert_ne!(a, b);
assert_eq!(reg.len(), 2);
assert_eq!(reg.name_of(a), Some("foo"));
assert_eq!(reg.name_of(b), Some("bar"));
assert_eq!(reg.name_of(99), None);
}
#[test]
fn gpu_snapshot_u32_layout_conversion_rejects_overflow() {
let error = u32_len(u32::MAX as usize + 1, "test overflow")
.expect_err("Fix: GPU e-graph snapshot must not silently saturate oversized columns");
assert_eq!(error.context(), "test overflow");
assert_eq!(error.value(), u32::MAX as usize + 1);
assert!(
error.to_string().contains("shard the e-graph snapshot")
&& error.to_string().contains("widen the GPU snapshot ABI"),
"oversized GPU snapshot errors must explain both viable fixes"
);
}
#[test]
fn gpu_snapshot_builders_use_fallible_u32_conversion_not_saturation() {
let source = include_str!("eqsat_gpu.rs");
let production = source
.split("#[cfg(test)]")
.next()
.expect("Fix: production eqsat_gpu section must exist");
assert!(
source.contains("pub fn try_build")
&& source.contains("pub fn try_from_egraph_with")
&& source.contains("snapshot.op_ids.try_intern")
&& source.contains("u32::try_from(value).map_err")
&& !source.contains(concat!("unwrap_or", "(u32::MAX)")),
"Fix: GPU e-graph snapshots must reject oversized u32 ABI fields instead of saturating them to u32::MAX."
);
assert!(
!production.contains(".expect("),
"Fix: GPU e-graph snapshot production paths must return structured errors instead of panicking."
);
}
#[test]
fn rows_by_eclass_groups_correctly() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(0u32, "var", &[][..]),
(1u32, "binop_add", &[0u32][..]),
]);
let groups = snap.rows_by_eclass();
assert_eq!(groups.len(), 2);
assert_eq!(groups.get(&0).unwrap().len(), 2);
assert_eq!(groups.get(&1).unwrap().len(), 1);
}
#[test]
fn generated_snapshot_integrity_accepts_pack_boundaries_and_forward_children() {
for node_count in [1_usize, 2, 7, 8, 9, 16, 17, 31, 32, 33, 65, 128] {
let mut rows = Vec::with_capacity(node_count);
let mut child_storage = Vec::new();
for row in 0..node_count {
let start = child_storage.len();
if row > 0 {
child_storage.push((row - 1) as u32);
}
if row > 1 && row % 3 == 0 {
child_storage.push((row / 2) as u32);
}
rows.push((
row as u32,
if row % 2 == 0 { "lit" } else { "add" },
start,
child_storage.len() - start,
));
}
let build_rows = rows
.iter()
.map(|&(class, op, start, len)| (class, op, &child_storage[start..start + len]))
.collect::<Vec<_>>();
let snapshot = GpuEGraphSnapshot::build(build_rows);
snapshot
.validate_integrity()
.unwrap_or_else(|error| panic!("node_count={node_count}: {error}"));
}
}
#[test]
fn snapshot_integrity_rejects_unknown_op_id() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].language_op_id = 99;
let error = snapshot
.validate_integrity()
.expect_err("Fix: malformed GPU snapshot op ids must be rejected before upload.");
assert_eq!(error.context(), "unknown language_op_id");
assert_eq!(error.row(), 0);
assert_eq!(error.value(), 99);
}
#[test]
fn snapshot_integrity_rejects_out_of_bounds_child_range() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].children_offset = 1;
snapshot.rows[0].children_len = 1;
let error = snapshot
.validate_integrity()
.expect_err("Fix: malformed GPU snapshot child ranges must be rejected before upload.");
assert_eq!(error.context(), "children range end");
assert_eq!(error.row(), 0);
}
#[test]
fn snapshot_integrity_rejects_dangling_child_eclass() {
let snapshot =
GpuEGraphSnapshot::build([(0u32, "lit", &[][..]), (1u32, "add", &[0u32, 99u32][..])]);
let error = snapshot.validate_integrity().expect_err(
"Fix: malformed GPU snapshot child eclasses must be rejected before upload.",
);
assert_eq!(error.context(), "dangling child eclass");
assert_eq!(error.row(), 1);
assert_eq!(error.value(), 99);
}
#[test]
fn device_image_packs_single_upload_slab_with_sorted_group_index() {
let snapshot = GpuEGraphSnapshot::build([
(2u32, "lit", &[][..]),
(1u32, "lit", &[][..]),
(2u32, "add", &[1u32, 2u32][..]),
]);
let image = snapshot
.try_pack_device_image()
.expect("Fix: valid GPU e-graph snapshot must pack into a device image");
let layout = image.layout();
assert_eq!(layout.row_count(), 3);
assert_eq!(layout.child_count(), 2);
assert_eq!(layout.eclass_group_count(), 2);
assert_eq!(image.row_eclass_ids(), &[2, 1, 2]);
assert_eq!(image.row_language_op_ids(), &[0, 0, 1]);
assert_eq!(image.row_children_offsets(), &[0, 0, 0]);
assert_eq!(image.row_children_lens(), &[0, 0, 2]);
assert_eq!(image.row_signatures().len(), 3);
assert_ne!(image.row_signatures()[0], image.row_signatures()[2]);
assert_eq!(image.children(), &[1, 2]);
assert_eq!(image.group_eclass_ids(), &[1, 2]);
assert_eq!(image.group_offsets(), &[0, 1, 3]);
assert_eq!(image.group_rows(), &[1, 0, 2]);
assert_eq!(
image.words().len(),
layout.row_eclass_ids().len()
+ layout.row_language_op_ids().len()
+ layout.row_children_offsets().len()
+ layout.row_children_lens().len()
+ layout.row_signatures().len()
+ layout.children().len()
+ layout.group_eclass_ids().len()
+ layout.group_offsets().len()
+ layout.group_rows().len()
);
}
#[test]
fn generated_device_image_pack_accepts_empty_and_power_boundaries() {
for node_count in [0_usize, 1, 2, 7, 8, 9, 31, 32, 33, 127, 128, 129] {
let mut rows = Vec::with_capacity(node_count);
let mut child_storage = Vec::new();
for row in 0..node_count {
let start = child_storage.len();
if row > 0 {
child_storage.push((row - 1) as u32);
}
rows.push((
row as u32,
if row & 1 == 0 { "lit" } else { "neg" },
start,
child_storage.len() - start,
));
}
let build_rows = rows
.iter()
.map(|&(class, op, start, len)| (class, op, &child_storage[start..start + len]))
.collect::<Vec<_>>();
let snapshot = GpuEGraphSnapshot::build(build_rows);
let image = snapshot
.try_pack_device_image()
.unwrap_or_else(|error| panic!("node_count={node_count}: {error}"));
assert_eq!(image.layout().row_count(), node_count);
assert_eq!(image.row_eclass_ids().len(), node_count);
assert_eq!(image.row_language_op_ids().len(), node_count);
assert_eq!(image.row_signatures().len(), node_count);
assert_eq!(image.group_rows().len(), node_count);
assert_eq!(image.group_offsets().len(), node_count + 1);
}
}
#[test]
fn row_signatures_group_structural_duplicates_without_eclass_identity() {
let snapshot = GpuEGraphSnapshot::build([
(1u32, "lit", &[][..]),
(2u32, "lit", &[][..]),
(10u32, "add", &[1u32, 2u32][..]),
(11u32, "add", &[1u32, 2u32][..]),
(12u32, "add", &[2u32, 1u32][..]),
(13u32, "mul", &[1u32, 2u32][..]),
]);
let image = snapshot
.try_pack_device_image()
.expect("Fix: valid duplicate-signature snapshot must pack");
assert_eq!(image.row_signatures()[2], image.row_signatures()[3]);
assert_ne!(image.row_signatures()[2], image.row_signatures()[4]);
assert_ne!(image.row_signatures()[2], image.row_signatures()[5]);
}
#[test]
fn device_image_rejects_malformed_snapshot_before_pack() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].language_op_id = 42;
let error = snapshot
.try_pack_device_image()
.expect_err("Fix: device image packing must reject malformed snapshots");
match error {
GpuEGraphDeviceImageError::Integrity(error) => {
assert_eq!(error.context(), "unknown language_op_id");
assert_eq!(error.row(), 0);
assert_eq!(error.value(), 42);
}
GpuEGraphDeviceImageError::Layout(error) => {
panic!("expected integrity error, got layout error: {error}")
}
}
}
#[test]
fn snapshot_from_egraph_uses_canonical_children() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let add = egraph.add(TinyLang::Add(a, b));
assert_eq!(add.0, 2);
let snap = GpuEGraphSnapshot::from_egraph_with(&egraph, |node| match node {
TinyLang::Lit(_) => "lit",
TinyLang::Add(_, _) => "add",
});
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
assert_eq!(snap.op_ids.name_of(0), Some("lit"));
assert_eq!(snap.op_ids.name_of(1), Some("add"));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
}
#[test]
fn apply_equivalences_counts_state_changes() {
let equivalences = vec![
Equivalence { left: 0, right: 1 },
Equivalence { left: 1, right: 0 }, Equivalence { left: 2, right: 3 },
];
let mut canonical: FxHashMap<u32, u32> = FxHashMap::default();
let applied = apply_equivalences(&equivalences, |a, b| {
let canon_a = *canonical.get(&a).unwrap_or(&a);
let canon_b = *canonical.get(&b).unwrap_or(&b);
if canon_a == canon_b {
false
} else {
let (lo, hi) = if canon_a < canon_b {
(canon_a, canon_b)
} else {
(canon_b, canon_a)
};
canonical.insert(hi, lo);
canonical.insert(a, lo);
canonical.insert(b, lo);
true
}
});
assert_eq!(applied, 2);
}
#[test]
fn apply_equivalences_empty_batch() {
let applied = apply_equivalences(&[], |_, _| true);
assert_eq!(applied, 0);
}
#[test]
fn apply_equivalences_to_egraph_merges_valid_ids() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let c = egraph.add(TinyLang::Lit(3));
let report = apply_equivalences_to_egraph(
&mut egraph,
&[
Equivalence {
left: a.0,
right: b.0,
},
Equivalence {
left: c.0,
right: 99,
},
],
);
assert_eq!(
report,
ApplyEquivalencesReport {
requested: 2,
valid: 1,
merged: 1,
rebuild_unions: 0,
}
);
assert_eq!(egraph.find(a), egraph.find(b));
assert_ne!(egraph.find(a), egraph.find(c));
}
}