cubecl_spirv/
arithmetic.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10    pub fn compile_arithmetic(
11        &mut self,
12        op: Arithmetic,
13        out: Option<core::Variable>,
14        uniform: bool,
15    ) {
16        let out = out.unwrap();
17        match op {
18            Arithmetic::Add(op) => {
19                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
20                    match out_ty.elem() {
21                        Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
22                        Elem::Float(..) => b.f_add(ty, Some(out), lhs, rhs).unwrap(),
23                        Elem::Relaxed => {
24                            b.decorate(out, Decoration::RelaxedPrecision, []);
25                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
26                        }
27                        _ => unreachable!(),
28                    };
29                });
30            }
31            Arithmetic::SaturatingAdd(_) => {
32                unimplemented!("Should be replaced by polyfill");
33            }
34            Arithmetic::Sub(op) => {
35                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
36                    match out_ty.elem() {
37                        Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
38                        Elem::Float(..) => b.f_sub(ty, Some(out), lhs, rhs).unwrap(),
39                        Elem::Relaxed => {
40                            b.decorate(out, Decoration::RelaxedPrecision, []);
41                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
42                        }
43                        _ => unreachable!(),
44                    };
45                });
46            }
47            Arithmetic::SaturatingSub(_) => {
48                unimplemented!("Should be replaced by polyfill");
49            }
50            Arithmetic::Mul(op) => {
51                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
52                    match out_ty.elem() {
53                        Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
54                        Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(),
55                        Elem::Relaxed => {
56                            b.decorate(out, Decoration::RelaxedPrecision, []);
57                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
58                        }
59                        _ => unreachable!(),
60                    };
61                });
62            }
63            Arithmetic::MulHi(op) => {
64                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
65                    let out_st = b.type_struct([ty, ty]);
66                    let extended = match out_ty.elem() {
67                        Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
68                        Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
69                        _ => unreachable!(),
70                    };
71                    b.composite_extract(ty, Some(out), extended, [1]).unwrap();
72                });
73            }
74            Arithmetic::Div(op) => {
75                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
76                    match out_ty.elem() {
77                        Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
78                        Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
79                        Elem::Float(..) => b.f_div(ty, Some(out), lhs, rhs).unwrap(),
80                        Elem::Relaxed => {
81                            b.decorate(out, Decoration::RelaxedPrecision, []);
82                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
83                        }
84                        _ => unreachable!(),
85                    };
86                });
87            }
88            Arithmetic::Remainder(op) => {
89                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
90                    match out_ty.elem() {
91                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
92                        Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(),
93                        Elem::Float(..) => b.f_mod(ty, Some(out), lhs, rhs).unwrap(),
94                        Elem::Relaxed => {
95                            b.decorate(out, Decoration::RelaxedPrecision, []);
96                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
97                        }
98                        _ => unreachable!(),
99                    };
100                });
101            }
102            Arithmetic::Modulo(op) => {
103                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
104                    match out_ty.elem() {
105                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
106                        Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
107                        Elem::Float(..) => b.f_rem(ty, Some(out), lhs, rhs).unwrap(),
108                        Elem::Relaxed => {
109                            b.decorate(out, Decoration::RelaxedPrecision, []);
110                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
111                        }
112                        _ => unreachable!(),
113                    };
114                });
115            }
116            Arithmetic::Dot(op) => {
117                if op.lhs.ty.line_size() == 1 {
118                    self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
119                        match out_ty.elem() {
120                            Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
121                            Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(),
122                            Elem::Relaxed => {
123                                b.decorate(out, Decoration::RelaxedPrecision, []);
124                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
125                            }
126                            _ => unreachable!(),
127                        };
128                    });
129                } else {
130                    let lhs = self.compile_variable(op.lhs);
131                    let rhs = self.compile_variable(op.rhs);
132                    let out = self.compile_variable(out);
133                    let ty = out.item().id(self);
134
135                    let lhs_id = self.read(&lhs);
136                    let rhs_id = self.read(&rhs);
137                    let out_id = self.write_id(&out);
138                    self.mark_uniformity(out_id, uniform);
139
140                    if matches!(lhs.elem(), Elem::Int(_, _)) {
141                        self.capabilities.insert(Capability::DotProduct);
142                    }
143                    if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
144                        self.capabilities.insert(Capability::BFloat16DotProductKHR);
145                    }
146
147                    match (lhs.elem(), rhs.elem()) {
148                        (Elem::Int(_, false), Elem::Int(_, false)) => {
149                            self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
150                        }
151                        (Elem::Int(_, true), Elem::Int(_, false)) => {
152                            self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
153                        }
154                        (Elem::Int(_, false), Elem::Int(_, true)) => {
155                            self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
156                        }
157                        (Elem::Int(_, true), Elem::Int(_, true)) => {
158                            self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
159                        }
160                        (Elem::Float(..), Elem::Float(..))
161                        | (Elem::Relaxed, Elem::Float(..))
162                        | (Elem::Float(..), Elem::Relaxed) => {
163                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
164                        }
165                        (Elem::Relaxed, Elem::Relaxed) => {
166                            self.decorate(out_id, Decoration::RelaxedPrecision, []);
167                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
168                        }
169                        _ => unreachable!(),
170                    }
171                    .unwrap();
172                    self.write(&out, out_id);
173                }
174            }
175            Arithmetic::Fma(op) => {
176                let a = self.compile_variable(op.a);
177                let b = self.compile_variable(op.b);
178                let c = self.compile_variable(op.c);
179                let out = self.compile_variable(out);
180                let out_ty = out.item();
181                let relaxed = matches!(
182                    (a.item().elem(), b.item().elem(), c.item().elem()),
183                    (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
184                );
185
186                let a_id = self.read_as(&a, &out_ty);
187                let b_id = self.read_as(&b, &out_ty);
188                let c_id = self.read_as(&c, &out_ty);
189                let out_id = self.write_id(&out);
190                self.mark_uniformity(out_id, uniform);
191
192                let ty = out_ty.id(self);
193
194                let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
195                self.mark_uniformity(mul, uniform);
196                self.f_add(ty, Some(out_id), mul, c_id).unwrap();
197                if relaxed {
198                    self.decorate(mul, Decoration::RelaxedPrecision, []);
199                    self.decorate(out_id, Decoration::RelaxedPrecision, []);
200                }
201                self.write(&out, out_id);
202            }
203            Arithmetic::Recip(op) => {
204                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
205                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
206                    b.f_div(ty, Some(out), one, input).unwrap();
207                });
208            }
209            Arithmetic::Neg(op) => {
210                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
211                    match out_ty.elem() {
212                        Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
213                        Elem::Float(..) => b.f_negate(ty, Some(out), input).unwrap(),
214                        Elem::Relaxed => {
215                            b.decorate(out, Decoration::RelaxedPrecision, []);
216                            b.f_negate(ty, Some(out), input).unwrap()
217                        }
218                        _ => unreachable!(),
219                    };
220                });
221            }
222            Arithmetic::Erf(_) => {
223                unreachable!("Replaced by transformer")
224            }
225
226            // Extension functions
227            Arithmetic::Normalize(op) => {
228                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
229                    T::normalize(b, ty, input, out);
230                    if matches!(out_ty.elem(), Elem::Relaxed) {
231                        b.decorate(out, Decoration::RelaxedPrecision, []);
232                    }
233                });
234            }
235            Arithmetic::Magnitude(op) => {
236                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
237                    T::magnitude(b, ty, input, out);
238                    if matches!(out_ty.elem(), Elem::Relaxed) {
239                        b.decorate(out, Decoration::RelaxedPrecision, []);
240                    }
241                });
242            }
243            Arithmetic::Abs(op) => {
244                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
245                    match out_ty.elem() {
246                        Elem::Int(_, _) => T::s_abs(b, ty, input, out),
247                        Elem::Float(..) => T::f_abs(b, ty, input, out),
248                        Elem::Relaxed => {
249                            b.decorate(out, Decoration::RelaxedPrecision, []);
250                            T::f_abs(b, ty, input, out)
251                        }
252                        _ => unreachable!(),
253                    }
254                });
255            }
256            Arithmetic::Exp(op) => {
257                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
258                    T::exp(b, ty, input, out);
259                    if matches!(out_ty.elem(), Elem::Relaxed) {
260                        b.decorate(out, Decoration::RelaxedPrecision, []);
261                    }
262                });
263            }
264            Arithmetic::Log(op) => {
265                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
266                    T::log(b, ty, input, out);
267                    if matches!(out_ty.elem(), Elem::Relaxed) {
268                        b.decorate(out, Decoration::RelaxedPrecision, []);
269                    }
270                })
271            }
272            Arithmetic::Log1p(op) => {
273                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
274                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
275                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
276                    let add = match out_ty.elem() {
277                        Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
278                        Elem::Float(..) | Elem::Relaxed => b.f_add(ty, None, input, one).unwrap(),
279                        _ => unreachable!(),
280                    };
281                    b.mark_uniformity(add, uniform);
282                    if relaxed {
283                        b.decorate(add, Decoration::RelaxedPrecision, []);
284                        b.decorate(out, Decoration::RelaxedPrecision, []);
285                    }
286                    T::log(b, ty, add, out)
287                });
288            }
289            Arithmetic::Cos(op) => {
290                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
291                    T::cos(b, ty, input, out);
292                    if matches!(out_ty.elem(), Elem::Relaxed) {
293                        b.decorate(out, Decoration::RelaxedPrecision, []);
294                    }
295                })
296            }
297            Arithmetic::Sin(op) => {
298                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
299                    T::sin(b, ty, input, out);
300                    if matches!(out_ty.elem(), Elem::Relaxed) {
301                        b.decorate(out, Decoration::RelaxedPrecision, []);
302                    }
303                })
304            }
305            Arithmetic::Tanh(op) => {
306                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
307                    T::tanh(b, ty, input, out);
308                    if matches!(out_ty.elem(), Elem::Relaxed) {
309                        b.decorate(out, Decoration::RelaxedPrecision, []);
310                    }
311                })
312            }
313            // No powi for Vulkan, just auto-cast to float
314            Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
315                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
316                    let bool = match out_ty {
317                        Item::Scalar(_) => Elem::Bool.id(b),
318                        Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
319                        _ => unreachable!(),
320                    };
321                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
322                    let zero = out_ty.const_u32(b, 0);
323                    let one = out_ty.const_u32(b, 1);
324                    let two = out_ty.const_u32(b, 2);
325                    let modulo = b.f_rem(ty, None, rhs, two).unwrap();
326                    let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
327                    let abs = b.id();
328                    T::f_abs(b, ty, lhs, abs);
329                    let even = b.id();
330                    T::pow(b, ty, abs, rhs, even);
331                    let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
332                    let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
333                    let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
334                    let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
335                    let pow2 = b.id();
336                    T::pow(b, ty, neg_lhs, rhs, pow2);
337                    let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
338                    let default = b.id();
339                    T::pow(b, ty, lhs, rhs, default);
340                    let ids = [
341                        modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
342                        default,
343                    ];
344                    for id in ids {
345                        b.mark_uniformity(id, uniform);
346                        if relaxed {
347                            b.decorate(id, Decoration::RelaxedPrecision, []);
348                        }
349                    }
350                    let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
351                    b.mark_uniformity(sel1, uniform);
352                    b.select(ty, Some(out), is_zero, even, sel1).unwrap();
353                })
354            }
355            Arithmetic::Sqrt(op) => {
356                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357                    T::sqrt(b, ty, input, out);
358                    if matches!(out_ty.elem(), Elem::Relaxed) {
359                        b.decorate(out, Decoration::RelaxedPrecision, []);
360                    }
361                })
362            }
363            Arithmetic::Round(op) => {
364                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
365                    T::round(b, ty, input, out);
366                    if matches!(out_ty.elem(), Elem::Relaxed) {
367                        b.decorate(out, Decoration::RelaxedPrecision, []);
368                    }
369                })
370            }
371            Arithmetic::Floor(op) => {
372                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
373                    T::floor(b, ty, input, out);
374                    if matches!(out_ty.elem(), Elem::Relaxed) {
375                        b.decorate(out, Decoration::RelaxedPrecision, []);
376                    }
377                })
378            }
379            Arithmetic::Ceil(op) => {
380                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
381                    T::ceil(b, ty, input, out);
382                    if matches!(out_ty.elem(), Elem::Relaxed) {
383                        b.decorate(out, Decoration::RelaxedPrecision, []);
384                    }
385                })
386            }
387            Arithmetic::Trunc(op) => {
388                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
389                    T::trunc(b, ty, input, out);
390                    if matches!(out_ty.elem(), Elem::Relaxed) {
391                        b.decorate(out, Decoration::RelaxedPrecision, []);
392                    }
393                })
394            }
395            Arithmetic::Clamp(op) => {
396                let input = self.compile_variable(op.input);
397                let min = self.compile_variable(op.min_value);
398                let max = self.compile_variable(op.max_value);
399                let out = self.compile_variable(out);
400                let out_ty = out.item();
401
402                let input = self.read_as(&input, &out_ty);
403                let min = self.read_as(&min, &out_ty);
404                let max = self.read_as(&max, &out_ty);
405                let out_id = self.write_id(&out);
406                self.mark_uniformity(out_id, uniform);
407
408                let ty = out_ty.id(self);
409
410                match out_ty.elem() {
411                    Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
412                    Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
413                    Elem::Float(..) => T::f_clamp(self, ty, input, min, max, out_id),
414                    Elem::Relaxed => {
415                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
416                        T::f_clamp(self, ty, input, min, max, out_id)
417                    }
418                    _ => unreachable!(),
419                }
420                self.write(&out, out_id);
421            }
422
423            Arithmetic::Max(op) => self.compile_binary_op(
424                op,
425                out,
426                uniform,
427                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
428                    Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
429                    Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
430                    Elem::Float(..) => T::f_max(b, ty, lhs, rhs, out),
431                    Elem::Relaxed => {
432                        b.decorate(out, Decoration::RelaxedPrecision, []);
433                        T::f_max(b, ty, lhs, rhs, out)
434                    }
435                    _ => unreachable!(),
436                },
437            ),
438            Arithmetic::Min(op) => self.compile_binary_op(
439                op,
440                out,
441                uniform,
442                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
443                    Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
444                    Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
445                    Elem::Float(..) => T::f_min(b, ty, lhs, rhs, out),
446                    Elem::Relaxed => {
447                        b.decorate(out, Decoration::RelaxedPrecision, []);
448                        T::f_min(b, ty, lhs, rhs, out)
449                    }
450                    _ => unreachable!(),
451                },
452            ),
453        }
454    }
455}