cubecl_spirv/
arithmetic.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10    pub fn compile_arithmetic(
11        &mut self,
12        op: Arithmetic,
13        out: Option<core::Variable>,
14        modes: InstructionModes,
15        uniform: bool,
16    ) {
17        let out = out.unwrap();
18        match op {
19            Arithmetic::Add(op) => {
20                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21                    match out_ty.elem() {
22                        Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23                        Elem::Float(..) => {
24                            b.declare_math_mode(modes, out);
25                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
26                        }
27                        Elem::Relaxed => {
28                            b.decorate(out, Decoration::RelaxedPrecision, []);
29                            b.declare_math_mode(modes, out);
30                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
31                        }
32                        _ => unreachable!(),
33                    };
34                });
35            }
36            Arithmetic::SaturatingAdd(_) => {
37                unimplemented!("Should be replaced by polyfill");
38            }
39            Arithmetic::Sub(op) => {
40                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41                    match out_ty.elem() {
42                        Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43                        Elem::Float(..) => {
44                            b.declare_math_mode(modes, out);
45                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46                        }
47                        Elem::Relaxed => {
48                            b.decorate(out, Decoration::RelaxedPrecision, []);
49                            b.declare_math_mode(modes, out);
50                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51                        }
52                        _ => unreachable!(),
53                    };
54                });
55            }
56            Arithmetic::SaturatingSub(_) => {
57                unimplemented!("Should be replaced by polyfill");
58            }
59            Arithmetic::Mul(op) => {
60                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61                    match out_ty.elem() {
62                        Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63                        Elem::Float(..) => {
64                            b.declare_math_mode(modes, out);
65                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66                        }
67                        Elem::Relaxed => {
68                            b.decorate(out, Decoration::RelaxedPrecision, []);
69                            b.declare_math_mode(modes, out);
70                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71                        }
72                        _ => unreachable!(),
73                    };
74                });
75            }
76            Arithmetic::MulHi(op) => {
77                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78                    let out_st = b.type_struct([ty, ty]);
79                    let extended = match out_ty.elem() {
80                        Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81                        Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82                        _ => unreachable!(),
83                    };
84                    b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85                });
86            }
87            Arithmetic::Div(op) => {
88                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89                    match out_ty.elem() {
90                        Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91                        Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92                        Elem::Float(..) => {
93                            b.declare_math_mode(modes, out);
94                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
95                        }
96                        Elem::Relaxed => {
97                            b.decorate(out, Decoration::RelaxedPrecision, []);
98                            b.declare_math_mode(modes, out);
99                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
100                        }
101                        _ => unreachable!(),
102                    };
103                });
104            }
105            Arithmetic::Remainder(op) => {
106                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107                    match out_ty.elem() {
108                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109                        Elem::Int(_, true) => {
110                            // Convert to float and use `f_mod` (floored division) instead of `s_mod`
111                            // (truncated division) to match remainder semantics across dtypes
112                            // e.g. remainder(-2, 3) = 1, not 2
113                            let f_ty = match out_ty {
114                                Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115                                Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116                                _ => unreachable!(),
117                            };
118                            let f_ty = f_ty.id(b);
119                            let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120                            let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121                            let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122                            b.convert_f_to_s(ty, Some(out), rem).unwrap()
123                        }
124                        Elem::Float(..) => {
125                            b.declare_math_mode(modes, out);
126                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127                        }
128                        Elem::Relaxed => {
129                            b.decorate(out, Decoration::RelaxedPrecision, []);
130                            b.declare_math_mode(modes, out);
131                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132                        }
133                        _ => unreachable!(),
134                    };
135                });
136            }
137            Arithmetic::Modulo(op) => {
138                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139                    match out_ty.elem() {
140                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141                        Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142                        Elem::Float(..) => {
143                            b.declare_math_mode(modes, out);
144                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145                        }
146                        Elem::Relaxed => {
147                            b.decorate(out, Decoration::RelaxedPrecision, []);
148                            b.declare_math_mode(modes, out);
149                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150                        }
151                        _ => unreachable!(),
152                    };
153                });
154            }
155            Arithmetic::Dot(op) => {
156                if op.lhs.ty.line_size() == 1 {
157                    self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158                        match out_ty.elem() {
159                            Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160                            Elem::Float(..) => {
161                                b.declare_math_mode(modes, out);
162                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163                            }
164                            Elem::Relaxed => {
165                                b.decorate(out, Decoration::RelaxedPrecision, []);
166                                b.declare_math_mode(modes, out);
167                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168                            }
169                            _ => unreachable!(),
170                        };
171                    });
172                } else {
173                    let lhs = self.compile_variable(op.lhs);
174                    let rhs = self.compile_variable(op.rhs);
175                    let out = self.compile_variable(out);
176                    let ty = out.item().id(self);
177
178                    let lhs_id = self.read(&lhs);
179                    let rhs_id = self.read(&rhs);
180                    let out_id = self.write_id(&out);
181                    self.mark_uniformity(out_id, uniform);
182
183                    if matches!(lhs.elem(), Elem::Int(_, _)) {
184                        self.capabilities.insert(Capability::DotProduct);
185                    }
186                    if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187                        self.capabilities.insert(Capability::BFloat16DotProductKHR);
188                    }
189
190                    match (lhs.elem(), rhs.elem()) {
191                        (Elem::Int(_, false), Elem::Int(_, false)) => {
192                            self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193                        }
194                        (Elem::Int(_, true), Elem::Int(_, false)) => {
195                            self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196                        }
197                        (Elem::Int(_, false), Elem::Int(_, true)) => {
198                            self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199                        }
200                        (Elem::Int(_, true), Elem::Int(_, true)) => {
201                            self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202                        }
203                        (Elem::Float(..), Elem::Float(..))
204                        | (Elem::Relaxed, Elem::Float(..))
205                        | (Elem::Float(..), Elem::Relaxed) => {
206                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
207                        }
208                        (Elem::Relaxed, Elem::Relaxed) => {
209                            self.decorate(out_id, Decoration::RelaxedPrecision, []);
210                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
211                        }
212                        _ => unreachable!(),
213                    }
214                    .unwrap();
215                    self.write(&out, out_id);
216                }
217            }
218            Arithmetic::Fma(op) => {
219                let a = self.compile_variable(op.a);
220                let b = self.compile_variable(op.b);
221                let c = self.compile_variable(op.c);
222                let out = self.compile_variable(out);
223                let out_ty = out.item();
224                let relaxed = matches!(
225                    (a.item().elem(), b.item().elem(), c.item().elem()),
226                    (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227                );
228
229                let a_id = self.read_as(&a, &out_ty);
230                let b_id = self.read_as(&b, &out_ty);
231                let c_id = self.read_as(&c, &out_ty);
232                let out_id = self.write_id(&out);
233                self.mark_uniformity(out_id, uniform);
234
235                let ty = out_ty.id(self);
236
237                let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238                self.mark_uniformity(mul, uniform);
239                self.declare_math_mode(modes, mul);
240                self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241                self.declare_math_mode(modes, out_id);
242                if relaxed {
243                    self.decorate(mul, Decoration::RelaxedPrecision, []);
244                    self.decorate(out_id, Decoration::RelaxedPrecision, []);
245                }
246                self.write(&out, out_id);
247            }
248            Arithmetic::Recip(op) => {
249                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
251                    b.declare_math_mode(modes, out);
252                    b.f_div(ty, Some(out), one, input).unwrap();
253                });
254            }
255            Arithmetic::Neg(op) => {
256                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
257                    match out_ty.elem() {
258                        Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
259                        Elem::Float(..) => {
260                            b.declare_math_mode(modes, out);
261                            b.f_negate(ty, Some(out), input).unwrap()
262                        }
263                        Elem::Relaxed => {
264                            b.decorate(out, Decoration::RelaxedPrecision, []);
265                            b.declare_math_mode(modes, out);
266                            b.f_negate(ty, Some(out), input).unwrap()
267                        }
268                        _ => unreachable!(),
269                    };
270                });
271            }
272            Arithmetic::Erf(_) => {
273                unreachable!("Replaced by transformer")
274            }
275
276            // Extension functions
277            Arithmetic::Normalize(op) => {
278                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
279                    b.declare_math_mode(modes, out);
280                    T::normalize(b, ty, input, out);
281                    if matches!(out_ty.elem(), Elem::Relaxed) {
282                        b.decorate(out, Decoration::RelaxedPrecision, []);
283                    }
284                });
285            }
286            Arithmetic::Magnitude(op) => {
287                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
288                    b.declare_math_mode(modes, out);
289                    T::magnitude(b, ty, input, out);
290                    if matches!(out_ty.elem(), Elem::Relaxed) {
291                        b.decorate(out, Decoration::RelaxedPrecision, []);
292                    }
293                });
294            }
295            Arithmetic::Abs(op) => {
296                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
297                    match out_ty.elem() {
298                        Elem::Int(_, _) => T::s_abs(b, ty, input, out),
299                        Elem::Float(..) => {
300                            b.declare_math_mode(modes, out);
301                            T::f_abs(b, ty, input, out)
302                        }
303                        Elem::Relaxed => {
304                            b.decorate(out, Decoration::RelaxedPrecision, []);
305                            b.declare_math_mode(modes, out);
306                            T::f_abs(b, ty, input, out)
307                        }
308                        _ => unreachable!(),
309                    }
310                });
311            }
312            Arithmetic::Exp(op) => {
313                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
314                    b.declare_math_mode(modes, out);
315                    T::exp(b, ty, input, out);
316                    if matches!(out_ty.elem(), Elem::Relaxed) {
317                        b.decorate(out, Decoration::RelaxedPrecision, []);
318                    }
319                });
320            }
321            Arithmetic::Log(op) => {
322                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
323                    b.declare_math_mode(modes, out);
324                    T::log(b, ty, input, out);
325                    if matches!(out_ty.elem(), Elem::Relaxed) {
326                        b.decorate(out, Decoration::RelaxedPrecision, []);
327                    }
328                })
329            }
330            Arithmetic::Log1p(op) => {
331                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
332                    let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
333                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
334                    let add = match out_ty.elem() {
335                        Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
336                        Elem::Float(..) | Elem::Relaxed => {
337                            b.declare_math_mode(modes, out);
338                            b.f_add(ty, None, input, one).unwrap()
339                        }
340                        _ => unreachable!(),
341                    };
342                    b.mark_uniformity(add, uniform);
343                    if relaxed {
344                        b.decorate(add, Decoration::RelaxedPrecision, []);
345                        b.decorate(out, Decoration::RelaxedPrecision, []);
346                    }
347                    b.declare_math_mode(modes, out);
348                    T::log(b, ty, add, out)
349                });
350            }
351            Arithmetic::Cos(op) => {
352                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
353                    b.declare_math_mode(modes, out);
354                    T::cos(b, ty, input, out);
355                    if matches!(out_ty.elem(), Elem::Relaxed) {
356                        b.decorate(out, Decoration::RelaxedPrecision, []);
357                    }
358                })
359            }
360            Arithmetic::Sin(op) => {
361                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
362                    b.declare_math_mode(modes, out);
363                    T::sin(b, ty, input, out);
364                    if matches!(out_ty.elem(), Elem::Relaxed) {
365                        b.decorate(out, Decoration::RelaxedPrecision, []);
366                    }
367                })
368            }
369            Arithmetic::Tan(op) => {
370                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
371                    b.declare_math_mode(modes, out);
372                    T::tan(b, ty, input, out);
373                    if matches!(out_ty.elem(), Elem::Relaxed) {
374                        b.decorate(out, Decoration::RelaxedPrecision, []);
375                    }
376                })
377            }
378            Arithmetic::Tanh(op) => {
379                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
380                    b.declare_math_mode(modes, out);
381                    T::tanh(b, ty, input, out);
382                    if matches!(out_ty.elem(), Elem::Relaxed) {
383                        b.decorate(out, Decoration::RelaxedPrecision, []);
384                    }
385                })
386            }
387            Arithmetic::Sinh(op) => {
388                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
389                    b.declare_math_mode(modes, out);
390                    T::sinh(b, ty, input, out);
391                    if matches!(out_ty.elem(), Elem::Relaxed) {
392                        b.decorate(out, Decoration::RelaxedPrecision, []);
393                    }
394                })
395            }
396            Arithmetic::Cosh(op) => {
397                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
398                    b.declare_math_mode(modes, out);
399                    T::cosh(b, ty, input, out);
400                    if matches!(out_ty.elem(), Elem::Relaxed) {
401                        b.decorate(out, Decoration::RelaxedPrecision, []);
402                    }
403                })
404            }
405            Arithmetic::ArcCos(op) => {
406                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
407                    b.declare_math_mode(modes, out);
408                    T::acos(b, ty, input, out);
409                    if matches!(out_ty.elem(), Elem::Relaxed) {
410                        b.decorate(out, Decoration::RelaxedPrecision, []);
411                    }
412                })
413            }
414            Arithmetic::ArcSin(op) => {
415                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
416                    b.declare_math_mode(modes, out);
417                    T::asin(b, ty, input, out);
418                    if matches!(out_ty.elem(), Elem::Relaxed) {
419                        b.decorate(out, Decoration::RelaxedPrecision, []);
420                    }
421                })
422            }
423            Arithmetic::ArcTan(op) => {
424                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
425                    b.declare_math_mode(modes, out);
426                    T::atan(b, ty, input, out);
427                    if matches!(out_ty.elem(), Elem::Relaxed) {
428                        b.decorate(out, Decoration::RelaxedPrecision, []);
429                    }
430                })
431            }
432            Arithmetic::ArcSinh(op) => {
433                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
434                    b.declare_math_mode(modes, out);
435                    T::asinh(b, ty, input, out);
436                    if matches!(out_ty.elem(), Elem::Relaxed) {
437                        b.decorate(out, Decoration::RelaxedPrecision, []);
438                    }
439                })
440            }
441            Arithmetic::ArcCosh(op) => {
442                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
443                    b.declare_math_mode(modes, out);
444                    T::acosh(b, ty, input, out);
445                    if matches!(out_ty.elem(), Elem::Relaxed) {
446                        b.decorate(out, Decoration::RelaxedPrecision, []);
447                    }
448                })
449            }
450            Arithmetic::ArcTanh(op) => {
451                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
452                    b.declare_math_mode(modes, out);
453                    T::atanh(b, ty, input, out);
454                    if matches!(out_ty.elem(), Elem::Relaxed) {
455                        b.decorate(out, Decoration::RelaxedPrecision, []);
456                    }
457                })
458            }
459            Arithmetic::Degrees(op) => {
460                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
461                    b.declare_math_mode(modes, out);
462                    T::degrees(b, ty, input, out);
463                    if matches!(out_ty.elem(), Elem::Relaxed) {
464                        b.decorate(out, Decoration::RelaxedPrecision, []);
465                    }
466                })
467            }
468            Arithmetic::Radians(op) => {
469                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
470                    b.declare_math_mode(modes, out);
471                    T::radians(b, ty, input, out);
472                    if matches!(out_ty.elem(), Elem::Relaxed) {
473                        b.decorate(out, Decoration::RelaxedPrecision, []);
474                    }
475                })
476            }
477            Arithmetic::ArcTan2(op) => {
478                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
479                    b.declare_math_mode(modes, out);
480                    T::atan2(b, ty, lhs, rhs, out);
481                    if matches!(out_ty.elem(), Elem::Relaxed) {
482                        b.decorate(out, Decoration::RelaxedPrecision, []);
483                    }
484                })
485            }
486            // No powi for Vulkan, just auto-cast to float
487            Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
488                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
489                    let bool = match out_ty {
490                        Item::Scalar(_) => Elem::Bool.id(b),
491                        Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
492                        _ => unreachable!(),
493                    };
494                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
495                    let zero = out_ty.const_u32(b, 0);
496                    let one = out_ty.const_u32(b, 1);
497                    let two = out_ty.const_u32(b, 2);
498                    let modulo = b.f_rem(ty, None, rhs, two).unwrap();
499                    b.declare_math_mode(modes, modulo);
500                    let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
501                    b.declare_math_mode(modes, is_zero);
502                    let abs = b.id();
503                    b.declare_math_mode(modes, abs);
504                    T::f_abs(b, ty, lhs, abs);
505                    let even = b.id();
506                    b.declare_math_mode(modes, even);
507                    T::pow(b, ty, abs, rhs, even);
508                    let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
509                    b.declare_math_mode(modes, cond2_0);
510                    let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
511                    b.declare_math_mode(modes, cond2_1);
512                    let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
513                    let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
514                    b.declare_math_mode(modes, neg_lhs);
515                    let pow2 = b.id();
516                    b.declare_math_mode(modes, pow2);
517                    T::pow(b, ty, neg_lhs, rhs, pow2);
518                    let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
519                    b.declare_math_mode(modes, pow2_neg);
520                    let default = b.id();
521                    b.declare_math_mode(modes, default);
522                    T::pow(b, ty, lhs, rhs, default);
523                    let ids = [
524                        modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
525                        default,
526                    ];
527                    for id in ids {
528                        b.mark_uniformity(id, uniform);
529                        if relaxed {
530                            b.decorate(id, Decoration::RelaxedPrecision, []);
531                        }
532                    }
533                    let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
534                    b.mark_uniformity(sel1, uniform);
535                    b.select(ty, Some(out), is_zero, even, sel1).unwrap();
536                })
537            }
538            Arithmetic::Sqrt(op) => {
539                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
540                    b.declare_math_mode(modes, out);
541                    T::sqrt(b, ty, input, out);
542                    if matches!(out_ty.elem(), Elem::Relaxed) {
543                        b.decorate(out, Decoration::RelaxedPrecision, []);
544                    }
545                })
546            }
547            Arithmetic::InverseSqrt(op) => {
548                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
549                    b.declare_math_mode(modes, out);
550                    T::inverse_sqrt(b, ty, input, out);
551                    if matches!(out_ty.elem(), Elem::Relaxed) {
552                        b.decorate(out, Decoration::RelaxedPrecision, []);
553                    }
554                })
555            }
556            Arithmetic::Round(op) => {
557                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
558                    T::round(b, ty, input, out);
559                    if matches!(out_ty.elem(), Elem::Relaxed) {
560                        b.decorate(out, Decoration::RelaxedPrecision, []);
561                    }
562                })
563            }
564            Arithmetic::Floor(op) => {
565                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
566                    b.declare_math_mode(modes, out);
567                    T::floor(b, ty, input, out);
568                    if matches!(out_ty.elem(), Elem::Relaxed) {
569                        b.decorate(out, Decoration::RelaxedPrecision, []);
570                    }
571                })
572            }
573            Arithmetic::Ceil(op) => {
574                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
575                    b.declare_math_mode(modes, out);
576                    T::ceil(b, ty, input, out);
577                    if matches!(out_ty.elem(), Elem::Relaxed) {
578                        b.decorate(out, Decoration::RelaxedPrecision, []);
579                    }
580                })
581            }
582            Arithmetic::Trunc(op) => {
583                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
584                    b.declare_math_mode(modes, out);
585                    T::trunc(b, ty, input, out);
586                    if matches!(out_ty.elem(), Elem::Relaxed) {
587                        b.decorate(out, Decoration::RelaxedPrecision, []);
588                    }
589                })
590            }
591            Arithmetic::Clamp(op) => {
592                let input = self.compile_variable(op.input);
593                let min = self.compile_variable(op.min_value);
594                let max = self.compile_variable(op.max_value);
595                let out = self.compile_variable(out);
596                let out_ty = out.item();
597
598                let input = self.read_as(&input, &out_ty);
599                let min = self.read_as(&min, &out_ty);
600                let max = self.read_as(&max, &out_ty);
601                let out_id = self.write_id(&out);
602                self.mark_uniformity(out_id, uniform);
603
604                let ty = out_ty.id(self);
605
606                match out_ty.elem() {
607                    Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
608                    Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
609                    Elem::Float(..) => {
610                        self.declare_math_mode(modes, out_id);
611                        T::f_clamp(self, ty, input, min, max, out_id)
612                    }
613                    Elem::Relaxed => {
614                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
615                        self.declare_math_mode(modes, out_id);
616                        T::f_clamp(self, ty, input, min, max, out_id)
617                    }
618                    _ => unreachable!(),
619                }
620                self.write(&out, out_id);
621            }
622
623            Arithmetic::Max(op) => self.compile_binary_op(
624                op,
625                out,
626                uniform,
627                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
628                    Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
629                    Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
630                    Elem::Float(..) => {
631                        b.declare_math_mode(modes, out);
632                        T::f_max(b, ty, lhs, rhs, out)
633                    }
634                    Elem::Relaxed => {
635                        b.decorate(out, Decoration::RelaxedPrecision, []);
636                        b.declare_math_mode(modes, out);
637                        T::f_max(b, ty, lhs, rhs, out)
638                    }
639                    _ => unreachable!(),
640                },
641            ),
642            Arithmetic::Min(op) => self.compile_binary_op(
643                op,
644                out,
645                uniform,
646                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
647                    Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
648                    Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
649                    Elem::Float(..) => {
650                        b.declare_math_mode(modes, out);
651                        T::f_min(b, ty, lhs, rhs, out)
652                    }
653                    Elem::Relaxed => {
654                        b.decorate(out, Decoration::RelaxedPrecision, []);
655                        b.declare_math_mode(modes, out);
656                        T::f_min(b, ty, lhs, rhs, out)
657                    }
658                    _ => unreachable!(),
659                },
660            ),
661        }
662    }
663}