1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3    flex32,
4    frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9    unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15    use super::*;
16
17    pub fn expand<C: CubePrimitive>(
18        scope: &mut Scope,
19        lhs: ExpandElementTyped<C>,
20        rhs: ExpandElementTyped<C>,
21    ) -> ExpandElementTyped<C> {
22        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23    }
24}
25
26pub mod sub {
27    use super::*;
28
29    pub fn expand<C: CubePrimitive>(
30        scope: &mut Scope,
31        lhs: ExpandElementTyped<C>,
32        rhs: ExpandElementTyped<C>,
33    ) -> ExpandElementTyped<C> {
34        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35    }
36}
37
38pub mod mul {
39    use super::*;
40
41    pub fn expand<C: CubePrimitive>(
42        scope: &mut Scope,
43        lhs: ExpandElementTyped<C>,
44        rhs: ExpandElementTyped<C>,
45    ) -> ExpandElementTyped<C> {
46        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47    }
48}
49
50pub mod div {
51    use super::*;
52
53    pub fn expand<C: CubePrimitive>(
54        scope: &mut Scope,
55        lhs: ExpandElementTyped<C>,
56        rhs: ExpandElementTyped<C>,
57    ) -> ExpandElementTyped<C> {
58        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59    }
60}
61
62pub mod rem {
63    use super::*;
64
65    pub fn expand<C: CubePrimitive>(
66        scope: &mut Scope,
67        lhs: ExpandElementTyped<C>,
68        rhs: ExpandElementTyped<C>,
69    ) -> ExpandElementTyped<C> {
70        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71    }
72}
73
74pub mod and {
75    use super::*;
76
77    pub fn expand<C: CubePrimitive>(
78        scope: &mut Scope,
79        lhs: ExpandElementTyped<C>,
80        rhs: ExpandElementTyped<C>,
81    ) -> ExpandElementTyped<bool> {
82        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83    }
84}
85
86pub mod bitand {
87    use super::*;
88
89    pub fn expand<C: CubePrimitive>(
90        scope: &mut Scope,
91        lhs: ExpandElementTyped<C>,
92        rhs: ExpandElementTyped<C>,
93    ) -> ExpandElementTyped<C> {
94        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95    }
96}
97
98pub mod bitor {
99    use super::*;
100
101    pub fn expand<C: CubePrimitive>(
102        scope: &mut Scope,
103        lhs: ExpandElementTyped<C>,
104        rhs: ExpandElementTyped<C>,
105    ) -> ExpandElementTyped<C> {
106        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107    }
108}
109
110pub mod or {
111    use super::*;
112
113    pub fn expand<C: CubePrimitive>(
114        scope: &mut Scope,
115        lhs: ExpandElementTyped<C>,
116        rhs: ExpandElementTyped<C>,
117    ) -> ExpandElementTyped<bool> {
118        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119    }
120}
121
122pub mod bitxor {
123    use super::*;
124
125    pub fn expand<C: CubePrimitive>(
126        scope: &mut Scope,
127        lhs: ExpandElementTyped<C>,
128        rhs: ExpandElementTyped<C>,
129    ) -> ExpandElementTyped<C> {
130        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131    }
132}
133
134pub mod shl {
135    use super::*;
136
137    pub fn expand<C: CubePrimitive>(
138        scope: &mut Scope,
139        lhs: ExpandElementTyped<C>,
140        rhs: ExpandElementTyped<C>,
141    ) -> ExpandElementTyped<C> {
142        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143    }
144}
145
146pub mod shr {
147    use super::*;
148
149    pub fn expand<C: CubePrimitive>(
150        scope: &mut Scope,
151        lhs: ExpandElementTyped<C>,
152        rhs: ExpandElementTyped<C>,
153    ) -> ExpandElementTyped<C> {
154        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155    }
156}
157
158macro_rules! impl_binary_func {
160    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161        paste::paste! {
162            pub trait $trait_name: CubeType + Sized {
163                fn $method_name(self, _rhs: Self) -> Self {
164                    unexpanded!()
165                }
166
167                fn [<__expand_ $method_name>](
168                    scope: &mut Scope,
169                    lhs: ExpandElementTyped<Self>,
170                    rhs: ExpandElementTyped<Self>,
171                ) -> ExpandElementTyped<Self> {
172                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173                }
174            }
175
176            $(impl $trait_name for $type {})*
177            $(impl ExpandElementTyped<$type> {
178                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
180                }
181            })*
182        }
183    }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188        paste::paste! {
189            pub trait $trait_name: CubeType + Sized {
190                fn $method_name(self, _rhs: Self) -> Self {
191                    unexpanded!()
192                }
193
194                fn [<__expand_ $method_name>](
195                    scope: &mut Scope,
196                    lhs: ExpandElementTyped<Self>,
197                    rhs: ExpandElementTyped<Self>,
198                ) -> ExpandElementTyped<Self> {
199                    let lhs: ExpandElement = lhs.into();
200                    let item = lhs.ty.line($out_vectorization);
201                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202                }
203            }
204
205            $(impl $trait_name for $type {})*
206            $(impl ExpandElementTyped<$type> {
207                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208                    let lhs: ExpandElement = self.into();
209                    let item = lhs.ty.line($out_vectorization);
210                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211                }
212            })*
213        }
214    }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219        paste::paste! {
220            pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221                fn $method_name(self, _rhs: Rhs) -> Self {
222                    unexpanded!()
223                }
224
225                fn [<__expand_ $method_name>](
226                    scope: &mut Scope,
227                    lhs: ExpandElementTyped<Self>,
228                    rhs: ExpandElementTyped<Rhs>,
229                ) -> ExpandElementTyped<Self> {
230                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231                }
232            }
233
234            $(impl $trait_name<$rhs_ty> for $type {})*
235            $(impl ExpandElementTyped<$type> {
236                pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
238                }
239            })*
240        }
241    }
242}
243
244impl_binary_func!(
245    Powf,
246    powf,
247    Arithmetic::Powf,
248    f16,
249    bf16,
250    flex32,
251    tf32,
252    f32,
253    f64
254);
255impl_binary_func!(
256    Max,
257    max,
258    Arithmetic::Max,
259    e2m1,
260    e4m3,
261    e5m2,
262    ue8m0,
263    f16,
264    bf16,
265    flex32,
266    tf32,
267    f32,
268    f64,
269    i8,
270    i16,
271    i32,
272    i64,
273    u8,
274    u16,
275    u32,
276    u64
277);
278impl_binary_func!(
279    Min,
280    min,
281    Arithmetic::Min,
282    e2m1,
283    e4m3,
284    e5m2,
285    ue8m0,
286    f16,
287    bf16,
288    flex32,
289    tf32,
290    f32,
291    f64,
292    i8,
293    i16,
294    i32,
295    i64,
296    u8,
297    u16,
298    u32,
299    u64
300);
301impl_binary_func!(
302    Remainder,
303    rem,
304    Arithmetic::Remainder,
305    e2m1,
306    e4m3,
307    e5m2,
308    ue8m0,
309    f16,
310    bf16,
311    flex32,
312    tf32,
313    f32,
314    f64,
315    i8,
316    i16,
317    i32,
318    i64,
319    u8,
320    u16,
321    u32,
322    u64
323);
324impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
325impl_binary_func!(
326    SaturatingAdd,
327    saturating_add,
328    Arithmetic::SaturatingAdd,
329    i8,
330    i16,
331    i32,
332    i64,
333    u8,
334    u16,
335    u32,
336    u64
337);
338impl_binary_func!(
339    SaturatingSub,
340    saturating_sub,
341    Arithmetic::SaturatingSub,
342    i8,
343    i16,
344    i32,
345    i64,
346    u8,
347    u16,
348    u32,
349    u64
350);
351impl_binary_func_fixed_output_vectorization!(
352    Dot,
353    dot,
354    Arithmetic::Dot,
355    0,
356    f16,
357    bf16,
358    flex32,
359    tf32,
360    f32,
361    f64,
362    i8,
363    i16,
364    i32,
365    i64,
366    u8,
367    u16,
368    u32,
369    u64
370);
371
372impl_binary_func_mixed_types!(
373    Powi,
374    powi,
375    i32,
376    Arithmetic::Powi,
377    f16,
378    bf16,
379    flex32,
380    tf32,
381    f32,
382    f64,
383    i8,
384    i16,
385    i32,
386    i64,
387    u8,
388    u16,
389    u32,
390    u64
391);