cubecl_core/frontend/operation/
binary.rs

1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3    flex32,
4    frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15    use super::*;
16
17    pub fn expand<C: CubePrimitive>(
18        scope: &mut Scope,
19        lhs: ExpandElementTyped<C>,
20        rhs: ExpandElementTyped<C>,
21    ) -> ExpandElementTyped<C> {
22        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23    }
24}
25
26pub mod sub {
27    use super::*;
28
29    pub fn expand<C: CubePrimitive>(
30        scope: &mut Scope,
31        lhs: ExpandElementTyped<C>,
32        rhs: ExpandElementTyped<C>,
33    ) -> ExpandElementTyped<C> {
34        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35    }
36}
37
38pub mod mul {
39    use super::*;
40
41    pub fn expand<C: CubePrimitive>(
42        scope: &mut Scope,
43        lhs: ExpandElementTyped<C>,
44        rhs: ExpandElementTyped<C>,
45    ) -> ExpandElementTyped<C> {
46        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47    }
48}
49
50pub mod div {
51    use super::*;
52
53    pub fn expand<C: CubePrimitive>(
54        scope: &mut Scope,
55        lhs: ExpandElementTyped<C>,
56        rhs: ExpandElementTyped<C>,
57    ) -> ExpandElementTyped<C> {
58        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59    }
60}
61
62pub mod rem {
63    use super::*;
64
65    pub fn expand<C: CubePrimitive>(
66        scope: &mut Scope,
67        lhs: ExpandElementTyped<C>,
68        rhs: ExpandElementTyped<C>,
69    ) -> ExpandElementTyped<C> {
70        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71    }
72}
73
74pub mod and {
75    use super::*;
76
77    pub fn expand<C: CubePrimitive>(
78        scope: &mut Scope,
79        lhs: ExpandElementTyped<C>,
80        rhs: ExpandElementTyped<C>,
81    ) -> ExpandElementTyped<bool> {
82        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83    }
84}
85
86pub mod bitand {
87    use super::*;
88
89    pub fn expand<C: CubePrimitive>(
90        scope: &mut Scope,
91        lhs: ExpandElementTyped<C>,
92        rhs: ExpandElementTyped<C>,
93    ) -> ExpandElementTyped<C> {
94        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95    }
96}
97
98pub mod bitor {
99    use super::*;
100
101    pub fn expand<C: CubePrimitive>(
102        scope: &mut Scope,
103        lhs: ExpandElementTyped<C>,
104        rhs: ExpandElementTyped<C>,
105    ) -> ExpandElementTyped<C> {
106        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107    }
108}
109
110pub mod or {
111    use super::*;
112
113    pub fn expand<C: CubePrimitive>(
114        scope: &mut Scope,
115        lhs: ExpandElementTyped<C>,
116        rhs: ExpandElementTyped<C>,
117    ) -> ExpandElementTyped<bool> {
118        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119    }
120}
121
122pub mod bitxor {
123    use super::*;
124
125    pub fn expand<C: CubePrimitive>(
126        scope: &mut Scope,
127        lhs: ExpandElementTyped<C>,
128        rhs: ExpandElementTyped<C>,
129    ) -> ExpandElementTyped<C> {
130        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131    }
132}
133
134pub mod shl {
135    use super::*;
136
137    pub fn expand<C: CubePrimitive>(
138        scope: &mut Scope,
139        lhs: ExpandElementTyped<C>,
140        rhs: ExpandElementTyped<C>,
141    ) -> ExpandElementTyped<C> {
142        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143    }
144}
145
146pub mod shr {
147    use super::*;
148
149    pub fn expand<C: CubePrimitive>(
150        scope: &mut Scope,
151        lhs: ExpandElementTyped<C>,
152        rhs: ExpandElementTyped<C>,
153    ) -> ExpandElementTyped<C> {
154        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155    }
156}
157
158/// For binary functions without special syntax
159macro_rules! impl_binary_func {
160    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161        paste::paste! {
162            pub trait $trait_name: CubeType + Sized {
163                fn $method_name(self, _rhs: Self) -> Self {
164                    unexpanded!()
165                }
166
167                fn [<__expand_ $method_name>](
168                    scope: &mut Scope,
169                    lhs: ExpandElementTyped<Self>,
170                    rhs: ExpandElementTyped<Self>,
171                ) -> ExpandElementTyped<Self> {
172                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173                }
174            }
175
176            $(impl $trait_name for $type {})*
177            $(impl ExpandElementTyped<$type> {
178                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
180                }
181            })*
182        }
183    }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188        paste::paste! {
189            pub trait $trait_name: CubeType + Sized {
190                fn $method_name(self, _rhs: Self) -> Self {
191                    unexpanded!()
192                }
193
194                fn [<__expand_ $method_name>](
195                    scope: &mut Scope,
196                    lhs: ExpandElementTyped<Self>,
197                    rhs: ExpandElementTyped<Self>,
198                ) -> ExpandElementTyped<Self> {
199                    let lhs: ExpandElement = lhs.into();
200                    let item = lhs.ty.line($out_vectorization);
201                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202                }
203            }
204
205            $(impl $trait_name for $type {})*
206            $(impl ExpandElementTyped<$type> {
207                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208                    let lhs: ExpandElement = self.into();
209                    let item = lhs.ty.line($out_vectorization);
210                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211                }
212            })*
213        }
214    }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219        paste::paste! {
220            pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221                fn $method_name(self, _rhs: Rhs) -> Self {
222                    unexpanded!()
223                }
224
225                fn [<__expand_ $method_name>](
226                    scope: &mut Scope,
227                    lhs: ExpandElementTyped<Self>,
228                    rhs: ExpandElementTyped<Rhs>,
229                ) -> ExpandElementTyped<Self> {
230                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231                }
232            }
233
234            $(impl $trait_name<$rhs_ty> for $type {})*
235            $(impl ExpandElementTyped<$type> {
236                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
238                }
239            })*
240        }
241    }
242}
243
244impl_binary_func!(
245    Powf,
246    powf,
247    Arithmetic::Powf,
248    f16,
249    bf16,
250    flex32,
251    tf32,
252    f32,
253    f64
254);
255impl_binary_func!(
256    ArcTan2,
257    atan2,
258    Arithmetic::ArcTan2,
259    f16,
260    bf16,
261    flex32,
262    tf32,
263    f32,
264    f64
265);
266impl_binary_func!(
267    Max,
268    max,
269    Arithmetic::Max,
270    e2m1,
271    e4m3,
272    e5m2,
273    ue8m0,
274    f16,
275    bf16,
276    flex32,
277    tf32,
278    f32,
279    f64,
280    i8,
281    i16,
282    i32,
283    i64,
284    u8,
285    u16,
286    u32,
287    u64
288);
289impl_binary_func!(
290    Min,
291    min,
292    Arithmetic::Min,
293    e2m1,
294    e4m3,
295    e5m2,
296    ue8m0,
297    f16,
298    bf16,
299    flex32,
300    tf32,
301    f32,
302    f64,
303    i8,
304    i16,
305    i32,
306    i64,
307    u8,
308    u16,
309    u32,
310    u64
311);
312impl_binary_func!(
313    Remainder,
314    rem,
315    Arithmetic::Remainder,
316    e2m1,
317    e4m3,
318    e5m2,
319    ue8m0,
320    f16,
321    bf16,
322    flex32,
323    tf32,
324    f32,
325    f64,
326    i8,
327    i16,
328    i32,
329    i64,
330    u8,
331    u16,
332    u32,
333    u64
334);
335impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
336impl_binary_func!(
337    SaturatingAdd,
338    saturating_add,
339    Arithmetic::SaturatingAdd,
340    i8,
341    i16,
342    i32,
343    i64,
344    u8,
345    u16,
346    u32,
347    u64
348);
349impl_binary_func!(
350    SaturatingSub,
351    saturating_sub,
352    Arithmetic::SaturatingSub,
353    i8,
354    i16,
355    i32,
356    i64,
357    u8,
358    u16,
359    u32,
360    u64
361);
362impl_binary_func_fixed_output_vectorization!(
363    Dot,
364    dot,
365    Arithmetic::Dot,
366    0,
367    f16,
368    bf16,
369    flex32,
370    tf32,
371    f32,
372    f64,
373    i8,
374    i16,
375    i32,
376    i64,
377    u8,
378    u16,
379    u32,
380    u64
381);
382
383impl_binary_func_mixed_types!(
384    Powi,
385    powi,
386    i32,
387    Arithmetic::Powi,
388    f16,
389    bf16,
390    flex32,
391    tf32,
392    f32,
393    f64,
394    i8,
395    i16,
396    i32,
397    i64,
398    u8,
399    u16,
400    u32,
401    u64
402);