1use cubecl_core::ir::{Plane, UnaryOperator, Variable};
2use rspirv::spirv::{Capability, GroupOperation, Scope, Word};
3
4use crate::{SpirvCompiler, SpirvTarget, item::Elem};
5
6impl<T: SpirvTarget> SpirvCompiler<T> {
7 pub fn compile_plane(&mut self, plane: Plane, out: Option<Variable>, uniform: bool) {
8 self.capabilities
9 .insert(Capability::GroupNonUniformArithmetic);
10 let subgroup = self.subgroup();
11 let out = out.unwrap();
12 match plane {
13 Plane::Elect => {
14 let out = self.compile_variable(out);
15 let out_id = self.write_id(&out);
16 let bool = self.type_bool();
17 self.group_non_uniform_elect(bool, Some(out_id), subgroup)
18 .unwrap();
19 self.write(&out, out_id);
20 }
21 Plane::All(op) => {
22 self.capabilities.insert(Capability::GroupNonUniformVote);
23 match out.line_size() {
24 1 => {
25 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
26 b.group_non_uniform_all(ty, Some(out), subgroup, input)
27 .unwrap();
28 });
29 }
30 vec => {
31 let elem_ty = self.compile_type(op.input.ty).elem().id(self);
32 let bool_ty = self.type_bool();
33
34 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
35 let ids = (0..vec)
36 .map(|i| {
37 let elem_i =
38 b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
39 b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
40 .unwrap()
41 })
42 .collect::<Vec<_>>();
43 b.composite_construct(ty, Some(out), ids).unwrap();
44 });
45 }
46 };
47 }
48 Plane::Any(op) => {
49 self.capabilities.insert(Capability::GroupNonUniformVote);
50 match out.line_size() {
51 1 => {
52 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
53 b.group_non_uniform_any(ty, Some(out), subgroup, input)
54 .unwrap();
55 });
56 }
57 vec => {
58 let elem_ty = self.compile_type(op.input.ty).elem().id(self);
59 let bool_ty = self.type_bool();
60
61 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
62 let ids = (0..vec)
63 .map(|i| {
64 let elem_i =
65 b.composite_extract(elem_ty, None, input, vec![i]).unwrap();
66 b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
67 .unwrap()
68 })
69 .collect::<Vec<_>>();
70 b.composite_construct(ty, Some(out), ids).unwrap();
71 });
72 }
73 };
74 }
75 Plane::Ballot(op) => {
76 self.capabilities.insert(Capability::GroupNonUniformBallot);
77 assert_eq!(
78 op.input.line_size(),
79 1,
80 "plane_ballot can't work with vectorized values"
81 );
82 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
83 b.group_non_uniform_ballot(ty, Some(out), subgroup, input)
84 .unwrap();
85 });
86 }
87 Plane::Broadcast(op) => {
88 let is_broadcast = self.uniformity.is_var_uniform(op.rhs);
89 self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
90 match is_broadcast {
91 true => {
92 b.capabilities.insert(Capability::GroupNonUniformBallot);
93 b.group_non_uniform_broadcast(ty, Some(out), subgroup, lhs, rhs)
94 .unwrap();
95 }
96 false => {
97 b.capabilities.insert(Capability::GroupNonUniformShuffle);
98 b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
99 .unwrap();
100 }
101 }
102 });
103 }
104 Plane::Sum(op) => {
105 self.plane_sum(op, out, GroupOperation::Reduce, uniform);
106 }
107 Plane::ExclusiveSum(op) => {
108 self.plane_sum(op, out, GroupOperation::ExclusiveScan, uniform);
109 }
110 Plane::InclusiveSum(op) => {
111 self.plane_sum(op, out, GroupOperation::InclusiveScan, uniform);
112 }
113 Plane::Prod(op) => {
114 self.plane_prod(op, out, GroupOperation::Reduce, uniform);
115 }
116 Plane::ExclusiveProd(op) => {
117 self.plane_prod(op, out, GroupOperation::ExclusiveScan, uniform);
118 }
119 Plane::InclusiveProd(op) => {
120 self.plane_prod(op, out, GroupOperation::InclusiveScan, uniform);
121 }
122 Plane::Min(op) => {
123 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
124 match out_ty.elem() {
125 Elem::Int(_, false) => b.group_non_uniform_u_min(
126 ty,
127 Some(out),
128 subgroup,
129 GroupOperation::Reduce,
130 input,
131 None,
132 ),
133 Elem::Int(_, true) => b.group_non_uniform_s_min(
134 ty,
135 Some(out),
136 subgroup,
137 GroupOperation::Reduce,
138 input,
139 None,
140 ),
141 Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_min(
142 ty,
143 Some(out),
144 subgroup,
145 GroupOperation::Reduce,
146 input,
147 None,
148 ),
149 _ => unreachable!(),
150 }
151 .unwrap();
152 });
153 }
154 Plane::Max(op) => {
155 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
156 match out_ty.elem() {
157 Elem::Int(_, false) => b.group_non_uniform_u_max(
158 ty,
159 Some(out),
160 subgroup,
161 GroupOperation::Reduce,
162 input,
163 None,
164 ),
165 Elem::Int(_, true) => b.group_non_uniform_s_max(
166 ty,
167 Some(out),
168 subgroup,
169 GroupOperation::Reduce,
170 input,
171 None,
172 ),
173 Elem::Float(..) | Elem::Relaxed => b.group_non_uniform_f_max(
174 ty,
175 Some(out),
176 subgroup,
177 GroupOperation::Reduce,
178 input,
179 None,
180 ),
181 _ => unreachable!(),
182 }
183 .unwrap();
184 });
185 }
186 Plane::Shuffle(op) => {
187 self.capabilities.insert(Capability::GroupNonUniformShuffle);
188 self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
189 b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs)
190 .unwrap();
191 });
192 }
193 Plane::ShuffleXor(op) => {
194 self.capabilities.insert(Capability::GroupNonUniformShuffle);
195 self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
196 b.group_non_uniform_shuffle_xor(ty, Some(out), subgroup, lhs, rhs)
197 .unwrap();
198 });
199 }
200 Plane::ShuffleUp(op) => {
201 self.capabilities.insert(Capability::GroupNonUniformShuffle);
202 self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
203 b.group_non_uniform_shuffle_up(ty, Some(out), subgroup, lhs, rhs)
204 .unwrap();
205 });
206 }
207 Plane::ShuffleDown(op) => {
208 self.capabilities.insert(Capability::GroupNonUniformShuffle);
209 self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| {
210 b.group_non_uniform_shuffle_down(ty, Some(out), subgroup, lhs, rhs)
211 .unwrap();
212 });
213 }
214 }
215 }
216
217 fn plane_sum(
218 &mut self,
219 op: UnaryOperator,
220 out: Variable,
221 action: GroupOperation,
222 uniform: bool,
223 ) {
224 let subgroup = self.subgroup();
225 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
226 match out_ty.elem() {
227 Elem::Int(_, _) => {
228 b.group_non_uniform_i_add(ty, Some(out), subgroup, action, input, None)
229 }
230 Elem::Float(..) | Elem::Relaxed => {
231 b.group_non_uniform_f_add(ty, Some(out), subgroup, action, input, None)
232 }
233 elem => unreachable!("{elem}"),
234 }
235 .unwrap();
236 });
237 }
238
239 fn plane_prod(
240 &mut self,
241 op: UnaryOperator,
242 out: Variable,
243 action: GroupOperation,
244 uniform: bool,
245 ) {
246 let subgroup = self.subgroup();
247 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
248 match out_ty.elem() {
249 Elem::Int(_, _) => {
250 b.group_non_uniform_i_mul(ty, Some(out), subgroup, action, input, None)
251 }
252 Elem::Float(..) | Elem::Relaxed => {
253 b.group_non_uniform_f_mul(ty, Some(out), subgroup, action, input, None)
254 }
255 _ => unreachable!(),
256 }
257 .unwrap();
258 });
259 }
260
261 fn subgroup(&mut self) -> Word {
262 self.const_u32(Scope::Subgroup as u32)
263 }
264}