cppn/
activation_function.rs

1use std::fmt::Debug;
2use std::f64::consts::PI;
3
4pub trait ActivationFunction: Clone + Debug + Send + Sized + PartialEq + Eq {
5    fn formula_gnuplot(&self, x: String) -> String;
6    fn name(&self) -> String;
7    fn calculate(&self, x: f64) -> f64;
8}
9
10#[inline(always)]
11fn bipolar_debug_check(x: f64) -> f64 {
12    debug_assert!(x >= -1.0 && x <= 1.0);
13    x
14}
15
16/// Clips the value of `x` into the range [-1, 1].
17fn bipolar_clip(x: f64) -> f64 {
18    if x > 1.0 {
19        1.0
20    } else if x < -1.0 {
21        -1.0
22    } else {
23        x
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29pub enum GeometricActivationFunction {
30    Linear,
31    LinearBipolarClipped,
32    LinearClipped,
33    Absolute,
34    Gaussian,
35    BipolarGaussian,
36    BipolarSigmoid,
37    Sine,
38    Cosine,
39    Constant1,
40}
41
42impl ActivationFunction for GeometricActivationFunction {
43    fn calculate(&self, x: f64) -> f64 {
44        match *self {
45            GeometricActivationFunction::Linear => x,
46            GeometricActivationFunction::LinearBipolarClipped => {
47                bipolar_debug_check(bipolar_clip(x))
48            }
49            GeometricActivationFunction::LinearClipped => x.min(1.0).max(0.0),
50            GeometricActivationFunction::Absolute => x.abs(),
51            GeometricActivationFunction::Gaussian => (-((x * 2.5).powi(2))).exp(),
52            GeometricActivationFunction::BipolarGaussian => {
53                bipolar_debug_check(2.0 * (-((x * 2.5).powi(2))).exp() - 1.0)
54            }
55            GeometricActivationFunction::BipolarSigmoid => {
56                bipolar_debug_check((2.0 / (1.0 + (-4.9 * x).exp())) - 1.0)
57            }
58            GeometricActivationFunction::Sine => bipolar_debug_check((2.0 * PI * x).sin()),
59            GeometricActivationFunction::Cosine => bipolar_debug_check(2.0 * PI * x.cos()),
60            GeometricActivationFunction::Constant1 => 1.0,
61        }
62    }
63
64    fn formula_gnuplot(&self, x: String) -> String {
65        match *self {
66            GeometricActivationFunction::Linear => format!("{}", x),
67            GeometricActivationFunction::LinearBipolarClipped => {
68                format!("max(-1.0, min(1.0, {}))", x)
69            }
70            GeometricActivationFunction::LinearClipped => format!("max(0.0, min(1.0, {}))", x),
71            GeometricActivationFunction::Absolute => format!("abs({})", x),
72            GeometricActivationFunction::Gaussian => format!("(exp(-((({}) * 2.5)**2.0))", x),
73            GeometricActivationFunction::BipolarGaussian => {
74                format!("2.0 * exp(-((({}) * 2.5)**2.0)) - 1.0", x)
75            }
76            GeometricActivationFunction::BipolarSigmoid => {
77                format!("2.0 / (1.0 + exp(-4.9 * ({}))) - 1.0", x)
78            }
79            GeometricActivationFunction::Sine => format!("sin({})", x),
80            GeometricActivationFunction::Cosine => format!("cos({})", x),
81            GeometricActivationFunction::Constant1 => format!("1.0"),
82        }
83    }
84
85    fn name(&self) -> String {
86        match *self {
87            GeometricActivationFunction::Linear => "Linear",
88            GeometricActivationFunction::LinearBipolarClipped => "LinearBipolarClipped",
89            GeometricActivationFunction::LinearClipped => "LinearClipped",
90            GeometricActivationFunction::Absolute => "Absolute",
91            GeometricActivationFunction::Gaussian => "Gaussian",
92            GeometricActivationFunction::BipolarGaussian => "BipolarGaussian",
93            GeometricActivationFunction::BipolarSigmoid => "BipolarSigmoid",
94            GeometricActivationFunction::Sine => "Sine",
95            GeometricActivationFunction::Cosine => "Consine",
96            GeometricActivationFunction::Constant1 => "1.0",
97        }.to_string()
98    }
99}
100
101#[test]
102fn test_bipolar_linear_clipped() {
103    assert_eq!(
104        0.0,
105        GeometricActivationFunction::LinearBipolarClipped.calculate(0.0)
106    );
107    assert_eq!(
108        1.0,
109        GeometricActivationFunction::LinearBipolarClipped.calculate(1.0)
110    );
111    assert_eq!(
112        -1.0,
113        GeometricActivationFunction::LinearBipolarClipped.calculate(-1.0)
114    );
115    assert_eq!(
116        0.5,
117        GeometricActivationFunction::LinearBipolarClipped.calculate(0.5)
118    );
119    assert_eq!(
120        -0.5,
121        GeometricActivationFunction::LinearBipolarClipped.calculate(-0.5)
122    );
123    assert_eq!(
124        1.0,
125        GeometricActivationFunction::LinearBipolarClipped.calculate(1.1)
126    );
127    assert_eq!(
128        -1.0,
129        GeometricActivationFunction::LinearBipolarClipped.calculate(-1.1)
130    );
131}
132
133#[test]
134fn test_linear_clipped() {
135    assert_eq!(
136        0.0,
137        GeometricActivationFunction::LinearClipped.calculate(0.0)
138    );
139    assert_eq!(
140        1.0,
141        GeometricActivationFunction::LinearClipped.calculate(1.0)
142    );
143    assert_eq!(
144        0.0,
145        GeometricActivationFunction::LinearClipped.calculate(-1.0)
146    );
147    assert_eq!(
148        0.5,
149        GeometricActivationFunction::LinearClipped.calculate(0.5)
150    );
151    assert_eq!(
152        0.0,
153        GeometricActivationFunction::LinearClipped.calculate(-0.5)
154    );
155    assert_eq!(
156        1.0,
157        GeometricActivationFunction::LinearClipped.calculate(1.1)
158    );
159    assert_eq!(
160        0.0,
161        GeometricActivationFunction::LinearClipped.calculate(-1.1)
162    );
163}
164
165#[test]
166fn test_constant1() {
167    assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(0.0));
168    assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(-1.0));
169    assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(1.0));
170}