use serde::{Deserialize, Serialize};
use typetag;
use crate::{classification::ClassificationEvaluator, error::NetworkError, matrix::DMat, MetricEvaluator, Metrics};
use super::{LossFunction, LossFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct CrossEntropyLoss {
epsilon: f32,
}
pub struct CrossEntropy {
epsilon: f32,
}
impl CrossEntropy {
fn new() -> Self {
Self { epsilon: f32::EPSILON }
}
pub fn epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.epsilon <= 0.0 || self.epsilon >= 1.0 {
return Err(NetworkError::ConfigError(format!(
"Epsilon for CrossEntropy must be in the range (0, 1), but was {}",
self.epsilon
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn LossFunction>, NetworkError> {
self.validate()?;
Ok(Box::new(CrossEntropyLoss { epsilon: self.epsilon }))
}
}
impl Default for CrossEntropy {
fn default() -> Self {
Self::new()
}
}
impl LossFunctionClone for CrossEntropyLoss {
fn clone_box(&self) -> Box<dyn LossFunction> {
Box::new(self.clone())
}
}
#[typetag::serde]
impl LossFunction for CrossEntropyLoss {
fn forward(&self, predicted: &DMat, target: &DMat) -> f32 {
let mut total_loss = 0.0;
let (rows, cols) = (predicted.rows(), predicted.cols());
for i in 0..rows {
for j in 0..cols {
let mut v = predicted.at(i, j);
let t = target.at(i, j);
v = v.max(self.epsilon).min(1.0 - self.epsilon);
total_loss -= t * v.ln();
}
}
total_loss / rows as f32
}
fn backward(&self, predicted: &DMat, target: &DMat) -> DMat {
let (rows, cols) = (predicted.rows(), predicted.cols());
let mut gradient = DMat::zeros(rows, cols);
gradient.apply_with_indices(|i, j, v| {
let t = target.at(i, j);
let p = predicted.at(i, j);
*v = p - t;
});
gradient
}
fn calculate_metrics(&self, targets: &DMat, predictions: &DMat) -> Metrics {
ClassificationEvaluator.evaluate(targets, predictions)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{common::matrix::DMat, util};
#[test]
fn test_forward() {
let loss = CrossEntropy::new().epsilon(1e-7).build().unwrap();
let predicted = DMat::new(2, 2, &[0.9, 0.1, 0.2, 0.8]);
let target = DMat::new(2, 2, &[1.0, 0.0, 0.0, 1.0]);
let result = loss.forward(&predicted, &target);
assert!((result - 0.164_252_03).abs() < 1e-6);
}
#[test]
fn test_backward() {
let loss = CrossEntropy::new().epsilon(1e-7).build().unwrap();
let predicted = DMat::new(2, 2, &[0.9, 0.1, 0.2, 0.8]);
let target = DMat::new(2, 2, &[1.0, 0.0, 0.0, 1.0]);
let gradient = loss.backward(&predicted, &target);
let expected_gradient = DMat::new(2, 2, &[-0.1, 0.1, 0.2, -0.2]);
assert!(util::equal_approx(&gradient, &expected_gradient, 1e-6));
}
#[test]
fn test_crossentropy_validate() {
let loss = CrossEntropy::new().epsilon(1e-7);
assert!(loss.validate().is_ok());
let loss = CrossEntropy::new().epsilon(0.0);
assert!(loss.validate().is_err());
let loss = CrossEntropy::new().epsilon(1.0);
assert!(loss.validate().is_err());
}
}