use crate::activations::Activation;
use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct Softmax {
axis: usize,
}
impl Softmax {
pub fn new(axis: usize) -> Self {
Self { axis }
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(0)
}
}
impl<F: Float + Debug + NumAssign> Activation<F> for Softmax {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
if input.ndim() <= self.axis {
return Err(NeuralError::InferenceError(format!(
"Softmax axis {} is out of bounds for input with {} dimensions",
self.axis,
input.ndim()
)));
}
if input.ndim() == 1 && self.axis == 0 {
let max_val = input.fold(F::neg_infinity(), |a, &b| a.max(b));
let mut output = input.clone();
for val in output.iter_mut() {
*val = (*val - max_val).exp();
}
let sum = output.fold(F::zero(), |a, &b| a + b);
for val in output.iter_mut() {
*val /= sum;
}
return Ok(output);
}
let max_vals = input.map_axis(Axis(self.axis), |view| {
view.fold(F::neg_infinity(), |a, &b| a.max(b))
});
let mut output = input.clone();
for (mut out_subview, &max_val) in
output.axis_iter_mut(Axis(self.axis)).zip(max_vals.iter())
{
for val in out_subview.iter_mut() {
*val = (*val - max_val).exp();
}
}
let sum_vals = output.map_axis(Axis(self.axis), |view| view.fold(F::zero(), |a, &b| a + b));
for (mut out_subview, &sum_val) in
output.axis_iter_mut(Axis(self.axis)).zip(sum_vals.iter())
{
for val in out_subview.iter_mut() {
*val /= sum_val;
}
}
Ok(output)
}
fn backward(
&self,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
if output.ndim() == 1 && self.axis == 0 {
let dot_product = grad_output
.iter()
.zip(output.iter())
.map(|(&g, &s)| g * s)
.fold(F::zero(), |a, b| a + b);
let grad_input = output
.iter()
.zip(grad_output.iter())
.map(|(&s, &g)| s * (g - dot_product))
.collect::<Vec<_>>();
return Ok(Array::from_vec(grad_input).into_dyn());
}
let weighted_sum = (grad_output * output).sum_axis(Axis(self.axis));
let mut sumshape = output.shape().to_vec();
sumshape[self.axis] = 1;
let weighted_sum_reshaped = weighted_sum.into_shape_with_order(sumshape)?;
let weighted_sum_broadcast = weighted_sum_reshaped
.broadcast(output.shape())
.expect("Operation failed");
let grad_input = output * (grad_output - &weighted_sum_broadcast);
Ok(grad_input)
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Layer<F> for Softmax {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
<Self as Activation<F>>::forward(self, input)
}
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let _output = <Self as Activation<F>>::forward(self, input)?;
<Self as Activation<F>>::backward(self, grad_output, &_output)
}
fn update(&mut self, learningrate: F) -> Result<()> {
Ok(())
}
}