1use crate as cubecl;
2use crate::ir::{Arithmetic, Bitwise, ManagedVariable, Operator, Scope};
3use crate::{
4 flex32,
5 frontend::{CubePrimitive, NativeExpand},
6 prelude::*,
7};
8use crate::{frontend::CubeType, tf32};
9use crate::{
10 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
11 unexpanded,
12};
13use core::{cmp::Ordering, ops::*};
14use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
15use cubecl_ir::ClampOperator;
16use cubecl_macros::derive_expand;
17use half::{bf16, f16};
18
19pub mod add {
20 use super::*;
21
22 pub fn expand<C: CubePrimitive>(
23 scope: &mut Scope,
24 lhs: NativeExpand<C>,
25 rhs: NativeExpand<C>,
26 ) -> NativeExpand<C> {
27 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
28 }
29}
30
31pub mod sub {
32 use cubecl_ir::{ConstantValue, Variable};
33
34 use super::*;
35
36 pub fn expand<C: CubePrimitive>(
37 scope: &mut Scope,
38 lhs: NativeExpand<C>,
39 rhs: NativeExpand<C>,
40 ) -> NativeExpand<C> {
41 match (lhs.expand.as_const(), rhs.expand.as_const()) {
43 (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
44 let item_lhs = lhs.expand.ty;
45 let item_rhs = rhs.expand.ty;
46
47 let vector_size = find_vectorization(item_lhs, item_rhs);
48
49 let item = item_lhs.with_vector_size(vector_size);
50 let value = (lhs_val - rhs_val).into();
51 ManagedVariable::Plain(Variable::constant(value, item)).into()
52 }
53 _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
54 }
55 }
56}
57
58pub mod mul {
59 use super::*;
60
61 pub fn expand<C: CubePrimitive>(
62 scope: &mut Scope,
63 lhs: NativeExpand<C>,
64 rhs: NativeExpand<C>,
65 ) -> NativeExpand<C> {
66 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
67 }
68}
69
70pub mod div {
71 use super::*;
72
73 pub fn expand<C: CubePrimitive>(
74 scope: &mut Scope,
75 lhs: NativeExpand<C>,
76 rhs: NativeExpand<C>,
77 ) -> NativeExpand<C> {
78 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
79 }
80}
81
82pub mod rem {
83 use super::*;
84
85 pub fn expand<C: CubePrimitive>(
86 scope: &mut Scope,
87 lhs: NativeExpand<C>,
88 rhs: NativeExpand<C>,
89 ) -> NativeExpand<C> {
90 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
91 }
92}
93
94pub mod and {
95 use super::*;
96
97 pub fn expand<C: CubePrimitive>(
98 scope: &mut Scope,
99 lhs: NativeExpand<C>,
100 rhs: NativeExpand<C>,
101 ) -> NativeExpand<bool> {
102 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
103 }
104}
105
106pub mod bitand {
107 use super::*;
108
109 pub fn expand<C: CubePrimitive>(
110 scope: &mut Scope,
111 lhs: NativeExpand<C>,
112 rhs: NativeExpand<C>,
113 ) -> NativeExpand<C> {
114 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
115 }
116}
117
118pub mod bitor {
119 use super::*;
120
121 pub fn expand<C: CubePrimitive>(
122 scope: &mut Scope,
123 lhs: NativeExpand<C>,
124 rhs: NativeExpand<C>,
125 ) -> NativeExpand<C> {
126 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
127 }
128}
129
130pub mod or {
131 use super::*;
132
133 pub fn expand<C: CubePrimitive>(
134 scope: &mut Scope,
135 lhs: NativeExpand<C>,
136 rhs: NativeExpand<C>,
137 ) -> NativeExpand<bool> {
138 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
139 }
140}
141
142pub mod bitxor {
143 use super::*;
144
145 pub fn expand<C: CubePrimitive>(
146 scope: &mut Scope,
147 lhs: NativeExpand<C>,
148 rhs: NativeExpand<C>,
149 ) -> NativeExpand<C> {
150 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
151 }
152}
153
154pub mod shl {
155 use super::*;
156
157 pub fn expand<C: CubePrimitive>(
158 scope: &mut Scope,
159 lhs: NativeExpand<C>,
160 rhs: NativeExpand<C>,
161 ) -> NativeExpand<C> {
162 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
163 }
164}
165
166pub mod shr {
167 use super::*;
168
169 pub fn expand<C: CubePrimitive>(
170 scope: &mut Scope,
171 lhs: NativeExpand<C>,
172 rhs: NativeExpand<C>,
173 ) -> NativeExpand<C> {
174 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
175 }
176}
177
178pub mod clamp {
179 use super::*;
180
181 pub fn expand<C: PartialOrd + CubePrimitive>(
182 scope: &mut Scope,
183 input: NativeExpand<C>,
184 min: NativeExpand<C>,
185 max: NativeExpand<C>,
186 ) -> NativeExpand<C> {
187 unary_expand(scope, input.into(), |op| {
188 Arithmetic::Clamp(ClampOperator {
189 input: op.input,
190 min_value: *min.expand,
191 max_value: *max.expand,
192 })
193 })
194 .into()
195 }
196}
197
198pub mod clamp_max {
199 use super::*;
200
201 pub fn expand<C: PartialOrd + CubePrimitive>(
202 scope: &mut Scope,
203 lhs: NativeExpand<C>,
204 rhs: NativeExpand<C>,
205 ) -> NativeExpand<C> {
206 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
207 }
208}
209
210pub mod clamp_min {
211 use super::*;
212
213 pub fn expand<C: PartialOrd + CubePrimitive>(
214 scope: &mut Scope,
215 lhs: NativeExpand<C>,
216 rhs: NativeExpand<C>,
217 ) -> NativeExpand<C> {
218 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
219 }
220}
221
222pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
225 clamp_max(lhs, rhs)
226}
227
228pub mod min {
229 use super::*;
230
231 pub fn expand<C: PartialOrd + CubePrimitive>(
232 scope: &mut Scope,
233 lhs: NativeExpand<C>,
234 rhs: NativeExpand<C>,
235 ) -> NativeExpand<C> {
236 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
237 }
238}
239
240pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
243 clamp_min(lhs, rhs)
244}
245
246pub mod max {
247 use super::*;
248
249 pub fn expand<C: PartialOrd + CubePrimitive>(
250 scope: &mut Scope,
251 lhs: NativeExpand<C>,
252 rhs: NativeExpand<C>,
253 ) -> NativeExpand<C> {
254 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
255 }
256}
257
258macro_rules! impl_binary_func {
260 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
261 paste::paste! {
262 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
263 fn $method_name(self, _rhs: Self) -> Self {
264 unexpanded!()
265 }
266
267 fn [<__expand_ $method_name>](
268 scope: &mut Scope,
269 lhs: NativeExpand<Self>,
270 rhs: NativeExpand<Self>,
271 ) -> NativeExpand<Self> {
272 lhs.[<__expand_ $method_name _method>](scope, rhs)
273 }
274 }
275
276 pub trait [<$trait_name Expand>] {
277 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
278 }
279
280 $(impl $trait_name for $type {})*
281 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
282 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
283 binary_expand(scope, self.into(), rhs.into(), $operator).into()
284 }
285 }
286 }
287 }
288}
289
290macro_rules! impl_binary_func_scalar_out {
291 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
292 paste::paste! {
293 pub trait $trait_name: CubePrimitive
294 + CubeType<ExpandType: [<$trait_name Expand>]
295 + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
296 + Sized {
297 fn $method_name(self, _rhs: Self) -> Self::Scalar {
298 unexpanded!()
299 }
300
301 fn [<__expand_ $method_name>](
302 scope: &mut Scope,
303 lhs: NativeExpand<Self>,
304 rhs: NativeExpand<Self>,
305 ) -> NativeExpand<Self::Scalar> {
306 lhs.[<__expand_ $method_name _method>](scope, rhs)
307 }
308 }
309
310 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
311 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar;
312 }
313
314 $(impl $trait_name for $type {})*
315 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
316 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar {
317 let lhs: ManagedVariable = self.into();
318 let item = lhs.ty.with_vector_size(0);
319 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
320 }
321 }
322 }
323 }
324}
325
326macro_rules! impl_binary_func_mixed_types {
327 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
328 paste::paste! {
329 pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ManagedVariable>> + Sized>:
330 CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
331 fn $method_name(self, _rhs: Rhs) -> Self {
332 unexpanded!()
333 }
334
335 fn [<__expand_ $method_name>](
336 scope: &mut Scope,
337 lhs: NativeExpand<Self>,
338 rhs: NativeExpand<Rhs>,
339 ) -> NativeExpand<Self> {
340 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
341 }
342 }
343
344 pub trait [<$trait_name Expand>]<Rhs: CubeType>{
345 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
346 }
347
348 $(impl $trait_name<$rhs_ty> for $type {})*
349 impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for NativeExpand<T> {
350 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: NativeExpand<Rhs>) -> Self {
351 binary_expand(scope, self.into(), rhs.into(), $operator).into()
352 }
353 }
354 }
355 }
356}
357
358macro_rules! impl_core_binop {
359 ($trait: ident, $method: ident, $op: expr) => {
360 paste::paste! {
361 pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
362 fn [<__expand_ $method>](
363 scope: &mut Scope,
364 lhs: NativeExpand<Self>,
365 rhs: NativeExpand<Self>,
366 ) -> NativeExpand<Self> {
367 lhs.[<__expand_ $method _method>](scope, rhs)
368 }
369 }
370
371 pub trait [<$trait Expand>] {
372 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
373 }
374
375 impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
376 impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
377 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
378 binary_expand(scope, self.into(), rhs.into(), $op).into()
379 }
380 }
381 }
382 };
383}
384
385macro_rules! impl_core_assign_binop {
386 ($trait: ident, $method: ident, $op: expr) => {
387 paste::paste! {
388 pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
389 fn [<__expand_ $method>](
390 scope: &mut Scope,
391 lhs: NativeExpand<Self>,
392 rhs: NativeExpand<Self>,
393 ) {
394 lhs.[<__expand_ $method _method>](scope, rhs)
395 }
396 }
397
398 pub trait [<$trait Expand>] {
399 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
400 }
401
402 impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
403 impl<T: $trait + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
404 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
405 assign_op_expand(scope, self.into(), rhs.into(), $op);
406 }
407 }
408 }
409 };
410}
411
412impl_core_binop!(Add, add, Arithmetic::Add);
413impl_core_binop!(Sub, sub, Arithmetic::Sub);
414impl_core_binop!(Mul, mul, Arithmetic::Mul);
415impl_core_binop!(Div, mul, Arithmetic::Div);
416impl_core_binop!(Rem, rem, Arithmetic::Modulo);
417
418impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
419impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
420impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
421impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
422impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
423
424#[derive_expand(CubeType, CubeTypeMut, IntoRuntime)]
425#[cube(runtime_variants, no_constructors)]
426pub enum Ordering {
427 Less = -1,
428 Equal = 0,
429 Greater = 1,
430}
431
432fn ordering_disc(name: &'static str) -> NativeExpand<i32> {
433 OrderingExpand::discriminant_of(name).into()
434}
435
436#[allow(non_snake_case)]
437pub trait CubeOrdering {
438 fn Less() -> Ordering {
439 Ordering::Less
440 }
441 fn Equal() -> Ordering {
442 Ordering::Equal
443 }
444 fn Greater() -> Ordering {
445 Ordering::Greater
446 }
447 fn __expand_Less(_scope: &mut Scope) -> OrderingExpand {
448 OrderingExpand {
449 discriminant: ordering_disc("Less"),
450 value: (),
451 }
452 }
453 fn __expand_Equal(_scope: &mut Scope) -> OrderingExpand {
454 OrderingExpand {
455 discriminant: ordering_disc("Equal"),
456 value: (),
457 }
458 }
459 fn __expand_Greater(_scope: &mut Scope) -> OrderingExpand {
460 OrderingExpand {
461 discriminant: ordering_disc("Greater"),
462 value: (),
463 }
464 }
465}
466
467impl CubeOrdering for Ordering {}
468
469pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
470 fn __expand_cmp(
471 scope: &mut Scope,
472 lhs: Self::ExpandType,
473 rhs: Self::ExpandType,
474 ) -> OrderingExpand {
475 lhs.__expand_cmp_method(scope, rhs)
476 }
477
478 fn __expand_min(
479 scope: &mut Scope,
480 lhs: Self::ExpandType,
481 rhs: Self::ExpandType,
482 ) -> Self::ExpandType {
483 lhs.__expand_min_method(scope, rhs)
484 }
485
486 fn __expand_max(
487 scope: &mut Scope,
488 lhs: Self::ExpandType,
489 rhs: Self::ExpandType,
490 ) -> Self::ExpandType {
491 lhs.__expand_max_method(scope, rhs)
492 }
493
494 fn __expand_clamp(
495 scope: &mut Scope,
496 lhs: Self::ExpandType,
497 min: Self::ExpandType,
498 max: Self::ExpandType,
499 ) -> Self::ExpandType {
500 lhs.__expand_clamp_method(scope, min, max)
501 }
502}
503pub trait OrdExpand {
504 fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand;
505 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
506 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
507 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
508}
509
510impl<T: Ord + CubePrimitive> CubeOrd for T {}
511impl<T: Ord + CubePrimitive> OrdExpand for NativeExpand<T> {
512 fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand {
513 let lhs_lt_rhs = lt::expand(scope, self.clone(), rhs.clone());
514 let lhs_gt_rhs = gt::expand(scope, self, rhs);
515 let less = ordering_disc("Less");
516 let equal = ordering_disc("Equal");
517 let greater = ordering_disc("Greater");
518 let eq_or_gt = select::expand(scope, lhs_gt_rhs, greater, equal);
519 let discriminant = select::expand(scope, lhs_lt_rhs, less, eq_or_gt);
520 OrderingExpand {
521 discriminant,
522 value: (),
523 }
524 }
525 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
526 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
527 }
528 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
529 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
530 }
531 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
532 unary_expand(scope, self.into(), |op| {
533 Arithmetic::Clamp(ClampOperator {
534 input: op.input,
535 min_value: *min.expand,
536 max_value: *max.expand,
537 })
538 })
539 .into()
540 }
541}
542
543impl_binary_func!(
544 Powf,
545 powf,
546 Arithmetic::Powf,
547 f16,
548 bf16,
549 flex32,
550 tf32,
551 f32,
552 f64
553);
554
555impl_binary_func!(
556 Hypot,
557 hypot,
558 Arithmetic::Hypot,
559 f16,
560 bf16,
561 flex32,
562 tf32,
563 f32,
564 f64
565);
566
567impl_binary_func!(
568 Rhypot,
569 rhypot,
570 Arithmetic::Rhypot,
571 f16,
572 bf16,
573 flex32,
574 tf32,
575 f32,
576 f64
577);
578
579impl_binary_func!(
580 ArcTan2,
581 atan2,
582 Arithmetic::ArcTan2,
583 f16,
584 bf16,
585 flex32,
586 tf32,
587 f32,
588 f64
589);
590impl_binary_func!(
591 Remainder,
592 rem,
593 Arithmetic::Remainder,
594 e2m1,
595 e4m3,
596 e5m2,
597 ue8m0,
598 f16,
599 bf16,
600 flex32,
601 tf32,
602 f32,
603 f64,
604 i8,
605 i16,
606 i32,
607 i64,
608 u8,
609 u16,
610 u32,
611 u64,
612 usize,
613 isize
614);
615impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
616impl_binary_func!(
617 SaturatingAdd,
618 saturating_add,
619 Arithmetic::SaturatingAdd,
620 i8,
621 i16,
622 i32,
623 i64,
624 u8,
625 u16,
626 u32,
627 u64,
628 usize,
629 isize
630);
631impl_binary_func!(
632 SaturatingSub,
633 saturating_sub,
634 Arithmetic::SaturatingSub,
635 i8,
636 i16,
637 i32,
638 i64,
639 u8,
640 u16,
641 u32,
642 u64,
643 usize,
644 isize
645);
646impl_binary_func_scalar_out!(
647 Dot,
648 dot,
649 Arithmetic::Dot,
650 f16,
651 bf16,
652 flex32,
653 tf32,
654 f32,
655 f64,
656 i8,
657 i16,
658 i32,
659 i64,
660 u8,
661 u16,
662 u32,
663 u64,
664 usize,
665 isize
666);
667
668impl_binary_func_mixed_types!(
669 Powi,
670 powi,
671 i32,
672 Arithmetic::Powi,
673 f16,
674 bf16,
675 flex32,
676 tf32,
677 f32,
678 f64,
679 i8,
680 i16,
681 i32,
682 i64,
683 u8,
684 u16,
685 u32,
686 u64,
687 usize,
688 isize
689);