cubecl_spirv/
instruction.rs

1use cubecl_core::ir::{
2    self as core, BinaryOperator, Comparison, Instruction, InstructionModes, Operation, Operator,
3    UnaryOperator,
4};
5use rspirv::spirv::{Capability, Decoration, Word};
6
7use crate::{
8    SpirvCompiler, SpirvTarget,
9    item::{Elem, Item},
10    variable::IndexedVariable,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14    pub fn compile_operation(&mut self, inst: Instruction) {
15        // Setting source loc for non-semantic ops is pointless, they don't show up in a profiler/debugger.
16        if !matches!(inst.operation, Operation::NonSemantic(_)) {
17            self.set_source_loc(&inst.source_loc);
18        }
19        let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
20        match inst.operation {
21            Operation::Copy(var) => {
22                let input = self.compile_variable(var);
23                let out = self.compile_variable(inst.out());
24                let ty = out.item().id(self);
25                let in_id = self.read(&input);
26                let in_id = input.item().broadcast(self, in_id, None, &out.item());
27                let out_id = self.write_id(&out);
28
29                self.copy_object(ty, Some(out_id), in_id).unwrap();
30                self.mark_uniformity(out_id, uniform);
31                self.write(&out, out_id);
32            }
33            Operation::Arithmetic(operator) => {
34                self.compile_arithmetic(operator, inst.out, inst.modes, uniform)
35            }
36            Operation::Comparison(operator) => {
37                self.compile_cmp(operator, inst.out, inst.modes, uniform)
38            }
39            Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
40            Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
41            Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out, inst.modes),
42            Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
43            Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
44            Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
45            Operation::Synchronization(sync) => self.compile_sync(sync),
46            Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
47            Operation::NonSemantic(debug) => self.compile_debug(debug),
48            Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
49            Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
50            Operation::Marker(_) => {}
51        }
52    }
53
54    pub fn compile_cmp(
55        &mut self,
56        op: Comparison,
57        out: Option<core::Variable>,
58        modes: InstructionModes,
59        uniform: bool,
60    ) {
61        let out = out.unwrap();
62        match op {
63            Comparison::Equal(op) => {
64                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
65                    match lhs_ty.elem() {
66                        Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
67                        Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
68                        Elem::Float(..) => {
69                            b.declare_math_mode(modes, out);
70                            b.f_ord_equal(ty, Some(out), lhs, rhs)
71                        }
72                        Elem::Relaxed => {
73                            b.decorate(out, Decoration::RelaxedPrecision, []);
74                            b.declare_math_mode(modes, out);
75                            b.f_ord_equal(ty, Some(out), lhs, rhs)
76                        }
77                        Elem::Void => unreachable!(),
78                    }
79                    .unwrap();
80                });
81            }
82            Comparison::NotEqual(op) => {
83                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84                    match lhs_ty.elem() {
85                        Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
86                        Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
87                        Elem::Float(..) => {
88                            b.declare_math_mode(modes, out);
89                            b.f_ord_not_equal(ty, Some(out), lhs, rhs)
90                        }
91                        Elem::Relaxed => {
92                            b.decorate(out, Decoration::RelaxedPrecision, []);
93                            b.declare_math_mode(modes, out);
94                            b.f_ord_not_equal(ty, Some(out), lhs, rhs)
95                        }
96                        Elem::Void => unreachable!(),
97                    }
98                    .unwrap();
99                });
100            }
101            Comparison::Lower(op) => {
102                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
103                    match lhs_ty.elem() {
104                        Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
105                        Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
106                        Elem::Float(..) => {
107                            b.declare_math_mode(modes, out);
108                            b.f_ord_less_than(ty, Some(out), lhs, rhs)
109                        }
110                        Elem::Relaxed => {
111                            b.decorate(out, Decoration::RelaxedPrecision, []);
112                            b.declare_math_mode(modes, out);
113                            b.f_ord_less_than(ty, Some(out), lhs, rhs)
114                        }
115                        _ => unreachable!(),
116                    }
117                    .unwrap();
118                });
119            }
120            Comparison::LowerEqual(op) => {
121                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
122                    match lhs_ty.elem() {
123                        Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
124                        Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
125                        Elem::Float(..) => {
126                            b.declare_math_mode(modes, out);
127                            b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
128                        }
129                        Elem::Relaxed => {
130                            b.decorate(out, Decoration::RelaxedPrecision, []);
131                            b.declare_math_mode(modes, out);
132                            b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
133                        }
134                        _ => unreachable!(),
135                    }
136                    .unwrap();
137                });
138            }
139            Comparison::Greater(op) => {
140                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
141                    match lhs_ty.elem() {
142                        Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
143                        Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
144                        Elem::Float(..) => {
145                            b.declare_math_mode(modes, out);
146                            b.f_ord_greater_than(ty, Some(out), lhs, rhs)
147                        }
148                        Elem::Relaxed => {
149                            b.decorate(out, Decoration::RelaxedPrecision, []);
150                            b.declare_math_mode(modes, out);
151                            b.f_ord_greater_than(ty, Some(out), lhs, rhs)
152                        }
153                        _ => unreachable!(),
154                    }
155                    .unwrap();
156                });
157            }
158            Comparison::GreaterEqual(op) => {
159                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
160                    match lhs_ty.elem() {
161                        Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
162                        Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
163                        Elem::Float(..) => {
164                            b.declare_math_mode(modes, out);
165                            b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
166                        }
167                        Elem::Relaxed => {
168                            b.decorate(out, Decoration::RelaxedPrecision, []);
169                            b.declare_math_mode(modes, out);
170                            b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
171                        }
172                        _ => unreachable!(),
173                    }
174                    .unwrap();
175                });
176            }
177            Comparison::IsNan(op) => {
178                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
179                    b.is_nan(ty, Some(out), input).unwrap();
180                });
181            }
182            Comparison::IsInf(op) => {
183                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
184                    b.is_inf(ty, Some(out), input).unwrap();
185                });
186            }
187        }
188    }
189
190    pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
191        let out = out.unwrap();
192        match op {
193            Operator::Index(op) | Operator::UncheckedIndex(op) => {
194                let is_atomic = op.list.ty.is_atomic();
195                let value = self.compile_variable(op.list);
196                let index = self.compile_variable(op.index);
197                let out = self.compile_variable(out);
198
199                if is_atomic {
200                    let ptr = match self.index(&value, &index, true) {
201                        IndexedVariable::Pointer(ptr, _) => ptr,
202                        _ => unreachable!("Atomic is always pointer"),
203                    };
204                    let out_id = out.as_binding().unwrap();
205
206                    // This isn't great but atomics can't currently be constructed so should be fine
207                    self.merge_binding(out_id, ptr);
208                } else {
209                    let out_id = self.read_indexed(&out, &value, &index);
210                    self.mark_uniformity(out_id, uniform);
211                    self.write(&out, out_id);
212                }
213            }
214            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
215                let index = self.compile_variable(op.index);
216                let value = self.compile_variable(op.value);
217                let out = self.compile_variable(out);
218                let value_id = self.read_as(&value, &out.indexed_item());
219
220                self.write_indexed(&out, &index, value_id);
221            }
222            Operator::Cast(op) => {
223                let input = self.compile_variable(op.input);
224                let out = self.compile_variable(out);
225                let ty = out.item().id(self);
226                let in_id = self.read(&input);
227                let out_id = self.write_id(&out);
228                self.mark_uniformity(out_id, uniform);
229
230                if let Some(as_const) = input.as_const() {
231                    let cast = self.static_cast(as_const, &input.elem(), &out.item());
232                    self.copy_object(ty, Some(out_id), cast).unwrap();
233                } else {
234                    input.item().cast_to(self, Some(out_id), in_id, &out.item());
235                }
236
237                self.write(&out, out_id);
238            }
239            Operator::And(op) => {
240                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
241                    b.logical_and(ty, Some(out), lhs, rhs).unwrap();
242                });
243            }
244            Operator::Or(op) => {
245                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
246                    b.logical_or(ty, Some(out), lhs, rhs).unwrap();
247                });
248            }
249            Operator::Not(op) => {
250                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
251                    b.logical_not(ty, Some(out), input).unwrap();
252                });
253            }
254            Operator::Reinterpret(op) => {
255                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
256                    b.bitcast(ty, Some(out), input).unwrap();
257                })
258            }
259            Operator::InitLine(op) => {
260                let values = op
261                    .inputs
262                    .into_iter()
263                    .map(|input| self.compile_variable(input))
264                    .collect::<Vec<_>>()
265                    .into_iter()
266                    .map(|it| self.read(&it))
267                    .collect::<Vec<_>>();
268                let item = self.compile_type(out.ty);
269                let out = self.compile_variable(out);
270                let out_id = self.write_id(&out);
271                self.mark_uniformity(out_id, uniform);
272                let ty = item.id(self);
273                self.composite_construct(ty, Some(out_id), values).unwrap();
274                self.write(&out, out_id);
275            }
276            Operator::CopyMemory(op) => {
277                let input = self.compile_variable(op.input);
278                let in_index = self.compile_variable(op.in_index);
279                let out = self.compile_variable(out);
280                let out_index = self.compile_variable(op.out_index);
281
282                let in_ptr = self.index_ptr(&input, &in_index);
283                let out_ptr = self.index_ptr(&out, &out_index);
284                self.copy_memory(out_ptr, in_ptr, None, None, vec![])
285                    .unwrap();
286            }
287            Operator::CopyMemoryBulk(op) => {
288                self.capabilities.insert(Capability::Addresses);
289                let input = self.compile_variable(op.input);
290                let in_index = self.compile_variable(op.in_index);
291                let out = self.compile_variable(out);
292                let out_index = self.compile_variable(op.out_index);
293                let len = op.len;
294
295                let source = self.index_ptr(&input, &in_index);
296                let target = self.index_ptr(&out, &out_index);
297                let size = self.const_u32(len * out.item().size());
298                self.copy_memory_sized(target, source, size, None, None, vec![])
299                    .unwrap();
300            }
301            Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
302        }
303    }
304
305    pub fn compile_unary_op_cast(
306        &mut self,
307        op: UnaryOperator,
308        out: core::Variable,
309        uniform: bool,
310        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
311    ) {
312        let input = self.compile_variable(op.input);
313        let out = self.compile_variable(out);
314        let out_ty = out.item();
315
316        let input_id = self.read_as(&input, &out_ty);
317        let out_id = self.write_id(&out);
318        self.mark_uniformity(out_id, uniform);
319
320        let ty = out_ty.id(self);
321
322        exec(self, out_ty, ty, input_id, out_id);
323        self.write(&out, out_id);
324    }
325
326    pub fn compile_unary_op(
327        &mut self,
328        op: UnaryOperator,
329        out: core::Variable,
330        uniform: bool,
331        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
332    ) {
333        let input = self.compile_variable(op.input);
334        let out = self.compile_variable(out);
335        let out_ty = out.item();
336
337        let input_id = self.read(&input);
338        let out_id = self.write_id(&out);
339        self.mark_uniformity(out_id, uniform);
340
341        let ty = out_ty.id(self);
342
343        exec(self, out_ty, ty, input_id, out_id);
344        self.write(&out, out_id);
345    }
346
347    pub fn compile_unary_op_bool(
348        &mut self,
349        op: UnaryOperator,
350        out: core::Variable,
351        uniform: bool,
352        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
353    ) {
354        let input = self.compile_variable(op.input);
355        let out = self.compile_variable(out);
356        let in_ty = input.item();
357
358        let input_id = self.read(&input);
359        let out_id = self.write_id(&out);
360        self.mark_uniformity(out_id, uniform);
361
362        let ty = out.item().id(self);
363
364        exec(self, in_ty, ty, input_id, out_id);
365        self.write(&out, out_id);
366    }
367
368    pub fn compile_binary_op(
369        &mut self,
370        op: BinaryOperator,
371        out: core::Variable,
372        uniform: bool,
373        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
374    ) {
375        let lhs = self.compile_variable(op.lhs);
376        let rhs = self.compile_variable(op.rhs);
377        let out = self.compile_variable(out);
378        let out_ty = out.item();
379
380        let lhs_id = self.read_as(&lhs, &out_ty);
381        let rhs_id = self.read_as(&rhs, &out_ty);
382        let out_id = self.write_id(&out);
383        self.mark_uniformity(out_id, uniform);
384
385        let ty = out_ty.id(self);
386
387        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
388        self.write(&out, out_id);
389    }
390
391    pub fn compile_binary_op_no_cast(
392        &mut self,
393        op: BinaryOperator,
394        out: core::Variable,
395        uniform: bool,
396        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
397    ) {
398        let lhs = self.compile_variable(op.lhs);
399        let rhs = self.compile_variable(op.rhs);
400        let out = self.compile_variable(out);
401        let out_ty = out.item();
402
403        let lhs_id = self.read(&lhs);
404        let rhs_id = self.read(&rhs);
405        let out_id = self.write_id(&out);
406        self.mark_uniformity(out_id, uniform);
407
408        let ty = out_ty.id(self);
409
410        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
411        self.write(&out, out_id);
412    }
413
414    pub fn compile_binary_op_bool(
415        &mut self,
416        op: BinaryOperator,
417        out: core::Variable,
418        uniform: bool,
419        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
420    ) {
421        let lhs = self.compile_variable(op.lhs);
422        let rhs = self.compile_variable(op.rhs);
423        let out = self.compile_variable(out);
424
425        let in_ty = out.item().same_vectorization(lhs.elem());
426
427        let lhs_id = self.read_as(&lhs, &in_ty);
428        let rhs_id = self.read_as(&rhs, &in_ty);
429        let out_id = self.write_id(&out);
430        self.mark_uniformity(out_id, uniform);
431
432        let ty = out.item().id(self);
433
434        exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
435        self.write(&out, out_id);
436    }
437
438    pub fn compile_select(
439        &mut self,
440        cond: core::Variable,
441        then: core::Variable,
442        or_else: core::Variable,
443        out: core::Variable,
444        uniform: bool,
445    ) {
446        let cond = self.compile_variable(cond);
447        let then = self.compile_variable(then);
448        let or_else = self.compile_variable(or_else);
449        let out = self.compile_variable(out);
450
451        let out_ty = out.item();
452        let ty = out_ty.id(self);
453
454        let cond_id = self.read(&cond);
455        let then = self.read_as(&then, &out_ty);
456        let or_else = self.read_as(&or_else, &out_ty);
457        let out_id = self.write_id(&out);
458        self.mark_uniformity(out_id, uniform);
459
460        self.select(ty, Some(out_id), cond_id, then, or_else)
461            .unwrap();
462        self.write(&out, out_id);
463    }
464
465    pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
466        if uniform {
467            self.decorate(id, Decoration::Uniform, []);
468        }
469    }
470}