1use crate::{
2 flex32,
3 frontend::{CubePrimitive, ExpandElementTyped},
4 prelude::*,
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use crate::{
12 ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope},
13 prelude::assign_op_expand,
14};
15use core::ops::*;
16use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
17use cubecl_ir::ClampOperator;
18use half::{bf16, f16};
19
20pub mod add {
21 use super::*;
22
23 pub fn expand<C: CubePrimitive>(
24 scope: &mut Scope,
25 lhs: ExpandElementTyped<C>,
26 rhs: ExpandElementTyped<C>,
27 ) -> ExpandElementTyped<C> {
28 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
29 }
30}
31
32pub mod sub {
33 use cubecl_ir::{ConstantValue, Variable};
34
35 use super::*;
36
37 pub fn expand<C: CubePrimitive>(
38 scope: &mut Scope,
39 lhs: ExpandElementTyped<C>,
40 rhs: ExpandElementTyped<C>,
41 ) -> ExpandElementTyped<C> {
42 match (lhs.expand.as_const(), rhs.expand.as_const()) {
44 (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
45 let item_lhs = lhs.expand.ty;
46 let item_rhs = rhs.expand.ty;
47
48 let line_size = find_vectorization(item_lhs, item_rhs);
49
50 let item = item_lhs.line(line_size);
51 let value = (lhs_val - rhs_val).into();
52 ExpandElement::Plain(Variable::constant(value, item)).into()
53 }
54 _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
55 }
56 }
57}
58
59pub mod mul {
60 use super::*;
61
62 pub fn expand<C: CubePrimitive>(
63 scope: &mut Scope,
64 lhs: ExpandElementTyped<C>,
65 rhs: ExpandElementTyped<C>,
66 ) -> ExpandElementTyped<C> {
67 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
68 }
69}
70
71pub mod div {
72 use super::*;
73
74 pub fn expand<C: CubePrimitive>(
75 scope: &mut Scope,
76 lhs: ExpandElementTyped<C>,
77 rhs: ExpandElementTyped<C>,
78 ) -> ExpandElementTyped<C> {
79 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
80 }
81}
82
83pub mod rem {
84 use super::*;
85
86 pub fn expand<C: CubePrimitive>(
87 scope: &mut Scope,
88 lhs: ExpandElementTyped<C>,
89 rhs: ExpandElementTyped<C>,
90 ) -> ExpandElementTyped<C> {
91 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
92 }
93}
94
95pub mod and {
96 use super::*;
97
98 pub fn expand<C: CubePrimitive>(
99 scope: &mut Scope,
100 lhs: ExpandElementTyped<C>,
101 rhs: ExpandElementTyped<C>,
102 ) -> ExpandElementTyped<bool> {
103 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
104 }
105}
106
107pub mod bitand {
108 use super::*;
109
110 pub fn expand<C: CubePrimitive>(
111 scope: &mut Scope,
112 lhs: ExpandElementTyped<C>,
113 rhs: ExpandElementTyped<C>,
114 ) -> ExpandElementTyped<C> {
115 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
116 }
117}
118
119pub mod bitor {
120 use super::*;
121
122 pub fn expand<C: CubePrimitive>(
123 scope: &mut Scope,
124 lhs: ExpandElementTyped<C>,
125 rhs: ExpandElementTyped<C>,
126 ) -> ExpandElementTyped<C> {
127 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
128 }
129}
130
131pub mod or {
132 use super::*;
133
134 pub fn expand<C: CubePrimitive>(
135 scope: &mut Scope,
136 lhs: ExpandElementTyped<C>,
137 rhs: ExpandElementTyped<C>,
138 ) -> ExpandElementTyped<bool> {
139 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
140 }
141}
142
143pub mod bitxor {
144 use super::*;
145
146 pub fn expand<C: CubePrimitive>(
147 scope: &mut Scope,
148 lhs: ExpandElementTyped<C>,
149 rhs: ExpandElementTyped<C>,
150 ) -> ExpandElementTyped<C> {
151 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
152 }
153}
154
155pub mod shl {
156 use super::*;
157
158 pub fn expand<C: CubePrimitive>(
159 scope: &mut Scope,
160 lhs: ExpandElementTyped<C>,
161 rhs: ExpandElementTyped<C>,
162 ) -> ExpandElementTyped<C> {
163 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
164 }
165}
166
167pub mod shr {
168 use super::*;
169
170 pub fn expand<C: CubePrimitive>(
171 scope: &mut Scope,
172 lhs: ExpandElementTyped<C>,
173 rhs: ExpandElementTyped<C>,
174 ) -> ExpandElementTyped<C> {
175 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
176 }
177}
178
179pub mod clamp {
180 use super::*;
181
182 pub fn expand<C: PartialOrd + CubePrimitive>(
183 scope: &mut Scope,
184 input: ExpandElementTyped<C>,
185 min: ExpandElementTyped<C>,
186 max: ExpandElementTyped<C>,
187 ) -> ExpandElementTyped<C> {
188 unary_expand(scope, input.into(), |op| {
189 Arithmetic::Clamp(ClampOperator {
190 input: op.input,
191 min_value: *min.expand,
192 max_value: *max.expand,
193 })
194 })
195 .into()
196 }
197}
198
199pub mod clamp_max {
200 use super::*;
201
202 pub fn expand<C: PartialOrd + CubePrimitive>(
203 scope: &mut Scope,
204 lhs: ExpandElementTyped<C>,
205 rhs: ExpandElementTyped<C>,
206 ) -> ExpandElementTyped<C> {
207 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
208 }
209}
210
211pub mod clamp_min {
212 use super::*;
213
214 pub fn expand<C: PartialOrd + CubePrimitive>(
215 scope: &mut Scope,
216 lhs: ExpandElementTyped<C>,
217 rhs: ExpandElementTyped<C>,
218 ) -> ExpandElementTyped<C> {
219 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
220 }
221}
222
223pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
226 clamp_max(lhs, rhs)
227}
228
229pub mod min {
230 use super::*;
231
232 pub fn expand<C: PartialOrd + CubePrimitive>(
233 scope: &mut Scope,
234 lhs: ExpandElementTyped<C>,
235 rhs: ExpandElementTyped<C>,
236 ) -> ExpandElementTyped<C> {
237 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
238 }
239}
240
241pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
244 clamp_min(lhs, rhs)
245}
246
247pub mod max {
248 use super::*;
249
250 pub fn expand<C: PartialOrd + CubePrimitive>(
251 scope: &mut Scope,
252 lhs: ExpandElementTyped<C>,
253 rhs: ExpandElementTyped<C>,
254 ) -> ExpandElementTyped<C> {
255 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
256 }
257}
258
259macro_rules! impl_binary_func {
261 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
262 paste::paste! {
263 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
264 fn $method_name(self, _rhs: Self) -> Self {
265 unexpanded!()
266 }
267
268 fn [<__expand_ $method_name>](
269 scope: &mut Scope,
270 lhs: ExpandElementTyped<Self>,
271 rhs: ExpandElementTyped<Self>,
272 ) -> ExpandElementTyped<Self> {
273 lhs.[<__expand_ $method_name _method>](scope, rhs)
274 }
275 }
276
277 pub trait [<$trait_name Expand>] {
278 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
279 }
280
281 $(impl $trait_name for $type {})*
282 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
283 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
284 binary_expand(scope, self.into(), rhs.into(), $operator).into()
285 }
286 }
287 }
288 }
289}
290
291macro_rules! impl_binary_func_fixed_output_vectorization {
292 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
293 paste::paste! {
294 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
295 fn $method_name(self, _rhs: Self) -> Self {
296 unexpanded!()
297 }
298
299 fn [<__expand_ $method_name>](
300 scope: &mut Scope,
301 lhs: ExpandElementTyped<Self>,
302 rhs: ExpandElementTyped<Self>,
303 ) -> ExpandElementTyped<Self> {
304 lhs.[<__expand_ $method_name _method>](scope, rhs)
305 }
306 }
307
308 pub trait [<$trait_name Expand>] {
309 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
310 }
311
312 $(impl $trait_name for $type {})*
313 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
314 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
315 let lhs: ExpandElement = self.into();
316 let item = lhs.ty.line($out_vectorization);
317 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
318 }
319 }
320 }
321 }
322}
323
324macro_rules! impl_binary_func_mixed_types {
325 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
326 paste::paste! {
327 pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ExpandElement>> + Sized>:
328 CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
329 fn $method_name(self, _rhs: Rhs) -> Self {
330 unexpanded!()
331 }
332
333 fn [<__expand_ $method_name>](
334 scope: &mut Scope,
335 lhs: ExpandElementTyped<Self>,
336 rhs: ExpandElementTyped<Rhs>,
337 ) -> ExpandElementTyped<Self> {
338 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
339 }
340 }
341
342 pub trait [<$trait_name Expand>]<Rhs: CubeType>{
343 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
344 }
345
346 $(impl $trait_name<$rhs_ty> for $type {})*
347 impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for ExpandElementTyped<T> {
348 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<Rhs>) -> Self {
349 binary_expand(scope, self.into(), rhs.into(), $operator).into()
350 }
351 }
352 }
353 }
354}
355
356macro_rules! impl_core_binop {
357 ($trait: ident, $method: ident, $op: expr) => {
358 paste::paste! {
359 pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
360 fn [<__expand_ $method>](
361 scope: &mut Scope,
362 lhs: ExpandElementTyped<Self>,
363 rhs: ExpandElementTyped<Self>,
364 ) -> ExpandElementTyped<Self> {
365 lhs.[<__expand_ $method _method>](scope, rhs)
366 }
367 }
368
369 pub trait [<$trait Expand>] {
370 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
371 }
372
373 impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
374 impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
375 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
376 binary_expand(scope, self.into(), rhs.into(), $op).into()
377 }
378 }
379 }
380 };
381}
382
383macro_rules! impl_core_assign_binop {
384 ($trait: ident, $method: ident, $op: expr) => {
385 paste::paste! {
386 pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
387 fn [<__expand_ $method>](
388 scope: &mut Scope,
389 lhs: ExpandElementTyped<Self>,
390 rhs: ExpandElementTyped<Self>,
391 ) {
392 lhs.[<__expand_ $method _method>](scope, rhs)
393 }
394 }
395
396 pub trait [<$trait Expand>] {
397 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
398 }
399
400 impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
401 impl<T: $trait + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
402 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
403 assign_op_expand(scope, self.into(), rhs.into(), $op);
404 }
405 }
406 }
407 };
408}
409
410impl_core_binop!(Add, add, Arithmetic::Add);
411impl_core_binop!(Sub, sub, Arithmetic::Sub);
412impl_core_binop!(Mul, mul, Arithmetic::Mul);
413impl_core_binop!(Div, mul, Arithmetic::Div);
414impl_core_binop!(Rem, rem, Arithmetic::Modulo);
415
416impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
417impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
418impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
419impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
420impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
421
422pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
423 fn __expand_min(
424 scope: &mut Scope,
425 lhs: Self::ExpandType,
426 rhs: Self::ExpandType,
427 ) -> Self::ExpandType {
428 lhs.__expand_min_method(scope, rhs)
429 }
430
431 fn __expand_max(
432 scope: &mut Scope,
433 lhs: Self::ExpandType,
434 rhs: Self::ExpandType,
435 ) -> Self::ExpandType {
436 lhs.__expand_max_method(scope, rhs)
437 }
438
439 fn __expand_clamp(
440 scope: &mut Scope,
441 lhs: Self::ExpandType,
442 min: Self::ExpandType,
443 max: Self::ExpandType,
444 ) -> Self::ExpandType {
445 lhs.__expand_clamp_method(scope, min, max)
446 }
447}
448pub trait OrdExpand {
449 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
450 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
451 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
452}
453
454impl<T: Ord + CubePrimitive> CubeOrd for T {}
455impl<T: Ord + CubePrimitive> OrdExpand for ExpandElementTyped<T> {
456 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
457 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
458 }
459 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
460 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
461 }
462 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
463 unary_expand(scope, self.into(), |op| {
464 Arithmetic::Clamp(ClampOperator {
465 input: op.input,
466 min_value: *min.expand,
467 max_value: *max.expand,
468 })
469 })
470 .into()
471 }
472}
473
474impl_binary_func!(
475 Powf,
476 powf,
477 Arithmetic::Powf,
478 f16,
479 bf16,
480 flex32,
481 tf32,
482 f32,
483 f64
484);
485
486impl_binary_func!(
487 Hypot,
488 hypot,
489 Arithmetic::Hypot,
490 f16,
491 bf16,
492 flex32,
493 tf32,
494 f32,
495 f64
496);
497
498impl_binary_func!(
499 Rhypot,
500 rhypot,
501 Arithmetic::Rhypot,
502 f16,
503 bf16,
504 flex32,
505 tf32,
506 f32,
507 f64
508);
509
510impl_binary_func!(
511 ArcTan2,
512 atan2,
513 Arithmetic::ArcTan2,
514 f16,
515 bf16,
516 flex32,
517 tf32,
518 f32,
519 f64
520);
521impl_binary_func!(
522 Remainder,
523 rem,
524 Arithmetic::Remainder,
525 e2m1,
526 e4m3,
527 e5m2,
528 ue8m0,
529 f16,
530 bf16,
531 flex32,
532 tf32,
533 f32,
534 f64,
535 i8,
536 i16,
537 i32,
538 i64,
539 u8,
540 u16,
541 u32,
542 u64,
543 usize,
544 isize
545);
546impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
547impl_binary_func!(
548 SaturatingAdd,
549 saturating_add,
550 Arithmetic::SaturatingAdd,
551 i8,
552 i16,
553 i32,
554 i64,
555 u8,
556 u16,
557 u32,
558 u64,
559 usize,
560 isize
561);
562impl_binary_func!(
563 SaturatingSub,
564 saturating_sub,
565 Arithmetic::SaturatingSub,
566 i8,
567 i16,
568 i32,
569 i64,
570 u8,
571 u16,
572 u32,
573 u64,
574 usize,
575 isize
576);
577impl_binary_func_fixed_output_vectorization!(
578 Dot,
579 dot,
580 Arithmetic::Dot,
581 0,
582 f16,
583 bf16,
584 flex32,
585 tf32,
586 f32,
587 f64,
588 i8,
589 i16,
590 i32,
591 i64,
592 u8,
593 u16,
594 u32,
595 u64,
596 usize,
597 isize
598);
599
600impl_binary_func_mixed_types!(
601 Powi,
602 powi,
603 i32,
604 Arithmetic::Powi,
605 f16,
606 bf16,
607 flex32,
608 tf32,
609 f32,
610 f64,
611 i8,
612 i16,
613 i32,
614 i64,
615 u8,
616 u16,
617 u32,
618 u64,
619 usize,
620 isize
621);