use crate::activation::Function;
use crate::linalg::Matrix;
use crate::Float;
pub struct Tanh;
impl Tanh {
pub fn new() -> Self {
Self
}
fn num_fun<T: Float>(&self, num: T) -> T {
let e_z = num.exp();
let e_mz = (-num).exp();
(e_z - e_mz) / (e_z + e_mz)
}
fn num_der<T: Float>(&self, num: T) -> T {
let val = self.num_fun(num);
T::one() - val * val
}
}
impl<T: Float> Function<T> for Tanh {
fn name(&self) -> String {
"Tanh".to_string()
}
fn call(&self, matrix: Matrix<T>) -> Matrix<T> {
matrix.map(|x| self.num_fun(x))
}
fn derivative(&self, matrix: Matrix<T>) -> Matrix<T> {
matrix.map(|x| self.num_der(x))
}
}
#[cfg(test)]
mod tests {
use crate::activation::tanh::Tanh;
use crate::activation::Function;
use crate::linalg::Matrix;
use crate::matrix;
#[test]
fn tanh_test() {
let a = matrix![[1f32, 0f32, 2f32, 3f32]];
let tanh = Tanh::new();
println!("{:?} {}", a.shape(), tanh.call(a));
}
#[test]
fn tanh_der() {
let a = matrix![[1f32, 0f32, 2f32, 3f32]];
let tanh = Tanh::new();
println!("{:?} {}", a.shape(), tanh.derivative(a));
}
}