fast_neural_network/
activation.rs

1//! Activation functions and their derivatives.
2//!
3//! The activation functions are used to determine the output of a neuron and to compute the back-propagation gradient.
4
5use serde::{Deserialize, Serialize};
6
7/// Determine types of activation functions contained in this module.
8/// >   The network automaticaly uses the correct derivative when propagating
9#[derive(Debug, Serialize, Deserialize, Clone)]
10pub enum ActivationType {
11    Sigmoid,
12    Tanh,
13    ArcTanh,
14    Relu,
15    LeakyRelu,
16    ELU,
17    Swish,
18    SoftMax,
19    SoftPlus,
20}
21
22pub fn sigm(x: &f64) -> f64 {
23    1.0 / (1.0 + (-x).exp())
24}
25pub fn der_sigm(x: &f64) -> f64 {
26    sigm(x) * (1.0 - sigm(x))
27}
28
29pub fn tanh(x: &f64) -> f64 {
30    x.tanh()
31}
32
33pub fn der_tanh(x: &f64) -> f64 {
34    1.0 - x.tanh().powi(2)
35}
36
37pub fn arc_tanh(x: &f64) -> f64 {
38    x.atan()
39}
40
41pub fn der_arc_tanh(x: &f64) -> f64 {
42    1.0 / (1.0 + x.powi(2))
43}
44
45pub fn relu(x: &f64) -> f64 {
46    f64::max(0.0, *x)
47}
48
49pub fn der_relu(x: &f64) -> f64 {
50    if *x <= 0.0 {
51        0.0
52    } else {
53        1.0
54    }
55}
56
57pub fn leaky_relu(x: &f64) -> f64 {
58    if *x <= 0.0 {
59        0.01 * x
60    } else {
61        *x
62    }
63}
64
65pub fn der_leaky_relu(x: &f64) -> f64 {
66    if *x <= 0.0 {
67        0.01
68    } else {
69        1.0
70    }
71}
72
73// pub fn parametric_relu(x: f64, alpha: f64) -> f64 {
74//     if x <= 0.0 {
75//         alpha * x
76//     } else {
77//         x
78//     }
79// }
80
81// pub fn der_parametric_relu(x: f64, alpha: f64) -> f64 {
82//     if x <= 0.0 {
83//         alpha
84//     } else {
85//         1.0
86//     }
87// }
88
89pub fn elu(x: &f64) -> f64 {
90    if *x <= 0.0 {
91        0.01 * (x.exp() - 1.0)
92    } else {
93        *x
94    }
95}
96
97// ----------------- d/dt of the Activation functions -----------------
98
99pub fn der_elu(x: &f64) -> f64 {
100    if *x <= 0.0 {
101        elu(x) + 0.01
102    } else {
103        1.0
104    }
105}
106
107pub fn swish(x: &f64) -> f64 {
108    x * sigm(x)
109}
110
111pub fn der_swish(x: &f64) -> f64 {
112    swish(x) + sigm(x) * (1.0 - swish(x))
113}
114
115pub fn softmax(x: &f64, total: &ndarray::Array1<f64>) -> f64 {
116    x.exp() / total.iter().map(|x| x.exp()).sum::<f64>()
117}
118
119pub fn softmax_array<const SIZE: usize>(x: &f64, total: &[f64; SIZE]) -> f64 {
120    x.exp() / total.iter().map(|x| x.exp()).sum::<f64>()
121}
122
123pub fn der_softmax(x: &f64, total: &ndarray::Array1<f64>) -> f64 {
124    softmax(x, total) * (1.0 - softmax(x, total))
125}
126
127pub fn softplus(x: &f64) -> f64 {
128    x.ln_1p().exp()
129}
130
131pub fn der_softplus(x: &f64) -> f64 {
132    1.0 / (1.0 + (-x).exp())
133}