use crate::ir::model::expr::ExprNode;
use crate::ir::{AtomicOp, BinOp, DataType, Expr, Ident, MemoryOrdering, UnOp};
use rustc_hash::FxHashMap;
use std::sync::Arc;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct ExprId(pub u32);
impl ExprId {
#[must_use]
#[inline]
pub fn raw(self) -> u32 {
self.0
}
}
#[allow(missing_docs)]
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum FlatExpr {
LitU32(u32),
LitI32(i32),
LitF32Bits(u32),
LitBool(bool),
Var(Ident),
Load {
buffer: Ident,
index: ExprId,
},
BufLen {
buffer: Ident,
},
InvocationId {
axis: u8,
},
WorkgroupId {
axis: u8,
},
LocalId {
axis: u8,
},
BinOp {
op: BinOp,
left: ExprId,
right: ExprId,
},
UnOp {
op: UnOp,
operand: ExprId,
},
Call {
op_id: Ident,
args: Vec<ExprId>,
},
Select {
cond: ExprId,
true_val: ExprId,
false_val: ExprId,
},
Cast {
target: DataType,
value: ExprId,
},
Fma {
a: ExprId,
b: ExprId,
c: ExprId,
},
Atomic {
op: AtomicOp,
buffer: Ident,
index: ExprId,
expected: Option<ExprId>,
value: ExprId,
ordering: MemoryOrdering,
},
SubgroupBallot {
cond: ExprId,
},
SubgroupShuffle {
value: ExprId,
lane: ExprId,
},
SubgroupAdd {
value: ExprId,
},
SubgroupLocalId,
SubgroupSize,
Opaque(OpaqueId),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct OpaqueId(usize);
#[derive(Debug, Default)]
pub struct ExprArena {
nodes: Vec<FlatExpr>,
opaques: Vec<Arc<dyn ExprNode>>,
opaque_lookup: FxHashMap<OpaqueId, usize>,
hashcons: FxHashMap<FlatExpr, ExprId>,
}
impl ExprArena {
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[must_use]
pub fn get(&self, id: ExprId) -> &FlatExpr {
self.nodes
.get(id.0 as usize)
.unwrap_or_else(|| panic!("Fix: ExprId({}) not produced by this arena", id.0))
}
pub fn intern(&mut self, expr: &Expr) -> ExprId {
let flat = self.flatten(expr);
self.intern_flat(flat)
}
#[must_use]
pub fn rebuild(&self, id: ExprId) -> Expr {
match self.get(id).clone() {
FlatExpr::LitU32(v) => Expr::LitU32(v),
FlatExpr::LitI32(v) => Expr::LitI32(v),
FlatExpr::LitF32Bits(bits) => Expr::LitF32(f32::from_bits(bits)),
FlatExpr::LitBool(v) => Expr::LitBool(v),
FlatExpr::Var(name) => Expr::Var(name),
FlatExpr::Load { buffer, index } => Expr::Load {
buffer,
index: Box::new(self.rebuild(index)),
},
FlatExpr::BufLen { buffer } => Expr::BufLen { buffer },
FlatExpr::InvocationId { axis } => Expr::InvocationId { axis },
FlatExpr::WorkgroupId { axis } => Expr::WorkgroupId { axis },
FlatExpr::LocalId { axis } => Expr::LocalId { axis },
FlatExpr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(self.rebuild(left)),
right: Box::new(self.rebuild(right)),
},
FlatExpr::UnOp { op, operand } => Expr::UnOp {
op,
operand: Box::new(self.rebuild(operand)),
},
FlatExpr::Call { op_id, args } => Expr::Call {
op_id,
args: args.into_iter().map(|a| self.rebuild(a)).collect(),
},
FlatExpr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(self.rebuild(cond)),
true_val: Box::new(self.rebuild(true_val)),
false_val: Box::new(self.rebuild(false_val)),
},
FlatExpr::Cast { target, value } => Expr::Cast {
target,
value: Box::new(self.rebuild(value)),
},
FlatExpr::Fma { a, b, c } => Expr::Fma {
a: Box::new(self.rebuild(a)),
b: Box::new(self.rebuild(b)),
c: Box::new(self.rebuild(c)),
},
FlatExpr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op,
buffer,
index: Box::new(self.rebuild(index)),
expected: expected.map(|id| Box::new(self.rebuild(id))),
value: Box::new(self.rebuild(value)),
ordering,
},
FlatExpr::SubgroupBallot { cond } => Expr::SubgroupBallot {
cond: Box::new(self.rebuild(cond)),
},
FlatExpr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
value: Box::new(self.rebuild(value)),
lane: Box::new(self.rebuild(lane)),
},
FlatExpr::SubgroupAdd { value } => Expr::SubgroupAdd {
value: Box::new(self.rebuild(value)),
},
FlatExpr::SubgroupLocalId => Expr::SubgroupLocalId,
FlatExpr::SubgroupSize => Expr::SubgroupSize,
FlatExpr::Opaque(opaque_id) => {
let idx = self
.opaque_lookup
.get(&opaque_id)
.copied()
.expect("Fix: rebuild saw an OpaqueId not produced by this arena");
Expr::Opaque(Arc::clone(&self.opaques[idx]))
}
}
}
fn intern_flat(&mut self, flat: FlatExpr) -> ExprId {
if let Some(existing) = self.hashcons.get(&flat) {
return *existing;
}
let id = ExprId(self.nodes.len() as u32);
self.nodes.push(flat.clone());
self.hashcons.insert(flat, id);
id
}
fn flatten(&mut self, expr: &Expr) -> FlatExpr {
match expr {
Expr::LitU32(v) => FlatExpr::LitU32(*v),
Expr::LitI32(v) => FlatExpr::LitI32(*v),
Expr::LitF32(v) => FlatExpr::LitF32Bits(v.to_bits()),
Expr::LitBool(v) => FlatExpr::LitBool(*v),
Expr::Var(name) => FlatExpr::Var(name.clone()),
Expr::Load { buffer, index } => FlatExpr::Load {
buffer: buffer.clone(),
index: self.intern(index),
},
Expr::BufLen { buffer } => FlatExpr::BufLen {
buffer: buffer.clone(),
},
Expr::InvocationId { axis } => FlatExpr::InvocationId { axis: *axis },
Expr::WorkgroupId { axis } => FlatExpr::WorkgroupId { axis: *axis },
Expr::LocalId { axis } => FlatExpr::LocalId { axis: *axis },
Expr::BinOp { op, left, right } => FlatExpr::BinOp {
op: *op,
left: self.intern(left),
right: self.intern(right),
},
Expr::UnOp { op, operand } => FlatExpr::UnOp {
op: op.clone(),
operand: self.intern(operand),
},
Expr::Call { op_id, args } => FlatExpr::Call {
op_id: op_id.clone(),
args: args.iter().map(|a| self.intern(a)).collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => FlatExpr::Select {
cond: self.intern(cond),
true_val: self.intern(true_val),
false_val: self.intern(false_val),
},
Expr::Cast { target, value } => FlatExpr::Cast {
target: target.clone(),
value: self.intern(value),
},
Expr::Fma { a, b, c } => FlatExpr::Fma {
a: self.intern(a),
b: self.intern(b),
c: self.intern(c),
},
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => FlatExpr::Atomic {
op: *op,
buffer: buffer.clone(),
index: self.intern(index),
expected: expected.as_deref().map(|e| self.intern(e)),
value: self.intern(value),
ordering: *ordering,
},
Expr::SubgroupBallot { cond } => FlatExpr::SubgroupBallot {
cond: self.intern(cond),
},
Expr::SubgroupShuffle { value, lane } => FlatExpr::SubgroupShuffle {
value: self.intern(value),
lane: self.intern(lane),
},
Expr::SubgroupAdd { value } => FlatExpr::SubgroupAdd {
value: self.intern(value),
},
Expr::SubgroupLocalId => FlatExpr::SubgroupLocalId,
Expr::SubgroupSize => FlatExpr::SubgroupSize,
Expr::Opaque(arc) => {
let opaque_id = OpaqueId(Arc::as_ptr(arc) as *const () as usize);
self.opaque_lookup.entry(opaque_id).or_insert_with(|| {
let idx = self.opaques.len();
self.opaques.push(Arc::clone(arc));
idx
});
FlatExpr::Opaque(opaque_id)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Expr;
#[test]
fn distinct_literals_get_distinct_ids() {
let mut arena = ExprArena::default();
let a = arena.intern(&Expr::u32(1));
let b = arena.intern(&Expr::u32(2));
assert_ne!(a, b);
assert_eq!(arena.len(), 2);
}
#[test]
fn equal_literals_collapse_to_one_id() {
let mut arena = ExprArena::default();
let a = arena.intern(&Expr::u32(7));
let b = arena.intern(&Expr::u32(7));
assert_eq!(a, b);
assert_eq!(
arena.len(),
1,
"second intern of equal Expr must not grow the arena"
);
}
#[test]
fn equal_subtrees_collapse_at_every_level() {
let mut arena = ExprArena::default();
let lhs = Expr::add(Expr::u32(1), Expr::u32(2));
let id_a = arena.intern(&lhs);
assert_eq!(arena.len(), 3);
let id_b = arena.intern(&lhs);
assert_eq!(id_a, id_b);
assert_eq!(arena.len(), 3, "no new nodes for an identical subtree");
}
#[test]
fn shared_leaves_dedup_across_distinct_parents() {
let mut arena = ExprArena::default();
let _ = arena.intern(&Expr::add(Expr::u32(1), Expr::u32(2)));
let _ = arena.intern(&Expr::sub(Expr::u32(1), Expr::u32(2)));
assert_eq!(arena.len(), 4);
}
#[test]
fn rebuild_round_trip_for_simple_literal() {
let mut arena = ExprArena::default();
let id = arena.intern(&Expr::u32(42));
assert_eq!(arena.rebuild(id), Expr::u32(42));
}
#[test]
fn rebuild_round_trip_for_nested_expression() {
let mut arena = ExprArena::default();
let original = Expr::Select {
cond: Box::new(Expr::lt(Expr::var("i"), Expr::u32(8))),
true_val: Box::new(Expr::add(Expr::var("i"), Expr::u32(1))),
false_val: Box::new(Expr::u32(0)),
};
let id = arena.intern(&original);
assert_eq!(arena.rebuild(id), original);
}
#[test]
fn rebuild_round_trip_for_load() {
let mut arena = ExprArena::default();
let original = Expr::load("buf", Expr::add(Expr::var("base"), Expr::u32(4)));
let id = arena.intern(&original);
assert_eq!(arena.rebuild(id), original);
}
#[test]
fn nan_literals_intern_to_same_id() {
let mut arena = ExprArena::default();
let a = arena.intern(&Expr::f32(f32::NAN));
let b = arena.intern(&Expr::f32(f32::NAN));
assert_eq!(a, b);
assert_eq!(arena.len(), 1);
}
#[test]
fn positive_zero_and_negative_zero_get_distinct_ids() {
let mut arena = ExprArena::default();
let pos = arena.intern(&Expr::f32(0.0));
let neg = arena.intern(&Expr::f32(-0.0));
assert_ne!(pos, neg);
}
#[test]
fn expr_id_is_copy_and_eq() {
let id = ExprId(0);
let copy = id;
assert_eq!(id, copy);
}
#[test]
fn arena_intern_is_idempotent_under_rebuild_intern_cycle() {
let mut arena = ExprArena::default();
let original = Expr::Fma {
a: Box::new(Expr::var("x")),
b: Box::new(Expr::var("y")),
c: Box::new(Expr::var("z")),
};
let id = arena.intern(&original);
let rebuilt = arena.rebuild(id);
let id2 = arena.intern(&rebuilt);
assert_eq!(id, id2, "rebuild then re-intern must hit the same ExprId");
}
#[test]
fn opaque_expr_interning_via_arc_identity() {
use crate::ir::model::expr::ExprNode;
use crate::ir::DataType;
use std::any::Any;
#[derive(Debug)]
struct DummyOpaque;
impl ExprNode for DummyOpaque {
fn extension_kind(&self) -> &'static str {
"dummy"
}
fn debug_identity(&self) -> &str {
"dummy"
}
fn result_type(&self) -> Option<DataType> {
Some(DataType::U32)
}
fn cse_safe(&self) -> bool {
true
}
fn stable_fingerprint(&self) -> [u8; 32] {
[0u8; 32]
}
fn validate_extension(&self) -> Result<(), String> {
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
}
let arc: Arc<dyn ExprNode> = Arc::new(DummyOpaque);
let mut arena = ExprArena::default();
let id_a = arena.intern(&Expr::Opaque(Arc::clone(&arc)));
let id_b = arena.intern(&Expr::Opaque(arc));
assert_eq!(id_a, id_b, "two Arcs of the same allocation must collapse");
assert_eq!(arena.len(), 1);
}
}