cubecl_core/frontend/operation/
binary.rs

1use crate::{
2    flex32,
3    frontend::{CubePrimitive, ExpandElementTyped},
4    prelude::*,
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use crate::{
12    ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope},
13    prelude::assign_op_expand,
14};
15use core::ops::*;
16use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
17use cubecl_ir::ClampOperator;
18use half::{bf16, f16};
19
20pub mod add {
21    use super::*;
22
23    pub fn expand<C: CubePrimitive>(
24        scope: &mut Scope,
25        lhs: ExpandElementTyped<C>,
26        rhs: ExpandElementTyped<C>,
27    ) -> ExpandElementTyped<C> {
28        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
29    }
30}
31
32pub mod sub {
33    use super::*;
34
35    pub fn expand<C: CubePrimitive>(
36        scope: &mut Scope,
37        lhs: ExpandElementTyped<C>,
38        rhs: ExpandElementTyped<C>,
39    ) -> ExpandElementTyped<C> {
40        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
41    }
42}
43
44pub mod mul {
45    use super::*;
46
47    pub fn expand<C: CubePrimitive>(
48        scope: &mut Scope,
49        lhs: ExpandElementTyped<C>,
50        rhs: ExpandElementTyped<C>,
51    ) -> ExpandElementTyped<C> {
52        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
53    }
54}
55
56pub mod div {
57    use super::*;
58
59    pub fn expand<C: CubePrimitive>(
60        scope: &mut Scope,
61        lhs: ExpandElementTyped<C>,
62        rhs: ExpandElementTyped<C>,
63    ) -> ExpandElementTyped<C> {
64        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
65    }
66}
67
68pub mod rem {
69    use super::*;
70
71    pub fn expand<C: CubePrimitive>(
72        scope: &mut Scope,
73        lhs: ExpandElementTyped<C>,
74        rhs: ExpandElementTyped<C>,
75    ) -> ExpandElementTyped<C> {
76        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
77    }
78}
79
80pub mod and {
81    use super::*;
82
83    pub fn expand<C: CubePrimitive>(
84        scope: &mut Scope,
85        lhs: ExpandElementTyped<C>,
86        rhs: ExpandElementTyped<C>,
87    ) -> ExpandElementTyped<bool> {
88        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
89    }
90}
91
92pub mod bitand {
93    use super::*;
94
95    pub fn expand<C: CubePrimitive>(
96        scope: &mut Scope,
97        lhs: ExpandElementTyped<C>,
98        rhs: ExpandElementTyped<C>,
99    ) -> ExpandElementTyped<C> {
100        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
101    }
102}
103
104pub mod bitor {
105    use super::*;
106
107    pub fn expand<C: CubePrimitive>(
108        scope: &mut Scope,
109        lhs: ExpandElementTyped<C>,
110        rhs: ExpandElementTyped<C>,
111    ) -> ExpandElementTyped<C> {
112        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
113    }
114}
115
116pub mod or {
117    use super::*;
118
119    pub fn expand<C: CubePrimitive>(
120        scope: &mut Scope,
121        lhs: ExpandElementTyped<C>,
122        rhs: ExpandElementTyped<C>,
123    ) -> ExpandElementTyped<bool> {
124        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
125    }
126}
127
128pub mod bitxor {
129    use super::*;
130
131    pub fn expand<C: CubePrimitive>(
132        scope: &mut Scope,
133        lhs: ExpandElementTyped<C>,
134        rhs: ExpandElementTyped<C>,
135    ) -> ExpandElementTyped<C> {
136        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
137    }
138}
139
140pub mod shl {
141    use super::*;
142
143    pub fn expand<C: CubePrimitive>(
144        scope: &mut Scope,
145        lhs: ExpandElementTyped<C>,
146        rhs: ExpandElementTyped<C>,
147    ) -> ExpandElementTyped<C> {
148        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
149    }
150}
151
152pub mod shr {
153    use super::*;
154
155    pub fn expand<C: CubePrimitive>(
156        scope: &mut Scope,
157        lhs: ExpandElementTyped<C>,
158        rhs: ExpandElementTyped<C>,
159    ) -> ExpandElementTyped<C> {
160        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
161    }
162}
163
164pub mod clamp {
165    use super::*;
166
167    pub fn expand<C: PartialOrd + CubePrimitive>(
168        scope: &mut Scope,
169        input: ExpandElementTyped<C>,
170        min: ExpandElementTyped<C>,
171        max: ExpandElementTyped<C>,
172    ) -> ExpandElementTyped<C> {
173        unary_expand(scope, input.into(), |op| {
174            Arithmetic::Clamp(ClampOperator {
175                input: op.input,
176                min_value: *min.expand,
177                max_value: *max.expand,
178            })
179        })
180        .into()
181    }
182}
183
184pub mod clamp_max {
185    use super::*;
186
187    pub fn expand<C: PartialOrd + CubePrimitive>(
188        scope: &mut Scope,
189        lhs: ExpandElementTyped<C>,
190        rhs: ExpandElementTyped<C>,
191    ) -> ExpandElementTyped<C> {
192        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
193    }
194}
195
196pub mod clamp_min {
197    use super::*;
198
199    pub fn expand<C: PartialOrd + CubePrimitive>(
200        scope: &mut Scope,
201        lhs: ExpandElementTyped<C>,
202        rhs: ExpandElementTyped<C>,
203    ) -> ExpandElementTyped<C> {
204        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
205    }
206}
207
208/// The minimum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
209/// `clamp_max` may sometimes be more clear.
210pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
211    clamp_max(lhs, rhs)
212}
213
214pub mod min {
215    use super::*;
216
217    pub fn expand<C: PartialOrd + CubePrimitive>(
218        scope: &mut Scope,
219        lhs: ExpandElementTyped<C>,
220        rhs: ExpandElementTyped<C>,
221    ) -> ExpandElementTyped<C> {
222        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
223    }
224}
225
226/// The maximum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
227/// `clamp_min` may sometimes be more clear.
228pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
229    clamp_min(lhs, rhs)
230}
231
232pub mod max {
233    use super::*;
234
235    pub fn expand<C: PartialOrd + CubePrimitive>(
236        scope: &mut Scope,
237        lhs: ExpandElementTyped<C>,
238        rhs: ExpandElementTyped<C>,
239    ) -> ExpandElementTyped<C> {
240        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
241    }
242}
243
244/// For binary functions without special syntax
245macro_rules! impl_binary_func {
246    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
247        paste::paste! {
248            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
249                fn $method_name(self, _rhs: Self) -> Self {
250                    unexpanded!()
251                }
252
253                fn [<__expand_ $method_name>](
254                    scope: &mut Scope,
255                    lhs: ExpandElementTyped<Self>,
256                    rhs: ExpandElementTyped<Self>,
257                ) -> ExpandElementTyped<Self> {
258                    lhs.[<__expand_ $method_name _method>](scope, rhs)
259                }
260            }
261
262            pub trait [<$trait_name Expand>] {
263                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
264            }
265
266            $(impl $trait_name for $type {})*
267            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
268                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
269                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
270                }
271            }
272        }
273    }
274}
275
276macro_rules! impl_binary_func_fixed_output_vectorization {
277    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
278        paste::paste! {
279            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
280                fn $method_name(self, _rhs: Self) -> Self {
281                    unexpanded!()
282                }
283
284                fn [<__expand_ $method_name>](
285                    scope: &mut Scope,
286                    lhs: ExpandElementTyped<Self>,
287                    rhs: ExpandElementTyped<Self>,
288                ) -> ExpandElementTyped<Self> {
289                    lhs.[<__expand_ $method_name _method>](scope, rhs)
290                }
291            }
292
293            pub trait [<$trait_name Expand>] {
294                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
295            }
296
297            $(impl $trait_name for $type {})*
298            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for ExpandElementTyped<T> {
299                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
300                    let lhs: ExpandElement = self.into();
301                    let item = lhs.ty.line($out_vectorization);
302                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
303                }
304            }
305        }
306    }
307}
308
309macro_rules! impl_binary_func_mixed_types {
310    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
311        paste::paste! {
312            pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ExpandElement>> + Sized>:
313                CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
314                fn $method_name(self, _rhs: Rhs) -> Self {
315                    unexpanded!()
316                }
317
318                fn [<__expand_ $method_name>](
319                    scope: &mut Scope,
320                    lhs: ExpandElementTyped<Self>,
321                    rhs: ExpandElementTyped<Rhs>,
322                ) -> ExpandElementTyped<Self> {
323                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
324                }
325            }
326
327            pub trait [<$trait_name Expand>]<Rhs: CubeType>{
328                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
329            }
330
331            $(impl $trait_name<$rhs_ty> for $type {})*
332            impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for ExpandElementTyped<T> {
333                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<Rhs>) -> Self {
334                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
335                }
336            }
337        }
338    }
339}
340
341macro_rules! impl_core_binop {
342    ($trait: ident, $method: ident, $op: expr) => {
343        paste::paste! {
344            pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
345                fn [<__expand_ $method>](
346                    scope: &mut Scope,
347                    lhs: ExpandElementTyped<Self>,
348                    rhs: ExpandElementTyped<Self>,
349                ) -> ExpandElementTyped<Self> {
350                    lhs.[<__expand_ $method _method>](scope, rhs)
351                }
352            }
353
354            pub trait [<$trait Expand>] {
355                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
356            }
357
358            impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
359            impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
360                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
361                    binary_expand(scope, self.into(), rhs.into(), $op).into()
362                }
363            }
364        }
365    };
366}
367
368macro_rules! impl_core_assign_binop {
369    ($trait: ident, $method: ident, $op: expr) => {
370        paste::paste! {
371            pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
372                fn [<__expand_ $method>](
373                    scope: &mut Scope,
374                    lhs: ExpandElementTyped<Self>,
375                    rhs: ExpandElementTyped<Self>,
376                ) {
377                    lhs.[<__expand_ $method _method>](scope, rhs)
378                }
379            }
380
381            pub trait [<$trait Expand>] {
382                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
383            }
384
385            impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
386            impl<T: $trait + CubePrimitive> [<$trait Expand>] for ExpandElementTyped<T> {
387                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
388                    assign_op_expand(scope, self.into(), rhs.into(), $op);
389                }
390            }
391        }
392    };
393}
394
395impl_core_binop!(Add, add, Arithmetic::Add);
396impl_core_binop!(Sub, sub, Arithmetic::Sub);
397impl_core_binop!(Mul, mul, Arithmetic::Mul);
398impl_core_binop!(Div, mul, Arithmetic::Div);
399impl_core_binop!(Rem, rem, Arithmetic::Modulo);
400
401impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
402impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
403impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
404impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
405impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
406
407pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
408    fn __expand_min(
409        scope: &mut Scope,
410        lhs: Self::ExpandType,
411        rhs: Self::ExpandType,
412    ) -> Self::ExpandType {
413        lhs.__expand_min_method(scope, rhs)
414    }
415
416    fn __expand_max(
417        scope: &mut Scope,
418        lhs: Self::ExpandType,
419        rhs: Self::ExpandType,
420    ) -> Self::ExpandType {
421        lhs.__expand_max_method(scope, rhs)
422    }
423
424    fn __expand_clamp(
425        scope: &mut Scope,
426        lhs: Self::ExpandType,
427        min: Self::ExpandType,
428        max: Self::ExpandType,
429    ) -> Self::ExpandType {
430        lhs.__expand_clamp_method(scope, min, max)
431    }
432}
433pub trait OrdExpand {
434    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
435    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
436    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
437}
438
439impl<T: Ord + CubePrimitive> CubeOrd for T {}
440impl<T: Ord + CubePrimitive> OrdExpand for ExpandElementTyped<T> {
441    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
442        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
443    }
444    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
445        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
446    }
447    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
448        unary_expand(scope, self.into(), |op| {
449            Arithmetic::Clamp(ClampOperator {
450                input: op.input,
451                min_value: *min.expand,
452                max_value: *max.expand,
453            })
454        })
455        .into()
456    }
457}
458
459impl_binary_func!(
460    Powf,
461    powf,
462    Arithmetic::Powf,
463    f16,
464    bf16,
465    flex32,
466    tf32,
467    f32,
468    f64
469);
470
471impl_binary_func!(
472    Hypot,
473    hypot,
474    Arithmetic::Hypot,
475    f16,
476    bf16,
477    flex32,
478    tf32,
479    f32,
480    f64
481);
482
483impl_binary_func!(
484    Rhypot,
485    rhypot,
486    Arithmetic::Rhypot,
487    f16,
488    bf16,
489    flex32,
490    tf32,
491    f32,
492    f64
493);
494
495impl_binary_func!(
496    ArcTan2,
497    atan2,
498    Arithmetic::ArcTan2,
499    f16,
500    bf16,
501    flex32,
502    tf32,
503    f32,
504    f64
505);
506impl_binary_func!(
507    Remainder,
508    rem,
509    Arithmetic::Remainder,
510    e2m1,
511    e4m3,
512    e5m2,
513    ue8m0,
514    f16,
515    bf16,
516    flex32,
517    tf32,
518    f32,
519    f64,
520    i8,
521    i16,
522    i32,
523    i64,
524    u8,
525    u16,
526    u32,
527    u64,
528    usize,
529    isize
530);
531impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
532impl_binary_func!(
533    SaturatingAdd,
534    saturating_add,
535    Arithmetic::SaturatingAdd,
536    i8,
537    i16,
538    i32,
539    i64,
540    u8,
541    u16,
542    u32,
543    u64,
544    usize,
545    isize
546);
547impl_binary_func!(
548    SaturatingSub,
549    saturating_sub,
550    Arithmetic::SaturatingSub,
551    i8,
552    i16,
553    i32,
554    i64,
555    u8,
556    u16,
557    u32,
558    u64,
559    usize,
560    isize
561);
562impl_binary_func_fixed_output_vectorization!(
563    Dot,
564    dot,
565    Arithmetic::Dot,
566    0,
567    f16,
568    bf16,
569    flex32,
570    tf32,
571    f32,
572    f64,
573    i8,
574    i16,
575    i32,
576    i64,
577    u8,
578    u16,
579    u32,
580    u64,
581    usize,
582    isize
583);
584
585impl_binary_func_mixed_types!(
586    Powi,
587    powi,
588    i32,
589    Arithmetic::Powi,
590    f16,
591    bf16,
592    flex32,
593    tf32,
594    f32,
595    f64,
596    i8,
597    i16,
598    i32,
599    i64,
600    u8,
601    u16,
602    u32,
603    u64,
604    usize,
605    isize
606);