1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3 flex32,
4 frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15 use super::*;
16
17 pub fn expand<C: CubePrimitive>(
18 scope: &mut Scope,
19 lhs: ExpandElementTyped<C>,
20 rhs: ExpandElementTyped<C>,
21 ) -> ExpandElementTyped<C> {
22 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23 }
24}
25
26pub mod sub {
27 use super::*;
28
29 pub fn expand<C: CubePrimitive>(
30 scope: &mut Scope,
31 lhs: ExpandElementTyped<C>,
32 rhs: ExpandElementTyped<C>,
33 ) -> ExpandElementTyped<C> {
34 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35 }
36}
37
38pub mod mul {
39 use super::*;
40
41 pub fn expand<C: CubePrimitive>(
42 scope: &mut Scope,
43 lhs: ExpandElementTyped<C>,
44 rhs: ExpandElementTyped<C>,
45 ) -> ExpandElementTyped<C> {
46 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47 }
48}
49
50pub mod div {
51 use super::*;
52
53 pub fn expand<C: CubePrimitive>(
54 scope: &mut Scope,
55 lhs: ExpandElementTyped<C>,
56 rhs: ExpandElementTyped<C>,
57 ) -> ExpandElementTyped<C> {
58 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59 }
60}
61
62pub mod rem {
63 use super::*;
64
65 pub fn expand<C: CubePrimitive>(
66 scope: &mut Scope,
67 lhs: ExpandElementTyped<C>,
68 rhs: ExpandElementTyped<C>,
69 ) -> ExpandElementTyped<C> {
70 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71 }
72}
73
74pub mod and {
75 use super::*;
76
77 pub fn expand<C: CubePrimitive>(
78 scope: &mut Scope,
79 lhs: ExpandElementTyped<C>,
80 rhs: ExpandElementTyped<C>,
81 ) -> ExpandElementTyped<bool> {
82 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83 }
84}
85
86pub mod bitand {
87 use super::*;
88
89 pub fn expand<C: CubePrimitive>(
90 scope: &mut Scope,
91 lhs: ExpandElementTyped<C>,
92 rhs: ExpandElementTyped<C>,
93 ) -> ExpandElementTyped<C> {
94 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95 }
96}
97
98pub mod bitor {
99 use super::*;
100
101 pub fn expand<C: CubePrimitive>(
102 scope: &mut Scope,
103 lhs: ExpandElementTyped<C>,
104 rhs: ExpandElementTyped<C>,
105 ) -> ExpandElementTyped<C> {
106 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107 }
108}
109
110pub mod or {
111 use super::*;
112
113 pub fn expand<C: CubePrimitive>(
114 scope: &mut Scope,
115 lhs: ExpandElementTyped<C>,
116 rhs: ExpandElementTyped<C>,
117 ) -> ExpandElementTyped<bool> {
118 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119 }
120}
121
122pub mod bitxor {
123 use super::*;
124
125 pub fn expand<C: CubePrimitive>(
126 scope: &mut Scope,
127 lhs: ExpandElementTyped<C>,
128 rhs: ExpandElementTyped<C>,
129 ) -> ExpandElementTyped<C> {
130 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131 }
132}
133
134pub mod shl {
135 use super::*;
136
137 pub fn expand<C: CubePrimitive>(
138 scope: &mut Scope,
139 lhs: ExpandElementTyped<C>,
140 rhs: ExpandElementTyped<C>,
141 ) -> ExpandElementTyped<C> {
142 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143 }
144}
145
146pub mod shr {
147 use super::*;
148
149 pub fn expand<C: CubePrimitive>(
150 scope: &mut Scope,
151 lhs: ExpandElementTyped<C>,
152 rhs: ExpandElementTyped<C>,
153 ) -> ExpandElementTyped<C> {
154 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155 }
156}
157
158macro_rules! impl_binary_func {
160 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161 paste::paste! {
162 pub trait $trait_name: CubeType + Sized {
163 fn $method_name(self, _rhs: Self) -> Self {
164 unexpanded!()
165 }
166
167 fn [<__expand_ $method_name>](
168 scope: &mut Scope,
169 lhs: ExpandElementTyped<Self>,
170 rhs: ExpandElementTyped<Self>,
171 ) -> ExpandElementTyped<Self> {
172 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173 }
174 }
175
176 $(impl $trait_name for $type {})*
177 $(impl ExpandElementTyped<$type> {
178 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179 binary_expand(scope, self.into(), rhs.into(), $operator).into()
180 }
181 })*
182 }
183 }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188 paste::paste! {
189 pub trait $trait_name: CubeType + Sized {
190 fn $method_name(self, _rhs: Self) -> Self {
191 unexpanded!()
192 }
193
194 fn [<__expand_ $method_name>](
195 scope: &mut Scope,
196 lhs: ExpandElementTyped<Self>,
197 rhs: ExpandElementTyped<Self>,
198 ) -> ExpandElementTyped<Self> {
199 let lhs: ExpandElement = lhs.into();
200 let item = lhs.ty.line($out_vectorization);
201 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202 }
203 }
204
205 $(impl $trait_name for $type {})*
206 $(impl ExpandElementTyped<$type> {
207 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208 let lhs: ExpandElement = self.into();
209 let item = lhs.ty.line($out_vectorization);
210 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211 }
212 })*
213 }
214 }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219 paste::paste! {
220 pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221 fn $method_name(self, _rhs: Rhs) -> Self {
222 unexpanded!()
223 }
224
225 fn [<__expand_ $method_name>](
226 scope: &mut Scope,
227 lhs: ExpandElementTyped<Self>,
228 rhs: ExpandElementTyped<Rhs>,
229 ) -> ExpandElementTyped<Self> {
230 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231 }
232 }
233
234 $(impl $trait_name<$rhs_ty> for $type {})*
235 $(impl ExpandElementTyped<$type> {
236 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237 binary_expand(scope, self.into(), rhs.into(), $operator).into()
238 }
239 })*
240 }
241 }
242}
243
244impl_binary_func!(
245 Powf,
246 powf,
247 Arithmetic::Powf,
248 f16,
249 bf16,
250 flex32,
251 tf32,
252 f32,
253 f64
254);
255impl_binary_func!(
256 ArcTan2,
257 atan2,
258 Arithmetic::ArcTan2,
259 f16,
260 bf16,
261 flex32,
262 tf32,
263 f32,
264 f64
265);
266impl_binary_func!(
267 Max,
268 max,
269 Arithmetic::Max,
270 e2m1,
271 e4m3,
272 e5m2,
273 ue8m0,
274 f16,
275 bf16,
276 flex32,
277 tf32,
278 f32,
279 f64,
280 i8,
281 i16,
282 i32,
283 i64,
284 u8,
285 u16,
286 u32,
287 u64
288);
289impl_binary_func!(
290 Min,
291 min,
292 Arithmetic::Min,
293 e2m1,
294 e4m3,
295 e5m2,
296 ue8m0,
297 f16,
298 bf16,
299 flex32,
300 tf32,
301 f32,
302 f64,
303 i8,
304 i16,
305 i32,
306 i64,
307 u8,
308 u16,
309 u32,
310 u64
311);
312impl_binary_func!(
313 Remainder,
314 rem,
315 Arithmetic::Remainder,
316 e2m1,
317 e4m3,
318 e5m2,
319 ue8m0,
320 f16,
321 bf16,
322 flex32,
323 tf32,
324 f32,
325 f64,
326 i8,
327 i16,
328 i32,
329 i64,
330 u8,
331 u16,
332 u32,
333 u64
334);
335impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
336impl_binary_func!(
337 SaturatingAdd,
338 saturating_add,
339 Arithmetic::SaturatingAdd,
340 i8,
341 i16,
342 i32,
343 i64,
344 u8,
345 u16,
346 u32,
347 u64
348);
349impl_binary_func!(
350 SaturatingSub,
351 saturating_sub,
352 Arithmetic::SaturatingSub,
353 i8,
354 i16,
355 i32,
356 i64,
357 u8,
358 u16,
359 u32,
360 u64
361);
362impl_binary_func_fixed_output_vectorization!(
363 Dot,
364 dot,
365 Arithmetic::Dot,
366 0,
367 f16,
368 bf16,
369 flex32,
370 tf32,
371 f32,
372 f64,
373 i8,
374 i16,
375 i32,
376 i64,
377 u8,
378 u16,
379 u32,
380 u64
381);
382
383impl_binary_func_mixed_types!(
384 Powi,
385 powi,
386 i32,
387 Arithmetic::Powi,
388 f16,
389 bf16,
390 flex32,
391 tf32,
392 f32,
393 f64,
394 i8,
395 i16,
396 i32,
397 i64,
398 u8,
399 u16,
400 u32,
401 u64
402);