cppn/
activation_function.rs1use std::fmt::Debug;
2use std::f64::consts::PI;
3
4pub trait ActivationFunction: Clone + Debug + Send + Sized + PartialEq + Eq {
5 fn formula_gnuplot(&self, x: String) -> String;
6 fn name(&self) -> String;
7 fn calculate(&self, x: f64) -> f64;
8}
9
10#[inline(always)]
11fn bipolar_debug_check(x: f64) -> f64 {
12 debug_assert!(x >= -1.0 && x <= 1.0);
13 x
14}
15
16fn bipolar_clip(x: f64) -> f64 {
18 if x > 1.0 {
19 1.0
20 } else if x < -1.0 {
21 -1.0
22 } else {
23 x
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29pub enum GeometricActivationFunction {
30 Linear,
31 LinearBipolarClipped,
32 LinearClipped,
33 Absolute,
34 Gaussian,
35 BipolarGaussian,
36 BipolarSigmoid,
37 Sine,
38 Cosine,
39 Constant1,
40}
41
42impl ActivationFunction for GeometricActivationFunction {
43 fn calculate(&self, x: f64) -> f64 {
44 match *self {
45 GeometricActivationFunction::Linear => x,
46 GeometricActivationFunction::LinearBipolarClipped => {
47 bipolar_debug_check(bipolar_clip(x))
48 }
49 GeometricActivationFunction::LinearClipped => x.min(1.0).max(0.0),
50 GeometricActivationFunction::Absolute => x.abs(),
51 GeometricActivationFunction::Gaussian => (-((x * 2.5).powi(2))).exp(),
52 GeometricActivationFunction::BipolarGaussian => {
53 bipolar_debug_check(2.0 * (-((x * 2.5).powi(2))).exp() - 1.0)
54 }
55 GeometricActivationFunction::BipolarSigmoid => {
56 bipolar_debug_check((2.0 / (1.0 + (-4.9 * x).exp())) - 1.0)
57 }
58 GeometricActivationFunction::Sine => bipolar_debug_check((2.0 * PI * x).sin()),
59 GeometricActivationFunction::Cosine => bipolar_debug_check(2.0 * PI * x.cos()),
60 GeometricActivationFunction::Constant1 => 1.0,
61 }
62 }
63
64 fn formula_gnuplot(&self, x: String) -> String {
65 match *self {
66 GeometricActivationFunction::Linear => format!("{}", x),
67 GeometricActivationFunction::LinearBipolarClipped => {
68 format!("max(-1.0, min(1.0, {}))", x)
69 }
70 GeometricActivationFunction::LinearClipped => format!("max(0.0, min(1.0, {}))", x),
71 GeometricActivationFunction::Absolute => format!("abs({})", x),
72 GeometricActivationFunction::Gaussian => format!("(exp(-((({}) * 2.5)**2.0))", x),
73 GeometricActivationFunction::BipolarGaussian => {
74 format!("2.0 * exp(-((({}) * 2.5)**2.0)) - 1.0", x)
75 }
76 GeometricActivationFunction::BipolarSigmoid => {
77 format!("2.0 / (1.0 + exp(-4.9 * ({}))) - 1.0", x)
78 }
79 GeometricActivationFunction::Sine => format!("sin({})", x),
80 GeometricActivationFunction::Cosine => format!("cos({})", x),
81 GeometricActivationFunction::Constant1 => format!("1.0"),
82 }
83 }
84
85 fn name(&self) -> String {
86 match *self {
87 GeometricActivationFunction::Linear => "Linear",
88 GeometricActivationFunction::LinearBipolarClipped => "LinearBipolarClipped",
89 GeometricActivationFunction::LinearClipped => "LinearClipped",
90 GeometricActivationFunction::Absolute => "Absolute",
91 GeometricActivationFunction::Gaussian => "Gaussian",
92 GeometricActivationFunction::BipolarGaussian => "BipolarGaussian",
93 GeometricActivationFunction::BipolarSigmoid => "BipolarSigmoid",
94 GeometricActivationFunction::Sine => "Sine",
95 GeometricActivationFunction::Cosine => "Consine",
96 GeometricActivationFunction::Constant1 => "1.0",
97 }.to_string()
98 }
99}
100
101#[test]
102fn test_bipolar_linear_clipped() {
103 assert_eq!(
104 0.0,
105 GeometricActivationFunction::LinearBipolarClipped.calculate(0.0)
106 );
107 assert_eq!(
108 1.0,
109 GeometricActivationFunction::LinearBipolarClipped.calculate(1.0)
110 );
111 assert_eq!(
112 -1.0,
113 GeometricActivationFunction::LinearBipolarClipped.calculate(-1.0)
114 );
115 assert_eq!(
116 0.5,
117 GeometricActivationFunction::LinearBipolarClipped.calculate(0.5)
118 );
119 assert_eq!(
120 -0.5,
121 GeometricActivationFunction::LinearBipolarClipped.calculate(-0.5)
122 );
123 assert_eq!(
124 1.0,
125 GeometricActivationFunction::LinearBipolarClipped.calculate(1.1)
126 );
127 assert_eq!(
128 -1.0,
129 GeometricActivationFunction::LinearBipolarClipped.calculate(-1.1)
130 );
131}
132
133#[test]
134fn test_linear_clipped() {
135 assert_eq!(
136 0.0,
137 GeometricActivationFunction::LinearClipped.calculate(0.0)
138 );
139 assert_eq!(
140 1.0,
141 GeometricActivationFunction::LinearClipped.calculate(1.0)
142 );
143 assert_eq!(
144 0.0,
145 GeometricActivationFunction::LinearClipped.calculate(-1.0)
146 );
147 assert_eq!(
148 0.5,
149 GeometricActivationFunction::LinearClipped.calculate(0.5)
150 );
151 assert_eq!(
152 0.0,
153 GeometricActivationFunction::LinearClipped.calculate(-0.5)
154 );
155 assert_eq!(
156 1.0,
157 GeometricActivationFunction::LinearClipped.calculate(1.1)
158 );
159 assert_eq!(
160 0.0,
161 GeometricActivationFunction::LinearClipped.calculate(-1.1)
162 );
163}
164
165#[test]
166fn test_constant1() {
167 assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(0.0));
168 assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(-1.0));
169 assert_eq!(1.0, GeometricActivationFunction::Constant1.calculate(1.0));
170}