vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Visitor walk — a bounded post-order tree traversal primitive.
//!
//! Compilers spend most of their time walking trees: AST visitors, scope-tree
//! traversals, dominator walks.  `visitor_walk` gives vyre a first-class
//! primitive for that coordination.  It takes a root node and a CSR child
//! table, then emits a post-order sequence using an explicit stack that lives
//! in workgroup-local SRAM.  The WGSL kernel uses the same stack bound and
//! cycle-detection logic, so conform can prove the visit order is identical on
//! CPU and GPU.

use crate::ops::AlgebraicLaw;
use crate::ir::transform::compiler::{U32X4_INPUTS, U32_OUTPUTS};
use crate::lower::wgsl::compiler::wgsl_backend;
use crate::ops::{IntrinsicDescriptor, OpSpec};
use thiserror::Error;

/// Portable WGSL source for the visitor walk primitive.
#[must_use]
pub const fn source() -> &'static str {
    include_str!("../../../lower/wgsl/compiler/visitor_walk.wgsl")
}

impl VisitorWalkOp {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::intrinsic(
        "compiler_primitives.visitor_walk",
        U32X4_INPUTS,
        U32_OUTPUTS,
        LAWS,
        wgsl_backend,
        IntrinsicDescriptor::new("compiler_primitives_visitor_walk", "workgroup_visitor", crate::ops::cpu_op::structured_intrinsic_cpu),
    );
}

/// Safely cast a `u32` node id to `usize` for host indexing.
///
/// # Errors
///
/// Returns `VisitorWalkError::IndexOverflow` if the value does not fit.
pub fn index(value: u32) -> Result<usize, VisitorWalkError> {
    usize::try_from(value).map_err(|_| VisitorWalkError::IndexOverflow)
}

/// Algebraic laws declared by the visitor-walk primitive.
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Produce a post-order tree visit sequence from CSR children.
///
/// # Errors
///
/// Returns `Fix: ...` when the tree is malformed, cyclic, or exceeds the
/// explicit stack/output bounds.
pub fn postorder(
    root: u32,
    child_offsets: &[u32],
    children: &[u32],
    max_stack: usize,
) -> Result<Vec<u32>, VisitorWalkError> {
    let node_count = child_offsets
        .len()
        .checked_sub(1)
        .ok_or(VisitorWalkError::EmptyOffsets)?;
    let root_index = index(root)?;
    if root_index >= node_count {
        return Err(VisitorWalkError::InvalidRoot { root, node_count });
    }
    validate_tree(node_count, child_offsets, children)?;
    let mut seen = vec![false; node_count];
    let mut sequence = Vec::new();
    let mut stack = vec![(root, false)];
    while let Some((node, expanded)) = stack.pop() {
        let node_index = index(node)?;
        if expanded {
            sequence.push(node);
            continue;
        }
        if seen[node_index] {
            return Err(VisitorWalkError::Cycle { node });
        }
        seen[node_index] = true;
        if stack.len().saturating_add(1) > max_stack {
            return Err(VisitorWalkError::StackOverflow { max_stack });
        }
        stack.push((node, true));
        let start = index(child_offsets[node_index])?;
        let end = index(child_offsets[node_index + 1])?;
        for &child in children[start..end].iter().rev() {
            if stack.len().saturating_add(1) > max_stack {
                return Err(VisitorWalkError::StackOverflow { max_stack });
            }
            stack.push((child, false));
        }
    }
    Ok(sequence)
}

/// Validate that `offsets` and `children` form a well-formed CSR child table.
///
/// # Errors
///
/// Returns `Fix: ...` when offsets are non-monotone, out of range, or when
/// any child id exceeds the node count.
pub fn validate_tree(
    node_count: usize,
    offsets: &[u32],
    children: &[u32],
) -> Result<(), VisitorWalkError> {
    let mut previous = 0usize;
    for &offset in offsets {
        let current = index(offset)?;
        if current < previous || current > children.len() {
            return Err(VisitorWalkError::InvalidOffset);
        }
        previous = current;
    }
    for &child in children {
        if index(child)? >= node_count {
            return Err(VisitorWalkError::InvalidChild { child, node_count });
        }
    }
    Ok(())
}

/// Visitor-walk validation errors.
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum VisitorWalkError {
    /// Offset table has no terminal offset.
    #[error("VisitorEmptyOffsets: child_offsets must include node_count + 1 entries. Fix: emit a valid tree CSR table.")]
    EmptyOffsets,
    /// Root node is outside the tree.
    #[error("VisitorInvalidRoot: root {root} outside node_count {node_count}. Fix: pass a valid AST root.")]
    InvalidRoot {
        /// Invalid root.
        root: u32,
        /// Node count.
        node_count: usize,
    },
    /// Child offsets are not monotone or exceed child length.
    #[error("VisitorInvalidOffset: child offsets must be monotone and within children. Fix: rebuild child_offsets.")]
    InvalidOffset,
    /// Node id cannot fit in host index space.
    #[error("VisitorIndexOverflow: node id cannot fit usize. Fix: split the AST before dispatch.")]
    IndexOverflow,
    /// Child node is outside the tree.
    #[error("VisitorInvalidChild: child {child} outside node_count {node_count}. Fix: validate AST child references.")]
    InvalidChild {
        /// Invalid child id.
        child: u32,
        /// Node count.
        node_count: usize,
    },
    /// Traversal found a cycle instead of a tree.
    #[error("VisitorCycle: node {node} was reached twice. Fix: pass a tree or DAG-expanded AST, not a cyclic graph.")]
    Cycle {
        /// Revisited node.
        node: u32,
    },
    /// Explicit traversal stack exceeded its bound.
    #[error("VisitorStackOverflow: stack exceeded {max_stack} entries. Fix: increase workgroup visitor stack or split the AST.")]
    StackOverflow {
        /// Stack capacity.
        max_stack: usize,
    },
}

/// Category C visitor-walk intrinsic.
#[derive(Debug, Default, Clone, Copy)]
pub struct VisitorWalkOp;

/// Workgroup size used by the reference WGSL lowering.
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];