1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 variable::ConstVal,
5};
6use cubecl_core::ir::{self as core, Arithmetic};
7use rspirv::spirv::{Capability, Decoration, FPEncoding};
8
9impl<T: SpirvTarget> SpirvCompiler<T> {
10 pub fn compile_arithmetic(
11 &mut self,
12 op: Arithmetic,
13 out: Option<core::Variable>,
14 uniform: bool,
15 ) {
16 let out = out.unwrap();
17 match op {
18 Arithmetic::Add(op) => {
19 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
20 match out_ty.elem() {
21 Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(),
22 Elem::Float(..) => b.f_add(ty, Some(out), lhs, rhs).unwrap(),
23 Elem::Relaxed => {
24 b.decorate(out, Decoration::RelaxedPrecision, []);
25 b.f_add(ty, Some(out), lhs, rhs).unwrap()
26 }
27 _ => unreachable!(),
28 };
29 });
30 }
31 Arithmetic::SaturatingAdd(_) => {
32 unimplemented!("Should be replaced by polyfill");
33 }
34 Arithmetic::Sub(op) => {
35 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
36 match out_ty.elem() {
37 Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(),
38 Elem::Float(..) => b.f_sub(ty, Some(out), lhs, rhs).unwrap(),
39 Elem::Relaxed => {
40 b.decorate(out, Decoration::RelaxedPrecision, []);
41 b.f_sub(ty, Some(out), lhs, rhs).unwrap()
42 }
43 _ => unreachable!(),
44 };
45 });
46 }
47 Arithmetic::SaturatingSub(_) => {
48 unimplemented!("Should be replaced by polyfill");
49 }
50 Arithmetic::Mul(op) => {
51 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
52 match out_ty.elem() {
53 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
54 Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(),
55 Elem::Relaxed => {
56 b.decorate(out, Decoration::RelaxedPrecision, []);
57 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
58 }
59 _ => unreachable!(),
60 };
61 });
62 }
63 Arithmetic::MulHi(op) => {
64 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
65 let out_st = b.type_struct([ty, ty]);
66 let extended = match out_ty.elem() {
67 Elem::Int(_, false) => b.u_mul_extended(out_st, None, lhs, rhs).unwrap(),
68 Elem::Int(_, true) => b.s_mul_extended(out_st, None, lhs, rhs).unwrap(),
69 _ => unreachable!(),
70 };
71 b.composite_extract(ty, Some(out), extended, [1]).unwrap();
72 });
73 }
74 Arithmetic::Div(op) => {
75 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
76 match out_ty.elem() {
77 Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(),
78 Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(),
79 Elem::Float(..) => b.f_div(ty, Some(out), lhs, rhs).unwrap(),
80 Elem::Relaxed => {
81 b.decorate(out, Decoration::RelaxedPrecision, []);
82 b.f_div(ty, Some(out), lhs, rhs).unwrap()
83 }
84 _ => unreachable!(),
85 };
86 });
87 }
88 Arithmetic::Remainder(op) => {
89 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
90 match out_ty.elem() {
91 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
92 Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(),
93 Elem::Float(..) => b.f_mod(ty, Some(out), lhs, rhs).unwrap(),
94 Elem::Relaxed => {
95 b.decorate(out, Decoration::RelaxedPrecision, []);
96 b.f_mod(ty, Some(out), lhs, rhs).unwrap()
97 }
98 _ => unreachable!(),
99 };
100 });
101 }
102 Arithmetic::Modulo(op) => {
103 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
104 match out_ty.elem() {
105 Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(),
106 Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(),
107 Elem::Float(..) => b.f_rem(ty, Some(out), lhs, rhs).unwrap(),
108 Elem::Relaxed => {
109 b.decorate(out, Decoration::RelaxedPrecision, []);
110 b.f_rem(ty, Some(out), lhs, rhs).unwrap()
111 }
112 _ => unreachable!(),
113 };
114 });
115 }
116 Arithmetic::Dot(op) => {
117 if op.lhs.ty.line_size() == 1 {
118 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
119 match out_ty.elem() {
120 Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(),
121 Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(),
122 Elem::Relaxed => {
123 b.decorate(out, Decoration::RelaxedPrecision, []);
124 b.f_mul(ty, Some(out), lhs, rhs).unwrap()
125 }
126 _ => unreachable!(),
127 };
128 });
129 } else {
130 let lhs = self.compile_variable(op.lhs);
131 let rhs = self.compile_variable(op.rhs);
132 let out = self.compile_variable(out);
133 let ty = out.item().id(self);
134
135 let lhs_id = self.read(&lhs);
136 let rhs_id = self.read(&rhs);
137 let out_id = self.write_id(&out);
138 self.mark_uniformity(out_id, uniform);
139
140 if matches!(lhs.elem(), Elem::Int(_, _)) {
141 self.capabilities.insert(Capability::DotProduct);
142 }
143 if matches!(lhs.elem(), Elem::Float(16, Some(FPEncoding::BFloat16KHR))) {
144 self.capabilities.insert(Capability::BFloat16DotProductKHR);
145 }
146
147 match (lhs.elem(), rhs.elem()) {
148 (Elem::Int(_, false), Elem::Int(_, false)) => {
149 self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None)
150 }
151 (Elem::Int(_, true), Elem::Int(_, false)) => {
152 self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None)
153 }
154 (Elem::Int(_, false), Elem::Int(_, true)) => {
155 self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None)
156 }
157 (Elem::Int(_, true), Elem::Int(_, true)) => {
158 self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None)
159 }
160 (Elem::Float(..), Elem::Float(..))
161 | (Elem::Relaxed, Elem::Float(..))
162 | (Elem::Float(..), Elem::Relaxed) => {
163 self.dot(ty, Some(out_id), lhs_id, rhs_id)
164 }
165 (Elem::Relaxed, Elem::Relaxed) => {
166 self.decorate(out_id, Decoration::RelaxedPrecision, []);
167 self.dot(ty, Some(out_id), lhs_id, rhs_id)
168 }
169 _ => unreachable!(),
170 }
171 .unwrap();
172 self.write(&out, out_id);
173 }
174 }
175 Arithmetic::Fma(op) => {
176 let a = self.compile_variable(op.a);
177 let b = self.compile_variable(op.b);
178 let c = self.compile_variable(op.c);
179 let out = self.compile_variable(out);
180 let out_ty = out.item();
181 let relaxed = matches!(
182 (a.item().elem(), b.item().elem(), c.item().elem()),
183 (Elem::Relaxed, Elem::Relaxed, Elem::Relaxed)
184 );
185
186 let a_id = self.read_as(&a, &out_ty);
187 let b_id = self.read_as(&b, &out_ty);
188 let c_id = self.read_as(&c, &out_ty);
189 let out_id = self.write_id(&out);
190 self.mark_uniformity(out_id, uniform);
191
192 let ty = out_ty.id(self);
193
194 let mul = self.f_mul(ty, None, a_id, b_id).unwrap();
195 self.mark_uniformity(mul, uniform);
196 self.f_add(ty, Some(out_id), mul, c_id).unwrap();
197 if relaxed {
198 self.decorate(mul, Decoration::RelaxedPrecision, []);
199 self.decorate(out_id, Decoration::RelaxedPrecision, []);
200 }
201 self.write(&out, out_id);
202 }
203 Arithmetic::Recip(op) => {
204 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
205 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
206 b.f_div(ty, Some(out), one, input).unwrap();
207 });
208 }
209 Arithmetic::Neg(op) => {
210 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
211 match out_ty.elem() {
212 Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(),
213 Elem::Float(..) => b.f_negate(ty, Some(out), input).unwrap(),
214 Elem::Relaxed => {
215 b.decorate(out, Decoration::RelaxedPrecision, []);
216 b.f_negate(ty, Some(out), input).unwrap()
217 }
218 _ => unreachable!(),
219 };
220 });
221 }
222 Arithmetic::Erf(_) => {
223 unreachable!("Replaced by transformer")
224 }
225
226 Arithmetic::Normalize(op) => {
228 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
229 T::normalize(b, ty, input, out);
230 if matches!(out_ty.elem(), Elem::Relaxed) {
231 b.decorate(out, Decoration::RelaxedPrecision, []);
232 }
233 });
234 }
235 Arithmetic::Magnitude(op) => {
236 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
237 T::magnitude(b, ty, input, out);
238 if matches!(out_ty.elem(), Elem::Relaxed) {
239 b.decorate(out, Decoration::RelaxedPrecision, []);
240 }
241 });
242 }
243 Arithmetic::Abs(op) => {
244 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
245 match out_ty.elem() {
246 Elem::Int(_, _) => T::s_abs(b, ty, input, out),
247 Elem::Float(..) => T::f_abs(b, ty, input, out),
248 Elem::Relaxed => {
249 b.decorate(out, Decoration::RelaxedPrecision, []);
250 T::f_abs(b, ty, input, out)
251 }
252 _ => unreachable!(),
253 }
254 });
255 }
256 Arithmetic::Exp(op) => {
257 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
258 T::exp(b, ty, input, out);
259 if matches!(out_ty.elem(), Elem::Relaxed) {
260 b.decorate(out, Decoration::RelaxedPrecision, []);
261 }
262 });
263 }
264 Arithmetic::Log(op) => {
265 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
266 T::log(b, ty, input, out);
267 if matches!(out_ty.elem(), Elem::Relaxed) {
268 b.decorate(out, Decoration::RelaxedPrecision, []);
269 }
270 })
271 }
272 Arithmetic::Log1p(op) => {
273 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
274 let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty);
275 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
276 let add = match out_ty.elem() {
277 Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(),
278 Elem::Float(..) | Elem::Relaxed => b.f_add(ty, None, input, one).unwrap(),
279 _ => unreachable!(),
280 };
281 b.mark_uniformity(add, uniform);
282 if relaxed {
283 b.decorate(add, Decoration::RelaxedPrecision, []);
284 b.decorate(out, Decoration::RelaxedPrecision, []);
285 }
286 T::log(b, ty, add, out)
287 });
288 }
289 Arithmetic::Cos(op) => {
290 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
291 T::cos(b, ty, input, out);
292 if matches!(out_ty.elem(), Elem::Relaxed) {
293 b.decorate(out, Decoration::RelaxedPrecision, []);
294 }
295 })
296 }
297 Arithmetic::Sin(op) => {
298 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
299 T::sin(b, ty, input, out);
300 if matches!(out_ty.elem(), Elem::Relaxed) {
301 b.decorate(out, Decoration::RelaxedPrecision, []);
302 }
303 })
304 }
305 Arithmetic::Tanh(op) => {
306 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
307 T::tanh(b, ty, input, out);
308 if matches!(out_ty.elem(), Elem::Relaxed) {
309 b.decorate(out, Decoration::RelaxedPrecision, []);
310 }
311 })
312 }
313 Arithmetic::Powf(op) | Arithmetic::Powi(op) => {
315 self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| {
316 let bool = match out_ty {
317 Item::Scalar(_) => Elem::Bool.id(b),
318 Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b),
319 _ => unreachable!(),
320 };
321 let relaxed = matches!(out_ty.elem(), Elem::Relaxed);
322 let zero = out_ty.const_u32(b, 0);
323 let one = out_ty.const_u32(b, 1);
324 let two = out_ty.const_u32(b, 2);
325 let modulo = b.f_rem(ty, None, rhs, two).unwrap();
326 let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap();
327 let abs = b.id();
328 T::f_abs(b, ty, lhs, abs);
329 let even = b.id();
330 T::pow(b, ty, abs, rhs, even);
331 let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap();
332 let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap();
333 let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap();
334 let neg_lhs = b.f_negate(ty, None, lhs).unwrap();
335 let pow2 = b.id();
336 T::pow(b, ty, neg_lhs, rhs, pow2);
337 let pow2_neg = b.f_negate(ty, None, pow2).unwrap();
338 let default = b.id();
339 T::pow(b, ty, lhs, rhs, default);
340 let ids = [
341 modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg,
342 default,
343 ];
344 for id in ids {
345 b.mark_uniformity(id, uniform);
346 if relaxed {
347 b.decorate(id, Decoration::RelaxedPrecision, []);
348 }
349 }
350 let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap();
351 b.mark_uniformity(sel1, uniform);
352 b.select(ty, Some(out), is_zero, even, sel1).unwrap();
353 })
354 }
355 Arithmetic::Sqrt(op) => {
356 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
357 T::sqrt(b, ty, input, out);
358 if matches!(out_ty.elem(), Elem::Relaxed) {
359 b.decorate(out, Decoration::RelaxedPrecision, []);
360 }
361 })
362 }
363 Arithmetic::Round(op) => {
364 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
365 T::round(b, ty, input, out);
366 if matches!(out_ty.elem(), Elem::Relaxed) {
367 b.decorate(out, Decoration::RelaxedPrecision, []);
368 }
369 })
370 }
371 Arithmetic::Floor(op) => {
372 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
373 T::floor(b, ty, input, out);
374 if matches!(out_ty.elem(), Elem::Relaxed) {
375 b.decorate(out, Decoration::RelaxedPrecision, []);
376 }
377 })
378 }
379 Arithmetic::Ceil(op) => {
380 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
381 T::ceil(b, ty, input, out);
382 if matches!(out_ty.elem(), Elem::Relaxed) {
383 b.decorate(out, Decoration::RelaxedPrecision, []);
384 }
385 })
386 }
387 Arithmetic::Trunc(op) => {
388 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
389 T::trunc(b, ty, input, out);
390 if matches!(out_ty.elem(), Elem::Relaxed) {
391 b.decorate(out, Decoration::RelaxedPrecision, []);
392 }
393 })
394 }
395 Arithmetic::Clamp(op) => {
396 let input = self.compile_variable(op.input);
397 let min = self.compile_variable(op.min_value);
398 let max = self.compile_variable(op.max_value);
399 let out = self.compile_variable(out);
400 let out_ty = out.item();
401
402 let input = self.read_as(&input, &out_ty);
403 let min = self.read_as(&min, &out_ty);
404 let max = self.read_as(&max, &out_ty);
405 let out_id = self.write_id(&out);
406 self.mark_uniformity(out_id, uniform);
407
408 let ty = out_ty.id(self);
409
410 match out_ty.elem() {
411 Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id),
412 Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id),
413 Elem::Float(..) => T::f_clamp(self, ty, input, min, max, out_id),
414 Elem::Relaxed => {
415 self.decorate(out_id, Decoration::RelaxedPrecision, []);
416 T::f_clamp(self, ty, input, min, max, out_id)
417 }
418 _ => unreachable!(),
419 }
420 self.write(&out, out_id);
421 }
422
423 Arithmetic::Max(op) => self.compile_binary_op(
424 op,
425 out,
426 uniform,
427 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
428 Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out),
429 Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out),
430 Elem::Float(..) => T::f_max(b, ty, lhs, rhs, out),
431 Elem::Relaxed => {
432 b.decorate(out, Decoration::RelaxedPrecision, []);
433 T::f_max(b, ty, lhs, rhs, out)
434 }
435 _ => unreachable!(),
436 },
437 ),
438 Arithmetic::Min(op) => self.compile_binary_op(
439 op,
440 out,
441 uniform,
442 |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() {
443 Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out),
444 Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out),
445 Elem::Float(..) => T::f_min(b, ty, lhs, rhs, out),
446 Elem::Relaxed => {
447 b.decorate(out, Decoration::RelaxedPrecision, []);
448 T::f_min(b, ty, lhs, rhs, out)
449 }
450 _ => unreachable!(),
451 },
452 ),
453 }
454 }
455}