Skip to main content

cubecl_core/frontend/operation/
unary.rs

1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7    flex32,
8    ir::{Arithmetic, ManagedVariable, Scope},
9    prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret},
10    tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16    use super::*;
17
18    pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19        if x.expand.ty.is_bool() {
20            unary_expand(scope, x.into(), Operator::Not).into()
21        } else {
22            unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23        }
24    }
25}
26
27pub mod neg {
28    use super::*;
29
30    pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31        unary_expand(scope, x.into(), Arithmetic::Neg).into()
32    }
33}
34
35macro_rules! impl_unary_func {
36    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37        paste::paste! {
38            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39                #[allow(unused_variables)]
40                fn $method_name(self) -> Self {
41                    unexpanded!()
42                }
43
44                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45                    x.[<__expand_ $method_name _method>](scope)
46                }
47            }
48
49            pub trait [<$trait_name Expand>] {
50                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51            }
52
53            $(impl $trait_name for $type {})*
54            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56                    unary_expand(scope, self.into(), $operator).into()
57                }
58            }
59        }
60    }
61}
62
63impl Exp for f32 {
64    fn exp(self) -> Self {
65        self.exp()
66    }
67}
68
69macro_rules! impl_unary_func_scalar_out {
70    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
71        paste::paste! {
72            pub trait $trait_name: CubePrimitive
73                + CubeType<ExpandType: [<$trait_name Expand>]
74                + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
75                + Sized {
76                #[allow(unused_variables)]
77                fn $method_name(self) -> Self {
78                    unexpanded!()
79                }
80
81                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
82                    x.[<__expand_ $method_name _method>](scope)
83                }
84            }
85
86            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
87                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
88            }
89
90            $(impl $trait_name for $type {})*
91            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
92                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
93                    let expand_element: ManagedVariable = self.into();
94                    let item = expand_element.ty.with_vector_size(0);
95                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
96                }
97            }
98        }
99    }
100}
101
102macro_rules! impl_unary_func_fixed_out_ty {
103    ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
104        paste::paste! {
105            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
106            + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
107                #[allow(unused_variables, clippy::wrong_self_convention)]
108                fn $method_name(self) -> Self::WithScalar<$out_ty> {
109                    unexpanded!()
110                }
111
112                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
113                    x.[<__expand_ $method_name _method>](scope)
114                }
115            }
116
117            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
118                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
119            }
120
121            $(impl $trait_name for $type {})*
122            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
123                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
124                    let expand_element: ManagedVariable = self.into();
125                    let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
126                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
127                }
128            }
129        }
130    }
131}
132
133// Needs special handling because Rust combines bitwise and logical or into one trait
134macro_rules! impl_not {
135    ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
136        paste::paste! {
137            pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
138                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
139                    x.[<__expand_ $method_name _method>](scope)
140                }
141            }
142
143            pub trait [<$trait_name Expand>] {
144                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
145            }
146
147            $(impl [<Cube $trait_name>] for $type {})*
148            impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
149                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
150                    not::expand(scope, self.into())
151                }
152            }
153        }
154    }
155}
156
157impl_not!(
158    Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
159);
160
161impl_unary_func!(
162    Abs,
163    abs,
164    Arithmetic::Abs,
165    e2m1,
166    e4m3,
167    e5m2,
168    ue8m0,
169    f16,
170    bf16,
171    flex32,
172    tf32,
173    f32,
174    f64,
175    i8,
176    i16,
177    i32,
178    i64,
179    u8,
180    u16,
181    u32,
182    u64,
183    usize,
184    isize
185);
186impl_unary_func!(
187    Exp,
188    exp,
189    Arithmetic::Exp,
190    f16,
191    bf16,
192    flex32,
193    tf32,
194    // f32,
195    f64
196);
197impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
198impl_unary_func!(
199    Log1p,
200    log1p,
201    Arithmetic::Log1p,
202    f16,
203    bf16,
204    flex32,
205    tf32,
206    f32,
207    f64
208);
209impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
210impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
211impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(
213    Tanh,
214    tanh,
215    Arithmetic::Tanh,
216    f16,
217    bf16,
218    flex32,
219    tf32,
220    f32,
221    f64
222);
223impl_unary_func!(
224    Sinh,
225    sinh,
226    Arithmetic::Sinh,
227    f16,
228    bf16,
229    flex32,
230    tf32,
231    f32,
232    f64
233);
234impl_unary_func!(
235    Cosh,
236    cosh,
237    Arithmetic::Cosh,
238    f16,
239    bf16,
240    flex32,
241    tf32,
242    f32,
243    f64
244);
245impl_unary_func!(
246    ArcCos,
247    acos,
248    Arithmetic::ArcCos,
249    f16,
250    bf16,
251    flex32,
252    tf32,
253    f32,
254    f64
255);
256impl_unary_func!(
257    ArcSin,
258    asin,
259    Arithmetic::ArcSin,
260    f16,
261    bf16,
262    flex32,
263    tf32,
264    f32,
265    f64
266);
267impl_unary_func!(
268    ArcTan,
269    atan,
270    Arithmetic::ArcTan,
271    f16,
272    bf16,
273    flex32,
274    tf32,
275    f32,
276    f64
277);
278impl_unary_func!(
279    ArcSinh,
280    asinh,
281    Arithmetic::ArcSinh,
282    f16,
283    bf16,
284    flex32,
285    tf32,
286    f32,
287    f64
288);
289impl_unary_func!(
290    ArcCosh,
291    acosh,
292    Arithmetic::ArcCosh,
293    f16,
294    bf16,
295    flex32,
296    tf32,
297    f32,
298    f64
299);
300impl_unary_func!(
301    ArcTanh,
302    atanh,
303    Arithmetic::ArcTanh,
304    f16,
305    bf16,
306    flex32,
307    tf32,
308    f32,
309    f64
310);
311impl_unary_func!(
312    Degrees,
313    to_degrees,
314    Arithmetic::Degrees,
315    f16,
316    bf16,
317    flex32,
318    tf32,
319    f32,
320    f64
321);
322impl_unary_func!(
323    Radians,
324    to_radians,
325    Arithmetic::Radians,
326    f16,
327    bf16,
328    flex32,
329    tf32,
330    f32,
331    f64
332);
333impl_unary_func!(
334    Sqrt,
335    sqrt,
336    Arithmetic::Sqrt,
337    f16,
338    bf16,
339    flex32,
340    tf32,
341    f32,
342    f64
343);
344impl_unary_func!(
345    InverseSqrt,
346    inverse_sqrt,
347    Arithmetic::InverseSqrt,
348    f16,
349    bf16,
350    flex32,
351    tf32,
352    f32,
353    f64
354);
355impl_unary_func!(
356    Round,
357    round,
358    Arithmetic::Round,
359    f16,
360    bf16,
361    flex32,
362    tf32,
363    f32,
364    f64
365);
366impl_unary_func!(
367    Floor,
368    floor,
369    Arithmetic::Floor,
370    f16,
371    bf16,
372    flex32,
373    tf32,
374    f32,
375    f64
376);
377impl_unary_func!(
378    Ceil,
379    ceil,
380    Arithmetic::Ceil,
381    f16,
382    bf16,
383    flex32,
384    tf32,
385    f32,
386    f64
387);
388impl_unary_func!(
389    Trunc,
390    trunc,
391    Arithmetic::Trunc,
392    f16,
393    bf16,
394    flex32,
395    tf32,
396    f32,
397    f64
398);
399impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
400impl_unary_func!(
401    Recip,
402    recip,
403    Arithmetic::Recip,
404    f16,
405    bf16,
406    flex32,
407    tf32,
408    f32,
409    f64
410);
411impl_unary_func_scalar_out!(
412    Magnitude,
413    magnitude,
414    Arithmetic::Magnitude,
415    f16,
416    bf16,
417    flex32,
418    tf32,
419    f32,
420    f64
421);
422impl_unary_func!(
423    Normalize,
424    normalize,
425    Arithmetic::Normalize,
426    f16,
427    bf16,
428    flex32,
429    tf32,
430    f32,
431    f64
432);
433impl_unary_func_fixed_out_ty!(
434    CountOnes,
435    count_ones,
436    u32,
437    Bitwise::CountOnes,
438    u8,
439    i8,
440    u16,
441    i16,
442    u32,
443    i32,
444    u64,
445    i64,
446    usize,
447    isize
448);
449impl_unary_func!(
450    ReverseBits,
451    reverse_bits,
452    Bitwise::ReverseBits,
453    u8,
454    i8,
455    u16,
456    i16,
457    u32,
458    i32,
459    u64,
460    i64,
461    usize,
462    isize
463);
464
465impl_unary_func_fixed_out_ty!(
466    LeadingZeros,
467    leading_zeros,
468    u32,
469    Bitwise::LeadingZeros,
470    u8,
471    i8,
472    u16,
473    i16,
474    u32,
475    i32,
476    u64,
477    i64,
478    usize,
479    isize
480);
481impl_unary_func_fixed_out_ty!(
482    TrailingZeros,
483    trailing_zeros,
484    u32,
485    Bitwise::TrailingZeros,
486    u8,
487    i8,
488    u16,
489    i16,
490    u32,
491    i32,
492    u64,
493    i64,
494    usize,
495    isize
496);
497impl_unary_func_fixed_out_ty!(
498    FindFirstSet,
499    find_first_set,
500    u32,
501    Bitwise::FindFirstSet,
502    u8,
503    i8,
504    u16,
505    i16,
506    u32,
507    i32,
508    u64,
509    i64,
510    usize,
511    isize
512);
513impl_unary_func_fixed_out_ty!(
514    IsNan,
515    is_nan,
516    bool,
517    Comparison::IsNan,
518    f16,
519    bf16,
520    flex32,
521    tf32,
522    f32,
523    f64
524);
525impl_unary_func_fixed_out_ty!(
526    IsInf,
527    is_inf,
528    bool,
529    Comparison::IsInf,
530    f16,
531    bf16,
532    flex32,
533    tf32,
534    f32,
535    f64
536);
537
538pub trait FloatBits:
539    CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
540{
541    type Bits: CubePrimitive;
542
543    fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
544        Self::__expand_reinterpret(scope, bits)
545    }
546
547    fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
548        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
549    }
550}
551
552pub trait FloatBitsExpand: Sized {
553    type Bits: CubePrimitive;
554
555    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
556}
557
558impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
559    type Bits = F::Bits;
560
561    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
562        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
563    }
564}
565
566impl FloatBits for e2m1x2 {
567    type Bits = u8;
568}
569
570impl FloatBits for e5m2 {
571    type Bits = u8;
572}
573
574impl FloatBits for e4m3 {
575    type Bits = u8;
576}
577
578impl FloatBits for f16 {
579    type Bits = u16;
580}
581
582impl FloatBits for bf16 {
583    type Bits = u16;
584}
585
586impl FloatBits for f32 {
587    type Bits = u32;
588}
589
590impl FloatBits for f64 {
591    type Bits = u64;
592}