1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6 flex32,
7 ir::{Arithmetic, ExpandElement, Scope},
8 prelude::{CubePrimitive, ExpandElementTyped},
9 tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15 use super::*;
16
17 pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18 unary_expand(scope, x.into(), Operator::Not).into()
19 }
20}
21
22pub mod neg {
23 use super::*;
24
25 pub fn expand<E: CubePrimitive>(
26 scope: &mut Scope,
27 x: ExpandElementTyped<E>,
28 ) -> ExpandElementTyped<E> {
29 unary_expand(scope, x.into(), Arithmetic::Neg).into()
30 }
31}
32
33macro_rules! impl_unary_func {
34 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35 pub trait $trait_name: CubePrimitive + Sized {
36 #[allow(unused_variables)]
37 fn $method_name(x: Self) -> Self {
38 unexpanded!()
39 }
40
41 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42 unary_expand(scope, x.into(), $operator).into()
43 }
44 }
45
46 $(impl $trait_name for $type {})*
47 }
48}
49
50impl Exp for f32 {
51 fn exp(x: Self) -> Self {
52 x.exp()
53 }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58 pub trait $trait_name: CubePrimitive + Sized {
59 #[allow(unused_variables)]
60 fn $method_name(x: Self) -> Self {
61 unexpanded!()
62 }
63
64 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65 let expand_element: ExpandElement = x.into();
66 let item = expand_element.ty.line($out_vectorization);
67 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68 }
69 }
70
71 $(impl $trait_name for $type {})*
72 }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77 pub trait $trait_name: CubePrimitive + Sized {
78 #[allow(unused_variables)]
79 fn $method_name(x: Self) -> $out_ty {
80 unexpanded!()
81 }
82
83 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84 let expand_element: ExpandElement = x.into();
85 let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87 }
88 }
89
90 $(impl $trait_name for $type {})*
91 }
92}
93
94impl_unary_func!(
95 Abs,
96 abs,
97 __expand_abs,
98 Arithmetic::Abs,
99 e2m1,
100 e4m3,
101 e5m2,
102 ue8m0,
103 f16,
104 bf16,
105 flex32,
106 tf32,
107 f32,
108 f64,
109 i8,
110 i16,
111 i32,
112 i64,
113 u8,
114 u16,
115 u32,
116 u64
117);
118impl_unary_func!(
119 Exp,
120 exp,
121 __expand_exp,
122 Arithmetic::Exp,
123 f16,
124 bf16,
125 flex32,
126 tf32,
127 f64
129);
130impl_unary_func!(
131 Log,
132 log,
133 __expand_log,
134 Arithmetic::Log,
135 f16,
136 bf16,
137 flex32,
138 tf32,
139 f32,
140 f64
141);
142impl_unary_func!(
143 Log1p,
144 log1p,
145 __expand_log1p,
146 Arithmetic::Log1p,
147 f16,
148 bf16,
149 flex32,
150 tf32,
151 f32,
152 f64
153);
154impl_unary_func!(
155 Cos,
156 cos,
157 __expand_cos,
158 Arithmetic::Cos,
159 f16,
160 bf16,
161 flex32,
162 tf32,
163 f32,
164 f64
165);
166impl_unary_func!(
167 Sin,
168 sin,
169 __expand_sin,
170 Arithmetic::Sin,
171 f16,
172 bf16,
173 flex32,
174 tf32,
175 f32,
176 f64
177);
178impl_unary_func!(
179 Tan,
180 tan,
181 __expand_tan,
182 Arithmetic::Tan,
183 f16,
184 bf16,
185 flex32,
186 tf32,
187 f32,
188 f64
189);
190impl_unary_func!(
191 Tanh,
192 tanh,
193 __expand_tanh,
194 Arithmetic::Tanh,
195 f16,
196 bf16,
197 flex32,
198 tf32,
199 f32,
200 f64
201);
202impl_unary_func!(
203 Sinh,
204 sinh,
205 __expand_sinh,
206 Arithmetic::Sinh,
207 f16,
208 bf16,
209 flex32,
210 tf32,
211 f32,
212 f64
213);
214impl_unary_func!(
215 Cosh,
216 cosh,
217 __expand_cosh,
218 Arithmetic::Cosh,
219 f16,
220 bf16,
221 flex32,
222 tf32,
223 f32,
224 f64
225);
226impl_unary_func!(
227 ArcCos,
228 acos,
229 __expand_acos,
230 Arithmetic::ArcCos,
231 f16,
232 bf16,
233 flex32,
234 tf32,
235 f32,
236 f64
237);
238impl_unary_func!(
239 ArcSin,
240 asin,
241 __expand_asin,
242 Arithmetic::ArcSin,
243 f16,
244 bf16,
245 flex32,
246 tf32,
247 f32,
248 f64
249);
250impl_unary_func!(
251 ArcTan,
252 atan,
253 __expand_atan,
254 Arithmetic::ArcTan,
255 f16,
256 bf16,
257 flex32,
258 tf32,
259 f32,
260 f64
261);
262impl_unary_func!(
263 ArcSinh,
264 asinh,
265 __expand_asinh,
266 Arithmetic::ArcSinh,
267 f16,
268 bf16,
269 flex32,
270 tf32,
271 f32,
272 f64
273);
274impl_unary_func!(
275 ArcCosh,
276 acosh,
277 __expand_acosh,
278 Arithmetic::ArcCosh,
279 f16,
280 bf16,
281 flex32,
282 tf32,
283 f32,
284 f64
285);
286impl_unary_func!(
287 ArcTanh,
288 atanh,
289 __expand_atanh,
290 Arithmetic::ArcTanh,
291 f16,
292 bf16,
293 flex32,
294 tf32,
295 f32,
296 f64
297);
298impl_unary_func!(
299 Degrees,
300 to_degrees,
301 __expand_to_degrees,
302 Arithmetic::Degrees,
303 f16,
304 bf16,
305 flex32,
306 tf32,
307 f32,
308 f64
309);
310impl_unary_func!(
311 Radians,
312 to_radians,
313 __expand_to_radians,
314 Arithmetic::Radians,
315 f16,
316 bf16,
317 flex32,
318 tf32,
319 f32,
320 f64
321);
322impl_unary_func!(
323 Sqrt,
324 sqrt,
325 __expand_sqrt,
326 Arithmetic::Sqrt,
327 f16,
328 bf16,
329 flex32,
330 tf32,
331 f32,
332 f64
333);
334impl_unary_func!(
335 InverseSqrt,
336 inverse_sqrt,
337 __expand_inverse_sqrt,
338 Arithmetic::InverseSqrt,
339 f16,
340 bf16,
341 flex32,
342 tf32,
343 f32,
344 f64
345);
346impl_unary_func!(
347 Round,
348 round,
349 __expand_round,
350 Arithmetic::Round,
351 f16,
352 bf16,
353 flex32,
354 tf32,
355 f32,
356 f64
357);
358impl_unary_func!(
359 Floor,
360 floor,
361 __expand_floor,
362 Arithmetic::Floor,
363 f16,
364 bf16,
365 flex32,
366 tf32,
367 f32,
368 f64
369);
370impl_unary_func!(
371 Ceil,
372 ceil,
373 __expand_ceil,
374 Arithmetic::Ceil,
375 f16,
376 bf16,
377 flex32,
378 tf32,
379 f32,
380 f64
381);
382impl_unary_func!(
383 Trunc,
384 trunc,
385 __expand_trunc,
386 Arithmetic::Trunc,
387 f16,
388 bf16,
389 flex32,
390 tf32,
391 f32,
392 f64
393);
394impl_unary_func!(
395 Erf,
396 erf,
397 __expand_erf,
398 Arithmetic::Erf,
399 f16,
400 bf16,
401 flex32,
402 tf32,
403 f32,
404 f64
405);
406impl_unary_func!(
407 Recip,
408 recip,
409 __expand_recip,
410 Arithmetic::Recip,
411 f16,
412 bf16,
413 flex32,
414 tf32,
415 f32,
416 f64
417);
418impl_unary_func_fixed_out_vectorization!(
419 Magnitude,
420 magnitude,
421 __expand_magnitude,
422 Arithmetic::Magnitude,
423 0,
424 f16,
425 bf16,
426 flex32,
427 tf32,
428 f32,
429 f64
430);
431impl_unary_func!(
432 Normalize,
433 normalize,
434 __expand_normalize,
435 Arithmetic::Normalize,
436 f16,
437 bf16,
438 flex32,
439 tf32,
440 f32,
441 f64
442);
443impl_unary_func_fixed_out_ty!(
444 CountOnes,
445 count_ones,
446 __expand_count_ones,
447 u32,
448 Bitwise::CountOnes,
449 u8,
450 i8,
451 u16,
452 i16,
453 u32,
454 i32,
455 u64,
456 i64
457);
458impl_unary_func!(
459 ReverseBits,
460 reverse_bits,
461 __expand_reverse_bits,
462 Bitwise::ReverseBits,
463 u8,
464 i8,
465 u16,
466 i16,
467 u32,
468 i32,
469 u64,
470 i64
471);
472
473impl_unary_func!(
474 BitwiseNot,
475 bitwise_not,
476 __expand_bitwise_not,
477 Bitwise::BitwiseNot,
478 u8,
479 i8,
480 u16,
481 i16,
482 u32,
483 i32,
484 u64,
485 i64
486);
487impl_unary_func_fixed_out_ty!(
488 LeadingZeros,
489 leading_zeros,
490 __expand_leading_zeros,
491 u32,
492 Bitwise::LeadingZeros,
493 u8,
494 i8,
495 u16,
496 i16,
497 u32,
498 i32,
499 u64,
500 i64
501);
502impl_unary_func_fixed_out_ty!(
503 FindFirstSet,
504 find_first_set,
505 __expand_find_first_set,
506 u32,
507 Bitwise::FindFirstSet,
508 u8,
509 i8,
510 u16,
511 i16,
512 u32,
513 i32,
514 u64,
515 i64
516);
517impl_unary_func_fixed_out_ty!(
518 IsNan,
519 is_nan,
520 __expand_is_nan,
521 bool,
522 Comparison::IsNan,
523 f16,
524 bf16,
525 flex32,
526 tf32,
527 f32,
528 f64
529);
530impl_unary_func_fixed_out_ty!(
531 IsInf,
532 is_inf,
533 __expand_is_inf,
534 bool,
535 Comparison::IsInf,
536 f16,
537 bf16,
538 flex32,
539 tf32,
540 f32,
541 f64
542);