use rustc_hash::{FxHashMap, FxHashSet};
use std::fmt;
use std::sync::Arc;
use std::time::Instant;
use super::eqsat::{
try_extract_best_with_budget, EClassId, EGraph, EGraphError, ENodeLang,
DEFAULT_EXTRACTION_ITER_BUDGET,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct SnapshotRow {
pub eclass_id: u32,
pub language_op_id: u32,
pub children_offset: u32,
pub children_len: u32,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct Equivalence {
pub left: u32,
pub right: u32,
}
#[derive(Clone, Debug, Default)]
pub struct GpuEGraphSnapshot {
pub rows: Vec<SnapshotRow>,
pub children: Vec<u32>,
pub op_ids: OpIdRegistry,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuEGraphSnapshotError {
context: &'static str,
value: usize,
}
impl GpuEGraphSnapshotError {
fn new(context: &'static str, value: usize) -> Self {
Self { context, value }
}
#[must_use]
pub const fn context(&self) -> &'static str {
self.context
}
#[must_use]
pub const fn value(&self) -> usize {
self.value
}
}
impl fmt::Display for GpuEGraphSnapshotError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GPU e-graph snapshot {} value {} exceeds the u32 column ABI. Fix: shard the e-graph snapshot or widen the GPU snapshot ABI before upload.",
self.context, self.value
)
}
}
impl std::error::Error for GpuEGraphSnapshotError {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuEGraphSnapshotIntegrityError {
context: &'static str,
row: usize,
value: u32,
}
impl GpuEGraphSnapshotIntegrityError {
fn new(context: &'static str, row: usize, value: u32) -> Self {
Self {
context,
row,
value,
}
}
#[must_use]
pub const fn context(&self) -> &'static str {
self.context
}
#[must_use]
pub const fn row(&self) -> usize {
self.row
}
#[must_use]
pub const fn value(&self) -> u32 {
self.value
}
}
impl fmt::Display for GpuEGraphSnapshotIntegrityError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GPU e-graph snapshot integrity error at row {}: {} value {} is invalid. Fix: rebuild the snapshot from canonical e-graph rows before upload.",
self.row, self.context, self.value
)
}
}
impl std::error::Error for GpuEGraphSnapshotIntegrityError {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum GpuEGraphDeviceImageError {
Integrity(GpuEGraphSnapshotIntegrityError),
Layout(GpuEGraphSnapshotError),
}
impl fmt::Display for GpuEGraphDeviceImageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Integrity(error) => error.fmt(f),
Self::Layout(error) => error.fmt(f),
}
}
}
impl std::error::Error for GpuEGraphDeviceImageError {}
impl From<GpuEGraphSnapshotIntegrityError> for GpuEGraphDeviceImageError {
fn from(error: GpuEGraphSnapshotIntegrityError) -> Self {
Self::Integrity(error)
}
}
impl From<GpuEGraphSnapshotError> for GpuEGraphDeviceImageError {
fn from(error: GpuEGraphSnapshotError) -> Self {
Self::Layout(error)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum GpuEGraphBridgeError {
Snapshot(GpuEGraphSnapshotError),
DeviceImage(GpuEGraphDeviceImageError),
EGraph(EGraphError),
}
impl fmt::Display for GpuEGraphBridgeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Snapshot(error) => write!(f, "GPU e-graph bridge snapshot failed: {error}"),
Self::DeviceImage(error) => {
write!(f, "GPU e-graph bridge device image failed: {error}")
}
Self::EGraph(error) => write!(f, "GPU e-graph bridge extraction failed: {error}"),
}
}
}
impl std::error::Error for GpuEGraphBridgeError {}
impl From<GpuEGraphSnapshotError> for GpuEGraphBridgeError {
fn from(error: GpuEGraphSnapshotError) -> Self {
Self::Snapshot(error)
}
}
impl From<GpuEGraphDeviceImageError> for GpuEGraphBridgeError {
fn from(error: GpuEGraphDeviceImageError) -> Self {
Self::DeviceImage(error)
}
}
impl From<EGraphError> for GpuEGraphBridgeError {
fn from(error: EGraphError) -> Self {
Self::EGraph(error)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceSpan {
offset: usize,
len: usize,
}
impl GpuEGraphDeviceSpan {
const fn new(offset: usize, len: usize) -> Self {
Self { offset, len }
}
#[must_use]
pub const fn offset(&self) -> usize {
self.offset
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
fn slice<'a>(&self, words: &'a [u32]) -> &'a [u32] {
&words[self.offset..self.offset + self.len]
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceLayout {
row_count: usize,
child_count: usize,
eclass_group_count: usize,
row_eclass_ids: GpuEGraphDeviceSpan,
row_language_op_ids: GpuEGraphDeviceSpan,
row_children_offsets: GpuEGraphDeviceSpan,
row_children_lens: GpuEGraphDeviceSpan,
row_signatures: GpuEGraphDeviceSpan,
children: GpuEGraphDeviceSpan,
group_eclass_ids: GpuEGraphDeviceSpan,
group_offsets: GpuEGraphDeviceSpan,
group_rows: GpuEGraphDeviceSpan,
}
impl GpuEGraphDeviceLayout {
#[must_use]
pub const fn row_count(&self) -> usize {
self.row_count
}
#[must_use]
pub const fn child_count(&self) -> usize {
self.child_count
}
#[must_use]
pub const fn eclass_group_count(&self) -> usize {
self.eclass_group_count
}
#[must_use]
pub const fn row_eclass_ids(&self) -> GpuEGraphDeviceSpan {
self.row_eclass_ids
}
#[must_use]
pub const fn row_language_op_ids(&self) -> GpuEGraphDeviceSpan {
self.row_language_op_ids
}
#[must_use]
pub const fn row_children_offsets(&self) -> GpuEGraphDeviceSpan {
self.row_children_offsets
}
#[must_use]
pub const fn row_children_lens(&self) -> GpuEGraphDeviceSpan {
self.row_children_lens
}
#[must_use]
pub const fn row_signatures(&self) -> GpuEGraphDeviceSpan {
self.row_signatures
}
#[must_use]
pub const fn children(&self) -> GpuEGraphDeviceSpan {
self.children
}
#[must_use]
pub const fn group_eclass_ids(&self) -> GpuEGraphDeviceSpan {
self.group_eclass_ids
}
#[must_use]
pub const fn group_offsets(&self) -> GpuEGraphDeviceSpan {
self.group_offsets
}
#[must_use]
pub const fn group_rows(&self) -> GpuEGraphDeviceSpan {
self.group_rows
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphDeviceImage {
words: Vec<u32>,
layout: GpuEGraphDeviceLayout,
}
impl GpuEGraphDeviceImage {
#[must_use]
pub fn words(&self) -> &[u32] {
&self.words
}
#[must_use]
pub const fn layout(&self) -> GpuEGraphDeviceLayout {
self.layout
}
#[must_use]
pub fn row_eclass_ids(&self) -> &[u32] {
self.layout.row_eclass_ids.slice(&self.words)
}
#[must_use]
pub fn row_language_op_ids(&self) -> &[u32] {
self.layout.row_language_op_ids.slice(&self.words)
}
#[must_use]
pub fn row_children_offsets(&self) -> &[u32] {
self.layout.row_children_offsets.slice(&self.words)
}
#[must_use]
pub fn row_children_lens(&self) -> &[u32] {
self.layout.row_children_lens.slice(&self.words)
}
#[must_use]
pub fn row_signatures(&self) -> &[u32] {
self.layout.row_signatures.slice(&self.words)
}
#[must_use]
pub fn children(&self) -> &[u32] {
self.layout.children.slice(&self.words)
}
#[must_use]
pub fn group_eclass_ids(&self) -> &[u32] {
self.layout.group_eclass_ids.slice(&self.words)
}
#[must_use]
pub fn group_offsets(&self) -> &[u32] {
self.layout.group_offsets.slice(&self.words)
}
#[must_use]
pub fn group_rows(&self) -> &[u32] {
self.layout.group_rows.slice(&self.words)
}
}
#[derive(Clone, Debug, Default)]
pub struct OpIdRegistry {
by_name: FxHashMap<Arc<str>, u32>,
names: Vec<Arc<str>>,
}
impl OpIdRegistry {
pub fn intern(&mut self, name: &str) -> u32 {
match self.try_intern(name) {
Ok(id) => id,
Err(_) => u32::MAX,
}
}
pub fn try_intern(&mut self, name: &str) -> Result<u32, GpuEGraphSnapshotError> {
if let Some(&id) = self.by_name.get(name) {
return Ok(id);
}
let id = u32_len(self.names.len(), "op-id registry")?;
let name: Arc<str> = Arc::from(name);
self.names.push(Arc::clone(&name));
self.by_name.insert(name, id);
Ok(id)
}
#[must_use]
pub fn name_of(&self, id: u32) -> Option<&str> {
self.names.get(id as usize).map(AsRef::as_ref)
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
}
impl GpuEGraphSnapshot {
#[must_use]
pub fn build<'a, I>(rows: I) -> Self
where
I: IntoIterator<Item = (u32, &'a str, &'a [u32])>,
{
match Self::try_build(rows) {
Ok(snapshot) => snapshot,
Err(_) => Self::default(),
}
}
pub fn try_build<'a, I>(rows: I) -> Result<Self, GpuEGraphSnapshotError>
where
I: IntoIterator<Item = (u32, &'a str, &'a [u32])>,
{
let mut snapshot = Self::default();
let rows = rows.into_iter();
let (lower_bound, _) = rows.size_hint();
snapshot.rows.reserve(lower_bound);
for (eclass_id, op_name, kids) in rows {
let language_op_id = snapshot.op_ids.try_intern(op_name)?;
let children_offset = u32_len(snapshot.children.len(), "GPU egraph children offset")?;
let children_len = u32_len(kids.len(), "GPU egraph row child count")?;
snapshot.children.extend_from_slice(kids);
snapshot.rows.push(SnapshotRow {
eclass_id,
language_op_id,
children_offset,
children_len,
});
}
Ok(snapshot)
}
#[must_use]
pub fn from_egraph_with<L, F, S>(egraph: &EGraph<L>, mut op_name: F) -> Self
where
L: ENodeLang,
F: FnMut(&L) -> S,
S: AsRef<str>,
{
match Self::try_from_egraph_with(egraph, &mut op_name) {
Ok(snapshot) => snapshot,
Err(_) => Self::default(),
}
}
pub fn try_from_egraph_with<L, F, S>(
egraph: &EGraph<L>,
mut op_name: F,
) -> Result<Self, GpuEGraphSnapshotError>
where
L: ENodeLang,
F: FnMut(&L) -> S,
S: AsRef<str>,
{
let mut snapshot = Self::default();
snapshot.rows.reserve(egraph.class_count());
for (eclass_id, node) in egraph.iter_nodes() {
let language_op_id = snapshot.op_ids.try_intern(op_name(node).as_ref())?;
let children = node.children();
let children_offset = u32_len(snapshot.children.len(), "GPU egraph children offset")?;
let children_len = u32_len(children.len(), "GPU egraph row child count")?;
snapshot
.children
.extend(children.iter().map(|child| egraph.find_immut(*child).0));
snapshot.rows.push(SnapshotRow {
eclass_id: egraph.find_immut(eclass_id).0,
language_op_id,
children_offset,
children_len,
});
}
Ok(snapshot)
}
#[must_use]
pub fn node_count(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn child_count(&self) -> usize {
self.children.len()
}
#[must_use]
pub fn children_of(&self, row_idx: usize) -> Option<&[u32]> {
let row = self.rows.get(row_idx)?;
let start = row.children_offset as usize;
let end = start.checked_add(row.children_len as usize)?;
self.children.get(start..end)
}
#[must_use]
pub fn rows_by_eclass(&self) -> FxHashMap<u32, Vec<usize>> {
let mut out: FxHashMap<u32, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(self.rows.len(), Default::default());
for (i, row) in self.rows.iter().enumerate() {
out.entry(row.eclass_id).or_default().push(i);
}
out
}
pub fn validate_integrity(&self) -> Result<(), GpuEGraphSnapshotIntegrityError> {
let mut eclasses: FxHashSet<u32> =
FxHashSet::with_capacity_and_hasher(self.rows.len(), Default::default());
for row in &self.rows {
eclasses.insert(row.eclass_id);
}
for (row_idx, row) in self.rows.iter().enumerate() {
if self.op_ids.name_of(row.language_op_id).is_none() {
return Err(GpuEGraphSnapshotIntegrityError::new(
"unknown language_op_id",
row_idx,
row.language_op_id,
));
}
let start = row.children_offset as usize;
let end = start
.checked_add(row.children_len as usize)
.ok_or_else(|| {
GpuEGraphSnapshotIntegrityError::new(
"children range overflow",
row_idx,
row.children_len,
)
})?;
if end > self.children.len() {
return Err(GpuEGraphSnapshotIntegrityError::new(
"children range end",
row_idx,
row.children_len,
));
}
for &child in &self.children[start..end] {
if !eclasses.contains(&child) {
return Err(GpuEGraphSnapshotIntegrityError::new(
"dangling child eclass",
row_idx,
child,
));
}
}
}
Ok(())
}
pub fn try_pack_device_image(&self) -> Result<GpuEGraphDeviceImage, GpuEGraphDeviceImageError> {
self.validate_integrity()?;
let mut groups: FxHashMap<u32, Vec<u32>> =
FxHashMap::with_capacity_and_hasher(self.rows.len(), Default::default());
for (row_idx, row) in self.rows.iter().enumerate() {
groups
.entry(row.eclass_id)
.or_default()
.push(u32_len(row_idx, "GPU egraph grouped row index")?);
}
let mut group_eclass_ids = groups.keys().copied().collect::<Vec<_>>();
group_eclass_ids.sort_unstable();
let mut group_offsets = Vec::with_capacity(group_eclass_ids.len() + 1);
let mut group_rows = Vec::with_capacity(self.rows.len());
for eclass_id in &group_eclass_ids {
group_offsets.push(u32_len(group_rows.len(), "GPU egraph group row offset")?);
let Some(rows) = groups.get(eclass_id) else {
return Err(GpuEGraphSnapshotIntegrityError::new(
"missing grouped eclass key",
0,
*eclass_id,
)
.into());
};
group_rows.extend_from_slice(rows);
}
group_offsets.push(u32_len(
group_rows.len(),
"GPU egraph group row terminal offset",
)?);
let row_signatures = self
.rows
.iter()
.map(|row| {
let start = row.children_offset as usize;
let end = start + row.children_len as usize;
egraph_row_signature(row, &self.children[start..end])
})
.collect::<Vec<_>>();
let mut words = Vec::with_capacity(
self.rows.len() * 5
+ self.children.len()
+ group_eclass_ids.len()
+ group_offsets.len()
+ group_rows.len(),
);
let row_eclass_ids = append_words(&mut words, self.rows.iter().map(|row| row.eclass_id));
let row_language_op_ids =
append_words(&mut words, self.rows.iter().map(|row| row.language_op_id));
let row_children_offsets =
append_words(&mut words, self.rows.iter().map(|row| row.children_offset));
let row_children_lens =
append_words(&mut words, self.rows.iter().map(|row| row.children_len));
let row_signatures = append_words(&mut words, row_signatures);
let children = append_words(&mut words, self.children.iter().copied());
let group_eclass_ids_span = append_words(&mut words, group_eclass_ids);
let group_offsets = append_words(&mut words, group_offsets);
let group_rows = append_words(&mut words, group_rows);
Ok(GpuEGraphDeviceImage {
words,
layout: GpuEGraphDeviceLayout {
row_count: self.rows.len(),
child_count: self.children.len(),
eclass_group_count: groups.len(),
row_eclass_ids,
row_language_op_ids,
row_children_offsets,
row_children_lens,
row_signatures,
children,
group_eclass_ids: group_eclass_ids_span,
group_offsets,
group_rows,
},
})
}
#[must_use]
pub fn pack_device_image(&self) -> GpuEGraphDeviceImage {
match self.try_pack_device_image() {
Ok(image) => image,
Err(_) => GpuEGraphDeviceImage::default(),
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct ApplyEquivalencesReport {
pub requested: usize,
pub valid: usize,
pub merged: usize,
pub rebuild_unions: usize,
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct GpuEGraphBridgeReport {
pub snapshot_rows: usize,
pub snapshot_children: usize,
pub device_words: usize,
pub device_eclass_groups: usize,
pub equivalences_requested: usize,
pub equivalences_valid: usize,
pub equivalences_merged: usize,
pub rebuild_unions: usize,
pub cpu_equivalences_valid: usize,
pub cpu_equivalences_merged: usize,
pub cpu_rebuild_unions: usize,
pub snapshot_ns: u64,
pub pack_ns: u64,
pub cpu_apply_ns: u64,
pub gpu_apply_ns: u64,
pub cpu_extraction_ns: u64,
pub gpu_extraction_ns: u64,
pub cpu_extraction_cost: Option<u64>,
pub gpu_extraction_cost: Option<u64>,
pub recall_parity: bool,
pub class_id_deterministic: bool,
}
pub fn bridge_equivalence_batch_with_report<L, F, S, C>(
egraph: &mut EGraph<L>,
root: EClassId,
op_name: F,
equivalences: &[Equivalence],
cost_fn: C,
) -> Result<GpuEGraphBridgeReport, GpuEGraphBridgeError>
where
L: ENodeLang,
F: Fn(&L) -> S,
S: AsRef<str>,
C: Fn(&L) -> u64 + Copy,
{
let snapshot_start = Instant::now();
let snapshot = GpuEGraphSnapshot::try_from_egraph_with(egraph, |node| op_name(node))?;
let snapshot_ns = elapsed_nonzero_ns(snapshot_start);
let deterministic_snapshot =
GpuEGraphSnapshot::try_from_egraph_with(egraph, |node| op_name(node))?;
let pack_start = Instant::now();
let image = snapshot.try_pack_device_image()?;
let pack_ns = elapsed_nonzero_ns(pack_start);
let deterministic_image = deterministic_snapshot.try_pack_device_image()?;
let class_id_deterministic = image.group_eclass_ids() == deterministic_image.group_eclass_ids()
&& image.group_offsets() == deterministic_image.group_offsets()
&& image.group_rows() == deterministic_image.group_rows();
let mut cpu_parity = egraph.clone();
let cpu_apply_start = Instant::now();
let cpu_apply = apply_equivalences_to_egraph(&mut cpu_parity, equivalences);
let cpu_apply_ns = elapsed_nonzero_ns(cpu_apply_start);
let gpu_apply_start = Instant::now();
let gpu_apply = apply_equivalences_to_egraph(egraph, equivalences);
let gpu_apply_ns = elapsed_nonzero_ns(gpu_apply_start);
let cpu_extraction_start = Instant::now();
let cpu_extraction = try_extract_best_with_budget(
&cpu_parity,
root,
|node| cost_fn(node),
DEFAULT_EXTRACTION_ITER_BUDGET,
)?;
let cpu_extraction_ns = elapsed_nonzero_ns(cpu_extraction_start);
let gpu_extraction_start = Instant::now();
let gpu_extraction = try_extract_best_with_budget(
egraph,
root,
|node| cost_fn(node),
DEFAULT_EXTRACTION_ITER_BUDGET,
)?;
let gpu_extraction_ns = elapsed_nonzero_ns(gpu_extraction_start);
let cpu_extraction_cost = cpu_extraction.best.as_ref().map(|(_, cost)| *cost);
let gpu_extraction_cost = gpu_extraction.best.as_ref().map(|(_, cost)| *cost);
let recall_parity = cpu_apply == gpu_apply && cpu_extraction.best == gpu_extraction.best;
Ok(GpuEGraphBridgeReport {
snapshot_rows: snapshot.node_count(),
snapshot_children: snapshot.child_count(),
device_words: image.words().len(),
device_eclass_groups: image.layout().eclass_group_count(),
equivalences_requested: gpu_apply.requested,
equivalences_valid: gpu_apply.valid,
equivalences_merged: gpu_apply.merged,
rebuild_unions: gpu_apply.rebuild_unions,
cpu_equivalences_valid: cpu_apply.valid,
cpu_equivalences_merged: cpu_apply.merged,
cpu_rebuild_unions: cpu_apply.rebuild_unions,
snapshot_ns,
pack_ns,
cpu_apply_ns,
gpu_apply_ns,
cpu_extraction_ns,
gpu_extraction_ns,
cpu_extraction_cost,
gpu_extraction_cost,
recall_parity,
class_id_deterministic,
})
}
pub fn apply_equivalences<F>(equivalences: &[Equivalence], mut merger: F) -> usize
where
F: FnMut(u32, u32) -> bool,
{
let mut applied = 0usize;
for eq in equivalences {
if merger(eq.left, eq.right) {
applied += 1;
}
}
applied
}
pub fn apply_equivalences_to_egraph<L>(
egraph: &mut EGraph<L>,
equivalences: &[Equivalence],
) -> ApplyEquivalencesReport
where
L: ENodeLang,
{
let mut report = ApplyEquivalencesReport {
requested: equivalences.len(),
..ApplyEquivalencesReport::default()
};
let Ok(class_count) = u32_len(egraph.class_count(), "CPU egraph class count") else {
return report;
};
for eq in equivalences {
if eq.left >= class_count || eq.right >= class_count {
continue;
}
report.valid += 1;
let left = EClassId(eq.left);
let right = EClassId(eq.right);
if egraph.find(left) != egraph.find(right) {
egraph.union(left, right);
report.merged += 1;
}
}
report.rebuild_unions = egraph.rebuild();
report
}
#[inline]
fn u32_len(value: usize, context: &'static str) -> Result<u32, GpuEGraphSnapshotError> {
u32::try_from(value).map_err(|_| GpuEGraphSnapshotError::new(context, value))
}
fn append_words<I>(words: &mut Vec<u32>, values: I) -> GpuEGraphDeviceSpan
where
I: IntoIterator<Item = u32>,
{
let offset = words.len();
words.extend(values);
GpuEGraphDeviceSpan::new(offset, words.len() - offset)
}
fn elapsed_nonzero_ns(start: Instant) -> u64 {
let ns = start.elapsed().as_nanos();
u64::try_from(ns).unwrap_or(u64::MAX).max(1)
}
#[must_use]
pub fn gpu_egraph_row_signature(language_op_id: u32, children_len: u32, children: &[u32]) -> u32 {
let mut hash = mix_egraph_signature(0xA24B_AED4, language_op_id);
hash = mix_egraph_signature(hash, children_len);
for &child in children {
hash = mix_egraph_signature(hash, child);
}
hash
}
fn egraph_row_signature(row: &SnapshotRow, children: &[u32]) -> u32 {
gpu_egraph_row_signature(row.language_op_id, row.children_len, children)
}
fn mix_egraph_signature(hash: u32, value: u32) -> u32 {
let mixed = hash
^ value
.wrapping_add(0x9E37_79B9)
.wrapping_add(hash << 6)
.wrapping_add(hash >> 2);
mixed.rotate_left(13).wrapping_mul(0x85EB_CA6B)
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::Hash;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
enum TinyLang {
Lit(u32),
Add(EClassId, EClassId),
}
impl ENodeLang for TinyLang {
fn children(&self) -> super::super::eqsat::EChildren {
match self {
Self::Lit(_) => super::super::eqsat::EChildren::new(),
Self::Add(left, right) => [*left, *right].into_iter().collect(),
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Self::Lit(value) => Self::Lit(*value),
Self::Add(_, _) => Self::Add(children[0], children[1]),
}
}
}
fn tiny_op_name(node: &TinyLang) -> &'static str {
match node {
TinyLang::Lit(_) => "lit",
TinyLang::Add(_, _) => "add",
}
}
fn tiny_cost(node: &TinyLang) -> u64 {
match node {
TinyLang::Lit(_) => 1,
TinyLang::Add(_, _) => 4,
}
}
#[test]
fn empty_snapshot() {
let snap = GpuEGraphSnapshot::default();
assert!(snap.is_empty());
assert_eq!(snap.node_count(), 0);
assert_eq!(snap.child_count(), 0);
assert!(snap.op_ids.is_empty());
}
#[test]
fn build_three_node_snapshot() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(1u32, "lit_u32", &[][..]),
(2u32, "binop_add", &[0u32, 1u32][..]),
]);
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
let empty: &[u32] = &[];
assert_eq!(snap.children_of(0), Some(empty));
assert_eq!(snap.children_of(1), Some(empty));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
assert_eq!(snap.children_of(99), None);
}
#[test]
fn op_id_intern_dedups() {
let mut reg = OpIdRegistry::default();
let a = reg.intern("foo");
let b = reg.intern("bar");
let c = reg.intern("foo");
assert_eq!(a, c);
assert_ne!(a, b);
assert_eq!(reg.len(), 2);
assert_eq!(reg.name_of(a), Some("foo"));
assert_eq!(reg.name_of(b), Some("bar"));
assert_eq!(reg.name_of(99), None);
}
#[test]
fn gpu_snapshot_u32_layout_conversion_rejects_overflow() {
let error = u32_len(u32::MAX as usize + 1, "test overflow")
.expect_err("Fix: GPU e-graph snapshot must not silently saturate oversized columns");
assert_eq!(error.context(), "test overflow");
assert_eq!(error.value(), u32::MAX as usize + 1);
assert!(
error.to_string().contains("shard the e-graph snapshot")
&& error.to_string().contains("widen the GPU snapshot ABI"),
"oversized GPU snapshot errors must explain both viable fixes"
);
}
#[test]
fn gpu_snapshot_builders_use_fallible_u32_conversion_not_saturation() {
let source = include_str!("eqsat_gpu.rs");
let production = source
.split("#[cfg(test)]")
.next()
.expect("Fix: production eqsat_gpu section must exist");
assert!(
source.contains("pub fn try_build")
&& source.contains("pub fn try_from_egraph_with")
&& source.contains("snapshot.op_ids.try_intern")
&& source.contains("u32::try_from(value).map_err")
&& !source.contains(concat!("unwrap_or", "(u32::MAX)")),
"Fix: GPU e-graph snapshots must reject oversized u32 ABI fields instead of saturating them to u32::MAX."
);
assert!(
!production.contains(".expect("),
"Fix: GPU e-graph snapshot production paths must return structured errors instead of panicking."
);
}
#[test]
fn rows_by_eclass_groups_correctly() {
let snap = GpuEGraphSnapshot::build([
(0u32, "lit_u32", &[][..]),
(0u32, "var", &[][..]),
(1u32, "binop_add", &[0u32][..]),
]);
let groups = snap.rows_by_eclass();
assert_eq!(groups.len(), 2);
assert_eq!(groups.get(&0).unwrap().len(), 2);
assert_eq!(groups.get(&1).unwrap().len(), 1);
}
#[test]
fn generated_snapshot_integrity_accepts_pack_boundaries_and_forward_children() {
for node_count in [1_usize, 2, 7, 8, 9, 16, 17, 31, 32, 33, 65, 128] {
let mut rows = Vec::with_capacity(node_count);
let mut child_storage = Vec::new();
for row in 0..node_count {
let start = child_storage.len();
if row > 0 {
child_storage.push((row - 1) as u32);
}
if row > 1 && row % 3 == 0 {
child_storage.push((row / 2) as u32);
}
rows.push((
row as u32,
if row % 2 == 0 { "lit" } else { "add" },
start,
child_storage.len() - start,
));
}
let build_rows = rows
.iter()
.map(|&(class, op, start, len)| (class, op, &child_storage[start..start + len]))
.collect::<Vec<_>>();
let snapshot = GpuEGraphSnapshot::build(build_rows);
snapshot
.validate_integrity()
.unwrap_or_else(|error| panic!("node_count={node_count}: {error}"));
}
}
#[test]
fn snapshot_integrity_rejects_unknown_op_id() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].language_op_id = 99;
let error = snapshot
.validate_integrity()
.expect_err("Fix: malformed GPU snapshot op ids must be rejected before upload.");
assert_eq!(error.context(), "unknown language_op_id");
assert_eq!(error.row(), 0);
assert_eq!(error.value(), 99);
}
#[test]
fn snapshot_integrity_rejects_out_of_bounds_child_range() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].children_offset = 1;
snapshot.rows[0].children_len = 1;
let error = snapshot
.validate_integrity()
.expect_err("Fix: malformed GPU snapshot child ranges must be rejected before upload.");
assert_eq!(error.context(), "children range end");
assert_eq!(error.row(), 0);
}
#[test]
fn snapshot_integrity_rejects_dangling_child_eclass() {
let snapshot =
GpuEGraphSnapshot::build([(0u32, "lit", &[][..]), (1u32, "add", &[0u32, 99u32][..])]);
let error = snapshot.validate_integrity().expect_err(
"Fix: malformed GPU snapshot child eclasses must be rejected before upload.",
);
assert_eq!(error.context(), "dangling child eclass");
assert_eq!(error.row(), 1);
assert_eq!(error.value(), 99);
}
#[test]
fn device_image_packs_single_upload_slab_with_sorted_group_index() {
let snapshot = GpuEGraphSnapshot::build([
(2u32, "lit", &[][..]),
(1u32, "lit", &[][..]),
(2u32, "add", &[1u32, 2u32][..]),
]);
let image = snapshot
.try_pack_device_image()
.expect("Fix: valid GPU e-graph snapshot must pack into a device image");
let layout = image.layout();
assert_eq!(layout.row_count(), 3);
assert_eq!(layout.child_count(), 2);
assert_eq!(layout.eclass_group_count(), 2);
assert_eq!(image.row_eclass_ids(), &[2, 1, 2]);
assert_eq!(image.row_language_op_ids(), &[0, 0, 1]);
assert_eq!(image.row_children_offsets(), &[0, 0, 0]);
assert_eq!(image.row_children_lens(), &[0, 0, 2]);
assert_eq!(image.row_signatures().len(), 3);
assert_ne!(image.row_signatures()[0], image.row_signatures()[2]);
assert_eq!(image.children(), &[1, 2]);
assert_eq!(image.group_eclass_ids(), &[1, 2]);
assert_eq!(image.group_offsets(), &[0, 1, 3]);
assert_eq!(image.group_rows(), &[1, 0, 2]);
assert_eq!(
image.words().len(),
layout.row_eclass_ids().len()
+ layout.row_language_op_ids().len()
+ layout.row_children_offsets().len()
+ layout.row_children_lens().len()
+ layout.row_signatures().len()
+ layout.children().len()
+ layout.group_eclass_ids().len()
+ layout.group_offsets().len()
+ layout.group_rows().len()
);
}
#[test]
fn generated_device_image_pack_accepts_empty_and_power_boundaries() {
for node_count in [0_usize, 1, 2, 7, 8, 9, 31, 32, 33, 127, 128, 129] {
let mut rows = Vec::with_capacity(node_count);
let mut child_storage = Vec::new();
for row in 0..node_count {
let start = child_storage.len();
if row > 0 {
child_storage.push((row - 1) as u32);
}
rows.push((
row as u32,
if row & 1 == 0 { "lit" } else { "neg" },
start,
child_storage.len() - start,
));
}
let build_rows = rows
.iter()
.map(|&(class, op, start, len)| (class, op, &child_storage[start..start + len]))
.collect::<Vec<_>>();
let snapshot = GpuEGraphSnapshot::build(build_rows);
let image = snapshot
.try_pack_device_image()
.unwrap_or_else(|error| panic!("node_count={node_count}: {error}"));
assert_eq!(image.layout().row_count(), node_count);
assert_eq!(image.row_eclass_ids().len(), node_count);
assert_eq!(image.row_language_op_ids().len(), node_count);
assert_eq!(image.row_signatures().len(), node_count);
assert_eq!(image.group_rows().len(), node_count);
assert_eq!(image.group_offsets().len(), node_count + 1);
}
}
#[test]
fn row_signatures_group_structural_duplicates_without_eclass_identity() {
let snapshot = GpuEGraphSnapshot::build([
(1u32, "lit", &[][..]),
(2u32, "lit", &[][..]),
(10u32, "add", &[1u32, 2u32][..]),
(11u32, "add", &[1u32, 2u32][..]),
(12u32, "add", &[2u32, 1u32][..]),
(13u32, "mul", &[1u32, 2u32][..]),
]);
let image = snapshot
.try_pack_device_image()
.expect("Fix: valid duplicate-signature snapshot must pack");
assert_eq!(image.row_signatures()[2], image.row_signatures()[3]);
assert_ne!(image.row_signatures()[2], image.row_signatures()[4]);
assert_ne!(image.row_signatures()[2], image.row_signatures()[5]);
}
#[test]
fn device_image_rejects_malformed_snapshot_before_pack() {
let mut snapshot = GpuEGraphSnapshot::build([(0u32, "lit", &[][..])]);
snapshot.rows[0].language_op_id = 42;
let error = snapshot
.try_pack_device_image()
.expect_err("Fix: device image packing must reject malformed snapshots");
match error {
GpuEGraphDeviceImageError::Integrity(error) => {
assert_eq!(error.context(), "unknown language_op_id");
assert_eq!(error.row(), 0);
assert_eq!(error.value(), 42);
}
GpuEGraphDeviceImageError::Layout(error) => {
panic!("expected integrity error, got layout error: {error}")
}
}
}
#[test]
fn snapshot_from_egraph_uses_canonical_children() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let add = egraph.add(TinyLang::Add(a, b));
assert_eq!(add.0, 2);
let snap = GpuEGraphSnapshot::from_egraph_with(&egraph, |node| match node {
TinyLang::Lit(_) => "lit",
TinyLang::Add(_, _) => "add",
});
assert_eq!(snap.node_count(), 3);
assert_eq!(snap.child_count(), 2);
assert_eq!(snap.op_ids.name_of(0), Some("lit"));
assert_eq!(snap.op_ids.name_of(1), Some("add"));
assert_eq!(snap.children_of(2), Some(&[0, 1][..]));
}
#[test]
fn apply_equivalences_counts_state_changes() {
let equivalences = vec![
Equivalence { left: 0, right: 1 },
Equivalence { left: 1, right: 0 }, Equivalence { left: 2, right: 3 },
];
let mut canonical: FxHashMap<u32, u32> = FxHashMap::default();
let applied = apply_equivalences(&equivalences, |a, b| {
let canon_a = *canonical.get(&a).unwrap_or(&a);
let canon_b = *canonical.get(&b).unwrap_or(&b);
if canon_a == canon_b {
false
} else {
let (lo, hi) = if canon_a < canon_b {
(canon_a, canon_b)
} else {
(canon_b, canon_a)
};
canonical.insert(hi, lo);
canonical.insert(a, lo);
canonical.insert(b, lo);
true
}
});
assert_eq!(applied, 2);
}
#[test]
fn apply_equivalences_empty_batch() {
let applied = apply_equivalences(&[], |_, _| true);
assert_eq!(applied, 0);
}
#[test]
fn apply_equivalences_to_egraph_merges_valid_ids() {
let mut egraph = EGraph::new();
let a = egraph.add(TinyLang::Lit(1));
let b = egraph.add(TinyLang::Lit(2));
let c = egraph.add(TinyLang::Lit(3));
let report = apply_equivalences_to_egraph(
&mut egraph,
&[
Equivalence {
left: a.0,
right: b.0,
},
Equivalence {
left: c.0,
right: 99,
},
],
);
assert_eq!(
report,
ApplyEquivalencesReport {
requested: 2,
valid: 1,
merged: 1,
rebuild_unions: 0,
}
);
assert_eq!(egraph.find(a), egraph.find(b));
assert_ne!(egraph.find(a), egraph.find(c));
}
#[test]
fn gpu_egraph_bridge_reports_compact_columns_apply_and_extraction_parity() {
let mut egraph = EGraph::new();
let one = egraph.add(TinyLang::Lit(1));
let two = egraph.add(TinyLang::Lit(2));
let add = egraph.add(TinyLang::Add(one, two));
let folded = egraph.add(TinyLang::Lit(3));
let report = bridge_equivalence_batch_with_report(
&mut egraph,
add,
tiny_op_name,
&[Equivalence {
left: add.0,
right: folded.0,
}],
tiny_cost,
)
.expect("Fix: valid GPU e-graph bridge probe must produce a parity report");
assert_eq!(report.snapshot_rows, 4);
assert_eq!(report.snapshot_children, 2);
assert!(report.device_words > report.snapshot_rows);
assert_eq!(report.device_eclass_groups, 4);
assert_eq!(report.equivalences_requested, 1);
assert_eq!(report.equivalences_valid, 1);
assert_eq!(report.equivalences_merged, 1);
assert_eq!(report.cpu_equivalences_valid, 1);
assert_eq!(report.cpu_equivalences_merged, 1);
assert_eq!(report.cpu_extraction_cost, Some(1));
assert_eq!(report.gpu_extraction_cost, Some(1));
assert!(report.snapshot_ns > 0);
assert!(report.pack_ns > 0);
assert!(report.cpu_apply_ns > 0);
assert!(report.gpu_apply_ns > 0);
assert!(report.cpu_extraction_ns > 0);
assert!(report.gpu_extraction_ns > 0);
assert!(report.recall_parity);
assert!(report.class_id_deterministic);
}
}