use crate::core::CoreError;
use crate::error::ApexSolverResult;
pub trait LossFunction: Send + Sync {
fn evaluate(&self, s: f64) -> [f64; 3];
}
#[derive(Debug, Clone, Copy)]
pub struct L2Loss;
impl L2Loss {
pub fn new() -> Self {
L2Loss
}
}
impl Default for L2Loss {
fn default() -> Self {
Self::new()
}
}
impl LossFunction for L2Loss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
[s, 1.0, 0.0]
}
}
#[derive(Debug, Clone, Copy)]
pub struct L1Loss;
impl L1Loss {
pub fn new() -> Self {
L1Loss
}
}
impl Default for L1Loss {
fn default() -> Self {
Self::new()
}
}
impl LossFunction for L1Loss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if s < f64::EPSILON {
return [s, 1.0, 0.0];
}
let sqrt_s = s.sqrt();
[
2.0 * sqrt_s, 1.0 / sqrt_s, -1.0 / (2.0 * s * sqrt_s), ]
}
}
#[derive(Debug, Clone)]
pub struct HuberLoss {
scale: f64,
scale2: f64,
}
impl HuberLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(HuberLoss {
scale,
scale2: scale * scale,
})
}
}
impl LossFunction for HuberLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if s > self.scale2 {
let r = s.sqrt(); let rho1 = (self.scale / r).max(f64::MIN); [
2.0 * self.scale * r - self.scale2, rho1, -rho1 / (2.0 * s), ]
} else {
[s, 1.0, 0.0]
}
}
}
pub struct CauchyLoss {
scale2: f64,
c: f64,
}
impl CauchyLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
let scale2 = scale * scale;
Ok(CauchyLoss {
scale2,
c: 1.0 / scale2,
})
}
}
impl LossFunction for CauchyLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let sum = 1.0 + s * self.c; let inv = 1.0 / sum;
[
self.scale2 * sum.ln() / 2.0, inv.max(f64::MIN), -self.c * (inv * inv), ]
}
}
#[derive(Debug, Clone)]
pub struct FairLoss {
scale: f64,
}
impl FairLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(FairLoss { scale })
}
}
impl LossFunction for FairLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if s < f64::EPSILON {
return [s, 1.0, 0.0];
}
let x = s.sqrt(); let abs_x = x.abs();
let c_plus_x = self.scale + abs_x;
let rho = self.scale * self.scale * (abs_x / self.scale - (1.0 + abs_x / self.scale).ln());
let rho_prime = 0.5 / c_plus_x;
let rho_double_prime = -1.0 / (4.0 * s * c_plus_x * c_plus_x);
[rho, rho_prime, rho_double_prime]
}
}
#[derive(Debug, Clone)]
pub struct GemanMcClureLoss {
c: f64, }
impl GemanMcClureLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
let scale2 = scale * scale;
Ok(GemanMcClureLoss { c: 1.0 / scale2 })
}
}
impl LossFunction for GemanMcClureLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let denom = 1.0 + s * self.c; let inv = 1.0 / denom;
let inv2 = inv * inv;
[
s * inv, inv2, -2.0 * self.c * inv2 * inv, ]
}
}
#[derive(Debug, Clone)]
pub struct WelschLoss {
scale2: f64,
inv_scale2: f64,
}
impl WelschLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
let scale2 = scale * scale;
Ok(WelschLoss {
scale2,
inv_scale2: 1.0 / scale2,
})
}
}
impl LossFunction for WelschLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let exp_term = (-s * self.inv_scale2).exp();
[
(self.scale2 / 2.0) * (1.0 - exp_term), 0.5 * exp_term, -0.5 * self.inv_scale2 * exp_term, ]
}
}
#[derive(Debug, Clone)]
pub struct TukeyBiweightLoss {
scale: f64,
scale2: f64,
}
impl TukeyBiweightLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(TukeyBiweightLoss {
scale,
scale2: scale * scale,
})
}
}
impl LossFunction for TukeyBiweightLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let x = s.sqrt();
if x > self.scale {
[self.scale2 / 6.0, 0.0, 0.0]
} else {
let ratio = x / self.scale;
let ratio2 = ratio * ratio;
let one_minus_ratio2 = 1.0 - ratio2;
let one_minus_ratio2_sq = one_minus_ratio2 * one_minus_ratio2;
[
(self.scale2 / 6.0) * (1.0 - one_minus_ratio2 * one_minus_ratio2_sq), 0.5 * one_minus_ratio2_sq, -(ratio / self.scale2) * one_minus_ratio2, ]
}
}
}
#[derive(Debug, Clone)]
pub struct AndrewsWaveLoss {
scale: f64,
scale2: f64,
threshold: f64, }
impl AndrewsWaveLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(AndrewsWaveLoss {
scale,
scale2: scale * scale,
threshold: std::f64::consts::PI * scale,
})
}
}
impl LossFunction for AndrewsWaveLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let x = s.sqrt();
if x > self.threshold {
[2.0 * self.scale2, 0.0, 0.0]
} else {
let arg = x / self.scale;
let sin_val = arg.sin();
let cos_val = arg.cos();
[
self.scale2 * (1.0 - cos_val), 0.5 * sin_val, (0.25 / self.scale) * cos_val / x.max(f64::EPSILON), ]
}
}
}
#[derive(Debug, Clone)]
pub struct RamsayEaLoss {
scale: f64,
inv_scale2: f64,
}
impl RamsayEaLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(RamsayEaLoss {
scale,
inv_scale2: 1.0 / (scale * scale),
})
}
}
impl LossFunction for RamsayEaLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let x = s.sqrt();
let ax = self.scale * x;
let exp_term = (-ax).exp();
let rho = self.inv_scale2 * (1.0 - exp_term * (1.0 + ax));
let rho_prime = 0.5 * exp_term;
let rho_double_prime = -(self.scale / (4.0 * x.max(f64::EPSILON))) * exp_term;
[rho, rho_prime, rho_double_prime]
}
}
#[derive(Debug, Clone)]
pub struct TrimmedMeanLoss {
scale2: f64,
}
impl TrimmedMeanLoss {
pub fn new(scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(
CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
);
}
Ok(TrimmedMeanLoss {
scale2: scale * scale,
})
}
}
impl LossFunction for TrimmedMeanLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if s <= self.scale2 {
[s / 2.0, 0.5, 0.0]
} else {
[self.scale2 / 2.0, 0.0, 0.0]
}
}
}
#[derive(Debug, Clone)]
pub struct LpNormLoss {
p: f64,
}
impl LpNormLoss {
pub fn new(p: f64) -> ApexSolverResult<Self> {
if p <= 0.0 {
return Err(CoreError::InvalidInput("p must be positive".to_string()).into());
}
Ok(LpNormLoss { p })
}
}
impl LossFunction for LpNormLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if s < f64::EPSILON {
return [s, 1.0, 0.0];
}
let exp_rho = self.p / 2.0;
let exp_rho_prime = exp_rho - 1.0;
let exp_rho_double_prime = exp_rho_prime - 1.0;
[
s.powf(exp_rho), exp_rho * s.powf(exp_rho_prime), exp_rho * exp_rho_prime * s.powf(exp_rho_double_prime), ]
}
}
#[derive(Debug, Clone)]
pub struct BarronGeneralLoss {
alpha: f64,
scale: f64,
scale2: f64,
}
impl BarronGeneralLoss {
pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
if scale <= 0.0 {
return Err(CoreError::InvalidInput("scale must be positive".to_string()).into());
}
Ok(BarronGeneralLoss {
alpha,
scale,
scale2: scale * scale,
})
}
}
impl LossFunction for BarronGeneralLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
if self.alpha.abs() < 1e-6 {
let denom = 1.0 + s / self.scale2;
let inv = 1.0 / denom;
return [
(self.scale2 / 2.0) * denom.ln(),
inv.max(f64::MIN),
-inv * inv / self.scale2,
];
}
if (self.alpha - 2.0).abs() < 1e-6 {
return [s, 1.0, 0.0];
}
let x = s.sqrt();
let normalized = x / self.scale;
let normalized2 = normalized * normalized;
let inner = self.alpha.abs() / 2.0 * normalized2 + 1.0;
let power = inner.powf(self.alpha / 2.0);
let rho = (self.alpha.abs() / self.scale2) * (power - 1.0);
let rho_prime = 0.5 * inner.powf(self.alpha / 2.0 - 1.0);
let rho_double_prime =
(self.alpha - 2.0) / (4.0 * self.scale2) * inner.powf(self.alpha / 2.0 - 2.0);
[rho, rho_prime, rho_double_prime]
}
}
#[derive(Debug, Clone)]
pub struct TDistributionLoss {
nu: f64, half_nu_plus_1: f64, }
impl TDistributionLoss {
pub fn new(nu: f64) -> ApexSolverResult<Self> {
if nu <= 0.0 {
return Err(
CoreError::InvalidInput("degrees of freedom must be positive".to_string()).into(),
);
}
Ok(TDistributionLoss {
nu,
half_nu_plus_1: (nu + 1.0) / 2.0,
})
}
}
impl LossFunction for TDistributionLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
let inner = 1.0 + s / self.nu;
let rho = self.half_nu_plus_1 * inner.ln();
let denom = self.nu + s;
let rho_prime = self.half_nu_plus_1 / denom;
let rho_double_prime = -self.half_nu_plus_1 / (denom * denom);
[rho, rho_prime, rho_double_prime]
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveBarronLoss {
inner: BarronGeneralLoss,
}
impl AdaptiveBarronLoss {
pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
Ok(AdaptiveBarronLoss {
inner: BarronGeneralLoss::new(alpha, scale)?,
})
}
const fn new_default() -> Self {
AdaptiveBarronLoss {
inner: BarronGeneralLoss {
alpha: 0.0,
scale: 1.0,
scale2: 1.0,
},
}
}
}
impl LossFunction for AdaptiveBarronLoss {
#[inline]
fn evaluate(&self, s: f64) -> [f64; 3] {
self.inner.evaluate(s)
}
}
impl Default for AdaptiveBarronLoss {
fn default() -> Self {
Self::new_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
type TestResult = Result<(), Box<dyn std::error::Error>>;
const EPSILON: f64 = 1e-6;
fn numerical_derivative(loss: &dyn LossFunction, s: f64, h: f64) -> (f64, f64) {
let [rho_plus, _, _] = loss.evaluate(s + h);
let [rho_minus, _, _] = loss.evaluate(s - h);
let [rho, _, _] = loss.evaluate(s);
let rho_prime_numerical = (rho_plus - rho_minus) / (2.0 * h);
let rho_double_prime_numerical = (rho_plus - 2.0 * rho + rho_minus) / (h * h);
(rho_prime_numerical, rho_double_prime_numerical)
}
#[test]
fn test_l2_loss() -> TestResult {
let loss = L2Loss;
let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert_eq!(rho_prime, 1.0);
assert_eq!(rho_double_prime, 0.0);
let [rho, rho_prime, rho_double_prime] = loss.evaluate(4.0);
assert_eq!(rho, 4.0);
assert_eq!(rho_prime, 1.0);
assert_eq!(rho_double_prime, 0.0);
Ok(())
}
#[test]
fn test_l1_loss() -> TestResult {
let loss = L1Loss;
let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!(rho_prime.is_finite());
assert!(rho_double_prime.is_finite());
let [rho, rho_prime, _] = loss.evaluate(4.0);
assert!((rho - 4.0).abs() < EPSILON); assert!((rho_prime - 0.5).abs() < EPSILON);
Ok(())
}
#[test]
fn test_fair_loss() -> TestResult {
let loss = FairLoss::new(1.3999)?;
let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert_eq!(rho_prime, 1.0);
assert_eq!(rho_double_prime, 0.0);
let [_, rho_prime, _] = loss.evaluate(1.0);
assert!(rho_prime > 0.2 && rho_prime < 0.25);
let [_, rho_prime_outlier, _] = loss.evaluate(100.0);
assert!(rho_prime_outlier < rho_prime);
let [_, rho_prime_4, rho_double_prime_4] = loss.evaluate(4.0);
assert!(rho_prime_4.is_finite() && rho_prime_4 > 0.0);
assert!(rho_double_prime_4.is_finite() && rho_double_prime_4 < 0.0);
Ok(())
}
#[test]
fn test_geman_mcclure_loss() -> TestResult {
let loss = GemanMcClureLoss::new(1.0)?;
let [rho, rho_prime, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!((rho_prime - 1.0).abs() < EPSILON);
let [_, rho_prime_small, _] = loss.evaluate(1.0);
let [_, rho_prime_large, _] = loss.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small);
assert!(rho_prime_large < 0.1);
let s = 2.0;
let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
assert!((rho_prime - rho_prime_num).abs() < 1e-4);
assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
Ok(())
}
#[test]
fn test_welsch_loss() -> TestResult {
let loss = WelschLoss::new(2.9846)?;
let [rho, rho_prime, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!((rho_prime - 0.5).abs() < EPSILON);
let [_, rho_prime_10, _] = loss.evaluate(10.0);
let [_, rho_prime_100, _] = loss.evaluate(100.0);
assert!(rho_prime_100 < rho_prime_10);
assert!(rho_prime_100 < 0.01);
let s = 5.0;
let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
assert!((rho_prime - rho_prime_num).abs() < 1e-4);
assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
Ok(())
}
#[test]
fn test_tukey_biweight_loss() -> TestResult {
let loss = TukeyBiweightLoss::new(4.6851)?;
let [rho, rho_prime, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!((rho_prime - 0.5).abs() < EPSILON);
let scale2 = 4.6851 * 4.6851;
let [_, rho_prime_in, _] = loss.evaluate(scale2 * 0.5);
assert!(rho_prime_in > 0.05);
let [_, rho_prime_out, _] = loss.evaluate(scale2 * 1.5);
assert_eq!(rho_prime_out, 0.0);
let [_, rho_prime_5, rho_double_prime_5] = loss.evaluate(5.0);
assert!(rho_prime_5.is_finite() && rho_prime_5 > 0.0);
assert!(rho_double_prime_5.is_finite() && rho_double_prime_5 < 0.0);
Ok(())
}
#[test]
fn test_andrews_wave_loss() -> TestResult {
let loss = AndrewsWaveLoss::new(1.339)?;
let [rho, rho_prime, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!(rho_prime.abs() < EPSILON);
let [_, rho_prime_in, _] = loss.evaluate(1.0);
assert!(rho_prime_in > 0.33 && rho_prime_in < 0.35);
let scale = 1.339;
let [_, rho_prime_out, _] = loss.evaluate((scale * std::f64::consts::PI + 0.1).powi(2));
assert!(rho_prime_out.abs() < 0.01);
let [_, rho_prime_1, rho_double_prime_1] = loss.evaluate(1.0);
assert!(rho_prime_1.is_finite() && rho_prime_1 > 0.0);
assert!(rho_double_prime_1.is_finite());
Ok(())
}
#[test]
fn test_ramsay_ea_loss() -> TestResult {
let loss = RamsayEaLoss::new(0.3)?;
let [rho, _, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
let [_, rho_prime_small, _] = loss.evaluate(1.0);
let [_, rho_prime_large, _] = loss.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small);
let s = 4.0;
let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
assert!((rho_prime - rho_prime_num).abs() < 1e-4);
assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
Ok(())
}
#[test]
fn test_trimmed_mean_loss() -> TestResult {
let loss = TrimmedMeanLoss::new(2.0)?;
let scale2 = 4.0;
let [rho, rho_prime, rho_double_prime] = loss.evaluate(2.0);
assert!((rho - 1.0).abs() < EPSILON);
assert!((rho_prime - 0.5).abs() < EPSILON);
assert_eq!(rho_double_prime, 0.0);
let [rho_out, rho_prime_out, rho_double_prime_out] = loss.evaluate(10.0);
assert!((rho_out - scale2 / 2.0).abs() < EPSILON);
assert_eq!(rho_prime_out, 0.0);
assert_eq!(rho_double_prime_out, 0.0);
Ok(())
}
#[test]
fn test_lp_norm_loss() -> TestResult {
let l1 = LpNormLoss::new(1.0)?;
let [rho_l1, _, _] = l1.evaluate(4.0);
assert!((rho_l1 - 2.0).abs() < EPSILON);
let l2 = LpNormLoss::new(2.0)?;
let [rho_l2, rho_prime_l2, rho_double_prime_l2] = l2.evaluate(4.0);
assert!((rho_l2 - 4.0).abs() < EPSILON);
assert!((rho_prime_l2 - 1.0).abs() < EPSILON);
assert_eq!(rho_double_prime_l2, 0.0);
let l05 = LpNormLoss::new(0.5)?;
let [_, rho_prime_05, _] = l05.evaluate(4.0);
assert!(rho_prime_05 < 1.0);
let loss = LpNormLoss::new(1.5)?;
let s = 4.0;
let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
assert!((rho_prime - rho_prime_num).abs() < 1e-4);
assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
Ok(())
}
#[test]
fn test_barron_general_loss_special_cases() -> TestResult {
let cauchy = BarronGeneralLoss::new(0.0, 1.0)?;
let [_, rho_prime_small, _] = cauchy.evaluate(1.0);
let [_, rho_prime_large, _] = cauchy.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small);
let l2 = BarronGeneralLoss::new(2.0, 1.0)?;
let [rho, rho_prime, rho_double_prime] = l2.evaluate(4.0);
assert!((rho - 4.0).abs() < EPSILON);
assert!((rho_prime - 1.0).abs() < EPSILON);
assert!(rho_double_prime.abs() < EPSILON);
let charbonnier = BarronGeneralLoss::new(1.0, 1.0)?;
let [_, rho_prime_char, _] = charbonnier.evaluate(4.0);
assert!(rho_prime_char > 0.0 && rho_prime_char < 1.0);
let gm = BarronGeneralLoss::new(-2.0, 1.0)?;
let [_, rho_prime_small, _] = gm.evaluate(1.0);
let [_, rho_prime_large, _] = gm.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small); assert!(rho_prime_large < 0.1);
Ok(())
}
#[test]
fn test_constructor_validation() -> TestResult {
assert!(FairLoss::new(0.0).is_err());
assert!(FairLoss::new(-1.0).is_err());
assert!(GemanMcClureLoss::new(0.0).is_err());
assert!(WelschLoss::new(-1.0).is_err());
assert!(TukeyBiweightLoss::new(0.0).is_err());
assert!(AndrewsWaveLoss::new(-1.0).is_err());
assert!(RamsayEaLoss::new(0.0).is_err());
assert!(TrimmedMeanLoss::new(-1.0).is_err());
assert!(BarronGeneralLoss::new(0.0, 0.0).is_err());
assert!(BarronGeneralLoss::new(1.0, -1.0).is_err());
assert!(LpNormLoss::new(0.0).is_err());
assert!(LpNormLoss::new(-1.0).is_err());
assert!(FairLoss::new(1.0).is_ok());
assert!(LpNormLoss::new(1.5).is_ok());
assert!(BarronGeneralLoss::new(1.0, 1.0).is_ok());
Ok(())
}
#[test]
fn test_loss_comparison() -> TestResult {
let s_outlier = 100.0;
let l2 = L2Loss;
let huber = HuberLoss::new(1.345)?;
let cauchy = CauchyLoss::new(2.3849)?;
let [_, w_l2, _] = l2.evaluate(s_outlier);
let [_, w_huber, _] = huber.evaluate(s_outlier);
let [_, w_cauchy, _] = cauchy.evaluate(s_outlier);
assert!(w_l2 > w_huber);
assert!(w_huber > w_cauchy);
assert!(w_cauchy < 0.1);
Ok(())
}
#[test]
fn test_t_distribution_loss() -> TestResult {
let loss = TDistributionLoss::new(5.0)?;
let [rho, rho_prime, _] = loss.evaluate(0.0);
assert_eq!(rho, 0.0);
assert!((rho_prime - 0.6).abs() < 0.01);
let [_, rho_prime_small, _] = loss.evaluate(1.0);
let [_, rho_prime_large, _] = loss.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small);
assert!(rho_prime_large < 0.1);
let s = 4.0;
let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
assert!((rho_prime - rho_prime_num).abs() < 1e-4);
assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-4);
Ok(())
}
#[test]
fn test_t_distribution_loss_different_nu() -> TestResult {
let t3 = TDistributionLoss::new(3.0)?;
let t10 = TDistributionLoss::new(10.0)?;
let s_outlier = 100.0;
let [_, w_t3, _] = t3.evaluate(s_outlier);
let [_, w_t10, _] = t10.evaluate(s_outlier);
assert!(w_t3 < w_t10);
Ok(())
}
#[test]
fn test_adaptive_barron_loss() -> TestResult {
let adaptive = AdaptiveBarronLoss::new(0.0, 1.0)?;
let [rho, _, _] = adaptive.evaluate(0.0);
assert!(rho.abs() < EPSILON);
let [_, rho_prime_small, _] = adaptive.evaluate(1.0);
let [_, rho_prime_large, _] = adaptive.evaluate(100.0);
assert!(rho_prime_large < rho_prime_small);
let barron = BarronGeneralLoss::new(0.0, 1.0)?;
let [rho_a, rho_prime_a, rho_double_prime_a] = adaptive.evaluate(4.0);
let [rho_b, rho_prime_b, rho_double_prime_b] = barron.evaluate(4.0);
assert!((rho_a - rho_b).abs() < EPSILON);
assert!((rho_prime_a - rho_prime_b).abs() < EPSILON);
assert!((rho_double_prime_a - rho_double_prime_b).abs() < EPSILON);
Ok(())
}
#[test]
fn test_adaptive_barron_default() -> TestResult {
let adaptive = AdaptiveBarronLoss::default();
let [_, rho_prime, _] = adaptive.evaluate(4.0);
assert!(rho_prime > 0.0 && rho_prime < 1.0);
Ok(())
}
#[test]
fn test_new_loss_constructor_validation() -> TestResult {
assert!(TDistributionLoss::new(0.0).is_err());
assert!(TDistributionLoss::new(-1.0).is_err());
assert!(TDistributionLoss::new(5.0).is_ok());
assert!(AdaptiveBarronLoss::new(0.0, 0.0).is_err());
assert!(AdaptiveBarronLoss::new(1.0, -1.0).is_err());
assert!(AdaptiveBarronLoss::new(0.0, 1.0).is_ok());
Ok(())
}
}