1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/*
    Appellation: kinds <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, EnumIs, EnumIter, VariantNames};

#[cfg_attr(
    feature = "serde",
    derive(Deserialize, Serialize,),
    serde(rename_all = "lowercase", untagged)
)]
#[derive(
    Clone,
    Copy,
    Debug,
    Display,
    EnumCount,
    EnumIs,
    EnumIter,
    Eq,
    Hash,
    Ord,
    PartialEq,
    PartialOrd,
    VariantNames,
)]
#[repr(u8)]
#[strum(serialize_all = "lowercase")]
pub enum UnaryOp {
    Abs,
    Cos,
    Cosh,
    Exp,
    Floor,
    #[cfg_attr(
        feature = "serde",
        serde(alias = "inverse", alias = "recip", alias = "reciprocal")
    )]
    Inv,
    Ln,
    Neg,
    Not,
    Sin,
    Sinh,
    #[cfg_attr(feature = "serde", serde(alias = "square_root"))]
    Sqrt,
    Square,
    Tan,
    Tanh,
}

impl UnaryOp {
    pub fn differentiable(&self) -> bool {
        match self {
            UnaryOp::Floor | UnaryOp::Inv => false,
            _ => true,
        }
    }

    unit_enum_constructor!(
        (Abs, abs),
        (Cos, cos),
        (Cosh, cosh),
        (Exp, exp),
        (Floor, floor),
        (Inv, inv),
        (Ln, ln),
        (Neg, neg),
        (Not, not),
        (Sin, sin),
        (Sinh, sinh),
        (Sqrt, sqrt),
        (Square, square),
        (Tan, tan),
        (Tanh, tanh)
    );
}