use rustc_hash::FxHashSet;
use crate::ir::{Expr, Ident, Node};
pub(super) fn collect_buffer_targets(
node: &Node,
loads: &mut FxHashSet<Ident>,
stores: &mut FxHashSet<Ident>,
atomics: &mut FxHashSet<Ident>,
) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
collect_buffer_targets_from_expr(value, loads, atomics);
}
Node::Store {
buffer,
index,
value,
} => {
stores.insert(Ident::from(buffer));
collect_buffer_targets_from_expr(index, loads, atomics);
collect_buffer_targets_from_expr(value, loads, atomics);
}
Node::If {
cond,
then,
otherwise,
} => {
collect_buffer_targets_from_expr(cond, loads, atomics);
for n in then.iter().chain(otherwise.iter()) {
collect_buffer_targets(n, loads, stores, atomics);
}
}
Node::Loop { from, to, body, .. } => {
collect_buffer_targets_from_expr(from, loads, atomics);
collect_buffer_targets_from_expr(to, loads, atomics);
for n in body {
collect_buffer_targets(n, loads, stores, atomics);
}
}
Node::Block(body) => {
for n in body {
collect_buffer_targets(n, loads, stores, atomics);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
collect_buffer_targets(n, loads, stores, atomics);
}
}
Node::AllReduce { buffer, .. } | Node::Broadcast { buffer, .. } => {
loads.insert(buffer.clone());
stores.insert(buffer.clone());
}
Node::AllGather { input, output, .. } | Node::ReduceScatter { input, output, .. } => {
loads.insert(input.clone());
stores.insert(output.clone());
}
Node::IndirectDispatch { .. }
| Node::Return
| Node::Barrier { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Trap { .. }
| Node::Resume { .. }
| Node::Opaque(_) => {}
}
}
fn collect_buffer_targets_from_expr(
expr: &Expr,
loads: &mut FxHashSet<Ident>,
atomics: &mut FxHashSet<Ident>,
) {
match expr {
Expr::Load { buffer, index } => {
loads.insert(Ident::from(buffer));
collect_buffer_targets_from_expr(index, loads, atomics);
}
Expr::Atomic {
buffer,
index,
expected,
value,
..
} => {
atomics.insert(Ident::from(buffer));
collect_buffer_targets_from_expr(index, loads, atomics);
if let Some(expected) = expected {
collect_buffer_targets_from_expr(expected, loads, atomics);
}
collect_buffer_targets_from_expr(value, loads, atomics);
}
Expr::BinOp { left, right, .. } => {
collect_buffer_targets_from_expr(left, loads, atomics);
collect_buffer_targets_from_expr(right, loads, atomics);
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
collect_buffer_targets_from_expr(operand, loads, atomics);
}
Expr::Fma { a, b, c } => {
collect_buffer_targets_from_expr(a, loads, atomics);
collect_buffer_targets_from_expr(b, loads, atomics);
collect_buffer_targets_from_expr(c, loads, atomics);
}
Expr::Call { args, .. } => {
for arg in args {
collect_buffer_targets_from_expr(arg, loads, atomics);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
collect_buffer_targets_from_expr(cond, loads, atomics);
collect_buffer_targets_from_expr(true_val, loads, atomics);
collect_buffer_targets_from_expr(false_val, loads, atomics);
}
Expr::SubgroupBallot { cond } => collect_buffer_targets_from_expr(cond, loads, atomics),
Expr::SubgroupShuffle { value, lane } => {
collect_buffer_targets_from_expr(value, loads, atomics);
collect_buffer_targets_from_expr(lane, loads, atomics);
}
Expr::SubgroupAdd { value } => collect_buffer_targets_from_expr(value, loads, atomics),
_ => {}
}
}