1use cubecl_core::ir::{
2 self as core, BinaryOperator, Comparison, Instruction, InstructionModes, Operation, Operator,
3 UnaryOperator,
4};
5use rspirv::spirv::{Capability, Decoration, Word};
6
7use crate::{
8 SpirvCompiler, SpirvTarget,
9 item::{Elem, Item},
10 variable::IndexedVariable,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14 pub fn compile_operation(&mut self, inst: Instruction) {
15 if !matches!(inst.operation, Operation::NonSemantic(_)) {
17 self.set_source_loc(&inst.source_loc);
18 }
19 let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
20 match inst.operation {
21 Operation::Copy(var) => {
22 let input = self.compile_variable(var);
23 let out = self.compile_variable(inst.out());
24 let ty = out.item().id(self);
25 let in_id = self.read(&input);
26 let in_id = input.item().broadcast(self, in_id, None, &out.item());
27 let out_id = self.write_id(&out);
28
29 self.copy_object(ty, Some(out_id), in_id).unwrap();
30 self.mark_uniformity(out_id, uniform);
31 self.write(&out, out_id);
32 }
33 Operation::Arithmetic(operator) => {
34 self.compile_arithmetic(operator, inst.out, inst.modes, uniform)
35 }
36 Operation::Comparison(operator) => {
37 self.compile_cmp(operator, inst.out, inst.modes, uniform)
38 }
39 Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
40 Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
41 Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out, inst.modes),
42 Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
43 Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
44 Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
45 Operation::Synchronization(sync) => self.compile_sync(sync),
46 Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
47 Operation::NonSemantic(debug) => self.compile_debug(debug),
48 Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
49 Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
50 Operation::Marker(_) => {}
51 }
52 }
53
54 pub fn compile_cmp(
55 &mut self,
56 op: Comparison,
57 out: Option<core::Variable>,
58 modes: InstructionModes,
59 uniform: bool,
60 ) {
61 let out = out.unwrap();
62 match op {
63 Comparison::Equal(op) => {
64 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
65 match lhs_ty.elem() {
66 Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
67 Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
68 Elem::Float(..) => {
69 b.declare_math_mode(modes, out);
70 b.f_ord_equal(ty, Some(out), lhs, rhs)
71 }
72 Elem::Relaxed => {
73 b.decorate(out, Decoration::RelaxedPrecision, []);
74 b.declare_math_mode(modes, out);
75 b.f_ord_equal(ty, Some(out), lhs, rhs)
76 }
77 Elem::Void => unreachable!(),
78 }
79 .unwrap();
80 });
81 }
82 Comparison::NotEqual(op) => {
83 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84 match lhs_ty.elem() {
85 Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
86 Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
87 Elem::Float(..) => {
88 b.declare_math_mode(modes, out);
89 b.f_ord_not_equal(ty, Some(out), lhs, rhs)
90 }
91 Elem::Relaxed => {
92 b.decorate(out, Decoration::RelaxedPrecision, []);
93 b.declare_math_mode(modes, out);
94 b.f_ord_not_equal(ty, Some(out), lhs, rhs)
95 }
96 Elem::Void => unreachable!(),
97 }
98 .unwrap();
99 });
100 }
101 Comparison::Lower(op) => {
102 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
103 match lhs_ty.elem() {
104 Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
105 Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
106 Elem::Float(..) => {
107 b.declare_math_mode(modes, out);
108 b.f_ord_less_than(ty, Some(out), lhs, rhs)
109 }
110 Elem::Relaxed => {
111 b.decorate(out, Decoration::RelaxedPrecision, []);
112 b.declare_math_mode(modes, out);
113 b.f_ord_less_than(ty, Some(out), lhs, rhs)
114 }
115 _ => unreachable!(),
116 }
117 .unwrap();
118 });
119 }
120 Comparison::LowerEqual(op) => {
121 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
122 match lhs_ty.elem() {
123 Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
124 Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
125 Elem::Float(..) => {
126 b.declare_math_mode(modes, out);
127 b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
128 }
129 Elem::Relaxed => {
130 b.decorate(out, Decoration::RelaxedPrecision, []);
131 b.declare_math_mode(modes, out);
132 b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
133 }
134 _ => unreachable!(),
135 }
136 .unwrap();
137 });
138 }
139 Comparison::Greater(op) => {
140 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
141 match lhs_ty.elem() {
142 Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
143 Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
144 Elem::Float(..) => {
145 b.declare_math_mode(modes, out);
146 b.f_ord_greater_than(ty, Some(out), lhs, rhs)
147 }
148 Elem::Relaxed => {
149 b.decorate(out, Decoration::RelaxedPrecision, []);
150 b.declare_math_mode(modes, out);
151 b.f_ord_greater_than(ty, Some(out), lhs, rhs)
152 }
153 _ => unreachable!(),
154 }
155 .unwrap();
156 });
157 }
158 Comparison::GreaterEqual(op) => {
159 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
160 match lhs_ty.elem() {
161 Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
162 Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
163 Elem::Float(..) => {
164 b.declare_math_mode(modes, out);
165 b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
166 }
167 Elem::Relaxed => {
168 b.decorate(out, Decoration::RelaxedPrecision, []);
169 b.declare_math_mode(modes, out);
170 b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
171 }
172 _ => unreachable!(),
173 }
174 .unwrap();
175 });
176 }
177 Comparison::IsNan(op) => {
178 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
179 b.is_nan(ty, Some(out), input).unwrap();
180 });
181 }
182 Comparison::IsInf(op) => {
183 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
184 b.is_inf(ty, Some(out), input).unwrap();
185 });
186 }
187 }
188 }
189
190 pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
191 let out = out.unwrap();
192 match op {
193 Operator::Index(op) | Operator::UncheckedIndex(op) => {
194 let is_atomic = op.list.ty.is_atomic();
195 let value = self.compile_variable(op.list);
196 let index = self.compile_variable(op.index);
197 let out = self.compile_variable(out);
198
199 if is_atomic {
200 let ptr = match self.index(&value, &index, true) {
201 IndexedVariable::Pointer(ptr, _) => ptr,
202 _ => unreachable!("Atomic is always pointer"),
203 };
204 self.state.atomic_scopes.insert(ptr, value.scope());
205 let out_id = out.as_binding().unwrap();
206
207 self.merge_binding(out_id, ptr);
209 } else {
210 let out_id = self.read_indexed(&out, &value, &index);
211 self.mark_uniformity(out_id, uniform);
212 self.write(&out, out_id);
213 }
214 }
215 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
216 let index = self.compile_variable(op.index);
217 let value = self.compile_variable(op.value);
218 let out = self.compile_variable(out);
219 let value_id = self.read_as(&value, &out.indexed_item());
220
221 self.write_indexed(&out, &index, value_id);
222 }
223 Operator::Cast(op) => {
224 let input = self.compile_variable(op.input);
225 let out = self.compile_variable(out);
226 let ty = out.item().id(self);
227 let in_id = self.read(&input);
228 let out_id = self.write_id(&out);
229 self.mark_uniformity(out_id, uniform);
230
231 if let Some(as_const) = input.as_const() {
232 let cast = self.static_cast(as_const, &input.elem(), &out.item()).0;
233 self.copy_object(ty, Some(out_id), cast).unwrap();
234 } else {
235 input.item().cast_to(self, Some(out_id), in_id, &out.item());
236 }
237
238 self.write(&out, out_id);
239 }
240 Operator::And(op) => {
241 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
242 b.logical_and(ty, Some(out), lhs, rhs).unwrap();
243 });
244 }
245 Operator::Or(op) => {
246 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
247 b.logical_or(ty, Some(out), lhs, rhs).unwrap();
248 });
249 }
250 Operator::Not(op) => {
251 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
252 b.logical_not(ty, Some(out), input).unwrap();
253 });
254 }
255 Operator::Reinterpret(op) => {
256 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
257 b.bitcast(ty, Some(out), input).unwrap();
258 })
259 }
260 Operator::InitLine(op) => {
261 let values = op
262 .inputs
263 .into_iter()
264 .map(|input| self.compile_variable(input))
265 .collect::<Vec<_>>()
266 .into_iter()
267 .map(|it| self.read(&it))
268 .collect::<Vec<_>>();
269 let item = self.compile_type(out.ty);
270 let out = self.compile_variable(out);
271 let out_id = self.write_id(&out);
272 self.mark_uniformity(out_id, uniform);
273 let ty = item.id(self);
274 self.composite_construct(ty, Some(out_id), values).unwrap();
275 self.write(&out, out_id);
276 }
277 Operator::CopyMemory(op) => {
278 let input = self.compile_variable(op.input);
279 let in_index = self.compile_variable(op.in_index);
280 let out = self.compile_variable(out);
281 let out_index = self.compile_variable(op.out_index);
282
283 let in_ptr = self.index_ptr(&input, &in_index);
284 let out_ptr = self.index_ptr(&out, &out_index);
285 self.copy_memory(out_ptr, in_ptr, None, None, vec![])
286 .unwrap();
287 }
288 Operator::CopyMemoryBulk(op) => {
289 self.capabilities.insert(Capability::Addresses);
290 let input = self.compile_variable(op.input);
291 let in_index = self.compile_variable(op.in_index);
292 let out = self.compile_variable(out);
293 let out_index = self.compile_variable(op.out_index);
294 let len = op.len;
295
296 let source = self.index_ptr(&input, &in_index);
297 let target = self.index_ptr(&out, &out_index);
298 let size = self.const_u32(len as u32 * out.item().size());
299 self.copy_memory_sized(target, source, size, None, None, vec![])
300 .unwrap();
301 }
302 Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
303 }
304 }
305
306 pub fn compile_unary_op_cast(
307 &mut self,
308 op: UnaryOperator,
309 out: core::Variable,
310 uniform: bool,
311 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
312 ) {
313 let input = self.compile_variable(op.input);
314 let out = self.compile_variable(out);
315 let out_ty = out.item();
316
317 let input_id = self.read_as(&input, &out_ty);
318 let out_id = self.write_id(&out);
319 self.mark_uniformity(out_id, uniform);
320
321 let ty = out_ty.id(self);
322
323 exec(self, out_ty, ty, input_id, out_id);
324 self.write(&out, out_id);
325 }
326
327 pub fn compile_unary_op(
328 &mut self,
329 op: UnaryOperator,
330 out: core::Variable,
331 uniform: bool,
332 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
333 ) {
334 let input = self.compile_variable(op.input);
335 let out = self.compile_variable(out);
336 let out_ty = out.item();
337
338 let input_id = self.read(&input);
339 let out_id = self.write_id(&out);
340 self.mark_uniformity(out_id, uniform);
341
342 let ty = out_ty.id(self);
343
344 exec(self, out_ty, ty, input_id, out_id);
345 self.write(&out, out_id);
346 }
347
348 pub fn compile_unary_op_bool(
349 &mut self,
350 op: UnaryOperator,
351 out: core::Variable,
352 uniform: bool,
353 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
354 ) {
355 let input = self.compile_variable(op.input);
356 let out = self.compile_variable(out);
357 let in_ty = input.item();
358
359 let input_id = self.read(&input);
360 let out_id = self.write_id(&out);
361 self.mark_uniformity(out_id, uniform);
362
363 let ty = out.item().id(self);
364
365 exec(self, in_ty, ty, input_id, out_id);
366 self.write(&out, out_id);
367 }
368
369 pub fn compile_binary_op(
370 &mut self,
371 op: BinaryOperator,
372 out: core::Variable,
373 uniform: bool,
374 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
375 ) {
376 let lhs = self.compile_variable(op.lhs);
377 let rhs = self.compile_variable(op.rhs);
378 let out = self.compile_variable(out);
379 let out_ty = out.item();
380
381 let lhs_id = self.read_as(&lhs, &out_ty);
382 let rhs_id = self.read_as(&rhs, &out_ty);
383 let out_id = self.write_id(&out);
384 self.mark_uniformity(out_id, uniform);
385
386 let ty = out_ty.id(self);
387
388 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
389 self.write(&out, out_id);
390 }
391
392 pub fn compile_binary_op_no_cast(
393 &mut self,
394 op: BinaryOperator,
395 out: core::Variable,
396 uniform: bool,
397 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
398 ) {
399 let lhs = self.compile_variable(op.lhs);
400 let rhs = self.compile_variable(op.rhs);
401 let out = self.compile_variable(out);
402 let out_ty = out.item();
403
404 let lhs_id = self.read(&lhs);
405 let rhs_id = self.read(&rhs);
406 let out_id = self.write_id(&out);
407 self.mark_uniformity(out_id, uniform);
408
409 let ty = out_ty.id(self);
410
411 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
412 self.write(&out, out_id);
413 }
414
415 pub fn compile_binary_op_bool(
416 &mut self,
417 op: BinaryOperator,
418 out: core::Variable,
419 uniform: bool,
420 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
421 ) {
422 let lhs = self.compile_variable(op.lhs);
423 let rhs = self.compile_variable(op.rhs);
424 let out = self.compile_variable(out);
425
426 let in_ty = out.item().same_vectorization(lhs.elem());
427
428 let lhs_id = self.read_as(&lhs, &in_ty);
429 let rhs_id = self.read_as(&rhs, &in_ty);
430 let out_id = self.write_id(&out);
431 self.mark_uniformity(out_id, uniform);
432
433 let ty = out.item().id(self);
434
435 exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
436 self.write(&out, out_id);
437 }
438
439 pub fn compile_select(
440 &mut self,
441 cond: core::Variable,
442 then: core::Variable,
443 or_else: core::Variable,
444 out: core::Variable,
445 uniform: bool,
446 ) {
447 let cond = self.compile_variable(cond);
448 let then = self.compile_variable(then);
449 let or_else = self.compile_variable(or_else);
450 let out = self.compile_variable(out);
451
452 let out_ty = out.item();
453 let ty = out_ty.id(self);
454
455 let cond_id = self.read(&cond);
456 let then = self.read_as(&then, &out_ty);
457 let or_else = self.read_as(&or_else, &out_ty);
458 let out_id = self.write_id(&out);
459 self.mark_uniformity(out_id, uniform);
460
461 self.select(ty, Some(out_id), cond_id, then, or_else)
462 .unwrap();
463 self.write(&out, out_id);
464 }
465
466 pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
467 if uniform {
468 self.decorate(id, Decoration::Uniform, []);
469 }
470 }
471}