use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::BufferAccess;
use crate::lower::wgsl::analysis::expr_type;
use crate::lower::wgsl::emit::axis_letter;
use crate::lower::wgsl::{Error, LowerCtx};
use std::fmt::{Arguments, Write as _};
mod atomic;
mod cast;
mod operators;
#[inline]
pub(crate) fn emit_expr(
out: &mut String,
expr: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
match expr {
Expr::LitU32(v) => append_wgsl(out, format_args!("{v}u")),
Expr::LitI32(v) => append_wgsl(out, format_args!("{v}i")),
Expr::LitF32(v) => append_wgsl(out, format_args!("{v}f")),
Expr::LitBool(true) => {
out.push_str("1u");
Ok(())
}
Expr::LitBool(false) => {
out.push_str("0u");
Ok(())
}
Expr::Var(name) => {
out.push_str(&LowerCtx::safe_var(name));
Ok(())
}
Expr::Load { buffer, index } => {
if let Some(b) = program.buffer(buffer) {
if b.access == BufferAccess::Workgroup {
append_wgsl(out, format_args!("_vyre_load_{buffer}("))?;
emit_expr_string(out, index, program, ctx)?;
out.push(')');
return Ok(());
}
} else {
return Err(lowering(format_args!(
"load references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
)));
}
append_wgsl(out, format_args!("_vyre_load_{buffer}("))?;
emit_expr_string(out, index, program, ctx)?;
out.push(')');
Ok(())
}
Expr::BufLen { buffer } => {
if let Some(b) = program.buffer(buffer) {
if b.access == BufferAccess::Workgroup {
return append_wgsl(out, format_args!("{}u", b.count));
}
} else {
return Err(lowering(format_args!(
"length references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
)));
}
append_wgsl(out, format_args!("arrayLength(&{buffer}.data)"))
}
Expr::InvocationId { axis } => {
append_wgsl(out, format_args!("_vyre_gid.{}", axis_letter(*axis)?))
}
Expr::WorkgroupId { axis } => {
append_wgsl(out, format_args!("_vyre_wgid.{}", axis_letter(*axis)?))
}
Expr::LocalId { axis } => {
append_wgsl(out, format_args!("_vyre_lid.{}", axis_letter(*axis)?))
}
Expr::BinOp { op, left, right } => {
let ty = expr_type(left, &ctx.buffer_map, &ctx.vars);
operators::emit_binop(out, op.clone(), left, right, ty, program, ctx)
}
Expr::UnOp { op, operand } => {
let ty = expr_type(operand, &ctx.buffer_map, &ctx.vars);
operators::emit_unop(out, op.clone(), operand, ty, program, ctx)
}
Expr::Call { op_id, args } => {
let _ = args;
Err(lowering(format_args!(
"residual call to `{op_id}` reached WGSL emission. Fix: run ir::inline_calls successfully before lowering and ensure the operation is a registered Category A composition."
)))
}
Expr::Fma { a, b, c } => {
out.push_str("fma(");
emit_expr_string(out, a, program, ctx)?;
out.push_str(", ");
emit_expr_string(out, b, program, ctx)?;
out.push_str(", ");
emit_expr_string(out, c, program, ctx)?;
out.push(')');
Ok(())
}
Expr::Select {
cond,
true_val,
false_val,
} => {
out.push_str("select(");
emit_expr_string(out, false_val, program, ctx)?;
out.push_str(", ");
emit_expr_string(out, true_val, program, ctx)?;
out.push_str(", (");
emit_expr_string(out, cond, program, ctx)?;
out.push_str(" != 0u))");
Ok(())
}
Expr::Cast { target, value } => {
let source = expr_type(value, &ctx.buffer_map, &ctx.vars).ok_or_else(|| {
lowering(format_args!(
"cast source type for target `{target}` is unknown. Fix: validate operation calls and ensure the cast input has a statically known IR type."
))
})?;
cast::emit_cast_expr(out, source, target.clone(), value, program, ctx)
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => {
if !ctx.atomic_buffers.contains(buffer.as_str()) {
return Err(lowering(format_args!(
"atomic expression targets buffer `{buffer}` but it was not marked atomic by the lowering pre-pass. Fix: run the complete Program entry through WGSL lowering."
)));
}
atomic::emit_atomic_expr(
out,
op.clone(),
buffer,
index,
expected.as_deref(),
value,
program,
ctx,
)
}
}
}
#[inline]
pub(crate) fn emit_expr_string(
out: &mut String,
expr: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
emit_expr(out, expr, program, ctx)
}
pub(super) fn append_wgsl(out: &mut String, args: Arguments<'_>) -> Result<(), Error> {
out.write_fmt(args).map_err(|_| {
Error::lowering(
"WGSL string write failed. Fix: retry lowering with a valid in-memory String buffer."
.to_string(),
)
})
}
pub(super) fn lowering(args: Arguments<'_>) -> Error {
Error::lowering(args.to_string())
}
#[cfg(test)]
mod tests;