cubecl_core/frontend/operation/
unary.rs

1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator, Type};
4use half::{bf16, f16};
5
6use crate::{
7    flex32,
8    ir::{Arithmetic, ExpandElement, Scope},
9    prelude::{CubePrimitive, CubeType, ExpandElementTyped, Reinterpret},
10    tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16    use super::*;
17
18    pub fn expand<T: CubeNot>(
19        scope: &mut Scope,
20        x: ExpandElementTyped<T>,
21    ) -> ExpandElementTyped<T> {
22        if x.expand.ty.is_bool() {
23            unary_expand(scope, x.into(), Operator::Not).into()
24        } else {
25            unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
26        }
27    }
28}
29
30pub mod neg {
31    use super::*;
32
33    pub fn expand<E: CubePrimitive>(
34        scope: &mut Scope,
35        x: ExpandElementTyped<E>,
36    ) -> ExpandElementTyped<E> {
37        unary_expand(scope, x.into(), Arithmetic::Neg).into()
38    }
39}
40
41macro_rules! impl_unary_func {
42    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
43        paste::paste! {
44            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
45                #[allow(unused_variables)]
46                fn $method_name(self) -> Self {
47                    unexpanded!()
48                }
49
50                fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
51                    x.[<__expand_ $method_name _method>](scope)
52                }
53            }
54
55            pub trait [<$trait_name Expand>] {
56                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
57            }
58
59            $(impl $trait_name for $type {})*
60            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
61                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
62                    unary_expand(scope, self.into(), $operator).into()
63                }
64            }
65        }
66    }
67}
68
69impl Exp for f32 {
70    fn exp(self) -> Self {
71        self.exp()
72    }
73}
74
75macro_rules! impl_unary_func_fixed_out_vectorization {
76    ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
77        paste::paste! {
78            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
79                #[allow(unused_variables)]
80                fn $method_name(self) -> Self {
81                    unexpanded!()
82                }
83
84                fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
85                    x.[<__expand_ $method_name _method>](scope)
86                }
87            }
88
89            pub trait [<$trait_name Expand>] {
90                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
91            }
92
93            $(impl $trait_name for $type {})*
94            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
95                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
96                    let expand_element: ExpandElement = self.into();
97                    let item = expand_element.ty.line($out_vectorization);
98                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
99                }
100            }
101        }
102    }
103}
104
105macro_rules! impl_unary_func_fixed_out_ty {
106    ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
107        paste::paste! {
108            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
109                #[allow(unused_variables, clippy::wrong_self_convention)]
110                fn $method_name(self) -> $out_ty {
111                    unexpanded!()
112                }
113
114                fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<$out_ty> {
115                    x.[<__expand_ $method_name _method>](scope)
116                }
117            }
118
119            pub trait [<$trait_name Expand>] {
120                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> ExpandElementTyped<$out_ty>;
121            }
122
123            $(impl $trait_name for $type {})*
124            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
125                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> ExpandElementTyped<$out_ty> {
126                    let expand_element: ExpandElement = self.into();
127                    let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
128                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
129                }
130            }
131        }
132    }
133}
134
135// Needs special handling because Rust combines bitwise and logical or into one trait
136macro_rules! impl_not {
137    ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
138        paste::paste! {
139            pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
140                fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
141                    x.[<__expand_ $method_name _method>](scope)
142                }
143            }
144
145            pub trait [<$trait_name Expand>] {
146                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
147            }
148
149            $(impl [<Cube $trait_name>] for $type {})*
150            impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
151                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
152                    not::expand(scope, self.into())
153                }
154            }
155        }
156    }
157}
158
159impl_not!(
160    Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
161);
162
163impl_unary_func!(
164    Abs,
165    abs,
166    Arithmetic::Abs,
167    e2m1,
168    e4m3,
169    e5m2,
170    ue8m0,
171    f16,
172    bf16,
173    flex32,
174    tf32,
175    f32,
176    f64,
177    i8,
178    i16,
179    i32,
180    i64,
181    u8,
182    u16,
183    u32,
184    u64,
185    usize,
186    isize
187);
188impl_unary_func!(
189    Exp,
190    exp,
191    Arithmetic::Exp,
192    f16,
193    bf16,
194    flex32,
195    tf32,
196    // f32,
197    f64
198);
199impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
200impl_unary_func!(
201    Log1p,
202    log1p,
203    Arithmetic::Log1p,
204    f16,
205    bf16,
206    flex32,
207    tf32,
208    f32,
209    f64
210);
211impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
213impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
214impl_unary_func!(
215    Tanh,
216    tanh,
217    Arithmetic::Tanh,
218    f16,
219    bf16,
220    flex32,
221    tf32,
222    f32,
223    f64
224);
225impl_unary_func!(
226    Sinh,
227    sinh,
228    Arithmetic::Sinh,
229    f16,
230    bf16,
231    flex32,
232    tf32,
233    f32,
234    f64
235);
236impl_unary_func!(
237    Cosh,
238    cosh,
239    Arithmetic::Cosh,
240    f16,
241    bf16,
242    flex32,
243    tf32,
244    f32,
245    f64
246);
247impl_unary_func!(
248    ArcCos,
249    acos,
250    Arithmetic::ArcCos,
251    f16,
252    bf16,
253    flex32,
254    tf32,
255    f32,
256    f64
257);
258impl_unary_func!(
259    ArcSin,
260    asin,
261    Arithmetic::ArcSin,
262    f16,
263    bf16,
264    flex32,
265    tf32,
266    f32,
267    f64
268);
269impl_unary_func!(
270    ArcTan,
271    atan,
272    Arithmetic::ArcTan,
273    f16,
274    bf16,
275    flex32,
276    tf32,
277    f32,
278    f64
279);
280impl_unary_func!(
281    ArcSinh,
282    asinh,
283    Arithmetic::ArcSinh,
284    f16,
285    bf16,
286    flex32,
287    tf32,
288    f32,
289    f64
290);
291impl_unary_func!(
292    ArcCosh,
293    acosh,
294    Arithmetic::ArcCosh,
295    f16,
296    bf16,
297    flex32,
298    tf32,
299    f32,
300    f64
301);
302impl_unary_func!(
303    ArcTanh,
304    atanh,
305    Arithmetic::ArcTanh,
306    f16,
307    bf16,
308    flex32,
309    tf32,
310    f32,
311    f64
312);
313impl_unary_func!(
314    Degrees,
315    to_degrees,
316    Arithmetic::Degrees,
317    f16,
318    bf16,
319    flex32,
320    tf32,
321    f32,
322    f64
323);
324impl_unary_func!(
325    Radians,
326    to_radians,
327    Arithmetic::Radians,
328    f16,
329    bf16,
330    flex32,
331    tf32,
332    f32,
333    f64
334);
335impl_unary_func!(
336    Sqrt,
337    sqrt,
338    Arithmetic::Sqrt,
339    f16,
340    bf16,
341    flex32,
342    tf32,
343    f32,
344    f64
345);
346impl_unary_func!(
347    InverseSqrt,
348    inverse_sqrt,
349    Arithmetic::InverseSqrt,
350    f16,
351    bf16,
352    flex32,
353    tf32,
354    f32,
355    f64
356);
357impl_unary_func!(
358    Round,
359    round,
360    Arithmetic::Round,
361    f16,
362    bf16,
363    flex32,
364    tf32,
365    f32,
366    f64
367);
368impl_unary_func!(
369    Floor,
370    floor,
371    Arithmetic::Floor,
372    f16,
373    bf16,
374    flex32,
375    tf32,
376    f32,
377    f64
378);
379impl_unary_func!(
380    Ceil,
381    ceil,
382    Arithmetic::Ceil,
383    f16,
384    bf16,
385    flex32,
386    tf32,
387    f32,
388    f64
389);
390impl_unary_func!(
391    Trunc,
392    trunc,
393    Arithmetic::Trunc,
394    f16,
395    bf16,
396    flex32,
397    tf32,
398    f32,
399    f64
400);
401impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
402impl_unary_func!(
403    Recip,
404    recip,
405    Arithmetic::Recip,
406    f16,
407    bf16,
408    flex32,
409    tf32,
410    f32,
411    f64
412);
413impl_unary_func_fixed_out_vectorization!(
414    Magnitude,
415    magnitude,
416    Arithmetic::Magnitude,
417    0,
418    f16,
419    bf16,
420    flex32,
421    tf32,
422    f32,
423    f64
424);
425impl_unary_func!(
426    Normalize,
427    normalize,
428    Arithmetic::Normalize,
429    f16,
430    bf16,
431    flex32,
432    tf32,
433    f32,
434    f64
435);
436impl_unary_func_fixed_out_ty!(
437    CountOnes,
438    count_ones,
439    u32,
440    Bitwise::CountOnes,
441    u8,
442    i8,
443    u16,
444    i16,
445    u32,
446    i32,
447    u64,
448    i64,
449    usize,
450    isize
451);
452impl_unary_func!(
453    ReverseBits,
454    reverse_bits,
455    Bitwise::ReverseBits,
456    u8,
457    i8,
458    u16,
459    i16,
460    u32,
461    i32,
462    u64,
463    i64,
464    usize,
465    isize
466);
467
468impl_unary_func_fixed_out_ty!(
469    LeadingZeros,
470    leading_zeros,
471    u32,
472    Bitwise::LeadingZeros,
473    u8,
474    i8,
475    u16,
476    i16,
477    u32,
478    i32,
479    u64,
480    i64,
481    usize,
482    isize
483);
484impl_unary_func_fixed_out_ty!(
485    FindFirstSet,
486    find_first_set,
487    u32,
488    Bitwise::FindFirstSet,
489    u8,
490    i8,
491    u16,
492    i16,
493    u32,
494    i32,
495    u64,
496    i64,
497    usize,
498    isize
499);
500impl_unary_func_fixed_out_ty!(
501    IsNan,
502    is_nan,
503    bool,
504    Comparison::IsNan,
505    f16,
506    bf16,
507    flex32,
508    tf32,
509    f32,
510    f64
511);
512impl_unary_func_fixed_out_ty!(
513    IsInf,
514    is_inf,
515    bool,
516    Comparison::IsInf,
517    f16,
518    bf16,
519    flex32,
520    tf32,
521    f32,
522    f64
523);
524
525pub trait FloatBits:
526    CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
527{
528    type Bits: CubePrimitive;
529
530    fn __expand_from_bits(
531        scope: &mut Scope,
532        bits: ExpandElementTyped<Self::Bits>,
533    ) -> ExpandElementTyped<Self> {
534        Self::__expand_reinterpret(scope, bits)
535    }
536
537    fn __expand_to_bits(
538        scope: &mut Scope,
539        this: ExpandElementTyped<Self>,
540    ) -> ExpandElementTyped<Self::Bits> {
541        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
542    }
543}
544
545pub trait FloatBitsExpand: Sized {
546    type Bits: CubePrimitive;
547
548    fn __expand_to_bits_method(self, scope: &mut Scope) -> ExpandElementTyped<Self::Bits>;
549}
550
551impl<F: FloatBits> FloatBitsExpand for ExpandElementTyped<F> {
552    type Bits = F::Bits;
553
554    fn __expand_to_bits_method(self, scope: &mut Scope) -> ExpandElementTyped<Self::Bits> {
555        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
556    }
557}
558
559impl FloatBits for e2m1x2 {
560    type Bits = u8;
561}
562
563impl FloatBits for e5m2 {
564    type Bits = u8;
565}
566
567impl FloatBits for e4m3 {
568    type Bits = u8;
569}
570
571impl FloatBits for f16 {
572    type Bits = u16;
573}
574
575impl FloatBits for bf16 {
576    type Bits = u16;
577}
578
579impl FloatBits for f32 {
580    type Bits = u32;
581}
582
583impl FloatBits for f64 {
584    type Bits = u64;
585}