use crate::ops::AlgebraicLaw;
use crate::ir::transform::compiler::{U32X4_INPUTS, U32_OUTPUTS};
use crate::lower::wgsl::compiler::wgsl_backend;
use crate::ops::{IntrinsicDescriptor, OpSpec};
use thiserror::Error;
#[must_use]
pub const fn source() -> &'static str {
include_str!("../../../lower/wgsl/compiler/visitor_walk.wgsl")
}
impl VisitorWalkOp {
pub const SPEC: OpSpec = OpSpec::intrinsic(
"compiler_primitives.visitor_walk",
U32X4_INPUTS,
U32_OUTPUTS,
LAWS,
wgsl_backend,
IntrinsicDescriptor::new("compiler_primitives_visitor_walk", "workgroup_visitor", crate::ops::cpu_op::structured_intrinsic_cpu),
);
}
pub fn index(value: u32) -> Result<usize, VisitorWalkError> {
usize::try_from(value).map_err(|_| VisitorWalkError::IndexOverflow)
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub fn postorder(
root: u32,
child_offsets: &[u32],
children: &[u32],
max_stack: usize,
) -> Result<Vec<u32>, VisitorWalkError> {
let node_count = child_offsets
.len()
.checked_sub(1)
.ok_or(VisitorWalkError::EmptyOffsets)?;
let root_index = index(root)?;
if root_index >= node_count {
return Err(VisitorWalkError::InvalidRoot { root, node_count });
}
validate_tree(node_count, child_offsets, children)?;
let mut seen = vec![false; node_count];
let mut sequence = Vec::new();
let mut stack = vec![(root, false)];
while let Some((node, expanded)) = stack.pop() {
let node_index = index(node)?;
if expanded {
sequence.push(node);
continue;
}
if seen[node_index] {
return Err(VisitorWalkError::Cycle { node });
}
seen[node_index] = true;
if stack.len().saturating_add(1) > max_stack {
return Err(VisitorWalkError::StackOverflow { max_stack });
}
stack.push((node, true));
let start = index(child_offsets[node_index])?;
let end = index(child_offsets[node_index + 1])?;
for &child in children[start..end].iter().rev() {
if stack.len().saturating_add(1) > max_stack {
return Err(VisitorWalkError::StackOverflow { max_stack });
}
stack.push((child, false));
}
}
Ok(sequence)
}
pub fn validate_tree(
node_count: usize,
offsets: &[u32],
children: &[u32],
) -> Result<(), VisitorWalkError> {
let mut previous = 0usize;
for &offset in offsets {
let current = index(offset)?;
if current < previous || current > children.len() {
return Err(VisitorWalkError::InvalidOffset);
}
previous = current;
}
for &child in children {
if index(child)? >= node_count {
return Err(VisitorWalkError::InvalidChild { child, node_count });
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum VisitorWalkError {
#[error("VisitorEmptyOffsets: child_offsets must include node_count + 1 entries. Fix: emit a valid tree CSR table.")]
EmptyOffsets,
#[error("VisitorInvalidRoot: root {root} outside node_count {node_count}. Fix: pass a valid AST root.")]
InvalidRoot {
root: u32,
node_count: usize,
},
#[error("VisitorInvalidOffset: child offsets must be monotone and within children. Fix: rebuild child_offsets.")]
InvalidOffset,
#[error("VisitorIndexOverflow: node id cannot fit usize. Fix: split the AST before dispatch.")]
IndexOverflow,
#[error("VisitorInvalidChild: child {child} outside node_count {node_count}. Fix: validate AST child references.")]
InvalidChild {
child: u32,
node_count: usize,
},
#[error("VisitorCycle: node {node} was reached twice. Fix: pass a tree or DAG-expanded AST, not a cyclic graph.")]
Cycle {
node: u32,
},
#[error("VisitorStackOverflow: stack exceeded {max_stack} entries. Fix: increase workgroup visitor stack or split the AST.")]
StackOverflow {
max_stack: usize,
},
}
#[derive(Debug, Default, Clone, Copy)]
pub struct VisitorWalkOp;
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];