vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Visitor for IR traversal.
//!
//! Optimization passes, lowering, and analysis use these utilities to walk
//! the IR tree without manually matching every variant. All traversals are
//! implemented with an explicit stack rather than recursion. This is a
//! critical design choice: it prevents stack overflows when processing deep
//! ASTs (e.g., highly nested `If` or `Block` nodes) during adversarial or
//! extreme workloads.

use crate::ir::model::expr::Expr;
use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use std::collections::HashSet;

/// Walk all nodes in a program, calling `f` on each.
///
/// The traversal is depth-first and visits every statement node in the
/// program's entry block, including nested `If`, `Loop`, and `Block`
/// bodies. Because the walk is iterative, it can handle arbitrarily deep
/// nesting without growing the native call stack.
///
/// # Examples
///
/// ```
/// use vyre::ir::{visit::walk_nodes, Program};
///
/// let program = Program::empty();
/// walk_nodes(&program, |_node| {
///     // process node
/// });
/// ```
#[inline]
pub fn walk_nodes(program: &Program, mut f: impl FnMut(&Node)) {
    let mut stack = Vec::with_capacity(program.entry().len());
    for node in program.entry().iter().rev() {
        stack.push(node);
    }

    while let Some(node) = stack.pop() {
        f(node);
        match node {
            Node::If {
                then, otherwise, ..
            } => {
                for n in otherwise.iter().rev() {
                    stack.push(n);
                }
                for n in then.iter().rev() {
                    stack.push(n);
                }
            }
            Node::Loop { body, .. } => {
                for n in body.iter().rev() {
                    stack.push(n);
                }
            }
            Node::Block(inner) => {
                for n in inner.iter().rev() {
                    stack.push(n);
                }
            }
            Node::Let { .. }
            | Node::Assign { .. }
            | Node::Store { .. }
            | Node::Return
            | Node::Barrier => {}
        }
    }
}

/// Walk all expressions in a program, calling `f` on each.
///
/// The traversal visits every `Expr` nested inside every node, again using
/// an explicit stack. This is the primary way to inspect or transform the
/// value-producing parts of a program.
///
/// # Examples
///
/// ```
/// use vyre::ir::{visit::walk_exprs, Program};
///
/// let program = Program::empty();
/// walk_exprs(&program, |_expr| {
///     // process expression
/// });
/// ```
#[inline]
pub fn walk_exprs(program: &Program, mut f: impl FnMut(&Expr)) {
    let mut node_stack = Vec::with_capacity(program.entry().len());
    for node in program.entry().iter().rev() {
        node_stack.push(node);
    }

    let mut expr_stack = Vec::with_capacity(program.entry().len().saturating_mul(2));

    while let Some(node) = node_stack.pop() {
        match node {
            Node::Let { value, .. } | Node::Assign { value, .. } => {
                expr_stack.push(value);
            }
            Node::Store { index, value, .. } => {
                expr_stack.push(value);
                expr_stack.push(index);
            }
            Node::If {
                cond,
                then,
                otherwise,
            } => {
                for n in otherwise.iter().rev() {
                    node_stack.push(n);
                }
                for n in then.iter().rev() {
                    node_stack.push(n);
                }
                expr_stack.push(cond);
            }
            Node::Loop { from, to, body, .. } => {
                for n in body.iter().rev() {
                    node_stack.push(n);
                }
                expr_stack.push(to);
                expr_stack.push(from);
            }
            Node::Block(nodes) => {
                for n in nodes.iter().rev() {
                    node_stack.push(n);
                }
            }
            Node::Return | Node::Barrier => {}
        }

        while let Some(expr) = expr_stack.pop() {
            f(expr);
            match expr {
                Expr::LitU32(_)
                | Expr::LitI32(_)
                | Expr::LitF32(_)
                | Expr::LitBool(_)
                | Expr::Var(_) => {}
                Expr::Load { index, .. } => expr_stack.push(index),
                Expr::BufLen { .. } => {}
                Expr::InvocationId { .. } | Expr::WorkgroupId { .. } | Expr::LocalId { .. } => {}
                Expr::BinOp { left, right, .. } => {
                    expr_stack.push(right);
                    expr_stack.push(left);
                }
                Expr::Fma { a, b, c, .. } => {
                    expr_stack.push(c);
                    expr_stack.push(b);
                    expr_stack.push(a);
                }
                Expr::UnOp { operand, .. } => expr_stack.push(operand),
                Expr::Call { args, .. } => {
                    for arg in args.iter().rev() {
                        expr_stack.push(arg);
                    }
                }
                Expr::Select {
                    cond,
                    true_val,
                    false_val,
                } => {
                    expr_stack.push(false_val);
                    expr_stack.push(true_val);
                    expr_stack.push(cond);
                }
                Expr::Cast { value, .. } => expr_stack.push(value),
                Expr::Atomic {
                    index,
                    expected,
                    value,
                    ..
                } => {
                    expr_stack.push(value);
                    if let Some(expected) = expected {
                        expr_stack.push(expected);
                    }
                    expr_stack.push(index);
                }
            }
        }
    }
}

