vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::expr::Expr;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::{BinOp, DataType};
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;

#[inline]
pub(crate) fn validate_u32_binop_operand(
    side: &str,
    expr: &Expr,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &FxHashMap<String, Binding>,
    errors: &mut Vec<ValidationError>,
) {
    if let Some(ty) = expr_type(expr, buffers, scope) {
        // Bytes, Bool, and I32 all lower to u32 on GPU (Bool as {0u,1u},
        // I32 as two's-complement bit reinterpretation), so every 32-bit
        // scalar type is binop-compatible. F32 is a distinct width-class.
        if !matches!(
            ty,
            DataType::U32 | DataType::Bytes | DataType::F32 | DataType::I32 | DataType::Bool
        ) {
            errors.push(err(format!(
                "binary operation {side} operand must be `u32` or `f32`, got `{ty}`. Fix: cast or rewrite the operand to produce U32 or F32."
            )));
        }
    }
}

#[inline]
pub(crate) fn validate_unop_operand(
    op: &crate::ir::model::types::UnOp,
    expr: &Expr,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &FxHashMap<String, Binding>,
    errors: &mut Vec<ValidationError>,
) {
    if let Some(ty) = expr_type(expr, buffers, scope) {
        match op {
            crate::ir::model::types::UnOp::Negate => {
                // Negate applies to every signed scalar type. I32 is the
                // canonical signed form (primitive.math.neg); F32 uses it
                // for IEEE-754 negation; U32 keeps the two's-complement
                // reading for bitwise-negation composition callers.
                if !matches!(ty, DataType::U32 | DataType::I32 | DataType::F32) {
                    errors.push(err(format!(
                        "unary operation `{op:?}` operand must be a signed scalar, got `{ty}`. Fix: cast or rewrite the operand to U32/I32/F32."
                    )));
                }
            }
            crate::ir::model::types::UnOp::LogicalNot => {
                if !matches!(ty, DataType::U32 | DataType::Bool | DataType::I32) {
                    errors.push(err(format!(
                        "unary operation `{op:?}` operand must be `u32` or `bool`, got `{ty}`. Fix: cast or rewrite the operand to produce U32."
                    )));
                }
            }
            crate::ir::model::types::UnOp::BitNot
            | crate::ir::model::types::UnOp::Popcount
            | crate::ir::model::types::UnOp::Clz
            | crate::ir::model::types::UnOp::Ctz
            | crate::ir::model::types::UnOp::ReverseBits => {
                if !matches!(ty, DataType::U32 | DataType::I32) {
                    errors.push(err(format!(
                        "unary operation `{op:?}` operand must be an integer, got `{ty}`. Fix: cast or rewrite the operand to produce U32 or I32."
                    )));
                }
            }
            crate::ir::model::types::UnOp::Sin
            | crate::ir::model::types::UnOp::Cos
            | crate::ir::model::types::UnOp::Abs
            | crate::ir::model::types::UnOp::Sqrt
            | crate::ir::model::types::UnOp::Floor
            | crate::ir::model::types::UnOp::Ceil
            | crate::ir::model::types::UnOp::Round
            | crate::ir::model::types::UnOp::Trunc
            | crate::ir::model::types::UnOp::Sign
            | crate::ir::model::types::UnOp::IsNan
            | crate::ir::model::types::UnOp::IsInf
            | crate::ir::model::types::UnOp::IsFinite => {
                if ty != DataType::F32 {
                    errors.push(err(format!(
                        "unary operation `{op:?}` operand must be `f32`, got `{ty}`. Fix: cast or rewrite the operand to produce F32."
                    )));
                }
            }
        }
    }
}

