1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3 flex32,
4 frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15 use super::*;
16
17 pub fn expand<C: CubePrimitive>(
18 scope: &mut Scope,
19 lhs: ExpandElementTyped<C>,
20 rhs: ExpandElementTyped<C>,
21 ) -> ExpandElementTyped<C> {
22 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23 }
24}
25
26pub mod sub {
27 use super::*;
28
29 pub fn expand<C: CubePrimitive>(
30 scope: &mut Scope,
31 lhs: ExpandElementTyped<C>,
32 rhs: ExpandElementTyped<C>,
33 ) -> ExpandElementTyped<C> {
34 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35 }
36}
37
38pub mod mul {
39 use super::*;
40
41 pub fn expand<C: CubePrimitive>(
42 scope: &mut Scope,
43 lhs: ExpandElementTyped<C>,
44 rhs: ExpandElementTyped<C>,
45 ) -> ExpandElementTyped<C> {
46 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47 }
48}
49
50pub mod div {
51 use super::*;
52
53 pub fn expand<C: CubePrimitive>(
54 scope: &mut Scope,
55 lhs: ExpandElementTyped<C>,
56 rhs: ExpandElementTyped<C>,
57 ) -> ExpandElementTyped<C> {
58 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59 }
60}
61
62pub mod rem {
63 use super::*;
64
65 pub fn expand<C: CubePrimitive>(
66 scope: &mut Scope,
67 lhs: ExpandElementTyped<C>,
68 rhs: ExpandElementTyped<C>,
69 ) -> ExpandElementTyped<C> {
70 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71 }
72}
73
74pub mod and {
75 use super::*;
76
77 pub fn expand<C: CubePrimitive>(
78 scope: &mut Scope,
79 lhs: ExpandElementTyped<C>,
80 rhs: ExpandElementTyped<C>,
81 ) -> ExpandElementTyped<bool> {
82 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83 }
84}
85
86pub mod bitand {
87 use super::*;
88
89 pub fn expand<C: CubePrimitive>(
90 scope: &mut Scope,
91 lhs: ExpandElementTyped<C>,
92 rhs: ExpandElementTyped<C>,
93 ) -> ExpandElementTyped<C> {
94 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95 }
96}
97
98pub mod bitor {
99 use super::*;
100
101 pub fn expand<C: CubePrimitive>(
102 scope: &mut Scope,
103 lhs: ExpandElementTyped<C>,
104 rhs: ExpandElementTyped<C>,
105 ) -> ExpandElementTyped<C> {
106 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107 }
108}
109
110pub mod or {
111 use super::*;
112
113 pub fn expand<C: CubePrimitive>(
114 scope: &mut Scope,
115 lhs: ExpandElementTyped<C>,
116 rhs: ExpandElementTyped<C>,
117 ) -> ExpandElementTyped<bool> {
118 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119 }
120}
121
122pub mod bitxor {
123 use super::*;
124
125 pub fn expand<C: CubePrimitive>(
126 scope: &mut Scope,
127 lhs: ExpandElementTyped<C>,
128 rhs: ExpandElementTyped<C>,
129 ) -> ExpandElementTyped<C> {
130 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131 }
132}
133
134pub mod shl {
135 use super::*;
136
137 pub fn expand<C: CubePrimitive>(
138 scope: &mut Scope,
139 lhs: ExpandElementTyped<C>,
140 rhs: ExpandElementTyped<C>,
141 ) -> ExpandElementTyped<C> {
142 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143 }
144}
145
146pub mod shr {
147 use super::*;
148
149 pub fn expand<C: CubePrimitive>(
150 scope: &mut Scope,
151 lhs: ExpandElementTyped<C>,
152 rhs: ExpandElementTyped<C>,
153 ) -> ExpandElementTyped<C> {
154 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155 }
156}
157
158macro_rules! impl_binary_func {
160 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161 paste::paste! {
162 pub trait $trait_name: CubeType + Sized {
163 fn $method_name(self, _rhs: Self) -> Self {
164 unexpanded!()
165 }
166
167 fn [<__expand_ $method_name>](
168 scope: &mut Scope,
169 lhs: ExpandElementTyped<Self>,
170 rhs: ExpandElementTyped<Self>,
171 ) -> ExpandElementTyped<Self> {
172 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173 }
174 }
175
176 $(impl $trait_name for $type {})*
177 $(impl ExpandElementTyped<$type> {
178 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179 binary_expand(scope, self.into(), rhs.into(), $operator).into()
180 }
181 })*
182 }
183 }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188 paste::paste! {
189 pub trait $trait_name: CubeType + Sized {
190 fn $method_name(self, _rhs: Self) -> Self {
191 unexpanded!()
192 }
193
194 fn [<__expand_ $method_name>](
195 scope: &mut Scope,
196 lhs: ExpandElementTyped<Self>,
197 rhs: ExpandElementTyped<Self>,
198 ) -> ExpandElementTyped<Self> {
199 let lhs: ExpandElement = lhs.into();
200 let item = lhs.ty.line($out_vectorization);
201 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202 }
203 }
204
205 $(impl $trait_name for $type {})*
206 $(impl ExpandElementTyped<$type> {
207 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208 let lhs: ExpandElement = self.into();
209 let item = lhs.ty.line($out_vectorization);
210 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211 }
212 })*
213 }
214 }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219 paste::paste! {
220 pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221 fn $method_name(self, _rhs: Rhs) -> Self {
222 unexpanded!()
223 }
224
225 fn [<__expand_ $method_name>](
226 scope: &mut Scope,
227 lhs: ExpandElementTyped<Self>,
228 rhs: ExpandElementTyped<Rhs>,
229 ) -> ExpandElementTyped<Self> {
230 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231 }
232 }
233
234 $(impl $trait_name<$rhs_ty> for $type {})*
235 $(impl ExpandElementTyped<$type> {
236 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237 binary_expand(scope, self.into(), rhs.into(), $operator).into()
238 }
239 })*
240 }
241 }
242}
243
244impl_binary_func!(
245 Powf,
246 powf,
247 Arithmetic::Powf,
248 f16,
249 bf16,
250 flex32,
251 tf32,
252 f32,
253 f64
254);
255
256impl_binary_func!(
257 Hypot,
258 hypot,
259 Arithmetic::Hypot,
260 f16,
261 bf16,
262 flex32,
263 tf32,
264 f32,
265 f64
266);
267
268impl_binary_func!(
269 Rhypot,
270 rhypot,
271 Arithmetic::Rhypot,
272 f16,
273 bf16,
274 flex32,
275 tf32,
276 f32,
277 f64
278);
279
280impl_binary_func!(
281 ArcTan2,
282 atan2,
283 Arithmetic::ArcTan2,
284 f16,
285 bf16,
286 flex32,
287 tf32,
288 f32,
289 f64
290);
291impl_binary_func!(
292 Max,
293 max,
294 Arithmetic::Max,
295 e2m1,
296 e4m3,
297 e5m2,
298 ue8m0,
299 f16,
300 bf16,
301 flex32,
302 tf32,
303 f32,
304 f64,
305 i8,
306 i16,
307 i32,
308 i64,
309 u8,
310 u16,
311 u32,
312 u64
313);
314impl_binary_func!(
315 Min,
316 min,
317 Arithmetic::Min,
318 e2m1,
319 e4m3,
320 e5m2,
321 ue8m0,
322 f16,
323 bf16,
324 flex32,
325 tf32,
326 f32,
327 f64,
328 i8,
329 i16,
330 i32,
331 i64,
332 u8,
333 u16,
334 u32,
335 u64
336);
337impl_binary_func!(
338 Remainder,
339 rem,
340 Arithmetic::Remainder,
341 e2m1,
342 e4m3,
343 e5m2,
344 ue8m0,
345 f16,
346 bf16,
347 flex32,
348 tf32,
349 f32,
350 f64,
351 i8,
352 i16,
353 i32,
354 i64,
355 u8,
356 u16,
357 u32,
358 u64
359);
360impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
361impl_binary_func!(
362 SaturatingAdd,
363 saturating_add,
364 Arithmetic::SaturatingAdd,
365 i8,
366 i16,
367 i32,
368 i64,
369 u8,
370 u16,
371 u32,
372 u64
373);
374impl_binary_func!(
375 SaturatingSub,
376 saturating_sub,
377 Arithmetic::SaturatingSub,
378 i8,
379 i16,
380 i32,
381 i64,
382 u8,
383 u16,
384 u32,
385 u64
386);
387impl_binary_func_fixed_output_vectorization!(
388 Dot,
389 dot,
390 Arithmetic::Dot,
391 0,
392 f16,
393 bf16,
394 flex32,
395 tf32,
396 f32,
397 f64,
398 i8,
399 i16,
400 i32,
401 i64,
402 u8,
403 u16,
404 u32,
405 u64
406);
407
408impl_binary_func_mixed_types!(
409 Powi,
410 powi,
411 i32,
412 Arithmetic::Powi,
413 f16,
414 bf16,
415 flex32,
416 tf32,
417 f32,
418 f64,
419 i8,
420 i16,
421 i32,
422 i64,
423 u8,
424 u16,
425 u32,
426 u64
427);