use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use thiserror::Error;
use vyre_spec::AlgebraicLaw;
#[must_use]
pub fn source() -> Option<&'static str> {
crate::transform::compiler::shader_provider::source("visitor_walk")
}
pub const VISIT_STACK_EMPTY: u32 = u32::MAX;
#[must_use]
pub fn visit_step_program(
child_offsets: &str,
children: &str,
stack: &str,
post_order: &str,
post_count: &str,
) -> Program {
let body = vec![
Node::let_bind("top", Expr::load(stack, Expr::u32(0))),
Node::if_then(
Expr::ne(Expr::var("top"), Expr::u32(0)),
vec![
Node::let_bind("node", Expr::load(stack, Expr::var("top"))),
Node::let_bind("new_top", Expr::sub(Expr::var("top"), Expr::u32(1))),
Node::store(stack, Expr::u32(0), Expr::var("new_top")),
Node::let_bind(
"pc",
Expr::atomic_add(post_count, Expr::u32(0), Expr::u32(1)),
),
Node::store(post_order, Expr::var("pc"), Expr::var("node")),
Node::let_bind("cs", Expr::load(child_offsets, Expr::var("node"))),
Node::let_bind(
"ce",
Expr::load(child_offsets, Expr::add(Expr::var("node"), Expr::u32(1))),
),
Node::loop_for(
"i",
Expr::var("cs"),
Expr::var("ce"),
vec![
Node::let_bind("c", Expr::load(children, Expr::var("i"))),
Node::let_bind(
"t2",
Expr::add(Expr::load(stack, Expr::u32(0)), Expr::u32(1)),
),
Node::store(stack, Expr::u32(0), Expr::var("t2")),
Node::store(stack, Expr::var("t2"), Expr::var("c")),
],
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(child_offsets, 0, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(children, 1, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(stack, 2, BufferAccess::ReadWrite, DataType::U32),
BufferDecl::storage(post_order, 3, BufferAccess::ReadWrite, DataType::U32),
BufferDecl::storage(post_count, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[1, 1, 1],
body,
)
}
impl VisitorWalkOp {}
#[must_use]
pub fn index(value: u32) -> Result<usize, VisitorWalkError> {
usize::try_from(value).map_err(|_| VisitorWalkError::IndexOverflow)
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[must_use]
pub fn postorder(
root: u32,
child_offsets: &[u32],
children: &[u32],
max_stack: usize,
) -> Result<Vec<u32>, VisitorWalkError> {
let node_count = child_offsets
.len()
.checked_sub(1)
.ok_or(VisitorWalkError::EmptyOffsets)?;
let root_index = index(root)?;
if root_index >= node_count {
return Err(VisitorWalkError::InvalidRoot { root, node_count });
}
validate_tree(node_count, child_offsets, children)?;
let mut seen = vec![false; node_count];
let mut sequence = Vec::with_capacity(node_count);
let mut stack = Vec::with_capacity(max_stack.min(node_count).saturating_add(1));
stack.push((root, false));
while let Some((node, expanded)) = stack.pop() {
let node_index = index(node)?;
if expanded {
sequence.push(node);
continue;
}
if seen[node_index] {
return Err(VisitorWalkError::Cycle { node });
}
seen[node_index] = true;
if stack.len().saturating_add(1) > max_stack {
return Err(VisitorWalkError::StackOverflow { max_stack });
}
stack.push((node, true));
let start = index(child_offsets[node_index])?;
let end = index(child_offsets[node_index + 1])?;
for &child in children[start..end].iter().rev() {
if stack.len().saturating_add(1) > max_stack {
return Err(VisitorWalkError::StackOverflow { max_stack });
}
stack.push((child, false));
}
}
Ok(sequence)
}
#[must_use]
pub fn validate_tree(
node_count: usize,
offsets: &[u32],
children: &[u32],
) -> Result<(), VisitorWalkError> {
let mut previous = 0usize;
for &offset in offsets {
let current = index(offset)?;
if current < previous || current > children.len() {
return Err(VisitorWalkError::InvalidOffset);
}
previous = current;
}
for &child in children {
if index(child)? >= node_count {
return Err(VisitorWalkError::InvalidChild { child, node_count });
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum VisitorWalkError {
#[error(
"VisitorEmptyOffsets: child_offsets must include node_count + 1 entries. Fix: emit a valid tree CSR table."
)]
EmptyOffsets,
#[error(
"VisitorInvalidRoot: root {root} outside node_count {node_count}. Fix: pass a valid AST root."
)]
InvalidRoot {
root: u32,
node_count: usize,
},
#[error(
"VisitorInvalidOffset: child offsets must be monotone and within children. Fix: rebuild child_offsets."
)]
InvalidOffset,
#[error("VisitorIndexOverflow: node id cannot fit usize. Fix: split the AST before dispatch.")]
IndexOverflow,
#[error(
"VisitorInvalidChild: child {child} outside node_count {node_count}. Fix: validate AST child references."
)]
InvalidChild {
child: u32,
node_count: usize,
},
#[error(
"VisitorCycle: node {node} was reached twice. Fix: pass a tree or DAG-expanded AST, not a cyclic graph."
)]
Cycle {
node: u32,
},
#[error(
"VisitorStackOverflow: stack exceeded {max_stack} entries. Fix: increase workgroup visitor stack or split the AST."
)]
StackOverflow {
max_stack: usize,
},
}
#[derive(Debug, Default, Clone, Copy)]
pub struct VisitorWalkOp;
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];
#[cfg(test)]
mod ir_program_tests {
use super::*;
#[test]
fn visit_step_program_validates() {
let prog = visit_step_program("co", "c", "stack", "po", "pc");
let errors = crate::validate::validate::validate(&prog);
assert!(errors.is_empty(), "visitor IR must validate: {errors:?}");
}
#[test]
fn visit_step_program_wire_round_trips() {
let prog = visit_step_program("co", "c", "stack", "po", "pc");
let bytes = prog
.to_wire()
.expect("Fix: serialize; restore this invariant before continuing.");
let decoded = Program::from_wire(&bytes)
.expect("Fix: decode; restore this invariant before continuing.");
assert_eq!(decoded.buffers().len(), 5);
}
#[test]
fn visit_stack_empty_sentinel_is_u32_max() {
assert_eq!(VISIT_STACK_EMPTY, u32::MAX);
}
}