use crate::ir::{AtomicOp, Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "dead_store_elim",
requires = [],
invalidates = []
)]
pub struct DeadStoreElim;
impl DeadStoreElim {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut has_redundant_store_pair))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = drop_dead_stores(
program
.into_entry_vec()
.into_iter()
.map(|n| rewrite_node(n, &mut changed))
.collect(),
&mut changed,
);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| rewrite_node(child, changed));
node_map::map_body(recursed, &mut |body| drop_dead_stores(body, changed))
}
fn drop_dead_stores(body: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let mut keep = vec![true; body.len()];
for first_idx in 0..body.len() {
if !keep[first_idx] {
continue;
}
let Node::Store {
buffer: first_buf,
index: first_idx_expr,
..
} = &body[first_idx]
else {
continue;
};
for second_idx in (first_idx + 1)..body.len() {
let between = &body[(first_idx + 1)..second_idx];
if any_node_observes_buffer(between, first_buf) {
break;
}
match &body[second_idx] {
Node::Store {
buffer: second_buf,
index: second_idx_expr,
..
} if second_buf == first_buf
&& expr_structurally_eq(first_idx_expr, second_idx_expr) =>
{
keep[first_idx] = false;
*changed = true;
break;
}
node if node_observes_buffer(node, first_buf) => {
break;
}
_ => {
}
}
}
}
body.into_iter()
.zip(keep)
.filter_map(|(node, alive)| alive.then_some(node))
.collect()
}
fn any_node_observes_buffer(nodes: &[Node], buffer: &Ident) -> bool {
nodes.iter().any(|n| node_observes_buffer(n, buffer))
}
fn node_observes_buffer(node: &Node, buffer: &Ident) -> bool {
match node {
Node::Store {
buffer: other,
index,
value,
} => {
(other == buffer).then(|| false).unwrap_or_else(|| {
expr_touches_buffer(index, buffer) || expr_touches_buffer(value, buffer)
})
}
Node::Let { value, .. } | Node::Assign { value, .. } => expr_touches_buffer(value, buffer),
Node::If {
cond,
then,
otherwise,
} => {
expr_touches_buffer(cond, buffer)
|| any_node_observes_buffer(then, buffer)
|| any_node_observes_buffer(otherwise, buffer)
}
Node::Loop { from, to, body, .. } => {
expr_touches_buffer(from, buffer)
|| expr_touches_buffer(to, buffer)
|| any_node_observes_buffer(body, buffer)
}
Node::Block(body) => any_node_observes_buffer(body, buffer),
Node::Region { body, .. } => any_node_observes_buffer(body.as_ref(), buffer),
Node::Barrier { .. } => true,
Node::AsyncLoad {
source,
destination,
offset,
size,
..
}
| Node::AsyncStore {
source,
destination,
offset,
size,
..
} => {
source == buffer
|| destination == buffer
|| expr_touches_buffer(offset, buffer)
|| expr_touches_buffer(size, buffer)
}
Node::AsyncWait { .. } => true,
Node::IndirectDispatch { count_buffer, .. } => count_buffer == buffer,
Node::Trap { address, .. } => expr_touches_buffer(address, buffer),
Node::Resume { .. } => true,
Node::Return | Node::Opaque(_) => true,
}
}
fn expr_touches_buffer(expr: &Expr, buffer: &Ident) -> bool {
match expr {
Expr::Load {
buffer: other,
index,
} => other == buffer || expr_touches_buffer(index, buffer),
Expr::BufLen { buffer: other } => other == buffer,
Expr::Atomic {
buffer: other,
index,
expected,
value,
op,
..
} => {
other == buffer
|| expr_touches_buffer(index, buffer)
|| matches!(
op,
AtomicOp::CompareExchange | AtomicOp::CompareExchangeWeak
)
|| expected
.as_deref()
.is_some_and(|e| expr_touches_buffer(e, buffer))
|| expr_touches_buffer(value, buffer)
}
Expr::BinOp { left, right, .. } => {
expr_touches_buffer(left, buffer) || expr_touches_buffer(right, buffer)
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
expr_touches_buffer(operand, buffer)
}
Expr::Fma { a, b, c } => {
expr_touches_buffer(a, buffer)
|| expr_touches_buffer(b, buffer)
|| expr_touches_buffer(c, buffer)
}
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_touches_buffer(cond, buffer)
|| expr_touches_buffer(true_val, buffer)
|| expr_touches_buffer(false_val, buffer)
}
Expr::Call { args, .. } => args.iter().any(|a| expr_touches_buffer(a, buffer)),
Expr::SubgroupShuffle { value, .. } | Expr::SubgroupAdd { value } => {
expr_touches_buffer(value, buffer)
}
Expr::SubgroupBallot { cond } => expr_touches_buffer(cond, buffer),
Expr::Opaque(_) => true,
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => false,
}
}
fn expr_structurally_eq(left: &Expr, right: &Expr) -> bool {
left == right
}
fn has_redundant_store_pair(node: &Node) -> bool {
let body: &[Node] = match node {
Node::If {
then, otherwise, ..
} => {
return contains_buffer_pair(then) || contains_buffer_pair(otherwise);
}
Node::Loop { body, .. } | Node::Block(body) => body,
Node::Region { body, .. } => body.as_ref(),
_ => return false,
};
contains_buffer_pair(body)
}
fn contains_buffer_pair(body: &[Node]) -> bool {
let mut seen: rustc_hash::FxHashSet<&Ident> = rustc_hash::FxHashSet::default();
for n in body {
if let Node::Store { buffer, .. } = n {
if !seen.insert(buffer) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf(name: &str) -> BufferDecl {
BufferDecl::storage(name, 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf("buf"), buf("other")], [1, 1, 1], entry)
}
fn count_stores(node: &Node) -> usize {
match node {
Node::Store { .. } => 1,
Node::If {
then, otherwise, ..
} => {
then.iter().map(count_stores).sum::<usize>()
+ otherwise.iter().map(count_stores).sum::<usize>()
}
Node::Loop { body, .. } | Node::Block(body) => body.iter().map(count_stores).sum(),
Node::Region { body, .. } => body.iter().map(count_stores).sum(),
_ => 0,
}
}
#[test]
fn drops_first_of_two_back_to_back_stores_to_same_index() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(total, 1, "first dead store must be dropped");
}
#[test]
fn keeps_both_stores_to_different_indices() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(1), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(!result.changed);
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(total, 2);
}
#[test]
fn keeps_both_stores_to_different_buffers() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("other", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(!result.changed);
}
#[test]
fn keeps_first_store_when_intervening_load_observes_it() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::let_bind("x", Expr::load("buf", Expr::u32(0))),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(
!result.changed,
"must not drop a store whose value is observably read before the overwrite"
);
}
#[test]
fn drops_first_when_intervening_let_reads_a_different_buffer() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::let_bind("x", Expr::load("other", Expr::u32(0))),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(total, 1, "only the overwritten store is removed");
}
#[test]
fn keeps_first_store_when_barrier_separates_it() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::barrier(),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(
!result.changed,
"Barrier between stores must keep the first one alive"
);
}
#[test]
fn keeps_first_store_when_atomic_read_intervenes() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::let_bind(
"x",
Expr::Atomic {
op: AtomicOp::Exchange,
buffer: Ident::from("buf"),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(0)),
ordering: crate::ir::MemoryOrdering::Relaxed,
},
),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(!result.changed);
}
#[test]
fn drops_dead_store_inside_if_branch() {
let entry = vec![Node::if_then(
Expr::var("c"),
vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
],
)];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(total, 1);
}
#[test]
fn keeps_stores_separated_by_nested_if() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::if_then(Expr::var("c"), vec![Node::Return]),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(
!result.changed,
"nested If between stores is opaque under conservative DSE — keep the first"
);
}
#[test]
fn drops_chain_of_three_redundant_stores() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
Node::store("buf", Expr::u32(0), Expr::u32(3)),
];
let program = program_with_entry(entry);
let result = DeadStoreElim::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(total, 1, "only the last store survives");
}
#[test]
fn analyze_skips_program_with_no_redundant_pair() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("other", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
assert_eq!(DeadStoreElim::analyze(&program), PassAnalysis::SKIP);
}
#[test]
fn analyze_runs_when_redundant_pair_present() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(0), Expr::u32(2)),
];
let program = program_with_entry(entry);
assert_eq!(DeadStoreElim::analyze(&program), PassAnalysis::RUN);
}
}