1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10 pub fn compile_arithmetic(
11 &mut self,
12 op: Arithmetic,
13 out: Option<core::Variable>,
14 modes: InstructionModes,
15 uniform: bool,
16 ) {
17 let out = out.unwrap();
18 match op {
19 Arithmetic::Add(op) => {
20 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21 match out_ty.elem() {
22 Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23 Elem::Float(..) => {
24 b.declare_math_mode(modes, out);
25 b.f_add(ty, Some(out), lhs, rhs).unwrap()
26 }
27 Elem::Relaxed => {
28 b.decorate(out, Decoration::RelaxedPrecision, []);
29 b.declare_math_mode(modes, out);
30 b.f_add(ty, Some(out), lhs, rhs).unwrap()
31 }
32 _ => unreachable!(),
33 };
34 });
35 }
36 Arithmetic::SaturatingAdd(_) => {
37 unimplemented!("Should be replaced by polyfill");
38 }
39 Arithmetic::Sub(op) => {
40 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41 match out_ty.elem() {
42 Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43 Elem::Float(..) => {
44 b.declare_math_mode(modes, out);
45 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46 }
47 Elem::Relaxed => {
48 b.decorate(out, Decoration::RelaxedPrecision, []);
49 b.declare_math_mode(modes, out);
50 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51 }
52 _ => unreachable!(),
53 };
54 });
55 }
56 Arithmetic::SaturatingSub(_) => {
57 unimplemented!("Should be replaced by polyfill");
58 }
59 Arithmetic::Mul(op) => {
60 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61 match out_ty.elem() {
62 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63 Elem::Float(..) => {
64 b.declare_math_mode(modes, out);
65 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66 }
67 Elem::Relaxed => {
68 b.decorate(out, Decoration::RelaxedPrecision, []);
69 b.declare_math_mode(modes, out);
70 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71 }
72 _ => unreachable!(),
73 };
74 });
75 }
76 Arithmetic::MulHi(op) => {
77 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78 let out_st = b.type_struct([ty, ty]);
79 let extended = match out_ty.elem() {
80 Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81 Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82 _ => unreachable!(),
83 };
84 b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85 });
86 }
87 Arithmetic::Div(op) => {
88 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89 match out_ty.elem() {
90 Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91 Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92 Elem::Float(..) => {
93 b.declare_math_mode(modes, out);
94 b.f_div(ty, Some(out), lhs, rhs).unwrap()
95 }
96 Elem::Relaxed => {
97 b.decorate(out, Decoration::RelaxedPrecision, []);
98 b.declare_math_mode(modes, out);
99 b.f_div(ty, Some(out), lhs, rhs).unwrap()
100 }
101 _ => unreachable!(),
102 };
103 });
104 }
105 Arithmetic::Remainder(op) => {
106 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107 match out_ty.elem() {
108 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109 Elem::Int(_, true) => {
110 let f_ty = match out_ty {
114 Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115 Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116 _ => unreachable!(),
117 };
118 let f_ty = f_ty.id(b);
119 let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120 let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121 let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122 b.convert_f_to_s(ty, Some(out), rem).unwrap()
123 }
124 Elem::Float(..) => {
125 b.declare_math_mode(modes, out);
126 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127 }
128 Elem::Relaxed => {
129 b.decorate(out, Decoration::RelaxedPrecision, []);
130 b.declare_math_mode(modes, out);
131 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132 }
133 _ => unreachable!(),
134 };
135 });
136 }
137 Arithmetic::Modulo(op) => {
138 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139 match out_ty.elem() {
140 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141 Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142 Elem::Float(..) => {
143 b.declare_math_mode(modes, out);
144 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145 }
146 Elem::Relaxed => {
147 b.decorate(out, Decoration::RelaxedPrecision, []);
148 b.declare_math_mode(modes, out);
149 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150 }
151 _ => unreachable!(),
152 };
153 });
154 }
155 Arithmetic::Dot(op) => {
156 if op.lhs.ty.line_size() == 1 {
157 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158 match out_ty.elem() {
159 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160 Elem::Float(..) => {
161 b.declare_math_mode(modes, out);
162 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163 }
164 Elem::Relaxed => {
165 b.decorate(out, Decoration::RelaxedPrecision, []);
166 b.declare_math_mode(modes, out);
167 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168 }
169 _ => unreachable!(),
170 };
171 });
172 } else {
173 let lhs = self.compile_variable(op.lhs);
174 let rhs = self.compile_variable(op.rhs);
175 let out = self.compile_variable(out);
176 let ty = out.item().id(self);
177
178 let lhs_id = self.read(&lhs);
179 let rhs_id = self.read(&rhs);
180 let out_id = self.write_id(&out);
181 self.mark_uniformity(out_id, uniform);
182
183 if matches!(lhs.elem(), Elem::Int(_, _)) {
184 self.capabilities.insert(Capability::DotProduct);
185 }
186 if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187 self.capabilities.insert(Capability::BFloat16DotProductKHR);
188 }
189
190 match (lhs.elem(), rhs.elem()) {
191 (Elem::Int(_, false), Elem::Int(_, false)) => {
192 self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193 }
194 (Elem::Int(_, true), Elem::Int(_, false)) => {
195 self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196 }
197 (Elem::Int(_, false), Elem::Int(_, true)) => {
198 self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199 }
200 (Elem::Int(_, true), Elem::Int(_, true)) => {
201 self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202 }
203 (Elem::Float(..), Elem::Float(..))
204 | (Elem::Relaxed, Elem::Float(..))
205 | (Elem::Float(..), Elem::Relaxed) => {
206 self.dot(ty, Some(out_id), lhs_id, rhs_id)
207 }
208 (Elem::Relaxed, Elem::Relaxed) => {
209 self.decorate(out_id, Decoration::RelaxedPrecision, []);
210 self.dot(ty, Some(out_id), lhs_id, rhs_id)
211 }
212 _ => unreachable!(),
213 }
214 .unwrap();
215 self.write(&out, out_id);
216 }
217 }
218 Arithmetic::Fma(op) => {
219 let a = self.compile_variable(op.a);
220 let b = self.compile_variable(op.b);
221 let c = self.compile_variable(op.c);
222 let out = self.compile_variable(out);
223 let out_ty = out.item();
224 let relaxed = matches!(
225 (a.item().elem(), b.item().elem(), c.item().elem()),
226 (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227 );
228
229 let a_id = self.read_as(&a, &out_ty);
230 let b_id = self.read_as(&b, &out_ty);
231 let c_id = self.read_as(&c, &out_ty);
232 let out_id = self.write_id(&out);
233 self.mark_uniformity(out_id, uniform);
234
235 let ty = out_ty.id(self);
236
237 let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238 self.mark_uniformity(mul, uniform);
239 self.declare_math_mode(modes, mul);
240 self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241 self.declare_math_mode(modes, out_id);
242 if relaxed {
243 self.decorate(mul, Decoration::RelaxedPrecision, []);
244 self.decorate(out_id, Decoration::RelaxedPrecision, []);
245 }
246 self.write(&out, out_id);
247 }
248 Arithmetic::Recip(op) => {
249 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250 let one = b
251 .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
252 .0;
253 b.declare_math_mode(modes, out);
254 b.f_div(ty, Some(out), one, input).unwrap();
255 });
256 }
257 Arithmetic::Neg(op) => {
258 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
259 match out_ty.elem() {
260 Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
261 Elem::Float(..) => {
262 b.declare_math_mode(modes, out);
263 b.f_negate(ty, Some(out), input).unwrap()
264 }
265 Elem::Relaxed => {
266 b.decorate(out, Decoration::RelaxedPrecision, []);
267 b.declare_math_mode(modes, out);
268 b.f_negate(ty, Some(out), input).unwrap()
269 }
270 _ => unreachable!(),
271 };
272 });
273 }
274 Arithmetic::Erf(_) => {
275 unreachable!("Replaced by transformer")
276 }
277
278 Arithmetic::Normalize(op) => {
280 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
281 b.declare_math_mode(modes, out);
282 T::normalize(b, ty, input, out);
283 if matches!(out_ty.elem(), Elem::Relaxed) {
284 b.decorate(out, Decoration::RelaxedPrecision, []);
285 }
286 });
287 }
288 Arithmetic::Magnitude(op) => {
289 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
290 b.declare_math_mode(modes, out);
291 T::magnitude(b, ty, input, out);
292 if matches!(out_ty.elem(), Elem::Relaxed) {
293 b.decorate(out, Decoration::RelaxedPrecision, []);
294 }
295 });
296 }
297 Arithmetic::Abs(op) => {
298 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
299 match out_ty.elem() {
300 Elem::Int(_, _) => T::s_abs(b, ty, input, out),
301 Elem::Float(..) => {
302 b.declare_math_mode(modes, out);
303 T::f_abs(b, ty, input, out)
304 }
305 Elem::Relaxed => {
306 b.decorate(out, Decoration::RelaxedPrecision, []);
307 b.declare_math_mode(modes, out);
308 T::f_abs(b, ty, input, out)
309 }
310 _ => unreachable!(),
311 }
312 });
313 }
314 Arithmetic::Exp(op) => {
315 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
316 b.declare_math_mode(modes, out);
317 T::exp(b, ty, input, out);
318 if matches!(out_ty.elem(), Elem::Relaxed) {
319 b.decorate(out, Decoration::RelaxedPrecision, []);
320 }
321 });
322 }
323 Arithmetic::Log(op) => {
324 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
325 b.declare_math_mode(modes, out);
326 T::log(b, ty, input, out);
327 if matches!(out_ty.elem(), Elem::Relaxed) {
328 b.decorate(out, Decoration::RelaxedPrecision, []);
329 }
330 })
331 }
332 Arithmetic::Log1p(op) => {
333 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
334 let one = b
335 .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
336 .0;
337 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
338 let add = match out_ty.elem() {
339 Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
340 Elem::Float(..) | Elem::Relaxed => {
341 b.declare_math_mode(modes, out);
342 b.f_add(ty, None, input, one).unwrap()
343 }
344 _ => unreachable!(),
345 };
346 b.mark_uniformity(add, uniform);
347 if relaxed {
348 b.decorate(add, Decoration::RelaxedPrecision, []);
349 b.decorate(out, Decoration::RelaxedPrecision, []);
350 }
351 b.declare_math_mode(modes, out);
352 T::log(b, ty, add, out)
353 });
354 }
355 Arithmetic::Cos(op) => {
356 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357 b.declare_math_mode(modes, out);
358 T::cos(b, ty, input, out);
359 if matches!(out_ty.elem(), Elem::Relaxed) {
360 b.decorate(out, Decoration::RelaxedPrecision, []);
361 }
362 })
363 }
364 Arithmetic::Sin(op) => {
365 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
366 b.declare_math_mode(modes, out);
367 T::sin(b, ty, input, out);
368 if matches!(out_ty.elem(), Elem::Relaxed) {
369 b.decorate(out, Decoration::RelaxedPrecision, []);
370 }
371 })
372 }
373 Arithmetic::Tan(op) => {
374 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
375 b.declare_math_mode(modes, out);
376 T::tan(b, ty, input, out);
377 if matches!(out_ty.elem(), Elem::Relaxed) {
378 b.decorate(out, Decoration::RelaxedPrecision, []);
379 }
380 })
381 }
382 Arithmetic::Tanh(op) => {
383 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
384 b.declare_math_mode(modes, out);
385 T::tanh(b, ty, input, out);
386 if matches!(out_ty.elem(), Elem::Relaxed) {
387 b.decorate(out, Decoration::RelaxedPrecision, []);
388 }
389 })
390 }
391 Arithmetic::Sinh(op) => {
392 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
393 b.declare_math_mode(modes, out);
394 T::sinh(b, ty, input, out);
395 if matches!(out_ty.elem(), Elem::Relaxed) {
396 b.decorate(out, Decoration::RelaxedPrecision, []);
397 }
398 })
399 }
400 Arithmetic::Cosh(op) => {
401 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
402 b.declare_math_mode(modes, out);
403 T::cosh(b, ty, input, out);
404 if matches!(out_ty.elem(), Elem::Relaxed) {
405 b.decorate(out, Decoration::RelaxedPrecision, []);
406 }
407 })
408 }
409 Arithmetic::ArcCos(op) => {
410 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
411 b.declare_math_mode(modes, out);
412 T::acos(b, ty, input, out);
413 if matches!(out_ty.elem(), Elem::Relaxed) {
414 b.decorate(out, Decoration::RelaxedPrecision, []);
415 }
416 })
417 }
418 Arithmetic::ArcSin(op) => {
419 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
420 b.declare_math_mode(modes, out);
421 T::asin(b, ty, input, out);
422 if matches!(out_ty.elem(), Elem::Relaxed) {
423 b.decorate(out, Decoration::RelaxedPrecision, []);
424 }
425 })
426 }
427 Arithmetic::ArcTan(op) => {
428 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
429 b.declare_math_mode(modes, out);
430 T::atan(b, ty, input, out);
431 if matches!(out_ty.elem(), Elem::Relaxed) {
432 b.decorate(out, Decoration::RelaxedPrecision, []);
433 }
434 })
435 }
436 Arithmetic::ArcSinh(op) => {
437 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
438 b.declare_math_mode(modes, out);
439 T::asinh(b, ty, input, out);
440 if matches!(out_ty.elem(), Elem::Relaxed) {
441 b.decorate(out, Decoration::RelaxedPrecision, []);
442 }
443 })
444 }
445 Arithmetic::ArcCosh(op) => {
446 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
447 b.declare_math_mode(modes, out);
448 T::acosh(b, ty, input, out);
449 if matches!(out_ty.elem(), Elem::Relaxed) {
450 b.decorate(out, Decoration::RelaxedPrecision, []);
451 }
452 })
453 }
454 Arithmetic::ArcTanh(op) => {
455 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
456 b.declare_math_mode(modes, out);
457 T::atanh(b, ty, input, out);
458 if matches!(out_ty.elem(), Elem::Relaxed) {
459 b.decorate(out, Decoration::RelaxedPrecision, []);
460 }
461 })
462 }
463 Arithmetic::Degrees(op) => {
464 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
465 b.declare_math_mode(modes, out);
466 T::degrees(b, ty, input, out);
467 if matches!(out_ty.elem(), Elem::Relaxed) {
468 b.decorate(out, Decoration::RelaxedPrecision, []);
469 }
470 })
471 }
472 Arithmetic::Radians(op) => {
473 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
474 b.declare_math_mode(modes, out);
475 T::radians(b, ty, input, out);
476 if matches!(out_ty.elem(), Elem::Relaxed) {
477 b.decorate(out, Decoration::RelaxedPrecision, []);
478 }
479 })
480 }
481 Arithmetic::ArcTan2(op) => {
482 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
483 b.declare_math_mode(modes, out);
484 T::atan2(b, ty, lhs, rhs, out);
485 if matches!(out_ty.elem(), Elem::Relaxed) {
486 b.decorate(out, Decoration::RelaxedPrecision, []);
487 }
488 })
489 }
490 Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
492 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
493 let bool = match out_ty {
494 Item::Scalar(_) => Elem::Bool.id(b),
495 Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
496 _ => unreachable!(),
497 };
498 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
499 let zero = out_ty.const_u32(b, 0);
500 let one = out_ty.const_u32(b, 1);
501 let two = out_ty.const_u32(b, 2);
502 let modulo = b.f_rem(ty, None, rhs, two).unwrap();
503 b.declare_math_mode(modes, modulo);
504 let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
505 b.declare_math_mode(modes, is_zero);
506 let abs = b.id();
507 b.declare_math_mode(modes, abs);
508 T::f_abs(b, ty, lhs, abs);
509 let even = b.id();
510 b.declare_math_mode(modes, even);
511 T::pow(b, ty, abs, rhs, even);
512 let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
513 b.declare_math_mode(modes, cond2_0);
514 let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
515 b.declare_math_mode(modes, cond2_1);
516 let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
517 let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
518 b.declare_math_mode(modes, neg_lhs);
519 let pow2 = b.id();
520 b.declare_math_mode(modes, pow2);
521 T::pow(b, ty, neg_lhs, rhs, pow2);
522 let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
523 b.declare_math_mode(modes, pow2_neg);
524 let default = b.id();
525 b.declare_math_mode(modes, default);
526 T::pow(b, ty, lhs, rhs, default);
527 let ids = [
528 modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
529 default,
530 ];
531 for id in ids {
532 b.mark_uniformity(id, uniform);
533 if relaxed {
534 b.decorate(id, Decoration::RelaxedPrecision, []);
535 }
536 }
537 let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
538 b.mark_uniformity(sel1, uniform);
539 b.select(ty, Some(out), is_zero, even, sel1).unwrap();
540 })
541 }
542 Arithmetic::Hypot(_op) => {
543 unreachable!("Replaced by transformer");
544 }
545 Arithmetic::Rhypot(_op) => {
546 unreachable!("Replaced by transformer");
547 }
548 Arithmetic::Sqrt(op) => {
549 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
550 b.declare_math_mode(modes, out);
551 T::sqrt(b, ty, input, out);
552 if matches!(out_ty.elem(), Elem::Relaxed) {
553 b.decorate(out, Decoration::RelaxedPrecision, []);
554 }
555 })
556 }
557 Arithmetic::InverseSqrt(op) => {
558 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
559 b.declare_math_mode(modes, out);
560 T::inverse_sqrt(b, ty, input, out);
561 if matches!(out_ty.elem(), Elem::Relaxed) {
562 b.decorate(out, Decoration::RelaxedPrecision, []);
563 }
564 })
565 }
566 Arithmetic::Round(op) => {
567 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
568 T::round(b, ty, input, out);
569 if matches!(out_ty.elem(), Elem::Relaxed) {
570 b.decorate(out, Decoration::RelaxedPrecision, []);
571 }
572 })
573 }
574 Arithmetic::Floor(op) => {
575 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
576 b.declare_math_mode(modes, out);
577 T::floor(b, ty, input, out);
578 if matches!(out_ty.elem(), Elem::Relaxed) {
579 b.decorate(out, Decoration::RelaxedPrecision, []);
580 }
581 })
582 }
583 Arithmetic::Ceil(op) => {
584 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
585 b.declare_math_mode(modes, out);
586 T::ceil(b, ty, input, out);
587 if matches!(out_ty.elem(), Elem::Relaxed) {
588 b.decorate(out, Decoration::RelaxedPrecision, []);
589 }
590 })
591 }
592 Arithmetic::Trunc(op) => {
593 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
594 b.declare_math_mode(modes, out);
595 T::trunc(b, ty, input, out);
596 if matches!(out_ty.elem(), Elem::Relaxed) {
597 b.decorate(out, Decoration::RelaxedPrecision, []);
598 }
599 })
600 }
601 Arithmetic::Clamp(op) => {
602 let input = self.compile_variable(op.input);
603 let min = self.compile_variable(op.min_value);
604 let max = self.compile_variable(op.max_value);
605 let out = self.compile_variable(out);
606 let out_ty = out.item();
607
608 let input = self.read_as(&input, &out_ty);
609 let min = self.read_as(&min, &out_ty);
610 let max = self.read_as(&max, &out_ty);
611 let out_id = self.write_id(&out);
612 self.mark_uniformity(out_id, uniform);
613
614 let ty = out_ty.id(self);
615
616 match out_ty.elem() {
617 Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
618 Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
619 Elem::Float(..) => {
620 self.declare_math_mode(modes, out_id);
621 T::f_clamp(self, ty, input, min, max, out_id)
622 }
623 Elem::Relaxed => {
624 self.decorate(out_id, Decoration::RelaxedPrecision, []);
625 self.declare_math_mode(modes, out_id);
626 T::f_clamp(self, ty, input, min, max, out_id)
627 }
628 _ => unreachable!(),
629 }
630 self.write(&out, out_id);
631 }
632
633 Arithmetic::Max(op) => self.compile_binary_op(
634 op,
635 out,
636 uniform,
637 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
638 Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
639 Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
640 Elem::Float(..) => {
641 b.declare_math_mode(modes, out);
642 T::f_max(b, ty, lhs, rhs, out)
643 }
644 Elem::Relaxed => {
645 b.decorate(out, Decoration::RelaxedPrecision, []);
646 b.declare_math_mode(modes, out);
647 T::f_max(b, ty, lhs, rhs, out)
648 }
649 _ => unreachable!(),
650 },
651 ),
652 Arithmetic::Min(op) => self.compile_binary_op(
653 op,
654 out,
655 uniform,
656 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
657 Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
658 Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
659 Elem::Float(..) => {
660 b.declare_math_mode(modes, out);
661 T::f_min(b, ty, lhs, rhs, out)
662 }
663 Elem::Relaxed => {
664 b.decorate(out, Decoration::RelaxedPrecision, []);
665 b.declare_math_mode(modes, out);
666 T::f_min(b, ty, lhs, rhs, out)
667 }
668 _ => unreachable!(),
669 },
670 ),
671 }
672 }
673}