cubecl_core/frontend/operation/
binary.rs

1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3    flex32,
4    frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use half::{bf16, f16};
12
13pub mod add {
14    use super::*;
15
16    pub fn expand<C: CubePrimitive>(
17        scope: &mut Scope,
18        lhs: ExpandElementTyped<C>,
19        rhs: ExpandElementTyped<C>,
20    ) -> ExpandElementTyped<C> {
21        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
22    }
23}
24
25pub mod sub {
26    use super::*;
27
28    pub fn expand<C: CubePrimitive>(
29        scope: &mut Scope,
30        lhs: ExpandElementTyped<C>,
31        rhs: ExpandElementTyped<C>,
32    ) -> ExpandElementTyped<C> {
33        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
34    }
35}
36
37pub mod mul {
38    use super::*;
39
40    pub fn expand<C: CubePrimitive>(
41        scope: &mut Scope,
42        lhs: ExpandElementTyped<C>,
43        rhs: ExpandElementTyped<C>,
44    ) -> ExpandElementTyped<C> {
45        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
46    }
47}
48
49pub mod div {
50    use super::*;
51
52    pub fn expand<C: CubePrimitive>(
53        scope: &mut Scope,
54        lhs: ExpandElementTyped<C>,
55        rhs: ExpandElementTyped<C>,
56    ) -> ExpandElementTyped<C> {
57        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
58    }
59}
60
61pub mod rem {
62    use super::*;
63
64    pub fn expand<C: CubePrimitive>(
65        scope: &mut Scope,
66        lhs: ExpandElementTyped<C>,
67        rhs: ExpandElementTyped<C>,
68    ) -> ExpandElementTyped<C> {
69        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
70    }
71}
72
73pub mod and {
74    use super::*;
75
76    pub fn expand<C: CubePrimitive>(
77        scope: &mut Scope,
78        lhs: ExpandElementTyped<C>,
79        rhs: ExpandElementTyped<C>,
80    ) -> ExpandElementTyped<bool> {
81        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
82    }
83}
84
85pub mod bitand {
86    use super::*;
87
88    pub fn expand<C: CubePrimitive>(
89        scope: &mut Scope,
90        lhs: ExpandElementTyped<C>,
91        rhs: ExpandElementTyped<C>,
92    ) -> ExpandElementTyped<C> {
93        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
94    }
95}
96
97pub mod bitor {
98    use super::*;
99
100    pub fn expand<C: CubePrimitive>(
101        scope: &mut Scope,
102        lhs: ExpandElementTyped<C>,
103        rhs: ExpandElementTyped<C>,
104    ) -> ExpandElementTyped<C> {
105        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
106    }
107}
108
109pub mod or {
110    use super::*;
111
112    pub fn expand<C: CubePrimitive>(
113        scope: &mut Scope,
114        lhs: ExpandElementTyped<C>,
115        rhs: ExpandElementTyped<C>,
116    ) -> ExpandElementTyped<bool> {
117        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
118    }
119}
120
121pub mod bitxor {
122    use super::*;
123
124    pub fn expand<C: CubePrimitive>(
125        scope: &mut Scope,
126        lhs: ExpandElementTyped<C>,
127        rhs: ExpandElementTyped<C>,
128    ) -> ExpandElementTyped<C> {
129        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
130    }
131}
132
133pub mod shl {
134    use super::*;
135
136    pub fn expand<C: CubePrimitive>(
137        scope: &mut Scope,
138        lhs: ExpandElementTyped<C>,
139        rhs: ExpandElementTyped<C>,
140    ) -> ExpandElementTyped<C> {
141        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
142    }
143}
144
145pub mod shr {
146    use super::*;
147
148    pub fn expand<C: CubePrimitive>(
149        scope: &mut Scope,
150        lhs: ExpandElementTyped<C>,
151        rhs: ExpandElementTyped<C>,
152    ) -> ExpandElementTyped<C> {
153        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
154    }
155}
156
157/// For binary functions without special syntax
158macro_rules! impl_binary_func {
159    ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
160        pub trait $trait_name: CubeType + Sized {
161            fn $method_name(self, _rhs: Self) -> Self {
162                unexpanded!()
163            }
164
165            fn $func_name_expand(
166                scope: &mut Scope,
167                lhs: ExpandElementTyped<Self>,
168                rhs: ExpandElementTyped<Self>,
169            ) -> ExpandElementTyped<Self> {
170                binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
171            }
172        }
173
174        $(impl $trait_name for $type {})*
175        $(impl ExpandElementTyped<$type> {
176            pub fn $method_name_expand(self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
177                binary_expand(scope, self.into(), rhs.into(), $operator).into()
178            }
179        })*
180    }
181}
182
183macro_rules! impl_binary_func_fixed_output_vectorization {
184    ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
185        pub trait $trait_name: CubeType + Sized {
186            fn $method_name(self, _rhs: Self) -> Self {
187                unexpanded!()
188            }
189
190            fn $func_name_expand(
191                scope: &mut Scope,
192                lhs: ExpandElementTyped<Self>,
193                rhs: ExpandElementTyped<Self>,
194            ) -> ExpandElementTyped<Self> {
195                let lhs: ExpandElement = lhs.into();
196                let mut item = lhs.item;
197                item.vectorization = $out_vectorization;
198                binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
199            }
200        }
201
202        $(impl $trait_name for $type {})*
203        $(impl ExpandElementTyped<$type> {
204            pub fn $method_name_expand(self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
205                let lhs: ExpandElement = self.into();
206                let mut item = lhs.item;
207                item.vectorization = $out_vectorization;
208                binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
209            }
210        })*
211    }
212}
213
214impl_binary_func!(
215    Powf,
216    powf,
217    __expand_powf,
218    __expand_powf_method,
219    Arithmetic::Powf,
220    f16,
221    bf16,
222    flex32,
223    tf32,
224    f32,
225    f64
226);
227impl_binary_func!(
228    Max,
229    max,
230    __expand_max,
231    __expand_max_method,
232    Arithmetic::Max,
233    f16,
234    bf16,
235    flex32,
236    tf32,
237    f32,
238    f64,
239    i8,
240    i16,
241    i32,
242    i64,
243    u8,
244    u16,
245    u32,
246    u64
247);
248impl_binary_func!(
249    Min,
250    min,
251    __expand_min,
252    __expand_min_method,
253    Arithmetic::Min,
254    f16,
255    bf16,
256    flex32,
257    tf32,
258    f32,
259    f64,
260    i8,
261    i16,
262    i32,
263    i64,
264    u8,
265    u16,
266    u32,
267    u64
268);
269impl_binary_func!(
270    Remainder,
271    rem,
272    __expand_rem,
273    __expand_rem_method,
274    Arithmetic::Remainder,
275    f16,
276    bf16,
277    flex32,
278    tf32,
279    f32,
280    f64,
281    i8,
282    i16,
283    i32,
284    i64,
285    u8,
286    u16,
287    u32,
288    u64
289);
290impl_binary_func!(
291    MulHi,
292    mul_hi,
293    __expand_mul_hi,
294    __expand_mul_hi_method,
295    Arithmetic::MulHi,
296    i32,
297    u32
298);
299impl_binary_func_fixed_output_vectorization!(
300    Dot,
301    dot,
302    __expand_dot,
303    __expand_dot_method,
304    Arithmetic::Dot,
305    None,
306    f16,
307    bf16,
308    flex32,
309    tf32,
310    f32,
311    f64,
312    i8,
313    i16,
314    i32,
315    i64,
316    u8,
317    u16,
318    u32,
319    u64
320);