use super::{append_wgsl, emit_expr_string};
use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::{BinOp, DataType, UnOp};
use crate::lower::wgsl::{Error, LowerCtx};
pub(super) fn emit_binop(
out: &mut String,
op: BinOp,
left: &Expr,
right: &Expr,
ty: Option<DataType>,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
match op {
BinOp::Add => emit_infix(out, left, " + ", right, program, ctx),
BinOp::Sub => emit_infix(out, left, " - ", right, program, ctx),
BinOp::Mul => emit_infix(out, left, " * ", right, program, ctx),
BinOp::Div => {
if ty == Some(DataType::I32) {
emit_call2(out, "_vyre_safe_div_i32", left, right, program, ctx)
} else if ty == Some(DataType::F32) {
emit_infix(out, left, " / ", right, program, ctx)
} else {
emit_call2(out, "_vyre_safe_div_u32", left, right, program, ctx)
}
}
BinOp::Mod => emit_call2(out, "_vyre_safe_mod_u32", left, right, program, ctx),
BinOp::BitAnd => emit_infix(out, left, " & ", right, program, ctx),
BinOp::BitOr => emit_infix(out, left, " | ", right, program, ctx),
BinOp::BitXor => emit_infix(out, left, " ^ ", right, program, ctx),
BinOp::Shl => emit_shift(out, left, " << ", right, program, ctx),
BinOp::Shr => emit_shift(out, left, " >> ", right, program, ctx),
BinOp::Eq => emit_comparison(out, left, " == ", right, program, ctx),
BinOp::Ne => emit_comparison(out, left, " != ", right, program, ctx),
BinOp::Lt => emit_comparison(out, left, " < ", right, program, ctx),
BinOp::Gt => emit_comparison(out, left, " > ", right, program, ctx),
BinOp::Le => emit_comparison(out, left, " <= ", right, program, ctx),
BinOp::Ge => emit_comparison(out, left, " >= ", right, program, ctx),
BinOp::And => emit_bool_pair(out, left, " && ", right, program, ctx),
BinOp::Or => emit_bool_pair(out, left, " || ", right, program, ctx),
BinOp::AbsDiff => emit_abs_diff(out, left, right, program, ctx),
BinOp::Min => emit_call2(out, "min", left, right, program, ctx),
BinOp::Max => emit_call2(out, "max", left, right, program, ctx),
}
}
pub(super) fn emit_unop(
out: &mut String,
op: UnOp,
operand: &Expr,
ty: Option<DataType>,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
match op {
UnOp::Negate => match ty {
Some(DataType::U32) => {
out.push_str("(~");
emit_expr_string(out, operand, program, ctx)?;
out.push_str(" + 1u)");
Ok(())
}
_ => emit_prefix(out, "-", operand, program, ctx),
},
UnOp::BitNot => emit_prefix(out, "~", operand, program, ctx),
UnOp::LogicalNot => {
out.push_str("select(0u, 1u, (");
emit_expr_string(out, operand, program, ctx)?;
out.push_str(" == 0u))");
Ok(())
}
UnOp::Popcount => emit_call1(out, "countOneBits", operand, program, ctx),
UnOp::Clz => emit_call1(out, "countLeadingZeros", operand, program, ctx),
UnOp::Ctz => emit_call1(out, "countTrailingZeros", operand, program, ctx),
UnOp::ReverseBits => emit_call1(out, "reverseBits", operand, program, ctx),
UnOp::Sin => emit_call1(out, "sin", operand, program, ctx),
UnOp::Cos => emit_call1(out, "cos", operand, program, ctx),
UnOp::Abs => emit_call1(out, "abs", operand, program, ctx),
UnOp::Sqrt => emit_call1(out, "sqrt", operand, program, ctx),
UnOp::Floor => emit_call1(out, "floor", operand, program, ctx),
UnOp::Ceil => emit_call1(out, "ceil", operand, program, ctx),
UnOp::Round => emit_call1(out, "round", operand, program, ctx),
UnOp::Trunc => emit_call1(out, "trunc", operand, program, ctx),
UnOp::Sign => emit_call1(out, "sign", operand, program, ctx),
UnOp::IsNan => emit_predicate(out, "isNan", operand, program, ctx),
UnOp::IsInf => emit_predicate(out, "isInf", operand, program, ctx),
UnOp::IsFinite => emit_predicate(out, "isFinite", operand, program, ctx),
}
}
fn emit_infix(
out: &mut String,
left: &Expr,
operator: &str,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push('(');
emit_expr_string(out, left, program, ctx)?;
out.push_str(operator);
emit_expr_string(out, right, program, ctx)?;
out.push(')');
Ok(())
}
fn emit_call1(
out: &mut String,
name: &str,
operand: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
append_wgsl(out, format_args!("{name}("))?;
emit_expr_string(out, operand, program, ctx)?;
out.push(')');
Ok(())
}
fn emit_call2(
out: &mut String,
name: &str,
left: &Expr,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
append_wgsl(out, format_args!("{name}("))?;
emit_expr_string(out, left, program, ctx)?;
out.push_str(", ");
emit_expr_string(out, right, program, ctx)?;
out.push(')');
Ok(())
}
fn emit_prefix(
out: &mut String,
operator: &str,
operand: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str(operator);
emit_expr_string(out, operand, program, ctx)
}
fn emit_shift(
out: &mut String,
left: &Expr,
operator: &str,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push('(');
emit_expr_string(out, left, program, ctx)?;
out.push_str(operator);
out.push('(');
emit_expr_string(out, right, program, ctx)?;
out.push_str(" & 31u))");
Ok(())
}
fn emit_comparison(
out: &mut String,
left: &Expr,
operator: &str,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, (");
emit_expr_string(out, left, program, ctx)?;
out.push_str(operator);
emit_expr_string(out, right, program, ctx)?;
out.push_str("))");
Ok(())
}
fn emit_bool_pair(
out: &mut String,
left: &Expr,
operator: &str,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, ((");
emit_expr_string(out, left, program, ctx)?;
out.push_str(" != 0u)");
out.push_str(operator);
out.push('(');
emit_expr_string(out, right, program, ctx)?;
out.push_str(" != 0u)))");
Ok(())
}
fn emit_abs_diff(
out: &mut String,
left: &Expr,
right: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select((");
emit_expr_string(out, left, program, ctx)?;
out.push_str(" - ");
emit_expr_string(out, right, program, ctx)?;
out.push_str("), (");
emit_expr_string(out, right, program, ctx)?;
out.push_str(" - ");
emit_expr_string(out, left, program, ctx)?;
out.push_str("), (");
emit_expr_string(out, left, program, ctx)?;
out.push_str(" < ");
emit_expr_string(out, right, program, ctx)?;
out.push_str("))");
Ok(())
}
fn emit_predicate(
out: &mut String,
name: &str,
operand: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
append_wgsl(out, format_args!("select(0u, 1u, {name}("))?;
emit_expr_string(out, operand, program, ctx)?;
out.push_str("))");
Ok(())
}