1use burn::tensor::{backend::Backend, Tensor};
6
7pub struct LeCun;
31
32impl LeCun {
33 pub fn forward<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
43 let scaled = x * 0.666f32;
45 scaled.tanh() * 1.7159f32
46 }
47}
48
49pub trait LeCunActivation {
53 fn lecun(self) -> Self;
55}
56
57impl<B: Backend, const D: usize> LeCunActivation for Tensor<B, D> {
58 fn lecun(self) -> Self {
59 LeCun::forward(self)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use burn::backend::NdArray;
67 use burn::tensor::Tensor;
68
69 type Backend = NdArray<f32>;
70
71 #[test]
72 fn test_lecun_tanh_zero() {
73 let device = Default::default();
74 let x = Tensor::<Backend, 1>::zeros([5], &device);
75 let y = LeCun::forward(x);
76
77 let sum = y.sum().into_scalar();
79 assert!((sum - 0.0).abs() < 1e-6);
80 }
81
82 #[test]
83 fn test_lecun_tanh_range() {
84 let device = Default::default();
85
86 let test_values = [-10.0f32, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0];
88
89 for &val in &test_values {
90 let x = Tensor::<Backend, 1>::full([1], val, &device);
91 let y = LeCun::forward(x);
92
93 let result = y.into_scalar();
94 let expected = 1.7159f32 * (0.666f32 * val).tanh();
95
96 assert!(
97 (result - expected).abs() < 1e-5,
98 "LeCun activation incorrect at x={}",
99 val
100 );
101 }
102 }
103
104 #[test]
105 fn test_lecun_tanh_multidimensional() {
106 let device = Default::default();
107 let x = Tensor::<Backend, 2>::random(
108 [4, 8],
109 burn::tensor::Distribution::Uniform(-2.0, 2.0),
110 &device,
111 );
112
113 let y = LeCun::forward(x.clone());
114
115 assert_eq!(y.dims(), [4, 8]);
116
117 for i in 0..4 {
120 for j in 0..8 {
121 let x_val = x.clone().slice([i..i + 1, j..j + 1]).into_scalar();
122 let y_val = y.clone().slice([i..i + 1, j..j + 1]).into_scalar();
123 let expected = 1.7159f32 * (0.666f32 * x_val).tanh();
124 assert!(
125 (y_val - expected).abs() < 1e-5,
126 "Element [{}, {}] incorrect: got {}, expected {}",
127 i,
128 j,
129 y_val,
130 expected
131 );
132 }
133 }
134 }
135
136 #[test]
137 fn test_lecun_tanh_saturation() {
138 let device = Default::default();
139
140 let x_large_pos = Tensor::<Backend, 1>::full([1], 100.0f32, &device);
142 let y_pos = LeCun::forward(x_large_pos);
143 assert!(y_pos.into_scalar() > 1.7);
144
145 let x_large_neg = Tensor::<Backend, 1>::full([1], -100.0f32, &device);
147 let y_neg = LeCun::forward(x_large_neg);
148 assert!(y_neg.into_scalar() < -1.7);
149 }
150
151 #[test]
152 fn test_lecun_trait() {
153 let device = Default::default();
154 let x = Tensor::<Backend, 1>::from_floats([0.0f32, 1.0, -1.0], &device);
155
156 let y_trait = x.clone().lecun();
158 let y_direct = LeCun::forward(x);
159
160 for i in 0..3 {
162 let t_val = y_trait.clone().slice([i..i + 1]).into_scalar();
163 let d_val = y_direct.clone().slice([i..i + 1]).into_scalar();
164 assert!((t_val - d_val).abs() < 1e-6);
165 }
166 }
167
168 #[test]
169 fn test_lecun_3d_tensor() {
170 let device = Default::default();
171 let x = Tensor::<Backend, 3>::random(
172 [2, 3, 4],
173 burn::tensor::Distribution::Uniform(-1.0, 1.0),
174 &device,
175 );
176
177 let y = LeCun::forward(x.clone());
178
179 assert_eq!(y.dims(), [2, 3, 4]);
180
181 for i in 0..2 {
183 for j in 0..3 {
184 for k in 0..4 {
185 let y_val = y
186 .clone()
187 .slice([i..i + 1, j..j + 1, k..k + 1])
188 .into_scalar();
189 assert!(
190 y_val >= -1.716 && y_val <= 1.716,
191 "Value out of range: {}",
192 y_val
193 );
194 }
195 }
196 }
197 }
198}