1use crate::{
2 flex32,
3 frontend::{CubePrimitive, ExpandElementTyped},
4 prelude::*,
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use crate::{
12 ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope},
13 prelude::assign_op_expand,
14};
15use core::ops::*;
16use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
17use cubecl_ir::ClampOperator;
18use half::{bf16, f16};
19
20pub mod add {
21 use super::*;
22
23 pub fn expand<C: CubePrimitive>(
24 scope: &mut Scope,
25 lhs: ExpandElementTyped<C>,
26 rhs: ExpandElementTyped<C>,
27 ) -> ExpandElementTyped<C> {
28 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
29 }
30}
31
32pub mod sub {
33 use super::*;
34
35 pub fn expand<C: CubePrimitive>(
36 scope: &mut Scope,
37 lhs: ExpandElementTyped<C>,
38 rhs: ExpandElementTyped<C>,
39 ) -> ExpandElementTyped<C> {
40 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
41 }
42}
43
44pub mod mul {
45 use super::*;
46
47 pub fn expand<C: CubePrimitive>(
48 scope: &mut Scope,
49 lhs: ExpandElementTyped<C>,
50 rhs: ExpandElementTyped<C>,
51 ) -> ExpandElementTyped<C> {
52 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
53 }
54}
55
56pub mod div {
57 use super::*;
58
59 pub fn expand<C: CubePrimitive>(
60 scope: &mut Scope,
61 lhs: ExpandElementTyped<C>,
62 rhs: ExpandElementTyped<C>,
63 ) -> ExpandElementTyped<C> {
64 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
65 }
66}
67
68pub mod rem {
69 use super::*;
70
71 pub fn expand<C: CubePrimitive>(
72 scope: &mut Scope,
73 lhs: ExpandElementTyped<C>,
74 rhs: ExpandElementTyped<C>,
75 ) -> ExpandElementTyped<C> {
76 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
77 }
78}
79
80pub mod and {
81 use super::*;
82
83 pub fn expand<C: CubePrimitive>(
84 scope: &mut Scope,
85 lhs: ExpandElementTyped<C>,
86 rhs: ExpandElementTyped<C>,
87 ) -> ExpandElementTyped<bool> {
88 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
89 }
90}
91
92pub mod bitand {
93 use super::*;
94
95 pub fn expand<C: CubePrimitive>(
96 scope: &mut Scope,
97 lhs: ExpandElementTyped<C>,
98 rhs: ExpandElementTyped<C>,
99 ) -> ExpandElementTyped<C> {
100 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
101 }
102}
103
104pub mod bitor {
105 use super::*;
106
107 pub fn expand<C: CubePrimitive>(
108 scope: &mut Scope,
109 lhs: ExpandElementTyped<C>,
110 rhs: ExpandElementTyped<C>,
111 ) -> ExpandElementTyped<C> {
112 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
113 }
114}
115
116pub mod or {
117 use super::*;
118
119 pub fn expand<C: CubePrimitive>(
120 scope: &mut Scope,
121 lhs: ExpandElementTyped<C>,
122 rhs: ExpandElementTyped<C>,
123 ) -> ExpandElementTyped<bool> {
124 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
125 }
126}
127
128pub mod bitxor {
129 use super::*;
130
131 pub fn expand<C: CubePrimitive>(
132 scope: &mut Scope,
133 lhs: ExpandElementTyped<C>,
134 rhs: ExpandElementTyped<C>,
135 ) -> ExpandElementTyped<C> {
136 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
137 }
138}
139
140pub mod shl {
141 use super::*;
142
143 pub fn expand<C: CubePrimitive>(
144 scope: &mut Scope,
145 lhs: ExpandElementTyped<C>,
146 rhs: ExpandElementTyped<C>,
147 ) -> ExpandElementTyped<C> {
148 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
149 }
150}
151
152pub mod shr {
153 use super::*;
154
155 pub fn expand<C: CubePrimitive>(
156 scope: &mut Scope,
157 lhs: ExpandElementTyped<C>,
158 rhs: ExpandElementTyped<C>,
159 ) -> ExpandElementTyped<C> {
160 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
161 }
162}
163
164pub mod clamp {
165 use super::*;
166
167 pub fn expand<C: PartialOrd + CubePrimitive>(
168 scope: &mut Scope,
169 input: ExpandElementTyped<C>,
170 min: ExpandElementTyped<C>,
171 max: ExpandElementTyped<C>,
172 ) -> ExpandElementTyped<C> {
173 unary_expand(scope, input.into(), |op| {
174 Arithmetic::Clamp(ClampOperator {
175 input: op.input,
176 min_value: *min.expand,
177 max_value: *max.expand,
178 })
179 })
180 .into()
181 }
182}
183
184pub mod clamp_max {
185 use super::*;
186
187 pub fn expand<C: PartialOrd + CubePrimitive>(
188 scope: &mut Scope,
189 lhs: ExpandElementTyped<C>,
190 rhs: ExpandElementTyped<C>,
191 ) -> ExpandElementTyped<C> {
192 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
193 }
194}
195
196pub mod clamp_min {
197 use super::*;
198
199 pub fn expand<C: PartialOrd + CubePrimitive>(
200 scope: &mut Scope,
201 lhs: ExpandElementTyped<C>,
202 rhs: ExpandElementTyped<C>,
203 ) -> ExpandElementTyped<C> {
204 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
205 }
206}
207
208pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
211 clamp_max(lhs, rhs)
212}
213
214pub mod min {
215 use super::*;
216
217 pub fn expand<C: PartialOrd + CubePrimitive>(
218 scope: &mut Scope,
219 lhs: ExpandElementTyped<C>,
220 rhs: ExpandElementTyped<C>,
221 ) -> ExpandElementTyped<C> {
222 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
223 }
224}
225
226pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
229 clamp_min(lhs, rhs)
230}
231
232pub mod max {
233 use super::*;
234
235 pub fn expand<C: PartialOrd + CubePrimitive>(
236 scope: &mut Scope,
237 lhs: ExpandElementTyped<C>,
238 rhs: ExpandElementTyped<C>,
239 ) -> ExpandElementTyped<C> {
240 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
241 }
242}
243
244macro_rules! impl_binary_func {
246 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
247 paste::paste! {
248 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
249 fn $method_name(self, _rhs: Self) -> Self {
250 unexpanded!()
251 }
252
253 fn [<__expand_ $method_name>](
254 scope: &mut Scope,
255 lhs: ExpandElementTyped<Self>,
256 rhs: ExpandElementTyped<Self>,
257 ) -> ExpandElementTyped<Self> {
258 lhs.[<__expand_ $method_name _method>](scope, rhs)
259 }
260 }
261
262 pub trait [<$trait_name Expand>] {
263 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
264 }
265
266 $(impl $trait_name for $type {})*
267 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
268 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
269 binary_expand(scope, self.into(), rhs.into(), $operator).into()
270 }
271 }
272 }
273 }
274}
275
276macro_rules! impl_binary_func_fixed_output_vectorization {
277 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
278 paste::paste! {
279 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
280 fn $method_name(self, _rhs: Self) -> Self {
281 unexpanded!()
282 }
283
284 fn [<__expand_ $method_name>](
285 scope: &mut Scope,
286 lhs: ExpandElementTyped<Self>,
287 rhs: ExpandElementTyped<Self>,
288 ) -> ExpandElementTyped<Self> {
289 lhs.[<__expand_ $method_name _method>](scope, rhs)
290 }
291 }
292
293 pub trait [<$trait_name Expand>] {
294 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
295 }
296
297 $(impl $trait_name for $type {})*
298 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
299 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
300 let lhs: ExpandElement = self.into();
301 let item = lhs.ty.line($out_vectorization);
302 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
303 }
304 }
305 }
306 }
307}
308
309macro_rules! impl_binary_func_mixed_types {
310 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
311 paste::paste! {
312 pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ExpandElement>> + Sized>:
313 CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
314 fn $method_name(self, _rhs: Rhs) -> Self {
315 unexpanded!()
316 }
317
318 fn [<__expand_ $method_name>](
319 scope: &mut Scope,
320 lhs: ExpandElementTyped<Self>,
321 rhs: ExpandElementTyped<Rhs>,
322 ) -> ExpandElementTyped<Self> {
323 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
324 }
325 }
326
327 pub trait [<$trait_name Expand>]<Rhs: CubeType>{
328 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
329 }
330
331 $(impl $trait_name<$rhs_ty> for $type {})*
332 impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for ExpandElementTyped<T> {
333 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<Rhs>) -> Self {
334 binary_expand(scope, self.into(), rhs.into(), $operator).into()
335 }
336 }
337 }
338 }
339}
340
341macro_rules! impl_core_binop {
342 ($trait: ident, $method: ident, $op: expr) => {
343 paste::paste! {
344 pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
345 fn [<__expand_ $method>](
346 scope: &mut Scope,
347 lhs: ExpandElementTyped<Self>,
348 rhs: ExpandElementTyped<Self>,
349 ) -> ExpandElementTyped<Self> {
350 lhs.[<__expand_ $method _method>](scope, rhs)
351 }
352 }
353
354 pub trait [<$trait Expand>] {
355 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
356 }
357
358 impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
359 impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
360 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
361 binary_expand(scope, self.into(), rhs.into(), $op).into()
362 }
363 }
364 }
365 };
366}
367
368macro_rules! impl_core_assign_binop {
369 ($trait: ident, $method: ident, $op: expr) => {
370 paste::paste! {
371 pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
372 fn [<__expand_ $method>](
373 scope: &mut Scope,
374 lhs: ExpandElementTyped<Self>,
375 rhs: ExpandElementTyped<Self>,
376 ) {
377 lhs.[<__expand_ $method _method>](scope, rhs)
378 }
379 }
380
381 pub trait [<$trait Expand>] {
382 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
383 }
384
385 impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
386 impl<T: $trait + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
387 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
388 assign_op_expand(scope, self.into(), rhs.into(), $op);
389 }
390 }
391 }
392 };
393}
394
395impl_core_binop!(Add, add, Arithmetic::Add);
396impl_core_binop!(Sub, sub, Arithmetic::Sub);
397impl_core_binop!(Mul, mul, Arithmetic::Mul);
398impl_core_binop!(Div, mul, Arithmetic::Div);
399impl_core_binop!(Rem, rem, Arithmetic::Modulo);
400
401impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
402impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
403impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
404impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
405impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
406
407pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
408 fn __expand_min(
409 scope: &mut Scope,
410 lhs: Self::ExpandType,
411 rhs: Self::ExpandType,
412 ) -> Self::ExpandType {
413 lhs.__expand_min_method(scope, rhs)
414 }
415
416 fn __expand_max(
417 scope: &mut Scope,
418 lhs: Self::ExpandType,
419 rhs: Self::ExpandType,
420 ) -> Self::ExpandType {
421 lhs.__expand_max_method(scope, rhs)
422 }
423
424 fn __expand_clamp(
425 scope: &mut Scope,
426 lhs: Self::ExpandType,
427 min: Self::ExpandType,
428 max: Self::ExpandType,
429 ) -> Self::ExpandType {
430 lhs.__expand_clamp_method(scope, min, max)
431 }
432}
433pub trait OrdExpand {
434 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
435 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
436 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
437}
438
439impl<T: Ord + CubePrimitive> CubeOrd for T {}
440impl<T: Ord + CubePrimitive> OrdExpand for ExpandElementTyped<T> {
441 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
442 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
443 }
444 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
445 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
446 }
447 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
448 unary_expand(scope, self.into(), |op| {
449 Arithmetic::Clamp(ClampOperator {
450 input: op.input,
451 min_value: *min.expand,
452 max_value: *max.expand,
453 })
454 })
455 .into()
456 }
457}
458
459impl_binary_func!(
460 Powf,
461 powf,
462 Arithmetic::Powf,
463 f16,
464 bf16,
465 flex32,
466 tf32,
467 f32,
468 f64
469);
470
471impl_binary_func!(
472 Hypot,
473 hypot,
474 Arithmetic::Hypot,
475 f16,
476 bf16,
477 flex32,
478 tf32,
479 f32,
480 f64
481);
482
483impl_binary_func!(
484 Rhypot,
485 rhypot,
486 Arithmetic::Rhypot,
487 f16,
488 bf16,
489 flex32,
490 tf32,
491 f32,
492 f64
493);
494
495impl_binary_func!(
496 ArcTan2,
497 atan2,
498 Arithmetic::ArcTan2,
499 f16,
500 bf16,
501 flex32,
502 tf32,
503 f32,
504 f64
505);
506impl_binary_func!(
507 Remainder,
508 rem,
509 Arithmetic::Remainder,
510 e2m1,
511 e4m3,
512 e5m2,
513 ue8m0,
514 f16,
515 bf16,
516 flex32,
517 tf32,
518 f32,
519 f64,
520 i8,
521 i16,
522 i32,
523 i64,
524 u8,
525 u16,
526 u32,
527 u64,
528 usize,
529 isize
530);
531impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
532impl_binary_func!(
533 SaturatingAdd,
534 saturating_add,
535 Arithmetic::SaturatingAdd,
536 i8,
537 i16,
538 i32,
539 i64,
540 u8,
541 u16,
542 u32,
543 u64,
544 usize,
545 isize
546);
547impl_binary_func!(
548 SaturatingSub,
549 saturating_sub,
550 Arithmetic::SaturatingSub,
551 i8,
552 i16,
553 i32,
554 i64,
555 u8,
556 u16,
557 u32,
558 u64,
559 usize,
560 isize
561);
562impl_binary_func_fixed_output_vectorization!(
563 Dot,
564 dot,
565 Arithmetic::Dot,
566 0,
567 f16,
568 bf16,
569 flex32,
570 tf32,
571 f32,
572 f64,
573 i8,
574 i16,
575 i32,
576 i64,
577 u8,
578 u16,
579 u32,
580 u64,
581 usize,
582 isize
583);
584
585impl_binary_func_mixed_types!(
586 Powi,
587 powi,
588 i32,
589 Arithmetic::Powi,
590 f16,
591 bf16,
592 flex32,
593 tf32,
594 f32,
595 f64,
596 i8,
597 i16,
598 i32,
599 i64,
600 u8,
601 u16,
602 u32,
603 u64,
604 usize,
605 isize
606);