use crate::ir::model::expr::Expr;
use crate::ir::model::program::{BufferDecl, Program};
use crate::ir::model::types::{BufferAccess, DataType};
use crate::ir::validate::atomic_rules;
use crate::ir::validate::bytes_rejection;
use crate::ir::validate::cast::cast_is_valid;
use crate::ir::validate::depth::{self, DEFAULT_MAX_CALL_DEPTH};
use crate::ir::validate::typecheck::{self, expr_type, validate_u32_binop_operand};
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;
#[allow(clippy::too_many_lines)]
#[inline]
pub(crate) fn validate_expr(
expr: &Expr,
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &FxHashMap<String, Binding>,
errors: &mut Vec<ValidationError>,
) {
match expr {
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_) => {}
Expr::Var(name) => {
if !scope.contains_key(name.as_str()) {
errors.push(err(format!(
"reference to undeclared variable `{name}`. Fix: add `let {name} = ...;` before this use."
)));
}
}
Expr::Load { buffer, index } => {
bytes_rejection::check_load(buffer, buffers, errors);
validate_expr(index, buffers, scope, errors);
}
Expr::BufLen { buffer } => {
if !buffers.contains_key(buffer.as_str()) {
errors.push(err(format!(
"buflen of unknown buffer `{buffer}`. Fix: declare it in Program::buffers."
)));
}
}
Expr::InvocationId { axis } | Expr::WorkgroupId { axis } | Expr::LocalId { axis } => {
if *axis > 2 {
errors.push(err(format!(
"invocation/workgroup ID axis {axis} out of range. Fix: use 0 (x), 1 (y), or 2 (z)."
)));
}
}
Expr::BinOp { left, right, .. } => {
validate_expr(left, buffers, scope, errors);
validate_expr(right, buffers, scope, errors);
validate_u32_binop_operand("left", left, buffers, scope, errors);
validate_u32_binop_operand("right", right, buffers, scope, errors);
}
Expr::UnOp { op, operand } => {
validate_expr(operand, buffers, scope, errors);
typecheck::validate_unop_operand(op, operand, buffers, scope, errors);
}
Expr::Call { op_id, args } => {
if let Some(spec) = crate::ops::registry::lookup(op_id) {
if !spec.inlinable() {
errors.push(err(format!(
"V020: call to non-inlinable op `{op_id}` is rejected by validation. Fix: lower this operation through its dedicated backend path or rewrite the caller with explicit IR."
)));
} else if let Some(callee) = spec.program() {
let expected = call_input_count(&callee);
if args.len() != expected {
errors.push(err(format!(
"V021: call to `{op_id}` has {} args but callee declares {expected} ReadOnly/Uniform inputs. Fix: pass exactly one argument per input buffer in binding order.",
args.len()
)));
}
let outputs = output_marker_count(&callee.buffers);
if outputs != 1 {
errors.push(err(format!(
"V022: inline-able op `{op_id}` declares {outputs} output buffers. Fix: mark exactly one result buffer with BufferDecl::output(...)."
)));
}
if depth::max_call_depth(op_id, 1).is_err() {
errors.push(err(format!(
"V017: call depth exceeds maximum of {DEFAULT_MAX_CALL_DEPTH}. Fix: reduce call nesting or mutually recursive operations."
)));
}
}
} else {
errors.push(err(format!(
"V016: unknown op `{op_id}`. Fix: use a registered op id or add the op to core::ops::*."
)));
}
for arg in args {
validate_expr(arg, buffers, scope, errors);
}
}
Expr::Fma { a, b, c } => {
validate_expr(a, buffers, scope, errors);
validate_expr(b, buffers, scope, errors);
validate_expr(c, buffers, scope, errors);
}
Expr::Select {
cond,
true_val,
false_val,
} => {
validate_expr(cond, buffers, scope, errors);
validate_expr(true_val, buffers, scope, errors);
validate_expr(false_val, buffers, scope, errors);
}
Expr::Cast { target, value } => {
validate_expr(value, buffers, scope, errors);
if let Some(src) = expr_type(value, buffers, scope) {
if target == &DataType::Bytes && src != DataType::Bytes {
errors.push(err(
"V023: cast to Bytes is unsupported in WGSL lowering. Fix: use buffer load/store directly for byte data."
.to_string(),
));
} else if !cast_is_valid(src.clone(), target.clone()) {
errors.push(err(format!(
"V012: unsupported cast from `{src}` to `{target}`. Fix: use a supported casts.md conversion or rewrite the expression before validation."
)));
}
}
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => {
atomic_rules::validate_atomic(
op,
buffer,
index,
expected.as_deref(),
value,
buffers,
scope,
errors,
);
validate_expr(index, buffers, scope, errors);
if let Some(expected) = expected {
validate_expr(expected, buffers, scope, errors);
}
validate_expr(value, buffers, scope, errors);
}
}
}
#[inline]
pub(crate) fn validate_output_markers(buffers: &[BufferDecl], errors: &mut Vec<ValidationError>) {
let outputs = output_marker_count(buffers);
if outputs > 1 {
errors.push(err(format!(
"V022: program declares {outputs} output buffers. Fix: mark at most one result buffer with BufferDecl::output(...)."
)));
}
}
#[inline]
pub fn output_marker_count(buffers: &[BufferDecl]) -> usize {
buffers.iter().filter(|buf| buf.is_output()).count()
}
#[inline]
pub fn call_input_count(program: &Program) -> usize {
program
.buffers
.iter()
.filter(|buf| matches!(buf.access, BufferAccess::ReadOnly | BufferAccess::Uniform))
.count()
}