/// Infer the static type of an expression, if it can be determined from the IR.
#[inline]
pub(crate) fn expr_type(
    expr: &Expr,
    buffers: &FxHashMap<&str, &BufferDecl>,
    scope: &FxHashMap<String, Binding>,
) -> Option<DataType> {
    enum Frame<'a> {
        Enter(&'a Expr),
        Bin,
        Un,
        Select,
        Fma,
    }

    let mut frames = vec![Frame::Enter(expr)];
    let mut values: Vec<Option<DataType>> = Vec::new();
    while let Some(frame) = frames.pop() {
        match frame {
            Frame::Enter(expr) => match expr {
                Expr::LitU32(_) => values.push(Some(DataType::U32)),
                Expr::LitI32(_) => values.push(Some(DataType::I32)),
                Expr::LitF32(_) => values.push(Some(DataType::F32)),
                Expr::LitBool(_) => values.push(Some(DataType::Bool)),
                Expr::Var(name) => values.push(scope.get(name.as_str()).map(|b| b.ty.clone())),
                Expr::Load { buffer, .. } => {
                    values.push(buffers.get(buffer.as_str()).map(|b| b.element.clone()))
                }
                Expr::BufLen { .. }
                | Expr::InvocationId { .. }
                | Expr::WorkgroupId { .. }
                | Expr::LocalId { .. }
                | Expr::Atomic { .. } => values.push(Some(DataType::U32)),
                Expr::Call { .. } => values.push(None),
                Expr::Cast { target, .. } => values.push(Some(target.clone())),
                Expr::BinOp { op, left, right } => {
                    if matches!(
                        op,
                        BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Min | BinOp::Max
                    ) {
                        frames.push(Frame::Bin);
                        frames.push(Frame::Enter(right));
                        frames.push(Frame::Enter(left));
                    } else {
                        values.push(Some(DataType::U32));
                    }
                }
                Expr::UnOp { op, operand } => match op {
                    crate::ir::model::types::UnOp::Negate
                    | crate::ir::model::types::UnOp::BitNot
                    | crate::ir::model::types::UnOp::Popcount
                    | crate::ir::model::types::UnOp::Clz
                    | crate::ir::model::types::UnOp::Ctz
                    | crate::ir::model::types::UnOp::ReverseBits => {
                        frames.push(Frame::Un);
                        frames.push(Frame::Enter(operand));
                    }
                    crate::ir::model::types::UnOp::LogicalNot => values.push(Some(DataType::U32)),
                    crate::ir::model::types::UnOp::Sin
                    | crate::ir::model::types::UnOp::Cos
                    | crate::ir::model::types::UnOp::Abs
                    | crate::ir::model::types::UnOp::Sqrt
                    | crate::ir::model::types::UnOp::Floor
                    | crate::ir::model::types::UnOp::Ceil
                    | crate::ir::model::types::UnOp::Round
                    | crate::ir::model::types::UnOp::Trunc
                    | crate::ir::model::types::UnOp::Sign => values.push(Some(DataType::F32)),
                    crate::ir::model::types::UnOp::IsNan
                    | crate::ir::model::types::UnOp::IsInf
                    | crate::ir::model::types::UnOp::IsFinite => values.push(Some(DataType::Bool)),
                },
                Expr::Select {
                    true_val,
                    false_val,
                    ..
                } => {
                    frames.push(Frame::Select);
                    frames.push(Frame::Enter(false_val));
                    frames.push(Frame::Enter(true_val));
                }
                Expr::Fma { a, b, c } => {
                    frames.push(Frame::Fma);
                    frames.push(Frame::Enter(c));
                    frames.push(Frame::Enter(b));
                    frames.push(Frame::Enter(a));
                }
            },
            Frame::Bin => {
                let r = values.pop().unwrap_or(None);
                let l = values.pop().unwrap_or(None);
                if l == r && l == Some(DataType::F32) {
                    values.push(Some(DataType::F32));
                } else {
                    values.push(Some(
                        l.as_ref()
                            .filter(|_| l == r)
                            .cloned()
                            .unwrap_or(DataType::U32),
                    ));
                }
            }
            Frame::Un => {
                let operand = values.pop().unwrap_or(None);
                values.push(operand);
            }
            Frame::Select => {
                let f = values.pop().unwrap_or(None);
                let t = values.pop().unwrap_or(None);
                values.push(if t == f { t } else { None });
            }
            Frame::Fma => {
                let tc = values.pop().unwrap_or(None);
                let tb = values.pop().unwrap_or(None);
                let ta = values.pop().unwrap_or(None);
                values.push(
                    if ta == Some(DataType::F32)
                        && tb == Some(DataType::F32)
                        && tc == Some(DataType::F32)
                    {
                        Some(DataType::F32)
                    } else {
                        None
                    },
                );
            }
        }
    }
    values.pop().flatten()
}