1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3 flex32,
4 frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
12use half::{bf16, f16};
13
14pub mod add {
15 use super::*;
16
17 pub fn expand<C: CubePrimitive>(
18 scope: &mut Scope,
19 lhs: ExpandElementTyped<C>,
20 rhs: ExpandElementTyped<C>,
21 ) -> ExpandElementTyped<C> {
22 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
23 }
24}
25
26pub mod sub {
27 use super::*;
28
29 pub fn expand<C: CubePrimitive>(
30 scope: &mut Scope,
31 lhs: ExpandElementTyped<C>,
32 rhs: ExpandElementTyped<C>,
33 ) -> ExpandElementTyped<C> {
34 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
35 }
36}
37
38pub mod mul {
39 use super::*;
40
41 pub fn expand<C: CubePrimitive>(
42 scope: &mut Scope,
43 lhs: ExpandElementTyped<C>,
44 rhs: ExpandElementTyped<C>,
45 ) -> ExpandElementTyped<C> {
46 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
47 }
48}
49
50pub mod div {
51 use super::*;
52
53 pub fn expand<C: CubePrimitive>(
54 scope: &mut Scope,
55 lhs: ExpandElementTyped<C>,
56 rhs: ExpandElementTyped<C>,
57 ) -> ExpandElementTyped<C> {
58 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
59 }
60}
61
62pub mod rem {
63 use super::*;
64
65 pub fn expand<C: CubePrimitive>(
66 scope: &mut Scope,
67 lhs: ExpandElementTyped<C>,
68 rhs: ExpandElementTyped<C>,
69 ) -> ExpandElementTyped<C> {
70 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
71 }
72}
73
74pub mod and {
75 use super::*;
76
77 pub fn expand<C: CubePrimitive>(
78 scope: &mut Scope,
79 lhs: ExpandElementTyped<C>,
80 rhs: ExpandElementTyped<C>,
81 ) -> ExpandElementTyped<bool> {
82 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
83 }
84}
85
86pub mod bitand {
87 use super::*;
88
89 pub fn expand<C: CubePrimitive>(
90 scope: &mut Scope,
91 lhs: ExpandElementTyped<C>,
92 rhs: ExpandElementTyped<C>,
93 ) -> ExpandElementTyped<C> {
94 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
95 }
96}
97
98pub mod bitor {
99 use super::*;
100
101 pub fn expand<C: CubePrimitive>(
102 scope: &mut Scope,
103 lhs: ExpandElementTyped<C>,
104 rhs: ExpandElementTyped<C>,
105 ) -> ExpandElementTyped<C> {
106 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
107 }
108}
109
110pub mod or {
111 use super::*;
112
113 pub fn expand<C: CubePrimitive>(
114 scope: &mut Scope,
115 lhs: ExpandElementTyped<C>,
116 rhs: ExpandElementTyped<C>,
117 ) -> ExpandElementTyped<bool> {
118 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
119 }
120}
121
122pub mod bitxor {
123 use super::*;
124
125 pub fn expand<C: CubePrimitive>(
126 scope: &mut Scope,
127 lhs: ExpandElementTyped<C>,
128 rhs: ExpandElementTyped<C>,
129 ) -> ExpandElementTyped<C> {
130 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
131 }
132}
133
134pub mod shl {
135 use super::*;
136
137 pub fn expand<C: CubePrimitive>(
138 scope: &mut Scope,
139 lhs: ExpandElementTyped<C>,
140 rhs: ExpandElementTyped<C>,
141 ) -> ExpandElementTyped<C> {
142 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
143 }
144}
145
146pub mod shr {
147 use super::*;
148
149 pub fn expand<C: CubePrimitive>(
150 scope: &mut Scope,
151 lhs: ExpandElementTyped<C>,
152 rhs: ExpandElementTyped<C>,
153 ) -> ExpandElementTyped<C> {
154 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
155 }
156}
157
158macro_rules! impl_binary_func {
160 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
161 paste::paste! {
162 pub trait $trait_name: CubeType + Sized {
163 fn $method_name(self, _rhs: Self) -> Self {
164 unexpanded!()
165 }
166
167 fn [<__expand_ $method_name>](
168 scope: &mut Scope,
169 lhs: ExpandElementTyped<Self>,
170 rhs: ExpandElementTyped<Self>,
171 ) -> ExpandElementTyped<Self> {
172 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
173 }
174 }
175
176 $(impl $trait_name for $type {})*
177 $(impl ExpandElementTyped<$type> {
178 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
179 binary_expand(scope, self.into(), rhs.into(), $operator).into()
180 }
181 })*
182 }
183 }
184}
185
186macro_rules! impl_binary_func_fixed_output_vectorization {
187 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
188 paste::paste! {
189 pub trait $trait_name: CubeType + Sized {
190 fn $method_name(self, _rhs: Self) -> Self {
191 unexpanded!()
192 }
193
194 fn [<__expand_ $method_name>](
195 scope: &mut Scope,
196 lhs: ExpandElementTyped<Self>,
197 rhs: ExpandElementTyped<Self>,
198 ) -> ExpandElementTyped<Self> {
199 let lhs: ExpandElement = lhs.into();
200 let item = lhs.ty.line($out_vectorization);
201 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
202 }
203 }
204
205 $(impl $trait_name for $type {})*
206 $(impl ExpandElementTyped<$type> {
207 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
208 let lhs: ExpandElement = self.into();
209 let item = lhs.ty.line($out_vectorization);
210 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
211 }
212 })*
213 }
214 }
215}
216
217macro_rules! impl_binary_func_mixed_types {
218 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
219 paste::paste! {
220 pub trait $trait_name<Rhs: CubeType + Sized>: CubeType + Sized {
221 fn $method_name(self, _rhs: Rhs) -> Self {
222 unexpanded!()
223 }
224
225 fn [<__expand_ $method_name>](
226 scope: &mut Scope,
227 lhs: ExpandElementTyped<Self>,
228 rhs: ExpandElementTyped<Rhs>,
229 ) -> ExpandElementTyped<Self> {
230 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
231 }
232 }
233
234 $(impl $trait_name<$rhs_ty> for $type {})*
235 $(impl ExpandElementTyped<$type> {
236 pub fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: ExpandElementTyped<$rhs_ty>) -> ExpandElementTyped<$type> {
237 binary_expand(scope, self.into(), rhs.into(), $operator).into()
238 }
239 })*
240 }
241 }
242}
243
244impl_binary_func!(
245 Powf,
246 powf,
247 Arithmetic::Powf,
248 f16,
249 bf16,
250 flex32,
251 tf32,
252 f32,
253 f64
254);
255impl_binary_func!(
256 Max,
257 max,
258 Arithmetic::Max,
259 e2m1,
260 e4m3,
261 e5m2,
262 ue8m0,
263 f16,
264 bf16,
265 flex32,
266 tf32,
267 f32,
268 f64,
269 i8,
270 i16,
271 i32,
272 i64,
273 u8,
274 u16,
275 u32,
276 u64
277);
278impl_binary_func!(
279 Min,
280 min,
281 Arithmetic::Min,
282 e2m1,
283 e4m3,
284 e5m2,
285 ue8m0,
286 f16,
287 bf16,
288 flex32,
289 tf32,
290 f32,
291 f64,
292 i8,
293 i16,
294 i32,
295 i64,
296 u8,
297 u16,
298 u32,
299 u64
300);
301impl_binary_func!(
302 Remainder,
303 rem,
304 Arithmetic::Remainder,
305 e2m1,
306 e4m3,
307 e5m2,
308 ue8m0,
309 f16,
310 bf16,
311 flex32,
312 tf32,
313 f32,
314 f64,
315 i8,
316 i16,
317 i32,
318 i64,
319 u8,
320 u16,
321 u32,
322 u64
323);
324impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32);
325impl_binary_func!(
326 SaturatingAdd,
327 saturating_add,
328 Arithmetic::SaturatingAdd,
329 i8,
330 i16,
331 i32,
332 i64,
333 u8,
334 u16,
335 u32,
336 u64
337);
338impl_binary_func!(
339 SaturatingSub,
340 saturating_sub,
341 Arithmetic::SaturatingSub,
342 i8,
343 i16,
344 i32,
345 i64,
346 u8,
347 u16,
348 u32,
349 u64
350);
351impl_binary_func_fixed_output_vectorization!(
352 Dot,
353 dot,
354 Arithmetic::Dot,
355 0,
356 f16,
357 bf16,
358 flex32,
359 tf32,
360 f32,
361 f64,
362 i8,
363 i16,
364 i32,
365 i64,
366 u8,
367 u16,
368 u32,
369 u64
370);
371
372impl_binary_func_mixed_types!(
373 Powi,
374 powi,
375 i32,
376 Arithmetic::Powi,
377 f16,
378 bf16,
379 flex32,
380 tf32,
381 f32,
382 f64,
383 i8,
384 i16,
385 i32,
386 i64,
387 u8,
388 u16,
389 u32,
390 u64
391);