vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Top-level validation entry point.
//!
//! This module runs the complete validation pipeline on a `Program`:
//! buffer declarations, node structure, expression types, depth limits,
//! and output markers. Every error is returned as a `ValidationError`
//! with an actionable `Fix:` hint.

pub use super::depth::{DEFAULT_MAX_CALL_DEPTH, DEFAULT_MAX_NESTING_DEPTH, DEFAULT_MAX_NODE_COUNT};
use super::expr_rules::validate_output_markers;
use super::{depth, err, nodes, ValidationError};
use crate::ir::model::program::Program;
use crate::ir::model::types::BufferAccess;
use rustc_hash::{FxHashMap, FxHashSet};

/// Validate a program for structural and semantic correctness.
///
/// The validator checks the stable rules documented in
/// `vyre/docs/ir/validation.md`: workgroup dimensions must be positive,
/// buffer names and bindings must be unique, workgroup buffers must have
/// a positive element count, and the node tree must respect depth limits.
/// A successful validation (empty error vector) means the program is
/// safe to lower to any backend.
///
/// # Examples
///
/// ```
/// use vyre::ir::{Program, validate};
///
/// let program = Program::new(Vec::new(), [1, 1, 1], Vec::new());
/// let errors = validate(&program);
/// assert!(errors.is_empty());
/// ```
#[inline]
pub fn validate(program: &Program) -> Vec<ValidationError> {
    let mut errors = Vec::with_capacity(program.buffers().len() + program.entry().len());

    for (axis, &size) in program.workgroup_size.iter().enumerate() {
        if size == 0 {
            errors.push(err(format!(
                "workgroup_size[{axis}] is 0. Fix: all workgroup dimensions must be >= 1."
            )));
        }
    }

    let mut seen_names = FxHashSet::default();
    seen_names.reserve(program.buffers().len());
    let mut seen_bindings = FxHashSet::default();
    seen_bindings.reserve(program.buffers().len());
    for buf in program.buffers() {
        if !seen_names.insert(&buf.name) {
            errors.push(err(format!(
                "duplicate buffer name `{}`. Fix: each buffer must have a unique name.",
                buf.name
            )));
        }
        if buf.access != BufferAccess::Workgroup && !seen_bindings.insert(buf.binding) {
            errors.push(err(format!(
                "duplicate binding slot {} (buffer `{}`). Fix: each buffer must have a unique binding.",
                buf.binding, buf.name
            )));
        }
        if buf.access == BufferAccess::Workgroup && buf.count == 0 {
            errors.push(err(format!(
                "workgroup buffer `{}` has count 0. Fix: declare a positive element count.",
                buf.name
            )));
        }
    }
    validate_output_markers(program.buffers(), &mut errors);

    let mut buffer_map: FxHashMap<&str, &crate::ir::model::program::BufferDecl> =
        FxHashMap::default();
    buffer_map.reserve(program.buffers().len());
    buffer_map.extend(program.buffers().iter().map(|b| (b.name.as_str(), b)));

    let mut scope = FxHashMap::default();
    let mut limits = depth::LimitState::default();
    nodes::validate_nodes(
        program.entry(),
        &buffer_map,
        &mut scope,
        false, // divergent
        0,
        &mut limits,
        &mut errors,
    );

    errors
}