cubecl_core/frontend/operation/
unary.rs

1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6    flex32,
7    ir::{Arithmetic, ExpandElement, Scope},
8    prelude::{CubePrimitive, ExpandElementTyped},
9    tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15    use super::*;
16
17    pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18        unary_expand(scope, x.into(), Operator::Not).into()
19    }
20}
21
22pub mod neg {
23    use super::*;
24
25    pub fn expand<E: CubePrimitive>(
26        scope: &mut Scope,
27        x: ExpandElementTyped<E>,
28    ) -> ExpandElementTyped<E> {
29        unary_expand(scope, x.into(), Arithmetic::Neg).into()
30    }
31}
32
33macro_rules! impl_unary_func {
34    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35        pub trait $trait_name: CubePrimitive + Sized {
36            #[allow(unused_variables)]
37            fn $method_name(x: Self) -> Self {
38                unexpanded!()
39            }
40
41            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42                unary_expand(scope, x.into(), $operator).into()
43            }
44        }
45
46        $(impl $trait_name for $type {})*
47    }
48}
49
50impl Exp for f32 {
51    fn exp(x: Self) -> Self {
52        x.exp()
53    }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58        pub trait $trait_name: CubePrimitive + Sized {
59            #[allow(unused_variables)]
60            fn $method_name(x: Self) -> Self {
61                unexpanded!()
62            }
63
64            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65                let expand_element: ExpandElement = x.into();
66                let item = expand_element.ty.line($out_vectorization);
67                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68            }
69        }
70
71        $(impl $trait_name for $type {})*
72    }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77        pub trait $trait_name: CubePrimitive + Sized {
78            #[allow(unused_variables)]
79            fn $method_name(x: Self) -> $out_ty {
80                unexpanded!()
81            }
82
83            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84                let expand_element: ExpandElement = x.into();
85                let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87            }
88        }
89
90        $(impl $trait_name for $type {})*
91    }
92}
93
94impl_unary_func!(
95    Abs,
96    abs,
97    __expand_abs,
98    Arithmetic::Abs,
99    e2m1,
100    e4m3,
101    e5m2,
102    ue8m0,
103    f16,
104    bf16,
105    flex32,
106    tf32,
107    f32,
108    f64,
109    i8,
110    i16,
111    i32,
112    i64,
113    u8,
114    u16,
115    u32,
116    u64
117);
118impl_unary_func!(
119    Exp,
120    exp,
121    __expand_exp,
122    Arithmetic::Exp,
123    f16,
124    bf16,
125    flex32,
126    tf32,
127    // f32,
128    f64
129);
130impl_unary_func!(
131    Log,
132    log,
133    __expand_log,
134    Arithmetic::Log,
135    f16,
136    bf16,
137    flex32,
138    tf32,
139    f32,
140    f64
141);
142impl_unary_func!(
143    Log1p,
144    log1p,
145    __expand_log1p,
146    Arithmetic::Log1p,
147    f16,
148    bf16,
149    flex32,
150    tf32,
151    f32,
152    f64
153);
154impl_unary_func!(
155    Cos,
156    cos,
157    __expand_cos,
158    Arithmetic::Cos,
159    f16,
160    bf16,
161    flex32,
162    tf32,
163    f32,
164    f64
165);
166impl_unary_func!(
167    Sin,
168    sin,
169    __expand_sin,
170    Arithmetic::Sin,
171    f16,
172    bf16,
173    flex32,
174    tf32,
175    f32,
176    f64
177);
178impl_unary_func!(
179    Tan,
180    tan,
181    __expand_tan,
182    Arithmetic::Tan,
183    f16,
184    bf16,
185    flex32,
186    tf32,
187    f32,
188    f64
189);
190impl_unary_func!(
191    Tanh,
192    tanh,
193    __expand_tanh,
194    Arithmetic::Tanh,
195    f16,
196    bf16,
197    flex32,
198    tf32,
199    f32,
200    f64
201);
202impl_unary_func!(
203    Sinh,
204    sinh,
205    __expand_sinh,
206    Arithmetic::Sinh,
207    f16,
208    bf16,
209    flex32,
210    tf32,
211    f32,
212    f64
213);
214impl_unary_func!(
215    Cosh,
216    cosh,
217    __expand_cosh,
218    Arithmetic::Cosh,
219    f16,
220    bf16,
221    flex32,
222    tf32,
223    f32,
224    f64
225);
226impl_unary_func!(
227    ArcCos,
228    acos,
229    __expand_acos,
230    Arithmetic::ArcCos,
231    f16,
232    bf16,
233    flex32,
234    tf32,
235    f32,
236    f64
237);
238impl_unary_func!(
239    ArcSin,
240    asin,
241    __expand_asin,
242    Arithmetic::ArcSin,
243    f16,
244    bf16,
245    flex32,
246    tf32,
247    f32,
248    f64
249);
250impl_unary_func!(
251    ArcTan,
252    atan,
253    __expand_atan,
254    Arithmetic::ArcTan,
255    f16,
256    bf16,
257    flex32,
258    tf32,
259    f32,
260    f64
261);
262impl_unary_func!(
263    ArcSinh,
264    asinh,
265    __expand_asinh,
266    Arithmetic::ArcSinh,
267    f16,
268    bf16,
269    flex32,
270    tf32,
271    f32,
272    f64
273);
274impl_unary_func!(
275    ArcCosh,
276    acosh,
277    __expand_acosh,
278    Arithmetic::ArcCosh,
279    f16,
280    bf16,
281    flex32,
282    tf32,
283    f32,
284    f64
285);
286impl_unary_func!(
287    ArcTanh,
288    atanh,
289    __expand_atanh,
290    Arithmetic::ArcTanh,
291    f16,
292    bf16,
293    flex32,
294    tf32,
295    f32,
296    f64
297);
298impl_unary_func!(
299    Degrees,
300    to_degrees,
301    __expand_to_degrees,
302    Arithmetic::Degrees,
303    f16,
304    bf16,
305    flex32,
306    tf32,
307    f32,
308    f64
309);
310impl_unary_func!(
311    Radians,
312    to_radians,
313    __expand_to_radians,
314    Arithmetic::Radians,
315    f16,
316    bf16,
317    flex32,
318    tf32,
319    f32,
320    f64
321);
322impl_unary_func!(
323    Sqrt,
324    sqrt,
325    __expand_sqrt,
326    Arithmetic::Sqrt,
327    f16,
328    bf16,
329    flex32,
330    tf32,
331    f32,
332    f64
333);
334impl_unary_func!(
335    InverseSqrt,
336    inverse_sqrt,
337    __expand_inverse_sqrt,
338    Arithmetic::InverseSqrt,
339    f16,
340    bf16,
341    flex32,
342    tf32,
343    f32,
344    f64
345);
346impl_unary_func!(
347    Round,
348    round,
349    __expand_round,
350    Arithmetic::Round,
351    f16,
352    bf16,
353    flex32,
354    tf32,
355    f32,
356    f64
357);
358impl_unary_func!(
359    Floor,
360    floor,
361    __expand_floor,
362    Arithmetic::Floor,
363    f16,
364    bf16,
365    flex32,
366    tf32,
367    f32,
368    f64
369);
370impl_unary_func!(
371    Ceil,
372    ceil,
373    __expand_ceil,
374    Arithmetic::Ceil,
375    f16,
376    bf16,
377    flex32,
378    tf32,
379    f32,
380    f64
381);
382impl_unary_func!(
383    Trunc,
384    trunc,
385    __expand_trunc,
386    Arithmetic::Trunc,
387    f16,
388    bf16,
389    flex32,
390    tf32,
391    f32,
392    f64
393);
394impl_unary_func!(
395    Erf,
396    erf,
397    __expand_erf,
398    Arithmetic::Erf,
399    f16,
400    bf16,
401    flex32,
402    tf32,
403    f32,
404    f64
405);
406impl_unary_func!(
407    Recip,
408    recip,
409    __expand_recip,
410    Arithmetic::Recip,
411    f16,
412    bf16,
413    flex32,
414    tf32,
415    f32,
416    f64
417);
418impl_unary_func_fixed_out_vectorization!(
419    Magnitude,
420    magnitude,
421    __expand_magnitude,
422    Arithmetic::Magnitude,
423    0,
424    f16,
425    bf16,
426    flex32,
427    tf32,
428    f32,
429    f64
430);
431impl_unary_func!(
432    Normalize,
433    normalize,
434    __expand_normalize,
435    Arithmetic::Normalize,
436    f16,
437    bf16,
438    flex32,
439    tf32,
440    f32,
441    f64
442);
443impl_unary_func_fixed_out_ty!(
444    CountOnes,
445    count_ones,
446    __expand_count_ones,
447    u32,
448    Bitwise::CountOnes,
449    u8,
450    i8,
451    u16,
452    i16,
453    u32,
454    i32,
455    u64,
456    i64
457);
458impl_unary_func!(
459    ReverseBits,
460    reverse_bits,
461    __expand_reverse_bits,
462    Bitwise::ReverseBits,
463    u8,
464    i8,
465    u16,
466    i16,
467    u32,
468    i32,
469    u64,
470    i64
471);
472
473impl_unary_func!(
474    BitwiseNot,
475    bitwise_not,
476    __expand_bitwise_not,
477    Bitwise::BitwiseNot,
478    u8,
479    i8,
480    u16,
481    i16,
482    u32,
483    i32,
484    u64,
485    i64
486);
487impl_unary_func_fixed_out_ty!(
488    LeadingZeros,
489    leading_zeros,
490    __expand_leading_zeros,
491    u32,
492    Bitwise::LeadingZeros,
493    u8,
494    i8,
495    u16,
496    i16,
497    u32,
498    i32,
499    u64,
500    i64
501);
502impl_unary_func_fixed_out_ty!(
503    FindFirstSet,
504    find_first_set,
505    __expand_find_first_set,
506    u32,
507    Bitwise::FindFirstSet,
508    u8,
509    i8,
510    u16,
511    i16,
512    u32,
513    i32,
514    u64,
515    i64
516);
517impl_unary_func_fixed_out_ty!(
518    IsNan,
519    is_nan,
520    __expand_is_nan,
521    bool,
522    Comparison::IsNan,
523    f16,
524    bf16,
525    flex32,
526    tf32,
527    f32,
528    f64
529);
530impl_unary_func_fixed_out_ty!(
531    IsInf,
532    is_inf,
533    __expand_is_inf,
534    bool,
535    Comparison::IsInf,
536    f16,
537    bf16,
538    flex32,
539    tf32,
540    f32,
541    f64
542);