vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::{append_wgsl, emit_expr_string};
use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::DataType;
use crate::lower::wgsl::{Error, LowerCtx};

pub(super) fn emit_cast_expr(
    out: &mut String,
    source: DataType,
    target: DataType,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    if source == target {
        return emit_expr_string(out, value, program, ctx);
    }
    match (&source, &target) {
        (DataType::U32, DataType::I32) => {
            emit_wrapped(out, "bitcast<i32>(", ")", value, program, ctx)
        }
        (DataType::U32, DataType::Bool) | (DataType::Bool, DataType::U32) => {
            emit_bool_u32_expr(out, value, "0u", program, ctx)
        }
        (DataType::U32, DataType::U64) => emit_vec2_scalar_zero(out, value, program, ctx),
        (DataType::U32, DataType::Vec2U32) => emit_vec2_repeated(out, value, program, ctx),
        (DataType::U32, DataType::Vec4U32) => emit_vec4_repeated(out, value, program, ctx),
        (DataType::I32, DataType::U32) | (DataType::F32, DataType::U32) => {
            emit_wrapped(out, "bitcast<u32>(", ")", value, program, ctx)
        }
        (DataType::I32, DataType::Bool) => emit_bool_i32_expr(out, value, program, ctx),
        (DataType::I32, DataType::U64) => {
            out.push_str("vec2<u32>(bitcast<u32>(");
            emit_expr_string(out, value, program, ctx)?;
            out.push_str("), select(0u, 0xffffffffu, (");
            emit_expr_string(out, value, program, ctx)?;
            out.push_str(" < 0i)))");
            Ok(())
        }
        (DataType::I32, DataType::Vec2U32) => emit_vec2_bitcast_repeated(out, value, program, ctx),
        (DataType::I32, DataType::Vec4U32) => emit_vec4_bitcast_repeated(out, value, program, ctx),
        (DataType::Bool, DataType::I32) => {
            out.push_str("i32(");
            emit_bool_u32_expr(out, value, "0u", program, ctx)?;
            out.push(')');
            Ok(())
        }
        (DataType::Bool, DataType::U64) => {
            out.push_str("vec2<u32>(");
            emit_bool_u32_expr(out, value, "0u", program, ctx)?;
            out.push_str(", 0u)");
            Ok(())
        }
        (DataType::Bool, DataType::Vec2U32) => emit_vec2_bool(out, value, program, ctx),
        (DataType::Bool, DataType::Vec4U32) => emit_vec4_bool(out, value, program, ctx),
        (DataType::U32, DataType::F32) => {
            emit_wrapped(out, "bitcast<f32>(", ")", value, program, ctx)
        }
        (DataType::U64 | DataType::Vec2U32, DataType::U32) => {
            emit_component(out, value, ".x", program, ctx)
        }
        (DataType::U64 | DataType::Vec2U32, DataType::I32) => {
            out.push_str("bitcast<i32>((");
            emit_expr_string(out, value, program, ctx)?;
            out.push_str(").x)");
            Ok(())
        }
        (DataType::U64 | DataType::Vec2U32, DataType::Bool) => {
            emit_vec2_truth(out, value, program, ctx)
        }
        (DataType::U64, DataType::Vec2U32) | (DataType::Vec2U32, DataType::U64) => {
            emit_expr_string(out, value, program, ctx)
        }
        (DataType::Vec4U32, DataType::U32) => emit_component(out, value, ".x", program, ctx),
        (DataType::Vec4U32, DataType::I32) => {
            out.push_str("bitcast<i32>((");
            emit_expr_string(out, value, program, ctx)?;
            out.push_str(").x)");
            Ok(())
        }
        (DataType::Vec4U32, DataType::Vec2U32 | DataType::U64) => {
            emit_vec4_to_vec2(out, value, program, ctx)
        }
        (DataType::Vec4U32, DataType::Bool) => emit_vec4_truth(out, value, program, ctx),
        _ => Err(Error::lowering(format!(
            "unsupported cast from `{source}` to `{target}` reached WGSL lowering. Fix: run Program validation before lowering and reject unsupported casts."
        ))),
    }
}

fn emit_wrapped(
    out: &mut String,
    prefix: &str,
    suffix: &str,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str(prefix);
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(suffix);
    Ok(())
}

fn emit_bool_u32_expr(
    out: &mut String,
    value: &Expr,
    zero: &str,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, (");
    emit_expr_string(out, value, program, ctx)?;
    append_wgsl(out, format_args!(" != {zero}))"))
}

fn emit_bool_i32_expr(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, (");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(" != 0i))");
    Ok(())
}

fn emit_vec2_scalar_zero(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec2<u32>(");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(", 0u)");
    Ok(())
}

fn emit_vec2_repeated(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec2<u32>(");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(", ");
    emit_expr_string(out, value, program, ctx)?;
    out.push(')');
    Ok(())
}

fn emit_vec4_repeated(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec4<u32>(");
    for index in 0..4 {
        if index > 0 {
            out.push_str(", ");
        }
        emit_expr_string(out, value, program, ctx)?;
    }
    out.push(')');
    Ok(())
}

fn emit_vec2_bitcast_repeated(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec2<u32>(bitcast<u32>(");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str("), bitcast<u32>(");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str("))");
    Ok(())
}

fn emit_vec4_bitcast_repeated(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec4<u32>(");
    for index in 0..4 {
        if index > 0 {
            out.push_str(", ");
        }
        out.push_str("bitcast<u32>(");
        emit_expr_string(out, value, program, ctx)?;
        out.push(')');
    }
    out.push(')');
    Ok(())
}

fn emit_vec2_bool(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec2<u32>(");
    emit_bool_u32_expr(out, value, "0u", program, ctx)?;
    out.push_str(", ");
    emit_bool_u32_expr(out, value, "0u", program, ctx)?;
    out.push(')');
    Ok(())
}

fn emit_vec4_bool(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec4<u32>(");
    for index in 0..4 {
        if index > 0 {
            out.push_str(", ");
        }
        emit_bool_u32_expr(out, value, "0u", program, ctx)?;
    }
    out.push(')');
    Ok(())
}

fn emit_component(
    out: &mut String,
    value: &Expr,
    suffix: &str,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push('(');
    emit_expr_string(out, value, program, ctx)?;
    out.push(')');
    out.push_str(suffix);
    Ok(())
}

fn emit_vec2_truth(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, (((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").x != 0u) || ((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").y != 0u)))");
    Ok(())
}

fn emit_vec4_to_vec2(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("vec2<u32>((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").x, (");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").y)");
    Ok(())
}

fn emit_vec4_truth(
    out: &mut String,
    value: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, (((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").x != 0u) || ((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").y != 0u) || ((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").z != 0u) || ((");
    emit_expr_string(out, value, program, ctx)?;
    out.push_str(").w != 0u)))");
    Ok(())
}