cubecl_core/frontend/operation/
binary.rs

1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3    flex32,
4    frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15    use super::*;
16
17    pub fn expand<C: CubePrimitive>(
18        scope: &mut Scope,
19        lhs: ExpandElementTyped<C>,
20        rhs: ExpandElementTyped<C>,
21    ) -> ExpandElementTyped<C> {
22        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23    }
24}
25
26pub mod sub {
27    use super::*;
28
29    pub fn expand<C: CubePrimitive>(
30        scope: &mut Scope,
31        lhs: ExpandElementTyped<C>,
32        rhs: ExpandElementTyped<C>,
33    ) -> ExpandElementTyped<C> {
34        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35    }
36}
37
38pub mod mul {
39    use super::*;
40
41    pub fn expand<C: CubePrimitive>(
42        scope: &mut Scope,
43        lhs: ExpandElementTyped<C>,
44        rhs: ExpandElementTyped<C>,
45    ) -> ExpandElementTyped<C> {
46        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47    }
48}
49
50pub mod div {
51    use super::*;
52
53    pub fn expand<C: CubePrimitive>(
54        scope: &mut Scope,
55        lhs: ExpandElementTyped<C>,
56        rhs: ExpandElementTyped<C>,
57    ) -> ExpandElementTyped<C> {
58        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59    }
60}
61
62pub mod rem {
63    use super::*;
64
65    pub fn expand<C: CubePrimitive>(
66        scope: &mut Scope,
67        lhs: ExpandElementTyped<C>,
68        rhs: ExpandElementTyped<C>,
69    ) -> ExpandElementTyped<C> {
70        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71    }
72}
73
74pub mod and {
75    use super::*;
76
77    pub fn expand<C: CubePrimitive>(
78        scope: &mut Scope,
79        lhs: ExpandElementTyped<C>,
80        rhs: ExpandElementTyped<C>,
81    ) -> ExpandElementTyped<bool> {
82        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83    }
84}
85
86pub mod bitand {
87    use super::*;
88
89    pub fn expand<C: CubePrimitive>(
90        scope: &mut Scope,
91        lhs: ExpandElementTyped<C>,
92        rhs: ExpandElementTyped<C>,
93    ) -> ExpandElementTyped<C> {
94        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95    }
96}
97
98pub mod bitor {
99    use super::*;
100
101    pub fn expand<C: CubePrimitive>(
102        scope: &mut Scope,
103        lhs: ExpandElementTyped<C>,
104        rhs: ExpandElementTyped<C>,
105    ) -> ExpandElementTyped<C> {
106        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107    }
108}
109
110pub mod or {
111    use super::*;
112
113    pub fn expand<C: CubePrimitive>(
114        scope: &mut Scope,
115        lhs: ExpandElementTyped<C>,
116        rhs: ExpandElementTyped<C>,
117    ) -> ExpandElementTyped<bool> {
118        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119    }
120}
121
122pub mod bitxor {
123    use super::*;
124
125    pub fn expand<C: CubePrimitive>(
126        scope: &mut Scope,
127        lhs: ExpandElementTyped<C>,
128        rhs: ExpandElementTyped<C>,
129    ) -> ExpandElementTyped<C> {
130        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131    }
132}
133
134pub mod shl {
135    use super::*;
136
137    pub fn expand<C: CubePrimitive>(
138        scope: &mut Scope,
139        lhs: ExpandElementTyped<C>,
140        rhs: ExpandElementTyped<C>,
141    ) -> ExpandElementTyped<C> {
142        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143    }
144}
145
146pub mod shr {
147    use super::*;
148
149    pub fn expand<C: CubePrimitive>(
150        scope: &mut Scope,
151        lhs: ExpandElementTyped<C>,
152        rhs: ExpandElementTyped<C>,
153    ) -> ExpandElementTyped<C> {
154        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155    }
156}
157
158/// For binary functions without special syntax
159macro_rules! impl_binary_func {
160    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161        paste::paste! {
162            pub trait $trait_name: CubeType + Sized {
163                fn $method_name(self, _rhs: Self) -> Self {
164                    unexpanded!()
165                }
166
167                fn [<__expand_ $method_name>](
168                    scope: &mut Scope,
169                    lhs: ExpandElementTyped<Self>,
170                    rhs: ExpandElementTyped<Self>,
171                ) -> ExpandElementTyped<Self> {
172                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173                }
174            }
175
176            $(impl $trait_name for $type {})*
177            $(impl ExpandElementTyped<$type> {
178                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
180                }
181            })*
182        }
183    }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188        paste::paste! {
189            pub trait $trait_name: CubeType + Sized {
190                fn $method_name(self, _rhs: Self) -> Self {
191                    unexpanded!()
192                }
193
194                fn [<__expand_ $method_name>](
195                    scope: &mut Scope,
196                    lhs: ExpandElementTyped<Self>,
197                    rhs: ExpandElementTyped<Self>,
198                ) -> ExpandElementTyped<Self> {
199                    let lhs: ExpandElement = lhs.into();
200                    let item = lhs.ty.line($out_vectorization);
201                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202                }
203            }
204
205            $(impl $trait_name for $type {})*
206            $(impl ExpandElementTyped<$type> {
207                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208                    let lhs: ExpandElement = self.into();
209                    let item = lhs.ty.line($out_vectorization);
210                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211                }
212            })*
213        }
214    }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219        paste::paste! {
220            pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221                fn $method_name(self, _rhs: Rhs) -> Self {
222                    unexpanded!()
223                }
224
225                fn [<__expand_ $method_name>](
226                    scope: &mut Scope,
227                    lhs: ExpandElementTyped<Self>,
228                    rhs: ExpandElementTyped<Rhs>,
229                ) -> ExpandElementTyped<Self> {
230                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231                }
232            }
233
234            $(impl $trait_name<$rhs_ty> for $type {})*
235            $(impl ExpandElementTyped<$type> {
236                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
238                }
239            })*
240        }
241    }
242}
243
244impl_binary_func!(
245    Powf,
246    powf,
247    Arithmetic::Powf,
248    f16,
249    bf16,
250    flex32,
251    tf32,
252    f32,
253    f64
254);
255
256impl_binary_func!(
257    Hypot,
258    hypot,
259    Arithmetic::Hypot,
260    f16,
261    bf16,
262    flex32,
263    tf32,
264    f32,
265    f64
266);
267
268impl_binary_func!(
269    Rhypot,
270    rhypot,
271    Arithmetic::Rhypot,
272    f16,
273    bf16,
274    flex32,
275    tf32,
276    f32,
277    f64
278);
279
280impl_binary_func!(
281    ArcTan2,
282    atan2,
283    Arithmetic::ArcTan2,
284    f16,
285    bf16,
286    flex32,
287    tf32,
288    f32,
289    f64
290);
291impl_binary_func!(
292    Max,
293    max,
294    Arithmetic::Max,
295    e2m1,
296    e4m3,
297    e5m2,
298    ue8m0,
299    f16,
300    bf16,
301    flex32,
302    tf32,
303    f32,
304    f64,
305    i8,
306    i16,
307    i32,
308    i64,
309    u8,
310    u16,
311    u32,
312    u64
313);
314impl_binary_func!(
315    Min,
316    min,
317    Arithmetic::Min,
318    e2m1,
319    e4m3,
320    e5m2,
321    ue8m0,
322    f16,
323    bf16,
324    flex32,
325    tf32,
326    f32,
327    f64,
328    i8,
329    i16,
330    i32,
331    i64,
332    u8,
333    u16,
334    u32,
335    u64
336);
337impl_binary_func!(
338    Remainder,
339    rem,
340    Arithmetic::Remainder,
341    e2m1,
342    e4m3,
343    e5m2,
344    ue8m0,
345    f16,
346    bf16,
347    flex32,
348    tf32,
349    f32,
350    f64,
351    i8,
352    i16,
353    i32,
354    i64,
355    u8,
356    u16,
357    u32,
358    u64
359);
360impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
361impl_binary_func!(
362    SaturatingAdd,
363    saturating_add,
364    Arithmetic::SaturatingAdd,
365    i8,
366    i16,
367    i32,
368    i64,
369    u8,
370    u16,
371    u32,
372    u64
373);
374impl_binary_func!(
375    SaturatingSub,
376    saturating_sub,
377    Arithmetic::SaturatingSub,
378    i8,
379    i16,
380    i32,
381    i64,
382    u8,
383    u16,
384    u32,
385    u64
386);
387impl_binary_func_fixed_output_vectorization!(
388    Dot,
389    dot,
390    Arithmetic::Dot,
391    0,
392    f16,
393    bf16,
394    flex32,
395    tf32,
396    f32,
397    f64,
398    i8,
399    i16,
400    i32,
401    i64,
402    u8,
403    u16,
404    u32,
405    u64
406);
407
408impl_binary_func_mixed_types!(
409    Powi,
410    powi,
411    i32,
412    Arithmetic::Powi,
413    f16,
414    bf16,
415    flex32,
416    tf32,
417    f32,
418    f64,
419    i8,
420    i16,
421    i32,
422    i64,
423    u8,
424    u16,
425    u32,
426    u64
427);