vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::BufferAccess;
use crate::lower::wgsl::analysis::expr_type;
use crate::lower::wgsl::emit::axis_letter;
use crate::lower::wgsl::{Error, LowerCtx};
use std::fmt::{Arguments, Write as _};

mod atomic;
mod cast;
mod operators;

#[inline]
pub(crate) fn emit_expr(
    out: &mut String,
    expr: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    match expr {
        Expr::LitU32(v) => append_wgsl(out, format_args!("{v}u")),
        Expr::LitI32(v) => append_wgsl(out, format_args!("{v}i")),
        Expr::LitF32(v) => append_wgsl(out, format_args!("{v}f")),
        Expr::LitBool(true) => {
            out.push_str("1u");
            Ok(())
        }
        Expr::LitBool(false) => {
            out.push_str("0u");
            Ok(())
        }
        Expr::Var(name) => {
            out.push_str(&LowerCtx::safe_var(name));
            Ok(())
        }
        Expr::Load { buffer, index } => {
            if let Some(b) = program.buffer(buffer) {
                if b.access == BufferAccess::Workgroup {
                    append_wgsl(out, format_args!("_vyre_load_{buffer}("))?;
                    emit_expr_string(out, index, program, ctx)?;
                    out.push(')');
                    return Ok(());
                }
            } else {
                return Err(lowering(format_args!(
                    "load references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
                )));
            }
            append_wgsl(out, format_args!("_vyre_load_{buffer}("))?;
            emit_expr_string(out, index, program, ctx)?;
            out.push(')');
            Ok(())
        }
        Expr::BufLen { buffer } => {
            if let Some(b) = program.buffer(buffer) {
                if b.access == BufferAccess::Workgroup {
                    return append_wgsl(out, format_args!("{}u", b.count));
                }
            } else {
                return Err(lowering(format_args!(
                    "length references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
                )));
            }
            append_wgsl(out, format_args!("arrayLength(&{buffer}.data)"))
        }
        Expr::InvocationId { axis } => {
            append_wgsl(out, format_args!("_vyre_gid.{}", axis_letter(*axis)?))
        }
        Expr::WorkgroupId { axis } => {
            append_wgsl(out, format_args!("_vyre_wgid.{}", axis_letter(*axis)?))
        }
        Expr::LocalId { axis } => {
            append_wgsl(out, format_args!("_vyre_lid.{}", axis_letter(*axis)?))
        }
        Expr::BinOp { op, left, right } => {
            let ty = expr_type(left, &ctx.buffer_map, &ctx.vars);
            operators::emit_binop(out, op.clone(), left, right, ty, program, ctx)
        }
        Expr::UnOp { op, operand } => {
            let ty = expr_type(operand, &ctx.buffer_map, &ctx.vars);
            operators::emit_unop(out, op.clone(), operand, ty, program, ctx)
        }
        Expr::Call { op_id, args } => {
            let _ = args;
            Err(lowering(format_args!(
                "residual call to `{op_id}` reached WGSL emission. Fix: run ir::inline_calls successfully before lowering and ensure the operation is a registered Category A composition."
            )))
        }
        Expr::Fma { a, b, c } => {
            out.push_str("fma(");
            emit_expr_string(out, a, program, ctx)?;
            out.push_str(", ");
            emit_expr_string(out, b, program, ctx)?;
            out.push_str(", ");
            emit_expr_string(out, c, program, ctx)?;
            out.push(')');
            Ok(())
        }
        Expr::Select {
            cond,
            true_val,
            false_val,
        } => {
            out.push_str("select(");
            emit_expr_string(out, false_val, program, ctx)?;
            out.push_str(", ");
            emit_expr_string(out, true_val, program, ctx)?;
            out.push_str(", (");
            emit_expr_string(out, cond, program, ctx)?;
            out.push_str(" != 0u))");
            Ok(())
        }
        Expr::Cast { target, value } => {
            let source = expr_type(value, &ctx.buffer_map, &ctx.vars).ok_or_else(|| {
                lowering(format_args!(
                    "cast source type for target `{target}` is unknown. Fix: validate operation calls and ensure the cast input has a statically known IR type."
                ))
            })?;
            cast::emit_cast_expr(out, source, target.clone(), value, program, ctx)
        }
        Expr::Atomic {
            op,
            buffer,
            index,
            expected,
            value,
        } => {
            if !ctx.atomic_buffers.contains(buffer.as_str()) {
                return Err(lowering(format_args!(
                    "atomic expression targets buffer `{buffer}` but it was not marked atomic by the lowering pre-pass. Fix: run the complete Program entry through WGSL lowering."
                )));
            }
            atomic::emit_atomic_expr(
                out,
                op.clone(),
                buffer,
                index,
                expected.as_deref(),
                value,
                program,
                ctx,
            )
        }
    }
}

#[inline]
pub(crate) fn emit_expr_string(
    out: &mut String,
    expr: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    emit_expr(out, expr, program, ctx)
}

pub(super) fn append_wgsl(out: &mut String, args: Arguments<'_>) -> Result<(), Error> {
    out.write_fmt(args).map_err(|_| {
        Error::lowering(
            "WGSL string write failed. Fix: retry lowering with a valid in-memory String buffer."
                .to_string(),
        )
    })
}

pub(super) fn lowering(args: Arguments<'_>) -> Error {
    Error::lowering(args.to_string())
}

#[cfg(test)]
mod tests;