1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10 pub fn compile_arithmetic(
11 &mut self,
12 op: Arithmetic,
13 out: Option<core::Variable>,
14 modes: InstructionModes,
15 uniform: bool,
16 ) {
17 let out = out.unwrap();
18 match op {
19 Arithmetic::Add(op) => {
20 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21 match out_ty.elem() {
22 Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23 Elem::Float(..) => {
24 b.declare_math_mode(modes, out);
25 b.f_add(ty, Some(out), lhs, rhs).unwrap()
26 }
27 Elem::Relaxed => {
28 b.decorate(out, Decoration::RelaxedPrecision, []);
29 b.declare_math_mode(modes, out);
30 b.f_add(ty, Some(out), lhs, rhs).unwrap()
31 }
32 _ => unreachable!(),
33 };
34 });
35 }
36 Arithmetic::SaturatingAdd(_) => {
37 unimplemented!("Should be replaced by polyfill");
38 }
39 Arithmetic::Sub(op) => {
40 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41 match out_ty.elem() {
42 Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43 Elem::Float(..) => {
44 b.declare_math_mode(modes, out);
45 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46 }
47 Elem::Relaxed => {
48 b.decorate(out, Decoration::RelaxedPrecision, []);
49 b.declare_math_mode(modes, out);
50 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51 }
52 _ => unreachable!(),
53 };
54 });
55 }
56 Arithmetic::SaturatingSub(_) => {
57 unimplemented!("Should be replaced by polyfill");
58 }
59 Arithmetic::Mul(op) => {
60 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61 match out_ty.elem() {
62 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63 Elem::Float(..) => {
64 b.declare_math_mode(modes, out);
65 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66 }
67 Elem::Relaxed => {
68 b.decorate(out, Decoration::RelaxedPrecision, []);
69 b.declare_math_mode(modes, out);
70 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71 }
72 _ => unreachable!(),
73 };
74 });
75 }
76 Arithmetic::MulHi(op) => {
77 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78 let out_st = b.type_struct([ty, ty]);
79 let extended = match out_ty.elem() {
80 Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81 Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82 _ => unreachable!(),
83 };
84 b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85 });
86 }
87 Arithmetic::Div(op) => {
88 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89 match out_ty.elem() {
90 Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91 Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92 Elem::Float(..) => {
93 b.declare_math_mode(modes, out);
94 b.f_div(ty, Some(out), lhs, rhs).unwrap()
95 }
96 Elem::Relaxed => {
97 b.decorate(out, Decoration::RelaxedPrecision, []);
98 b.declare_math_mode(modes, out);
99 b.f_div(ty, Some(out), lhs, rhs).unwrap()
100 }
101 _ => unreachable!(),
102 };
103 });
104 }
105 Arithmetic::Remainder(op) => {
106 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107 match out_ty.elem() {
108 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109 Elem::Int(_, true) => {
110 let f_ty = match out_ty {
114 Item::Scalar(_elem) => Item::Scalar(Elem::Relaxed),
115 Item::Vector(_elem, factor) => Item::Vector(Elem::Relaxed, factor),
116 _ => unreachable!(),
117 };
118 let f_ty = f_ty.id(b);
119 let lhs_f = b.convert_s_to_f(f_ty, None, lhs).unwrap();
120 let rhs_f = b.convert_s_to_f(f_ty, None, rhs).unwrap();
121 let rem = b.f_mod(f_ty, None, lhs_f, rhs_f).unwrap();
122 b.convert_f_to_s(ty, Some(out), rem).unwrap()
123 }
124 Elem::Float(..) => {
125 b.declare_math_mode(modes, out);
126 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
127 }
128 Elem::Relaxed => {
129 b.decorate(out, Decoration::RelaxedPrecision, []);
130 b.declare_math_mode(modes, out);
131 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
132 }
133 _ => unreachable!(),
134 };
135 });
136 }
137 Arithmetic::Modulo(op) => {
138 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
139 match out_ty.elem() {
140 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
141 Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
142 Elem::Float(..) => {
143 b.declare_math_mode(modes, out);
144 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
145 }
146 Elem::Relaxed => {
147 b.decorate(out, Decoration::RelaxedPrecision, []);
148 b.declare_math_mode(modes, out);
149 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
150 }
151 _ => unreachable!(),
152 };
153 });
154 }
155 Arithmetic::Dot(op) => {
156 if op.lhs.ty.vector_size() == 1 {
157 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
158 match out_ty.elem() {
159 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
160 Elem::Float(..) => {
161 b.declare_math_mode(modes, out);
162 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
163 }
164 Elem::Relaxed => {
165 b.decorate(out, Decoration::RelaxedPrecision, []);
166 b.declare_math_mode(modes, out);
167 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
168 }
169 _ => unreachable!(),
170 };
171 });
172 } else {
173 let lhs = self.compile_variable(op.lhs);
174 let rhs = self.compile_variable(op.rhs);
175 let out = self.compile_variable(out);
176 let ty = out.item().id(self);
177
178 let lhs_id = self.read(&lhs);
179 let rhs_id = self.read(&rhs);
180 let out_id = self.write_id(&out);
181 self.mark_uniformity(out_id, uniform);
182
183 if matches!(lhs.elem(), Elem::Int(_, _)) {
184 self.capabilities.insert(Capability::DotProduct);
185 }
186 if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
187 self.capabilities.insert(Capability::BFloat16DotProductKHR);
188 }
189
190 match (lhs.elem(), rhs.elem()) {
191 (Elem::Int(_, false), Elem::Int(_, false)) => {
192 self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
193 }
194 (Elem::Int(_, true), Elem::Int(_, false)) => {
195 self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
196 }
197 (Elem::Int(_, false), Elem::Int(_, true)) => {
198 self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
199 }
200 (Elem::Int(_, true), Elem::Int(_, true)) => {
201 self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
202 }
203 (Elem::Float(..), Elem::Float(..))
204 | (Elem::Relaxed, Elem::Float(..))
205 | (Elem::Float(..), Elem::Relaxed) => {
206 self.dot(ty, Some(out_id), lhs_id, rhs_id)
207 }
208 (Elem::Relaxed, Elem::Relaxed) => {
209 self.decorate(out_id, Decoration::RelaxedPrecision, []);
210 self.dot(ty, Some(out_id), lhs_id, rhs_id)
211 }
212 _ => unreachable!(),
213 }
214 .unwrap();
215 self.write(&out, out_id);
216 }
217 }
218 Arithmetic::Fma(op) => {
219 let a = self.compile_variable(op.a);
220 let b = self.compile_variable(op.b);
221 let c = self.compile_variable(op.c);
222 let out = self.compile_variable(out);
223 let out_ty = out.item();
224 let relaxed = matches!(
225 (a.item().elem(), b.item().elem(), c.item().elem()),
226 (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
227 );
228
229 let a_id = self.read_as(&a, &out_ty);
230 let b_id = self.read_as(&b, &out_ty);
231 let c_id = self.read_as(&c, &out_ty);
232 let out_id = self.write_id(&out);
233 self.mark_uniformity(out_id, uniform);
234
235 let ty = out_ty.id(self);
236
237 let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
238 self.mark_uniformity(mul, uniform);
239 self.declare_math_mode(modes, mul);
240 self.f_add(ty, Some(out_id), mul, c_id).unwrap();
241 self.declare_math_mode(modes, out_id);
242 if relaxed {
243 self.decorate(mul, Decoration::RelaxedPrecision, []);
244 self.decorate(out_id, Decoration::RelaxedPrecision, []);
245 }
246 self.write(&out, out_id);
247 }
248 Arithmetic::Recip(op) => {
249 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
250 let one = b
251 .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
252 .0;
253 b.declare_math_mode(modes, out);
254 b.f_div(ty, Some(out), one, input).unwrap();
255 });
256 }
257 Arithmetic::Neg(op) => {
258 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
259 match out_ty.elem() {
260 Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
261 Elem::Float(..) => {
262 b.declare_math_mode(modes, out);
263 b.f_negate(ty, Some(out), input).unwrap()
264 }
265 Elem::Relaxed => {
266 b.decorate(out, Decoration::RelaxedPrecision, []);
267 b.declare_math_mode(modes, out);
268 b.f_negate(ty, Some(out), input).unwrap()
269 }
270 _ => unreachable!(),
271 };
272 });
273 }
274 Arithmetic::Erf(_) => {
275 unreachable!("Replaced by transformer")
276 }
277
278 Arithmetic::Normalize(op) => {
280 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
281 b.declare_math_mode(modes, out);
282 T::normalize(b, ty, input, out);
283 if matches!(out_ty.elem(), Elem::Relaxed) {
284 b.decorate(out, Decoration::RelaxedPrecision, []);
285 }
286 });
287 }
288 Arithmetic::VectorSum(op) => {
289 let input_ir = op.input;
290 let input = self.compile_variable(input_ir);
291 let out = self.compile_variable(out);
292 let in_item = input.item();
293 let out_ty = out.item();
294 let vec_size = in_item.vectorization();
295 let scalar_ty = out_ty.id(self);
296 let input_id = self.read(&input);
297 let out_id = self.write_id(&out);
298 self.mark_uniformity(out_id, uniform);
299
300 if vec_size <= 1 {
301 self.copy_object(scalar_ty, Some(out_id), input_id).unwrap();
303 } else if matches!(out_ty.elem(), Elem::Float(..) | Elem::Relaxed) {
304 self.declare_math_mode(modes, out_id);
306 let ones_val = in_item
307 .elem()
308 .constant(self, ConstVal::Bit32(1.0f32.to_bits()));
309 let vec_ty = in_item.id(self);
310 let ones = self.constant_composite(vec_ty, (0..vec_size).map(|_| ones_val));
311 self.dot(scalar_ty, Some(out_id), input_id, ones).unwrap();
312 if matches!(out_ty.elem(), Elem::Relaxed) {
313 self.decorate(out_id, Decoration::RelaxedPrecision, []);
314 }
315 } else {
316 let elem_ty = out_ty.id(self);
318 let mut acc = self
319 .composite_extract(elem_ty, None, input_id, vec![0])
320 .unwrap();
321 for i in 1..vec_size {
322 let elem = self
323 .composite_extract(elem_ty, None, input_id, vec![i])
324 .unwrap();
325 acc = self.i_add(elem_ty, None, acc, elem).unwrap();
326 }
327 self.copy_object(scalar_ty, Some(out_id), acc).unwrap();
328 }
329 self.write(&out, out_id);
330 }
331 Arithmetic::Magnitude(op) => {
332 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
333 b.declare_math_mode(modes, out);
334 T::magnitude(b, ty, input, out);
335 if matches!(out_ty.elem(), Elem::Relaxed) {
336 b.decorate(out, Decoration::RelaxedPrecision, []);
337 }
338 });
339 }
340 Arithmetic::Abs(op) => {
341 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
342 match out_ty.elem() {
343 Elem::Int(_, _) => T::s_abs(b, ty, input, out),
344 Elem::Float(..) => {
345 b.declare_math_mode(modes, out);
346 T::f_abs(b, ty, input, out)
347 }
348 Elem::Relaxed => {
349 b.decorate(out, Decoration::RelaxedPrecision, []);
350 b.declare_math_mode(modes, out);
351 T::f_abs(b, ty, input, out)
352 }
353 _ => unreachable!(),
354 }
355 });
356 }
357 Arithmetic::Exp(op) => {
358 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
359 b.declare_math_mode(modes, out);
360 T::exp(b, ty, input, out);
361 if matches!(out_ty.elem(), Elem::Relaxed) {
362 b.decorate(out, Decoration::RelaxedPrecision, []);
363 }
364 });
365 }
366 Arithmetic::Log(op) => {
367 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
368 b.declare_math_mode(modes, out);
369 T::log(b, ty, input, out);
370 if matches!(out_ty.elem(), Elem::Relaxed) {
371 b.decorate(out, Decoration::RelaxedPrecision, []);
372 }
373 })
374 }
375 Arithmetic::Log1p(op) => {
376 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
377 let one = b
378 .static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty)
379 .0;
380 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
381 let add = match out_ty.elem() {
382 Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
383 Elem::Float(..) | Elem::Relaxed => {
384 b.declare_math_mode(modes, out);
385 b.f_add(ty, None, input, one).unwrap()
386 }
387 _ => unreachable!(),
388 };
389 b.mark_uniformity(add, uniform);
390 if relaxed {
391 b.decorate(add, Decoration::RelaxedPrecision, []);
392 b.decorate(out, Decoration::RelaxedPrecision, []);
393 }
394 b.declare_math_mode(modes, out);
395 T::log(b, ty, add, out)
396 });
397 }
398 Arithmetic::Cos(op) => {
399 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
400 b.declare_math_mode(modes, out);
401 T::cos(b, ty, input, out);
402 if matches!(out_ty.elem(), Elem::Relaxed) {
403 b.decorate(out, Decoration::RelaxedPrecision, []);
404 }
405 })
406 }
407 Arithmetic::Sin(op) => {
408 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
409 b.declare_math_mode(modes, out);
410 T::sin(b, ty, input, out);
411 if matches!(out_ty.elem(), Elem::Relaxed) {
412 b.decorate(out, Decoration::RelaxedPrecision, []);
413 }
414 })
415 }
416 Arithmetic::Tan(op) => {
417 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
418 b.declare_math_mode(modes, out);
419 T::tan(b, ty, input, out);
420 if matches!(out_ty.elem(), Elem::Relaxed) {
421 b.decorate(out, Decoration::RelaxedPrecision, []);
422 }
423 })
424 }
425 Arithmetic::Tanh(op) => {
426 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
427 b.declare_math_mode(modes, out);
428 T::tanh(b, ty, input, out);
429 if matches!(out_ty.elem(), Elem::Relaxed) {
430 b.decorate(out, Decoration::RelaxedPrecision, []);
431 }
432 })
433 }
434 Arithmetic::Sinh(op) => {
435 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
436 b.declare_math_mode(modes, out);
437 T::sinh(b, ty, input, out);
438 if matches!(out_ty.elem(), Elem::Relaxed) {
439 b.decorate(out, Decoration::RelaxedPrecision, []);
440 }
441 })
442 }
443 Arithmetic::Cosh(op) => {
444 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
445 b.declare_math_mode(modes, out);
446 T::cosh(b, ty, input, out);
447 if matches!(out_ty.elem(), Elem::Relaxed) {
448 b.decorate(out, Decoration::RelaxedPrecision, []);
449 }
450 })
451 }
452 Arithmetic::ArcCos(op) => {
453 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
454 b.declare_math_mode(modes, out);
455 T::acos(b, ty, input, out);
456 if matches!(out_ty.elem(), Elem::Relaxed) {
457 b.decorate(out, Decoration::RelaxedPrecision, []);
458 }
459 })
460 }
461 Arithmetic::ArcSin(op) => {
462 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
463 b.declare_math_mode(modes, out);
464 T::asin(b, ty, input, out);
465 if matches!(out_ty.elem(), Elem::Relaxed) {
466 b.decorate(out, Decoration::RelaxedPrecision, []);
467 }
468 })
469 }
470 Arithmetic::ArcTan(op) => {
471 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
472 b.declare_math_mode(modes, out);
473 T::atan(b, ty, input, out);
474 if matches!(out_ty.elem(), Elem::Relaxed) {
475 b.decorate(out, Decoration::RelaxedPrecision, []);
476 }
477 })
478 }
479 Arithmetic::ArcSinh(op) => {
480 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
481 b.declare_math_mode(modes, out);
482 T::asinh(b, ty, input, out);
483 if matches!(out_ty.elem(), Elem::Relaxed) {
484 b.decorate(out, Decoration::RelaxedPrecision, []);
485 }
486 })
487 }
488 Arithmetic::ArcCosh(op) => {
489 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
490 b.declare_math_mode(modes, out);
491 T::acosh(b, ty, input, out);
492 if matches!(out_ty.elem(), Elem::Relaxed) {
493 b.decorate(out, Decoration::RelaxedPrecision, []);
494 }
495 })
496 }
497 Arithmetic::ArcTanh(op) => {
498 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
499 b.declare_math_mode(modes, out);
500 T::atanh(b, ty, input, out);
501 if matches!(out_ty.elem(), Elem::Relaxed) {
502 b.decorate(out, Decoration::RelaxedPrecision, []);
503 }
504 })
505 }
506 Arithmetic::Degrees(op) => {
507 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
508 b.declare_math_mode(modes, out);
509 T::degrees(b, ty, input, out);
510 if matches!(out_ty.elem(), Elem::Relaxed) {
511 b.decorate(out, Decoration::RelaxedPrecision, []);
512 }
513 })
514 }
515 Arithmetic::Radians(op) => {
516 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
517 b.declare_math_mode(modes, out);
518 T::radians(b, ty, input, out);
519 if matches!(out_ty.elem(), Elem::Relaxed) {
520 b.decorate(out, Decoration::RelaxedPrecision, []);
521 }
522 })
523 }
524 Arithmetic::ArcTan2(op) => {
525 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
526 b.declare_math_mode(modes, out);
527 T::atan2(b, ty, lhs, rhs, out);
528 if matches!(out_ty.elem(), Elem::Relaxed) {
529 b.decorate(out, Decoration::RelaxedPrecision, []);
530 }
531 })
532 }
533 Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
535 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
536 let bool = match out_ty {
537 Item::Scalar(_) => Elem::Bool.id(b),
538 Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
539 _ => unreachable!(),
540 };
541 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
542 let zero = out_ty.const_u32(b, 0);
543 let one = out_ty.const_u32(b, 1);
544 let two = out_ty.const_u32(b, 2);
545 let modulo = b.f_rem(ty, None, rhs, two).unwrap();
546 b.declare_math_mode(modes, modulo);
547 let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
548 b.declare_math_mode(modes, is_zero);
549 let abs = b.id();
550 b.declare_math_mode(modes, abs);
551 T::f_abs(b, ty, lhs, abs);
552 let even = b.id();
553 b.declare_math_mode(modes, even);
554 T::pow(b, ty, abs, rhs, even);
555 let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
556 b.declare_math_mode(modes, cond2_0);
557 let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
558 b.declare_math_mode(modes, cond2_1);
559 let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
560 let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
561 b.declare_math_mode(modes, neg_lhs);
562 let pow2 = b.id();
563 b.declare_math_mode(modes, pow2);
564 T::pow(b, ty, neg_lhs, rhs, pow2);
565 let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
566 b.declare_math_mode(modes, pow2_neg);
567 let default = b.id();
568 b.declare_math_mode(modes, default);
569 T::pow(b, ty, lhs, rhs, default);
570 let ids = [
571 modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
572 default,
573 ];
574 for id in ids {
575 b.mark_uniformity(id, uniform);
576 if relaxed {
577 b.decorate(id, Decoration::RelaxedPrecision, []);
578 }
579 }
580 let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
581 b.mark_uniformity(sel1, uniform);
582 b.select(ty, Some(out), is_zero, even, sel1).unwrap();
583 })
584 }
585 Arithmetic::Hypot(_op) => {
586 unreachable!("Replaced by transformer");
587 }
588 Arithmetic::Rhypot(_op) => {
589 unreachable!("Replaced by transformer");
590 }
591 Arithmetic::Sqrt(op) => {
592 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
593 b.declare_math_mode(modes, out);
594 T::sqrt(b, ty, input, out);
595 if matches!(out_ty.elem(), Elem::Relaxed) {
596 b.decorate(out, Decoration::RelaxedPrecision, []);
597 }
598 })
599 }
600 Arithmetic::InverseSqrt(op) => {
601 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
602 b.declare_math_mode(modes, out);
603 T::inverse_sqrt(b, ty, input, out);
604 if matches!(out_ty.elem(), Elem::Relaxed) {
605 b.decorate(out, Decoration::RelaxedPrecision, []);
606 }
607 })
608 }
609 Arithmetic::Round(op) => {
610 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
611 T::round(b, ty, input, out);
612 if matches!(out_ty.elem(), Elem::Relaxed) {
613 b.decorate(out, Decoration::RelaxedPrecision, []);
614 }
615 })
616 }
617 Arithmetic::Floor(op) => {
618 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
619 b.declare_math_mode(modes, out);
620 T::floor(b, ty, input, out);
621 if matches!(out_ty.elem(), Elem::Relaxed) {
622 b.decorate(out, Decoration::RelaxedPrecision, []);
623 }
624 })
625 }
626 Arithmetic::Ceil(op) => {
627 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
628 b.declare_math_mode(modes, out);
629 T::ceil(b, ty, input, out);
630 if matches!(out_ty.elem(), Elem::Relaxed) {
631 b.decorate(out, Decoration::RelaxedPrecision, []);
632 }
633 })
634 }
635 Arithmetic::Trunc(op) => {
636 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
637 b.declare_math_mode(modes, out);
638 T::trunc(b, ty, input, out);
639 if matches!(out_ty.elem(), Elem::Relaxed) {
640 b.decorate(out, Decoration::RelaxedPrecision, []);
641 }
642 })
643 }
644 Arithmetic::Clamp(op) => {
645 let input = self.compile_variable(op.input);
646 let min = self.compile_variable(op.min_value);
647 let max = self.compile_variable(op.max_value);
648 let out = self.compile_variable(out);
649 let out_ty = out.item();
650
651 let input = self.read_as(&input, &out_ty);
652 let min = self.read_as(&min, &out_ty);
653 let max = self.read_as(&max, &out_ty);
654 let out_id = self.write_id(&out);
655 self.mark_uniformity(out_id, uniform);
656
657 let ty = out_ty.id(self);
658
659 match out_ty.elem() {
660 Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
661 Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
662 Elem::Float(..) => {
663 self.declare_math_mode(modes, out_id);
664 T::f_clamp(self, ty, input, min, max, out_id)
665 }
666 Elem::Relaxed => {
667 self.decorate(out_id, Decoration::RelaxedPrecision, []);
668 self.declare_math_mode(modes, out_id);
669 T::f_clamp(self, ty, input, min, max, out_id)
670 }
671 _ => unreachable!(),
672 }
673 self.write(&out, out_id);
674 }
675
676 Arithmetic::Max(op) => self.compile_binary_op(
677 op,
678 out,
679 uniform,
680 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
681 Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
682 Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
683 Elem::Float(..) => {
684 b.declare_math_mode(modes, out);
685 T::f_max(b, ty, lhs, rhs, out)
686 }
687 Elem::Relaxed => {
688 b.decorate(out, Decoration::RelaxedPrecision, []);
689 b.declare_math_mode(modes, out);
690 T::f_max(b, ty, lhs, rhs, out)
691 }
692 _ => unreachable!(),
693 },
694 ),
695 Arithmetic::Min(op) => self.compile_binary_op(
696 op,
697 out,
698 uniform,
699 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
700 Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
701 Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
702 Elem::Float(..) => {
703 b.declare_math_mode(modes, out);
704 T::f_min(b, ty, lhs, rhs, out)
705 }
706 Elem::Relaxed => {
707 b.decorate(out, Decoration::RelaxedPrecision, []);
708 b.declare_math_mode(modes, out);
709 T::f_min(b, ty, lhs, rhs, out)
710 }
711 _ => unreachable!(),
712 },
713 ),
714 }
715 }
716}