1use crate::ir::{Arithmetic, Bitwise, ExpandElement, Operator, Scope};
2use crate::{
3 flex32,
4 frontend::{CubePrimitive, ExpandElementTyped},
5};
6use crate::{frontend::CubeType, tf32};
7use crate::{
8 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
9 unexpanded,
10};
11use half::{bf16, f16};
12
13pub mod add {
14 use super::*;
15
16 pub fn expand<C: CubePrimitive>(
17 scope: &mut Scope,
18 lhs: ExpandElementTyped<C>,
19 rhs: ExpandElementTyped<C>,
20 ) -> ExpandElementTyped<C> {
21 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
22 }
23}
24
25pub mod sub {
26 use super::*;
27
28 pub fn expand<C: CubePrimitive>(
29 scope: &mut Scope,
30 lhs: ExpandElementTyped<C>,
31 rhs: ExpandElementTyped<C>,
32 ) -> ExpandElementTyped<C> {
33 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into()
34 }
35}
36
37pub mod mul {
38 use super::*;
39
40 pub fn expand<C: CubePrimitive>(
41 scope: &mut Scope,
42 lhs: ExpandElementTyped<C>,
43 rhs: ExpandElementTyped<C>,
44 ) -> ExpandElementTyped<C> {
45 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
46 }
47}
48
49pub mod div {
50 use super::*;
51
52 pub fn expand<C: CubePrimitive>(
53 scope: &mut Scope,
54 lhs: ExpandElementTyped<C>,
55 rhs: ExpandElementTyped<C>,
56 ) -> ExpandElementTyped<C> {
57 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
58 }
59}
60
61pub mod rem {
62 use super::*;
63
64 pub fn expand<C: CubePrimitive>(
65 scope: &mut Scope,
66 lhs: ExpandElementTyped<C>,
67 rhs: ExpandElementTyped<C>,
68 ) -> ExpandElementTyped<C> {
69 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
70 }
71}
72
73pub mod and {
74 use super::*;
75
76 pub fn expand<C: CubePrimitive>(
77 scope: &mut Scope,
78 lhs: ExpandElementTyped<C>,
79 rhs: ExpandElementTyped<C>,
80 ) -> ExpandElementTyped<bool> {
81 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
82 }
83}
84
85pub mod bitand {
86 use super::*;
87
88 pub fn expand<C: CubePrimitive>(
89 scope: &mut Scope,
90 lhs: ExpandElementTyped<C>,
91 rhs: ExpandElementTyped<C>,
92 ) -> ExpandElementTyped<C> {
93 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
94 }
95}
96
97pub mod bitor {
98 use super::*;
99
100 pub fn expand<C: CubePrimitive>(
101 scope: &mut Scope,
102 lhs: ExpandElementTyped<C>,
103 rhs: ExpandElementTyped<C>,
104 ) -> ExpandElementTyped<C> {
105 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
106 }
107}
108
109pub mod or {
110 use super::*;
111
112 pub fn expand<C: CubePrimitive>(
113 scope: &mut Scope,
114 lhs: ExpandElementTyped<C>,
115 rhs: ExpandElementTyped<C>,
116 ) -> ExpandElementTyped<bool> {
117 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
118 }
119}
120
121pub mod bitxor {
122 use super::*;
123
124 pub fn expand<C: CubePrimitive>(
125 scope: &mut Scope,
126 lhs: ExpandElementTyped<C>,
127 rhs: ExpandElementTyped<C>,
128 ) -> ExpandElementTyped<C> {
129 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
130 }
131}
132
133pub mod shl {
134 use super::*;
135
136 pub fn expand<C: CubePrimitive>(
137 scope: &mut Scope,
138 lhs: ExpandElementTyped<C>,
139 rhs: ExpandElementTyped<C>,
140 ) -> ExpandElementTyped<C> {
141 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
142 }
143}
144
145pub mod shr {
146 use super::*;
147
148 pub fn expand<C: CubePrimitive>(
149 scope: &mut Scope,
150 lhs: ExpandElementTyped<C>,
151 rhs: ExpandElementTyped<C>,
152 ) -> ExpandElementTyped<C> {
153 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
154 }
155}
156
157macro_rules! impl_binary_func {
159 ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
160 pub trait $trait_name: CubeType + Sized {
161 fn $method_name(self, _rhs: Self) -> Self {
162 unexpanded!()
163 }
164
165 fn $func_name_expand(
166 scope: &mut Scope,
167 lhs: ExpandElementTyped<Self>,
168 rhs: ExpandElementTyped<Self>,
169 ) -> ExpandElementTyped<Self> {
170 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
171 }
172 }
173
174 $(impl $trait_name for $type {})*
175 $(impl ExpandElementTyped<$type> {
176 pub fn $method_name_expand(self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
177 binary_expand(scope, self.into(), rhs.into(), $operator).into()
178 }
179 })*
180 }
181}
182
183macro_rules! impl_binary_func_fixed_output_vectorization {
184 ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
185 pub trait $trait_name: CubeType + Sized {
186 fn $method_name(self, _rhs: Self) -> Self {
187 unexpanded!()
188 }
189
190 fn $func_name_expand(
191 scope: &mut Scope,
192 lhs: ExpandElementTyped<Self>,
193 rhs: ExpandElementTyped<Self>,
194 ) -> ExpandElementTyped<Self> {
195 let lhs: ExpandElement = lhs.into();
196 let mut item = lhs.item;
197 item.vectorization = $out_vectorization;
198 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
199 }
200 }
201
202 $(impl $trait_name for $type {})*
203 $(impl ExpandElementTyped<$type> {
204 pub fn $method_name_expand(self, scope: &mut Scope, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
205 let lhs: ExpandElement = self.into();
206 let mut item = lhs.item;
207 item.vectorization = $out_vectorization;
208 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
209 }
210 })*
211 }
212}
213
214impl_binary_func!(
215 Powf,
216 powf,
217 __expand_powf,
218 __expand_powf_method,
219 Arithmetic::Powf,
220 f16,
221 bf16,
222 flex32,
223 tf32,
224 f32,
225 f64
226);
227impl_binary_func!(
228 Max,
229 max,
230 __expand_max,
231 __expand_max_method,
232 Arithmetic::Max,
233 f16,
234 bf16,
235 flex32,
236 tf32,
237 f32,
238 f64,
239 i8,
240 i16,
241 i32,
242 i64,
243 u8,
244 u16,
245 u32,
246 u64
247);
248impl_binary_func!(
249 Min,
250 min,
251 __expand_min,
252 __expand_min_method,
253 Arithmetic::Min,
254 f16,
255 bf16,
256 flex32,
257 tf32,
258 f32,
259 f64,
260 i8,
261 i16,
262 i32,
263 i64,
264 u8,
265 u16,
266 u32,
267 u64
268);
269impl_binary_func!(
270 Remainder,
271 rem,
272 __expand_rem,
273 __expand_rem_method,
274 Arithmetic::Remainder,
275 f16,
276 bf16,
277 flex32,
278 tf32,
279 f32,
280 f64,
281 i8,
282 i16,
283 i32,
284 i64,
285 u8,
286 u16,
287 u32,
288 u64
289);
290impl_binary_func!(
291 MulHi,
292 mul_hi,
293 __expand_mul_hi,
294 __expand_mul_hi_method,
295 Arithmetic::MulHi,
296 i32,
297 u32
298);
299impl_binary_func_fixed_output_vectorization!(
300 Dot,
301 dot,
302 __expand_dot,
303 __expand_dot_method,
304 Arithmetic::Dot,
305 None,
306 f16,
307 bf16,
308 flex32,
309 tf32,
310 f32,
311 f64,
312 i8,
313 i16,
314 i32,
315 i64,
316 u8,
317 u16,
318 u32,
319 u64
320);