Skip to main content

ncps/
activation.rs

1//! Custom activation functions for NCPS
2//!
3//! This module provides activation functions not available in Burn's standard library.
4
5use burn::tensor::{backend::Backend, Tensor};
6
7/// LeCun's tanh activation function.
8///
9/// This activation function is defined as:
10/// `f(x) = 1.7159 * tanh(0.666 * x)`
11///
12/// It provides a smoother alternative to standard tanh with better gradient flow
13/// properties. The scaling factors (1.7159 and 0.666) are chosen such that:
14/// - The function approximates the identity near the origin
15/// - The output range is approximately [-1.7159, 1.7159]
16///
17/// # Example
18///
19/// ```rust
20/// use burn::backend::NdArray;
21/// use burn::tensor::Tensor;
22/// use ncps::activation::LeCun;
23///
24/// type Backend = NdArray<f32>;
25/// let device = Default::default();
26///
27/// let x = Tensor::<Backend, 1>::from_floats([0.0, 1.0, -1.0], &device);
28/// let y = LeCun::forward(x);
29/// ```
30pub struct LeCun;
31
32impl LeCun {
33    /// Applies the LeCun tanh activation function.
34    ///
35    /// # Arguments
36    ///
37    /// * `x` - Input tensor of any dimension
38    ///
39    /// # Returns
40    ///
41    /// Tensor with LeCun activation applied element-wise
42    pub fn forward<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
43        // LeCun tanh: 1.7159 * tanh(0.666 * x)
44        let scaled = x * 0.666f32;
45        scaled.tanh() * 1.7159f32
46    }
47}
48
49/// Applies LeCun activation to a tensor.
50///
51/// This is a convenience trait extension for applying LeCun activation directly on tensors.
52pub trait LeCunActivation {
53    /// Applies LeCun activation
54    fn lecun(self) -> Self;
55}
56
57impl<B: Backend, const D: usize> LeCunActivation for Tensor<B, D> {
58    fn lecun(self) -> Self {
59        LeCun::forward(self)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use burn::backend::NdArray;
67    use burn::tensor::Tensor;
68
69    type Backend = NdArray<f32>;
70
71    #[test]
72    fn test_lecun_tanh_zero() {
73        let device = Default::default();
74        let x = Tensor::<Backend, 1>::zeros([5], &device);
75        let y = LeCun::forward(x);
76
77        // tanh(0) = 0, so LeCun(0) = 0
78        let sum = y.sum().into_scalar();
79        assert!((sum - 0.0).abs() < 1e-6);
80    }
81
82    #[test]
83    fn test_lecun_tanh_range() {
84        let device = Default::default();
85
86        // Test various inputs
87        let test_values = [-10.0f32, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0];
88
89        for &val in &test_values {
90            let x = Tensor::<Backend, 1>::full([1], val, &device);
91            let y = LeCun::forward(x);
92
93            let result = y.into_scalar();
94            let expected = 1.7159f32 * (0.666f32 * val).tanh();
95
96            assert!(
97                (result - expected).abs() < 1e-5,
98                "LeCun activation incorrect at x={}",
99                val
100            );
101        }
102    }
103
104    #[test]
105    fn test_lecun_tanh_multidimensional() {
106        let device = Default::default();
107        let x = Tensor::<Backend, 2>::random(
108            [4, 8],
109            burn::tensor::Distribution::Uniform(-2.0, 2.0),
110            &device,
111        );
112
113        let y = LeCun::forward(x.clone());
114
115        assert_eq!(y.dims(), [4, 8]);
116
117        // Verify element-wise correctness by comparing a few values
118        // Extract data using slice approach
119        for i in 0..4 {
120            for j in 0..8 {
121                let x_val = x.clone().slice([i..i + 1, j..j + 1]).into_scalar();
122                let y_val = y.clone().slice([i..i + 1, j..j + 1]).into_scalar();
123                let expected = 1.7159f32 * (0.666f32 * x_val).tanh();
124                assert!(
125                    (y_val - expected).abs() < 1e-5,
126                    "Element [{}, {}] incorrect: got {}, expected {}",
127                    i,
128                    j,
129                    y_val,
130                    expected
131                );
132            }
133        }
134    }
135
136    #[test]
137    fn test_lecun_tanh_saturation() {
138        let device = Default::default();
139
140        // Very large positive input should saturate near max
141        let x_large_pos = Tensor::<Backend, 1>::full([1], 100.0f32, &device);
142        let y_pos = LeCun::forward(x_large_pos);
143        assert!(y_pos.into_scalar() > 1.7);
144
145        // Very large negative input should saturate near min
146        let x_large_neg = Tensor::<Backend, 1>::full([1], -100.0f32, &device);
147        let y_neg = LeCun::forward(x_large_neg);
148        assert!(y_neg.into_scalar() < -1.7);
149    }
150
151    #[test]
152    fn test_lecun_trait() {
153        let device = Default::default();
154        let x = Tensor::<Backend, 1>::from_floats([0.0f32, 1.0, -1.0], &device);
155
156        // Test using the trait extension
157        let y_trait = x.clone().lecun();
158        let y_direct = LeCun::forward(x);
159
160        // Compare element by element
161        for i in 0..3 {
162            let t_val = y_trait.clone().slice([i..i + 1]).into_scalar();
163            let d_val = y_direct.clone().slice([i..i + 1]).into_scalar();
164            assert!((t_val - d_val).abs() < 1e-6);
165        }
166    }
167
168    #[test]
169    fn test_lecun_3d_tensor() {
170        let device = Default::default();
171        let x = Tensor::<Backend, 3>::random(
172            [2, 3, 4],
173            burn::tensor::Distribution::Uniform(-1.0, 1.0),
174            &device,
175        );
176
177        let y = LeCun::forward(x.clone());
178
179        assert_eq!(y.dims(), [2, 3, 4]);
180
181        // Verify output is within expected range by sampling a few values
182        for i in 0..2 {
183            for j in 0..3 {
184                for k in 0..4 {
185                    let y_val = y
186                        .clone()
187                        .slice([i..i + 1, j..j + 1, k..k + 1])
188                        .into_scalar();
189                    assert!(
190                        y_val >= -1.716 && y_val <= 1.716,
191                        "Value out of range: {}",
192                        y_val
193                    );
194                }
195            }
196        }
197    }
198}