1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use crate::{
    frontend::{CubeContext, UInt, BF16, F16, F32, F64, I32, I64},
    ir::Operator,
    prelude::{CubePrimitive, ExpandElementTyped},
    unexpanded,
};

use super::base::unary_expand;

pub mod not {
    use super::*;

    pub fn expand(
        context: &mut CubeContext,
        x: ExpandElementTyped<bool>,
    ) -> ExpandElementTyped<bool> {
        unary_expand(context, x.into(), Operator::Not).into()
    }
}

macro_rules! impl_unary_func {
    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
        pub trait $trait_name: CubePrimitive + Sized {
            #[allow(unused_variables)]
            fn $method_name(x: Self) -> Self {
                unexpanded!()
            }

            fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<Self> {
                unary_expand(context, x.into(), $operator).into()
            }
        }

        $(impl $trait_name for $type {})*
    }
}

impl_unary_func!(
    Abs,
    abs,
    __expand_abs,
    Operator::Abs,
    F16,
    BF16,
    F32,
    F64,
    I32,
    I64,
    UInt
);
impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64);
impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64);
impl_unary_func!(
    Log1p,
    log1p,
    __expand_log1p,
    Operator::Log1p,
    F16,
    BF16,
    F32,
    F64
);
impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64);
impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64);
impl_unary_func!(
    Tanh,
    tanh,
    __expand_tanh,
    Operator::Tanh,
    F16,
    BF16,
    F32,
    F64
);
impl_unary_func!(
    Sqrt,
    sqrt,
    __expand_sqrt,
    Operator::Sqrt,
    F16,
    BF16,
    F32,
    F64
);
impl_unary_func!(
    Floor,
    floor,
    __expand_floor,
    Operator::Floor,
    F16,
    BF16,
    F32,
    F64
);
impl_unary_func!(
    Ceil,
    ceil,
    __expand_ceil,
    Operator::Ceil,
    F16,
    BF16,
    F32,
    F64
);
impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64);
impl_unary_func!(
    Recip,
    recip,
    __expand_recip,
    Operator::Recip,
    F16,
    BF16,
    F32,
    F64
);