Skip to main content

cubecl_core/frontend/operation/
unary.rs

1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7    flex32,
8    ir::{Arithmetic, ManagedVariable, Scope},
9    prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret},
10    tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16    use super::*;
17
18    pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19        if x.expand.ty.is_bool() {
20            unary_expand(scope, x.into(), Operator::Not).into()
21        } else {
22            unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23        }
24    }
25}
26
27pub mod neg {
28    use super::*;
29
30    pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31        unary_expand(scope, x.into(), Arithmetic::Neg).into()
32    }
33}
34
35macro_rules! impl_unary_func {
36    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37        paste::paste! {
38            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39                #[allow(unused_variables)]
40                fn $method_name(self) -> Self {
41                    unexpanded!()
42                }
43
44                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45                    x.[<__expand_ $method_name _method>](scope)
46                }
47            }
48
49            pub trait [<$trait_name Expand>] {
50                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51            }
52
53            $(impl $trait_name for $type {})*
54            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56                    unary_expand(scope, self.into(), $operator).into()
57                }
58            }
59        }
60    }
61}
62
63impl Exp for f32 {
64    fn exp(self) -> Self {
65        self.exp()
66    }
67}
68
69macro_rules! impl_unary_func_scalar_out {
70    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
71        paste::paste! {
72            pub trait $trait_name: CubePrimitive
73                + CubeType<ExpandType: [<$trait_name Expand>]
74                + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
75                + Sized {
76                #[allow(unused_variables)]
77                fn $method_name(self) -> Self {
78                    unexpanded!()
79                }
80
81                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
82                    x.[<__expand_ $method_name _method>](scope)
83                }
84            }
85
86            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
87                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
88            }
89
90            $(impl $trait_name for $type {})*
91            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
92                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
93                    let expand_element: ManagedVariable = self.into();
94                    let item = expand_element.ty.with_vector_size(0);
95                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
96                }
97            }
98        }
99    }
100}
101
102macro_rules! impl_unary_func_fixed_out_ty {
103    ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
104        paste::paste! {
105            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
106            + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
107                #[allow(unused_variables, clippy::wrong_self_convention)]
108                fn $method_name(self) -> Self::WithScalar<$out_ty> {
109                    unexpanded!()
110                }
111
112                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
113                    x.[<__expand_ $method_name _method>](scope)
114                }
115            }
116
117            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
118                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
119            }
120
121            $(impl $trait_name for $type {})*
122            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
123                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
124                    let expand_element: ManagedVariable = self.into();
125                    let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
126                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
127                }
128            }
129        }
130    }
131}
132
133// Needs special handling because Rust combines bitwise and logical or into one trait
134macro_rules! impl_not {
135    ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
136        paste::paste! {
137            pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
138                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
139                    x.[<__expand_ $method_name _method>](scope)
140                }
141            }
142
143            pub trait [<$trait_name Expand>] {
144                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
145            }
146
147            $(impl [<Cube $trait_name>] for $type {})*
148            impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
149                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
150                    not::expand(scope, self.into())
151                }
152            }
153        }
154    }
155}
156
157impl_not!(
158    Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
159);
160
161impl_unary_func!(
162    Abs,
163    abs,
164    Arithmetic::Abs,
165    e2m1,
166    e4m3,
167    e5m2,
168    ue8m0,
169    f16,
170    bf16,
171    flex32,
172    tf32,
173    f32,
174    f64,
175    i8,
176    i16,
177    i32,
178    i64,
179    u8,
180    u16,
181    u32,
182    u64,
183    usize,
184    isize
185);
186impl_unary_func!(
187    Exp,
188    exp,
189    Arithmetic::Exp,
190    f16,
191    bf16,
192    flex32,
193    tf32,
194    // f32,
195    f64
196);
197impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
198impl_unary_func!(
199    Log1p,
200    log1p,
201    Arithmetic::Log1p,
202    f16,
203    bf16,
204    flex32,
205    tf32,
206    f32,
207    f64
208);
209impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
210impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
211impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(
213    Tanh,
214    tanh,
215    Arithmetic::Tanh,
216    f16,
217    bf16,
218    flex32,
219    tf32,
220    f32,
221    f64
222);
223impl_unary_func!(
224    Sinh,
225    sinh,
226    Arithmetic::Sinh,
227    f16,
228    bf16,
229    flex32,
230    tf32,
231    f32,
232    f64
233);
234impl_unary_func!(
235    Cosh,
236    cosh,
237    Arithmetic::Cosh,
238    f16,
239    bf16,
240    flex32,
241    tf32,
242    f32,
243    f64
244);
245impl_unary_func!(
246    ArcCos,
247    acos,
248    Arithmetic::ArcCos,
249    f16,
250    bf16,
251    flex32,
252    tf32,
253    f32,
254    f64
255);
256impl_unary_func!(
257    ArcSin,
258    asin,
259    Arithmetic::ArcSin,
260    f16,
261    bf16,
262    flex32,
263    tf32,
264    f32,
265    f64
266);
267impl_unary_func!(
268    ArcTan,
269    atan,
270    Arithmetic::ArcTan,
271    f16,
272    bf16,
273    flex32,
274    tf32,
275    f32,
276    f64
277);
278impl_unary_func!(
279    ArcSinh,
280    asinh,
281    Arithmetic::ArcSinh,
282    f16,
283    bf16,
284    flex32,
285    tf32,
286    f32,
287    f64
288);
289impl_unary_func!(
290    ArcCosh,
291    acosh,
292    Arithmetic::ArcCosh,
293    f16,
294    bf16,
295    flex32,
296    tf32,
297    f32,
298    f64
299);
300impl_unary_func!(
301    ArcTanh,
302    atanh,
303    Arithmetic::ArcTanh,
304    f16,
305    bf16,
306    flex32,
307    tf32,
308    f32,
309    f64
310);
311impl_unary_func!(
312    Degrees,
313    to_degrees,
314    Arithmetic::Degrees,
315    f16,
316    bf16,
317    flex32,
318    tf32,
319    f32,
320    f64
321);
322impl_unary_func!(
323    Radians,
324    to_radians,
325    Arithmetic::Radians,
326    f16,
327    bf16,
328    flex32,
329    tf32,
330    f32,
331    f64
332);
333impl_unary_func!(
334    Sqrt,
335    sqrt,
336    Arithmetic::Sqrt,
337    f16,
338    bf16,
339    flex32,
340    tf32,
341    f32,
342    f64
343);
344impl_unary_func!(
345    InverseSqrt,
346    inverse_sqrt,
347    Arithmetic::InverseSqrt,
348    f16,
349    bf16,
350    flex32,
351    tf32,
352    f32,
353    f64
354);
355impl_unary_func!(
356    Round,
357    round,
358    Arithmetic::Round,
359    f16,
360    bf16,
361    flex32,
362    tf32,
363    f32,
364    f64
365);
366impl_unary_func!(
367    Floor,
368    floor,
369    Arithmetic::Floor,
370    f16,
371    bf16,
372    flex32,
373    tf32,
374    f32,
375    f64
376);
377impl_unary_func!(
378    Ceil,
379    ceil,
380    Arithmetic::Ceil,
381    f16,
382    bf16,
383    flex32,
384    tf32,
385    f32,
386    f64
387);
388impl_unary_func!(
389    Trunc,
390    trunc,
391    Arithmetic::Trunc,
392    f16,
393    bf16,
394    flex32,
395    tf32,
396    f32,
397    f64
398);
399impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
400impl_unary_func!(
401    Recip,
402    recip,
403    Arithmetic::Recip,
404    f16,
405    bf16,
406    flex32,
407    tf32,
408    f32,
409    f64
410);
411impl_unary_func_scalar_out!(
412    Magnitude,
413    magnitude,
414    Arithmetic::Magnitude,
415    f16,
416    bf16,
417    flex32,
418    tf32,
419    f32,
420    f64
421);
422impl_unary_func_scalar_out!(
423    VectorSum,
424    vector_sum,
425    Arithmetic::VectorSum,
426    e2m1,
427    e4m3,
428    e5m2,
429    ue8m0,
430    f16,
431    bf16,
432    flex32,
433    tf32,
434    f32,
435    f64,
436    i8,
437    i16,
438    i32,
439    i64,
440    u8,
441    u16,
442    u32,
443    u64,
444    usize,
445    isize
446);
447impl_unary_func!(
448    Normalize,
449    normalize,
450    Arithmetic::Normalize,
451    f16,
452    bf16,
453    flex32,
454    tf32,
455    f32,
456    f64
457);
458impl_unary_func_fixed_out_ty!(
459    CountOnes,
460    count_ones,
461    u32,
462    Bitwise::CountOnes,
463    u8,
464    i8,
465    u16,
466    i16,
467    u32,
468    i32,
469    u64,
470    i64,
471    usize,
472    isize
473);
474impl_unary_func!(
475    ReverseBits,
476    reverse_bits,
477    Bitwise::ReverseBits,
478    u8,
479    i8,
480    u16,
481    i16,
482    u32,
483    i32,
484    u64,
485    i64,
486    usize,
487    isize
488);
489
490impl_unary_func_fixed_out_ty!(
491    LeadingZeros,
492    leading_zeros,
493    u32,
494    Bitwise::LeadingZeros,
495    u8,
496    i8,
497    u16,
498    i16,
499    u32,
500    i32,
501    u64,
502    i64,
503    usize,
504    isize
505);
506impl_unary_func_fixed_out_ty!(
507    TrailingZeros,
508    trailing_zeros,
509    u32,
510    Bitwise::TrailingZeros,
511    u8,
512    i8,
513    u16,
514    i16,
515    u32,
516    i32,
517    u64,
518    i64,
519    usize,
520    isize
521);
522impl_unary_func_fixed_out_ty!(
523    FindFirstSet,
524    find_first_set,
525    u32,
526    Bitwise::FindFirstSet,
527    u8,
528    i8,
529    u16,
530    i16,
531    u32,
532    i32,
533    u64,
534    i64,
535    usize,
536    isize
537);
538impl_unary_func_fixed_out_ty!(
539    IsNan,
540    is_nan,
541    bool,
542    Comparison::IsNan,
543    f16,
544    bf16,
545    flex32,
546    tf32,
547    f32,
548    f64
549);
550impl_unary_func_fixed_out_ty!(
551    IsInf,
552    is_inf,
553    bool,
554    Comparison::IsInf,
555    f16,
556    bf16,
557    flex32,
558    tf32,
559    f32,
560    f64
561);
562
563pub trait FloatBits:
564    CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
565{
566    type Bits: CubePrimitive;
567
568    fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
569        Self::__expand_reinterpret(scope, bits)
570    }
571
572    fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
573        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
574    }
575}
576
577pub trait FloatBitsExpand: Sized {
578    type Bits: CubePrimitive;
579
580    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
581}
582
583impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
584    type Bits = F::Bits;
585
586    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
587        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
588    }
589}
590
591impl FloatBits for e2m1x2 {
592    type Bits = u8;
593}
594
595impl FloatBits for e5m2 {
596    type Bits = u8;
597}
598
599impl FloatBits for e4m3 {
600    type Bits = u8;
601}
602
603impl FloatBits for f16 {
604    type Bits = u16;
605}
606
607impl FloatBits for bf16 {
608    type Bits = u16;
609}
610
611impl FloatBits for f32 {
612    type Bits = u32;
613}
614
615impl FloatBits for f64 {
616    type Bits = u64;
617}