use crate::common::matrix::DMat;
use crate::{activation::ActivationFunction, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{he_initialization, ActivationFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct ELUActivation {
alpha: f32,
}
pub struct ELU {
alpha: f32,
}
impl ELU {
fn new() -> Self {
ELU { alpha: 1.0 } }
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.alpha <= 0.0 {
return Err(NetworkError::ConfigError(format!(
"Alpha for ELU must be greater than 0.0, but was {}",
self.alpha
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn ActivationFunction>, NetworkError> {
self.validate()?;
Ok(Box::new(ELUActivation { alpha: self.alpha }))
}
}
impl Default for ELU {
fn default() -> Self {
Self::new()
}
}
#[typetag::serde]
impl ActivationFunction for ELUActivation {
fn forward(&self, input: &mut DMat) {
input.apply(|x| {
if x > 0.0 {
x } else {
self.alpha * ((x).exp() - 1.0) }
});
}
fn backward(&self, d_output: &DMat, input: &mut DMat, _output: &DMat) {
input.apply(|x| {
if x > 0.0 {
1.0 } else {
self.alpha * (x).exp() }
});
input.mul_elem(d_output); }
fn weight_initialization_factor(&self) -> fn(usize, usize) -> f32 {
he_initialization
}
}
impl ActivationFunctionClone for ELUActivation {
fn clone_box(&self) -> Box<dyn ActivationFunction> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod elu_tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_elu_forward_positive_values() {
let elu = ELU::new().alpha(1.0).build().unwrap();
let mut input = DMat::new(1, 3, &[1.0, 2.0, 3.0]);
elu.forward(&mut input);
let expected = DMat::new(1, 3, &[1.0, 2.0, 3.0]);
assert!(equal_approx(&input, &expected, 1e-6), "ELU forward pass with positive values failed");
}
#[test]
fn test_elu_forward_mixed_values() {
let elu: Box<dyn ActivationFunction> = ELU::new().alpha(1.0).build().unwrap();
let mut input = DMat::new(2, 3, &[-1.0, 0.0, 2.0, -3.5, 4.2, 0.0]);
elu.forward(&mut input);
let expected = DMat::new(
2,
3,
&[
1.0 * ((-1.0_f32).exp() - 1.0),
0.0,
2.0,
1.0 * ((-3.5_f32).exp() - 1.0),
4.2,
0.0,
],
);
assert!(equal_approx(&input, &expected, 1e-6), "ELU forward pass with mixed values failed");
}
#[test]
fn test_elu_backward_positive_values() {
let elu = ELU::new().alpha(1.0).build().unwrap();
let mut input = DMat::new(1, 3, &[1.0, 2.0, 3.0]);
let d_output = DMat::new(1, 3, &[0.5, 1.0, 0.7]);
let output: DMat = DMat::new(2, 3, &[0.0; 6]);
elu.backward(&d_output, &mut input, &output);
let expected = DMat::new(1, 3, &[0.5, 1.0, 0.7]);
assert!(equal_approx(&input, &expected, 1e-6), "ELU backward pass with positive values failed");
}
#[test]
fn test_elu_backward_mixed_values() {
let mut input = DMat::new(2, 3, &[-1.0, 0.0, 2.0, -3.5, 4.2, 0.0]);
let d_output = DMat::new(2, 3, &[0.5, 1.0, 0.7, 0.2, 0.3, 0.1]);
let output: DMat = DMat::new(2, 3, &[0.0; 6]);
let elu = ELU::new().alpha(1.0).build().unwrap();
elu.backward(&d_output, &mut input, &output);
let expected = DMat::new(
2,
3,
&[
1.0 * (-1.0_f32).exp() * 0.5,
1.0,
0.7,
1.0 * (-3.5_f32).exp() * 0.2,
0.3,
0.1,
],
);
assert!(equal_approx(&input, &expected, 1e-6), "ELU backward pass with mixed values failed");
}
#[test]
fn test_different_elu_alpha() {
let test_cases = [
(0.5, -2.0, 0.5 * ((-2.0_f32).exp() - 1.0)),
(1.5, -3.0, 1.5 * ((-3.0_f32).exp() - 1.0)),
];
for (alpha, input_value, expected_output) in test_cases {
let elu = ELU::new().alpha(alpha).build().unwrap();
let mut input = DMat::new(1, 1, &[input_value]);
elu.forward(&mut input);
let expected = DMat::new(1, 1, &[expected_output]);
assert!(equal_approx(&input, &expected, 1e-6), "ELU forward pass with alpha failed");
}
}
#[test]
fn test_invalid_alpha() {
let invalid_alphas = [-1.0, 0.0];
for &alpha in &invalid_alphas {
let result = ELU::new().alpha(alpha).build();
assert!(result.is_err(), "ELU should not accept non-positive alpha values");
}
}
}