use crate::common::matrix::DMat;
use crate::{activation::ActivationFunction, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{he_initialization, ActivationFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct LeakyReLUActivation {
alpha: f32,
}
pub struct LeakyReLU {
alpha: f32,
}
impl LeakyReLU {
fn new() -> Self {
LeakyReLU { alpha: 0.01 }
}
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.alpha <= 0.0 {
return Err(NetworkError::ConfigError(format!(
"Alpha for LeakyReLU must be greater than 0.0, but was {}",
self.alpha
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn ActivationFunction>, NetworkError> {
self.validate()?;
Ok(Box::new(LeakyReLUActivation { alpha: self.alpha }))
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new()
}
}
#[typetag::serde]
impl ActivationFunction for LeakyReLUActivation {
fn forward(&self, input: &mut DMat) {
input.apply(|x| if x > 0.0 { x } else { self.alpha * x });
}
fn backward(&self, d_output: &DMat, input: &mut DMat, _output: &DMat) {
input.apply(|x| if x > 0.0 { 1.0 } else { self.alpha });
input.mul_elem(d_output);
}
fn weight_initialization_factor(&self) -> fn(usize, usize) -> f32 {
he_initialization
}
}
impl ActivationFunctionClone for LeakyReLUActivation {
fn clone_box(&self) -> Box<dyn ActivationFunction> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod leakyrelu_tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_leakyrelu_forward() {
let mut input = DMat::new(2, 3, &[1.0, -2.0, 3.0, -4.0, 5.0, -6.0]);
let leakyrelu = LeakyReLU::new().alpha(0.01).build().unwrap();
leakyrelu.forward(&mut input);
let expected = DMat::new(2, 3, &[1.0, -0.02, 3.0, -0.04, 5.0, -0.06]);
assert!(equal_approx(&input, &expected, 1e-4), "LeakyReLU forward pass failed");
}
#[test]
fn test_leakyrelu_backward() {
let mut input = DMat::new(2, 3, &[1.0, -2.0, 3.0, -4.0, 5.0, -6.0]);
let d_output = DMat::new(2, 3, &[0.5, 1.0, 0.7, 0.2, 0.3, 0.1]);
let output: DMat = DMat::new(2, 3, &[0.0; 6]);
let leakyrelu = LeakyReLU::new().alpha(0.01).build().unwrap();
leakyrelu.backward(&d_output, &mut input, &output);
let expected = DMat::new(2, 3, &[0.5, 0.01, 0.7, 0.002, 0.3, 0.001]);
assert!(equal_approx(&input, &expected, 1e-4), "LeakyReLU backward pass failed");
}
#[test]
fn test_leakyrelu_invalid_alpha() {
let leakyrelu = LeakyReLU::new().alpha(-0.01);
let result = leakyrelu.build();
assert!(result.is_err(), "LeakyReLU should not allow negative alpha");
}
}