use crate::ir::model::node::Node;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::DataType;
use crate::ir::validate::barrier;
use crate::ir::validate::bytes_rejection;
use crate::ir::validate::depth::{self, LimitState};
use crate::ir::validate::expr_rules::validate_expr;
use crate::ir::validate::shadowing;
use crate::ir::validate::typecheck::expr_type;
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;
type ScopeLog = Vec<(String, Option<Binding>)>;
#[inline]
pub fn validate_nodes(
nodes: &[Node],
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &mut FxHashMap<String, Binding>,
divergent: bool,
depth: usize,
limits: &mut LimitState,
errors: &mut Vec<ValidationError>,
) {
validate_nodes_inner(
nodes, buffers, scope, divergent, depth, limits, errors, None,
);
}
fn validate_nodes_inner(
nodes: &[Node],
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &mut FxHashMap<String, Binding>,
divergent: bool,
depth: usize,
limits: &mut LimitState,
errors: &mut Vec<ValidationError>,
mut scope_log: Option<&mut ScopeLog>,
) {
for node in nodes {
validate_node_inner(
node,
buffers,
scope,
divergent,
depth,
limits,
errors,
scope_log.as_deref_mut(),
);
}
if let Some(pos) = nodes.iter().position(|n| matches!(n, Node::Return)) {
if pos != nodes.len().saturating_sub(1) {
errors.push(err(
"unreachable statements after `return`. Fix: remove statements after `return` or reorder them.".to_string(),
));
}
}
}
#[allow(clippy::too_many_lines)]
fn validate_node_inner(
node: &Node,
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &mut FxHashMap<String, Binding>,
divergent: bool,
depth: usize,
limits: &mut LimitState,
errors: &mut Vec<ValidationError>,
scope_log: Option<&mut ScopeLog>,
) {
depth::check_limits(limits, depth, errors);
match node {
Node::Let { name, value } => {
validate_expr(value, buffers, scope, errors);
shadowing::check_local(name, scope, errors);
let ty = expr_type(value, buffers, scope).unwrap_or(DataType::U32);
insert_binding(
scope,
name.clone(),
Binding { ty, mutable: true },
scope_log,
);
}
Node::Assign { name, value } => {
if let Some(binding) = scope.get(name.as_str()) {
if !binding.mutable {
errors.push(err(format!(
"V011: assignment to loop variable `{name}`. Fix: loop variables are immutable."
)));
}
} else {
errors.push(err(format!(
"assignment to undeclared variable `{name}`. Fix: add `let {name} = ...;` before this assignment."
)));
}
validate_expr(value, buffers, scope, errors);
}
Node::Store {
buffer,
index,
value,
} => {
bytes_rejection::check_store(buffer, buffers, errors);
if let Some(buf) = buffers.get(buffer.as_str()) {
if let Some(val_ty) = expr_type(value, buffers, scope) {
let elem = &buf.element;
let compatible = val_ty == *elem
|| matches!(
(&val_ty, elem),
(DataType::U32, DataType::Bytes)
| (DataType::Bytes, DataType::U32)
| (DataType::U32, DataType::Bool)
| (DataType::Bool, DataType::U32)
)
|| matches!((&val_ty, elem), (DataType::F32, DataType::F32));
if !compatible {
errors.push(err(format!(
"store value type `{val_ty}` does not match buffer `{buffer}` element type `{elem}`. Fix: insert an explicit cast or use a matching type.",
elem = elem
)));
}
}
}
validate_expr(index, buffers, scope, errors);
validate_expr(value, buffers, scope, errors);
}
Node::If {
cond,
then,
otherwise,
} => {
validate_expr(cond, buffers, scope, errors);
if let Some(cond_ty) = expr_type(cond, buffers, scope) {
if !matches!(cond_ty, DataType::U32 | DataType::Bool) {
errors.push(err(format!(
"if condition must be `u32` or `bool`, got `{cond_ty}`. Fix: cast or rewrite the condition to produce U32 or Bool."
)));
}
}
validate_scoped_nested_nodes(
then,
buffers,
scope,
true,
depth,
limits,
errors,
|_, _| {},
);
validate_scoped_nested_nodes(
otherwise,
buffers,
scope,
true,
depth,
limits,
errors,
|_, _| {},
);
}
Node::Loop {
var,
from,
to,
body,
} => {
validate_expr(from, buffers, scope, errors);
validate_expr(to, buffers, scope, errors);
if let Some(from_ty) = expr_type(from, buffers, scope) {
if from_ty != DataType::U32 {
errors.push(err(format!(
"V015: loop bound expression must be `u32`, got `{from_ty}`. Fix: ensure `from` and `to` are U32."
)));
}
}
if let Some(to_ty) = expr_type(to, buffers, scope) {
if to_ty != DataType::U32 {
errors.push(err(format!(
"V015: loop bound expression must be `u32`, got `{to_ty}`. Fix: ensure `from` and `to` are U32."
)));
}
}
shadowing::check_local(var, scope, errors);
validate_scoped_nested_nodes(
body,
buffers,
scope,
true,
depth,
limits,
errors,
|scope, scope_log| {
insert_binding(
scope,
var.clone(),
Binding {
ty: DataType::U32,
mutable: false,
},
Some(scope_log),
);
},
);
}
Node::Return => {}
Node::Block(nodes) => {
validate_scoped_nested_nodes(
nodes,
buffers,
scope,
divergent,
depth,
limits,
errors,
|_, _| {},
);
}
Node::Barrier => {
barrier::check_barrier(divergent, errors);
}
}
}
fn validate_scoped_nested_nodes(
nodes: &[Node],
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &mut FxHashMap<String, Binding>,
divergent: bool,
depth: usize,
limits: &mut LimitState,
errors: &mut Vec<ValidationError>,
configure_scope: impl FnOnce(&mut FxHashMap<String, Binding>, &mut ScopeLog),
) {
let mut scope_log = Vec::new();
configure_scope(scope, &mut scope_log);
validate_nodes_inner(
nodes,
buffers,
scope,
divergent,
depth.saturating_add(1),
limits,
errors,
Some(&mut scope_log),
);
restore_scope(scope, scope_log);
}
fn insert_binding(
scope: &mut FxHashMap<String, Binding>,
name: String,
binding: Binding,
scope_log: Option<&mut ScopeLog>,
) {
let previous = scope.insert(name.clone(), binding);
if let Some(scope_log) = scope_log {
scope_log.push((name, previous));
}
}
fn restore_scope(scope: &mut FxHashMap<String, Binding>, mut scope_log: ScopeLog) {
while let Some((name, previous)) = scope_log.pop() {
if let Some(binding) = previous {
scope.insert(name, binding);
} else {
scope.remove(&name);
}
}
}