Skip to main content

cubecl_core/frontend/operation/
binary.rs

1use crate as cubecl;
2use crate::ir::{Arithmetic, Bitwise, ManagedVariable, Operator, Scope};
3use crate::{
4    flex32,
5    frontend::{CubePrimitive, NativeExpand},
6    prelude::*,
7};
8use crate::{frontend::CubeType, tf32};
9use crate::{
10    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
11    unexpanded,
12};
13use core::{cmp::Ordering, ops::*};
14use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
15use cubecl_ir::ClampOperator;
16use cubecl_macros::derive_expand;
17use half::{bf16, f16};
18
19pub mod add {
20    use super::*;
21
22    pub fn expand<C: CubePrimitive>(
23        scope: &mut Scope,
24        lhs: NativeExpand<C>,
25        rhs: NativeExpand<C>,
26    ) -> NativeExpand<C> {
27        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
28    }
29}
30
31pub mod sub {
32    use cubecl_ir::{ConstantValue, Variable};
33
34    use super::*;
35
36    pub fn expand<C: CubePrimitive>(
37        scope: &mut Scope,
38        lhs: NativeExpand<C>,
39        rhs: NativeExpand<C>,
40    ) -> NativeExpand<C> {
41        // Dirty hack to enable slice destructuring with trailing patterns on `Sequence`
42        match (lhs.expand.as_const(), rhs.expand.as_const()) {
43            (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
44                let item_lhs = lhs.expand.ty;
45                let item_rhs = rhs.expand.ty;
46
47                let vector_size = find_vectorization(item_lhs, item_rhs);
48
49                let item = item_lhs.with_vector_size(vector_size);
50                let value = (lhs_val - rhs_val).into();
51                ManagedVariable::Plain(Variable::constant(value, item)).into()
52            }
53            _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
54        }
55    }
56}
57
58pub mod mul {
59    use super::*;
60
61    pub fn expand<C: CubePrimitive>(
62        scope: &mut Scope,
63        lhs: NativeExpand<C>,
64        rhs: NativeExpand<C>,
65    ) -> NativeExpand<C> {
66        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
67    }
68}
69
70pub mod div {
71    use super::*;
72
73    pub fn expand<C: CubePrimitive>(
74        scope: &mut Scope,
75        lhs: NativeExpand<C>,
76        rhs: NativeExpand<C>,
77    ) -> NativeExpand<C> {
78        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
79    }
80}
81
82pub mod rem {
83    use super::*;
84
85    pub fn expand<C: CubePrimitive>(
86        scope: &mut Scope,
87        lhs: NativeExpand<C>,
88        rhs: NativeExpand<C>,
89    ) -> NativeExpand<C> {
90        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
91    }
92}
93
94pub mod and {
95    use super::*;
96
97    pub fn expand<C: CubePrimitive>(
98        scope: &mut Scope,
99        lhs: NativeExpand<C>,
100        rhs: NativeExpand<C>,
101    ) -> NativeExpand<bool> {
102        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
103    }
104}
105
106pub mod bitand {
107    use super::*;
108
109    pub fn expand<C: CubePrimitive>(
110        scope: &mut Scope,
111        lhs: NativeExpand<C>,
112        rhs: NativeExpand<C>,
113    ) -> NativeExpand<C> {
114        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
115    }
116}
117
118pub mod bitor {
119    use super::*;
120
121    pub fn expand<C: CubePrimitive>(
122        scope: &mut Scope,
123        lhs: NativeExpand<C>,
124        rhs: NativeExpand<C>,
125    ) -> NativeExpand<C> {
126        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
127    }
128}
129
130pub mod or {
131    use super::*;
132
133    pub fn expand<C: CubePrimitive>(
134        scope: &mut Scope,
135        lhs: NativeExpand<C>,
136        rhs: NativeExpand<C>,
137    ) -> NativeExpand<bool> {
138        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
139    }
140}
141
142pub mod bitxor {
143    use super::*;
144
145    pub fn expand<C: CubePrimitive>(
146        scope: &mut Scope,
147        lhs: NativeExpand<C>,
148        rhs: NativeExpand<C>,
149    ) -> NativeExpand<C> {
150        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
151    }
152}
153
154pub mod shl {
155    use super::*;
156
157    pub fn expand<C: CubePrimitive>(
158        scope: &mut Scope,
159        lhs: NativeExpand<C>,
160        rhs: NativeExpand<C>,
161    ) -> NativeExpand<C> {
162        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
163    }
164}
165
166pub mod shr {
167    use super::*;
168
169    pub fn expand<C: CubePrimitive>(
170        scope: &mut Scope,
171        lhs: NativeExpand<C>,
172        rhs: NativeExpand<C>,
173    ) -> NativeExpand<C> {
174        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
175    }
176}
177
178pub mod clamp {
179    use super::*;
180
181    pub fn expand<C: PartialOrd + CubePrimitive>(
182        scope: &mut Scope,
183        input: NativeExpand<C>,
184        min: NativeExpand<C>,
185        max: NativeExpand<C>,
186    ) -> NativeExpand<C> {
187        unary_expand(scope, input.into(), |op| {
188            Arithmetic::Clamp(ClampOperator {
189                input: op.input,
190                min_value: *min.expand,
191                max_value: *max.expand,
192            })
193        })
194        .into()
195    }
196}
197
198pub mod clamp_max {
199    use super::*;
200
201    pub fn expand<C: PartialOrd + CubePrimitive>(
202        scope: &mut Scope,
203        lhs: NativeExpand<C>,
204        rhs: NativeExpand<C>,
205    ) -> NativeExpand<C> {
206        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
207    }
208}
209
210pub mod clamp_min {
211    use super::*;
212
213    pub fn expand<C: PartialOrd + CubePrimitive>(
214        scope: &mut Scope,
215        lhs: NativeExpand<C>,
216        rhs: NativeExpand<C>,
217    ) -> NativeExpand<C> {
218        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
219    }
220}
221
222/// The minimum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
223/// `clamp_max` may sometimes be more clear.
224pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
225    clamp_max(lhs, rhs)
226}
227
228pub mod min {
229    use super::*;
230
231    pub fn expand<C: PartialOrd + CubePrimitive>(
232        scope: &mut Scope,
233        lhs: NativeExpand<C>,
234        rhs: NativeExpand<C>,
235    ) -> NativeExpand<C> {
236        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
237    }
238}
239
240/// The maximum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
241/// `clamp_min` may sometimes be more clear.
242pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
243    clamp_min(lhs, rhs)
244}
245
246pub mod max {
247    use super::*;
248
249    pub fn expand<C: PartialOrd + CubePrimitive>(
250        scope: &mut Scope,
251        lhs: NativeExpand<C>,
252        rhs: NativeExpand<C>,
253    ) -> NativeExpand<C> {
254        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
255    }
256}
257
258/// For binary functions without special syntax
259macro_rules! impl_binary_func {
260    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
261        paste::paste! {
262            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
263                fn $method_name(self, _rhs: Self) -> Self {
264                    unexpanded!()
265                }
266
267                fn [<__expand_ $method_name>](
268                    scope: &mut Scope,
269                    lhs: NativeExpand<Self>,
270                    rhs: NativeExpand<Self>,
271                ) -> NativeExpand<Self> {
272                    lhs.[<__expand_ $method_name _method>](scope, rhs)
273                }
274            }
275
276            pub trait [<$trait_name Expand>] {
277                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
278            }
279
280            $(impl $trait_name for $type {})*
281            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
282                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
283                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
284                }
285            }
286        }
287    }
288}
289
290macro_rules! impl_binary_func_scalar_out {
291    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
292        paste::paste! {
293            pub trait $trait_name: CubePrimitive
294                + CubeType<ExpandType: [<$trait_name Expand>]
295                + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
296                + Sized {
297                fn $method_name(self, _rhs: Self) -> Self::Scalar {
298                    unexpanded!()
299                }
300
301                fn [<__expand_ $method_name>](
302                    scope: &mut Scope,
303                    lhs: NativeExpand<Self>,
304                    rhs: NativeExpand<Self>,
305                ) -> NativeExpand<Self::Scalar> {
306                    lhs.[<__expand_ $method_name _method>](scope, rhs)
307                }
308            }
309
310            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
311                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar;
312            }
313
314            $(impl $trait_name for $type {})*
315            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
316                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar {
317                    let lhs: ManagedVariable = self.into();
318                    let item = lhs.ty.with_vector_size(0);
319                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
320                }
321            }
322        }
323    }
324}
325
326macro_rules! impl_binary_func_mixed_types {
327    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
328        paste::paste! {
329            pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ManagedVariable>> + Sized>:
330                CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
331                fn $method_name(self, _rhs: Rhs) -> Self {
332                    unexpanded!()
333                }
334
335                fn [<__expand_ $method_name>](
336                    scope: &mut Scope,
337                    lhs: NativeExpand<Self>,
338                    rhs: NativeExpand<Rhs>,
339                ) -> NativeExpand<Self> {
340                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
341                }
342            }
343
344            pub trait [<$trait_name Expand>]<Rhs: CubeType>{
345                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
346            }
347
348            $(impl $trait_name<$rhs_ty> for $type {})*
349            impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for NativeExpand<T> {
350                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: NativeExpand<Rhs>) -> Self {
351                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
352                }
353            }
354        }
355    }
356}
357
358macro_rules! impl_core_binop {
359    ($trait: ident, $method: ident, $op: expr) => {
360        paste::paste! {
361            pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
362                fn [<__expand_ $method>](
363                    scope: &mut Scope,
364                    lhs: NativeExpand<Self>,
365                    rhs: NativeExpand<Self>,
366                ) -> NativeExpand<Self> {
367                    lhs.[<__expand_ $method _method>](scope, rhs)
368                }
369            }
370
371            pub trait [<$trait Expand>] {
372                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
373            }
374
375            impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
376            impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
377                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
378                    binary_expand(scope, self.into(), rhs.into(), $op).into()
379                }
380            }
381        }
382    };
383}
384
385macro_rules! impl_core_assign_binop {
386    ($trait: ident, $method: ident, $op: expr) => {
387        paste::paste! {
388            pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
389                fn [<__expand_ $method>](
390                    scope: &mut Scope,
391                    lhs: NativeExpand<Self>,
392                    rhs: NativeExpand<Self>,
393                ) {
394                    lhs.[<__expand_ $method _method>](scope, rhs)
395                }
396            }
397
398            pub trait [<$trait Expand>] {
399                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
400            }
401
402            impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
403            impl<T: $trait + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
404                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
405                    assign_op_expand(scope, self.into(), rhs.into(), $op);
406                }
407            }
408        }
409    };
410}
411
412impl_core_binop!(Add, add, Arithmetic::Add);
413impl_core_binop!(Sub, sub, Arithmetic::Sub);
414impl_core_binop!(Mul, mul, Arithmetic::Mul);
415impl_core_binop!(Div, mul, Arithmetic::Div);
416impl_core_binop!(Rem, rem, Arithmetic::Modulo);
417
418impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
419impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
420impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
421impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
422impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
423
424#[derive_expand(CubeType, CubeTypeMut, IntoRuntime)]
425#[cube(runtime_variants, no_constructors)]
426pub enum Ordering {
427    Less = -1,
428    Equal = 0,
429    Greater = 1,
430}
431
432fn ordering_disc(name: &'static str) -> NativeExpand<i32> {
433    OrderingExpand::discriminant_of(name).into()
434}
435
436#[allow(non_snake_case)]
437pub trait CubeOrdering {
438    fn Less() -> Ordering {
439        Ordering::Less
440    }
441    fn Equal() -> Ordering {
442        Ordering::Equal
443    }
444    fn Greater() -> Ordering {
445        Ordering::Greater
446    }
447    fn __expand_Less(_scope: &mut Scope) -> OrderingExpand {
448        OrderingExpand {
449            discriminant: ordering_disc("Less"),
450            value: (),
451        }
452    }
453    fn __expand_Equal(_scope: &mut Scope) -> OrderingExpand {
454        OrderingExpand {
455            discriminant: ordering_disc("Equal"),
456            value: (),
457        }
458    }
459    fn __expand_Greater(_scope: &mut Scope) -> OrderingExpand {
460        OrderingExpand {
461            discriminant: ordering_disc("Greater"),
462            value: (),
463        }
464    }
465}
466
467impl CubeOrdering for Ordering {}
468
469pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
470    fn __expand_cmp(
471        scope: &mut Scope,
472        lhs: Self::ExpandType,
473        rhs: Self::ExpandType,
474    ) -> OrderingExpand {
475        lhs.__expand_cmp_method(scope, rhs)
476    }
477
478    fn __expand_min(
479        scope: &mut Scope,
480        lhs: Self::ExpandType,
481        rhs: Self::ExpandType,
482    ) -> Self::ExpandType {
483        lhs.__expand_min_method(scope, rhs)
484    }
485
486    fn __expand_max(
487        scope: &mut Scope,
488        lhs: Self::ExpandType,
489        rhs: Self::ExpandType,
490    ) -> Self::ExpandType {
491        lhs.__expand_max_method(scope, rhs)
492    }
493
494    fn __expand_clamp(
495        scope: &mut Scope,
496        lhs: Self::ExpandType,
497        min: Self::ExpandType,
498        max: Self::ExpandType,
499    ) -> Self::ExpandType {
500        lhs.__expand_clamp_method(scope, min, max)
501    }
502}
503pub trait OrdExpand {
504    fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand;
505    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
506    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
507    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
508}
509
510impl<T: Ord + CubePrimitive> CubeOrd for T {}
511impl<T: Ord + CubePrimitive> OrdExpand for NativeExpand<T> {
512    fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand {
513        let lhs_lt_rhs = lt::expand(scope, self.clone(), rhs.clone());
514        let lhs_gt_rhs = gt::expand(scope, self, rhs);
515        let less = ordering_disc("Less");
516        let equal = ordering_disc("Equal");
517        let greater = ordering_disc("Greater");
518        let eq_or_gt = select::expand(scope, lhs_gt_rhs, greater, equal);
519        let discriminant = select::expand(scope, lhs_lt_rhs, less, eq_or_gt);
520        OrderingExpand {
521            discriminant,
522            value: (),
523        }
524    }
525    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
526        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
527    }
528    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
529        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
530    }
531    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
532        unary_expand(scope, self.into(), |op| {
533            Arithmetic::Clamp(ClampOperator {
534                input: op.input,
535                min_value: *min.expand,
536                max_value: *max.expand,
537            })
538        })
539        .into()
540    }
541}
542
543impl_binary_func!(
544    Powf,
545    powf,
546    Arithmetic::Powf,
547    f16,
548    bf16,
549    flex32,
550    tf32,
551    f32,
552    f64
553);
554
555impl_binary_func!(
556    Hypot,
557    hypot,
558    Arithmetic::Hypot,
559    f16,
560    bf16,
561    flex32,
562    tf32,
563    f32,
564    f64
565);
566
567impl_binary_func!(
568    Rhypot,
569    rhypot,
570    Arithmetic::Rhypot,
571    f16,
572    bf16,
573    flex32,
574    tf32,
575    f32,
576    f64
577);
578
579impl_binary_func!(
580    ArcTan2,
581    atan2,
582    Arithmetic::ArcTan2,
583    f16,
584    bf16,
585    flex32,
586    tf32,
587    f32,
588    f64
589);
590impl_binary_func!(
591    Remainder,
592    rem,
593    Arithmetic::Remainder,
594    e2m1,
595    e4m3,
596    e5m2,
597    ue8m0,
598    f16,
599    bf16,
600    flex32,
601    tf32,
602    f32,
603    f64,
604    i8,
605    i16,
606    i32,
607    i64,
608    u8,
609    u16,
610    u32,
611    u64,
612    usize,
613    isize
614);
615impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
616impl_binary_func!(
617    SaturatingAdd,
618    saturating_add,
619    Arithmetic::SaturatingAdd,
620    i8,
621    i16,
622    i32,
623    i64,
624    u8,
625    u16,
626    u32,
627    u64,
628    usize,
629    isize
630);
631impl_binary_func!(
632    SaturatingSub,
633    saturating_sub,
634    Arithmetic::SaturatingSub,
635    i8,
636    i16,
637    i32,
638    i64,
639    u8,
640    u16,
641    u32,
642    u64,
643    usize,
644    isize
645);
646impl_binary_func_scalar_out!(
647    Dot,
648    dot,
649    Arithmetic::Dot,
650    f16,
651    bf16,
652    flex32,
653    tf32,
654    f32,
655    f64,
656    i8,
657    i16,
658    i32,
659    i64,
660    u8,
661    u16,
662    u32,
663    u64,
664    usize,
665    isize
666);
667
668impl_binary_func_mixed_types!(
669    Powi,
670    powi,
671    i32,
672    Arithmetic::Powi,
673    f16,
674    bf16,
675    flex32,
676    tf32,
677    f32,
678    f64,
679    i8,
680    i16,
681    i32,
682    i64,
683    u8,
684    u16,
685    u32,
686    u64,
687    usize,
688    isize
689);