1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6 flex32,
7 ir::{Arithmetic, ExpandElement, Scope},
8 prelude::{CubePrimitive, ExpandElementTyped},
9 tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15 use super::*;
16
17 pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18 unary_expand(scope, x.into(), Operator::Not).into()
19 }
20}
21
22pub mod neg {
23 use super::*;
24
25 pub fn expand<E: CubePrimitive>(
26 scope: &mut Scope,
27 x: ExpandElementTyped<E>,
28 ) -> ExpandElementTyped<E> {
29 unary_expand(scope, x.into(), Arithmetic::Neg).into()
30 }
31}
32
33macro_rules! impl_unary_func {
34 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35 pub trait $trait_name: CubePrimitive + Sized {
36 #[allow(unused_variables)]
37 fn $method_name(x: Self) -> Self {
38 unexpanded!()
39 }
40
41 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42 unary_expand(scope, x.into(), $operator).into()
43 }
44 }
45
46 $(impl $trait_name for $type {})*
47 }
48}
49
50impl Exp for f32 {
51 fn exp(x: Self) -> Self {
52 x.exp()
53 }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58 pub trait $trait_name: CubePrimitive + Sized {
59 #[allow(unused_variables)]
60 fn $method_name(x: Self) -> Self {
61 unexpanded!()
62 }
63
64 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65 let expand_element: ExpandElement = x.into();
66 let item = expand_element.ty.line($out_vectorization);
67 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68 }
69 }
70
71 $(impl $trait_name for $type {})*
72 }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77 pub trait $trait_name: CubePrimitive + Sized {
78 #[allow(unused_variables)]
79 fn $method_name(x: Self) -> $out_ty {
80 unexpanded!()
81 }
82
83 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84 let expand_element: ExpandElement = x.into();
85 let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87 }
88 }
89
90 $(impl $trait_name for $type {})*
91 }
92}
93
94impl_unary_func!(
95 Abs,
96 abs,
97 __expand_abs,
98 Arithmetic::Abs,
99 e2m1,
100 e4m3,
101 e5m2,
102 ue8m0,
103 f16,
104 bf16,
105 flex32,
106 tf32,
107 f32,
108 f64,
109 i8,
110 i16,
111 i32,
112 i64,
113 u8,
114 u16,
115 u32,
116 u64
117);
118impl_unary_func!(
119 Exp,
120 exp,
121 __expand_exp,
122 Arithmetic::Exp,
123 f16,
124 bf16,
125 flex32,
126 tf32,
127 f64
129);
130impl_unary_func!(
131 Log,
132 log,
133 __expand_log,
134 Arithmetic::Log,
135 f16,
136 bf16,
137 flex32,
138 tf32,
139 f32,
140 f64
141);
142impl_unary_func!(
143 Log1p,
144 log1p,
145 __expand_log1p,
146 Arithmetic::Log1p,
147 f16,
148 bf16,
149 flex32,
150 tf32,
151 f32,
152 f64
153);
154impl_unary_func!(
155 Cos,
156 cos,
157 __expand_cos,
158 Arithmetic::Cos,
159 f16,
160 bf16,
161 flex32,
162 tf32,
163 f32,
164 f64
165);
166impl_unary_func!(
167 Sin,
168 sin,
169 __expand_sin,
170 Arithmetic::Sin,
171 f16,
172 bf16,
173 flex32,
174 tf32,
175 f32,
176 f64
177);
178impl_unary_func!(
179 Tanh,
180 tanh,
181 __expand_tanh,
182 Arithmetic::Tanh,
183 f16,
184 bf16,
185 flex32,
186 tf32,
187 f32,
188 f64
189);
190impl_unary_func!(
191 Sqrt,
192 sqrt,
193 __expand_sqrt,
194 Arithmetic::Sqrt,
195 f16,
196 bf16,
197 flex32,
198 tf32,
199 f32,
200 f64
201);
202impl_unary_func!(
203 InverseSqrt,
204 inverse_sqrt,
205 __expand_inverse_sqrt,
206 Arithmetic::InverseSqrt,
207 f16,
208 bf16,
209 flex32,
210 tf32,
211 f32,
212 f64
213);
214impl_unary_func!(
215 Round,
216 round,
217 __expand_round,
218 Arithmetic::Round,
219 f16,
220 bf16,
221 flex32,
222 tf32,
223 f32,
224 f64
225);
226impl_unary_func!(
227 Floor,
228 floor,
229 __expand_floor,
230 Arithmetic::Floor,
231 f16,
232 bf16,
233 flex32,
234 tf32,
235 f32,
236 f64
237);
238impl_unary_func!(
239 Ceil,
240 ceil,
241 __expand_ceil,
242 Arithmetic::Ceil,
243 f16,
244 bf16,
245 flex32,
246 tf32,
247 f32,
248 f64
249);
250impl_unary_func!(
251 Trunc,
252 trunc,
253 __expand_trunc,
254 Arithmetic::Trunc,
255 f16,
256 bf16,
257 flex32,
258 tf32,
259 f32,
260 f64
261);
262impl_unary_func!(
263 Erf,
264 erf,
265 __expand_erf,
266 Arithmetic::Erf,
267 f16,
268 bf16,
269 flex32,
270 tf32,
271 f32,
272 f64
273);
274impl_unary_func!(
275 Recip,
276 recip,
277 __expand_recip,
278 Arithmetic::Recip,
279 f16,
280 bf16,
281 flex32,
282 tf32,
283 f32,
284 f64
285);
286impl_unary_func_fixed_out_vectorization!(
287 Magnitude,
288 magnitude,
289 __expand_magnitude,
290 Arithmetic::Magnitude,
291 0,
292 f16,
293 bf16,
294 flex32,
295 tf32,
296 f32,
297 f64
298);
299impl_unary_func!(
300 Normalize,
301 normalize,
302 __expand_normalize,
303 Arithmetic::Normalize,
304 f16,
305 bf16,
306 flex32,
307 tf32,
308 f32,
309 f64
310);
311impl_unary_func_fixed_out_ty!(
312 CountOnes,
313 count_ones,
314 __expand_count_ones,
315 u32,
316 Bitwise::CountOnes,
317 u8,
318 i8,
319 u16,
320 i16,
321 u32,
322 i32,
323 u64,
324 i64
325);
326impl_unary_func!(
327 ReverseBits,
328 reverse_bits,
329 __expand_reverse_bits,
330 Bitwise::ReverseBits,
331 u8,
332 i8,
333 u16,
334 i16,
335 u32,
336 i32,
337 u64,
338 i64
339);
340
341impl_unary_func!(
342 BitwiseNot,
343 bitwise_not,
344 __expand_bitwise_not,
345 Bitwise::BitwiseNot,
346 u8,
347 i8,
348 u16,
349 i16,
350 u32,
351 i32,
352 u64,
353 i64
354);
355impl_unary_func_fixed_out_ty!(
356 LeadingZeros,
357 leading_zeros,
358 __expand_leading_zeros,
359 u32,
360 Bitwise::LeadingZeros,
361 u8,
362 i8,
363 u16,
364 i16,
365 u32,
366 i32,
367 u64,
368 i64
369);
370impl_unary_func_fixed_out_ty!(
371 FindFirstSet,
372 find_first_set,
373 __expand_find_first_set,
374 u32,
375 Bitwise::FindFirstSet,
376 u8,
377 i8,
378 u16,
379 i16,
380 u32,
381 i32,
382 u64,
383 i64
384);
385impl_unary_func_fixed_out_ty!(
386 IsNan,
387 is_nan,
388 __expand_is_nan,
389 bool,
390 Comparison::IsNan,
391 f16,
392 bf16,
393 flex32,
394 tf32,
395 f32,
396 f64
397);
398impl_unary_func_fixed_out_ty!(
399 IsInf,
400 is_inf,
401 __expand_is_inf,
402 bool,
403 Comparison::IsInf,
404 f16,
405 bf16,
406 flex32,
407 tf32,
408 f32,
409 f64
410);