use crate::ir_inner::model::expr::Expr;
use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::Program;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct ProgramEffects(u32);
impl ProgramEffects {
pub const fn empty() -> Self {
Self(0)
}
pub const BUFFER_WRITE: Self = Self(1 << 0);
pub const ATOMIC: Self = Self(1 << 1);
pub const HOST_IO: Self = Self(1 << 2);
pub const GPU_DISPATCH: Self = Self(1 << 3);
pub const BARRIER: Self = Self(1 << 4);
pub const ASYNC_LOAD: Self = Self(1 << 5);
pub const TRAP: Self = Self(1 << 6);
#[must_use]
#[inline]
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
#[must_use]
#[inline]
pub const fn bits(self) -> u32 {
self.0
}
}
impl core::ops::BitOr for ProgramEffects {
type Output = Self;
#[inline]
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}
impl core::ops::BitOrAssign for ProgramEffects {
#[inline]
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0;
}
}
#[must_use]
pub fn compute_program_effects(program: &Program) -> ProgramEffects {
let mut effects = ProgramEffects::empty();
for node in program.entry() {
walk_node(node, &mut effects);
}
effects
}
fn walk_node(node: &Node, effects: &mut ProgramEffects) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => walk_expr(value, effects),
Node::Store { index, value, .. } => {
*effects |= ProgramEffects::BUFFER_WRITE;
walk_expr(index, effects);
walk_expr(value, effects);
}
Node::If {
cond,
then,
otherwise,
} => {
walk_expr(cond, effects);
for n in then {
walk_node(n, effects);
}
for n in otherwise {
walk_node(n, effects);
}
}
Node::Loop { from, to, body, .. } => {
walk_expr(from, effects);
walk_expr(to, effects);
for n in body {
walk_node(n, effects);
}
}
Node::IndirectDispatch { .. } => {
*effects |= ProgramEffects::GPU_DISPATCH;
}
Node::AsyncLoad { offset, size, .. } => {
*effects |= ProgramEffects::ASYNC_LOAD;
walk_expr(offset, effects);
walk_expr(size, effects);
}
Node::AsyncStore { offset, size, .. } => {
*effects |= ProgramEffects::BUFFER_WRITE;
walk_expr(offset, effects);
walk_expr(size, effects);
}
Node::AsyncWait { .. } | Node::Resume { .. } => {}
Node::Trap { address, .. } => {
*effects |= ProgramEffects::TRAP;
walk_expr(address, effects);
}
Node::Return => {}
Node::Barrier { .. } => {
*effects |= ProgramEffects::BARRIER;
}
Node::Block(nodes) => {
for n in nodes {
walk_node(n, effects);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
walk_node(n, effects);
}
}
Node::Opaque(_) => {
}
}
}
fn walk_expr(expr: &Expr, effects: &mut ProgramEffects) {
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::BufLen { .. } => {}
Expr::Load { index, .. } => walk_expr(index, effects),
Expr::BinOp { left, right, .. } => {
walk_expr(left, effects);
walk_expr(right, effects);
}
Expr::UnOp { operand, .. } => walk_expr(operand, effects),
Expr::Call { args, .. } => {
for a in args {
walk_expr(a, effects);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
walk_expr(cond, effects);
walk_expr(true_val, effects);
walk_expr(false_val, effects);
}
Expr::Cast { value, .. } => walk_expr(value, effects),
Expr::Fma { a, b, c } => {
walk_expr(a, effects);
walk_expr(b, effects);
walk_expr(c, effects);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
*effects |= ProgramEffects::ATOMIC;
walk_expr(index, effects);
if let Some(expected) = expected {
walk_expr(expected, effects);
}
walk_expr(value, effects);
}
Expr::SubgroupBallot { cond } => walk_expr(cond, effects),
Expr::SubgroupShuffle { value, lane } => {
walk_expr(value, effects);
walk_expr(lane, effects);
}
Expr::SubgroupAdd { value } => walk_expr(value, effects),
Expr::Opaque(_) => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr as IrExpr, Node as IrNode, Program};
fn program_with(body: Vec<IrNode>, buffers: Vec<BufferDecl>) -> Program {
Program::wrapped(buffers, [1, 1, 1], body)
}
#[test]
fn empty_program_has_no_effects() {
let prog = program_with(vec![IrNode::Return], vec![]);
assert_eq!(compute_program_effects(&prog), ProgramEffects::empty());
}
#[test]
fn store_records_buffer_write() {
let prog = program_with(
vec![IrNode::store("out", IrExpr::u32(0), IrExpr::u32(7))],
vec![
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
);
let e = compute_program_effects(&prog);
assert!(e.contains(ProgramEffects::BUFFER_WRITE));
assert!(!e.contains(ProgramEffects::ATOMIC));
assert!(!e.contains(ProgramEffects::BARRIER));
}
#[test]
fn barrier_records_barrier() {
let prog = program_with(vec![IrNode::barrier(), IrNode::Return], vec![]);
let e = compute_program_effects(&prog);
assert!(e.contains(ProgramEffects::BARRIER));
}
#[test]
fn atomic_records_atomic() {
let prog = program_with(
vec![IrNode::store(
"out",
IrExpr::u32(0),
IrExpr::atomic_add("out", IrExpr::u32(0), IrExpr::u32(1)),
)],
vec![
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
);
let e = compute_program_effects(&prog);
assert!(e.contains(ProgramEffects::ATOMIC));
assert!(e.contains(ProgramEffects::BUFFER_WRITE));
}
#[test]
fn nested_in_if_branches_collects_effects() {
let prog = program_with(
vec![IrNode::If {
cond: IrExpr::bool(true),
then: vec![IrNode::barrier()],
otherwise: vec![IrNode::store("o", IrExpr::u32(0), IrExpr::u32(1))],
}],
vec![BufferDecl::storage("o", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1)],
);
let e = compute_program_effects(&prog);
assert!(e.contains(ProgramEffects::BARRIER));
assert!(e.contains(ProgramEffects::BUFFER_WRITE));
}
#[test]
fn pure_arithmetic_program_has_no_effects() {
let prog = program_with(
vec![IrNode::let_bind("x", IrExpr::u32(7)), IrNode::Return],
vec![],
);
let e = compute_program_effects(&prog);
assert_eq!(e, ProgramEffects::empty());
}
#[test]
fn region_traversal_descends_into_body() {
let prog = program_with(
vec![IrNode::Region {
generator: "test.r".into(),
source_region: None,
body: std::sync::Arc::new(vec![IrNode::barrier()]),
}],
vec![],
);
assert!(compute_program_effects(&prog).contains(ProgramEffects::BARRIER));
}
#[test]
fn effects_form_a_stable_set() {
let buffer =
BufferDecl::storage("o", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1);
let p1 = program_with(
vec![
IrNode::barrier(),
IrNode::store("o", IrExpr::u32(0), IrExpr::u32(1)),
],
vec![buffer.clone()],
);
let p2 = program_with(
vec![
IrNode::store("o", IrExpr::u32(0), IrExpr::u32(1)),
IrNode::barrier(),
],
vec![buffer],
);
assert_eq!(compute_program_effects(&p1), compute_program_effects(&p2));
}
}