pub use super::{Loss, LossType};
pub use irithyll_core::loss::quantile::*;
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-12;
#[test]
fn test_n_outputs() {
assert_eq!(QuantileLoss::new(0.5).n_outputs(), 1);
}
#[test]
fn test_gradient_over_predict() {
let loss = QuantileLoss::new(0.9);
assert!((loss.gradient(1.0, 3.0) - 0.1).abs() < EPS);
assert!((loss.gradient(0.0, 100.0) - 0.1).abs() < EPS);
}
#[test]
fn test_gradient_under_predict() {
let loss = QuantileLoss::new(0.9);
assert!((loss.gradient(3.0, 1.0) - (-0.9)).abs() < EPS);
assert!((loss.gradient(100.0, 0.0) - (-0.9)).abs() < EPS);
}
#[test]
fn test_gradient_at_exact() {
let loss = QuantileLoss::new(0.5);
assert!((loss.gradient(5.0, 5.0) - 0.5).abs() < EPS);
}
#[test]
fn test_hessian_is_one() {
let loss = QuantileLoss::new(0.9);
assert!((loss.hessian(0.0, 0.0) - 1.0).abs() < EPS);
assert!((loss.hessian(100.0, -50.0) - 1.0).abs() < EPS);
assert!((loss.hessian(-7.0, 42.0) - 1.0).abs() < EPS);
}
#[test]
fn test_loss_pinball() {
let loss = QuantileLoss::new(0.9);
assert!((loss.loss(5.0, 3.0) - 0.9 * 2.0).abs() < EPS);
assert!((loss.loss(3.0, 5.0) - 0.1 * 2.0).abs() < EPS);
assert!((loss.loss(4.0, 4.0)).abs() < EPS);
}
#[test]
fn test_median_loss_is_half_mae() {
let loss = QuantileLoss::new(0.5);
assert!((loss.loss(5.0, 3.0) - 1.0).abs() < EPS); assert!((loss.loss(3.0, 5.0) - 1.0).abs() < EPS); }
#[test]
fn test_predict_transform_is_identity() {
let loss = QuantileLoss::new(0.5);
assert!((loss.predict_transform(42.0) - 42.0).abs() < EPS);
}
#[test]
fn test_initial_prediction_is_quantile() {
let loss = QuantileLoss::new(0.5);
let targets = [1.0, 2.0, 3.0, 4.0, 5.0];
assert!((loss.initial_prediction(&targets) - 3.0).abs() < EPS);
let loss90 = QuantileLoss::new(0.9);
assert!((loss90.initial_prediction(&targets) - 5.0).abs() < EPS);
}
#[test]
fn test_initial_prediction_empty() {
let loss = QuantileLoss::new(0.5);
assert!((loss.initial_prediction(&[])).abs() < EPS);
}
#[test]
fn test_loss_type_returns_some() {
let loss = QuantileLoss::new(0.75);
match loss.loss_type() {
Some(LossType::Quantile { tau }) => assert!((tau - 0.75).abs() < EPS),
other => panic!("expected Quantile, got {other:?}"),
}
}
#[test]
fn test_gradient_is_subderivative_of_loss() {
let loss = QuantileLoss::new(0.75);
let target = 2.5;
let pred = 4.0;
let h = 1e-6;
let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
let analytical = loss.gradient(target, pred);
assert!(
(numerical - analytical).abs() < 1e-4,
"over: numerical={numerical}, analytical={analytical}"
);
let pred2 = 1.0;
let numerical2 = (loss.loss(target, pred2 + h) - loss.loss(target, pred2 - h)) / (2.0 * h);
let analytical2 = loss.gradient(target, pred2);
assert!(
(numerical2 - analytical2).abs() < 1e-4,
"under: numerical={numerical2}, analytical={analytical2}"
);
}
#[test]
#[should_panic(expected = "tau must be in (0, 1)")]
fn test_invalid_tau_zero() {
QuantileLoss::new(0.0);
}
#[test]
#[should_panic(expected = "tau must be in (0, 1)")]
fn test_invalid_tau_one() {
QuantileLoss::new(1.0);
}
}