vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::expr::Expr;
use crate::ir::model::program::{BufferDecl, Program};
use crate::ir::model::types::{BufferAccess, DataType};
use crate::ir::validate::atomic_rules;
use crate::ir::validate::bytes_rejection;
use crate::ir::validate::cast::cast_is_valid;
use crate::ir::validate::depth::{self, DEFAULT_MAX_CALL_DEPTH};
use crate::ir::validate::typecheck::{self, expr_type, validate_u32_binop_operand};
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;

#[allow(clippy::too_many_lines)]
#[inline]
pub(crate) fn validate_expr(
    expr: &Expr,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &FxHashMap<String, Binding>,
    errors: &mut Vec<ValidationError>,
) {
    match expr {
        Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_) => {}
        Expr::Var(name) => {
            if !scope.contains_key(name.as_str()) {
                errors.push(err(format!(
                    "reference to undeclared variable `{name}`. Fix: add `let {name} = ...;` before this use."
                )));
            }
        }
        Expr::Load { buffer, index } => {
            bytes_rejection::check_load(buffer, buffers, errors);
            validate_expr(index, buffers, scope, errors);
        }
        Expr::BufLen { buffer } => {
            if !buffers.contains_key(buffer.as_str()) {
                errors.push(err(format!(
                    "buflen of unknown buffer `{buffer}`. Fix: declare it in Program::buffers."
                )));
            }
        }
        Expr::InvocationId { axis } | Expr::WorkgroupId { axis } | Expr::LocalId { axis } => {
            if *axis > 2 {
                errors.push(err(format!(
                    "invocation/workgroup ID axis {axis} out of range. Fix: use 0 (x), 1 (y), or 2 (z)."
                )));
            }
        }
        Expr::BinOp { left, right, .. } => {
            validate_expr(left, buffers, scope, errors);
            validate_expr(right, buffers, scope, errors);
            validate_u32_binop_operand("left", left, buffers, scope, errors);
            validate_u32_binop_operand("right", right, buffers, scope, errors);
        }
        Expr::UnOp { op, operand } => {
            validate_expr(operand, buffers, scope, errors);
            typecheck::validate_unop_operand(op, operand, buffers, scope, errors);
        }
        Expr::Call { op_id, args } => {
            if let Some(spec) = crate::ops::registry::lookup(op_id) {
                if !spec.inlinable() {
                    errors.push(err(format!(
                        "V020: call to non-inlinable op `{op_id}` is rejected by validation. Fix: lower this operation through its dedicated backend path or rewrite the caller with explicit IR."
                    )));
                } else if let Some(callee) = spec.program() {
                    let expected = call_input_count(&callee);
                    if args.len() != expected {
                        errors.push(err(format!(
                            "V021: call to `{op_id}` has {} args but callee declares {expected} ReadOnly/Uniform inputs. Fix: pass exactly one argument per input buffer in binding order.",
                            args.len()
                        )));
                    }
                    let outputs = output_marker_count(&callee.buffers);
                    if outputs != 1 {
                        errors.push(err(format!(
                            "V022: inline-able op `{op_id}` declares {outputs} output buffers. Fix: mark exactly one result buffer with BufferDecl::output(...)."
                        )));
                    }
                    if depth::max_call_depth(op_id, 1).is_err() {
                        errors.push(err(format!(
                            "V017: call depth exceeds maximum of {DEFAULT_MAX_CALL_DEPTH}. Fix: reduce call nesting or mutually recursive operations."
                        )));
                    }
                }
            } else {
                errors.push(err(format!(
                    "V016: unknown op `{op_id}`. Fix: use a registered op id or add the op to core::ops::*."
                )));
            }
            for arg in args {
                validate_expr(arg, buffers, scope, errors);
            }
        }
        Expr::Fma { a, b, c } => {
            validate_expr(a, buffers, scope, errors);
            validate_expr(b, buffers, scope, errors);
            validate_expr(c, buffers, scope, errors);
        }
        Expr::Select {
            cond,
            true_val,
            false_val,
        } => {
            validate_expr(cond, buffers, scope, errors);
            validate_expr(true_val, buffers, scope, errors);
            validate_expr(false_val, buffers, scope, errors);
        }
        Expr::Cast { target, value } => {
            validate_expr(value, buffers, scope, errors);
            if let Some(src) = expr_type(value, buffers, scope) {
                if target == &DataType::Bytes && src != DataType::Bytes {
                    errors.push(err(
                        "V023: cast to Bytes is unsupported in WGSL lowering. Fix: use buffer load/store directly for byte data."
                            .to_string(),
                    ));
                } else if !cast_is_valid(src.clone(), target.clone()) {
                    errors.push(err(format!(
                        "V012: unsupported cast from `{src}` to `{target}`. Fix: use a supported casts.md conversion or rewrite the expression before validation."
                    )));
                }
            }
        }
        Expr::Atomic {
            op,
            buffer,
            index,
            expected,
            value,
        } => {
            atomic_rules::validate_atomic(
                op,
                buffer,
                index,
                expected.as_deref(),
                value,
                buffers,
                scope,
                errors,
            );
            validate_expr(index, buffers, scope, errors);
            if let Some(expected) = expected {
                validate_expr(expected, buffers, scope, errors);
            }
            validate_expr(value, buffers, scope, errors);
        }
    }
}

#[inline]
pub(crate) fn validate_output_markers(buffers: &[BufferDecl], errors: &mut Vec<ValidationError>) {
    let outputs = output_marker_count(buffers);
    if outputs > 1 {
        errors.push(err(format!(
            "V022: program declares {outputs} output buffers. Fix: mark at most one result buffer with BufferDecl::output(...)."
        )));
    }
}

#[inline]
pub fn output_marker_count(buffers: &[BufferDecl]) -> usize {
    buffers.iter().filter(|buf| buf.is_output()).count()
}

#[inline]
pub fn call_input_count(program: &Program) -> usize {
    program
        .buffers
        .iter()
        .filter(|buf| matches!(buf.access, BufferAccess::ReadOnly | BufferAccess::Uniform))
        .count()
}