use crate::common::matrix::DMat;
use crate::{activation::ActivationFunction, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{he_initialization, ActivationFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct ReLUActivation;
pub struct ReLU;
impl ReLU {
fn new() -> Self {
Self {}
}
pub fn build() -> Result<Box<dyn ActivationFunction>, NetworkError> {
Ok(Box::new(ReLUActivation {}))
}
}
impl Default for ReLU {
fn default() -> Self {
Self::new()
}
}
#[typetag::serde]
impl ActivationFunction for ReLUActivation {
fn forward(&self, input: &mut DMat) {
input.apply(|x| x.max(0.0)); }
fn backward(&self, d_output: &DMat, input: &mut DMat, _output: &DMat) {
input.apply(|x| if x < 0.0 { 0.0 } else { 1.0 });
input.mul_elem(d_output);
}
fn weight_initialization_factor(&self) -> fn(usize, usize) -> f32 {
he_initialization
}
}
impl ActivationFunctionClone for ReLUActivation {
fn clone_box(&self) -> Box<dyn ActivationFunction> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod relu_tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_relu_forward_positive_values() {
let mut input = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let relu = ReLU::build().unwrap();
relu.forward(&mut input);
let expected = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!(equal_approx(&input, &expected, 1e-6), "ReLU forward pass with positive values failed");
}
#[test]
fn test_relu_forward_mixed_values() {
let mut input = DMat::new(2, 3, &[-1.0, 0.0, 2.0, -3.5, 4.2, 0.0]);
let relu = ReLU::build().unwrap();
relu.forward(&mut input);
let expected = DMat::new(2, 3, &[0.0, 0.0, 2.0, 0.0, 4.2, 0.0]);
assert!(equal_approx(&input, &expected, 1e-6), "ReLU forward pass with mixed values failed");
}
#[test]
fn test_relu_backward() {
let mut input = DMat::new(2, 3, &[-1.0, 0.0, 2.0, -3.5, 4.2, 0.0]);
let d_output = DMat::new(2, 3, &[0.5, 1.0, 0.7, 0.2, 0.3, 0.1]);
let relu = ReLU::build().unwrap();
let output = input.clone();
relu.backward(&d_output, &mut input, &output);
let expected = DMat::new(2, 3, &[0.0, 1.0, 0.7, 0.0, 0.3, 0.1]);
assert!(equal_approx(&input, &expected, 1e-6), "ReLU backward pass failed");
}
#[test]
fn test_relu_backward_zero_gradient() {
let mut input = DMat::new(1, 3, &[-1.0, -2.0, -3.0]);
let d_output = DMat::new(1, 3, &[0.5, 1.0, 0.7]);
let output: DMat = DMat::new(2, 3, &[0.0; 6]);
let relu = ReLU::build().unwrap();
relu.backward(&d_output, &mut input, &output);
let expected = DMat::new(1, 3, &[0.0, 0.0, 0.0]);
assert!(equal_approx(&input, &expected, 1e-6), "ReLU backward pass with all negative values failed");
}
}