use crate::ir::model::expr::Expr;
use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use crate::ir::model::types::AtomicOp;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Default)]
pub(crate) struct HelperUsage {
buffers: HashMap<String, BufferHelperUsage>,
}
#[derive(Debug, Default)]
pub(crate) struct BufferHelperUsage {
pub(crate) load: bool,
pub(crate) store: bool,
pub(crate) atomic_ops: HashSet<AtomicOp>,
}
impl HelperUsage {
pub(crate) fn collect(program: &Program) -> Self {
let mut usage = Self::default();
for node in program.entry() {
usage.record_node(node);
}
usage
}
pub(crate) fn buffer(&self, name: &str) -> BufferHelperUsageView<'_> {
match self.buffers.get(name) {
Some(usage) => BufferHelperUsageView {
load: usage.load,
store: usage.store,
atomic_ops: Some(&usage.atomic_ops),
},
None => BufferHelperUsageView::default(),
}
}
fn record_node(&mut self, node: &Node) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => self.record_expr(value),
Node::Store {
buffer,
index,
value,
} => {
self.buffers.entry(buffer.clone()).or_default().store = true;
self.record_expr(index);
self.record_expr(value);
}
Node::If {
cond,
then,
otherwise,
} => {
self.record_expr(cond);
for node in then.iter().chain(otherwise) {
self.record_node(node);
}
}
Node::Loop { from, to, body, .. } => {
self.record_expr(from);
self.record_expr(to);
for node in body {
self.record_node(node);
}
}
Node::Block(nodes) => {
for node in nodes {
self.record_node(node);
}
}
Node::Return | Node::Barrier => {}
}
}
fn record_expr(&mut self, expr: &Expr) {
match expr {
Expr::Load { buffer, index } => {
self.buffers.entry(buffer.to_string()).or_default().load = true;
self.record_expr(index);
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => {
self.buffers
.entry(buffer.to_string())
.or_default()
.atomic_ops
.insert(op.clone());
self.record_expr(index);
if let Some(expected) = expected {
self.record_expr(expected);
}
self.record_expr(value);
}
Expr::BinOp { left, right, .. } => {
self.record_expr(left);
self.record_expr(right);
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
self.record_expr(operand);
}
Expr::Call { args, .. } => {
for arg in args {
self.record_expr(arg);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
self.record_expr(cond);
self.record_expr(true_val);
self.record_expr(false_val);
}
Expr::Fma { a, b, c } => {
self.record_expr(a);
self.record_expr(b);
self.record_expr(c);
}
Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_) => {}
}
}
}
#[derive(Clone, Copy, Default)]
pub(crate) struct BufferHelperUsageView<'a> {
pub(crate) load: bool,
pub(crate) store: bool,
pub(crate) atomic_ops: Option<&'a HashSet<AtomicOp>>,
}
impl BufferHelperUsageView<'_> {
pub(crate) fn uses_atomics(self) -> bool {
self.atomic_ops.is_some_and(|ops| !ops.is_empty())
}
}