1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10 pub fn compile_arithmetic(
11 &mut self,
12 op: Arithmetic,
13 out: Option<core::Variable>,
14 modes: InstructionModes,
15 uniform: bool,
16 ) {
17 let out = out.unwrap();
18 match op {
19 Arithmetic::Add(op) => {
20 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21 match out_ty.elem() {
22 Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23 Elem::Float(..) => {
24 b.declare_math_mode(modes, out);
25 b.f_add(ty, Some(out), lhs, rhs).unwrap()
26 }
27 Elem::Relaxed => {
28 b.decorate(out, Decoration::RelaxedPrecision, []);
29 b.declare_math_mode(modes, out);
30 b.f_add(ty, Some(out), lhs, rhs).unwrap()
31 }
32 _ => unreachable!(),
33 };
34 });
35 }
36 Arithmetic::SaturatingAdd(_) => {
37 unimplemented!("Should be replaced by polyfill");
38 }
39 Arithmetic::Sub(op) => {
40 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41 match out_ty.elem() {
42 Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43 Elem::Float(..) => {
44 b.declare_math_mode(modes, out);
45 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46 }
47 Elem::Relaxed => {
48 b.decorate(out, Decoration::RelaxedPrecision, []);
49 b.declare_math_mode(modes, out);
50 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51 }
52 _ => unreachable!(),
53 };
54 });
55 }
56 Arithmetic::SaturatingSub(_) => {
57 unimplemented!("Should be replaced by polyfill");
58 }
59 Arithmetic::Mul(op) => {
60 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61 match out_ty.elem() {
62 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63 Elem::Float(..) => {
64 b.declare_math_mode(modes, out);
65 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66 }
67 Elem::Relaxed => {
68 b.decorate(out, Decoration::RelaxedPrecision, []);
69 b.declare_math_mode(modes, out);
70 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71 }
72 _ => unreachable!(),
73 };
74 });
75 }
76 Arithmetic::MulHi(op) => {
77 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78 let out_st = b.type_struct([ty, ty]);
79 let extended = match out_ty.elem() {
80 Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81 Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82 _ => unreachable!(),
83 };
84 b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85 });
86 }
87 Arithmetic::Div(op) => {
88 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89 match out_ty.elem() {
90 Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91 Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92 Elem::Float(..) => {
93 b.declare_math_mode(modes, out);
94 b.f_div(ty, Some(out), lhs, rhs).unwrap()
95 }
96 Elem::Relaxed => {
97 b.decorate(out, Decoration::RelaxedPrecision, []);
98 b.declare_math_mode(modes, out);
99 b.f_div(ty, Some(out), lhs, rhs).unwrap()
100 }
101 _ => unreachable!(),
102 };
103 });
104 }
105 Arithmetic::Remainder(op) => {
106 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107 match out_ty.elem() {
108 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109 Elem::Int(_, true) => {
110 let f_ty = match out_ty {
114 Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115 Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116 _ => unreachable!(),
117 };
118 let f_ty = f_ty.id(b);
119 let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120 let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121 let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122 b.convert_f_to_s(ty, Some(out), rem).unwrap()
123 }
124 Elem::Float(..) => {
125 b.declare_math_mode(modes, out);
126 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127 }
128 Elem::Relaxed => {
129 b.decorate(out, Decoration::RelaxedPrecision, []);
130 b.declare_math_mode(modes, out);
131 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132 }
133 _ => unreachable!(),
134 };
135 });
136 }
137 Arithmetic::Modulo(op) => {
138 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139 match out_ty.elem() {
140 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141 Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142 Elem::Float(..) => {
143 b.declare_math_mode(modes, out);
144 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145 }
146 Elem::Relaxed => {
147 b.decorate(out, Decoration::RelaxedPrecision, []);
148 b.declare_math_mode(modes, out);
149 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150 }
151 _ => unreachable!(),
152 };
153 });
154 }
155 Arithmetic::Dot(op) => {
156 if op.lhs.ty.line_size() == 1 {
157 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158 match out_ty.elem() {
159 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160 Elem::Float(..) => {
161 b.declare_math_mode(modes, out);
162 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163 }
164 Elem::Relaxed => {
165 b.decorate(out, Decoration::RelaxedPrecision, []);
166 b.declare_math_mode(modes, out);
167 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168 }
169 _ => unreachable!(),
170 };
171 });
172 } else {
173 let lhs = self.compile_variable(op.lhs);
174 let rhs = self.compile_variable(op.rhs);
175 let out = self.compile_variable(out);
176 let ty = out.item().id(self);
177
178 let lhs_id = self.read(&lhs);
179 let rhs_id = self.read(&rhs);
180 let out_id = self.write_id(&out);
181 self.mark_uniformity(out_id, uniform);
182
183 if matches!(lhs.elem(), Elem::Int(_, _)) {
184 self.capabilities.insert(Capability::DotProduct);
185 }
186 if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187 self.capabilities.insert(Capability::BFloat16DotProductKHR);
188 }
189
190 match (lhs.elem(), rhs.elem()) {
191 (Elem::Int(_, false), Elem::Int(_, false)) => {
192 self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193 }
194 (Elem::Int(_, true), Elem::Int(_, false)) => {
195 self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196 }
197 (Elem::Int(_, false), Elem::Int(_, true)) => {
198 self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199 }
200 (Elem::Int(_, true), Elem::Int(_, true)) => {
201 self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202 }
203 (Elem::Float(..), Elem::Float(..))
204 | (Elem::Relaxed, Elem::Float(..))
205 | (Elem::Float(..), Elem::Relaxed) => {
206 self.dot(ty, Some(out_id), lhs_id, rhs_id)
207 }
208 (Elem::Relaxed, Elem::Relaxed) => {
209 self.decorate(out_id, Decoration::RelaxedPrecision, []);
210 self.dot(ty, Some(out_id), lhs_id, rhs_id)
211 }
212 _ => unreachable!(),
213 }
214 .unwrap();
215 self.write(&out, out_id);
216 }
217 }
218 Arithmetic::Fma(op) => {
219 let a = self.compile_variable(op.a);
220 let b = self.compile_variable(op.b);
221 let c = self.compile_variable(op.c);
222 let out = self.compile_variable(out);
223 let out_ty = out.item();
224 let relaxed = matches!(
225 (a.item().elem(), b.item().elem(), c.item().elem()),
226 (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227 );
228
229 let a_id = self.read_as(&a, &out_ty);
230 let b_id = self.read_as(&b, &out_ty);
231 let c_id = self.read_as(&c, &out_ty);
232 let out_id = self.write_id(&out);
233 self.mark_uniformity(out_id, uniform);
234
235 let ty = out_ty.id(self);
236
237 let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238 self.mark_uniformity(mul, uniform);
239 self.declare_math_mode(modes, mul);
240 self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241 self.declare_math_mode(modes, out_id);
242 if relaxed {
243 self.decorate(mul, Decoration::RelaxedPrecision, []);
244 self.decorate(out_id, Decoration::RelaxedPrecision, []);
245 }
246 self.write(&out, out_id);
247 }
248 Arithmetic::Recip(op) => {
249 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
251 b.declare_math_mode(modes, out);
252 b.f_div(ty, Some(out), one, input).unwrap();
253 });
254 }
255 Arithmetic::Neg(op) => {
256 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
257 match out_ty.elem() {
258 Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
259 Elem::Float(..) => {
260 b.declare_math_mode(modes, out);
261 b.f_negate(ty, Some(out), input).unwrap()
262 }
263 Elem::Relaxed => {
264 b.decorate(out, Decoration::RelaxedPrecision, []);
265 b.declare_math_mode(modes, out);
266 b.f_negate(ty, Some(out), input).unwrap()
267 }
268 _ => unreachable!(),
269 };
270 });
271 }
272 Arithmetic::Erf(_) => {
273 unreachable!("Replaced by transformer")
274 }
275
276 Arithmetic::Normalize(op) => {
278 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
279 b.declare_math_mode(modes, out);
280 T::normalize(b, ty, input, out);
281 if matches!(out_ty.elem(), Elem::Relaxed) {
282 b.decorate(out, Decoration::RelaxedPrecision, []);
283 }
284 });
285 }
286 Arithmetic::Magnitude(op) => {
287 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
288 b.declare_math_mode(modes, out);
289 T::magnitude(b, ty, input, out);
290 if matches!(out_ty.elem(), Elem::Relaxed) {
291 b.decorate(out, Decoration::RelaxedPrecision, []);
292 }
293 });
294 }
295 Arithmetic::Abs(op) => {
296 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
297 match out_ty.elem() {
298 Elem::Int(_, _) => T::s_abs(b, ty, input, out),
299 Elem::Float(..) => {
300 b.declare_math_mode(modes, out);
301 T::f_abs(b, ty, input, out)
302 }
303 Elem::Relaxed => {
304 b.decorate(out, Decoration::RelaxedPrecision, []);
305 b.declare_math_mode(modes, out);
306 T::f_abs(b, ty, input, out)
307 }
308 _ => unreachable!(),
309 }
310 });
311 }
312 Arithmetic::Exp(op) => {
313 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
314 b.declare_math_mode(modes, out);
315 T::exp(b, ty, input, out);
316 if matches!(out_ty.elem(), Elem::Relaxed) {
317 b.decorate(out, Decoration::RelaxedPrecision, []);
318 }
319 });
320 }
321 Arithmetic::Log(op) => {
322 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
323 b.declare_math_mode(modes, out);
324 T::log(b, ty, input, out);
325 if matches!(out_ty.elem(), Elem::Relaxed) {
326 b.decorate(out, Decoration::RelaxedPrecision, []);
327 }
328 })
329 }
330 Arithmetic::Log1p(op) => {
331 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
332 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
333 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
334 let add = match out_ty.elem() {
335 Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
336 Elem::Float(..) | Elem::Relaxed => {
337 b.declare_math_mode(modes, out);
338 b.f_add(ty, None, input, one).unwrap()
339 }
340 _ => unreachable!(),
341 };
342 b.mark_uniformity(add, uniform);
343 if relaxed {
344 b.decorate(add, Decoration::RelaxedPrecision, []);
345 b.decorate(out, Decoration::RelaxedPrecision, []);
346 }
347 b.declare_math_mode(modes, out);
348 T::log(b, ty, add, out)
349 });
350 }
351 Arithmetic::Cos(op) => {
352 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
353 b.declare_math_mode(modes, out);
354 T::cos(b, ty, input, out);
355 if matches!(out_ty.elem(), Elem::Relaxed) {
356 b.decorate(out, Decoration::RelaxedPrecision, []);
357 }
358 })
359 }
360 Arithmetic::Sin(op) => {
361 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
362 b.declare_math_mode(modes, out);
363 T::sin(b, ty, input, out);
364 if matches!(out_ty.elem(), Elem::Relaxed) {
365 b.decorate(out, Decoration::RelaxedPrecision, []);
366 }
367 })
368 }
369 Arithmetic::Tan(op) => {
370 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
371 b.declare_math_mode(modes, out);
372 T::tan(b, ty, input, out);
373 if matches!(out_ty.elem(), Elem::Relaxed) {
374 b.decorate(out, Decoration::RelaxedPrecision, []);
375 }
376 })
377 }
378 Arithmetic::Tanh(op) => {
379 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
380 b.declare_math_mode(modes, out);
381 T::tanh(b, ty, input, out);
382 if matches!(out_ty.elem(), Elem::Relaxed) {
383 b.decorate(out, Decoration::RelaxedPrecision, []);
384 }
385 })
386 }
387 Arithmetic::Sinh(op) => {
388 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
389 b.declare_math_mode(modes, out);
390 T::sinh(b, ty, input, out);
391 if matches!(out_ty.elem(), Elem::Relaxed) {
392 b.decorate(out, Decoration::RelaxedPrecision, []);
393 }
394 })
395 }
396 Arithmetic::Cosh(op) => {
397 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
398 b.declare_math_mode(modes, out);
399 T::cosh(b, ty, input, out);
400 if matches!(out_ty.elem(), Elem::Relaxed) {
401 b.decorate(out, Decoration::RelaxedPrecision, []);
402 }
403 })
404 }
405 Arithmetic::ArcCos(op) => {
406 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
407 b.declare_math_mode(modes, out);
408 T::acos(b, ty, input, out);
409 if matches!(out_ty.elem(), Elem::Relaxed) {
410 b.decorate(out, Decoration::RelaxedPrecision, []);
411 }
412 })
413 }
414 Arithmetic::ArcSin(op) => {
415 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
416 b.declare_math_mode(modes, out);
417 T::asin(b, ty, input, out);
418 if matches!(out_ty.elem(), Elem::Relaxed) {
419 b.decorate(out, Decoration::RelaxedPrecision, []);
420 }
421 })
422 }
423 Arithmetic::ArcTan(op) => {
424 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
425 b.declare_math_mode(modes, out);
426 T::atan(b, ty, input, out);
427 if matches!(out_ty.elem(), Elem::Relaxed) {
428 b.decorate(out, Decoration::RelaxedPrecision, []);
429 }
430 })
431 }
432 Arithmetic::ArcSinh(op) => {
433 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
434 b.declare_math_mode(modes, out);
435 T::asinh(b, ty, input, out);
436 if matches!(out_ty.elem(), Elem::Relaxed) {
437 b.decorate(out, Decoration::RelaxedPrecision, []);
438 }
439 })
440 }
441 Arithmetic::ArcCosh(op) => {
442 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
443 b.declare_math_mode(modes, out);
444 T::acosh(b, ty, input, out);
445 if matches!(out_ty.elem(), Elem::Relaxed) {
446 b.decorate(out, Decoration::RelaxedPrecision, []);
447 }
448 })
449 }
450 Arithmetic::ArcTanh(op) => {
451 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
452 b.declare_math_mode(modes, out);
453 T::atanh(b, ty, input, out);
454 if matches!(out_ty.elem(), Elem::Relaxed) {
455 b.decorate(out, Decoration::RelaxedPrecision, []);
456 }
457 })
458 }
459 Arithmetic::Degrees(op) => {
460 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
461 b.declare_math_mode(modes, out);
462 T::degrees(b, ty, input, out);
463 if matches!(out_ty.elem(), Elem::Relaxed) {
464 b.decorate(out, Decoration::RelaxedPrecision, []);
465 }
466 })
467 }
468 Arithmetic::Radians(op) => {
469 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
470 b.declare_math_mode(modes, out);
471 T::radians(b, ty, input, out);
472 if matches!(out_ty.elem(), Elem::Relaxed) {
473 b.decorate(out, Decoration::RelaxedPrecision, []);
474 }
475 })
476 }
477 Arithmetic::ArcTan2(op) => {
478 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
479 b.declare_math_mode(modes, out);
480 T::atan2(b, ty, lhs, rhs, out);
481 if matches!(out_ty.elem(), Elem::Relaxed) {
482 b.decorate(out, Decoration::RelaxedPrecision, []);
483 }
484 })
485 }
486 Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
488 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
489 let bool = match out_ty {
490 Item::Scalar(_) => Elem::Bool.id(b),
491 Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
492 _ => unreachable!(),
493 };
494 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
495 let zero = out_ty.const_u32(b, 0);
496 let one = out_ty.const_u32(b, 1);
497 let two = out_ty.const_u32(b, 2);
498 let modulo = b.f_rem(ty, None, rhs, two).unwrap();
499 b.declare_math_mode(modes, modulo);
500 let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
501 b.declare_math_mode(modes, is_zero);
502 let abs = b.id();
503 b.declare_math_mode(modes, abs);
504 T::f_abs(b, ty, lhs, abs);
505 let even = b.id();
506 b.declare_math_mode(modes, even);
507 T::pow(b, ty, abs, rhs, even);
508 let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
509 b.declare_math_mode(modes, cond2_0);
510 let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
511 b.declare_math_mode(modes, cond2_1);
512 let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
513 let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
514 b.declare_math_mode(modes, neg_lhs);
515 let pow2 = b.id();
516 b.declare_math_mode(modes, pow2);
517 T::pow(b, ty, neg_lhs, rhs, pow2);
518 let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
519 b.declare_math_mode(modes, pow2_neg);
520 let default = b.id();
521 b.declare_math_mode(modes, default);
522 T::pow(b, ty, lhs, rhs, default);
523 let ids = [
524 modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
525 default,
526 ];
527 for id in ids {
528 b.mark_uniformity(id, uniform);
529 if relaxed {
530 b.decorate(id, Decoration::RelaxedPrecision, []);
531 }
532 }
533 let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
534 b.mark_uniformity(sel1, uniform);
535 b.select(ty, Some(out), is_zero, even, sel1).unwrap();
536 })
537 }
538 Arithmetic::Sqrt(op) => {
539 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
540 b.declare_math_mode(modes, out);
541 T::sqrt(b, ty, input, out);
542 if matches!(out_ty.elem(), Elem::Relaxed) {
543 b.decorate(out, Decoration::RelaxedPrecision, []);
544 }
545 })
546 }
547 Arithmetic::InverseSqrt(op) => {
548 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
549 b.declare_math_mode(modes, out);
550 T::inverse_sqrt(b, ty, input, out);
551 if matches!(out_ty.elem(), Elem::Relaxed) {
552 b.decorate(out, Decoration::RelaxedPrecision, []);
553 }
554 })
555 }
556 Arithmetic::Round(op) => {
557 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
558 T::round(b, ty, input, out);
559 if matches!(out_ty.elem(), Elem::Relaxed) {
560 b.decorate(out, Decoration::RelaxedPrecision, []);
561 }
562 })
563 }
564 Arithmetic::Floor(op) => {
565 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
566 b.declare_math_mode(modes, out);
567 T::floor(b, ty, input, out);
568 if matches!(out_ty.elem(), Elem::Relaxed) {
569 b.decorate(out, Decoration::RelaxedPrecision, []);
570 }
571 })
572 }
573 Arithmetic::Ceil(op) => {
574 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
575 b.declare_math_mode(modes, out);
576 T::ceil(b, ty, input, out);
577 if matches!(out_ty.elem(), Elem::Relaxed) {
578 b.decorate(out, Decoration::RelaxedPrecision, []);
579 }
580 })
581 }
582 Arithmetic::Trunc(op) => {
583 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
584 b.declare_math_mode(modes, out);
585 T::trunc(b, ty, input, out);
586 if matches!(out_ty.elem(), Elem::Relaxed) {
587 b.decorate(out, Decoration::RelaxedPrecision, []);
588 }
589 })
590 }
591 Arithmetic::Clamp(op) => {
592 let input = self.compile_variable(op.input);
593 let min = self.compile_variable(op.min_value);
594 let max = self.compile_variable(op.max_value);
595 let out = self.compile_variable(out);
596 let out_ty = out.item();
597
598 let input = self.read_as(&input, &out_ty);
599 let min = self.read_as(&min, &out_ty);
600 let max = self.read_as(&max, &out_ty);
601 let out_id = self.write_id(&out);
602 self.mark_uniformity(out_id, uniform);
603
604 let ty = out_ty.id(self);
605
606 match out_ty.elem() {
607 Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
608 Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
609 Elem::Float(..) => {
610 self.declare_math_mode(modes, out_id);
611 T::f_clamp(self, ty, input, min, max, out_id)
612 }
613 Elem::Relaxed => {
614 self.decorate(out_id, Decoration::RelaxedPrecision, []);
615 self.declare_math_mode(modes, out_id);
616 T::f_clamp(self, ty, input, min, max, out_id)
617 }
618 _ => unreachable!(),
619 }
620 self.write(&out, out_id);
621 }
622
623 Arithmetic::Max(op) => self.compile_binary_op(
624 op,
625 out,
626 uniform,
627 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
628 Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
629 Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
630 Elem::Float(..) => {
631 b.declare_math_mode(modes, out);
632 T::f_max(b, ty, lhs, rhs, out)
633 }
634 Elem::Relaxed => {
635 b.decorate(out, Decoration::RelaxedPrecision, []);
636 b.declare_math_mode(modes, out);
637 T::f_max(b, ty, lhs, rhs, out)
638 }
639 _ => unreachable!(),
640 },
641 ),
642 Arithmetic::Min(op) => self.compile_binary_op(
643 op,
644 out,
645 uniform,
646 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
647 Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
648 Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
649 Elem::Float(..) => {
650 b.declare_math_mode(modes, out);
651 T::f_min(b, ty, lhs, rhs, out)
652 }
653 Elem::Relaxed => {
654 b.decorate(out, Decoration::RelaxedPrecision, []);
655 b.declare_math_mode(modes, out);
656 T::f_min(b, ty, lhs, rhs, out)
657 }
658 _ => unreachable!(),
659 },
660 ),
661 }
662 }
663}