Skip to main content

cubecl_spirv/
arithmetic.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10    pub fn compile_arithmetic(
11        &mut self,
12        op: Arithmetic,
13        out: Option<core::Variable>,
14        modes: InstructionModes,
15        uniform: bool,
16    ) {
17        let out = out.unwrap();
18        match op {
19            Arithmetic::Add(op) => {
20                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21                    match out_ty.elem() {
22                        Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23                        Elem::Float(..) => {
24                            b.declare_math_mode(modes, out);
25                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
26                        }
27                        Elem::Relaxed => {
28                            b.decorate(out, Decoration::RelaxedPrecision, []);
29                            b.declare_math_mode(modes, out);
30                            b.f_add(ty, Some(out), lhs, rhs).unwrap()
31                        }
32                        _ => unreachable!(),
33                    };
34                });
35            }
36            Arithmetic::SaturatingAdd(_) => {
37                unimplemented!("Should be replaced by polyfill");
38            }
39            Arithmetic::Sub(op) => {
40                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41                    match out_ty.elem() {
42                        Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43                        Elem::Float(..) => {
44                            b.declare_math_mode(modes, out);
45                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46                        }
47                        Elem::Relaxed => {
48                            b.decorate(out, Decoration::RelaxedPrecision, []);
49                            b.declare_math_mode(modes, out);
50                            b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51                        }
52                        _ => unreachable!(),
53                    };
54                });
55            }
56            Arithmetic::SaturatingSub(_) => {
57                unimplemented!("Should be replaced by polyfill");
58            }
59            Arithmetic::Mul(op) => {
60                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61                    match out_ty.elem() {
62                        Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63                        Elem::Float(..) => {
64                            b.declare_math_mode(modes, out);
65                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66                        }
67                        Elem::Relaxed => {
68                            b.decorate(out, Decoration::RelaxedPrecision, []);
69                            b.declare_math_mode(modes, out);
70                            b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71                        }
72                        _ => unreachable!(),
73                    };
74                });
75            }
76            Arithmetic::MulHi(op) => {
77                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78                    let out_st = b.type_struct([ty, ty]);
79                    let extended = match out_ty.elem() {
80                        Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81                        Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82                        _ => unreachable!(),
83                    };
84                    b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85                });
86            }
87            Arithmetic::Div(op) => {
88                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89                    match out_ty.elem() {
90                        Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91                        Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92                        Elem::Float(..) => {
93                            b.declare_math_mode(modes, out);
94                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
95                        }
96                        Elem::Relaxed => {
97                            b.decorate(out, Decoration::RelaxedPrecision, []);
98                            b.declare_math_mode(modes, out);
99                            b.f_div(ty, Some(out), lhs, rhs).unwrap()
100                        }
101                        _ => unreachable!(),
102                    };
103                });
104            }
105            Arithmetic::Remainder(op) => {
106                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107                    match out_ty.elem() {
108                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109                        Elem::Int(_, true) => {
110                            // Convert to float and use `f_mod` (floored division) instead of `s_mod`
111                            // (truncated division) to match remainder semantics across dtypes
112                            // e.g. remainder(-2, 3) = 1, not 2
113                            let f_ty = match out_ty {
114                                Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115                                Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116                                _ => unreachable!(),
117                            };
118                            let f_ty = f_ty.id(b);
119                            let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120                            let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121                            let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122                            b.convert_f_to_s(ty, Some(out), rem).unwrap()
123                        }
124                        Elem::Float(..) => {
125                            b.declare_math_mode(modes, out);
126                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127                        }
128                        Elem::Relaxed => {
129                            b.decorate(out, Decoration::RelaxedPrecision, []);
130                            b.declare_math_mode(modes, out);
131                            b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132                        }
133                        _ => unreachable!(),
134                    };
135                });
136            }
137            Arithmetic::Modulo(op) => {
138                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139                    match out_ty.elem() {
140                        Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141                        Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142                        Elem::Float(..) => {
143                            b.declare_math_mode(modes, out);
144                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145                        }
146                        Elem::Relaxed => {
147                            b.decorate(out, Decoration::RelaxedPrecision, []);
148                            b.declare_math_mode(modes, out);
149                            b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150                        }
151                        _ => unreachable!(),
152                    };
153                });
154            }
155            Arithmetic::Dot(op) => {
156                if op.lhs.ty.vector_size() == 1 {
157                    self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158                        match out_ty.elem() {
159                            Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160                            Elem::Float(..) => {
161                                b.declare_math_mode(modes, out);
162                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163                            }
164                            Elem::Relaxed => {
165                                b.decorate(out, Decoration::RelaxedPrecision, []);
166                                b.declare_math_mode(modes, out);
167                                b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168                            }
169                            _ => unreachable!(),
170                        };
171                    });
172                } else {
173                    let lhs = self.compile_variable(op.lhs);
174                    let rhs = self.compile_variable(op.rhs);
175                    let out = self.compile_variable(out);
176                    let ty = out.item().id(self);
177
178                    let lhs_id = self.read(&lhs);
179                    let rhs_id = self.read(&rhs);
180                    let out_id = self.write_id(&out);
181                    self.mark_uniformity(out_id, uniform);
182
183                    if matches!(lhs.elem(), Elem::Int(_, _)) {
184                        self.capabilities.insert(Capability::DotProduct);
185                    }
186                    if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187                        self.capabilities.insert(Capability::BFloat16DotProductKHR);
188                    }
189
190                    match (lhs.elem(), rhs.elem()) {
191                        (Elem::Int(_, false), Elem::Int(_, false)) => {
192                            self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193                        }
194                        (Elem::Int(_, true), Elem::Int(_, false)) => {
195                            self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196                        }
197                        (Elem::Int(_, false), Elem::Int(_, true)) => {
198                            self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199                        }
200                        (Elem::Int(_, true), Elem::Int(_, true)) => {
201                            self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202                        }
203                        (Elem::Float(..), Elem::Float(..))
204                        | (Elem::Relaxed, Elem::Float(..))
205                        | (Elem::Float(..), Elem::Relaxed) => {
206                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
207                        }
208                        (Elem::Relaxed, Elem::Relaxed) => {
209                            self.decorate(out_id, Decoration::RelaxedPrecision, []);
210                            self.dot(ty, Some(out_id), lhs_id, rhs_id)
211                        }
212                        _ => unreachable!(),
213                    }
214                    .unwrap();
215                    self.write(&out, out_id);
216                }
217            }
218            Arithmetic::Fma(op) => {
219                let a = self.compile_variable(op.a);
220                let b = self.compile_variable(op.b);
221                let c = self.compile_variable(op.c);
222                let out = self.compile_variable(out);
223                let out_ty = out.item();
224                let relaxed = matches!(
225                    (a.item().elem(), b.item().elem(), c.item().elem()),
226                    (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227                );
228
229                let a_id = self.read_as(&a, &out_ty);
230                let b_id = self.read_as(&b, &out_ty);
231                let c_id = self.read_as(&c, &out_ty);
232                let out_id = self.write_id(&out);
233                self.mark_uniformity(out_id, uniform);
234
235                let ty = out_ty.id(self);
236
237                let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238                self.mark_uniformity(mul, uniform);
239                self.declare_math_mode(modes, mul);
240                self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241                self.declare_math_mode(modes, out_id);
242                if relaxed {
243                    self.decorate(mul, Decoration::RelaxedPrecision, []);
244                    self.decorate(out_id, Decoration::RelaxedPrecision, []);
245                }
246                self.write(&out, out_id);
247            }
248            Arithmetic::Recip(op) => {
249                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250                    let one = b
251                        .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
252                        .0;
253                    b.declare_math_mode(modes, out);
254                    b.f_div(ty, Some(out), one, input).unwrap();
255                });
256            }
257            Arithmetic::Neg(op) => {
258                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
259                    match out_ty.elem() {
260                        Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
261                        Elem::Float(..) => {
262                            b.declare_math_mode(modes, out);
263                            b.f_negate(ty, Some(out), input).unwrap()
264                        }
265                        Elem::Relaxed => {
266                            b.decorate(out, Decoration::RelaxedPrecision, []);
267                            b.declare_math_mode(modes, out);
268                            b.f_negate(ty, Some(out), input).unwrap()
269                        }
270                        _ => unreachable!(),
271                    };
272                });
273            }
274            Arithmetic::Erf(_) => {
275                unreachable!("Replaced by transformer")
276            }
277
278            // Extension functions
279            Arithmetic::Normalize(op) => {
280                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
281                    b.declare_math_mode(modes, out);
282                    T::normalize(b, ty, input, out);
283                    if matches!(out_ty.elem(), Elem::Relaxed) {
284                        b.decorate(out, Decoration::RelaxedPrecision, []);
285                    }
286                });
287            }
288            Arithmetic::VectorSum(op) => {
289                let input_ir = op.input;
290                let input = self.compile_variable(input_ir);
291                let out = self.compile_variable(out);
292                let in_item = input.item();
293                let out_ty = out.item();
294                let vec_size = in_item.vectorization();
295                let scalar_ty = out_ty.id(self);
296                let input_id = self.read(&input);
297                let out_id = self.write_id(&out);
298                self.mark_uniformity(out_id, uniform);
299
300                if vec_size <= 1 {
301                    // Scalar: identity
302                    self.copy_object(scalar_ty, Some(out_id), input_id).unwrap();
303                } else if matches!(out_ty.elem(), Elem::Float(..) | Elem::Relaxed) {
304                    // Float vector: use OpDot with ones vector for optimal single instruction
305                    self.declare_math_mode(modes, out_id);
306                    let ones_val = in_item
307                        .elem()
308                        .constant(self, ConstVal::Bit32(1.0f32.to_bits()));
309                    let vec_ty = in_item.id(self);
310                    let ones = self.constant_composite(vec_ty, (0..vec_size).map(|_| ones_val));
311                    self.dot(scalar_ty, Some(out_id), input_id, ones).unwrap();
312                    if matches!(out_ty.elem(), Elem::Relaxed) {
313                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
314                    }
315                } else {
316                    // Integer vector: extract and add
317                    let elem_ty = out_ty.id(self);
318                    let mut acc = self
319                        .composite_extract(elem_ty, None, input_id, vec![0])
320                        .unwrap();
321                    for i in 1..vec_size {
322                        let elem = self
323                            .composite_extract(elem_ty, None, input_id, vec![i])
324                            .unwrap();
325                        acc = self.i_add(elem_ty, None, acc, elem).unwrap();
326                    }
327                    self.copy_object(scalar_ty, Some(out_id), acc).unwrap();
328                }
329                self.write(&out, out_id);
330            }
331            Arithmetic::Magnitude(op) => {
332                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
333                    b.declare_math_mode(modes, out);
334                    T::magnitude(b, ty, input, out);
335                    if matches!(out_ty.elem(), Elem::Relaxed) {
336                        b.decorate(out, Decoration::RelaxedPrecision, []);
337                    }
338                });
339            }
340            Arithmetic::Abs(op) => {
341                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
342                    match out_ty.elem() {
343                        Elem::Int(_, _) => T::s_abs(b, ty, input, out),
344                        Elem::Float(..) => {
345                            b.declare_math_mode(modes, out);
346                            T::f_abs(b, ty, input, out)
347                        }
348                        Elem::Relaxed => {
349                            b.decorate(out, Decoration::RelaxedPrecision, []);
350                            b.declare_math_mode(modes, out);
351                            T::f_abs(b, ty, input, out)
352                        }
353                        _ => unreachable!(),
354                    }
355                });
356            }
357            Arithmetic::Exp(op) => {
358                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
359                    b.declare_math_mode(modes, out);
360                    T::exp(b, ty, input, out);
361                    if matches!(out_ty.elem(), Elem::Relaxed) {
362                        b.decorate(out, Decoration::RelaxedPrecision, []);
363                    }
364                });
365            }
366            Arithmetic::Log(op) => {
367                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
368                    b.declare_math_mode(modes, out);
369                    T::log(b, ty, input, out);
370                    if matches!(out_ty.elem(), Elem::Relaxed) {
371                        b.decorate(out, Decoration::RelaxedPrecision, []);
372                    }
373                })
374            }
375            Arithmetic::Log1p(op) => {
376                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
377                    let one = b
378                        .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
379                        .0;
380                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
381                    let add = match out_ty.elem() {
382                        Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
383                        Elem::Float(..) | Elem::Relaxed => {
384                            b.declare_math_mode(modes, out);
385                            b.f_add(ty, None, input, one).unwrap()
386                        }
387                        _ => unreachable!(),
388                    };
389                    b.mark_uniformity(add, uniform);
390                    if relaxed {
391                        b.decorate(add, Decoration::RelaxedPrecision, []);
392                        b.decorate(out, Decoration::RelaxedPrecision, []);
393                    }
394                    b.declare_math_mode(modes, out);
395                    T::log(b, ty, add, out)
396                });
397            }
398            Arithmetic::Cos(op) => {
399                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
400                    b.declare_math_mode(modes, out);
401                    T::cos(b, ty, input, out);
402                    if matches!(out_ty.elem(), Elem::Relaxed) {
403                        b.decorate(out, Decoration::RelaxedPrecision, []);
404                    }
405                })
406            }
407            Arithmetic::Sin(op) => {
408                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
409                    b.declare_math_mode(modes, out);
410                    T::sin(b, ty, input, out);
411                    if matches!(out_ty.elem(), Elem::Relaxed) {
412                        b.decorate(out, Decoration::RelaxedPrecision, []);
413                    }
414                })
415            }
416            Arithmetic::Tan(op) => {
417                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
418                    b.declare_math_mode(modes, out);
419                    T::tan(b, ty, input, out);
420                    if matches!(out_ty.elem(), Elem::Relaxed) {
421                        b.decorate(out, Decoration::RelaxedPrecision, []);
422                    }
423                })
424            }
425            Arithmetic::Tanh(op) => {
426                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
427                    b.declare_math_mode(modes, out);
428                    T::tanh(b, ty, input, out);
429                    if matches!(out_ty.elem(), Elem::Relaxed) {
430                        b.decorate(out, Decoration::RelaxedPrecision, []);
431                    }
432                })
433            }
434            Arithmetic::Sinh(op) => {
435                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
436                    b.declare_math_mode(modes, out);
437                    T::sinh(b, ty, input, out);
438                    if matches!(out_ty.elem(), Elem::Relaxed) {
439                        b.decorate(out, Decoration::RelaxedPrecision, []);
440                    }
441                })
442            }
443            Arithmetic::Cosh(op) => {
444                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
445                    b.declare_math_mode(modes, out);
446                    T::cosh(b, ty, input, out);
447                    if matches!(out_ty.elem(), Elem::Relaxed) {
448                        b.decorate(out, Decoration::RelaxedPrecision, []);
449                    }
450                })
451            }
452            Arithmetic::ArcCos(op) => {
453                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
454                    b.declare_math_mode(modes, out);
455                    T::acos(b, ty, input, out);
456                    if matches!(out_ty.elem(), Elem::Relaxed) {
457                        b.decorate(out, Decoration::RelaxedPrecision, []);
458                    }
459                })
460            }
461            Arithmetic::ArcSin(op) => {
462                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
463                    b.declare_math_mode(modes, out);
464                    T::asin(b, ty, input, out);
465                    if matches!(out_ty.elem(), Elem::Relaxed) {
466                        b.decorate(out, Decoration::RelaxedPrecision, []);
467                    }
468                })
469            }
470            Arithmetic::ArcTan(op) => {
471                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
472                    b.declare_math_mode(modes, out);
473                    T::atan(b, ty, input, out);
474                    if matches!(out_ty.elem(), Elem::Relaxed) {
475                        b.decorate(out, Decoration::RelaxedPrecision, []);
476                    }
477                })
478            }
479            Arithmetic::ArcSinh(op) => {
480                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
481                    b.declare_math_mode(modes, out);
482                    T::asinh(b, ty, input, out);
483                    if matches!(out_ty.elem(), Elem::Relaxed) {
484                        b.decorate(out, Decoration::RelaxedPrecision, []);
485                    }
486                })
487            }
488            Arithmetic::ArcCosh(op) => {
489                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
490                    b.declare_math_mode(modes, out);
491                    T::acosh(b, ty, input, out);
492                    if matches!(out_ty.elem(), Elem::Relaxed) {
493                        b.decorate(out, Decoration::RelaxedPrecision, []);
494                    }
495                })
496            }
497            Arithmetic::ArcTanh(op) => {
498                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
499                    b.declare_math_mode(modes, out);
500                    T::atanh(b, ty, input, out);
501                    if matches!(out_ty.elem(), Elem::Relaxed) {
502                        b.decorate(out, Decoration::RelaxedPrecision, []);
503                    }
504                })
505            }
506            Arithmetic::Degrees(op) => {
507                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
508                    b.declare_math_mode(modes, out);
509                    T::degrees(b, ty, input, out);
510                    if matches!(out_ty.elem(), Elem::Relaxed) {
511                        b.decorate(out, Decoration::RelaxedPrecision, []);
512                    }
513                })
514            }
515            Arithmetic::Radians(op) => {
516                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
517                    b.declare_math_mode(modes, out);
518                    T::radians(b, ty, input, out);
519                    if matches!(out_ty.elem(), Elem::Relaxed) {
520                        b.decorate(out, Decoration::RelaxedPrecision, []);
521                    }
522                })
523            }
524            Arithmetic::ArcTan2(op) => {
525                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
526                    b.declare_math_mode(modes, out);
527                    T::atan2(b, ty, lhs, rhs, out);
528                    if matches!(out_ty.elem(), Elem::Relaxed) {
529                        b.decorate(out, Decoration::RelaxedPrecision, []);
530                    }
531                })
532            }
533            // No powi for Vulkan, just auto-cast to float
534            Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
535                self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
536                    let bool = match out_ty {
537                        Item::Scalar(_) => Elem::Bool.id(b),
538                        Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
539                        _ => unreachable!(),
540                    };
541                    let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
542                    let zero = out_ty.const_u32(b, 0);
543                    let one = out_ty.const_u32(b, 1);
544                    let two = out_ty.const_u32(b, 2);
545                    let modulo = b.f_rem(ty, None, rhs, two).unwrap();
546                    b.declare_math_mode(modes, modulo);
547                    let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
548                    b.declare_math_mode(modes, is_zero);
549                    let abs = b.id();
550                    b.declare_math_mode(modes, abs);
551                    T::f_abs(b, ty, lhs, abs);
552                    let even = b.id();
553                    b.declare_math_mode(modes, even);
554                    T::pow(b, ty, abs, rhs, even);
555                    let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
556                    b.declare_math_mode(modes, cond2_0);
557                    let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
558                    b.declare_math_mode(modes, cond2_1);
559                    let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
560                    let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
561                    b.declare_math_mode(modes, neg_lhs);
562                    let pow2 = b.id();
563                    b.declare_math_mode(modes, pow2);
564                    T::pow(b, ty, neg_lhs, rhs, pow2);
565                    let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
566                    b.declare_math_mode(modes, pow2_neg);
567                    let default = b.id();
568                    b.declare_math_mode(modes, default);
569                    T::pow(b, ty, lhs, rhs, default);
570                    let ids = [
571                        modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
572                        default,
573                    ];
574                    for id in ids {
575                        b.mark_uniformity(id, uniform);
576                        if relaxed {
577                            b.decorate(id, Decoration::RelaxedPrecision, []);
578                        }
579                    }
580                    let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
581                    b.mark_uniformity(sel1, uniform);
582                    b.select(ty, Some(out), is_zero, even, sel1).unwrap();
583                })
584            }
585            Arithmetic::Hypot(_op) => {
586                unreachable!("Replaced by transformer");
587            }
588            Arithmetic::Rhypot(_op) => {
589                unreachable!("Replaced by transformer");
590            }
591            Arithmetic::Sqrt(op) => {
592                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
593                    b.declare_math_mode(modes, out);
594                    T::sqrt(b, ty, input, out);
595                    if matches!(out_ty.elem(), Elem::Relaxed) {
596                        b.decorate(out, Decoration::RelaxedPrecision, []);
597                    }
598                })
599            }
600            Arithmetic::InverseSqrt(op) => {
601                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
602                    b.declare_math_mode(modes, out);
603                    T::inverse_sqrt(b, ty, input, out);
604                    if matches!(out_ty.elem(), Elem::Relaxed) {
605                        b.decorate(out, Decoration::RelaxedPrecision, []);
606                    }
607                })
608            }
609            Arithmetic::Round(op) => {
610                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
611                    T::round(b, ty, input, out);
612                    if matches!(out_ty.elem(), Elem::Relaxed) {
613                        b.decorate(out, Decoration::RelaxedPrecision, []);
614                    }
615                })
616            }
617            Arithmetic::Floor(op) => {
618                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
619                    b.declare_math_mode(modes, out);
620                    T::floor(b, ty, input, out);
621                    if matches!(out_ty.elem(), Elem::Relaxed) {
622                        b.decorate(out, Decoration::RelaxedPrecision, []);
623                    }
624                })
625            }
626            Arithmetic::Ceil(op) => {
627                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
628                    b.declare_math_mode(modes, out);
629                    T::ceil(b, ty, input, out);
630                    if matches!(out_ty.elem(), Elem::Relaxed) {
631                        b.decorate(out, Decoration::RelaxedPrecision, []);
632                    }
633                })
634            }
635            Arithmetic::Trunc(op) => {
636                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
637                    b.declare_math_mode(modes, out);
638                    T::trunc(b, ty, input, out);
639                    if matches!(out_ty.elem(), Elem::Relaxed) {
640                        b.decorate(out, Decoration::RelaxedPrecision, []);
641                    }
642                })
643            }
644            Arithmetic::Clamp(op) => {
645                let input = self.compile_variable(op.input);
646                let min = self.compile_variable(op.min_value);
647                let max = self.compile_variable(op.max_value);
648                let out = self.compile_variable(out);
649                let out_ty = out.item();
650
651                let input = self.read_as(&input, &out_ty);
652                let min = self.read_as(&min, &out_ty);
653                let max = self.read_as(&max, &out_ty);
654                let out_id = self.write_id(&out);
655                self.mark_uniformity(out_id, uniform);
656
657                let ty = out_ty.id(self);
658
659                match out_ty.elem() {
660                    Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
661                    Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
662                    Elem::Float(..) => {
663                        self.declare_math_mode(modes, out_id);
664                        T::f_clamp(self, ty, input, min, max, out_id)
665                    }
666                    Elem::Relaxed => {
667                        self.decorate(out_id, Decoration::RelaxedPrecision, []);
668                        self.declare_math_mode(modes, out_id);
669                        T::f_clamp(self, ty, input, min, max, out_id)
670                    }
671                    _ => unreachable!(),
672                }
673                self.write(&out, out_id);
674            }
675
676            Arithmetic::Max(op) => self.compile_binary_op(
677                op,
678                out,
679                uniform,
680                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
681                    Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
682                    Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
683                    Elem::Float(..) => {
684                        b.declare_math_mode(modes, out);
685                        T::f_max(b, ty, lhs, rhs, out)
686                    }
687                    Elem::Relaxed => {
688                        b.decorate(out, Decoration::RelaxedPrecision, []);
689                        b.declare_math_mode(modes, out);
690                        T::f_max(b, ty, lhs, rhs, out)
691                    }
692                    _ => unreachable!(),
693                },
694            ),
695            Arithmetic::Min(op) => self.compile_binary_op(
696                op,
697                out,
698                uniform,
699                |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
700                    Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
701                    Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
702                    Elem::Float(..) => {
703                        b.declare_math_mode(modes, out);
704                        T::f_min(b, ty, lhs, rhs, out)
705                    }
706                    Elem::Relaxed => {
707                        b.decorate(out, Decoration::RelaxedPrecision, []);
708                        b.declare_math_mode(modes, out);
709                        T::f_min(b, ty, lhs, rhs, out)
710                    }
711                    _ => unreachable!(),
712                },
713            ),
714        }
715    }
716}