use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::activation_layer::format_output_shape;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use ndarray::{Array2, ArrayView1, ArrayViewMut1, Axis, Zip};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
const EPSILON: f32 = 1e-8;
const GRAD_CLIP_VALUE: f32 = 1e6;
const INPUT_CLIP_MIN: f32 = -500.0;
const INPUT_CLIP_MAX: f32 = 500.0;
const SOFTMAX_PARALLEL_THRESHOLD: usize = 8;
pub struct Softmax {
input_cache: Option<Tensor>,
output_cache: Option<Tensor>,
}
impl Softmax {
pub fn new() -> Self {
Softmax {
input_cache: None,
output_cache: None,
}
}
}
impl Layer for Softmax {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.is_empty() {
return Err(ModelError::InputValidationError(
"Input tensor is empty".to_string(),
));
}
if input.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Input tensor contains NaN or infinite values".to_string(),
));
}
self.input_cache = Some(input.clone());
let shape = input.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(ModelError::InputValidationError(format!(
"Softmax requires input with at least 2 dimensions, got shape: {:?}",
shape
)));
}
let output = input.clone();
let batch_size: usize = shape[..ndim - 1].iter().product();
let num_features = shape[ndim - 1];
let mut output_2d = output
.into_shape_with_order((batch_size, num_features))
.map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to reshape for softmax computation: {}",
e
))
})?;
let apply_softmax = |mut row: ArrayViewMut1<f32>| {
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let clipped_max = max_val.clamp(INPUT_CLIP_MIN, INPUT_CLIP_MAX);
row.map_inplace(|x| {
let clipped_x = x.clamp(INPUT_CLIP_MIN, INPUT_CLIP_MAX);
*x = (clipped_x - clipped_max).exp()
});
let sum = row.sum().max(EPSILON);
row.map_inplace(|x| *x /= sum);
};
if batch_size > SOFTMAX_PARALLEL_THRESHOLD {
output_2d
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(apply_softmax);
} else {
output_2d.axis_iter_mut(Axis(0)).for_each(apply_softmax);
}
let output = output_2d
.into_shape_with_order(shape)
.map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to reshape back after softmax computation: {}",
e
))
})?
.into_dyn();
self.output_cache = Some(output.clone());
Ok(output)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
match (&self.input_cache, &self.output_cache) {
(Some(input), Some(output)) => {
if grad_output.shape() != input.shape() {
return Err(ModelError::ProcessingError(format!(
"Gradient output shape {:?} doesn't match input shape {:?}",
grad_output.shape(),
input.shape()
)));
}
if grad_output.iter().any(|&x| x.is_nan() || x.is_infinite()) {
return Err(ModelError::InputValidationError(
"Gradient output contains NaN or infinite values".to_string(),
));
}
let shape = input.shape();
let ndim = shape.len();
let batch_size: usize = shape[..ndim - 1].iter().product();
let num_features = shape[ndim - 1];
let output_2d = output
.clone()
.into_shape_with_order((batch_size, num_features))
.map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to reshape output for backward: {}",
e
))
})?;
let grad_output_2d = grad_output
.clone()
.into_shape_with_order((batch_size, num_features))
.map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to reshape grad_output for backward: {}",
e
))
})?;
let mut grad_input_2d = Array2::<f32>::zeros((batch_size, num_features));
let compute_gradient =
|mut grad_row: ArrayViewMut1<f32>,
out_row: ArrayView1<f32>,
grad_out_row: ArrayView1<f32>| {
let dot: f32 = out_row
.iter()
.zip(grad_out_row.iter())
.map(|(&o, &g)| o * g)
.sum();
for j in 0..num_features {
grad_row[j] = out_row[j] * (grad_out_row[j] - dot);
if grad_row[j].is_nan() || grad_row[j].is_infinite() {
grad_row[j] = 0.0;
} else {
grad_row[j] = grad_row[j].clamp(-GRAD_CLIP_VALUE, GRAD_CLIP_VALUE);
}
}
};
if batch_size > SOFTMAX_PARALLEL_THRESHOLD {
Zip::from(grad_input_2d.axis_iter_mut(Axis(0)))
.and(output_2d.axis_iter(Axis(0)))
.and(grad_output_2d.axis_iter(Axis(0)))
.par_for_each(compute_gradient);
} else {
Zip::from(grad_input_2d.axis_iter_mut(Axis(0)))
.and(output_2d.axis_iter(Axis(0)))
.and(grad_output_2d.axis_iter(Axis(0)))
.for_each(compute_gradient);
}
let grad_input = grad_input_2d
.into_shape_with_order(shape)
.map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to reshape grad_input back: {}",
e
))
})?
.into_dyn();
Ok(grad_input)
}
_ => Err(ModelError::ProcessingError(
"Forward pass has not been run yet".to_string(),
)),
}
}
fn layer_type(&self) -> &str {
"Softmax"
}
fn output_shape(&self) -> String {
format_output_shape(&self.input_cache)
}
no_trainable_parameters_layer_functions!();
}
impl ActivationLayer for Softmax {}