Skip to main content

pc_rl_core/
activation.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Activation functions for neural network layers.
6//!
7//! Provides an enum of common activation functions with element-wise
8//! `apply` and `derivative` operations. Used by layers, the PC actor,
9//! and the MLP critic.
10
11use serde::{Deserialize, Serialize};
12
13/// Supported activation function variants.
14///
15/// Each variant implements `apply(x)` for the forward pass and
16/// `derivative(fx)` which takes the **post-activation** value.
17///
18/// # Examples
19///
20/// ```
21/// use pc_rl_core::activation::Activation;
22///
23/// let act = Activation::Tanh;
24/// let y = act.apply(0.5);
25/// let dy = act.derivative(y);
26/// ```
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum Activation {
29    /// Hyperbolic tangent: output in (-1, 1).
30    Tanh,
31    /// Rectified linear unit: max(0, x).
32    Relu,
33    /// Logistic sigmoid: output in (0, 1).
34    Sigmoid,
35    /// Exponential linear unit: smooth in negatives, avoids dying neurons.
36    Elu,
37    /// Softsign: bounded in (-1, 1) with slower saturation than tanh.
38    /// Preserves more gradient in high-saturation regions.
39    Softsign,
40    /// Identity function: output equals input.
41    Linear,
42}
43
44impl Activation {
45    /// Applies the activation function to a single scalar value.
46    ///
47    /// # Parameters
48    ///
49    /// * `x` - Pre-activation input value.
50    ///
51    /// # Returns
52    ///
53    /// The activated output value.
54    pub fn apply(&self, x: f64) -> f64 {
55        match self {
56            Activation::Tanh => x.tanh(),
57            Activation::Relu => x.max(0.0),
58            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
59            Activation::Elu => {
60                if x > 0.0 {
61                    x
62                } else {
63                    x.exp() - 1.0
64                }
65            }
66            Activation::Softsign => x / (1.0 + x.abs()),
67            Activation::Linear => x,
68        }
69    }
70
71    /// Computes the derivative given the post-activation value.
72    ///
73    /// # Parameters
74    ///
75    /// * `fx` - The post-activation value (output of `apply`).
76    ///
77    /// # Returns
78    ///
79    /// The derivative at `fx`.
80    pub fn derivative(&self, fx: f64) -> f64 {
81        match self {
82            Activation::Tanh => 1.0 - fx * fx,
83            Activation::Relu => {
84                if fx > 0.0 {
85                    1.0
86                } else {
87                    0.0
88                }
89            }
90            Activation::Sigmoid => fx * (1.0 - fx),
91            Activation::Elu => {
92                if fx > 0.0 {
93                    1.0
94                } else {
95                    fx + 1.0
96                }
97            }
98            Activation::Softsign => {
99                // fx = x/(1+|x|), so (1-|fx|) = 1/(1+|x|), derivative = (1-|fx|)^2
100                let t = 1.0 - fx.abs();
101                t * t
102            }
103            Activation::Linear => 1.0,
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    // ── apply tests ──────────────────────────────────────────────
113
114    #[test]
115    fn test_tanh_apply_zero() {
116        assert_eq!(Activation::Tanh.apply(0.0), 0.0);
117    }
118
119    #[test]
120    fn test_tanh_apply_known() {
121        let expected = 1.0_f64.tanh();
122        let result = Activation::Tanh.apply(1.0);
123        assert!((result - expected).abs() < 1e-12);
124    }
125
126    #[test]
127    fn test_tanh_apply_negative() {
128        let expected = (-2.0_f64).tanh();
129        let result = Activation::Tanh.apply(-2.0);
130        assert!((result - expected).abs() < 1e-12);
131    }
132
133    #[test]
134    fn test_relu_apply_negative_is_zero() {
135        assert_eq!(Activation::Relu.apply(-5.0), 0.0);
136    }
137
138    #[test]
139    fn test_relu_apply_zero_is_zero() {
140        assert_eq!(Activation::Relu.apply(0.0), 0.0);
141    }
142
143    #[test]
144    fn test_relu_apply_positive_is_identity() {
145        assert_eq!(Activation::Relu.apply(3.7), 3.7);
146    }
147
148    #[test]
149    fn test_sigmoid_apply_zero_is_half() {
150        assert!((Activation::Sigmoid.apply(0.0) - 0.5).abs() < 1e-12);
151    }
152
153    #[test]
154    fn test_sigmoid_apply_large_stays_below_one() {
155        // At x=30, exp(-30) ≈ 9.4e-14 which is representable in f64.
156        // At x=100, f64 rounds sigmoid to exactly 1.0.
157        let result = Activation::Sigmoid.apply(30.0);
158        assert!(result < 1.0);
159        assert!(result > 0.99);
160    }
161
162    #[test]
163    fn test_sigmoid_apply_very_negative_stays_above_zero() {
164        let result = Activation::Sigmoid.apply(-100.0);
165        assert!(result > 0.0);
166    }
167
168    #[test]
169    fn test_elu_apply_positive_is_identity() {
170        assert_eq!(Activation::Elu.apply(3.0), 3.0);
171    }
172
173    #[test]
174    fn test_elu_apply_zero_is_zero() {
175        assert!((Activation::Elu.apply(0.0)).abs() < 1e-12);
176    }
177
178    #[test]
179    fn test_elu_apply_negative_is_exp_minus_one() {
180        let expected = (-1.0_f64).exp() - 1.0;
181        let result = Activation::Elu.apply(-1.0);
182        assert!((result - expected).abs() < 1e-12);
183    }
184
185    #[test]
186    fn test_elu_apply_large_negative_approaches_minus_one() {
187        let result = Activation::Elu.apply(-100.0);
188        assert!((result - (-1.0)).abs() < 1e-10);
189    }
190
191    #[test]
192    fn test_softsign_apply_positive() {
193        // softsign(2.0) = 2.0 / (1 + 2.0) = 2/3
194        let result = Activation::Softsign.apply(2.0);
195        assert!((result - 2.0 / 3.0).abs() < 1e-12);
196    }
197
198    #[test]
199    fn test_softsign_apply_zero() {
200        assert!((Activation::Softsign.apply(0.0)).abs() < 1e-12);
201    }
202
203    #[test]
204    fn test_softsign_apply_negative() {
205        // softsign(-3.0) = -3.0 / (1 + 3.0) = -0.75
206        let result = Activation::Softsign.apply(-3.0);
207        assert!((result - (-0.75)).abs() < 1e-12);
208    }
209
210    #[test]
211    fn test_softsign_apply_bounded() {
212        // Output must be in (-1, 1) for any input
213        assert!(Activation::Softsign.apply(100.0) < 1.0);
214        assert!(Activation::Softsign.apply(-100.0) > -1.0);
215    }
216
217    #[test]
218    fn test_linear_apply_is_identity() {
219        assert_eq!(Activation::Linear.apply(42.0), 42.0);
220    }
221
222    // ── derivative tests ─────────────────────────────────────────
223
224    #[test]
225    fn test_tanh_derivative_formula() {
226        // derivative(fx) = 1 - fx^2, for fx = 0.5 => 0.75
227        let result = Activation::Tanh.derivative(0.5);
228        assert!((result - 0.75).abs() < 1e-12);
229    }
230
231    #[test]
232    fn test_tanh_derivative_at_zero_is_one() {
233        assert!((Activation::Tanh.derivative(0.0) - 1.0).abs() < 1e-12);
234    }
235
236    #[test]
237    fn test_relu_derivative_zero_output_is_zero() {
238        assert_eq!(Activation::Relu.derivative(0.0), 0.0);
239    }
240
241    #[test]
242    fn test_relu_derivative_positive_output_is_one() {
243        assert_eq!(Activation::Relu.derivative(2.0), 1.0);
244    }
245
246    #[test]
247    fn test_sigmoid_derivative_formula() {
248        // derivative(fx) = fx * (1 - fx), for fx = 0.7 => 0.21
249        let result = Activation::Sigmoid.derivative(0.7);
250        assert!((result - 0.21).abs() < 1e-12);
251    }
252
253    #[test]
254    fn test_sigmoid_derivative_at_half() {
255        // derivative(0.5) = 0.5 * 0.5 = 0.25
256        assert!((Activation::Sigmoid.derivative(0.5) - 0.25).abs() < 1e-12);
257    }
258
259    #[test]
260    fn test_elu_derivative_positive_is_one() {
261        assert_eq!(Activation::Elu.derivative(2.0), 1.0);
262    }
263
264    #[test]
265    fn test_elu_derivative_negative_is_fx_plus_one() {
266        // fx = -0.6, derivative = -0.6 + 1.0 = 0.4
267        let result = Activation::Elu.derivative(-0.6);
268        assert!((result - 0.4).abs() < 1e-12);
269    }
270
271    #[test]
272    fn test_elu_derivative_at_minus_one_is_zero() {
273        // ELU floor is -1.0, derivative there = -1.0 + 1.0 = 0.0
274        assert!((Activation::Elu.derivative(-1.0)).abs() < 1e-12);
275    }
276
277    #[test]
278    fn test_softsign_derivative_at_zero() {
279        // derivative(softsign(0)) = 1 / (1 + 0)^2 = 1.0
280        assert!((Activation::Softsign.derivative(0.0) - 1.0).abs() < 1e-12);
281    }
282
283    #[test]
284    fn test_softsign_derivative_positive() {
285        // softsign(2) = 2/3 ≈ 0.6667, |x| = 2, derivative = 1/(1+2)^2 = 1/9
286        // But derivative takes fx (post-activation), so we need to recover |x|
287        // fx = x/(1+|x|), so |x| = |fx|/(1-|fx|)
288        // For fx=0.5: |x| = 0.5/0.5 = 1.0, derivative = 1/(1+1)^2 = 0.25
289        let result = Activation::Softsign.derivative(0.5);
290        assert!((result - 0.25).abs() < 1e-12);
291    }
292
293    #[test]
294    fn test_softsign_derivative_negative() {
295        // For fx=-0.5: |x| = 0.5/0.5 = 1.0, derivative = 1/(1+1)^2 = 0.25
296        let result = Activation::Softsign.derivative(-0.5);
297        assert!((result - 0.25).abs() < 1e-12);
298    }
299
300    #[test]
301    fn test_softsign_derivative_high_saturation() {
302        // For fx=0.9: |x| = 0.9/0.1 = 9, derivative = 1/(1+9)^2 = 0.01
303        let result = Activation::Softsign.derivative(0.9);
304        assert!((result - 0.01).abs() < 1e-12);
305    }
306
307    #[test]
308    fn test_softsign_derivative_always_positive() {
309        for &fx in &[-0.9, -0.5, 0.0, 0.5, 0.9] {
310            assert!(Activation::Softsign.derivative(fx) > 0.0);
311        }
312    }
313
314    #[test]
315    fn test_linear_derivative_always_one() {
316        assert_eq!(Activation::Linear.derivative(999.0), 1.0);
317        assert_eq!(Activation::Linear.derivative(-42.0), 1.0);
318        assert_eq!(Activation::Linear.derivative(0.0), 1.0);
319    }
320
321    // ── robustness tests ─────────────────────────────────────────
322
323    #[test]
324    fn test_all_activations_produce_finite_output_for_extreme_inputs() {
325        let variants = [
326            Activation::Tanh,
327            Activation::Relu,
328            Activation::Sigmoid,
329            Activation::Elu,
330            Activation::Softsign,
331            Activation::Linear,
332        ];
333        for act in &variants {
334            for &x in &[-100.0, 100.0] {
335                let y = act.apply(x);
336                assert!(y.is_finite(), "{:?}.apply({}) was not finite", act, x);
337            }
338        }
339    }
340
341    #[test]
342    fn test_all_derivatives_finite_for_typical_post_activation_values() {
343        let cases: [(Activation, f64); 6] = [
344            (Activation::Tanh, 0.5),
345            (Activation::Relu, 1.0),
346            (Activation::Sigmoid, 0.5),
347            (Activation::Elu, -0.5),
348            (Activation::Softsign, 0.5),
349            (Activation::Linear, 0.0),
350        ];
351        for (act, fx) in &cases {
352            let d = act.derivative(*fx);
353            assert!(d.is_finite(), "{:?}.derivative({}) was not finite", act, fx);
354        }
355    }
356
357    // ── serde tests ──────────────────────────────────────────────
358
359    #[test]
360    fn test_serde_roundtrip_all_variants() {
361        let variants = [
362            Activation::Tanh,
363            Activation::Relu,
364            Activation::Sigmoid,
365            Activation::Elu,
366            Activation::Softsign,
367            Activation::Linear,
368        ];
369        for act in &variants {
370            let json = serde_json::to_string(act).unwrap();
371            let back: Activation = serde_json::from_str(&json).unwrap();
372            assert_eq!(*act, back);
373        }
374    }
375
376    #[test]
377    fn test_serde_unknown_variant_returns_error() {
378        let result = serde_json::from_str::<Activation>("\"Softmax\"");
379        assert!(result.is_err());
380    }
381}