cubecl_spirv/
instruction.rs

1use cubecl_core::ir::{
2    self as core, BinaryOperator, Comparison, Instruction, InstructionModes, Operation, Operator,
3    UnaryOperator,
4};
5use rspirv::spirv::{Capability, Decoration, Word};
6
7use crate::{
8    SpirvCompiler, SpirvTarget,
9    item::{Elem, Item},
10    variable::IndexedVariable,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14    pub fn compile_operation(&mut self, inst: Instruction) {
15        // Setting source loc for non-semantic ops is pointless, they don't show up in a profiler/debugger.
16        if !matches!(inst.operation, Operation::NonSemantic(_)) {
17            self.set_source_loc(&inst.source_loc);
18        }
19        let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
20        match inst.operation {
21            Operation::Copy(var) => {
22                let input = self.compile_variable(var);
23                let out = self.compile_variable(inst.out());
24                let ty = out.item().id(self);
25                let in_id = self.read(&input);
26                let in_id = input.item().broadcast(self, in_id, None, &out.item());
27                let out_id = self.write_id(&out);
28
29                self.copy_object(ty, Some(out_id), in_id).unwrap();
30                self.mark_uniformity(out_id, uniform);
31                self.write(&out, out_id);
32            }
33            Operation::Arithmetic(operator) => {
34                self.compile_arithmetic(operator, inst.out, inst.modes, uniform)
35            }
36            Operation::Comparison(operator) => {
37                self.compile_cmp(operator, inst.out, inst.modes, uniform)
38            }
39            Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
40            Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
41            Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out, inst.modes),
42            Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
43            Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
44            Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
45            Operation::Synchronization(sync) => self.compile_sync(sync),
46            Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
47            Operation::NonSemantic(debug) => self.compile_debug(debug),
48            Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
49            Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
50            Operation::Marker(_) => {}
51        }
52    }
53
54    pub fn compile_cmp(
55        &mut self,
56        op: Comparison,
57        out: Option<core::Variable>,
58        modes: InstructionModes,
59        uniform: bool,
60    ) {
61        let out = out.unwrap();
62        match op {
63            Comparison::Equal(op) => {
64                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
65                    match lhs_ty.elem() {
66                        Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
67                        Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
68                        Elem::Float(..) => {
69                            b.declare_math_mode(modes, out);
70                            b.f_ord_equal(ty, Some(out), lhs, rhs)
71                        }
72                        Elem::Relaxed => {
73                            b.decorate(out, Decoration::RelaxedPrecision, []);
74                            b.declare_math_mode(modes, out);
75                            b.f_ord_equal(ty, Some(out), lhs, rhs)
76                        }
77                        Elem::Void => unreachable!(),
78                    }
79                    .unwrap();
80                });
81            }
82            Comparison::NotEqual(op) => {
83                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84                    match lhs_ty.elem() {
85                        Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
86                        Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
87                        Elem::Float(..) => {
88                            b.declare_math_mode(modes, out);
89                            b.f_ord_not_equal(ty, Some(out), lhs, rhs)
90                        }
91                        Elem::Relaxed => {
92                            b.decorate(out, Decoration::RelaxedPrecision, []);
93                            b.declare_math_mode(modes, out);
94                            b.f_ord_not_equal(ty, Some(out), lhs, rhs)
95                        }
96                        Elem::Void => unreachable!(),
97                    }
98                    .unwrap();
99                });
100            }
101            Comparison::Lower(op) => {
102                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
103                    match lhs_ty.elem() {
104                        Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
105                        Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
106                        Elem::Float(..) => {
107                            b.declare_math_mode(modes, out);
108                            b.f_ord_less_than(ty, Some(out), lhs, rhs)
109                        }
110                        Elem::Relaxed => {
111                            b.decorate(out, Decoration::RelaxedPrecision, []);
112                            b.declare_math_mode(modes, out);
113                            b.f_ord_less_than(ty, Some(out), lhs, rhs)
114                        }
115                        _ => unreachable!(),
116                    }
117                    .unwrap();
118                });
119            }
120            Comparison::LowerEqual(op) => {
121                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
122                    match lhs_ty.elem() {
123                        Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
124                        Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
125                        Elem::Float(..) => {
126                            b.declare_math_mode(modes, out);
127                            b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
128                        }
129                        Elem::Relaxed => {
130                            b.decorate(out, Decoration::RelaxedPrecision, []);
131                            b.declare_math_mode(modes, out);
132                            b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
133                        }
134                        _ => unreachable!(),
135                    }
136                    .unwrap();
137                });
138            }
139            Comparison::Greater(op) => {
140                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
141                    match lhs_ty.elem() {
142                        Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
143                        Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
144                        Elem::Float(..) => {
145                            b.declare_math_mode(modes, out);
146                            b.f_ord_greater_than(ty, Some(out), lhs, rhs)
147                        }
148                        Elem::Relaxed => {
149                            b.decorate(out, Decoration::RelaxedPrecision, []);
150                            b.declare_math_mode(modes, out);
151                            b.f_ord_greater_than(ty, Some(out), lhs, rhs)
152                        }
153                        _ => unreachable!(),
154                    }
155                    .unwrap();
156                });
157            }
158            Comparison::GreaterEqual(op) => {
159                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
160                    match lhs_ty.elem() {
161                        Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
162                        Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
163                        Elem::Float(..) => {
164                            b.declare_math_mode(modes, out);
165                            b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
166                        }
167                        Elem::Relaxed => {
168                            b.decorate(out, Decoration::RelaxedPrecision, []);
169                            b.declare_math_mode(modes, out);
170                            b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
171                        }
172                        _ => unreachable!(),
173                    }
174                    .unwrap();
175                });
176            }
177            Comparison::IsNan(op) => {
178                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
179                    b.is_nan(ty, Some(out), input).unwrap();
180                });
181            }
182            Comparison::IsInf(op) => {
183                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
184                    b.is_inf(ty, Some(out), input).unwrap();
185                });
186            }
187        }
188    }
189
190    pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
191        let out = out.unwrap();
192        match op {
193            Operator::Index(op) | Operator::UncheckedIndex(op) => {
194                let is_atomic = op.list.ty.is_atomic();
195                let value = self.compile_variable(op.list);
196                let index = self.compile_variable(op.index);
197                let out = self.compile_variable(out);
198
199                if is_atomic {
200                    let ptr = match self.index(&value, &index, true) {
201                        IndexedVariable::Pointer(ptr, _) => ptr,
202                        _ => unreachable!("Atomic is always pointer"),
203                    };
204                    self.state.atomic_scopes.insert(ptr, value.scope());
205                    let out_id = out.as_binding().unwrap();
206
207                    // This isn't great but atomics can't currently be constructed so should be fine
208                    self.merge_binding(out_id, ptr);
209                } else {
210                    let out_id = self.read_indexed(&out, &value, &index);
211                    self.mark_uniformity(out_id, uniform);
212                    self.write(&out, out_id);
213                }
214            }
215            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
216                let index = self.compile_variable(op.index);
217                let value = self.compile_variable(op.value);
218                let out = self.compile_variable(out);
219                let value_id = self.read_as(&value, &out.indexed_item());
220
221                self.write_indexed(&out, &index, value_id);
222            }
223            Operator::Cast(op) => {
224                let input = self.compile_variable(op.input);
225                let out = self.compile_variable(out);
226                let ty = out.item().id(self);
227                let in_id = self.read(&input);
228                let out_id = self.write_id(&out);
229                self.mark_uniformity(out_id, uniform);
230
231                if let Some(as_const) = input.as_const() {
232                    let cast = self.static_cast(as_const, &input.elem(), &out.item()).0;
233                    self.copy_object(ty, Some(out_id), cast).unwrap();
234                } else {
235                    input.item().cast_to(self, Some(out_id), in_id, &out.item());
236                }
237
238                self.write(&out, out_id);
239            }
240            Operator::And(op) => {
241                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
242                    b.logical_and(ty, Some(out), lhs, rhs).unwrap();
243                });
244            }
245            Operator::Or(op) => {
246                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
247                    b.logical_or(ty, Some(out), lhs, rhs).unwrap();
248                });
249            }
250            Operator::Not(op) => {
251                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
252                    b.logical_not(ty, Some(out), input).unwrap();
253                });
254            }
255            Operator::Reinterpret(op) => {
256                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
257                    b.bitcast(ty, Some(out), input).unwrap();
258                })
259            }
260            Operator::InitLine(op) => {
261                let values = op
262                    .inputs
263                    .into_iter()
264                    .map(|input| self.compile_variable(input))
265                    .collect::<Vec<_>>()
266                    .into_iter()
267                    .map(|it| self.read(&it))
268                    .collect::<Vec<_>>();
269                let item = self.compile_type(out.ty);
270                let out = self.compile_variable(out);
271                let out_id = self.write_id(&out);
272                self.mark_uniformity(out_id, uniform);
273                let ty = item.id(self);
274                self.composite_construct(ty, Some(out_id), values).unwrap();
275                self.write(&out, out_id);
276            }
277            Operator::CopyMemory(op) => {
278                let input = self.compile_variable(op.input);
279                let in_index = self.compile_variable(op.in_index);
280                let out = self.compile_variable(out);
281                let out_index = self.compile_variable(op.out_index);
282
283                let in_ptr = self.index_ptr(&input, &in_index);
284                let out_ptr = self.index_ptr(&out, &out_index);
285                self.copy_memory(out_ptr, in_ptr, None, None, vec![])
286                    .unwrap();
287            }
288            Operator::CopyMemoryBulk(op) => {
289                self.capabilities.insert(Capability::Addresses);
290                let input = self.compile_variable(op.input);
291                let in_index = self.compile_variable(op.in_index);
292                let out = self.compile_variable(out);
293                let out_index = self.compile_variable(op.out_index);
294                let len = op.len;
295
296                let source = self.index_ptr(&input, &in_index);
297                let target = self.index_ptr(&out, &out_index);
298                let size = self.const_u32(len as u32 * out.item().size());
299                self.copy_memory_sized(target, source, size, None, None, vec![])
300                    .unwrap();
301            }
302            Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
303        }
304    }
305
306    pub fn compile_unary_op_cast(
307        &mut self,
308        op: UnaryOperator,
309        out: core::Variable,
310        uniform: bool,
311        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
312    ) {
313        let input = self.compile_variable(op.input);
314        let out = self.compile_variable(out);
315        let out_ty = out.item();
316
317        let input_id = self.read_as(&input, &out_ty);
318        let out_id = self.write_id(&out);
319        self.mark_uniformity(out_id, uniform);
320
321        let ty = out_ty.id(self);
322
323        exec(self, out_ty, ty, input_id, out_id);
324        self.write(&out, out_id);
325    }
326
327    pub fn compile_unary_op(
328        &mut self,
329        op: UnaryOperator,
330        out: core::Variable,
331        uniform: bool,
332        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
333    ) {
334        let input = self.compile_variable(op.input);
335        let out = self.compile_variable(out);
336        let out_ty = out.item();
337
338        let input_id = self.read(&input);
339        let out_id = self.write_id(&out);
340        self.mark_uniformity(out_id, uniform);
341
342        let ty = out_ty.id(self);
343
344        exec(self, out_ty, ty, input_id, out_id);
345        self.write(&out, out_id);
346    }
347
348    pub fn compile_unary_op_bool(
349        &mut self,
350        op: UnaryOperator,
351        out: core::Variable,
352        uniform: bool,
353        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
354    ) {
355        let input = self.compile_variable(op.input);
356        let out = self.compile_variable(out);
357        let in_ty = input.item();
358
359        let input_id = self.read(&input);
360        let out_id = self.write_id(&out);
361        self.mark_uniformity(out_id, uniform);
362
363        let ty = out.item().id(self);
364
365        exec(self, in_ty, ty, input_id, out_id);
366        self.write(&out, out_id);
367    }
368
369    pub fn compile_binary_op(
370        &mut self,
371        op: BinaryOperator,
372        out: core::Variable,
373        uniform: bool,
374        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
375    ) {
376        let lhs = self.compile_variable(op.lhs);
377        let rhs = self.compile_variable(op.rhs);
378        let out = self.compile_variable(out);
379        let out_ty = out.item();
380
381        let lhs_id = self.read_as(&lhs, &out_ty);
382        let rhs_id = self.read_as(&rhs, &out_ty);
383        let out_id = self.write_id(&out);
384        self.mark_uniformity(out_id, uniform);
385
386        let ty = out_ty.id(self);
387
388        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
389        self.write(&out, out_id);
390    }
391
392    pub fn compile_binary_op_no_cast(
393        &mut self,
394        op: BinaryOperator,
395        out: core::Variable,
396        uniform: bool,
397        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
398    ) {
399        let lhs = self.compile_variable(op.lhs);
400        let rhs = self.compile_variable(op.rhs);
401        let out = self.compile_variable(out);
402        let out_ty = out.item();
403
404        let lhs_id = self.read(&lhs);
405        let rhs_id = self.read(&rhs);
406        let out_id = self.write_id(&out);
407        self.mark_uniformity(out_id, uniform);
408
409        let ty = out_ty.id(self);
410
411        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
412        self.write(&out, out_id);
413    }
414
415    pub fn compile_binary_op_bool(
416        &mut self,
417        op: BinaryOperator,
418        out: core::Variable,
419        uniform: bool,
420        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
421    ) {
422        let lhs = self.compile_variable(op.lhs);
423        let rhs = self.compile_variable(op.rhs);
424        let out = self.compile_variable(out);
425
426        let in_ty = out.item().same_vectorization(lhs.elem());
427
428        let lhs_id = self.read_as(&lhs, &in_ty);
429        let rhs_id = self.read_as(&rhs, &in_ty);
430        let out_id = self.write_id(&out);
431        self.mark_uniformity(out_id, uniform);
432
433        let ty = out.item().id(self);
434
435        exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
436        self.write(&out, out_id);
437    }
438
439    pub fn compile_select(
440        &mut self,
441        cond: core::Variable,
442        then: core::Variable,
443        or_else: core::Variable,
444        out: core::Variable,
445        uniform: bool,
446    ) {
447        let cond = self.compile_variable(cond);
448        let then = self.compile_variable(then);
449        let or_else = self.compile_variable(or_else);
450        let out = self.compile_variable(out);
451
452        let out_ty = out.item();
453        let ty = out_ty.id(self);
454
455        let cond_id = self.read(&cond);
456        let then = self.read_as(&then, &out_ty);
457        let or_else = self.read_as(&or_else, &out_ty);
458        let out_id = self.write_id(&out);
459        self.mark_uniformity(out_id, uniform);
460
461        self.select(ty, Some(out_id), cond_id, then, or_else)
462            .unwrap();
463        self.write(&out, out_id);
464    }
465
466    pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
467        if uniform {
468            self.decorate(id, Decoration::Uniform, []);
469        }
470    }
471}