cubecl_core/frontend/operation/
binary.rs

1use crate::ir::Operator;
2use crate::{
3    flex32,
4    frontend::{CubeContext, CubePrimitive, ExpandElement, ExpandElementTyped},
5};
6use crate::{
7    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
8    unexpanded,
9};
10use crate::{frontend::CubeType, tf32};
11use half::{bf16, f16};
12
13pub mod add {
14    use super::*;
15
16    pub fn expand<C: CubePrimitive>(
17        context: &mut CubeContext,
18        lhs: impl Into<ExpandElementTyped<C>>,
19        rhs: impl Into<ExpandElementTyped<C>>,
20    ) -> ExpandElementTyped<C> {
21        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Add).into()
22    }
23}
24
25pub mod sub {
26    use super::*;
27
28    pub fn expand<C: CubePrimitive>(
29        context: &mut CubeContext,
30        lhs: impl Into<ExpandElementTyped<C>>,
31        rhs: impl Into<ExpandElementTyped<C>>,
32    ) -> ExpandElementTyped<C> {
33        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Sub).into()
34    }
35}
36
37pub mod mul {
38    use super::*;
39
40    pub fn expand<C: CubePrimitive>(
41        context: &mut CubeContext,
42        lhs: impl Into<ExpandElementTyped<C>>,
43        rhs: impl Into<ExpandElementTyped<C>>,
44    ) -> ExpandElementTyped<C> {
45        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Mul).into()
46    }
47}
48
49pub mod div {
50    use super::*;
51
52    pub fn expand<C: CubePrimitive>(
53        context: &mut CubeContext,
54        lhs: impl Into<ExpandElementTyped<C>>,
55        rhs: impl Into<ExpandElementTyped<C>>,
56    ) -> ExpandElementTyped<C> {
57        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Div).into()
58    }
59}
60
61pub mod rem {
62    use super::*;
63
64    pub fn expand<C: CubePrimitive>(
65        context: &mut CubeContext,
66        lhs: impl Into<ExpandElementTyped<C>>,
67        rhs: impl Into<ExpandElementTyped<C>>,
68    ) -> ExpandElementTyped<C> {
69        binary_expand(
70            context,
71            lhs.into().into(),
72            rhs.into().into(),
73            Operator::Modulo,
74        )
75        .into()
76    }
77}
78
79pub mod and {
80    use super::*;
81
82    pub fn expand<C: CubePrimitive>(
83        context: &mut CubeContext,
84        lhs: impl Into<ExpandElementTyped<C>>,
85        rhs: impl Into<ExpandElementTyped<C>>,
86    ) -> ExpandElementTyped<bool> {
87        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::And).into()
88    }
89}
90
91pub mod bitand {
92    use super::*;
93
94    pub fn expand<C: CubePrimitive>(
95        context: &mut CubeContext,
96        lhs: impl Into<ExpandElementTyped<C>>,
97        rhs: impl Into<ExpandElementTyped<C>>,
98    ) -> ExpandElementTyped<C> {
99        binary_expand(
100            context,
101            lhs.into().into(),
102            rhs.into().into(),
103            Operator::BitwiseAnd,
104        )
105        .into()
106    }
107}
108
109pub mod bitor {
110    use super::*;
111
112    pub fn expand<C: CubePrimitive>(
113        context: &mut CubeContext,
114        lhs: impl Into<ExpandElementTyped<C>>,
115        rhs: impl Into<ExpandElementTyped<C>>,
116    ) -> ExpandElementTyped<C> {
117        binary_expand(
118            context,
119            lhs.into().into(),
120            rhs.into().into(),
121            Operator::BitwiseOr,
122        )
123        .into()
124    }
125}
126
127pub mod or {
128    use super::*;
129
130    pub fn expand<C: CubePrimitive>(
131        context: &mut CubeContext,
132        lhs: impl Into<ExpandElementTyped<C>>,
133        rhs: impl Into<ExpandElementTyped<C>>,
134    ) -> ExpandElementTyped<bool> {
135        binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Or).into()
136    }
137}
138
139pub mod bitxor {
140    use super::*;
141
142    pub fn expand<C: CubePrimitive>(
143        context: &mut CubeContext,
144        lhs: impl Into<ExpandElementTyped<C>>,
145        rhs: impl Into<ExpandElementTyped<C>>,
146    ) -> ExpandElementTyped<C> {
147        binary_expand(
148            context,
149            lhs.into().into(),
150            rhs.into().into(),
151            Operator::BitwiseXor,
152        )
153        .into()
154    }
155}
156
157pub mod shl {
158    use super::*;
159
160    pub fn expand<C: CubePrimitive>(
161        context: &mut CubeContext,
162        lhs: impl Into<ExpandElementTyped<C>>,
163        rhs: impl Into<ExpandElementTyped<C>>,
164    ) -> ExpandElementTyped<C> {
165        binary_expand(
166            context,
167            lhs.into().into(),
168            rhs.into().into(),
169            Operator::ShiftLeft,
170        )
171        .into()
172    }
173}
174
175pub mod shr {
176    use super::*;
177
178    pub fn expand<C: CubePrimitive>(
179        context: &mut CubeContext,
180        lhs: impl Into<ExpandElementTyped<C>>,
181        rhs: impl Into<ExpandElementTyped<C>>,
182    ) -> ExpandElementTyped<C> {
183        binary_expand(
184            context,
185            lhs.into().into(),
186            rhs.into().into(),
187            Operator::ShiftRight,
188        )
189        .into()
190    }
191}
192
193/// For binary functions without special syntax
194macro_rules! impl_binary_func {
195    ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
196        pub trait $trait_name: CubeType + Sized {
197            fn $method_name(self, _rhs: Self) -> Self {
198                unexpanded!()
199            }
200
201            fn $func_name_expand(
202                context: &mut CubeContext,
203                lhs: ExpandElementTyped<Self>,
204                rhs: ExpandElementTyped<Self>,
205            ) -> ExpandElementTyped<Self> {
206                binary_expand(context, lhs.into(), rhs.into(), $operator).into()
207            }
208        }
209
210        $(impl $trait_name for $type {})*
211        $(impl ExpandElementTyped<$type> {
212            pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
213                binary_expand(context, self.into(), rhs.into(), $operator).into()
214            }
215        })*
216    }
217}
218
219macro_rules! impl_binary_func_fixed_output_vectorization {
220    ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
221        pub trait $trait_name: CubeType + Sized {
222            fn $method_name(self, _rhs: Self) -> Self {
223                unexpanded!()
224            }
225
226            fn $func_name_expand(
227                context: &mut CubeContext,
228                lhs: ExpandElementTyped<Self>,
229                rhs: ExpandElementTyped<Self>,
230            ) -> ExpandElementTyped<Self> {
231                let lhs: ExpandElement = lhs.into();
232                let mut item = lhs.item;
233                item.vectorization = $out_vectorization;
234                binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
235            }
236        }
237
238        $(impl $trait_name for $type {})*
239        $(impl ExpandElementTyped<$type> {
240            pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
241                let lhs: ExpandElement = self.into();
242                let mut item = lhs.item;
243                item.vectorization = $out_vectorization;
244                binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
245            }
246        })*
247    }
248}
249
250impl_binary_func!(
251    Powf,
252    powf,
253    __expand_powf,
254    __expand_powf_method,
255    Operator::Powf,
256    f16,
257    bf16,
258    flex32,
259    tf32,
260    f32,
261    f64
262);
263impl_binary_func!(
264    Max,
265    max,
266    __expand_max,
267    __expand_max_method,
268    Operator::Max,
269    f16,
270    bf16,
271    flex32,
272    tf32,
273    f32,
274    f64,
275    i8,
276    i16,
277    i32,
278    i64,
279    u8,
280    u16,
281    u32,
282    u64
283);
284impl_binary_func!(
285    Min,
286    min,
287    __expand_min,
288    __expand_min_method,
289    Operator::Min,
290    f16,
291    bf16,
292    flex32,
293    tf32,
294    f32,
295    f64,
296    i8,
297    i16,
298    i32,
299    i64,
300    u8,
301    u16,
302    u32,
303    u64
304);
305impl_binary_func!(
306    Remainder,
307    rem,
308    __expand_rem,
309    __expand_rem_method,
310    Operator::Remainder,
311    f16,
312    bf16,
313    flex32,
314    tf32,
315    f32,
316    f64,
317    i8,
318    i16,
319    i32,
320    i64,
321    u8,
322    u16,
323    u32,
324    u64
325);
326impl_binary_func_fixed_output_vectorization!(
327    Dot,
328    dot,
329    __expand_dot,
330    __expand_dot_method,
331    Operator::Dot,
332    None,
333    f16,
334    bf16,
335    flex32,
336    tf32,
337    f32,
338    f64,
339    i8,
340    i16,
341    i32,
342    i64,
343    u8,
344    u16,
345    u32,
346    u64
347);