use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::activation_layer::format_output_shape;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use ndarray::Zip;
const GRAD_CLIP_VALUE: f32 = 1e6;
const INPUT_CLIP_MIN: f32 = -500.0;
const INPUT_CLIP_MAX: f32 = 500.0;
const TANH_PARALLEL_THRESHOLD: usize = 2048;
pub struct Tanh {
input_cache: Option<Tensor>,
}
impl Tanh {
pub fn new() -> Self {
Tanh { input_cache: None }
}
}
impl Layer for Tanh {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.is_empty() {
return Err(ModelError::InputValidationError(
"Input tensor is empty".to_string(),
));
}
if input.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Input tensor contains NaN or infinite values".to_string(),
));
}
self.input_cache = Some(input.clone());
let output = if input.len() >= TANH_PARALLEL_THRESHOLD {
let mut output = input.clone();
Zip::from(&mut output).par_for_each(|x| {
let clipped_x = x.clamp(INPUT_CLIP_MIN, INPUT_CLIP_MAX);
*x = clipped_x.tanh();
});
output
} else {
input.mapv(|x| {
let clipped_x = x.clamp(INPUT_CLIP_MIN, INPUT_CLIP_MAX);
clipped_x.tanh()
})
};
Ok(output)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
if let Some(input) = &self.input_cache {
if grad_output.shape() != input.shape() {
return Err(ModelError::ProcessingError(format!(
"Gradient output shape {:?} doesn't match input shape {:?}",
grad_output.shape(),
input.shape()
)));
}
if grad_output.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Gradient output contains NaN or infinite values".to_string(),
));
}
let mut grad_input = grad_output.clone();
let compute_gradient = |grad: &mut f32, &inp: &f32| {
let clipped_inp = inp.clamp(INPUT_CLIP_MIN, INPUT_CLIP_MAX);
let tanh_val = clipped_inp.tanh();
let derivative = 1.0 - tanh_val * tanh_val;
*grad *= derivative;
if grad.is_nan() || grad.is_infinite() {
*grad = 0.0;
} else {
*grad = grad.clamp(-GRAD_CLIP_VALUE, GRAD_CLIP_VALUE);
}
};
if input.len() >= TANH_PARALLEL_THRESHOLD {
Zip::from(&mut grad_input)
.and(input)
.par_for_each(compute_gradient);
} else {
Zip::from(&mut grad_input)
.and(input)
.for_each(compute_gradient);
}
Ok(grad_input)
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run yet".to_string(),
))
}
}
fn layer_type(&self) -> &str {
"Tanh"
}
fn output_shape(&self) -> String {
format_output_shape(&self.input_cache)
}
no_trainable_parameters_layer_functions!();
}
impl ActivationLayer for Tanh {}