pub use super::{Loss, LossType};
pub use irithyll_core::loss::softmax::*;
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn test_n_outputs() {
let loss = SoftmaxLoss { n_classes: 5 };
assert_eq!(loss.n_outputs(), 5);
}
#[test]
fn test_n_outputs_binary() {
let loss = SoftmaxLoss { n_classes: 2 };
assert_eq!(loss.n_outputs(), 2);
}
#[test]
fn test_gradient_correct_class() {
let loss = SoftmaxLoss { n_classes: 3 };
let g = loss.gradient(1.0, 0.0);
assert!((g - (-0.5)).abs() < EPS);
}
#[test]
fn test_gradient_wrong_class() {
let loss = SoftmaxLoss { n_classes: 3 };
let g = loss.gradient(0.0, 0.0);
assert!((g - 0.5).abs() < EPS);
}
#[test]
fn test_gradient_confident_correct() {
let loss = SoftmaxLoss { n_classes: 3 };
let g = loss.gradient(1.0, 5.0);
assert!(g < 0.0);
assert!(g > -0.01);
}
#[test]
fn test_hessian_positive() {
let loss = SoftmaxLoss { n_classes: 3 };
assert!(loss.hessian(0.0, 0.0) > 0.0);
assert!(loss.hessian(1.0, 5.0) > 0.0);
assert!(loss.hessian(0.0, -5.0) > 0.0);
assert!(loss.hessian(1.0, 100.0) > 0.0); }
#[test]
fn test_hessian_max_at_zero() {
let loss = SoftmaxLoss { n_classes: 3 };
let h_zero = loss.hessian(0.0, 0.0);
let h_large = loss.hessian(0.0, 5.0);
assert!((h_zero - 0.25).abs() < EPS);
assert!(h_large < h_zero);
}
#[test]
fn test_loss_value_at_zero() {
let loss = SoftmaxLoss { n_classes: 3 };
let l1 = loss.loss(1.0, 0.0);
let l0 = loss.loss(0.0, 0.0);
let ln2 = 2.0_f64.ln();
assert!((l1 - ln2).abs() < 1e-8);
assert!((l0 - ln2).abs() < 1e-8);
}
#[test]
fn test_loss_decreases_with_correct_prediction() {
let loss = SoftmaxLoss { n_classes: 3 };
let l_zero = loss.loss(1.0, 0.0);
let l_positive = loss.loss(1.0, 3.0);
assert!(l_positive < l_zero);
}
#[test]
fn test_predict_transform_is_sigmoid() {
let loss = SoftmaxLoss { n_classes: 3 };
assert!((loss.predict_transform(0.0) - 0.5).abs() < EPS);
assert!(loss.predict_transform(10.0) > 0.99);
assert!(loss.predict_transform(-10.0) < 0.01);
}
#[test]
fn test_initial_prediction_is_zero() {
let loss = SoftmaxLoss { n_classes: 3 };
let targets = [0.0, 1.0, 2.0, 1.0, 0.0];
assert!((loss.initial_prediction(&targets)).abs() < EPS);
assert!((loss.initial_prediction(&[])).abs() < EPS);
}
#[test]
fn test_gradient_is_derivative_of_loss() {
let loss = SoftmaxLoss { n_classes: 3 };
let target = 1.0;
let pred = 1.5;
let h = 1e-7;
let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
let analytical = loss.gradient(target, pred);
assert!(
(numerical - analytical).abs() < 1e-5,
"numerical={numerical}, analytical={analytical}"
);
let target = 0.0;
let pred = -0.5;
let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
let analytical = loss.gradient(target, pred);
assert!(
(numerical - analytical).abs() < 1e-5,
"numerical={numerical}, analytical={analytical}"
);
}
}