1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//! Activation functions and their derivatives.
//!
//! The activation functions are used to determine the output of a neuron and to compute the back-propagation gradient.

use serde::{Deserialize, Serialize};

/// Determine types of activation functions contained in this module.
/// >   The network automaticaly uses the correct derivative when propagating
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ActivationType {
    Sigmoid,
    Tanh,
    ArcTanh,
    Relu,
    LeakyRelu,
    ELU,
    Swish,
    SoftMax,
    SoftPlus,
}

pub fn sigm(x: f64) -> f64 {
    1.0 / (1.0 + (-x).exp())
}
pub fn der_sigm(x: f64) -> f64 {
    sigm(x) * (1.0 - sigm(x))
}

pub fn tanh(x: f64) -> f64 {
    x.tanh()
}

pub fn der_tanh(x: f64) -> f64 {
    1.0 - x.tanh().powi(2)
}

pub fn arc_tanh(x: f64) -> f64 {
    x.atan()
}

pub fn der_arc_tanh(x: f64) -> f64 {
    1.0 / (1.0 + x.powi(2))
}

pub fn relu(x: f64) -> f64 {
    f64::max(0.0, x)
}

pub fn der_relu(x: f64) -> f64 {
    if x <= 0.0 {
        0.0
    } else {
        1.0
    }
}

pub fn leaky_relu(x: f64) -> f64 {
    if x <= 0.0 {
        0.01 * x
    } else {
        x
    }
}

pub fn der_leaky_relu(x: f64) -> f64 {
    if x <= 0.0 {
        0.01
    } else {
        1.0
    }
}

// pub fn parametric_relu(x: f64, alpha: f64) -> f64 {
//     if x <= 0.0 {
//         alpha * x
//     } else {
//         x
//     }
// }

// pub fn der_parametric_relu(x: f64, alpha: f64) -> f64 {
//     if x <= 0.0 {
//         alpha
//     } else {
//         1.0
//     }
// }

pub fn elu(x: f64) -> f64 {
    if x <= 0.0 {
        0.01 * (x.exp() - 1.0)
    } else {
        x
    }
}

pub fn der_elu(x: f64) -> f64 {
    if x <= 0.0 {
        elu(x) + 0.01
    } else {
        1.0
    }
}

pub fn swish(x: f64) -> f64 {
    x * sigm(x)
}

pub fn der_swish(x: f64) -> f64 {
    swish(x) + sigm(x) * (1.0 - swish(x))
}

pub fn softmax(x: f64, total: &ndarray::Array1<f64>) -> f64 {
    x.exp() / total.iter().map(|x| x.exp()).sum::<f64>()
}

pub fn softmax_array<const SIZE: usize>(x: f64, total: &[f64; SIZE]) -> f64 {
    x.exp() / total.iter().map(|x| x.exp()).sum::<f64>()
}

pub fn der_softmax(x: f64, total: &ndarray::Array1<f64>) -> f64 {
    softmax(x, total) * (1.0 - softmax(x, total))
}

pub fn softplus(x: f64) -> f64 {
    x.ln_1p().exp()
}

pub fn der_softplus(x: f64) -> f64 {
    1.0 / (1.0 + (-x).exp())
}