cubecl_spirv/
instruction.rs

1use cubecl_core::ir::{
2    self as core, BinaryOperator, Comparison, Instruction, Operation, Operator, UnaryOperator,
3};
4use rspirv::spirv::{Capability, Decoration, Word};
5
6use crate::{
7    SpirvCompiler, SpirvTarget,
8    item::{Elem, Item},
9    variable::IndexedVariable,
10};
11
12impl<T: SpirvTarget> SpirvCompiler<T> {
13    pub fn compile_operation(&mut self, inst: Instruction) {
14        // Setting source loc for non-semantic ops is pointless, they don't show up in a profiler/debugger.
15        if !matches!(inst.operation, Operation::NonSemantic(_)) {
16            self.set_source_loc(&inst.source_loc);
17        }
18        let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
19        match inst.operation {
20            Operation::Copy(var) => {
21                let input = self.compile_variable(var);
22                let out = self.compile_variable(inst.out());
23                let ty = out.item().id(self);
24                let in_id = self.read(&input);
25                let in_id = input.item().broadcast(self, in_id, None, &out.item());
26                let out_id = self.write_id(&out);
27
28                self.copy_object(ty, Some(out_id), in_id).unwrap();
29                self.mark_uniformity(out_id, uniform);
30                self.write(&out, out_id);
31            }
32            Operation::Arithmetic(operator) => self.compile_arithmetic(operator, inst.out, uniform),
33            Operation::Comparison(operator) => self.compile_cmp(operator, inst.out, uniform),
34            Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
35            Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
36            Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out),
37            Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
38            Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
39            Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
40            Operation::Synchronization(sync) => self.compile_sync(sync),
41            Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
42            Operation::NonSemantic(debug) => self.compile_debug(debug),
43            Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
44            Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
45            Operation::Free(_) => {}
46        }
47    }
48
49    pub fn compile_cmp(&mut self, op: Comparison, out: Option<core::Variable>, uniform: bool) {
50        let out = out.unwrap();
51        match op {
52            Comparison::Equal(op) => {
53                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
54                    match lhs_ty.elem() {
55                        Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
56                        Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
57                        Elem::Float(..) => b.f_ord_equal(ty, Some(out), lhs, rhs),
58                        Elem::Relaxed => {
59                            b.decorate(out, Decoration::RelaxedPrecision, []);
60                            b.f_ord_equal(ty, Some(out), lhs, rhs)
61                        }
62                        Elem::Void => unreachable!(),
63                    }
64                    .unwrap();
65                });
66            }
67            Comparison::NotEqual(op) => {
68                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
69                    match lhs_ty.elem() {
70                        Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
71                        Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
72                        Elem::Float(..) => b.f_ord_not_equal(ty, Some(out), lhs, rhs),
73                        Elem::Relaxed => {
74                            b.decorate(out, Decoration::RelaxedPrecision, []);
75                            b.f_ord_not_equal(ty, Some(out), lhs, rhs)
76                        }
77                        Elem::Void => unreachable!(),
78                    }
79                    .unwrap();
80                });
81            }
82            Comparison::Lower(op) => {
83                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84                    match lhs_ty.elem() {
85                        Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
86                        Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
87                        Elem::Float(..) => b.f_ord_less_than(ty, Some(out), lhs, rhs),
88                        Elem::Relaxed => {
89                            b.decorate(out, Decoration::RelaxedPrecision, []);
90                            b.f_ord_less_than(ty, Some(out), lhs, rhs)
91                        }
92                        _ => unreachable!(),
93                    }
94                    .unwrap();
95                });
96            }
97            Comparison::LowerEqual(op) => {
98                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
99                    match lhs_ty.elem() {
100                        Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
101                        Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
102                        Elem::Float(..) => b.f_ord_less_than_equal(ty, Some(out), lhs, rhs),
103                        Elem::Relaxed => {
104                            b.decorate(out, Decoration::RelaxedPrecision, []);
105                            b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
106                        }
107                        _ => unreachable!(),
108                    }
109                    .unwrap();
110                });
111            }
112            Comparison::Greater(op) => {
113                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
114                    match lhs_ty.elem() {
115                        Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
116                        Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
117                        Elem::Float(..) => b.f_ord_greater_than(ty, Some(out), lhs, rhs),
118                        Elem::Relaxed => {
119                            b.decorate(out, Decoration::RelaxedPrecision, []);
120                            b.f_ord_greater_than(ty, Some(out), lhs, rhs)
121                        }
122                        _ => unreachable!(),
123                    }
124                    .unwrap();
125                });
126            }
127            Comparison::GreaterEqual(op) => {
128                self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
129                    match lhs_ty.elem() {
130                        Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
131                        Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
132                        Elem::Float(..) => b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs),
133                        Elem::Relaxed => {
134                            b.decorate(out, Decoration::RelaxedPrecision, []);
135                            b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
136                        }
137                        _ => unreachable!(),
138                    }
139                    .unwrap();
140                });
141            }
142            Comparison::IsNan(op) => {
143                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
144                    b.is_nan(ty, Some(out), input).unwrap();
145                });
146            }
147            Comparison::IsInf(op) => {
148                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
149                    b.is_inf(ty, Some(out), input).unwrap();
150                });
151            }
152        }
153    }
154
155    pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
156        let out = out.unwrap();
157        match op {
158            Operator::Index(op) | Operator::UncheckedIndex(op) => {
159                let is_atomic = op.list.ty.is_atomic();
160                let value = self.compile_variable(op.list);
161                let index = self.compile_variable(op.index);
162                let out = self.compile_variable(out);
163
164                if is_atomic {
165                    let ptr = match self.index(&value, &index, true) {
166                        IndexedVariable::Pointer(ptr, _) => ptr,
167                        _ => unreachable!("Atomic is always pointer"),
168                    };
169                    let out_id = out.as_binding().unwrap();
170
171                    // This isn't great but atomics can't currently be constructed so should be fine
172                    self.merge_binding(out_id, ptr);
173                } else {
174                    let out_id = self.read_indexed(&out, &value, &index);
175                    self.mark_uniformity(out_id, uniform);
176                    self.write(&out, out_id);
177                }
178            }
179            Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
180                let index = self.compile_variable(op.index);
181                let value = self.compile_variable(op.value);
182                let out = self.compile_variable(out);
183                let value_id = self.read_as(&value, &out.indexed_item());
184
185                self.write_indexed(&out, &index, value_id);
186            }
187            Operator::Cast(op) => {
188                let input = self.compile_variable(op.input);
189                let out = self.compile_variable(out);
190                let ty = out.item().id(self);
191                let in_id = self.read(&input);
192                let out_id = self.write_id(&out);
193                self.mark_uniformity(out_id, uniform);
194
195                if let Some(as_const) = input.as_const() {
196                    let cast = self.static_cast(as_const, &input.elem(), &out.item());
197                    self.copy_object(ty, Some(out_id), cast).unwrap();
198                } else {
199                    input.item().cast_to(self, Some(out_id), in_id, &out.item());
200                }
201
202                self.write(&out, out_id);
203            }
204            Operator::And(op) => {
205                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
206                    b.logical_and(ty, Some(out), lhs, rhs).unwrap();
207                });
208            }
209            Operator::Or(op) => {
210                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
211                    b.logical_or(ty, Some(out), lhs, rhs).unwrap();
212                });
213            }
214            Operator::Not(op) => {
215                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
216                    b.logical_not(ty, Some(out), input).unwrap();
217                });
218            }
219            Operator::Reinterpret(op) => {
220                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
221                    b.bitcast(ty, Some(out), input).unwrap();
222                })
223            }
224            Operator::InitLine(op) => {
225                let values = op
226                    .inputs
227                    .into_iter()
228                    .map(|input| self.compile_variable(input))
229                    .collect::<Vec<_>>()
230                    .into_iter()
231                    .map(|it| self.read(&it))
232                    .collect::<Vec<_>>();
233                let item = self.compile_type(out.ty);
234                let out = self.compile_variable(out);
235                let out_id = self.write_id(&out);
236                self.mark_uniformity(out_id, uniform);
237                let ty = item.id(self);
238                self.composite_construct(ty, Some(out_id), values).unwrap();
239                self.write(&out, out_id);
240            }
241            Operator::CopyMemory(op) => {
242                let input = self.compile_variable(op.input);
243                let in_index = self.compile_variable(op.in_index);
244                let out = self.compile_variable(out);
245                let out_index = self.compile_variable(op.out_index);
246
247                let in_ptr = self.index_ptr(&input, &in_index);
248                let out_ptr = self.index_ptr(&out, &out_index);
249                self.copy_memory(out_ptr, in_ptr, None, None, vec![])
250                    .unwrap();
251            }
252            Operator::CopyMemoryBulk(op) => {
253                self.capabilities.insert(Capability::Addresses);
254                let input = self.compile_variable(op.input);
255                let in_index = self.compile_variable(op.in_index);
256                let out = self.compile_variable(out);
257                let out_index = self.compile_variable(op.out_index);
258                let len = op.len;
259
260                let source = self.index_ptr(&input, &in_index);
261                let target = self.index_ptr(&out, &out_index);
262                let size = self.const_u32(len * out.item().size());
263                self.copy_memory_sized(target, source, size, None, None, vec![])
264                    .unwrap();
265            }
266            Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
267        }
268    }
269
270    pub fn compile_unary_op_cast(
271        &mut self,
272        op: UnaryOperator,
273        out: core::Variable,
274        uniform: bool,
275        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
276    ) {
277        let input = self.compile_variable(op.input);
278        let out = self.compile_variable(out);
279        let out_ty = out.item();
280
281        let input_id = self.read_as(&input, &out_ty);
282        let out_id = self.write_id(&out);
283        self.mark_uniformity(out_id, uniform);
284
285        let ty = out_ty.id(self);
286
287        exec(self, out_ty, ty, input_id, out_id);
288        self.write(&out, out_id);
289    }
290
291    pub fn compile_unary_op(
292        &mut self,
293        op: UnaryOperator,
294        out: core::Variable,
295        uniform: bool,
296        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
297    ) {
298        let input = self.compile_variable(op.input);
299        let out = self.compile_variable(out);
300        let out_ty = out.item();
301
302        let input_id = self.read(&input);
303        let out_id = self.write_id(&out);
304        self.mark_uniformity(out_id, uniform);
305
306        let ty = out_ty.id(self);
307
308        exec(self, out_ty, ty, input_id, out_id);
309        self.write(&out, out_id);
310    }
311
312    pub fn compile_unary_op_bool(
313        &mut self,
314        op: UnaryOperator,
315        out: core::Variable,
316        uniform: bool,
317        exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
318    ) {
319        let input = self.compile_variable(op.input);
320        let out = self.compile_variable(out);
321        let in_ty = input.item();
322
323        let input_id = self.read(&input);
324        let out_id = self.write_id(&out);
325        self.mark_uniformity(out_id, uniform);
326
327        let ty = out.item().id(self);
328
329        exec(self, in_ty, ty, input_id, out_id);
330        self.write(&out, out_id);
331    }
332
333    pub fn compile_binary_op(
334        &mut self,
335        op: BinaryOperator,
336        out: core::Variable,
337        uniform: bool,
338        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
339    ) {
340        let lhs = self.compile_variable(op.lhs);
341        let rhs = self.compile_variable(op.rhs);
342        let out = self.compile_variable(out);
343        let out_ty = out.item();
344
345        let lhs_id = self.read_as(&lhs, &out_ty);
346        let rhs_id = self.read_as(&rhs, &out_ty);
347        let out_id = self.write_id(&out);
348        self.mark_uniformity(out_id, uniform);
349
350        let ty = out_ty.id(self);
351
352        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
353        self.write(&out, out_id);
354    }
355
356    pub fn compile_binary_op_no_cast(
357        &mut self,
358        op: BinaryOperator,
359        out: core::Variable,
360        uniform: bool,
361        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
362    ) {
363        let lhs = self.compile_variable(op.lhs);
364        let rhs = self.compile_variable(op.rhs);
365        let out = self.compile_variable(out);
366        let out_ty = out.item();
367
368        let lhs_id = self.read(&lhs);
369        let rhs_id = self.read(&rhs);
370        let out_id = self.write_id(&out);
371        self.mark_uniformity(out_id, uniform);
372
373        let ty = out_ty.id(self);
374
375        exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
376        self.write(&out, out_id);
377    }
378
379    pub fn compile_binary_op_bool(
380        &mut self,
381        op: BinaryOperator,
382        out: core::Variable,
383        uniform: bool,
384        exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
385    ) {
386        let lhs = self.compile_variable(op.lhs);
387        let rhs = self.compile_variable(op.rhs);
388        let out = self.compile_variable(out);
389
390        let in_ty = out.item().same_vectorization(lhs.elem());
391
392        let lhs_id = self.read_as(&lhs, &in_ty);
393        let rhs_id = self.read_as(&rhs, &in_ty);
394        let out_id = self.write_id(&out);
395        self.mark_uniformity(out_id, uniform);
396
397        let ty = out.item().id(self);
398
399        exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
400        self.write(&out, out_id);
401    }
402
403    pub fn compile_select(
404        &mut self,
405        cond: core::Variable,
406        then: core::Variable,
407        or_else: core::Variable,
408        out: core::Variable,
409        uniform: bool,
410    ) {
411        let cond = self.compile_variable(cond);
412        let then = self.compile_variable(then);
413        let or_else = self.compile_variable(or_else);
414        let out = self.compile_variable(out);
415
416        let out_ty = out.item();
417        let ty = out_ty.id(self);
418
419        let cond_id = self.read(&cond);
420        let then = self.read_as(&then, &out_ty);
421        let or_else = self.read_as(&or_else, &out_ty);
422        let out_id = self.write_id(&out);
423        self.mark_uniformity(out_id, uniform);
424
425        self.select(ty, Some(out_id), cond_id, then, or_else)
426            .unwrap();
427        self.write(&out, out_id);
428    }
429
430    pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
431        if uniform {
432            self.decorate(id, Decoration::Uniform, []);
433        }
434    }
435}