vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::{append_wgsl, emit_expr_string};
use crate::ir::model::expr::Expr;
use crate::ir::model::program::Program;
use crate::ir::model::types::{BinOp, DataType, UnOp};
use crate::lower::wgsl::{Error, LowerCtx};

pub(super) fn emit_binop(
    out: &mut String,
    op: BinOp,
    left: &Expr,
    right: &Expr,
    ty: Option<DataType>,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    match op {
        BinOp::Add => emit_infix(out, left, " + ", right, program, ctx),
        BinOp::Sub => emit_infix(out, left, " - ", right, program, ctx),
        BinOp::Mul => emit_infix(out, left, " * ", right, program, ctx),
        BinOp::Div => {
            if ty == Some(DataType::I32) {
                emit_call2(out, "_vyre_safe_div_i32", left, right, program, ctx)
            } else if ty == Some(DataType::F32) {
                emit_infix(out, left, " / ", right, program, ctx)
            } else {
                emit_call2(out, "_vyre_safe_div_u32", left, right, program, ctx)
            }
        }
        BinOp::Mod => emit_call2(out, "_vyre_safe_mod_u32", left, right, program, ctx),
        BinOp::BitAnd => emit_infix(out, left, " & ", right, program, ctx),
        BinOp::BitOr => emit_infix(out, left, " | ", right, program, ctx),
        BinOp::BitXor => emit_infix(out, left, " ^ ", right, program, ctx),
        BinOp::Shl => emit_shift(out, left, " << ", right, program, ctx),
        BinOp::Shr => emit_shift(out, left, " >> ", right, program, ctx),
        BinOp::Eq => emit_comparison(out, left, " == ", right, program, ctx),
        BinOp::Ne => emit_comparison(out, left, " != ", right, program, ctx),
        BinOp::Lt => emit_comparison(out, left, " < ", right, program, ctx),
        BinOp::Gt => emit_comparison(out, left, " > ", right, program, ctx),
        BinOp::Le => emit_comparison(out, left, " <= ", right, program, ctx),
        BinOp::Ge => emit_comparison(out, left, " >= ", right, program, ctx),
        BinOp::And => emit_bool_pair(out, left, " && ", right, program, ctx),
        BinOp::Or => emit_bool_pair(out, left, " || ", right, program, ctx),
        BinOp::AbsDiff => emit_abs_diff(out, left, right, program, ctx),
        BinOp::Min => emit_call2(out, "min", left, right, program, ctx),
        BinOp::Max => emit_call2(out, "max", left, right, program, ctx),
    }
}

pub(super) fn emit_unop(
    out: &mut String,
    op: UnOp,
    operand: &Expr,
    ty: Option<DataType>,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    match op {
        UnOp::Negate => match ty {
            Some(DataType::U32) => {
                out.push_str("(~");
                emit_expr_string(out, operand, program, ctx)?;
                out.push_str(" + 1u)");
                Ok(())
            }
            _ => emit_prefix(out, "-", operand, program, ctx),
        },
        UnOp::BitNot => emit_prefix(out, "~", operand, program, ctx),
        UnOp::LogicalNot => {
            out.push_str("select(0u, 1u, (");
            emit_expr_string(out, operand, program, ctx)?;
            out.push_str(" == 0u))");
            Ok(())
        }
        UnOp::Popcount => emit_call1(out, "countOneBits", operand, program, ctx),
        UnOp::Clz => emit_call1(out, "countLeadingZeros", operand, program, ctx),
        UnOp::Ctz => emit_call1(out, "countTrailingZeros", operand, program, ctx),
        UnOp::ReverseBits => emit_call1(out, "reverseBits", operand, program, ctx),
        UnOp::Sin => emit_call1(out, "sin", operand, program, ctx),
        UnOp::Cos => emit_call1(out, "cos", operand, program, ctx),
        UnOp::Abs => emit_call1(out, "abs", operand, program, ctx),
        UnOp::Sqrt => emit_call1(out, "sqrt", operand, program, ctx),
        UnOp::Floor => emit_call1(out, "floor", operand, program, ctx),
        UnOp::Ceil => emit_call1(out, "ceil", operand, program, ctx),
        UnOp::Round => emit_call1(out, "round", operand, program, ctx),
        UnOp::Trunc => emit_call1(out, "trunc", operand, program, ctx),
        UnOp::Sign => emit_call1(out, "sign", operand, program, ctx),
        UnOp::IsNan => emit_predicate(out, "isNan", operand, program, ctx),
        UnOp::IsInf => emit_predicate(out, "isInf", operand, program, ctx),
        UnOp::IsFinite => emit_predicate(out, "isFinite", operand, program, ctx),
    }
}

fn emit_infix(
    out: &mut String,
    left: &Expr,
    operator: &str,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push('(');
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(operator);
    emit_expr_string(out, right, program, ctx)?;
    out.push(')');
    Ok(())
}

fn emit_call1(
    out: &mut String,
    name: &str,
    operand: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    append_wgsl(out, format_args!("{name}("))?;
    emit_expr_string(out, operand, program, ctx)?;
    out.push(')');
    Ok(())
}

fn emit_call2(
    out: &mut String,
    name: &str,
    left: &Expr,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    append_wgsl(out, format_args!("{name}("))?;
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(", ");
    emit_expr_string(out, right, program, ctx)?;
    out.push(')');
    Ok(())
}

fn emit_prefix(
    out: &mut String,
    operator: &str,
    operand: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str(operator);
    emit_expr_string(out, operand, program, ctx)
}

fn emit_shift(
    out: &mut String,
    left: &Expr,
    operator: &str,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push('(');
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(operator);
    out.push('(');
    emit_expr_string(out, right, program, ctx)?;
    out.push_str(" & 31u))");
    Ok(())
}

fn emit_comparison(
    out: &mut String,
    left: &Expr,
    operator: &str,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, (");
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(operator);
    emit_expr_string(out, right, program, ctx)?;
    out.push_str("))");
    Ok(())
}

fn emit_bool_pair(
    out: &mut String,
    left: &Expr,
    operator: &str,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select(0u, 1u, ((");
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(" != 0u)");
    out.push_str(operator);
    out.push('(');
    emit_expr_string(out, right, program, ctx)?;
    out.push_str(" != 0u)))");
    Ok(())
}

fn emit_abs_diff(
    out: &mut String,
    left: &Expr,
    right: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    out.push_str("select((");
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(" - ");
    emit_expr_string(out, right, program, ctx)?;
    out.push_str("), (");
    emit_expr_string(out, right, program, ctx)?;
    out.push_str(" - ");
    emit_expr_string(out, left, program, ctx)?;
    out.push_str("), (");
    emit_expr_string(out, left, program, ctx)?;
    out.push_str(" < ");
    emit_expr_string(out, right, program, ctx)?;
    out.push_str("))");
    Ok(())
}

fn emit_predicate(
    out: &mut String,
    name: &str,
    operand: &Expr,
    program: &Program,
    ctx: &LowerCtx<'_>,
) -> Result<(), Error> {
    append_wgsl(out, format_args!("select(0u, 1u, {name}("))?;
    emit_expr_string(out, operand, program, ctx)?;
    out.push_str("))");
    Ok(())
}