use crate::ir::model::expr::Expr;
use crate::ir::model::program::BufferDecl;
use crate::ir::model::types::{BinOp, DataType};
use crate::ir::validate::{err, Binding, ValidationError};
use rustc_hash::FxHashMap;
#[inline]
pub(crate) fn validate_u32_binop_operand(
side: &str,
expr: &Expr,
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &FxHashMap<String, Binding>,
errors: &mut Vec<ValidationError>,
) {
if let Some(ty) = expr_type(expr, buffers, scope) {
if !matches!(
ty,
DataType::U32 | DataType::Bytes | DataType::F32 | DataType::I32 | DataType::Bool
) {
errors.push(err(format!(
"binary operation {side} operand must be `u32` or `f32`, got `{ty}`. Fix: cast or rewrite the operand to produce U32 or F32."
)));
}
}
}
#[inline]
pub(crate) fn validate_unop_operand(
op: &crate::ir::model::types::UnOp,
expr: &Expr,
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &FxHashMap<String, Binding>,
errors: &mut Vec<ValidationError>,
) {
if let Some(ty) = expr_type(expr, buffers, scope) {
match op {
crate::ir::model::types::UnOp::Negate => {
if !matches!(ty, DataType::U32 | DataType::I32 | DataType::F32) {
errors.push(err(format!(
"unary operation `{op:?}` operand must be a signed scalar, got `{ty}`. Fix: cast or rewrite the operand to U32/I32/F32."
)));
}
}
crate::ir::model::types::UnOp::LogicalNot => {
if !matches!(ty, DataType::U32 | DataType::Bool | DataType::I32) {
errors.push(err(format!(
"unary operation `{op:?}` operand must be `u32` or `bool`, got `{ty}`. Fix: cast or rewrite the operand to produce U32."
)));
}
}
crate::ir::model::types::UnOp::BitNot
| crate::ir::model::types::UnOp::Popcount
| crate::ir::model::types::UnOp::Clz
| crate::ir::model::types::UnOp::Ctz
| crate::ir::model::types::UnOp::ReverseBits => {
if !matches!(ty, DataType::U32 | DataType::I32) {
errors.push(err(format!(
"unary operation `{op:?}` operand must be an integer, got `{ty}`. Fix: cast or rewrite the operand to produce U32 or I32."
)));
}
}
crate::ir::model::types::UnOp::Sin
| crate::ir::model::types::UnOp::Cos
| crate::ir::model::types::UnOp::Abs
| crate::ir::model::types::UnOp::Sqrt
| crate::ir::model::types::UnOp::Floor
| crate::ir::model::types::UnOp::Ceil
| crate::ir::model::types::UnOp::Round
| crate::ir::model::types::UnOp::Trunc
| crate::ir::model::types::UnOp::Sign
| crate::ir::model::types::UnOp::IsNan
| crate::ir::model::types::UnOp::IsInf
| crate::ir::model::types::UnOp::IsFinite => {
if ty != DataType::F32 {
errors.push(err(format!(
"unary operation `{op:?}` operand must be `f32`, got `{ty}`. Fix: cast or rewrite the operand to produce F32."
)));
}
}
}
}
}
#[inline]
pub(crate) fn expr_type(
expr: &Expr,
buffers: &FxHashMap<&str, &BufferDecl>,
scope: &FxHashMap<String, Binding>,
) -> Option<DataType> {
enum Frame<'a> {
Enter(&'a Expr),
Bin,
Un,
Select,
Fma,
}
let mut frames = vec![Frame::Enter(expr)];
let mut values: Vec<Option<DataType>> = Vec::new();
while let Some(frame) = frames.pop() {
match frame {
Frame::Enter(expr) => match expr {
Expr::LitU32(_) => values.push(Some(DataType::U32)),
Expr::LitI32(_) => values.push(Some(DataType::I32)),
Expr::LitF32(_) => values.push(Some(DataType::F32)),
Expr::LitBool(_) => values.push(Some(DataType::Bool)),
Expr::Var(name) => values.push(scope.get(name.as_str()).map(|b| b.ty.clone())),
Expr::Load { buffer, .. } => {
values.push(buffers.get(buffer.as_str()).map(|b| b.element.clone()))
}
Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::Atomic { .. } => values.push(Some(DataType::U32)),
Expr::Call { .. } => values.push(None),
Expr::Cast { target, .. } => values.push(Some(target.clone())),
Expr::BinOp { op, left, right } => {
if matches!(
op,
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Min | BinOp::Max
) {
frames.push(Frame::Bin);
frames.push(Frame::Enter(right));
frames.push(Frame::Enter(left));
} else {
values.push(Some(DataType::U32));
}
}
Expr::UnOp { op, operand } => match op {
crate::ir::model::types::UnOp::Negate
| crate::ir::model::types::UnOp::BitNot
| crate::ir::model::types::UnOp::Popcount
| crate::ir::model::types::UnOp::Clz
| crate::ir::model::types::UnOp::Ctz
| crate::ir::model::types::UnOp::ReverseBits => {
frames.push(Frame::Un);
frames.push(Frame::Enter(operand));
}
crate::ir::model::types::UnOp::LogicalNot => values.push(Some(DataType::U32)),
crate::ir::model::types::UnOp::Sin
| crate::ir::model::types::UnOp::Cos
| crate::ir::model::types::UnOp::Abs
| crate::ir::model::types::UnOp::Sqrt
| crate::ir::model::types::UnOp::Floor
| crate::ir::model::types::UnOp::Ceil
| crate::ir::model::types::UnOp::Round
| crate::ir::model::types::UnOp::Trunc
| crate::ir::model::types::UnOp::Sign => values.push(Some(DataType::F32)),
crate::ir::model::types::UnOp::IsNan
| crate::ir::model::types::UnOp::IsInf
| crate::ir::model::types::UnOp::IsFinite => values.push(Some(DataType::Bool)),
},
Expr::Select {
true_val,
false_val,
..
} => {
frames.push(Frame::Select);
frames.push(Frame::Enter(false_val));
frames.push(Frame::Enter(true_val));
}
Expr::Fma { a, b, c } => {
frames.push(Frame::Fma);
frames.push(Frame::Enter(c));
frames.push(Frame::Enter(b));
frames.push(Frame::Enter(a));
}
},
Frame::Bin => {
let r = values.pop().unwrap_or(None);
let l = values.pop().unwrap_or(None);
if l == r && l == Some(DataType::F32) {
values.push(Some(DataType::F32));
} else {
values.push(Some(
l.as_ref()
.filter(|_| l == r)
.cloned()
.unwrap_or(DataType::U32),
));
}
}
Frame::Un => {
let operand = values.pop().unwrap_or(None);
values.push(operand);
}
Frame::Select => {
let f = values.pop().unwrap_or(None);
let t = values.pop().unwrap_or(None);
values.push(if t == f { t } else { None });
}
Frame::Fma => {
let tc = values.pop().unwrap_or(None);
let tb = values.pop().unwrap_or(None);
let ta = values.pop().unwrap_or(None);
values.push(
if ta == Some(DataType::F32)
&& tb == Some(DataType::F32)
&& tc == Some(DataType::F32)
{
Some(DataType::F32)
} else {
None
},
);
}
}
}
values.pop().flatten()
}