use cubecl_core::ir::{Plane, UnaryOperator, Variable};
use rspirv::spirv::{Capability, GroupOperation, Scope, Word};
use crate::{SpirvCompiler, SpirvTarget, item::Elem};
impl<T: SpirvTarget> SpirvCompiler<T> {
pub fn compile_plane(&mut self, plane: Plane, out: Option<Variable>, uniform: bool) {
self.capabilities
.insert(Capability::GroupNonUniformArithmetic);
let subgroup = self.subgroup();
let out = out.unwrap();
match plane {
Plane::Elect => {
let out = self.compile_variable(out);
let out_id = self.write_id(&out);
let bool = self.type_bool();
self.group_non_uniform_elect(bool, Some(out_id), subgroup)
.unwrap();
self.write(&out, out_id);
}
Plane::All(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
match out.vector_size() {
1 => {
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_type(op.input.ty).elem().id(self);
let bool_ty = self.type_bool();
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
let ids = (0..vec as u32)
.map(|i| {
let elem_i =
b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Any(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
match out.vector_size() {
1 => {
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_type(op.input.ty).elem().id(self);
let bool_ty = self.type_bool();
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
let ids = (0..vec as u32)
.map(|i| {
let elem_i =
b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Ballot(op) => {
self.capabilities.insert(Capability::GroupNonUniformBallot);
assert_eq!(
op.input.vector_size(),
1,
"plane_ballot can't work with vectorized values"
);
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
b.group_non_uniform_ballot(ty, Some(out), subgroup, input)
.unwrap();
});
}
Plane::Broadcast(op) => {
let is_broadcast = self.uniformity.is_var_uniform(op.rhs);
self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
match is_broadcast {
true => {
b.capabilities.insert(Capability::GroupNonUniformBallot);
b.group_non_uniform_broadcast(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
}
false => {
b.capabilities.insert(Capability::GroupNonUniformShuffle);
b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
}
}
});
}
Plane::Sum(op) => {
self.plane_sum(op, out, GroupOperation::Reduce, uniform);
}
Plane::ExclusiveSum(op) => {
self.plane_sum(op, out, GroupOperation::ExclusiveScan, uniform);
}
Plane::InclusiveSum(op) => {
self.plane_sum(op, out, GroupOperation::InclusiveScan, uniform);
}
Plane::Prod(op) => {
self.plane_prod(op, out, GroupOperation::Reduce, uniform);
}
Plane::ExclusiveProd(op) => {
self.plane_prod(op, out, GroupOperation::ExclusiveScan, uniform);
}
Plane::InclusiveProd(op) => {
self.plane_prod(op, out, GroupOperation::InclusiveScan, uniform);
}
Plane::Min(op) => {
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
match out_ty.elem() {
Elem::Int(_, false) => b.group_non_uniform_u_min(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
Elem::Int(_, true) => b.group_non_uniform_s_min(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_min(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
_ => unreachable!(),
}
.unwrap();
});
}
Plane::Max(op) => {
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
match out_ty.elem() {
Elem::Int(_, false) => b.group_non_uniform_u_max(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
Elem::Int(_, true) => b.group_non_uniform_s_max(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_max(
ty,
Some(out),
subgroup,
GroupOperation::Reduce,
input,
None,
),
_ => unreachable!(),
}
.unwrap();
});
}
Plane::Shuffle(op) => {
self.capabilities.insert(Capability::GroupNonUniformShuffle);
self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
});
}
Plane::ShuffleXor(op) => {
self.capabilities.insert(Capability::GroupNonUniformShuffle);
self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.group_non_uniform_shuffle_xor(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
});
}
Plane::ShuffleUp(op) => {
self.capabilities.insert(Capability::GroupNonUniformShuffle);
self.capabilities
.insert(Capability::GroupNonUniformShuffleRelative);
self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.group_non_uniform_shuffle_up(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
});
}
Plane::ShuffleDown(op) => {
self.capabilities.insert(Capability::GroupNonUniformShuffle);
self.capabilities
.insert(Capability::GroupNonUniformShuffleRelative);
self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.group_non_uniform_shuffle_down(ty, Some(out), subgroup, lhs, rhs)
.unwrap();
});
}
}
}
fn plane_sum(
&mut self,
op: UnaryOperator,
out: Variable,
action: GroupOperation,
uniform: bool,
) {
let subgroup = self.subgroup();
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
match out_ty.elem() {
Elem::Int(_, _) => {
b.group_non_uniform_i_add(ty, Some(out), subgroup, action, input, None)
}
Elem::Float(..) | Elem::Relaxed => {
b.group_non_uniform_f_add(ty, Some(out), subgroup, action, input, None)
}
elem => unreachable!("{elem}"),
}
.unwrap();
});
}
fn plane_prod(
&mut self,
op: UnaryOperator,
out: Variable,
action: GroupOperation,
uniform: bool,
) {
let subgroup = self.subgroup();
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
match out_ty.elem() {
Elem::Int(_, _) => {
b.group_non_uniform_i_mul(ty, Some(out), subgroup, action, input, None)
}
Elem::Float(..) | Elem::Relaxed => {
b.group_non_uniform_f_mul(ty, Some(out), subgroup, action, input, None)
}
_ => unreachable!(),
}
.unwrap();
});
}
fn subgroup(&mut self) -> Word {
self.const_u32(Scope::Subgroup as u32)
}
}