use crate::error::{NeuralError, Result};
use crate::losses::Loss;
use scirs2_core::ndarray::{Array, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct CrossEntropyLoss {
epsilon: f64,
}
impl CrossEntropyLoss {
pub fn new(epsilon: f64) -> Self {
Self { epsilon }
}
}
impl Default for CrossEntropyLoss {
fn default() -> Self {
Self::new(1e-10)
}
}
impl<F: Float + Debug + NumAssign> Loss<F> for CrossEntropyLoss {
fn forward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in CrossEntropy: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let epsilon = F::from(self.epsilon).ok_or_else(|| {
NeuralError::InferenceError("Could not convert epsilon to float".to_string())
})?;
let n = F::from(if predictions.ndim() > 1 {
predictions.shape()[0]
} else {
1
})
.ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
let mut loss = F::zero();
if predictions.ndim() > 1 {
for i in 0..predictions.shape()[0] {
let mut sample_loss = F::zero();
for j in 0..predictions.shape()[1] {
let y_pred = predictions[[i, j]].max(epsilon).min(F::one() - epsilon);
let y_true = targets[[i, j]];
if y_true > F::zero() {
sample_loss -= y_true * y_pred.ln();
}
}
loss += sample_loss;
}
loss /= n;
} else {
Zip::from(predictions).and(targets).for_each(|&p, &t| {
let p_safe = p.max(epsilon).min(F::one() - epsilon);
if t > F::zero() {
loss -= t * p_safe.ln();
}
});
}
Ok(loss)
}
fn backward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in CrossEntropy gradient: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let epsilon = F::from(self.epsilon).ok_or_else(|| {
NeuralError::InferenceError("Could not convert epsilon to float".to_string())
})?;
let n = F::from(if predictions.ndim() > 1 {
predictions.shape()[0]
} else {
1
})
.ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
let mut gradients = predictions.clone();
Zip::from(&mut gradients).and(targets).for_each(|grad, &t| {
let _p_safe = grad.max(epsilon).min(F::one() - epsilon);
*grad = (*grad - t) / n;
});
Ok(gradients)
}
}