vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::node::Node;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::DataType;
use crate::ir::validate::barrier;
use crate::ir::validate::bytes_rejection;
use crate::ir::validate::depth::{self, LimitState};
use crate::ir::validate::expr_rules::validate_expr;
use crate::ir::validate::shadowing;
use crate::ir::validate::typecheck::expr_type;
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;

type ScopeLog = Vec<(String, Option<Binding>)>;

#[inline]
pub fn validate_nodes(
    nodes: &[Node],
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &mut FxHashMap<String, Binding>,
    divergent: bool,
    depth: usize,
    limits: &mut LimitState,
    errors: &mut Vec<ValidationError>,
) {
    validate_nodes_inner(
        nodes, buffers, scope, divergent, depth, limits, errors, None,
    );
}

fn validate_nodes_inner(
    nodes: &[Node],
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &mut FxHashMap<String, Binding>,
    divergent: bool,
    depth: usize,
    limits: &mut LimitState,
    errors: &mut Vec<ValidationError>,
    mut scope_log: Option<&mut ScopeLog>,
) {
    for node in nodes {
        validate_node_inner(
            node,
            buffers,
            scope,
            divergent,
            depth,
            limits,
            errors,
            scope_log.as_deref_mut(),
        );
    }

    if let Some(pos) = nodes.iter().position(|n| matches!(n, Node::Return)) {
        if pos != nodes.len().saturating_sub(1) {
            errors.push(err(
                "unreachable statements after `return`. Fix: remove statements after `return` or reorder them.".to_string(),
            ));
        }
    }
}

#[allow(clippy::too_many_lines)]
fn validate_node_inner(
    node: &Node,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &mut FxHashMap<String, Binding>,
    divergent: bool,
    depth: usize,
    limits: &mut LimitState,
    errors: &mut Vec<ValidationError>,
    scope_log: Option<&mut ScopeLog>,
) {
    depth::check_limits(limits, depth, errors);

    match node {
        Node::Let { name, value } => {
            validate_expr(value, buffers, scope, errors);
            shadowing::check_local(name, scope, errors);
            let ty = expr_type(value, buffers, scope).unwrap_or(DataType::U32);
            insert_binding(
                scope,
                name.clone(),
                Binding { ty, mutable: true },
                scope_log,
            );
        }
        Node::Assign { name, value } => {
            if let Some(binding) = scope.get(name.as_str()) {
                if !binding.mutable {
                    errors.push(err(format!(
                        "V011: assignment to loop variable `{name}`. Fix: loop variables are immutable."
                    )));
                }
            } else {
                errors.push(err(format!(
                    "assignment to undeclared variable `{name}`. Fix: add `let {name} = ...;` before this assignment."
                )));
            }
            validate_expr(value, buffers, scope, errors);
        }
        Node::Store {
            buffer,
            index,
            value,
        } => {
            bytes_rejection::check_store(buffer, buffers, errors);
            if let Some(buf) = buffers.get(buffer.as_str()) {
                if let Some(val_ty) = expr_type(value, buffers, scope) {
                    // Bytes, Bool, and U32 are equivalent on GPU (all lower to u32
                    // storage). Bool is the IR-level name for a {0u,1u} u32 result.
                    let elem = &buf.element;
                    let compatible = val_ty == *elem
                        || matches!(
                            (&val_ty, elem),
                            (DataType::U32, DataType::Bytes)
                                | (DataType::Bytes, DataType::U32)
                                | (DataType::U32, DataType::Bool)
                                | (DataType::Bool, DataType::U32)
                        )
                        || matches!((&val_ty, elem), (DataType::F32, DataType::F32));
                    if !compatible {
                        errors.push(err(format!(
                            "store value type `{val_ty}` does not match buffer `{buffer}` element type `{elem}`. Fix: insert an explicit cast or use a matching type.",
                            elem = elem
                        )));
                    }
                }
            }
            validate_expr(index, buffers, scope, errors);
            validate_expr(value, buffers, scope, errors);
        }
        Node::If {
            cond,
            then,
            otherwise,
        } => {
            validate_expr(cond, buffers, scope, errors);
            if let Some(cond_ty) = expr_type(cond, buffers, scope) {
                if !matches!(cond_ty, DataType::U32 | DataType::Bool) {
                    errors.push(err(format!(
                        "if condition must be `u32` or `bool`, got `{cond_ty}`. Fix: cast or rewrite the condition to produce U32 or Bool."
                    )));
                }
            }
            validate_scoped_nested_nodes(
                then,
                buffers,
                scope,
                true,
                depth,
                limits,
                errors,
                |_, _| {},
            );
            validate_scoped_nested_nodes(
                otherwise,
                buffers,
                scope,
                true,
                depth,
                limits,
                errors,
                |_, _| {},
            );
        }
        Node::Loop {
            var,
            from,
            to,
            body,
        } => {
            validate_expr(from, buffers, scope, errors);
            validate_expr(to, buffers, scope, errors);
            if let Some(from_ty) = expr_type(from, buffers, scope) {
                if from_ty != DataType::U32 {
                    errors.push(err(format!(
                        "V015: loop bound expression must be `u32`, got `{from_ty}`. Fix: ensure `from` and `to` are U32."
                    )));
                }
            }
            if let Some(to_ty) = expr_type(to, buffers, scope) {
                if to_ty != DataType::U32 {
                    errors.push(err(format!(
                        "V015: loop bound expression must be `u32`, got `{to_ty}`. Fix: ensure `from` and `to` are U32."
                    )));
                }
            }
            shadowing::check_local(var, scope, errors);
            validate_scoped_nested_nodes(
                body,
                buffers,
                scope,
                true,
                depth,
                limits,
                errors,
                |scope, scope_log| {
                    insert_binding(
                        scope,
                        var.clone(),
                        Binding {
                            ty: DataType::U32,
                            mutable: false,
                        },
                        Some(scope_log),
                    );
                },
            );
        }
        Node::Return => {}
        Node::Block(nodes) => {
            validate_scoped_nested_nodes(
                nodes,
                buffers,
                scope,
                divergent,
                depth,
                limits,
                errors,
                |_, _| {},
            );
        }
        Node::Barrier => {
            barrier::check_barrier(divergent, errors);
        }
    }
}

fn validate_scoped_nested_nodes(
    nodes: &[Node],
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &mut FxHashMap<String, Binding>,
    divergent: bool,
    depth: usize,
    limits: &mut LimitState,
    errors: &mut Vec<ValidationError>,
    configure_scope: impl FnOnce(&mut FxHashMap<String, Binding>, &mut ScopeLog),
) {
    let mut scope_log = Vec::new();
    configure_scope(scope, &mut scope_log);
    validate_nodes_inner(
        nodes,
        buffers,
        scope,
        divergent,
        depth.saturating_add(1),
        limits,
        errors,
        Some(&mut scope_log),
    );
    restore_scope(scope, scope_log);
}

fn insert_binding(
    scope: &mut FxHashMap<String, Binding>,
    name: String,
    binding: Binding,
    scope_log: Option<&mut ScopeLog>,
) {
    let previous = scope.insert(name.clone(), binding);
    if let Some(scope_log) = scope_log {
        scope_log.push((name, previous));
    }
}

fn restore_scope(scope: &mut FxHashMap<String, Binding>, mut scope_log: ScopeLog) {
    while let Some((name, previous)) = scope_log.pop() {
        if let Some(binding) = previous {
            scope.insert(name, binding);
        } else {
            scope.remove(&name);
        }
    }
}