vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Validation rules for atomic memory operations.
//!
//! Atomic operations are among the most error-prone primitives in GPU
//! compute: they require a read-write buffer, a `u32` element type, and
//! (for compare-exchange) a correctly supplied expected value. This
//! module checks all of those preconditions so that malformed atomics
//! are caught at IR validation time rather than producing silent data
//! races on the GPU.

use crate::ir::model::expr::Expr;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::{AtomicOp, BufferAccess, DataType};
use crate::ir::validate::typecheck::expr_type;
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;

/// Validate an `Expr::Atomic` against buffer and type rules.
///
/// The validator enforces four invariants:
/// 1. The target buffer is declared with `BufferAccess::ReadWrite`.
/// 2. The buffer's element type is `U32` (atomics do not support `Bytes`
///    or other scalar types).
/// 3. The value operand is `U32`.
/// 4. For `AtomicOp::CompareExchange`, an expected operand is present and
///    is also `U32`; for all other ops, no expected operand is present.
///
/// Violations are appended to `errors` as `ValidationError` values with
/// actionable `Fix:` hints.
///
/// # Examples
///
/// ```
/// use vyre::ir::validate::atomic_rules::validate_atomic;
/// use vyre::ir::model::types::{AtomicOp, DataType, BufferAccess};
/// use vyre::ir::model::program::BufferDecl;
/// use vyre::ir::model::expr::Expr;
/// use vyre::ir::validate::{Binding, ValidationError};
/// use rustc_hash::FxHashMap;
///
/// let mut errors = Vec::new();
/// let buffers: FxHashMap<&str, &BufferDecl> = FxHashMap::default();
/// let scope: FxHashMap<String, Binding> = FxHashMap::default();
///
/// // Missing buffer declaration -> error
/// validate_atomic(
///     AtomicOp::Add,
///     "missing",
///     &Expr::u32(0),
///     None,
///     &Expr::u32(1),
///     &buffers,
///     &scope,
///     &mut errors,
/// );
/// assert!(!errors.is_empty());
/// ```
///
/// # Errors
///
/// Appends a `ValidationError` when any of the invariants above is
/// violated.
#[inline]
pub fn validate_atomic(
    op: &AtomicOp,
    buffer: &str,
    _index: &Expr,
    expected: Option<&Expr>,
    value: &Expr,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &FxHashMap<String, Binding>,
    errors: &mut Vec<ValidationError>,
) {
    if let Some(buf) = buffers.get(buffer) {
        // L.1.36 / audit finding #5: split the "non-writable" check so
        // Workgroup buffers get their own V025 code. The vyre atomic
        // memory model is currently only defined for `ReadWrite`
        // storage buffers; Workgroup atomics need additional OOB and
        // ordering specification that has not yet been committed. V009
        // stays reserved for `ReadOnly`/`Uniform` targets only.
        match &buf.access {
            BufferAccess::ReadWrite => {}
            BufferAccess::Workgroup => {
                errors.push(err(format!(
                    "V025: atomic on workgroup buffer `{buffer}` is not yet specified. Fix: use a storage ReadWrite buffer for atomics, or wait for the workgroup-atomic memory model to land."
                )));
            }
            _ => {
                errors.push(err(format!(
                    "V009: atomic on non-writable buffer `{buffer}`. Fix: declare it with BufferAccess::ReadWrite."
                )));
            }
        }
        if buf.element == DataType::Bytes {
            errors.push(err(format!(
                "V013: operation on buffer `{buffer}` with element type `bytes` is not supported. Fix: use a typed buffer."
            )));
        }
        if buf.element != DataType::U32 {
            errors.push(err(format!(
                "V014: atomic on buffer `{buffer}` with non-u32 element type `{elem}`. Fix: atomics only support U32 elements.",
                elem = buf.element
            )));
        }
        if let Some(val_ty) = expr_type(value, buffers, scope) {
            if val_ty != DataType::U32 {
                errors.push(err(format!(
                    "atomic value type `{val_ty}` does not match required `u32`. Fix: ensure the atomic operand is U32."
                )));
            }
        }
        match (op, expected) {
            (AtomicOp::CompareExchange, Some(expected_expr)) => {
                if let Some(expected_ty) = expr_type(expected_expr, buffers, scope) {
                    if expected_ty != DataType::U32 {
                        errors.push(err(format!(
                            "compare-exchange expected type `{expected_ty}` does not match required `u32`. Fix: ensure Expr::Atomic.expected is U32."
                        )));
                    }
                }
            }
            (AtomicOp::CompareExchange, None) => errors.push(err(
                "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange."
                    .to_string(),
            )),
            (_, Some(_)) => errors.push(err(
                "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange."
                    .to_string(),
            )),
            (_, None) => {}
        }
    } else {
        errors.push(err(format!(
            "atomic on unknown buffer `{buffer}`. Fix: declare it in Program::buffers."
        )));
    }
}