Skip to main content

cubecl_core/frontend/operation/
binary.rs

1use crate::{
2    flex32,
3    frontend::{CubePrimitive, ExpandElementTyped},
4    prelude::*,
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use crate::{
12    ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope},
13    prelude::assign_op_expand,
14};
15use core::ops::*;
16use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
17use cubecl_ir::ClampOperator;
18use half::{bf16, f16};
19
20pub mod add {
21    use super::*;
22
23    pub fn expand<C: CubePrimitive>(
24        scope: &mut Scope,
25        lhs: ExpandElementTyped<C>,
26        rhs: ExpandElementTyped<C>,
27    ) -> ExpandElementTyped<C> {
28        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
29    }
30}
31
32pub mod sub {
33    use cubecl_ir::{ConstantValue, Variable};
34
35    use super::*;
36
37    pub fn expand<C: CubePrimitive>(
38        scope: &mut Scope,
39        lhs: ExpandElementTyped<C>,
40        rhs: ExpandElementTyped<C>,
41    ) -> ExpandElementTyped<C> {
42        // Dirty hack to enable slice destructuring with trailing patterns on `Sequence`
43        match (lhs.expand.as_const(), rhs.expand.as_const()) {
44            (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
45                let item_lhs = lhs.expand.ty;
46                let item_rhs = rhs.expand.ty;
47
48                let line_size = find_vectorization(item_lhs, item_rhs);
49
50                let item = item_lhs.line(line_size);
51                let value = (lhs_val - rhs_val).into();
52                ExpandElement::Plain(Variable::constant(value, item)).into()
53            }
54            _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
55        }
56    }
57}
58
59pub mod mul {
60    use super::*;
61
62    pub fn expand<C: CubePrimitive>(
63        scope: &mut Scope,
64        lhs: ExpandElementTyped<C>,
65        rhs: ExpandElementTyped<C>,
66    ) -> ExpandElementTyped<C> {
67        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
68    }
69}
70
71pub mod div {
72    use super::*;
73
74    pub fn expand<C: CubePrimitive>(
75        scope: &mut Scope,
76        lhs: ExpandElementTyped<C>,
77        rhs: ExpandElementTyped<C>,
78    ) -> ExpandElementTyped<C> {
79        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
80    }
81}
82
83pub mod rem {
84    use super::*;
85
86    pub fn expand<C: CubePrimitive>(
87        scope: &mut Scope,
88        lhs: ExpandElementTyped<C>,
89        rhs: ExpandElementTyped<C>,
90    ) -> ExpandElementTyped<C> {
91        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
92    }
93}
94
95pub mod and {
96    use super::*;
97
98    pub fn expand<C: CubePrimitive>(
99        scope: &mut Scope,
100        lhs: ExpandElementTyped<C>,
101        rhs: ExpandElementTyped<C>,
102    ) -> ExpandElementTyped<bool> {
103        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
104    }
105}
106
107pub mod bitand {
108    use super::*;
109
110    pub fn expand<C: CubePrimitive>(
111        scope: &mut Scope,
112        lhs: ExpandElementTyped<C>,
113        rhs: ExpandElementTyped<C>,
114    ) -> ExpandElementTyped<C> {
115        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
116    }
117}
118
119pub mod bitor {
120    use super::*;
121
122    pub fn expand<C: CubePrimitive>(
123        scope: &mut Scope,
124        lhs: ExpandElementTyped<C>,
125        rhs: ExpandElementTyped<C>,
126    ) -> ExpandElementTyped<C> {
127        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
128    }
129}
130
131pub mod or {
132    use super::*;
133
134    pub fn expand<C: CubePrimitive>(
135        scope: &mut Scope,
136        lhs: ExpandElementTyped<C>,
137        rhs: ExpandElementTyped<C>,
138    ) -> ExpandElementTyped<bool> {
139        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
140    }
141}
142
143pub mod bitxor {
144    use super::*;
145
146    pub fn expand<C: CubePrimitive>(
147        scope: &mut Scope,
148        lhs: ExpandElementTyped<C>,
149        rhs: ExpandElementTyped<C>,
150    ) -> ExpandElementTyped<C> {
151        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
152    }
153}
154
155pub mod shl {
156    use super::*;
157
158    pub fn expand<C: CubePrimitive>(
159        scope: &mut Scope,
160        lhs: ExpandElementTyped<C>,
161        rhs: ExpandElementTyped<C>,
162    ) -> ExpandElementTyped<C> {
163        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
164    }
165}
166
167pub mod shr {
168    use super::*;
169
170    pub fn expand<C: CubePrimitive>(
171        scope: &mut Scope,
172        lhs: ExpandElementTyped<C>,
173        rhs: ExpandElementTyped<C>,
174    ) -> ExpandElementTyped<C> {
175        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
176    }
177}
178
179pub mod clamp {
180    use super::*;
181
182    pub fn expand<C: PartialOrd + CubePrimitive>(
183        scope: &mut Scope,
184        input: ExpandElementTyped<C>,
185        min: ExpandElementTyped<C>,
186        max: ExpandElementTyped<C>,
187    ) -> ExpandElementTyped<C> {
188        unary_expand(scope, input.into(), |op| {
189            Arithmetic::Clamp(ClampOperator {
190                input: op.input,
191                min_value: *min.expand,
192                max_value: *max.expand,
193            })
194        })
195        .into()
196    }
197}
198
199pub mod clamp_max {
200    use super::*;
201
202    pub fn expand<C: PartialOrd + CubePrimitive>(
203        scope: &mut Scope,
204        lhs: ExpandElementTyped<C>,
205        rhs: ExpandElementTyped<C>,
206    ) -> ExpandElementTyped<C> {
207        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
208    }
209}
210
211pub mod clamp_min {
212    use super::*;
213
214    pub fn expand<C: PartialOrd + CubePrimitive>(
215        scope: &mut Scope,
216        lhs: ExpandElementTyped<C>,
217        rhs: ExpandElementTyped<C>,
218    ) -> ExpandElementTyped<C> {
219        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
220    }
221}
222
223/// The minimum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
224/// `clamp_max` may sometimes be more clear.
225pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
226    clamp_max(lhs, rhs)
227}
228
229pub mod min {
230    use super::*;
231
232    pub fn expand<C: PartialOrd + CubePrimitive>(
233        scope: &mut Scope,
234        lhs: ExpandElementTyped<C>,
235        rhs: ExpandElementTyped<C>,
236    ) -> ExpandElementTyped<C> {
237        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
238    }
239}
240
241/// The maximum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
242/// `clamp_min` may sometimes be more clear.
243pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
244    clamp_min(lhs, rhs)
245}
246
247pub mod max {
248    use super::*;
249
250    pub fn expand<C: PartialOrd + CubePrimitive>(
251        scope: &mut Scope,
252        lhs: ExpandElementTyped<C>,
253        rhs: ExpandElementTyped<C>,
254    ) -> ExpandElementTyped<C> {
255        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
256    }
257}
258
259/// For binary functions without special syntax
260macro_rules! impl_binary_func {
261    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
262        paste::paste! {
263            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
264                fn $method_name(self, _rhs: Self) -> Self {
265                    unexpanded!()
266                }
267
268                fn [<__expand_ $method_name>](
269                    scope: &mut Scope,
270                    lhs: ExpandElementTyped<Self>,
271                    rhs: ExpandElementTyped<Self>,
272                ) -> ExpandElementTyped<Self> {
273                    lhs.[<__expand_ $method_name _method>](scope, rhs)
274                }
275            }
276
277            pub trait [<$trait_name Expand>] {
278                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
279            }
280
281            $(impl $trait_name for $type {})*
282            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
283                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
284                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
285                }
286            }
287        }
288    }
289}
290
291macro_rules! impl_binary_func_fixed_output_vectorization {
292    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
293        paste::paste! {
294            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
295                fn $method_name(self, _rhs: Self) -> Self {
296                    unexpanded!()
297                }
298
299                fn [<__expand_ $method_name>](
300                    scope: &mut Scope,
301                    lhs: ExpandElementTyped<Self>,
302                    rhs: ExpandElementTyped<Self>,
303                ) -> ExpandElementTyped<Self> {
304                    lhs.[<__expand_ $method_name _method>](scope, rhs)
305                }
306            }
307
308            pub trait [<$trait_name Expand>] {
309                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
310            }
311
312            $(impl $trait_name for $type {})*
313            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
314                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
315                    let lhs: ExpandElement = self.into();
316                    let item = lhs.ty.line($out_vectorization);
317                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
318                }
319            }
320        }
321    }
322}
323
324macro_rules! impl_binary_func_mixed_types {
325    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
326        paste::paste! {
327            pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ExpandElement>> + Sized>:
328                CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
329                fn $method_name(self, _rhs: Rhs) -> Self {
330                    unexpanded!()
331                }
332
333                fn [<__expand_ $method_name>](
334                    scope: &mut Scope,
335                    lhs: ExpandElementTyped<Self>,
336                    rhs: ExpandElementTyped<Rhs>,
337                ) -> ExpandElementTyped<Self> {
338                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
339                }
340            }
341
342            pub trait [<$trait_name Expand>]<Rhs: CubeType>{
343                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
344            }
345
346            $(impl $trait_name<$rhs_ty> for $type {})*
347            impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for ExpandElementTyped<T> {
348                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<Rhs>) -> Self {
349                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
350                }
351            }
352        }
353    }
354}
355
356macro_rules! impl_core_binop {
357    ($trait: ident, $method: ident, $op: expr) => {
358        paste::paste! {
359            pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
360                fn [<__expand_ $method>](
361                    scope: &mut Scope,
362                    lhs: ExpandElementTyped<Self>,
363                    rhs: ExpandElementTyped<Self>,
364                ) -> ExpandElementTyped<Self> {
365                    lhs.[<__expand_ $method _method>](scope, rhs)
366                }
367            }
368
369            pub trait [<$trait Expand>] {
370                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
371            }
372
373            impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
374            impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
375                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
376                    binary_expand(scope, self.into(), rhs.into(), $op).into()
377                }
378            }
379        }
380    };
381}
382
383macro_rules! impl_core_assign_binop {
384    ($trait: ident, $method: ident, $op: expr) => {
385        paste::paste! {
386            pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
387                fn [<__expand_ $method>](
388                    scope: &mut Scope,
389                    lhs: ExpandElementTyped<Self>,
390                    rhs: ExpandElementTyped<Self>,
391                ) {
392                    lhs.[<__expand_ $method _method>](scope, rhs)
393                }
394            }
395
396            pub trait [<$trait Expand>] {
397                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
398            }
399
400            impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
401            impl<T: $trait + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
402                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
403                    assign_op_expand(scope, self.into(), rhs.into(), $op);
404                }
405            }
406        }
407    };
408}
409
410impl_core_binop!(Add, add, Arithmetic::Add);
411impl_core_binop!(Sub, sub, Arithmetic::Sub);
412impl_core_binop!(Mul, mul, Arithmetic::Mul);
413impl_core_binop!(Div, mul, Arithmetic::Div);
414impl_core_binop!(Rem, rem, Arithmetic::Modulo);
415
416impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
417impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
418impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
419impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
420impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
421
422pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
423    fn __expand_min(
424        scope: &mut Scope,
425        lhs: Self::ExpandType,
426        rhs: Self::ExpandType,
427    ) -> Self::ExpandType {
428        lhs.__expand_min_method(scope, rhs)
429    }
430
431    fn __expand_max(
432        scope: &mut Scope,
433        lhs: Self::ExpandType,
434        rhs: Self::ExpandType,
435    ) -> Self::ExpandType {
436        lhs.__expand_max_method(scope, rhs)
437    }
438
439    fn __expand_clamp(
440        scope: &mut Scope,
441        lhs: Self::ExpandType,
442        min: Self::ExpandType,
443        max: Self::ExpandType,
444    ) -> Self::ExpandType {
445        lhs.__expand_clamp_method(scope, min, max)
446    }
447}
448pub trait OrdExpand {
449    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
450    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
451    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
452}
453
454impl<T: Ord + CubePrimitive> CubeOrd for T {}
455impl<T: Ord + CubePrimitive> OrdExpand for ExpandElementTyped<T> {
456    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
457        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
458    }
459    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
460        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
461    }
462    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
463        unary_expand(scope, self.into(), |op| {
464            Arithmetic::Clamp(ClampOperator {
465                input: op.input,
466                min_value: *min.expand,
467                max_value: *max.expand,
468            })
469        })
470        .into()
471    }
472}
473
474impl_binary_func!(
475    Powf,
476    powf,
477    Arithmetic::Powf,
478    f16,
479    bf16,
480    flex32,
481    tf32,
482    f32,
483    f64
484);
485
486impl_binary_func!(
487    Hypot,
488    hypot,
489    Arithmetic::Hypot,
490    f16,
491    bf16,
492    flex32,
493    tf32,
494    f32,
495    f64
496);
497
498impl_binary_func!(
499    Rhypot,
500    rhypot,
501    Arithmetic::Rhypot,
502    f16,
503    bf16,
504    flex32,
505    tf32,
506    f32,
507    f64
508);
509
510impl_binary_func!(
511    ArcTan2,
512    atan2,
513    Arithmetic::ArcTan2,
514    f16,
515    bf16,
516    flex32,
517    tf32,
518    f32,
519    f64
520);
521impl_binary_func!(
522    Remainder,
523    rem,
524    Arithmetic::Remainder,
525    e2m1,
526    e4m3,
527    e5m2,
528    ue8m0,
529    f16,
530    bf16,
531    flex32,
532    tf32,
533    f32,
534    f64,
535    i8,
536    i16,
537    i32,
538    i64,
539    u8,
540    u16,
541    u32,
542    u64,
543    usize,
544    isize
545);
546impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
547impl_binary_func!(
548    SaturatingAdd,
549    saturating_add,
550    Arithmetic::SaturatingAdd,
551    i8,
552    i16,
553    i32,
554    i64,
555    u8,
556    u16,
557    u32,
558    u64,
559    usize,
560    isize
561);
562impl_binary_func!(
563    SaturatingSub,
564    saturating_sub,
565    Arithmetic::SaturatingSub,
566    i8,
567    i16,
568    i32,
569    i64,
570    u8,
571    u16,
572    u32,
573    u64,
574    usize,
575    isize
576);
577impl_binary_func_fixed_output_vectorization!(
578    Dot,
579    dot,
580    Arithmetic::Dot,
581    0,
582    f16,
583    bf16,
584    flex32,
585    tf32,
586    f32,
587    f64,
588    i8,
589    i16,
590    i32,
591    i64,
592    u8,
593    u16,
594    u32,
595    u64,
596    usize,
597    isize
598);
599
600impl_binary_func_mixed_types!(
601    Powi,
602    powi,
603    i32,
604    Arithmetic::Powi,
605    f16,
606    bf16,
607    flex32,
608    tf32,
609    f32,
610    f64,
611    i8,
612    i16,
613    i32,
614    i64,
615    u8,
616    u16,
617    u32,
618    u64,
619    usize,
620    isize
621);