cubecl_spirv/
arithmetic.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10    pub fn compile_arithmetic(
11        &mut self,
12        op: Arithmetic,
13        out: Option<core::Variable>,
14        modes: InstructionModes,
15        uniform: bool,
16    ) {
17        let out = out.unwrap();
18        match op {
19            Arithmetic::Add(op) => {
20                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21                    match out_ty.elem() {
22                        Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23                        Elem::Float(..) => {
24                            b.declare_math_mode(modes, out);
25                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
26                        }
27                        Elem::Relaxed => {
28                            b.decorate(out, Decoration::RelaxedPrecision, []);
29                            b.declare_math_mode(modes, out);
30                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
31                        }
32                        _ => unreachable!(),
33                    };
34                });
35            }
36            Arithmetic::SaturatingAdd(_) => {
37                unimplemented!("Should be replaced by polyfill");
38            }
39            Arithmetic::Sub(op) => {
40                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41                    match out_ty.elem() {
42                        Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43                        Elem::Float(..) => {
44                            b.declare_math_mode(modes, out);
45                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46                        }
47                        Elem::Relaxed => {
48                            b.decorate(out, Decoration::RelaxedPrecision, []);
49                            b.declare_math_mode(modes, out);
50                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51                        }
52                        _ => unreachable!(),
53                    };
54                });
55            }
56            Arithmetic::SaturatingSub(_) => {
57                unimplemented!("Should be replaced by polyfill");
58            }
59            Arithmetic::Mul(op) => {
60                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61                    match out_ty.elem() {
62                        Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63                        Elem::Float(..) => {
64                            b.declare_math_mode(modes, out);
65                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66                        }
67                        Elem::Relaxed => {
68                            b.decorate(out, Decoration::RelaxedPrecision, []);
69                            b.declare_math_mode(modes, out);
70                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71                        }
72                        _ => unreachable!(),
73                    };
74                });
75            }
76            Arithmetic::MulHi(op) => {
77                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78                    let out_st = b.type_struct([ty, ty]);
79                    let extended = match out_ty.elem() {
80                        Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81                        Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82                        _ => unreachable!(),
83                    };
84                    b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85                });
86            }
87            Arithmetic::Div(op) => {
88                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89                    match out_ty.elem() {
90                        Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91                        Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92                        Elem::Float(..) => {
93                            b.declare_math_mode(modes, out);
94                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
95                        }
96                        Elem::Relaxed => {
97                            b.decorate(out, Decoration::RelaxedPrecision, []);
98                            b.declare_math_mode(modes, out);
99                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
100                        }
101                        _ => unreachable!(),
102                    };
103                });
104            }
105            Arithmetic::Remainder(op) => {
106                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107                    match out_ty.elem() {
108                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109                        Elem::Int(_, true) => {
110                            // Convert to float and use `f_mod` (floored division) instead of `s_mod`
111                            // (truncated division) to match remainder semantics across dtypes
112                            // e.g. remainder(-2, 3) = 1, not 2
113                            let f_ty = match out_ty {
114                                Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115                                Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116                                _ => unreachable!(),
117                            };
118                            let f_ty = f_ty.id(b);
119                            let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120                            let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121                            let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122                            b.convert_f_to_s(ty, Some(out), rem).unwrap()
123                        }
124                        Elem::Float(..) => {
125                            b.declare_math_mode(modes, out);
126                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127                        }
128                        Elem::Relaxed => {
129                            b.decorate(out, Decoration::RelaxedPrecision, []);
130                            b.declare_math_mode(modes, out);
131                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132                        }
133                        _ => unreachable!(),
134                    };
135                });
136            }
137            Arithmetic::Modulo(op) => {
138                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139                    match out_ty.elem() {
140                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141                        Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142                        Elem::Float(..) => {
143                            b.declare_math_mode(modes, out);
144                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145                        }
146                        Elem::Relaxed => {
147                            b.decorate(out, Decoration::RelaxedPrecision, []);
148                            b.declare_math_mode(modes, out);
149                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150                        }
151                        _ => unreachable!(),
152                    };
153                });
154            }
155            Arithmetic::Dot(op) => {
156                if op.lhs.ty.line_size() == 1 {
157                    self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158                        match out_ty.elem() {
159                            Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160                            Elem::Float(..) => {
161                                b.declare_math_mode(modes, out);
162                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163                            }
164                            Elem::Relaxed => {
165                                b.decorate(out, Decoration::RelaxedPrecision, []);
166                                b.declare_math_mode(modes, out);
167                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168                            }
169                            _ => unreachable!(),
170                        };
171                    });
172                } else {
173                    let lhs = self.compile_variable(op.lhs);
174                    let rhs = self.compile_variable(op.rhs);
175                    let out = self.compile_variable(out);
176                    let ty = out.item().id(self);
177
178                    let lhs_id = self.read(&lhs);
179                    let rhs_id = self.read(&rhs);
180                    let out_id = self.write_id(&out);
181                    self.mark_uniformity(out_id, uniform);
182
183                    if matches!(lhs.elem(), Elem::Int(_, _)) {
184                        self.capabilities.insert(Capability::DotProduct);
185                    }
186                    if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187                        self.capabilities.insert(Capability::BFloat16DotProductKHR);
188                    }
189
190                    match (lhs.elem(), rhs.elem()) {
191                        (Elem::Int(_, false), Elem::Int(_, false)) => {
192                            self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193                        }
194                        (Elem::Int(_, true), Elem::Int(_, false)) => {
195                            self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196                        }
197                        (Elem::Int(_, false), Elem::Int(_, true)) => {
198                            self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199                        }
200                        (Elem::Int(_, true), Elem::Int(_, true)) => {
201                            self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202                        }
203                        (Elem::Float(..), Elem::Float(..))
204                        | (Elem::Relaxed, Elem::Float(..))
205                        | (Elem::Float(..), Elem::Relaxed) => {
206                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
207                        }
208                        (Elem::Relaxed, Elem::Relaxed) => {
209                            self.decorate(out_id, Decoration::RelaxedPrecision, []);
210                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
211                        }
212                        _ => unreachable!(),
213                    }
214                    .unwrap();
215                    self.write(&out, out_id);
216                }
217            }
218            Arithmetic::Fma(op) => {
219                let a = self.compile_variable(op.a);
220                let b = self.compile_variable(op.b);
221                let c = self.compile_variable(op.c);
222                let out = self.compile_variable(out);
223                let out_ty = out.item();
224                let relaxed = matches!(
225                    (a.item().elem(), b.item().elem(), c.item().elem()),
226                    (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227                );
228
229                let a_id = self.read_as(&a, &out_ty);
230                let b_id = self.read_as(&b, &out_ty);
231                let c_id = self.read_as(&c, &out_ty);
232                let out_id = self.write_id(&out);
233                self.mark_uniformity(out_id, uniform);
234
235                let ty = out_ty.id(self);
236
237                let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238                self.mark_uniformity(mul, uniform);
239                self.declare_math_mode(modes, mul);
240                self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241                self.declare_math_mode(modes, out_id);
242                if relaxed {
243                    self.decorate(mul, Decoration::RelaxedPrecision, []);
244                    self.decorate(out_id, Decoration::RelaxedPrecision, []);
245                }
246                self.write(&out, out_id);
247            }
248            Arithmetic::Recip(op) => {
249                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250                    let one = b
251                        .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
252                        .0;
253                    b.declare_math_mode(modes, out);
254                    b.f_div(ty, Some(out), one, input).unwrap();
255                });
256            }
257            Arithmetic::Neg(op) => {
258                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
259                    match out_ty.elem() {
260                        Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
261                        Elem::Float(..) => {
262                            b.declare_math_mode(modes, out);
263                            b.f_negate(ty, Some(out), input).unwrap()
264                        }
265                        Elem::Relaxed => {
266                            b.decorate(out, Decoration::RelaxedPrecision, []);
267                            b.declare_math_mode(modes, out);
268                            b.f_negate(ty, Some(out), input).unwrap()
269                        }
270                        _ => unreachable!(),
271                    };
272                });
273            }
274            Arithmetic::Erf(_) => {
275                unreachable!("Replaced by transformer")
276            }
277
278            // Extension functions
279            Arithmetic::Normalize(op) => {
280                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
281                    b.declare_math_mode(modes, out);
282                    T::normalize(b, ty, input, out);
283                    if matches!(out_ty.elem(), Elem::Relaxed) {
284                        b.decorate(out, Decoration::RelaxedPrecision, []);
285                    }
286                });
287            }
288            Arithmetic::Magnitude(op) => {
289                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
290                    b.declare_math_mode(modes, out);
291                    T::magnitude(b, ty, input, out);
292                    if matches!(out_ty.elem(), Elem::Relaxed) {
293                        b.decorate(out, Decoration::RelaxedPrecision, []);
294                    }
295                });
296            }
297            Arithmetic::Abs(op) => {
298                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
299                    match out_ty.elem() {
300                        Elem::Int(_, _) => T::s_abs(b, ty, input, out),
301                        Elem::Float(..) => {
302                            b.declare_math_mode(modes, out);
303                            T::f_abs(b, ty, input, out)
304                        }
305                        Elem::Relaxed => {
306                            b.decorate(out, Decoration::RelaxedPrecision, []);
307                            b.declare_math_mode(modes, out);
308                            T::f_abs(b, ty, input, out)
309                        }
310                        _ => unreachable!(),
311                    }
312                });
313            }
314            Arithmetic::Exp(op) => {
315                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
316                    b.declare_math_mode(modes, out);
317                    T::exp(b, ty, input, out);
318                    if matches!(out_ty.elem(), Elem::Relaxed) {
319                        b.decorate(out, Decoration::RelaxedPrecision, []);
320                    }
321                });
322            }
323            Arithmetic::Log(op) => {
324                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
325                    b.declare_math_mode(modes, out);
326                    T::log(b, ty, input, out);
327                    if matches!(out_ty.elem(), Elem::Relaxed) {
328                        b.decorate(out, Decoration::RelaxedPrecision, []);
329                    }
330                })
331            }
332            Arithmetic::Log1p(op) => {
333                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
334                    let one = b
335                        .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
336                        .0;
337                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
338                    let add = match out_ty.elem() {
339                        Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
340                        Elem::Float(..) | Elem::Relaxed => {
341                            b.declare_math_mode(modes, out);
342                            b.f_add(ty, None, input, one).unwrap()
343                        }
344                        _ => unreachable!(),
345                    };
346                    b.mark_uniformity(add, uniform);
347                    if relaxed {
348                        b.decorate(add, Decoration::RelaxedPrecision, []);
349                        b.decorate(out, Decoration::RelaxedPrecision, []);
350                    }
351                    b.declare_math_mode(modes, out);
352                    T::log(b, ty, add, out)
353                });
354            }
355            Arithmetic::Cos(op) => {
356                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357                    b.declare_math_mode(modes, out);
358                    T::cos(b, ty, input, out);
359                    if matches!(out_ty.elem(), Elem::Relaxed) {
360                        b.decorate(out, Decoration::RelaxedPrecision, []);
361                    }
362                })
363            }
364            Arithmetic::Sin(op) => {
365                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
366                    b.declare_math_mode(modes, out);
367                    T::sin(b, ty, input, out);
368                    if matches!(out_ty.elem(), Elem::Relaxed) {
369                        b.decorate(out, Decoration::RelaxedPrecision, []);
370                    }
371                })
372            }
373            Arithmetic::Tan(op) => {
374                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
375                    b.declare_math_mode(modes, out);
376                    T::tan(b, ty, input, out);
377                    if matches!(out_ty.elem(), Elem::Relaxed) {
378                        b.decorate(out, Decoration::RelaxedPrecision, []);
379                    }
380                })
381            }
382            Arithmetic::Tanh(op) => {
383                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
384                    b.declare_math_mode(modes, out);
385                    T::tanh(b, ty, input, out);
386                    if matches!(out_ty.elem(), Elem::Relaxed) {
387                        b.decorate(out, Decoration::RelaxedPrecision, []);
388                    }
389                })
390            }
391            Arithmetic::Sinh(op) => {
392                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
393                    b.declare_math_mode(modes, out);
394                    T::sinh(b, ty, input, out);
395                    if matches!(out_ty.elem(), Elem::Relaxed) {
396                        b.decorate(out, Decoration::RelaxedPrecision, []);
397                    }
398                })
399            }
400            Arithmetic::Cosh(op) => {
401                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
402                    b.declare_math_mode(modes, out);
403                    T::cosh(b, ty, input, out);
404                    if matches!(out_ty.elem(), Elem::Relaxed) {
405                        b.decorate(out, Decoration::RelaxedPrecision, []);
406                    }
407                })
408            }
409            Arithmetic::ArcCos(op) => {
410                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
411                    b.declare_math_mode(modes, out);
412                    T::acos(b, ty, input, out);
413                    if matches!(out_ty.elem(), Elem::Relaxed) {
414                        b.decorate(out, Decoration::RelaxedPrecision, []);
415                    }
416                })
417            }
418            Arithmetic::ArcSin(op) => {
419                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
420                    b.declare_math_mode(modes, out);
421                    T::asin(b, ty, input, out);
422                    if matches!(out_ty.elem(), Elem::Relaxed) {
423                        b.decorate(out, Decoration::RelaxedPrecision, []);
424                    }
425                })
426            }
427            Arithmetic::ArcTan(op) => {
428                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
429                    b.declare_math_mode(modes, out);
430                    T::atan(b, ty, input, out);
431                    if matches!(out_ty.elem(), Elem::Relaxed) {
432                        b.decorate(out, Decoration::RelaxedPrecision, []);
433                    }
434                })
435            }
436            Arithmetic::ArcSinh(op) => {
437                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
438                    b.declare_math_mode(modes, out);
439                    T::asinh(b, ty, input, out);
440                    if matches!(out_ty.elem(), Elem::Relaxed) {
441                        b.decorate(out, Decoration::RelaxedPrecision, []);
442                    }
443                })
444            }
445            Arithmetic::ArcCosh(op) => {
446                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
447                    b.declare_math_mode(modes, out);
448                    T::acosh(b, ty, input, out);
449                    if matches!(out_ty.elem(), Elem::Relaxed) {
450                        b.decorate(out, Decoration::RelaxedPrecision, []);
451                    }
452                })
453            }
454            Arithmetic::ArcTanh(op) => {
455                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
456                    b.declare_math_mode(modes, out);
457                    T::atanh(b, ty, input, out);
458                    if matches!(out_ty.elem(), Elem::Relaxed) {
459                        b.decorate(out, Decoration::RelaxedPrecision, []);
460                    }
461                })
462            }
463            Arithmetic::Degrees(op) => {
464                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
465                    b.declare_math_mode(modes, out);
466                    T::degrees(b, ty, input, out);
467                    if matches!(out_ty.elem(), Elem::Relaxed) {
468                        b.decorate(out, Decoration::RelaxedPrecision, []);
469                    }
470                })
471            }
472            Arithmetic::Radians(op) => {
473                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
474                    b.declare_math_mode(modes, out);
475                    T::radians(b, ty, input, out);
476                    if matches!(out_ty.elem(), Elem::Relaxed) {
477                        b.decorate(out, Decoration::RelaxedPrecision, []);
478                    }
479                })
480            }
481            Arithmetic::ArcTan2(op) => {
482                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
483                    b.declare_math_mode(modes, out);
484                    T::atan2(b, ty, lhs, rhs, out);
485                    if matches!(out_ty.elem(), Elem::Relaxed) {
486                        b.decorate(out, Decoration::RelaxedPrecision, []);
487                    }
488                })
489            }
490            // No powi for Vulkan, just auto-cast to float
491            Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
492                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
493                    let bool = match out_ty {
494                        Item::Scalar(_) => Elem::Bool.id(b),
495                        Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
496                        _ => unreachable!(),
497                    };
498                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
499                    let zero = out_ty.const_u32(b, 0);
500                    let one = out_ty.const_u32(b, 1);
501                    let two = out_ty.const_u32(b, 2);
502                    let modulo = b.f_rem(ty, None, rhs, two).unwrap();
503                    b.declare_math_mode(modes, modulo);
504                    let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
505                    b.declare_math_mode(modes, is_zero);
506                    let abs = b.id();
507                    b.declare_math_mode(modes, abs);
508                    T::f_abs(b, ty, lhs, abs);
509                    let even = b.id();
510                    b.declare_math_mode(modes, even);
511                    T::pow(b, ty, abs, rhs, even);
512                    let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
513                    b.declare_math_mode(modes, cond2_0);
514                    let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
515                    b.declare_math_mode(modes, cond2_1);
516                    let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
517                    let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
518                    b.declare_math_mode(modes, neg_lhs);
519                    let pow2 = b.id();
520                    b.declare_math_mode(modes, pow2);
521                    T::pow(b, ty, neg_lhs, rhs, pow2);
522                    let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
523                    b.declare_math_mode(modes, pow2_neg);
524                    let default = b.id();
525                    b.declare_math_mode(modes, default);
526                    T::pow(b, ty, lhs, rhs, default);
527                    let ids = [
528                        modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
529                        default,
530                    ];
531                    for id in ids {
532                        b.mark_uniformity(id, uniform);
533                        if relaxed {
534                            b.decorate(id, Decoration::RelaxedPrecision, []);
535                        }
536                    }
537                    let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
538                    b.mark_uniformity(sel1, uniform);
539                    b.select(ty, Some(out), is_zero, even, sel1).unwrap();
540                })
541            }
542            Arithmetic::Hypot(_op) => {
543                unreachable!("Replaced by transformer");
544            }
545            Arithmetic::Rhypot(_op) => {
546                unreachable!("Replaced by transformer");
547            }
548            Arithmetic::Sqrt(op) => {
549                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
550                    b.declare_math_mode(modes, out);
551                    T::sqrt(b, ty, input, out);
552                    if matches!(out_ty.elem(), Elem::Relaxed) {
553                        b.decorate(out, Decoration::RelaxedPrecision, []);
554                    }
555                })
556            }
557            Arithmetic::InverseSqrt(op) => {
558                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
559                    b.declare_math_mode(modes, out);
560                    T::inverse_sqrt(b, ty, input, out);
561                    if matches!(out_ty.elem(), Elem::Relaxed) {
562                        b.decorate(out, Decoration::RelaxedPrecision, []);
563                    }
564                })
565            }
566            Arithmetic::Round(op) => {
567                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
568                    T::round(b, ty, input, out);
569                    if matches!(out_ty.elem(), Elem::Relaxed) {
570                        b.decorate(out, Decoration::RelaxedPrecision, []);
571                    }
572                })
573            }
574            Arithmetic::Floor(op) => {
575                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
576                    b.declare_math_mode(modes, out);
577                    T::floor(b, ty, input, out);
578                    if matches!(out_ty.elem(), Elem::Relaxed) {
579                        b.decorate(out, Decoration::RelaxedPrecision, []);
580                    }
581                })
582            }
583            Arithmetic::Ceil(op) => {
584                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
585                    b.declare_math_mode(modes, out);
586                    T::ceil(b, ty, input, out);
587                    if matches!(out_ty.elem(), Elem::Relaxed) {
588                        b.decorate(out, Decoration::RelaxedPrecision, []);
589                    }
590                })
591            }
592            Arithmetic::Trunc(op) => {
593                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
594                    b.declare_math_mode(modes, out);
595                    T::trunc(b, ty, input, out);
596                    if matches!(out_ty.elem(), Elem::Relaxed) {
597                        b.decorate(out, Decoration::RelaxedPrecision, []);
598                    }
599                })
600            }
601            Arithmetic::Clamp(op) => {
602                let input = self.compile_variable(op.input);
603                let min = self.compile_variable(op.min_value);
604                let max = self.compile_variable(op.max_value);
605                let out = self.compile_variable(out);
606                let out_ty = out.item();
607
608                let input = self.read_as(&input, &out_ty);
609                let min = self.read_as(&min, &out_ty);
610                let max = self.read_as(&max, &out_ty);
611                let out_id = self.write_id(&out);
612                self.mark_uniformity(out_id, uniform);
613
614                let ty = out_ty.id(self);
615
616                match out_ty.elem() {
617                    Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
618                    Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
619                    Elem::Float(..) => {
620                        self.declare_math_mode(modes, out_id);
621                        T::f_clamp(self, ty, input, min, max, out_id)
622                    }
623                    Elem::Relaxed => {
624                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
625                        self.declare_math_mode(modes, out_id);
626                        T::f_clamp(self, ty, input, min, max, out_id)
627                    }
628                    _ => unreachable!(),
629                }
630                self.write(&out, out_id);
631            }
632
633            Arithmetic::Max(op) => self.compile_binary_op(
634                op,
635                out,
636                uniform,
637                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
638                    Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
639                    Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
640                    Elem::Float(..) => {
641                        b.declare_math_mode(modes, out);
642                        T::f_max(b, ty, lhs, rhs, out)
643                    }
644                    Elem::Relaxed => {
645                        b.decorate(out, Decoration::RelaxedPrecision, []);
646                        b.declare_math_mode(modes, out);
647                        T::f_max(b, ty, lhs, rhs, out)
648                    }
649                    _ => unreachable!(),
650                },
651            ),
652            Arithmetic::Min(op) => self.compile_binary_op(
653                op,
654                out,
655                uniform,
656                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
657                    Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
658                    Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
659                    Elem::Float(..) => {
660                        b.declare_math_mode(modes, out);
661                        T::f_min(b, ty, lhs, rhs, out)
662                    }
663                    Elem::Relaxed => {
664                        b.decorate(out, Decoration::RelaxedPrecision, []);
665                        b.declare_math_mode(modes, out);
666                        T::f_min(b, ty, lhs, rhs, out)
667                    }
668                    _ => unreachable!(),
669                },
670            ),
671        }
672    }
673}