1use crate::ir::Operator;
2use crate::{
3 flex32,
4 frontend::{CubeContext, CubePrimitive, ExpandElement, ExpandElementTyped},
5};
6use crate::{
7 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
8 unexpanded,
9};
10use crate::{frontend::CubeType, tf32};
11use half::{bf16, f16};
12
13pub mod add {
14 use super::*;
15
16 pub fn expand<C: CubePrimitive>(
17 context: &mut CubeContext,
18 lhs: impl Into<ExpandElementTyped<C>>,
19 rhs: impl Into<ExpandElementTyped<C>>,
20 ) -> ExpandElementTyped<C> {
21 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Add).into()
22 }
23}
24
25pub mod sub {
26 use super::*;
27
28 pub fn expand<C: CubePrimitive>(
29 context: &mut CubeContext,
30 lhs: impl Into<ExpandElementTyped<C>>,
31 rhs: impl Into<ExpandElementTyped<C>>,
32 ) -> ExpandElementTyped<C> {
33 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Sub).into()
34 }
35}
36
37pub mod mul {
38 use super::*;
39
40 pub fn expand<C: CubePrimitive>(
41 context: &mut CubeContext,
42 lhs: impl Into<ExpandElementTyped<C>>,
43 rhs: impl Into<ExpandElementTyped<C>>,
44 ) -> ExpandElementTyped<C> {
45 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Mul).into()
46 }
47}
48
49pub mod div {
50 use super::*;
51
52 pub fn expand<C: CubePrimitive>(
53 context: &mut CubeContext,
54 lhs: impl Into<ExpandElementTyped<C>>,
55 rhs: impl Into<ExpandElementTyped<C>>,
56 ) -> ExpandElementTyped<C> {
57 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Div).into()
58 }
59}
60
61pub mod rem {
62 use super::*;
63
64 pub fn expand<C: CubePrimitive>(
65 context: &mut CubeContext,
66 lhs: impl Into<ExpandElementTyped<C>>,
67 rhs: impl Into<ExpandElementTyped<C>>,
68 ) -> ExpandElementTyped<C> {
69 binary_expand(
70 context,
71 lhs.into().into(),
72 rhs.into().into(),
73 Operator::Modulo,
74 )
75 .into()
76 }
77}
78
79pub mod and {
80 use super::*;
81
82 pub fn expand<C: CubePrimitive>(
83 context: &mut CubeContext,
84 lhs: impl Into<ExpandElementTyped<C>>,
85 rhs: impl Into<ExpandElementTyped<C>>,
86 ) -> ExpandElementTyped<bool> {
87 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::And).into()
88 }
89}
90
91pub mod bitand {
92 use super::*;
93
94 pub fn expand<C: CubePrimitive>(
95 context: &mut CubeContext,
96 lhs: impl Into<ExpandElementTyped<C>>,
97 rhs: impl Into<ExpandElementTyped<C>>,
98 ) -> ExpandElementTyped<C> {
99 binary_expand(
100 context,
101 lhs.into().into(),
102 rhs.into().into(),
103 Operator::BitwiseAnd,
104 )
105 .into()
106 }
107}
108
109pub mod bitor {
110 use super::*;
111
112 pub fn expand<C: CubePrimitive>(
113 context: &mut CubeContext,
114 lhs: impl Into<ExpandElementTyped<C>>,
115 rhs: impl Into<ExpandElementTyped<C>>,
116 ) -> ExpandElementTyped<C> {
117 binary_expand(
118 context,
119 lhs.into().into(),
120 rhs.into().into(),
121 Operator::BitwiseOr,
122 )
123 .into()
124 }
125}
126
127pub mod or {
128 use super::*;
129
130 pub fn expand<C: CubePrimitive>(
131 context: &mut CubeContext,
132 lhs: impl Into<ExpandElementTyped<C>>,
133 rhs: impl Into<ExpandElementTyped<C>>,
134 ) -> ExpandElementTyped<bool> {
135 binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Or).into()
136 }
137}
138
139pub mod bitxor {
140 use super::*;
141
142 pub fn expand<C: CubePrimitive>(
143 context: &mut CubeContext,
144 lhs: impl Into<ExpandElementTyped<C>>,
145 rhs: impl Into<ExpandElementTyped<C>>,
146 ) -> ExpandElementTyped<C> {
147 binary_expand(
148 context,
149 lhs.into().into(),
150 rhs.into().into(),
151 Operator::BitwiseXor,
152 )
153 .into()
154 }
155}
156
157pub mod shl {
158 use super::*;
159
160 pub fn expand<C: CubePrimitive>(
161 context: &mut CubeContext,
162 lhs: impl Into<ExpandElementTyped<C>>,
163 rhs: impl Into<ExpandElementTyped<C>>,
164 ) -> ExpandElementTyped<C> {
165 binary_expand(
166 context,
167 lhs.into().into(),
168 rhs.into().into(),
169 Operator::ShiftLeft,
170 )
171 .into()
172 }
173}
174
175pub mod shr {
176 use super::*;
177
178 pub fn expand<C: CubePrimitive>(
179 context: &mut CubeContext,
180 lhs: impl Into<ExpandElementTyped<C>>,
181 rhs: impl Into<ExpandElementTyped<C>>,
182 ) -> ExpandElementTyped<C> {
183 binary_expand(
184 context,
185 lhs.into().into(),
186 rhs.into().into(),
187 Operator::ShiftRight,
188 )
189 .into()
190 }
191}
192
193macro_rules! impl_binary_func {
195 ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
196 pub trait $trait_name: CubeType + Sized {
197 fn $method_name(self, _rhs: Self) -> Self {
198 unexpanded!()
199 }
200
201 fn $func_name_expand(
202 context: &mut CubeContext,
203 lhs: ExpandElementTyped<Self>,
204 rhs: ExpandElementTyped<Self>,
205 ) -> ExpandElementTyped<Self> {
206 binary_expand(context, lhs.into(), rhs.into(), $operator).into()
207 }
208 }
209
210 $(impl $trait_name for $type {})*
211 $(impl ExpandElementTyped<$type> {
212 pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
213 binary_expand(context, self.into(), rhs.into(), $operator).into()
214 }
215 })*
216 }
217}
218
219macro_rules! impl_binary_func_fixed_output_vectorization {
220 ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
221 pub trait $trait_name: CubeType + Sized {
222 fn $method_name(self, _rhs: Self) -> Self {
223 unexpanded!()
224 }
225
226 fn $func_name_expand(
227 context: &mut CubeContext,
228 lhs: ExpandElementTyped<Self>,
229 rhs: ExpandElementTyped<Self>,
230 ) -> ExpandElementTyped<Self> {
231 let lhs: ExpandElement = lhs.into();
232 let mut item = lhs.item;
233 item.vectorization = $out_vectorization;
234 binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
235 }
236 }
237
238 $(impl $trait_name for $type {})*
239 $(impl ExpandElementTyped<$type> {
240 pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
241 let lhs: ExpandElement = self.into();
242 let mut item = lhs.item;
243 item.vectorization = $out_vectorization;
244 binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
245 }
246 })*
247 }
248}
249
250impl_binary_func!(
251 Powf,
252 powf,
253 __expand_powf,
254 __expand_powf_method,
255 Operator::Powf,
256 f16,
257 bf16,
258 flex32,
259 tf32,
260 f32,
261 f64
262);
263impl_binary_func!(
264 Max,
265 max,
266 __expand_max,
267 __expand_max_method,
268 Operator::Max,
269 f16,
270 bf16,
271 flex32,
272 tf32,
273 f32,
274 f64,
275 i8,
276 i16,
277 i32,
278 i64,
279 u8,
280 u16,
281 u32,
282 u64
283);
284impl_binary_func!(
285 Min,
286 min,
287 __expand_min,
288 __expand_min_method,
289 Operator::Min,
290 f16,
291 bf16,
292 flex32,
293 tf32,
294 f32,
295 f64,
296 i8,
297 i16,
298 i32,
299 i64,
300 u8,
301 u16,
302 u32,
303 u64
304);
305impl_binary_func!(
306 Remainder,
307 rem,
308 __expand_rem,
309 __expand_rem_method,
310 Operator::Remainder,
311 f16,
312 bf16,
313 flex32,
314 tf32,
315 f32,
316 f64,
317 i8,
318 i16,
319 i32,
320 i64,
321 u8,
322 u16,
323 u32,
324 u64
325);
326impl_binary_func_fixed_output_vectorization!(
327 Dot,
328 dot,
329 __expand_dot,
330 __expand_dot_method,
331 Operator::Dot,
332 None,
333 f16,
334 bf16,
335 flex32,
336 tf32,
337 f32,
338 f64,
339 i8,
340 i16,
341 i32,
342 i64,
343 u8,
344 u16,
345 u32,
346 u64
347);