1use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum Activation {
29 Tanh,
31 Relu,
33 Sigmoid,
35 Elu,
37 Softsign,
40 Linear,
42}
43
44impl Activation {
45 pub fn apply(&self, x: f64) -> f64 {
55 match self {
56 Activation::Tanh => x.tanh(),
57 Activation::Relu => x.max(0.0),
58 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
59 Activation::Elu => {
60 if x > 0.0 {
61 x
62 } else {
63 x.exp() - 1.0
64 }
65 }
66 Activation::Softsign => x / (1.0 + x.abs()),
67 Activation::Linear => x,
68 }
69 }
70
71 pub fn derivative(&self, fx: f64) -> f64 {
81 match self {
82 Activation::Tanh => 1.0 - fx * fx,
83 Activation::Relu => {
84 if fx > 0.0 {
85 1.0
86 } else {
87 0.0
88 }
89 }
90 Activation::Sigmoid => fx * (1.0 - fx),
91 Activation::Elu => {
92 if fx > 0.0 {
93 1.0
94 } else {
95 fx + 1.0
96 }
97 }
98 Activation::Softsign => {
99 let t = 1.0 - fx.abs();
101 t * t
102 }
103 Activation::Linear => 1.0,
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
115 fn test_tanh_apply_zero() {
116 assert_eq!(Activation::Tanh.apply(0.0), 0.0);
117 }
118
119 #[test]
120 fn test_tanh_apply_known() {
121 let expected = 1.0_f64.tanh();
122 let result = Activation::Tanh.apply(1.0);
123 assert!((result - expected).abs() < 1e-12);
124 }
125
126 #[test]
127 fn test_tanh_apply_negative() {
128 let expected = (-2.0_f64).tanh();
129 let result = Activation::Tanh.apply(-2.0);
130 assert!((result - expected).abs() < 1e-12);
131 }
132
133 #[test]
134 fn test_relu_apply_negative_is_zero() {
135 assert_eq!(Activation::Relu.apply(-5.0), 0.0);
136 }
137
138 #[test]
139 fn test_relu_apply_zero_is_zero() {
140 assert_eq!(Activation::Relu.apply(0.0), 0.0);
141 }
142
143 #[test]
144 fn test_relu_apply_positive_is_identity() {
145 assert_eq!(Activation::Relu.apply(3.7), 3.7);
146 }
147
148 #[test]
149 fn test_sigmoid_apply_zero_is_half() {
150 assert!((Activation::Sigmoid.apply(0.0) - 0.5).abs() < 1e-12);
151 }
152
153 #[test]
154 fn test_sigmoid_apply_large_stays_below_one() {
155 let result = Activation::Sigmoid.apply(30.0);
158 assert!(result < 1.0);
159 assert!(result > 0.99);
160 }
161
162 #[test]
163 fn test_sigmoid_apply_very_negative_stays_above_zero() {
164 let result = Activation::Sigmoid.apply(-100.0);
165 assert!(result > 0.0);
166 }
167
168 #[test]
169 fn test_elu_apply_positive_is_identity() {
170 assert_eq!(Activation::Elu.apply(3.0), 3.0);
171 }
172
173 #[test]
174 fn test_elu_apply_zero_is_zero() {
175 assert!((Activation::Elu.apply(0.0)).abs() < 1e-12);
176 }
177
178 #[test]
179 fn test_elu_apply_negative_is_exp_minus_one() {
180 let expected = (-1.0_f64).exp() - 1.0;
181 let result = Activation::Elu.apply(-1.0);
182 assert!((result - expected).abs() < 1e-12);
183 }
184
185 #[test]
186 fn test_elu_apply_large_negative_approaches_minus_one() {
187 let result = Activation::Elu.apply(-100.0);
188 assert!((result - (-1.0)).abs() < 1e-10);
189 }
190
191 #[test]
192 fn test_softsign_apply_positive() {
193 let result = Activation::Softsign.apply(2.0);
195 assert!((result - 2.0 / 3.0).abs() < 1e-12);
196 }
197
198 #[test]
199 fn test_softsign_apply_zero() {
200 assert!((Activation::Softsign.apply(0.0)).abs() < 1e-12);
201 }
202
203 #[test]
204 fn test_softsign_apply_negative() {
205 let result = Activation::Softsign.apply(-3.0);
207 assert!((result - (-0.75)).abs() < 1e-12);
208 }
209
210 #[test]
211 fn test_softsign_apply_bounded() {
212 assert!(Activation::Softsign.apply(100.0) < 1.0);
214 assert!(Activation::Softsign.apply(-100.0) > -1.0);
215 }
216
217 #[test]
218 fn test_linear_apply_is_identity() {
219 assert_eq!(Activation::Linear.apply(42.0), 42.0);
220 }
221
222 #[test]
225 fn test_tanh_derivative_formula() {
226 let result = Activation::Tanh.derivative(0.5);
228 assert!((result - 0.75).abs() < 1e-12);
229 }
230
231 #[test]
232 fn test_tanh_derivative_at_zero_is_one() {
233 assert!((Activation::Tanh.derivative(0.0) - 1.0).abs() < 1e-12);
234 }
235
236 #[test]
237 fn test_relu_derivative_zero_output_is_zero() {
238 assert_eq!(Activation::Relu.derivative(0.0), 0.0);
239 }
240
241 #[test]
242 fn test_relu_derivative_positive_output_is_one() {
243 assert_eq!(Activation::Relu.derivative(2.0), 1.0);
244 }
245
246 #[test]
247 fn test_sigmoid_derivative_formula() {
248 let result = Activation::Sigmoid.derivative(0.7);
250 assert!((result - 0.21).abs() < 1e-12);
251 }
252
253 #[test]
254 fn test_sigmoid_derivative_at_half() {
255 assert!((Activation::Sigmoid.derivative(0.5) - 0.25).abs() < 1e-12);
257 }
258
259 #[test]
260 fn test_elu_derivative_positive_is_one() {
261 assert_eq!(Activation::Elu.derivative(2.0), 1.0);
262 }
263
264 #[test]
265 fn test_elu_derivative_negative_is_fx_plus_one() {
266 let result = Activation::Elu.derivative(-0.6);
268 assert!((result - 0.4).abs() < 1e-12);
269 }
270
271 #[test]
272 fn test_elu_derivative_at_minus_one_is_zero() {
273 assert!((Activation::Elu.derivative(-1.0)).abs() < 1e-12);
275 }
276
277 #[test]
278 fn test_softsign_derivative_at_zero() {
279 assert!((Activation::Softsign.derivative(0.0) - 1.0).abs() < 1e-12);
281 }
282
283 #[test]
284 fn test_softsign_derivative_positive() {
285 let result = Activation::Softsign.derivative(0.5);
290 assert!((result - 0.25).abs() < 1e-12);
291 }
292
293 #[test]
294 fn test_softsign_derivative_negative() {
295 let result = Activation::Softsign.derivative(-0.5);
297 assert!((result - 0.25).abs() < 1e-12);
298 }
299
300 #[test]
301 fn test_softsign_derivative_high_saturation() {
302 let result = Activation::Softsign.derivative(0.9);
304 assert!((result - 0.01).abs() < 1e-12);
305 }
306
307 #[test]
308 fn test_softsign_derivative_always_positive() {
309 for &fx in &[-0.9, -0.5, 0.0, 0.5, 0.9] {
310 assert!(Activation::Softsign.derivative(fx) > 0.0);
311 }
312 }
313
314 #[test]
315 fn test_linear_derivative_always_one() {
316 assert_eq!(Activation::Linear.derivative(999.0), 1.0);
317 assert_eq!(Activation::Linear.derivative(-42.0), 1.0);
318 assert_eq!(Activation::Linear.derivative(0.0), 1.0);
319 }
320
321 #[test]
324 fn test_all_activations_produce_finite_output_for_extreme_inputs() {
325 let variants = [
326 Activation::Tanh,
327 Activation::Relu,
328 Activation::Sigmoid,
329 Activation::Elu,
330 Activation::Softsign,
331 Activation::Linear,
332 ];
333 for act in &variants {
334 for &x in &[-100.0, 100.0] {
335 let y = act.apply(x);
336 assert!(y.is_finite(), "{:?}.apply({}) was not finite", act, x);
337 }
338 }
339 }
340
341 #[test]
342 fn test_all_derivatives_finite_for_typical_post_activation_values() {
343 let cases: [(Activation, f64); 6] = [
344 (Activation::Tanh, 0.5),
345 (Activation::Relu, 1.0),
346 (Activation::Sigmoid, 0.5),
347 (Activation::Elu, -0.5),
348 (Activation::Softsign, 0.5),
349 (Activation::Linear, 0.0),
350 ];
351 for (act, fx) in &cases {
352 let d = act.derivative(*fx);
353 assert!(d.is_finite(), "{:?}.derivative({}) was not finite", act, fx);
354 }
355 }
356
357 #[test]
360 fn test_serde_roundtrip_all_variants() {
361 let variants = [
362 Activation::Tanh,
363 Activation::Relu,
364 Activation::Sigmoid,
365 Activation::Elu,
366 Activation::Softsign,
367 Activation::Linear,
368 ];
369 for act in &variants {
370 let json = serde_json::to_string(act).unwrap();
371 let back: Activation = serde_json::from_str(&json).unwrap();
372 assert_eq!(*act, back);
373 }
374 }
375
376 #[test]
377 fn test_serde_unknown_variant_returns_error() {
378 let result = serde_json::from_str::<Activation>("\"Softmax\"");
379 assert!(result.is_err());
380 }
381}