use crate::ir::model::expr::Expr;
use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use std::collections::HashSet;
#[inline]
pub fn walk_nodes(program: &Program, mut f: impl FnMut(&Node)) {
let mut stack = Vec::with_capacity(program.entry().len());
for node in program.entry().iter().rev() {
stack.push(node);
}
while let Some(node) = stack.pop() {
f(node);
match node {
Node::If {
then, otherwise, ..
} => {
for n in otherwise.iter().rev() {
stack.push(n);
}
for n in then.iter().rev() {
stack.push(n);
}
}
Node::Loop { body, .. } => {
for n in body.iter().rev() {
stack.push(n);
}
}
Node::Block(inner) => {
for n in inner.iter().rev() {
stack.push(n);
}
}
Node::Let { .. }
| Node::Assign { .. }
| Node::Store { .. }
| Node::Return
| Node::Barrier => {}
}
}
}
#[inline]
pub fn walk_exprs(program: &Program, mut f: impl FnMut(&Expr)) {
let mut node_stack = Vec::with_capacity(program.entry().len());
for node in program.entry().iter().rev() {
node_stack.push(node);
}
let mut expr_stack = Vec::with_capacity(program.entry().len().saturating_mul(2));
while let Some(node) = node_stack.pop() {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
expr_stack.push(value);
}
Node::Store { index, value, .. } => {
expr_stack.push(value);
expr_stack.push(index);
}
Node::If {
cond,
then,
otherwise,
} => {
for n in otherwise.iter().rev() {
node_stack.push(n);
}
for n in then.iter().rev() {
node_stack.push(n);
}
expr_stack.push(cond);
}
Node::Loop { from, to, body, .. } => {
for n in body.iter().rev() {
node_stack.push(n);
}
expr_stack.push(to);
expr_stack.push(from);
}
Node::Block(nodes) => {
for n in nodes.iter().rev() {
node_stack.push(n);
}
}
Node::Return | Node::Barrier => {}
}
while let Some(expr) = expr_stack.pop() {
f(expr);
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_) => {}
Expr::Load { index, .. } => expr_stack.push(index),
Expr::BufLen { .. } => {}
Expr::InvocationId { .. } | Expr::WorkgroupId { .. } | Expr::LocalId { .. } => {}
Expr::BinOp { left, right, .. } => {
expr_stack.push(right);
expr_stack.push(left);
}
Expr::Fma { a, b, c, .. } => {
expr_stack.push(c);
expr_stack.push(b);
expr_stack.push(a);
}
Expr::UnOp { operand, .. } => expr_stack.push(operand),
Expr::Call { args, .. } => {
for arg in args.iter().rev() {
expr_stack.push(arg);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_stack.push(false_val);
expr_stack.push(true_val);
expr_stack.push(cond);
}
Expr::Cast { value, .. } => expr_stack.push(value),
Expr::Atomic {
index,
expected,
value,
..
} => {
expr_stack.push(value);
if let Some(expected) = expected {
expr_stack.push(expected);
}
expr_stack.push(index);
}
}
}
}
}
#[inline]
pub fn walk_nodes_mut(program: &mut Program, mut f: impl FnMut(&mut Node)) {
let mut stack = Vec::with_capacity(program.entry().len());
for node in program.entry_mut().iter_mut().rev() {
stack.push(node);
}
while let Some(node) = stack.pop() {
f(&mut *node);
match node {
Node::If {
then, otherwise, ..
} => {
for n in otherwise.iter_mut().rev() {
stack.push(n);
}
for n in then.iter_mut().rev() {
stack.push(n);
}
}
Node::Loop { body, .. } => {
for n in body.iter_mut().rev() {
stack.push(n);
}
}
Node::Block(inner) => {
for n in inner.iter_mut().rev() {
stack.push(n);
}
}
Node::Let { .. }
| Node::Assign { .. }
| Node::Store { .. }
| Node::Return
| Node::Barrier => {}
}
}
}
#[must_use]
#[inline]
pub fn referenced_buffers(program: &Program) -> HashSet<String> {
let mut names = HashSet::new();
walk_exprs(program, |expr| match expr {
Expr::Load { buffer, .. } | Expr::BufLen { buffer } | Expr::Atomic { buffer, .. } => {
names.insert(buffer.to_string());
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::BinOp { .. }
| Expr::Fma { .. }
| Expr::UnOp { .. }
| Expr::Call { .. }
| Expr::Select { .. }
| Expr::Cast { .. } => {}
});
walk_nodes(program, |node| {
if let Node::Store { buffer, .. } = node {
names.insert(buffer.clone());
}
});
names
}
#[must_use]
#[inline]
pub fn collect_call_op_ids(program: &Program) -> Vec<String> {
let mut op_ids = Vec::with_capacity(program.entry().len());
walk_exprs(program, |expr| {
if let Expr::Call { op_id, .. } = expr {
op_ids.push(op_id.clone());
}
});
op_ids
}