1use cubecl_core::ir::{
2 self as core, BinaryOperator, Comparison, Instruction, InstructionModes, Operation, Operator,
3 UnaryOperator,
4};
5use rspirv::spirv::{Capability, Decoration, Word};
6
7use crate::{
8 SpirvCompiler, SpirvTarget,
9 item::{Elem, Item},
10 variable::IndexedVariable,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14 pub fn compile_operation(&mut self, inst: Instruction) {
15 if !matches!(inst.operation, Operation::NonSemantic(_)) {
17 self.set_source_loc(&inst.source_loc);
18 }
19 let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
20 match inst.operation {
21 Operation::Copy(var) => {
22 let input = self.compile_variable(var);
23 let out = self.compile_variable(inst.out());
24 let ty = out.item().id(self);
25 let in_id = self.read(&input);
26 let in_id = input.item().broadcast(self, in_id, None, &out.item());
27 let out_id = self.write_id(&out);
28
29 self.copy_object(ty, Some(out_id), in_id).unwrap();
30 self.mark_uniformity(out_id, uniform);
31 self.write(&out, out_id);
32 }
33 Operation::Arithmetic(operator) => {
34 self.compile_arithmetic(operator, inst.out, inst.modes, uniform)
35 }
36 Operation::Comparison(operator) => {
37 self.compile_cmp(operator, inst.out, inst.modes, uniform)
38 }
39 Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
40 Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
41 Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out, inst.modes),
42 Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
43 Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
44 Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
45 Operation::Synchronization(sync) => self.compile_sync(sync),
46 Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
47 Operation::NonSemantic(debug) => self.compile_debug(debug),
48 Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
49 Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
50 Operation::Marker(_) => {}
51 }
52 }
53
54 pub fn compile_cmp(
55 &mut self,
56 op: Comparison,
57 out: Option<core::Variable>,
58 modes: InstructionModes,
59 uniform: bool,
60 ) {
61 let out = out.unwrap();
62 match op {
63 Comparison::Equal(op) => {
64 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
65 match lhs_ty.elem() {
66 Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
67 Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
68 Elem::Float(..) => {
69 b.declare_math_mode(modes, out);
70 b.f_ord_equal(ty, Some(out), lhs, rhs)
71 }
72 Elem::Relaxed => {
73 b.decorate(out, Decoration::RelaxedPrecision, []);
74 b.declare_math_mode(modes, out);
75 b.f_ord_equal(ty, Some(out), lhs, rhs)
76 }
77 Elem::Void => unreachable!(),
78 }
79 .unwrap();
80 });
81 }
82 Comparison::NotEqual(op) => {
83 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84 match lhs_ty.elem() {
85 Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
86 Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
87 Elem::Float(..) => {
88 b.declare_math_mode(modes, out);
89 b.f_ord_not_equal(ty, Some(out), lhs, rhs)
90 }
91 Elem::Relaxed => {
92 b.decorate(out, Decoration::RelaxedPrecision, []);
93 b.declare_math_mode(modes, out);
94 b.f_ord_not_equal(ty, Some(out), lhs, rhs)
95 }
96 Elem::Void => unreachable!(),
97 }
98 .unwrap();
99 });
100 }
101 Comparison::Lower(op) => {
102 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
103 match lhs_ty.elem() {
104 Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
105 Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
106 Elem::Float(..) => {
107 b.declare_math_mode(modes, out);
108 b.f_ord_less_than(ty, Some(out), lhs, rhs)
109 }
110 Elem::Relaxed => {
111 b.decorate(out, Decoration::RelaxedPrecision, []);
112 b.declare_math_mode(modes, out);
113 b.f_ord_less_than(ty, Some(out), lhs, rhs)
114 }
115 _ => unreachable!(),
116 }
117 .unwrap();
118 });
119 }
120 Comparison::LowerEqual(op) => {
121 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
122 match lhs_ty.elem() {
123 Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
124 Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
125 Elem::Float(..) => {
126 b.declare_math_mode(modes, out);
127 b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
128 }
129 Elem::Relaxed => {
130 b.decorate(out, Decoration::RelaxedPrecision, []);
131 b.declare_math_mode(modes, out);
132 b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
133 }
134 _ => unreachable!(),
135 }
136 .unwrap();
137 });
138 }
139 Comparison::Greater(op) => {
140 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
141 match lhs_ty.elem() {
142 Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
143 Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
144 Elem::Float(..) => {
145 b.declare_math_mode(modes, out);
146 b.f_ord_greater_than(ty, Some(out), lhs, rhs)
147 }
148 Elem::Relaxed => {
149 b.decorate(out, Decoration::RelaxedPrecision, []);
150 b.declare_math_mode(modes, out);
151 b.f_ord_greater_than(ty, Some(out), lhs, rhs)
152 }
153 _ => unreachable!(),
154 }
155 .unwrap();
156 });
157 }
158 Comparison::GreaterEqual(op) => {
159 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
160 match lhs_ty.elem() {
161 Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
162 Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
163 Elem::Float(..) => {
164 b.declare_math_mode(modes, out);
165 b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
166 }
167 Elem::Relaxed => {
168 b.decorate(out, Decoration::RelaxedPrecision, []);
169 b.declare_math_mode(modes, out);
170 b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
171 }
172 _ => unreachable!(),
173 }
174 .unwrap();
175 });
176 }
177 Comparison::IsNan(op) => {
178 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
179 b.is_nan(ty, Some(out), input).unwrap();
180 });
181 }
182 Comparison::IsInf(op) => {
183 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
184 b.is_inf(ty, Some(out), input).unwrap();
185 });
186 }
187 }
188 }
189
190 pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
191 let out = out.unwrap();
192 match op {
193 Operator::Index(op) | Operator::UncheckedIndex(op) => {
194 let is_atomic = op.list.ty.is_atomic();
195 let value = self.compile_variable(op.list);
196 let index = self.compile_variable(op.index);
197 let out = self.compile_variable(out);
198
199 if is_atomic {
200 let ptr = match self.index(&value, &index, true) {
201 IndexedVariable::Pointer(ptr, _) => ptr,
202 _ => unreachable!("Atomic is always pointer"),
203 };
204 let out_id = out.as_binding().unwrap();
205
206 self.merge_binding(out_id, ptr);
208 } else {
209 let out_id = self.read_indexed(&out, &value, &index);
210 self.mark_uniformity(out_id, uniform);
211 self.write(&out, out_id);
212 }
213 }
214 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
215 let index = self.compile_variable(op.index);
216 let value = self.compile_variable(op.value);
217 let out = self.compile_variable(out);
218 let value_id = self.read_as(&value, &out.indexed_item());
219
220 self.write_indexed(&out, &index, value_id);
221 }
222 Operator::Cast(op) => {
223 let input = self.compile_variable(op.input);
224 let out = self.compile_variable(out);
225 let ty = out.item().id(self);
226 let in_id = self.read(&input);
227 let out_id = self.write_id(&out);
228 self.mark_uniformity(out_id, uniform);
229
230 if let Some(as_const) = input.as_const() {
231 let cast = self.static_cast(as_const, &input.elem(), &out.item());
232 self.copy_object(ty, Some(out_id), cast).unwrap();
233 } else {
234 input.item().cast_to(self, Some(out_id), in_id, &out.item());
235 }
236
237 self.write(&out, out_id);
238 }
239 Operator::And(op) => {
240 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
241 b.logical_and(ty, Some(out), lhs, rhs).unwrap();
242 });
243 }
244 Operator::Or(op) => {
245 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
246 b.logical_or(ty, Some(out), lhs, rhs).unwrap();
247 });
248 }
249 Operator::Not(op) => {
250 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
251 b.logical_not(ty, Some(out), input).unwrap();
252 });
253 }
254 Operator::Reinterpret(op) => {
255 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
256 b.bitcast(ty, Some(out), input).unwrap();
257 })
258 }
259 Operator::InitLine(op) => {
260 let values = op
261 .inputs
262 .into_iter()
263 .map(|input| self.compile_variable(input))
264 .collect::<Vec<_>>()
265 .into_iter()
266 .map(|it| self.read(&it))
267 .collect::<Vec<_>>();
268 let item = self.compile_type(out.ty);
269 let out = self.compile_variable(out);
270 let out_id = self.write_id(&out);
271 self.mark_uniformity(out_id, uniform);
272 let ty = item.id(self);
273 self.composite_construct(ty, Some(out_id), values).unwrap();
274 self.write(&out, out_id);
275 }
276 Operator::CopyMemory(op) => {
277 let input = self.compile_variable(op.input);
278 let in_index = self.compile_variable(op.in_index);
279 let out = self.compile_variable(out);
280 let out_index = self.compile_variable(op.out_index);
281
282 let in_ptr = self.index_ptr(&input, &in_index);
283 let out_ptr = self.index_ptr(&out, &out_index);
284 self.copy_memory(out_ptr, in_ptr, None, None, vec![])
285 .unwrap();
286 }
287 Operator::CopyMemoryBulk(op) => {
288 self.capabilities.insert(Capability::Addresses);
289 let input = self.compile_variable(op.input);
290 let in_index = self.compile_variable(op.in_index);
291 let out = self.compile_variable(out);
292 let out_index = self.compile_variable(op.out_index);
293 let len = op.len;
294
295 let source = self.index_ptr(&input, &in_index);
296 let target = self.index_ptr(&out, &out_index);
297 let size = self.const_u32(len * out.item().size());
298 self.copy_memory_sized(target, source, size, None, None, vec![])
299 .unwrap();
300 }
301 Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
302 }
303 }
304
305 pub fn compile_unary_op_cast(
306 &mut self,
307 op: UnaryOperator,
308 out: core::Variable,
309 uniform: bool,
310 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
311 ) {
312 let input = self.compile_variable(op.input);
313 let out = self.compile_variable(out);
314 let out_ty = out.item();
315
316 let input_id = self.read_as(&input, &out_ty);
317 let out_id = self.write_id(&out);
318 self.mark_uniformity(out_id, uniform);
319
320 let ty = out_ty.id(self);
321
322 exec(self, out_ty, ty, input_id, out_id);
323 self.write(&out, out_id);
324 }
325
326 pub fn compile_unary_op(
327 &mut self,
328 op: UnaryOperator,
329 out: core::Variable,
330 uniform: bool,
331 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
332 ) {
333 let input = self.compile_variable(op.input);
334 let out = self.compile_variable(out);
335 let out_ty = out.item();
336
337 let input_id = self.read(&input);
338 let out_id = self.write_id(&out);
339 self.mark_uniformity(out_id, uniform);
340
341 let ty = out_ty.id(self);
342
343 exec(self, out_ty, ty, input_id, out_id);
344 self.write(&out, out_id);
345 }
346
347 pub fn compile_unary_op_bool(
348 &mut self,
349 op: UnaryOperator,
350 out: core::Variable,
351 uniform: bool,
352 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
353 ) {
354 let input = self.compile_variable(op.input);
355 let out = self.compile_variable(out);
356 let in_ty = input.item();
357
358 let input_id = self.read(&input);
359 let out_id = self.write_id(&out);
360 self.mark_uniformity(out_id, uniform);
361
362 let ty = out.item().id(self);
363
364 exec(self, in_ty, ty, input_id, out_id);
365 self.write(&out, out_id);
366 }
367
368 pub fn compile_binary_op(
369 &mut self,
370 op: BinaryOperator,
371 out: core::Variable,
372 uniform: bool,
373 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
374 ) {
375 let lhs = self.compile_variable(op.lhs);
376 let rhs = self.compile_variable(op.rhs);
377 let out = self.compile_variable(out);
378 let out_ty = out.item();
379
380 let lhs_id = self.read_as(&lhs, &out_ty);
381 let rhs_id = self.read_as(&rhs, &out_ty);
382 let out_id = self.write_id(&out);
383 self.mark_uniformity(out_id, uniform);
384
385 let ty = out_ty.id(self);
386
387 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
388 self.write(&out, out_id);
389 }
390
391 pub fn compile_binary_op_no_cast(
392 &mut self,
393 op: BinaryOperator,
394 out: core::Variable,
395 uniform: bool,
396 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
397 ) {
398 let lhs = self.compile_variable(op.lhs);
399 let rhs = self.compile_variable(op.rhs);
400 let out = self.compile_variable(out);
401 let out_ty = out.item();
402
403 let lhs_id = self.read(&lhs);
404 let rhs_id = self.read(&rhs);
405 let out_id = self.write_id(&out);
406 self.mark_uniformity(out_id, uniform);
407
408 let ty = out_ty.id(self);
409
410 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
411 self.write(&out, out_id);
412 }
413
414 pub fn compile_binary_op_bool(
415 &mut self,
416 op: BinaryOperator,
417 out: core::Variable,
418 uniform: bool,
419 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
420 ) {
421 let lhs = self.compile_variable(op.lhs);
422 let rhs = self.compile_variable(op.rhs);
423 let out = self.compile_variable(out);
424
425 let in_ty = out.item().same_vectorization(lhs.elem());
426
427 let lhs_id = self.read_as(&lhs, &in_ty);
428 let rhs_id = self.read_as(&rhs, &in_ty);
429 let out_id = self.write_id(&out);
430 self.mark_uniformity(out_id, uniform);
431
432 let ty = out.item().id(self);
433
434 exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
435 self.write(&out, out_id);
436 }
437
438 pub fn compile_select(
439 &mut self,
440 cond: core::Variable,
441 then: core::Variable,
442 or_else: core::Variable,
443 out: core::Variable,
444 uniform: bool,
445 ) {
446 let cond = self.compile_variable(cond);
447 let then = self.compile_variable(then);
448 let or_else = self.compile_variable(or_else);
449 let out = self.compile_variable(out);
450
451 let out_ty = out.item();
452 let ty = out_ty.id(self);
453
454 let cond_id = self.read(&cond);
455 let then = self.read_as(&then, &out_ty);
456 let or_else = self.read_as(&or_else, &out_ty);
457 let out_id = self.write_id(&out);
458 self.mark_uniformity(out_id, uniform);
459
460 self.select(ty, Some(out_id), cond_id, then, or_else)
461 .unwrap();
462 self.write(&out, out_id);
463 }
464
465 pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
466 if uniform {
467 self.decorate(id, Decoration::Uniform, []);
468 }
469 }
470}