use crate::common::matrix::DMat;
use crate::{activation::ActivationFunction, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{xavier_initialization, ActivationFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct TanhActivation;
pub struct Tanh;
impl Tanh {
fn new() -> Self {
Self {}
}
pub fn build() -> Result<Box<dyn ActivationFunction>, NetworkError> {
Ok(Box::new(TanhActivation {}))
}
}
impl Default for Tanh {
fn default() -> Self {
Self::new()
}
}
#[typetag::serde]
impl ActivationFunction for TanhActivation {
fn forward(&self, input: &mut DMat) {
input.apply(|x| x.tanh());
}
fn backward(&self, d_output: &DMat, input: &mut DMat, _output: &DMat) {
input.apply(|x| {
x * (1.0 - x * x) });
input.mul_elem(d_output);
}
fn weight_initialization_factor(&self) -> fn(usize, usize) -> f32 {
xavier_initialization
}
}
impl ActivationFunctionClone for TanhActivation {
fn clone_box(&self) -> Box<dyn ActivationFunction> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tanh_tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_tanh_forward_zero_input() {
let mut input = DMat::new(1, 1, &[0.0f32]);
let tanh = TanhActivation;
tanh.forward(&mut input);
let expected = DMat::new(1, 1, &[0.0f32]);
assert!(equal_approx(&input, &expected, 1e-6), "Tanh forward pass with zero input failed");
}
#[test]
fn test_tanh_forward_mixed_values() {
let mut input = DMat::new(2, 3, &[-1.0f32, 0.0, 2.0, -3.5, 4.2, 0.0]);
let tanh = TanhActivation;
tanh.forward(&mut input);
let expected = DMat::new(
2,
3,
&[
(-1.0f32).tanh(),
0.0,
2.0f32.tanh(),
(-3.5f32).tanh(),
4.2f32.tanh(),
0.0,
],
);
assert!(equal_approx(&input, &expected, 1e-6), "Tanh forward pass with mixed values failed");
}
#[test]
fn test_tanh_backward() {
let mut input = DMat::new(2, 3, &[-1.0f32, 0.0, 2.0, -3.5, 4.2, 0.0]);
let d_output = DMat::new(2, 3, &[0.5f32, 1.0, 0.7, 0.2, 0.3, 0.1]);
let tanh = TanhActivation;
tanh.forward(&mut input); let original_input = input.clone();
let output: DMat = DMat::new(2, 3, &[0.0; 6]); tanh.backward(&d_output, &mut input, &output);
let expected = DMat::new(
2,
3,
&[
original_input.at(0, 0) * (1.0 - original_input.at(0, 0).powi(2)) * d_output.at(0, 0),
original_input.at(0, 1) * (1.0 - original_input.at(0, 1).powi(2)) * d_output.at(0, 1),
original_input.at(0, 2) * (1.0 - original_input.at(0, 2).powi(2)) * d_output.at(0, 2),
original_input.at(1, 0) * (1.0 - original_input.at(1, 0).powi(2)) * d_output.at(1, 0),
original_input.at(1, 1) * (1.0 - original_input.at(1, 1).powi(2)) * d_output.at(1, 1),
original_input.at(1, 2) * (1.0 - original_input.at(1, 2).powi(2)) * d_output.at(1, 2),
],
);
assert!(equal_approx(&input, &expected, 1e-6), "Tanh backward pass failed");
}
#[test]
fn test_tanh_bounds() {
let test_cases = [(f32::NEG_INFINITY, -1.0f32), (f32::INFINITY, 1.0f32)];
let tanh = TanhActivation;
for (input_value, expected_output) in test_cases {
let mut input = DMat::new(1, 1, &[input_value]);
tanh.forward(&mut input);
let expected = DMat::new(1, 1, &[expected_output]);
assert!(equal_approx(&input, &expected, 1e-6), "Tanh forward pass at extreme bounds failed");
}
}
#[test]
fn test_tanh_symmetry() {
let test_cases = [
(-2.0f32, -2.0f32.tanh()),
(2.0f32, 2.0f32.tanh()),
(-0.5f32, -0.5f32.tanh()),
(0.5f32, 0.5f32.tanh()),
];
let tanh = TanhActivation;
for (input_value, expected_output) in test_cases {
let mut input = DMat::new(1, 1, &[input_value]);
tanh.forward(&mut input);
let expected = DMat::new(1, 1, &[expected_output]);
assert!(equal_approx(&input, &expected, 1e-6), "Tanh forward pass symmetry test failed");
}
}
#[test]
fn test_tanh_weight_initialization() {
let tanh = TanhActivation;
let factor = tanh.weight_initialization_factor();
let weight_matrix = factor(3, 3);
assert!(weight_matrix > 0.0, "Tanh weight initialization factor should be positive");
}
#[test]
fn test_tanh_clone() {
let tanh = Tanh::build().unwrap();
let _cloned_tanh = tanh.clone_box();
}
}