use burn::tensor::{backend::Backend, Tensor};
pub struct LeCun;
impl LeCun {
pub fn forward<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let scaled = x * 0.666f32;
scaled.tanh() * 1.7159f32
}
}
pub trait LeCunActivation {
fn lecun(self) -> Self;
}
impl<B: Backend, const D: usize> LeCunActivation for Tensor<B, D> {
fn lecun(self) -> Self {
LeCun::forward(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Tensor;
type Backend = NdArray<f32>;
#[test]
fn test_lecun_tanh_zero() {
let device = Default::default();
let x = Tensor::<Backend, 1>::zeros([5], &device);
let y = LeCun::forward(x);
let sum = y.sum().into_scalar();
assert!((sum - 0.0).abs() < 1e-6);
}
#[test]
fn test_lecun_tanh_range() {
let device = Default::default();
let test_values = [-10.0f32, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0];
for &val in &test_values {
let x = Tensor::<Backend, 1>::full([1], val, &device);
let y = LeCun::forward(x);
let result = y.into_scalar();
let expected = 1.7159f32 * (0.666f32 * val).tanh();
assert!(
(result - expected).abs() < 1e-5,
"LeCun activation incorrect at x={}",
val
);
}
}
#[test]
fn test_lecun_tanh_multidimensional() {
let device = Default::default();
let x = Tensor::<Backend, 2>::random(
[4, 8],
burn::tensor::Distribution::Uniform(-2.0, 2.0),
&device,
);
let y = LeCun::forward(x.clone());
assert_eq!(y.dims(), [4, 8]);
for i in 0..4 {
for j in 0..8 {
let x_val = x.clone().slice([i..i + 1, j..j + 1]).into_scalar();
let y_val = y.clone().slice([i..i + 1, j..j + 1]).into_scalar();
let expected = 1.7159f32 * (0.666f32 * x_val).tanh();
assert!(
(y_val - expected).abs() < 1e-5,
"Element [{}, {}] incorrect: got {}, expected {}",
i,
j,
y_val,
expected
);
}
}
}
#[test]
fn test_lecun_tanh_saturation() {
let device = Default::default();
let x_large_pos = Tensor::<Backend, 1>::full([1], 100.0f32, &device);
let y_pos = LeCun::forward(x_large_pos);
assert!(y_pos.into_scalar() > 1.7);
let x_large_neg = Tensor::<Backend, 1>::full([1], -100.0f32, &device);
let y_neg = LeCun::forward(x_large_neg);
assert!(y_neg.into_scalar() < -1.7);
}
#[test]
fn test_lecun_trait() {
let device = Default::default();
let x = Tensor::<Backend, 1>::from_floats([0.0f32, 1.0, -1.0], &device);
let y_trait = x.clone().lecun();
let y_direct = LeCun::forward(x);
for i in 0..3 {
let t_val = y_trait.clone().slice([i..i + 1]).into_scalar();
let d_val = y_direct.clone().slice([i..i + 1]).into_scalar();
assert!((t_val - d_val).abs() < 1e-6);
}
}
#[test]
fn test_lecun_3d_tensor() {
let device = Default::default();
let x = Tensor::<Backend, 3>::random(
[2, 3, 4],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let y = LeCun::forward(x.clone());
assert_eq!(y.dims(), [2, 3, 4]);
for i in 0..2 {
for j in 0..3 {
for k in 0..4 {
let y_val = y
.clone()
.slice([i..i + 1, j..j + 1, k..k + 1])
.into_scalar();
assert!(
y_val >= -1.716 && y_val <= 1.716,
"Value out of range: {}",
y_val
);
}
}
}
}
}