vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::{is_commutative, ExprKey, TypeKey};
use crate::ir::{BinOp, Expr, UnOp};
use std::sync::Arc;

impl ExprKey {
    #[inline]
    pub(crate) fn from_expr(expr: &Expr) -> Self {
        match expr {
            Expr::LitU32(value) => Self::LitU32(*value),
            Expr::LitI32(value) => Self::LitI32(*value),
            Expr::LitF32(value) => Self::LitF32(value.to_bits()),
            Expr::LitBool(value) => Self::LitBool(*value),
            Expr::Var(name) => Self::Var(Arc::from(name.as_str())),
            Expr::Load { buffer, index } => {
                Self::Load(Arc::from(buffer.as_str()), Box::new(Self::from_expr(index)))
            }
            Expr::BufLen { buffer } => Self::BufLen(Arc::from(buffer.as_str())),
            Expr::InvocationId { axis } => Self::InvocationId(*axis),
            Expr::WorkgroupId { axis } => Self::WorkgroupId(*axis),
            Expr::LocalId { axis } => Self::LocalId(*axis),
            Expr::BinOp { op, left, right } => {
                let mut left = Self::from_expr(left);
                let mut right = Self::from_expr(right);
                if is_commutative(op) && right < left {
                    std::mem::swap(&mut left, &mut right);
                }
                Self::BinOp(bin_op_key(op), Box::new(left), Box::new(right))
            }
            Expr::UnOp { op, operand } => {
                Self::UnOp(un_op_key(op), Box::new(Self::from_expr(operand)))
            }
            Expr::Call { op_id, args } => Self::Call(
                Arc::from(op_id.as_str()),
                args.iter().map(Self::from_expr).collect::<Vec<_>>(),
            ),
            Expr::Fma { a, b, c } => Self::Call(
                Arc::from("fma"),
                vec![Self::from_expr(a), Self::from_expr(b), Self::from_expr(c)],
            ),
            Expr::Select {
                cond,
                true_val,
                false_val,
            } => Self::Select(
                Box::new(Self::from_expr(cond)),
                Box::new(Self::from_expr(true_val)),
                Box::new(Self::from_expr(false_val)),
            ),
            Expr::Cast { target, value } => {
                Self::Cast(TypeKey::from(target), Box::new(Self::from_expr(value)))
            }
            Expr::Atomic { .. } => Self::Atomic,
        }
    }
}

#[inline]
fn bin_op_key(op: &BinOp) -> u8 {
    match op {
        BinOp::Add => 0,
        BinOp::Sub => 1,
        BinOp::Mul => 2,
        BinOp::Div => 3,
        BinOp::Mod => 4,
        BinOp::BitAnd => 5,
        BinOp::BitOr => 6,
        BinOp::BitXor => 7,
        BinOp::Shl => 8,
        BinOp::Shr => 9,
        BinOp::Eq => 10,
        BinOp::Ne => 11,
        BinOp::Lt => 12,
        BinOp::Gt => 13,
        BinOp::Le => 14,
        BinOp::Ge => 15,
        BinOp::And => 16,
        BinOp::Or => 17,
        BinOp::AbsDiff => 18,
        BinOp::Min => 19,
        BinOp::Max => 20,
    }
}

#[inline]
fn un_op_key(op: &UnOp) -> u8 {
    match op {
        UnOp::Negate => 0,
        UnOp::BitNot => 1,
        UnOp::LogicalNot => 2,
        UnOp::Popcount => 3,
        UnOp::Clz => 4,
        UnOp::Ctz => 5,
        UnOp::ReverseBits => 6,
        UnOp::Sin => 7,
        UnOp::Cos => 8,
        UnOp::Abs => 9,
        UnOp::Sqrt => 10,
        UnOp::Floor => 11,
        UnOp::Ceil => 12,
        UnOp::Round => 13,
        UnOp::Trunc => 14,
        UnOp::Sign => 15,
        UnOp::IsNan => 16,
        UnOp::IsInf => 17,
        UnOp::IsFinite => 18,
    }
}