/// Mutably walk all nodes, allowing in-place transformation.
///
/// This is the mutable counterpart to [`walk_nodes`]. Callers can rewrite
/// nodes in place, for example to specialize control flow or inject
/// instrumentation. The explicit-stack invariant is preserved.
///
/// # Examples
///
/// ```
/// use vyre::ir::{visit::walk_nodes_mut, Program};
///
/// let mut program = Program::empty();
/// walk_nodes_mut(&mut program, |_node| {
///     // modify node
/// });
/// ```
#[inline]
pub fn walk_nodes_mut(program: &mut Program, mut f: impl FnMut(&mut Node)) {
    let mut stack = Vec::with_capacity(program.entry().len());
    for node in program.entry_mut().iter_mut().rev() {
        stack.push(node);
    }

    while let Some(node) = stack.pop() {
        f(&mut *node);
        match node {
            Node::If {
                then, otherwise, ..
            } => {
                for n in otherwise.iter_mut().rev() {
                    stack.push(n);
                }
                for n in then.iter_mut().rev() {
                    stack.push(n);
                }
            }
            Node::Loop { body, .. } => {
                for n in body.iter_mut().rev() {
                    stack.push(n);
                }
            }
            Node::Block(inner) => {
                for n in inner.iter_mut().rev() {
                    stack.push(n);
                }
            }
            Node::Let { .. }
            | Node::Assign { .. }
            | Node::Store { .. }
            | Node::Return
            | Node::Barrier => {}
        }
    }
}

/// Collect all buffer names referenced in the program's expressions and statements.
///
/// This is a convenience wrapper around the visitor that extracts the set
/// of buffer identifiers actually used by the program. It is used by
/// validation and lowering to check that every declared buffer is
/// referenced and that no undeclared buffer is accessed.
///
/// # Examples
///
/// ```
/// use vyre::ir::{visit::referenced_buffers, Program};
///
/// let program = Program::empty();
/// let buffers = referenced_buffers(&program);
/// assert!(buffers.is_empty());
/// ```
#[must_use]
#[inline]
pub fn referenced_buffers(program: &Program) -> HashSet<String> {
    let mut names = HashSet::new();
    walk_exprs(program, |expr| match expr {
        Expr::Load { buffer, .. } | Expr::BufLen { buffer } | Expr::Atomic { buffer, .. } => {
            names.insert(buffer.to_string());
        }
        Expr::LitU32(_)
        | Expr::LitI32(_)
        | Expr::LitF32(_)
        | Expr::LitBool(_)
        | Expr::Var(_)
        | Expr::InvocationId { .. }
        | Expr::WorkgroupId { .. }
        | Expr::LocalId { .. }
        | Expr::BinOp { .. }
        | Expr::Fma { .. }
        | Expr::UnOp { .. }
        | Expr::Call { .. }
        | Expr::Select { .. }
        | Expr::Cast { .. } => {}
    });
    walk_nodes(program, |node| {
        if let Node::Store { buffer, .. } = node {
            names.insert(buffer.clone());
        }
    });
    names
}

/// Collect operation IDs from every [`Expr::Call`] in traversal order.
///
/// This helper is used by the inliner and the conform gate to discover
/// which operations a program depends on. The returned vector preserves
/// the order of first appearance.
///
/// # Examples
///
/// ```
/// use vyre::ir::{visit::collect_call_op_ids, Expr, Node, Program};
///
/// let program = Program::new(
///     Vec::new(),
///     [1, 1, 1],
///     vec![Node::let_bind("x", Expr::call("primitive.math.add", vec![Expr::u32(1)]))],
/// );
/// assert_eq!(collect_call_op_ids(&program), vec!["primitive.math.add"]);
/// ```
#[must_use]
#[inline]
pub fn collect_call_op_ids(program: &Program) -> Vec<String> {
    let mut op_ids = Vec::with_capacity(program.entry().len());
    walk_exprs(program, |expr| {
        if let Expr::Call { op_id, .. } = expr {
            op_ids.push(op_id.clone());
        }
    });
    op_ids
}