cubecl_spirv/
arithmetic.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10    pub fn compile_arithmetic(
11        &mut self,
12        op: Arithmetic,
13        out: Option<core::Variable>,
14        modes: InstructionModes,
15        uniform: bool,
16    ) {
17        let out = out.unwrap();
18        match op {
19            Arithmetic::Add(op) => {
20                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21                    match out_ty.elem() {
22                        Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23                        Elem::Float(..) => {
24                            b.declare_math_mode(modes, out);
25                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
26                        }
27                        Elem::Relaxed => {
28                            b.decorate(out, Decoration::RelaxedPrecision, []);
29                            b.declare_math_mode(modes, out);
30                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
31                        }
32                        _ => unreachable!(),
33                    };
34                });
35            }
36            Arithmetic::SaturatingAdd(_) => {
37                unimplemented!("Should be replaced by polyfill");
38            }
39            Arithmetic::Sub(op) => {
40                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41                    match out_ty.elem() {
42                        Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43                        Elem::Float(..) => {
44                            b.declare_math_mode(modes, out);
45                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46                        }
47                        Elem::Relaxed => {
48                            b.decorate(out, Decoration::RelaxedPrecision, []);
49                            b.declare_math_mode(modes, out);
50                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51                        }
52                        _ => unreachable!(),
53                    };
54                });
55            }
56            Arithmetic::SaturatingSub(_) => {
57                unimplemented!("Should be replaced by polyfill");
58            }
59            Arithmetic::Mul(op) => {
60                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61                    match out_ty.elem() {
62                        Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63                        Elem::Float(..) => {
64                            b.declare_math_mode(modes, out);
65                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66                        }
67                        Elem::Relaxed => {
68                            b.decorate(out, Decoration::RelaxedPrecision, []);
69                            b.declare_math_mode(modes, out);
70                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71                        }
72                        _ => unreachable!(),
73                    };
74                });
75            }
76            Arithmetic::MulHi(op) => {
77                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78                    let out_st = b.type_struct([ty, ty]);
79                    let extended = match out_ty.elem() {
80                        Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81                        Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82                        _ => unreachable!(),
83                    };
84                    b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85                });
86            }
87            Arithmetic::Div(op) => {
88                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89                    match out_ty.elem() {
90                        Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91                        Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92                        Elem::Float(..) => {
93                            b.declare_math_mode(modes, out);
94                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
95                        }
96                        Elem::Relaxed => {
97                            b.decorate(out, Decoration::RelaxedPrecision, []);
98                            b.declare_math_mode(modes, out);
99                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
100                        }
101                        _ => unreachable!(),
102                    };
103                });
104            }
105            Arithmetic::Remainder(op) => {
106                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107                    match out_ty.elem() {
108                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109                        Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(),
110                        Elem::Float(..) => {
111                            b.declare_math_mode(modes, out);
112                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
113                        }
114                        Elem::Relaxed => {
115                            b.decorate(out, Decoration::RelaxedPrecision, []);
116                            b.declare_math_mode(modes, out);
117                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
118                        }
119                        _ => unreachable!(),
120                    };
121                });
122            }
123            Arithmetic::Modulo(op) => {
124                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
125                    match out_ty.elem() {
126                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
127                        Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
128                        Elem::Float(..) => {
129                            b.declare_math_mode(modes, out);
130                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
131                        }
132                        Elem::Relaxed => {
133                            b.decorate(out, Decoration::RelaxedPrecision, []);
134                            b.declare_math_mode(modes, out);
135                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
136                        }
137                        _ => unreachable!(),
138                    };
139                });
140            }
141            Arithmetic::Dot(op) => {
142                if op.lhs.ty.line_size() == 1 {
143                    self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
144                        match out_ty.elem() {
145                            Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
146                            Elem::Float(..) => {
147                                b.declare_math_mode(modes, out);
148                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
149                            }
150                            Elem::Relaxed => {
151                                b.decorate(out, Decoration::RelaxedPrecision, []);
152                                b.declare_math_mode(modes, out);
153                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
154                            }
155                            _ => unreachable!(),
156                        };
157                    });
158                } else {
159                    let lhs = self.compile_variable(op.lhs);
160                    let rhs = self.compile_variable(op.rhs);
161                    let out = self.compile_variable(out);
162                    let ty = out.item().id(self);
163
164                    let lhs_id = self.read(&lhs);
165                    let rhs_id = self.read(&rhs);
166                    let out_id = self.write_id(&out);
167                    self.mark_uniformity(out_id, uniform);
168
169                    if matches!(lhs.elem(), Elem::Int(_, _)) {
170                        self.capabilities.insert(Capability::DotProduct);
171                    }
172                    if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
173                        self.capabilities.insert(Capability::BFloat16DotProductKHR);
174                    }
175
176                    match (lhs.elem(), rhs.elem()) {
177                        (Elem::Int(_, false), Elem::Int(_, false)) => {
178                            self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
179                        }
180                        (Elem::Int(_, true), Elem::Int(_, false)) => {
181                            self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
182                        }
183                        (Elem::Int(_, false), Elem::Int(_, true)) => {
184                            self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
185                        }
186                        (Elem::Int(_, true), Elem::Int(_, true)) => {
187                            self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
188                        }
189                        (Elem::Float(..), Elem::Float(..))
190                        | (Elem::Relaxed, Elem::Float(..))
191                        | (Elem::Float(..), Elem::Relaxed) => {
192                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
193                        }
194                        (Elem::Relaxed, Elem::Relaxed) => {
195                            self.decorate(out_id, Decoration::RelaxedPrecision, []);
196                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
197                        }
198                        _ => unreachable!(),
199                    }
200                    .unwrap();
201                    self.write(&out, out_id);
202                }
203            }
204            Arithmetic::Fma(op) => {
205                let a = self.compile_variable(op.a);
206                let b = self.compile_variable(op.b);
207                let c = self.compile_variable(op.c);
208                let out = self.compile_variable(out);
209                let out_ty = out.item();
210                let relaxed = matches!(
211                    (a.item().elem(), b.item().elem(), c.item().elem()),
212                    (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
213                );
214
215                let a_id = self.read_as(&a, &out_ty);
216                let b_id = self.read_as(&b, &out_ty);
217                let c_id = self.read_as(&c, &out_ty);
218                let out_id = self.write_id(&out);
219                self.mark_uniformity(out_id, uniform);
220
221                let ty = out_ty.id(self);
222
223                let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
224                self.mark_uniformity(mul, uniform);
225                self.declare_math_mode(modes, mul);
226                self.f_add(ty, Some(out_id), mul, c_id).unwrap();
227                self.declare_math_mode(modes, out_id);
228                if relaxed {
229                    self.decorate(mul, Decoration::RelaxedPrecision, []);
230                    self.decorate(out_id, Decoration::RelaxedPrecision, []);
231                }
232                self.write(&out, out_id);
233            }
234            Arithmetic::Recip(op) => {
235                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
236                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
237                    b.declare_math_mode(modes, out);
238                    b.f_div(ty, Some(out), one, input).unwrap();
239                });
240            }
241            Arithmetic::Neg(op) => {
242                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
243                    match out_ty.elem() {
244                        Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
245                        Elem::Float(..) => {
246                            b.declare_math_mode(modes, out);
247                            b.f_negate(ty, Some(out), input).unwrap()
248                        }
249                        Elem::Relaxed => {
250                            b.decorate(out, Decoration::RelaxedPrecision, []);
251                            b.declare_math_mode(modes, out);
252                            b.f_negate(ty, Some(out), input).unwrap()
253                        }
254                        _ => unreachable!(),
255                    };
256                });
257            }
258            Arithmetic::Erf(_) => {
259                unreachable!("Replaced by transformer")
260            }
261
262            // Extension functions
263            Arithmetic::Normalize(op) => {
264                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
265                    b.declare_math_mode(modes, out);
266                    T::normalize(b, ty, input, out);
267                    if matches!(out_ty.elem(), Elem::Relaxed) {
268                        b.decorate(out, Decoration::RelaxedPrecision, []);
269                    }
270                });
271            }
272            Arithmetic::Magnitude(op) => {
273                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
274                    b.declare_math_mode(modes, out);
275                    T::magnitude(b, ty, input, out);
276                    if matches!(out_ty.elem(), Elem::Relaxed) {
277                        b.decorate(out, Decoration::RelaxedPrecision, []);
278                    }
279                });
280            }
281            Arithmetic::Abs(op) => {
282                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
283                    match out_ty.elem() {
284                        Elem::Int(_, _) => T::s_abs(b, ty, input, out),
285                        Elem::Float(..) => {
286                            b.declare_math_mode(modes, out);
287                            T::f_abs(b, ty, input, out)
288                        }
289                        Elem::Relaxed => {
290                            b.decorate(out, Decoration::RelaxedPrecision, []);
291                            b.declare_math_mode(modes, out);
292                            T::f_abs(b, ty, input, out)
293                        }
294                        _ => unreachable!(),
295                    }
296                });
297            }
298            Arithmetic::Exp(op) => {
299                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
300                    b.declare_math_mode(modes, out);
301                    T::exp(b, ty, input, out);
302                    if matches!(out_ty.elem(), Elem::Relaxed) {
303                        b.decorate(out, Decoration::RelaxedPrecision, []);
304                    }
305                });
306            }
307            Arithmetic::Log(op) => {
308                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
309                    b.declare_math_mode(modes, out);
310                    T::log(b, ty, input, out);
311                    if matches!(out_ty.elem(), Elem::Relaxed) {
312                        b.decorate(out, Decoration::RelaxedPrecision, []);
313                    }
314                })
315            }
316            Arithmetic::Log1p(op) => {
317                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
318                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
319                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
320                    let add = match out_ty.elem() {
321                        Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
322                        Elem::Float(..) | Elem::Relaxed => {
323                            b.declare_math_mode(modes, out);
324                            b.f_add(ty, None, input, one).unwrap()
325                        }
326                        _ => unreachable!(),
327                    };
328                    b.mark_uniformity(add, uniform);
329                    if relaxed {
330                        b.decorate(add, Decoration::RelaxedPrecision, []);
331                        b.decorate(out, Decoration::RelaxedPrecision, []);
332                    }
333                    b.declare_math_mode(modes, out);
334                    T::log(b, ty, add, out)
335                });
336            }
337            Arithmetic::Cos(op) => {
338                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
339                    b.declare_math_mode(modes, out);
340                    T::cos(b, ty, input, out);
341                    if matches!(out_ty.elem(), Elem::Relaxed) {
342                        b.decorate(out, Decoration::RelaxedPrecision, []);
343                    }
344                })
345            }
346            Arithmetic::Sin(op) => {
347                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
348                    b.declare_math_mode(modes, out);
349                    T::sin(b, ty, input, out);
350                    if matches!(out_ty.elem(), Elem::Relaxed) {
351                        b.decorate(out, Decoration::RelaxedPrecision, []);
352                    }
353                })
354            }
355            Arithmetic::Tanh(op) => {
356                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357                    b.declare_math_mode(modes, out);
358                    T::tanh(b, ty, input, out);
359                    if matches!(out_ty.elem(), Elem::Relaxed) {
360                        b.decorate(out, Decoration::RelaxedPrecision, []);
361                    }
362                })
363            }
364            // No powi for Vulkan, just auto-cast to float
365            Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
366                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
367                    let bool = match out_ty {
368                        Item::Scalar(_) => Elem::Bool.id(b),
369                        Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
370                        _ => unreachable!(),
371                    };
372                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
373                    let zero = out_ty.const_u32(b, 0);
374                    let one = out_ty.const_u32(b, 1);
375                    let two = out_ty.const_u32(b, 2);
376                    let modulo = b.f_rem(ty, None, rhs, two).unwrap();
377                    b.declare_math_mode(modes, modulo);
378                    let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
379                    b.declare_math_mode(modes, is_zero);
380                    let abs = b.id();
381                    b.declare_math_mode(modes, abs);
382                    T::f_abs(b, ty, lhs, abs);
383                    let even = b.id();
384                    b.declare_math_mode(modes, even);
385                    T::pow(b, ty, abs, rhs, even);
386                    let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
387                    b.declare_math_mode(modes, cond2_0);
388                    let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
389                    b.declare_math_mode(modes, cond2_1);
390                    let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
391                    let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
392                    b.declare_math_mode(modes, neg_lhs);
393                    let pow2 = b.id();
394                    b.declare_math_mode(modes, pow2);
395                    T::pow(b, ty, neg_lhs, rhs, pow2);
396                    let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
397                    b.declare_math_mode(modes, pow2_neg);
398                    let default = b.id();
399                    b.declare_math_mode(modes, default);
400                    T::pow(b, ty, lhs, rhs, default);
401                    let ids = [
402                        modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
403                        default,
404                    ];
405                    for id in ids {
406                        b.mark_uniformity(id, uniform);
407                        if relaxed {
408                            b.decorate(id, Decoration::RelaxedPrecision, []);
409                        }
410                    }
411                    let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
412                    b.mark_uniformity(sel1, uniform);
413                    b.select(ty, Some(out), is_zero, even, sel1).unwrap();
414                })
415            }
416            Arithmetic::Sqrt(op) => {
417                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
418                    b.declare_math_mode(modes, out);
419                    T::sqrt(b, ty, input, out);
420                    if matches!(out_ty.elem(), Elem::Relaxed) {
421                        b.decorate(out, Decoration::RelaxedPrecision, []);
422                    }
423                })
424            }
425            Arithmetic::InverseSqrt(op) => {
426                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
427                    b.declare_math_mode(modes, out);
428                    T::inverse_sqrt(b, ty, input, out);
429                    if matches!(out_ty.elem(), Elem::Relaxed) {
430                        b.decorate(out, Decoration::RelaxedPrecision, []);
431                    }
432                })
433            }
434            Arithmetic::Round(op) => {
435                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
436                    T::round(b, ty, input, out);
437                    if matches!(out_ty.elem(), Elem::Relaxed) {
438                        b.decorate(out, Decoration::RelaxedPrecision, []);
439                    }
440                })
441            }
442            Arithmetic::Floor(op) => {
443                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
444                    b.declare_math_mode(modes, out);
445                    T::floor(b, ty, input, out);
446                    if matches!(out_ty.elem(), Elem::Relaxed) {
447                        b.decorate(out, Decoration::RelaxedPrecision, []);
448                    }
449                })
450            }
451            Arithmetic::Ceil(op) => {
452                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
453                    b.declare_math_mode(modes, out);
454                    T::ceil(b, ty, input, out);
455                    if matches!(out_ty.elem(), Elem::Relaxed) {
456                        b.decorate(out, Decoration::RelaxedPrecision, []);
457                    }
458                })
459            }
460            Arithmetic::Trunc(op) => {
461                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
462                    b.declare_math_mode(modes, out);
463                    T::trunc(b, ty, input, out);
464                    if matches!(out_ty.elem(), Elem::Relaxed) {
465                        b.decorate(out, Decoration::RelaxedPrecision, []);
466                    }
467                })
468            }
469            Arithmetic::Clamp(op) => {
470                let input = self.compile_variable(op.input);
471                let min = self.compile_variable(op.min_value);
472                let max = self.compile_variable(op.max_value);
473                let out = self.compile_variable(out);
474                let out_ty = out.item();
475
476                let input = self.read_as(&input, &out_ty);
477                let min = self.read_as(&min, &out_ty);
478                let max = self.read_as(&max, &out_ty);
479                let out_id = self.write_id(&out);
480                self.mark_uniformity(out_id, uniform);
481
482                let ty = out_ty.id(self);
483
484                match out_ty.elem() {
485                    Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
486                    Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
487                    Elem::Float(..) => {
488                        self.declare_math_mode(modes, out_id);
489                        T::f_clamp(self, ty, input, min, max, out_id)
490                    }
491                    Elem::Relaxed => {
492                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
493                        self.declare_math_mode(modes, out_id);
494                        T::f_clamp(self, ty, input, min, max, out_id)
495                    }
496                    _ => unreachable!(),
497                }
498                self.write(&out, out_id);
499            }
500
501            Arithmetic::Max(op) => self.compile_binary_op(
502                op,
503                out,
504                uniform,
505                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
506                    Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
507                    Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
508                    Elem::Float(..) => {
509                        b.declare_math_mode(modes, out);
510                        T::f_max(b, ty, lhs, rhs, out)
511                    }
512                    Elem::Relaxed => {
513                        b.decorate(out, Decoration::RelaxedPrecision, []);
514                        b.declare_math_mode(modes, out);
515                        T::f_max(b, ty, lhs, rhs, out)
516                    }
517                    _ => unreachable!(),
518                },
519            ),
520            Arithmetic::Min(op) => self.compile_binary_op(
521                op,
522                out,
523                uniform,
524                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
525                    Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
526                    Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
527                    Elem::Float(..) => {
528                        b.declare_math_mode(modes, out);
529                        T::f_min(b, ty, lhs, rhs, out)
530                    }
531                    Elem::Relaxed => {
532                        b.decorate(out, Decoration::RelaxedPrecision, []);
533                        b.declare_math_mode(modes, out);
534                        T::f_min(b, ty, lhs, rhs, out)
535                    }
536                    _ => unreachable!(),
537                },
538            ),
539        }
540    }
541}