1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic, InstructionModes};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10 pub fn compile_arithmetic(
11 &mut self,
12 op: Arithmetic,
13 out: Option<core::Variable>,
14 modes: InstructionModes,
15 uniform: bool,
16 ) {
17 let out = out.unwrap();
18 match op {
19 Arithmetic::Add(op) => {
20 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
21 match out_ty.elem() {
22 Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
23 Elem::Float(..) => {
24 b.declare_math_mode(modes, out);
25 b.f_add(ty, Some(out), lhs, rhs).unwrap()
26 }
27 Elem::Relaxed => {
28 b.decorate(out, Decoration::RelaxedPrecision, []);
29 b.declare_math_mode(modes, out);
30 b.f_add(ty, Some(out), lhs, rhs).unwrap()
31 }
32 _ => unreachable!(),
33 };
34 });
35 }
36 Arithmetic::SaturatingAdd(_) => {
37 unimplemented!("Should be replaced by polyfill");
38 }
39 Arithmetic::Sub(op) => {
40 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
41 match out_ty.elem() {
42 Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
43 Elem::Float(..) => {
44 b.declare_math_mode(modes, out);
45 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
46 }
47 Elem::Relaxed => {
48 b.decorate(out, Decoration::RelaxedPrecision, []);
49 b.declare_math_mode(modes, out);
50 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
51 }
52 _ => unreachable!(),
53 };
54 });
55 }
56 Arithmetic::SaturatingSub(_) => {
57 unimplemented!("Should be replaced by polyfill");
58 }
59 Arithmetic::Mul(op) => {
60 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
61 match out_ty.elem() {
62 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
63 Elem::Float(..) => {
64 b.declare_math_mode(modes, out);
65 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
66 }
67 Elem::Relaxed => {
68 b.decorate(out, Decoration::RelaxedPrecision, []);
69 b.declare_math_mode(modes, out);
70 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
71 }
72 _ => unreachable!(),
73 };
74 });
75 }
76 Arithmetic::MulHi(op) => {
77 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
78 let out_st = b.type_struct([ty, ty]);
79 let extended = match out_ty.elem() {
80 Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
81 Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
82 _ => unreachable!(),
83 };
84 b.composite_extract(ty, Some(out), extended, [1]).unwrap();
85 });
86 }
87 Arithmetic::Div(op) => {
88 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
89 match out_ty.elem() {
90 Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
91 Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
92 Elem::Float(..) => {
93 b.declare_math_mode(modes, out);
94 b.f_div(ty, Some(out), lhs, rhs).unwrap()
95 }
96 Elem::Relaxed => {
97 b.decorate(out, Decoration::RelaxedPrecision, []);
98 b.declare_math_mode(modes, out);
99 b.f_div(ty, Some(out), lhs, rhs).unwrap()
100 }
101 _ => unreachable!(),
102 };
103 });
104 }
105 Arithmetic::Remainder(op) => {
106 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
107 match out_ty.elem() {
108 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
109 Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(),
110 Elem::Float(..) => {
111 b.declare_math_mode(modes, out);
112 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
113 }
114 Elem::Relaxed => {
115 b.decorate(out, Decoration::RelaxedPrecision, []);
116 b.declare_math_mode(modes, out);
117 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
118 }
119 _ => unreachable!(),
120 };
121 });
122 }
123 Arithmetic::Modulo(op) => {
124 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
125 match out_ty.elem() {
126 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
127 Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
128 Elem::Float(..) => {
129 b.declare_math_mode(modes, out);
130 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
131 }
132 Elem::Relaxed => {
133 b.decorate(out, Decoration::RelaxedPrecision, []);
134 b.declare_math_mode(modes, out);
135 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
136 }
137 _ => unreachable!(),
138 };
139 });
140 }
141 Arithmetic::Dot(op) => {
142 if op.lhs.ty.line_size() == 1 {
143 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
144 match out_ty.elem() {
145 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
146 Elem::Float(..) => {
147 b.declare_math_mode(modes, out);
148 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
149 }
150 Elem::Relaxed => {
151 b.decorate(out, Decoration::RelaxedPrecision, []);
152 b.declare_math_mode(modes, out);
153 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
154 }
155 _ => unreachable!(),
156 };
157 });
158 } else {
159 let lhs = self.compile_variable(op.lhs);
160 let rhs = self.compile_variable(op.rhs);
161 let out = self.compile_variable(out);
162 let ty = out.item().id(self);
163
164 let lhs_id = self.read(&lhs);
165 let rhs_id = self.read(&rhs);
166 let out_id = self.write_id(&out);
167 self.mark_uniformity(out_id, uniform);
168
169 if matches!(lhs.elem(), Elem::Int(_, _)) {
170 self.capabilities.insert(Capability::DotProduct);
171 }
172 if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
173 self.capabilities.insert(Capability::BFloat16DotProductKHR);
174 }
175
176 match (lhs.elem(), rhs.elem()) {
177 (Elem::Int(_, false), Elem::Int(_, false)) => {
178 self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
179 }
180 (Elem::Int(_, true), Elem::Int(_, false)) => {
181 self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
182 }
183 (Elem::Int(_, false), Elem::Int(_, true)) => {
184 self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
185 }
186 (Elem::Int(_, true), Elem::Int(_, true)) => {
187 self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
188 }
189 (Elem::Float(..), Elem::Float(..))
190 | (Elem::Relaxed, Elem::Float(..))
191 | (Elem::Float(..), Elem::Relaxed) => {
192 self.dot(ty, Some(out_id), lhs_id, rhs_id)
193 }
194 (Elem::Relaxed, Elem::Relaxed) => {
195 self.decorate(out_id, Decoration::RelaxedPrecision, []);
196 self.dot(ty, Some(out_id), lhs_id, rhs_id)
197 }
198 _ => unreachable!(),
199 }
200 .unwrap();
201 self.write(&out, out_id);
202 }
203 }
204 Arithmetic::Fma(op) => {
205 let a = self.compile_variable(op.a);
206 let b = self.compile_variable(op.b);
207 let c = self.compile_variable(op.c);
208 let out = self.compile_variable(out);
209 let out_ty = out.item();
210 let relaxed = matches!(
211 (a.item().elem(), b.item().elem(), c.item().elem()),
212 (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
213 );
214
215 let a_id = self.read_as(&a, &out_ty);
216 let b_id = self.read_as(&b, &out_ty);
217 let c_id = self.read_as(&c, &out_ty);
218 let out_id = self.write_id(&out);
219 self.mark_uniformity(out_id, uniform);
220
221 let ty = out_ty.id(self);
222
223 let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
224 self.mark_uniformity(mul, uniform);
225 self.declare_math_mode(modes, mul);
226 self.f_add(ty, Some(out_id), mul, c_id).unwrap();
227 self.declare_math_mode(modes, out_id);
228 if relaxed {
229 self.decorate(mul, Decoration::RelaxedPrecision, []);
230 self.decorate(out_id, Decoration::RelaxedPrecision, []);
231 }
232 self.write(&out, out_id);
233 }
234 Arithmetic::Recip(op) => {
235 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
236 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
237 b.declare_math_mode(modes, out);
238 b.f_div(ty, Some(out), one, input).unwrap();
239 });
240 }
241 Arithmetic::Neg(op) => {
242 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
243 match out_ty.elem() {
244 Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
245 Elem::Float(..) => {
246 b.declare_math_mode(modes, out);
247 b.f_negate(ty, Some(out), input).unwrap()
248 }
249 Elem::Relaxed => {
250 b.decorate(out, Decoration::RelaxedPrecision, []);
251 b.declare_math_mode(modes, out);
252 b.f_negate(ty, Some(out), input).unwrap()
253 }
254 _ => unreachable!(),
255 };
256 });
257 }
258 Arithmetic::Erf(_) => {
259 unreachable!("Replaced by transformer")
260 }
261
262 Arithmetic::Normalize(op) => {
264 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
265 b.declare_math_mode(modes, out);
266 T::normalize(b, ty, input, out);
267 if matches!(out_ty.elem(), Elem::Relaxed) {
268 b.decorate(out, Decoration::RelaxedPrecision, []);
269 }
270 });
271 }
272 Arithmetic::Magnitude(op) => {
273 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
274 b.declare_math_mode(modes, out);
275 T::magnitude(b, ty, input, out);
276 if matches!(out_ty.elem(), Elem::Relaxed) {
277 b.decorate(out, Decoration::RelaxedPrecision, []);
278 }
279 });
280 }
281 Arithmetic::Abs(op) => {
282 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
283 match out_ty.elem() {
284 Elem::Int(_, _) => T::s_abs(b, ty, input, out),
285 Elem::Float(..) => {
286 b.declare_math_mode(modes, out);
287 T::f_abs(b, ty, input, out)
288 }
289 Elem::Relaxed => {
290 b.decorate(out, Decoration::RelaxedPrecision, []);
291 b.declare_math_mode(modes, out);
292 T::f_abs(b, ty, input, out)
293 }
294 _ => unreachable!(),
295 }
296 });
297 }
298 Arithmetic::Exp(op) => {
299 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
300 b.declare_math_mode(modes, out);
301 T::exp(b, ty, input, out);
302 if matches!(out_ty.elem(), Elem::Relaxed) {
303 b.decorate(out, Decoration::RelaxedPrecision, []);
304 }
305 });
306 }
307 Arithmetic::Log(op) => {
308 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
309 b.declare_math_mode(modes, out);
310 T::log(b, ty, input, out);
311 if matches!(out_ty.elem(), Elem::Relaxed) {
312 b.decorate(out, Decoration::RelaxedPrecision, []);
313 }
314 })
315 }
316 Arithmetic::Log1p(op) => {
317 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
318 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
319 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
320 let add = match out_ty.elem() {
321 Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
322 Elem::Float(..) | Elem::Relaxed => {
323 b.declare_math_mode(modes, out);
324 b.f_add(ty, None, input, one).unwrap()
325 }
326 _ => unreachable!(),
327 };
328 b.mark_uniformity(add, uniform);
329 if relaxed {
330 b.decorate(add, Decoration::RelaxedPrecision, []);
331 b.decorate(out, Decoration::RelaxedPrecision, []);
332 }
333 b.declare_math_mode(modes, out);
334 T::log(b, ty, add, out)
335 });
336 }
337 Arithmetic::Cos(op) => {
338 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
339 b.declare_math_mode(modes, out);
340 T::cos(b, ty, input, out);
341 if matches!(out_ty.elem(), Elem::Relaxed) {
342 b.decorate(out, Decoration::RelaxedPrecision, []);
343 }
344 })
345 }
346 Arithmetic::Sin(op) => {
347 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
348 b.declare_math_mode(modes, out);
349 T::sin(b, ty, input, out);
350 if matches!(out_ty.elem(), Elem::Relaxed) {
351 b.decorate(out, Decoration::RelaxedPrecision, []);
352 }
353 })
354 }
355 Arithmetic::Tanh(op) => {
356 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357 b.declare_math_mode(modes, out);
358 T::tanh(b, ty, input, out);
359 if matches!(out_ty.elem(), Elem::Relaxed) {
360 b.decorate(out, Decoration::RelaxedPrecision, []);
361 }
362 })
363 }
364 Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
366 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
367 let bool = match out_ty {
368 Item::Scalar(_) => Elem::Bool.id(b),
369 Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
370 _ => unreachable!(),
371 };
372 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
373 let zero = out_ty.const_u32(b, 0);
374 let one = out_ty.const_u32(b, 1);
375 let two = out_ty.const_u32(b, 2);
376 let modulo = b.f_rem(ty, None, rhs, two).unwrap();
377 b.declare_math_mode(modes, modulo);
378 let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
379 b.declare_math_mode(modes, is_zero);
380 let abs = b.id();
381 b.declare_math_mode(modes, abs);
382 T::f_abs(b, ty, lhs, abs);
383 let even = b.id();
384 b.declare_math_mode(modes, even);
385 T::pow(b, ty, abs, rhs, even);
386 let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
387 b.declare_math_mode(modes, cond2_0);
388 let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
389 b.declare_math_mode(modes, cond2_1);
390 let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
391 let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
392 b.declare_math_mode(modes, neg_lhs);
393 let pow2 = b.id();
394 b.declare_math_mode(modes, pow2);
395 T::pow(b, ty, neg_lhs, rhs, pow2);
396 let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
397 b.declare_math_mode(modes, pow2_neg);
398 let default = b.id();
399 b.declare_math_mode(modes, default);
400 T::pow(b, ty, lhs, rhs, default);
401 let ids = [
402 modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
403 default,
404 ];
405 for id in ids {
406 b.mark_uniformity(id, uniform);
407 if relaxed {
408 b.decorate(id, Decoration::RelaxedPrecision, []);
409 }
410 }
411 let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
412 b.mark_uniformity(sel1, uniform);
413 b.select(ty, Some(out), is_zero, even, sel1).unwrap();
414 })
415 }
416 Arithmetic::Sqrt(op) => {
417 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
418 b.declare_math_mode(modes, out);
419 T::sqrt(b, ty, input, out);
420 if matches!(out_ty.elem(), Elem::Relaxed) {
421 b.decorate(out, Decoration::RelaxedPrecision, []);
422 }
423 })
424 }
425 Arithmetic::InverseSqrt(op) => {
426 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
427 b.declare_math_mode(modes, out);
428 T::inverse_sqrt(b, ty, input, out);
429 if matches!(out_ty.elem(), Elem::Relaxed) {
430 b.decorate(out, Decoration::RelaxedPrecision, []);
431 }
432 })
433 }
434 Arithmetic::Round(op) => {
435 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
436 T::round(b, ty, input, out);
437 if matches!(out_ty.elem(), Elem::Relaxed) {
438 b.decorate(out, Decoration::RelaxedPrecision, []);
439 }
440 })
441 }
442 Arithmetic::Floor(op) => {
443 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
444 b.declare_math_mode(modes, out);
445 T::floor(b, ty, input, out);
446 if matches!(out_ty.elem(), Elem::Relaxed) {
447 b.decorate(out, Decoration::RelaxedPrecision, []);
448 }
449 })
450 }
451 Arithmetic::Ceil(op) => {
452 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
453 b.declare_math_mode(modes, out);
454 T::ceil(b, ty, input, out);
455 if matches!(out_ty.elem(), Elem::Relaxed) {
456 b.decorate(out, Decoration::RelaxedPrecision, []);
457 }
458 })
459 }
460 Arithmetic::Trunc(op) => {
461 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
462 b.declare_math_mode(modes, out);
463 T::trunc(b, ty, input, out);
464 if matches!(out_ty.elem(), Elem::Relaxed) {
465 b.decorate(out, Decoration::RelaxedPrecision, []);
466 }
467 })
468 }
469 Arithmetic::Clamp(op) => {
470 let input = self.compile_variable(op.input);
471 let min = self.compile_variable(op.min_value);
472 let max = self.compile_variable(op.max_value);
473 let out = self.compile_variable(out);
474 let out_ty = out.item();
475
476 let input = self.read_as(&input, &out_ty);
477 let min = self.read_as(&min, &out_ty);
478 let max = self.read_as(&max, &out_ty);
479 let out_id = self.write_id(&out);
480 self.mark_uniformity(out_id, uniform);
481
482 let ty = out_ty.id(self);
483
484 match out_ty.elem() {
485 Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
486 Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
487 Elem::Float(..) => {
488 self.declare_math_mode(modes, out_id);
489 T::f_clamp(self, ty, input, min, max, out_id)
490 }
491 Elem::Relaxed => {
492 self.decorate(out_id, Decoration::RelaxedPrecision, []);
493 self.declare_math_mode(modes, out_id);
494 T::f_clamp(self, ty, input, min, max, out_id)
495 }
496 _ => unreachable!(),
497 }
498 self.write(&out, out_id);
499 }
500
501 Arithmetic::Max(op) => self.compile_binary_op(
502 op,
503 out,
504 uniform,
505 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
506 Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
507 Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
508 Elem::Float(..) => {
509 b.declare_math_mode(modes, out);
510 T::f_max(b, ty, lhs, rhs, out)
511 }
512 Elem::Relaxed => {
513 b.decorate(out, Decoration::RelaxedPrecision, []);
514 b.declare_math_mode(modes, out);
515 T::f_max(b, ty, lhs, rhs, out)
516 }
517 _ => unreachable!(),
518 },
519 ),
520 Arithmetic::Min(op) => self.compile_binary_op(
521 op,
522 out,
523 uniform,
524 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
525 Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
526 Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
527 Elem::Float(..) => {
528 b.declare_math_mode(modes, out);
529 T::f_min(b, ty, lhs, rhs, out)
530 }
531 Elem::Relaxed => {
532 b.decorate(out, Decoration::RelaxedPrecision, []);
533 b.declare_math_mode(modes, out);
534 T::f_min(b, ty, lhs, rhs, out)
535 }
536 _ => unreachable!(),
537 },
538 ),
539 }
540 }
541}