use super::{append_wgsl, emit_expr_string};
use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::DataType;
use crate::lower::wgsl::{Error, LowerCtx};
pub(super) fn emit_cast_expr(
out: &mut String,
source: DataType,
target: DataType,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
if source == target {
return emit_expr_string(out, value, program, ctx);
}
match (&source, &target) {
(DataType::U32, DataType::I32) => {
emit_wrapped(out, "bitcast<i32>(", ")", value, program, ctx)
}
(DataType::U32, DataType::Bool) | (DataType::Bool, DataType::U32) => {
emit_bool_u32_expr(out, value, "0u", program, ctx)
}
(DataType::U32, DataType::U64) => emit_vec2_scalar_zero(out, value, program, ctx),
(DataType::U32, DataType::Vec2U32) => emit_vec2_repeated(out, value, program, ctx),
(DataType::U32, DataType::Vec4U32) => emit_vec4_repeated(out, value, program, ctx),
(DataType::I32, DataType::U32) | (DataType::F32, DataType::U32) => {
emit_wrapped(out, "bitcast<u32>(", ")", value, program, ctx)
}
(DataType::I32, DataType::Bool) => emit_bool_i32_expr(out, value, program, ctx),
(DataType::I32, DataType::U64) => {
out.push_str("vec2<u32>(bitcast<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push_str("), select(0u, 0xffffffffu, (");
emit_expr_string(out, value, program, ctx)?;
out.push_str(" < 0i)))");
Ok(())
}
(DataType::I32, DataType::Vec2U32) => emit_vec2_bitcast_repeated(out, value, program, ctx),
(DataType::I32, DataType::Vec4U32) => emit_vec4_bitcast_repeated(out, value, program, ctx),
(DataType::Bool, DataType::I32) => {
out.push_str("i32(");
emit_bool_u32_expr(out, value, "0u", program, ctx)?;
out.push(')');
Ok(())
}
(DataType::Bool, DataType::U64) => {
out.push_str("vec2<u32>(");
emit_bool_u32_expr(out, value, "0u", program, ctx)?;
out.push_str(", 0u)");
Ok(())
}
(DataType::Bool, DataType::Vec2U32) => emit_vec2_bool(out, value, program, ctx),
(DataType::Bool, DataType::Vec4U32) => emit_vec4_bool(out, value, program, ctx),
(DataType::U32, DataType::F32) => {
emit_wrapped(out, "bitcast<f32>(", ")", value, program, ctx)
}
(DataType::U64 | DataType::Vec2U32, DataType::U32) => {
emit_component(out, value, ".x", program, ctx)
}
(DataType::U64 | DataType::Vec2U32, DataType::I32) => {
out.push_str("bitcast<i32>((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").x)");
Ok(())
}
(DataType::U64 | DataType::Vec2U32, DataType::Bool) => {
emit_vec2_truth(out, value, program, ctx)
}
(DataType::U64, DataType::Vec2U32) | (DataType::Vec2U32, DataType::U64) => {
emit_expr_string(out, value, program, ctx)
}
(DataType::Vec4U32, DataType::U32) => emit_component(out, value, ".x", program, ctx),
(DataType::Vec4U32, DataType::I32) => {
out.push_str("bitcast<i32>((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").x)");
Ok(())
}
(DataType::Vec4U32, DataType::Vec2U32 | DataType::U64) => {
emit_vec4_to_vec2(out, value, program, ctx)
}
(DataType::Vec4U32, DataType::Bool) => emit_vec4_truth(out, value, program, ctx),
_ => Err(Error::lowering(format!(
"unsupported cast from `{source}` to `{target}` reached WGSL lowering. Fix: run Program validation before lowering and reject unsupported casts."
))),
}
}
fn emit_wrapped(
out: &mut String,
prefix: &str,
suffix: &str,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str(prefix);
emit_expr_string(out, value, program, ctx)?;
out.push_str(suffix);
Ok(())
}
fn emit_bool_u32_expr(
out: &mut String,
value: &Expr,
zero: &str,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, (");
emit_expr_string(out, value, program, ctx)?;
append_wgsl(out, format_args!(" != {zero}))"))
}
fn emit_bool_i32_expr(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, (");
emit_expr_string(out, value, program, ctx)?;
out.push_str(" != 0i))");
Ok(())
}
fn emit_vec2_scalar_zero(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec2<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push_str(", 0u)");
Ok(())
}
fn emit_vec2_repeated(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec2<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push_str(", ");
emit_expr_string(out, value, program, ctx)?;
out.push(')');
Ok(())
}
fn emit_vec4_repeated(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec4<u32>(");
for index in 0..4 {
if index > 0 {
out.push_str(", ");
}
emit_expr_string(out, value, program, ctx)?;
}
out.push(')');
Ok(())
}
fn emit_vec2_bitcast_repeated(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec2<u32>(bitcast<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push_str("), bitcast<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push_str("))");
Ok(())
}
fn emit_vec4_bitcast_repeated(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec4<u32>(");
for index in 0..4 {
if index > 0 {
out.push_str(", ");
}
out.push_str("bitcast<u32>(");
emit_expr_string(out, value, program, ctx)?;
out.push(')');
}
out.push(')');
Ok(())
}
fn emit_vec2_bool(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec2<u32>(");
emit_bool_u32_expr(out, value, "0u", program, ctx)?;
out.push_str(", ");
emit_bool_u32_expr(out, value, "0u", program, ctx)?;
out.push(')');
Ok(())
}
fn emit_vec4_bool(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec4<u32>(");
for index in 0..4 {
if index > 0 {
out.push_str(", ");
}
emit_bool_u32_expr(out, value, "0u", program, ctx)?;
}
out.push(')');
Ok(())
}
fn emit_component(
out: &mut String,
value: &Expr,
suffix: &str,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push('(');
emit_expr_string(out, value, program, ctx)?;
out.push(')');
out.push_str(suffix);
Ok(())
}
fn emit_vec2_truth(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, (((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").x != 0u) || ((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").y != 0u)))");
Ok(())
}
fn emit_vec4_to_vec2(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("vec2<u32>((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").x, (");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").y)");
Ok(())
}
fn emit_vec4_truth(
out: &mut String,
value: &Expr,
program: &Program,
ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
out.push_str("select(0u, 1u, (((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").x != 0u) || ((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").y != 0u) || ((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").z != 0u) || ((");
emit_expr_string(out, value, program, ctx)?;
out.push_str(").w != 0u)))");
Ok(())
}