use super::{is_commutative, ExprKey, TypeKey};
use crate::ir::{BinOp, Expr, UnOp};
use std::sync::Arc;
impl ExprKey {
#[inline]
pub(crate) fn from_expr(expr: &Expr) -> Self {
match expr {
Expr::LitU32(value) => Self::LitU32(*value),
Expr::LitI32(value) => Self::LitI32(*value),
Expr::LitF32(value) => Self::LitF32(value.to_bits()),
Expr::LitBool(value) => Self::LitBool(*value),
Expr::Var(name) => Self::Var(Arc::from(name.as_str())),
Expr::Load { buffer, index } => {
Self::Load(Arc::from(buffer.as_str()), Box::new(Self::from_expr(index)))
}
Expr::BufLen { buffer } => Self::BufLen(Arc::from(buffer.as_str())),
Expr::InvocationId { axis } => Self::InvocationId(*axis),
Expr::WorkgroupId { axis } => Self::WorkgroupId(*axis),
Expr::LocalId { axis } => Self::LocalId(*axis),
Expr::BinOp { op, left, right } => {
let mut left = Self::from_expr(left);
let mut right = Self::from_expr(right);
if is_commutative(op) && right < left {
std::mem::swap(&mut left, &mut right);
}
Self::BinOp(bin_op_key(op), Box::new(left), Box::new(right))
}
Expr::UnOp { op, operand } => {
Self::UnOp(un_op_key(op), Box::new(Self::from_expr(operand)))
}
Expr::Call { op_id, args } => Self::Call(
Arc::from(op_id.as_str()),
args.iter().map(Self::from_expr).collect::<Vec<_>>(),
),
Expr::Fma { a, b, c } => Self::Call(
Arc::from("fma"),
vec![Self::from_expr(a), Self::from_expr(b), Self::from_expr(c)],
),
Expr::Select {
cond,
true_val,
false_val,
} => Self::Select(
Box::new(Self::from_expr(cond)),
Box::new(Self::from_expr(true_val)),
Box::new(Self::from_expr(false_val)),
),
Expr::Cast { target, value } => {
Self::Cast(TypeKey::from(target), Box::new(Self::from_expr(value)))
}
Expr::Atomic { .. } => Self::Atomic,
}
}
}
#[inline]
fn bin_op_key(op: &BinOp) -> u8 {
match op {
BinOp::Add => 0,
BinOp::Sub => 1,
BinOp::Mul => 2,
BinOp::Div => 3,
BinOp::Mod => 4,
BinOp::BitAnd => 5,
BinOp::BitOr => 6,
BinOp::BitXor => 7,
BinOp::Shl => 8,
BinOp::Shr => 9,
BinOp::Eq => 10,
BinOp::Ne => 11,
BinOp::Lt => 12,
BinOp::Gt => 13,
BinOp::Le => 14,
BinOp::Ge => 15,
BinOp::And => 16,
BinOp::Or => 17,
BinOp::AbsDiff => 18,
BinOp::Min => 19,
BinOp::Max => 20,
}
}
#[inline]
fn un_op_key(op: &UnOp) -> u8 {
match op {
UnOp::Negate => 0,
UnOp::BitNot => 1,
UnOp::LogicalNot => 2,
UnOp::Popcount => 3,
UnOp::Clz => 4,
UnOp::Ctz => 5,
UnOp::ReverseBits => 6,
UnOp::Sin => 7,
UnOp::Cos => 8,
UnOp::Abs => 9,
UnOp::Sqrt => 10,
UnOp::Floor => 11,
UnOp::Ceil => 12,
UnOp::Round => 13,
UnOp::Trunc => 14,
UnOp::Sign => 15,
UnOp::IsNan => 16,
UnOp::IsInf => 17,
UnOp::IsFinite => 18,
}
}