use super::{expr_has_effect, CseCtx, ExprKey};
use crate::ir::{Expr, Node};
impl CseCtx {
#[inline]
pub(crate) fn child(&self) -> Self {
Self {
values: self.values.clone(),
}
}
#[inline]
pub(crate) fn clear_observed_state(&mut self) {
self.values.clear();
}
#[inline]
pub(crate) fn nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
nodes.iter().map(|node| self.node(node)).collect()
}
#[inline]
pub(crate) fn node(&mut self, node: &Node) -> Node {
match node {
Node::Let { name, value } => {
let value = self.expr(value);
if expr_has_effect(&value) {
self.clear_observed_state();
return Node::let_bind(name, value);
}
let key = ExprKey::from_expr(&value);
if matches!(value, Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitBool(_)) {
return Node::let_bind(name, value);
}
let canonical = self
.values
.get(&key)
.map_or_else(|| value.clone(), |existing| Expr::var(existing));
self.values.entry(key).or_insert_with(|| name.clone());
Node::let_bind(name, canonical)
}
Node::Assign { name, value } => {
let value = self.expr(value);
self.clear_observed_state();
Node::assign(name, value)
}
Node::Store {
buffer,
index,
value,
} => {
let index = self.expr(index);
let value = self.expr(value);
self.clear_observed_state();
Node::store(buffer, index, value)
}
Node::If {
cond,
then,
otherwise,
} => {
let cond = self.expr(cond);
let mut then_ctx = self.child();
let then = then_ctx.nodes(then);
let mut otherwise_ctx = self.child();
let otherwise = otherwise_ctx.nodes(otherwise);
self.clear_observed_state();
Node::if_then_else(cond, then, otherwise)
}
Node::Loop {
var,
from,
to,
body,
} => {
let from = self.expr(from);
let to = self.expr(to);
let mut body_ctx = CseCtx::default();
let body = body_ctx.nodes(body);
self.clear_observed_state();
Node::loop_for(var, from, to, body)
}
Node::Return => Node::Return,
Node::Block(nodes) => {
let mut block_ctx = self.child();
let nodes = block_ctx.nodes(nodes);
self.clear_observed_state();
Node::block(nodes)
}
Node::Barrier => {
self.clear_observed_state();
Node::Barrier
}
}
}
#[inline]
pub(crate) fn expr(&mut self, expr: &Expr) -> Expr {
let rewritten = match expr {
Expr::Load { buffer, index } => Expr::Load {
buffer: buffer.clone(),
index: Box::new(self.expr(index)),
},
Expr::BinOp { op, left, right } => Expr::BinOp {
op: op.clone(),
left: Box::new(self.expr(left)),
right: Box::new(self.expr(right)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op: op.clone(),
operand: Box::new(self.expr(operand)),
},
Expr::Fma { a, b, c } => Expr::Fma {
a: Box::new(self.expr(a)),
b: Box::new(self.expr(b)),
c: Box::new(self.expr(c)),
},
Expr::Call { op_id, args } => Expr::Call {
op_id: op_id.clone(),
args: args.iter().map(|arg| self.expr(arg)).collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(self.expr(cond)),
true_val: Box::new(self.expr(true_val)),
false_val: Box::new(self.expr(false_val)),
},
Expr::Cast { target, value } => Expr::Cast {
target: target.clone(),
value: Box::new(self.expr(value)),
},
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => {
let index = self.expr(index);
let expected = expected.as_deref().map(|expr| Box::new(self.expr(expr)));
let value = self.expr(value);
self.clear_observed_state();
Expr::Atomic {
op: op.clone(),
buffer: buffer.clone(),
index: Box::new(index),
expected,
value: Box::new(value),
}
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => expr.clone(),
};
if matches!(rewritten, Expr::Var(_)) || expr_has_effect(&rewritten) {
return rewritten;
}
let key = ExprKey::from_expr(&rewritten);
if matches!(
rewritten,
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitBool(_)
) {
return rewritten;
}
self.values
.get(&key)
.map_or(rewritten, |existing| Expr::var(existing))
}
}