use crate::ir_inner::model::expr::Expr;
use crate::validate::binding::Binding;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
pub(crate) fn is_uniform(expr: &Expr, scope: &FxHashMap<crate::ir::Ident, Binding>) -> bool {
let mut stack: SmallVec<[&Expr; 32]> = SmallVec::new();
stack.push(expr);
while let Some(expr) = stack.pop() {
match expr {
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_) => {}
Expr::BufLen { .. } | Expr::WorkgroupId { .. } => {}
Expr::Var(name) if scope.get(name.as_str()).is_some_and(|b| b.uniform) => {}
Expr::BinOp { left, right, .. } => {
stack.push(right);
stack.push(left);
}
Expr::UnOp { operand, .. } => stack.push(operand),
Expr::Cast { value, .. } => stack.push(value),
Expr::Select {
cond,
true_val,
false_val,
} => {
stack.push(false_val);
stack.push(true_val);
stack.push(cond);
}
Expr::Fma { a, b, c } => {
stack.push(c);
stack.push(b);
stack.push(a);
}
Expr::InvocationId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Var(_)
| Expr::Load { .. }
| Expr::Call { .. }
| Expr::Atomic { .. }
| Expr::SubgroupBallot { .. }
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. }
| Expr::Opaque(_) => return false,
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{DataType, Expr, Ident};
fn empty_scope() -> FxHashMap<crate::ir::Ident, Binding> {
FxHashMap::default()
}
fn scope_with_uniform(name: &str) -> FxHashMap<crate::ir::Ident, Binding> {
let mut scope = FxHashMap::default();
scope.insert(
crate::ir::Ident::from(name),
Binding {
ty: DataType::U32,
mutable: true,
uniform: true,
},
);
scope
}
fn scope_with_divergent(name: &str) -> FxHashMap<crate::ir::Ident, Binding> {
let mut scope = FxHashMap::default();
scope.insert(
crate::ir::Ident::from(name),
Binding {
ty: DataType::U32,
mutable: true,
uniform: false,
},
);
scope
}
#[test]
fn literals_are_uniform() {
let scope = empty_scope();
assert!(is_uniform(&Expr::u32(42), &scope));
assert!(is_uniform(&Expr::f32(std::f32::consts::PI), &scope));
assert!(is_uniform(&Expr::LitBool(true), &scope));
assert!(is_uniform(&Expr::i32(-1), &scope));
}
#[test]
fn invocation_id_is_divergent() {
let scope = empty_scope();
assert!(!is_uniform(&Expr::InvocationId { axis: 0 }, &scope));
}
#[test]
fn workgroup_id_is_uniform() {
let scope = empty_scope();
assert!(is_uniform(&Expr::WorkgroupId { axis: 0 }, &scope));
}
#[test]
fn uniform_var_is_uniform() {
let scope = scope_with_uniform("x");
assert!(is_uniform(&Expr::Var(Ident::from("x")), &scope));
}
#[test]
fn divergent_var_is_divergent() {
let scope = scope_with_divergent("x");
assert!(!is_uniform(&Expr::Var(Ident::from("x")), &scope));
}
#[test]
fn unknown_var_is_divergent() {
let scope = empty_scope();
assert!(!is_uniform(&Expr::Var(Ident::from("unknown")), &scope));
}
#[test]
fn binop_of_uniform_is_uniform() {
let scope = empty_scope();
assert!(is_uniform(&Expr::add(Expr::u32(1), Expr::u32(2)), &scope));
}
#[test]
fn binop_with_divergent_is_divergent() {
let scope = empty_scope();
let expr = Expr::add(Expr::u32(1), Expr::InvocationId { axis: 0 });
assert!(!is_uniform(&expr, &scope));
}
#[test]
fn load_is_always_divergent() {
let scope = empty_scope();
assert!(!is_uniform(&Expr::load("buf", Expr::u32(0)), &scope));
}
#[test]
fn fma_uniform_when_all_uniform() {
let scope = empty_scope();
let fma = Expr::Fma {
a: Box::new(Expr::f32(1.0)),
b: Box::new(Expr::f32(2.0)),
c: Box::new(Expr::f32(3.0)),
};
assert!(is_uniform(&fma, &scope));
}
}