use crate::geometry::diagram::RegionMask;
use std::collections::HashMap;
#[inline]
fn smooth_abs(x: f64, eps: f64) -> f64 {
(x * x + eps * eps).sqrt() - eps
}
#[inline]
fn smooth_max(values: &[f64], eps: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let m = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if !m.is_finite() {
return m;
}
let inv_eps = 1.0 / eps;
let sum: f64 = values.iter().map(|&v| ((v - m) * inv_eps).exp()).sum();
m + eps * sum.ln()
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum LossType {
#[default]
SumSquared,
SumAbsoute,
SumAbsoluteRegionError,
SumSquaredRegionError,
MaxAbsolute,
MaxSquared,
RootMeanSquared,
Stress,
DiagError,
SmoothSumAbsolute {
eps: f64,
},
SmoothSumAbsoluteRegionError {
eps: f64,
},
SmoothMaxAbsolute {
eps: f64,
},
SmoothMaxSquared {
eps: f64,
},
SmoothDiagError {
eps: f64,
},
}
impl LossType {
pub fn sse() -> Self {
Self::SumSquared
}
pub fn rmse() -> Self {
Self::RootMeanSquared
}
pub fn stress() -> Self {
Self::Stress
}
pub fn max_absolute() -> Self {
Self::MaxAbsolute
}
pub fn max_squared() -> Self {
Self::MaxSquared
}
pub fn sum_absolute() -> Self {
Self::SumAbsoute
}
pub fn sum_absolute_region_error() -> Self {
Self::SumAbsoluteRegionError
}
pub fn sum_squared_region_error() -> Self {
Self::SumSquaredRegionError
}
pub fn diag_error() -> Self {
Self::DiagError
}
pub fn smooth_sum_absolute(eps: f64) -> Self {
Self::SmoothSumAbsolute { eps }
}
pub fn smooth_sum_absolute_region_error(eps: f64) -> Self {
Self::SmoothSumAbsoluteRegionError { eps }
}
pub fn smooth_max_absolute(eps: f64) -> Self {
Self::SmoothMaxAbsolute { eps }
}
pub fn smooth_max_squared(eps: f64) -> Self {
Self::SmoothMaxSquared { eps }
}
pub fn smooth_diag_error(eps: f64) -> Self {
Self::SmoothDiagError { eps }
}
pub fn is_smooth(&self) -> bool {
match self {
LossType::SumSquared
| LossType::RootMeanSquared
| LossType::Stress
| LossType::SumSquaredRegionError
| LossType::SmoothSumAbsolute { .. }
| LossType::SmoothSumAbsoluteRegionError { .. }
| LossType::SmoothMaxAbsolute { .. }
| LossType::SmoothMaxSquared { .. }
| LossType::SmoothDiagError { .. } => true,
LossType::SumAbsoute
| LossType::SumAbsoluteRegionError
| LossType::MaxAbsolute
| LossType::MaxSquared
| LossType::DiagError => false,
}
}
pub fn compute(
&self,
fitted: &HashMap<RegionMask, f64>,
target: &HashMap<RegionMask, f64>,
) -> f64 {
let mut all_masks: Vec<RegionMask> = fitted.keys().chain(target.keys()).copied().collect();
all_masks.sort_unstable();
all_masks.dedup();
if all_masks.is_empty() {
return 0.0;
}
match self {
LossType::SumSquared => {
let sum_t2: f64 = target.values().map(|&v| v * v).sum();
if sum_t2 < 1e-20 {
return 0.0;
}
let sum_sq: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).powi(2)
})
.sum();
sum_sq / sum_t2
}
LossType::RootMeanSquared => {
let sum_t2: f64 = target.values().map(|&v| v * v).sum();
if sum_t2 < 1e-20 {
return 0.0;
}
let sum_squared: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).powi(2)
})
.sum();
(sum_squared / sum_t2).sqrt()
}
LossType::Stress => {
let sum_ft: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
f * t
})
.sum();
let sum_t2: f64 = target.values().map(|&v| v * v).sum();
let sum_f2: f64 = fitted.values().map(|&v| v * v).sum();
if sum_t2 < 1e-20 || sum_f2 < 1e-20 {
return 0.0;
}
let beta = sum_ft / sum_t2;
let numerator: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - beta * t).powi(2)
})
.sum();
numerator / sum_f2
}
LossType::MaxAbsolute => {
let max_t: f64 = target.values().map(|v| v.abs()).fold(0.0_f64, f64::max);
if max_t < 1e-20 {
return 0.0;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).abs()
})
.fold(0.0, f64::max)
/ max_t
}
LossType::MaxSquared => {
let max_t2: f64 = target.values().map(|v| v * v).fold(0.0_f64, f64::max);
if max_t2 < 1e-20 {
return 0.0;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).powi(2)
})
.fold(0.0, f64::max)
/ max_t2
}
LossType::DiagError => {
let ssf = fitted.values().sum::<f64>();
let sst = target.values().sum::<f64>();
if ssf.abs() < 1e-10 || sst.abs() < 1e-10 {
return f64::MAX;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f / ssf - t / sst).abs()
})
.fold(0.0, f64::max)
}
LossType::SumAbsoute => {
let sum_abs_t: f64 = target.values().map(|v| v.abs()).sum();
if sum_abs_t < 1e-20 {
return 0.0;
}
let sum_abs: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).abs()
})
.sum();
sum_abs / sum_abs_t
}
LossType::SumAbsoluteRegionError => {
let ssf = fitted.values().sum::<f64>();
let sst = target.values().sum::<f64>();
if ssf.abs() < 1e-10 || sst.abs() < 1e-10 {
return f64::MAX;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f / ssf - t / sst).abs()
})
.sum()
}
LossType::SumSquaredRegionError => {
let ssf = fitted.values().sum::<f64>();
let sst = target.values().sum::<f64>();
if ssf.abs() < 1e-10 || sst.abs() < 1e-10 {
return f64::MAX;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f / ssf - t / sst).powi(2)
})
.sum()
}
LossType::SmoothSumAbsolute { eps } => {
let eps = eps.max(f64::MIN_POSITIVE);
let sum_abs_t: f64 = target.values().map(|v| v.abs()).sum();
if sum_abs_t < 1e-20 {
return 0.0;
}
let sum_abs: f64 = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
smooth_abs(f - t, eps)
})
.sum();
sum_abs / sum_abs_t
}
LossType::SmoothSumAbsoluteRegionError { eps } => {
let eps = eps.max(f64::MIN_POSITIVE);
let ssf = fitted.values().sum::<f64>();
let sst = target.values().sum::<f64>();
if ssf.abs() < 1e-10 || sst.abs() < 1e-10 {
return f64::MAX;
}
all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
smooth_abs(f / ssf - t / sst, eps)
})
.sum()
}
LossType::SmoothMaxAbsolute { eps } => {
let eps = eps.max(f64::MIN_POSITIVE);
let max_t: f64 = target.values().map(|v| v.abs()).fold(0.0_f64, f64::max);
if max_t < 1e-20 {
return 0.0;
}
let smoothed_abs: Vec<f64> = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
smooth_abs(f - t, eps)
})
.collect();
smooth_max(&smoothed_abs, eps) / max_t
}
LossType::SmoothMaxSquared { eps } => {
let eps = eps.max(f64::MIN_POSITIVE);
let max_t2: f64 = target.values().map(|v| v * v).fold(0.0_f64, f64::max);
if max_t2 < 1e-20 {
return 0.0;
}
let squared: Vec<f64> = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
(f - t).powi(2)
})
.collect();
smooth_max(&squared, eps) / max_t2
}
LossType::SmoothDiagError { eps } => {
let eps = eps.max(f64::MIN_POSITIVE);
let ssf = fitted.values().sum::<f64>();
let sst = target.values().sum::<f64>();
if ssf.abs() < 1e-10 || sst.abs() < 1e-10 {
return f64::MAX;
}
let smoothed_abs: Vec<f64> = all_masks
.iter()
.map(|&mask| {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
smooth_abs(f / ssf - t / sst, eps)
})
.collect();
smooth_max(&smoothed_abs, eps)
}
}
}
pub fn compute_with_gradient(
&self,
fitted: &HashMap<RegionMask, f64>,
target: &HashMap<RegionMask, f64>,
) -> Option<(f64, HashMap<RegionMask, f64>)> {
let mut all_masks: Vec<RegionMask> = fitted.keys().chain(target.keys()).copied().collect();
all_masks.sort_unstable();
all_masks.dedup();
if all_masks.is_empty() {
return Some((0.0, HashMap::new()));
}
match self {
LossType::SumSquared => {
let sum_t2: f64 = target.values().map(|&v| v * v).sum();
if sum_t2 < 1e-20 {
return Some((0.0, HashMap::new()));
}
let mut grad: HashMap<RegionMask, f64> = HashMap::with_capacity(all_masks.len());
let mut total = 0.0;
for &mask in &all_masks {
let f = fitted.get(&mask).copied().unwrap_or(0.0);
let t = target.get(&mask).copied().unwrap_or(0.0);
let diff = f - t;
total += diff * diff;
grad.insert(mask, 2.0 * diff / sum_t2);
}
Some((total / sum_t2, grad))
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sse() {
let loss = LossType::sse();
let mut fitted = HashMap::new();
fitted.insert(0b001, 10.0);
fitted.insert(0b010, 20.0);
fitted.insert(0b100, 30.0);
let mut target = HashMap::new();
target.insert(0b001, 12.0);
target.insert(0b010, 18.0);
target.insert(0b100, 28.0);
let expected = 12.0 / 1252.0;
assert!((loss.compute(&fitted, &target) - expected).abs() < 1e-12);
}
#[test]
fn test_rmse() {
let loss = LossType::rmse();
let mut fitted = HashMap::new();
fitted.insert(0b001, 10.0);
fitted.insert(0b010, 20.0);
fitted.insert(0b100, 30.0);
let mut target = HashMap::new();
target.insert(0b001, 12.0);
target.insert(0b010, 18.0);
target.insert(0b100, 28.0);
let expected = (12.0_f64 / 1252.0).sqrt();
assert!((loss.compute(&fitted, &target) - expected).abs() < 1e-12);
}
#[test]
fn test_stress() {
let loss = LossType::stress();
let mut fitted = HashMap::new();
fitted.insert(0b001, 10.0);
fitted.insert(0b010, 20.0);
let mut target = HashMap::new();
target.insert(0b001, 12.0);
target.insert(0b010, 18.0);
let result = loss.compute(&fitted, &target);
assert!(
(result - 0.015385).abs() < 1e-5,
"expected 0.015385, got {}",
result
);
}
#[test]
fn test_max_absolute() {
let loss = LossType::max_absolute();
let mut fitted = HashMap::new();
fitted.insert(0b001, 10.0);
fitted.insert(0b010, 20.0);
fitted.insert(0b100, 30.0);
let mut target = HashMap::new();
target.insert(0b001, 8.0);
target.insert(0b010, 25.0);
target.insert(0b100, 28.0);
assert!((loss.compute(&fitted, &target) - 5.0 / 28.0).abs() < 1e-12);
}
#[test]
fn test_empty_target() {
let loss = LossType::sse();
let fitted = HashMap::new();
let target = HashMap::new();
assert_eq!(loss.compute(&fitted, &target), 0.0);
}
#[test]
fn test_missing_fitted_area() {
let loss = LossType::sse();
let fitted = HashMap::new();
let mut target = HashMap::new();
target.insert(0b001, 5.0);
target.insert(0b010, 3.0);
assert!((loss.compute(&fitted, &target) - 1.0).abs() < 1e-12);
}
#[test]
fn test_extra_fitted_area() {
let loss = LossType::sse();
let mut fitted = HashMap::new();
fitted.insert(0b001, 5.0);
fitted.insert(0b010, 3.0);
fitted.insert(0b100, 7.0);
let mut target = HashMap::new();
target.insert(0b001, 5.0);
target.insert(0b010, 3.0);
let expected = 49.0 / 34.0;
assert!((loss.compute(&fitted, &target) - expected).abs() < 1e-12);
}
#[test]
fn test_stress_with_zero_target() {
let loss = LossType::stress();
let mut fitted = HashMap::new();
fitted.insert(0b001, 5.0);
fitted.insert(0b010, 0.0);
fitted.insert(0b100, 3.0);
let mut target = HashMap::new();
target.insert(0b001, 0.0);
target.insert(0b010, 0.0);
target.insert(0b100, 3.0);
let result = loss.compute(&fitted, &target);
assert!(
(result - 25.0 / 34.0).abs() < 1e-10,
"expected 0.735294, got {}",
result
);
}
#[test]
fn test_equality() {
assert_eq!(LossType::sse(), LossType::SumSquared);
assert_eq!(LossType::stress(), LossType::Stress);
assert_ne!(LossType::sse(), LossType::rmse());
}
#[test]
fn test_clone() {
let loss = LossType::sse();
let cloned = loss;
assert_eq!(loss, cloned);
}
#[test]
fn test_is_smooth() {
assert!(LossType::SumSquared.is_smooth());
assert!(LossType::RootMeanSquared.is_smooth());
assert!(LossType::Stress.is_smooth());
assert!(LossType::SumSquaredRegionError.is_smooth());
assert!(!LossType::SumAbsoute.is_smooth());
assert!(!LossType::SumAbsoluteRegionError.is_smooth());
assert!(!LossType::MaxAbsolute.is_smooth());
assert!(!LossType::MaxSquared.is_smooth());
assert!(!LossType::DiagError.is_smooth());
}
#[test]
fn test_smooth_abs_basics() {
assert!((smooth_abs(0.0, 1e-3) - 0.0).abs() < 1e-3);
assert!((smooth_abs(1.0, 1e-6) - 1.0).abs() < 1e-6);
assert!((smooth_abs(-2.5, 1e-6) - 2.5).abs() < 1e-6);
assert!(smooth_abs(0.0, 0.5) >= 0.0);
assert!(smooth_abs(-3.0, 0.5) >= 0.0);
}
#[test]
fn test_smooth_max_basics() {
let xs = vec![1.0, 2.0, 3.0];
assert!((smooth_max(&xs, 1e-6) - 3.0).abs() < 1e-3);
assert_eq!(smooth_max(&[], 1.0), 0.0);
assert!((smooth_max(&[5.0], 1e-6) - 5.0).abs() < 1e-9);
let big = vec![1e6, 1e6 - 1.0];
let res = smooth_max(&big, 1.0);
assert!(res.is_finite());
assert!((res - 1e6).abs() < 5.0); }
#[test]
fn test_smooth_variants_converge_to_true_loss() {
let mut fitted = HashMap::new();
fitted.insert(0b001, 10.0);
fitted.insert(0b010, 20.0);
fitted.insert(0b100, 30.0);
let mut target = HashMap::new();
target.insert(0b001, 8.0);
target.insert(0b010, 25.0);
target.insert(0b100, 28.0);
let pairs: &[(LossType, LossType)] = &[
(LossType::SumAbsoute, LossType::smooth_sum_absolute(1e-9)),
(LossType::MaxAbsolute, LossType::smooth_max_absolute(1e-9)),
(LossType::MaxSquared, LossType::smooth_max_squared(1e-9)),
(
LossType::SumAbsoluteRegionError,
LossType::smooth_sum_absolute_region_error(1e-9),
),
(LossType::DiagError, LossType::smooth_diag_error(1e-9)),
];
for &(true_loss, smooth_loss) in pairs {
let exact = true_loss.compute(&fitted, &target);
let smoothed = smooth_loss.compute(&fitted, &target);
assert!(
(smoothed - exact).abs() < 1e-3 * exact.abs().max(1e-3),
"{:?} vs {:?}: smoothed = {}, exact = {}",
true_loss,
smooth_loss,
smoothed,
exact
);
}
}
#[test]
fn test_smooth_variants_are_smooth() {
assert!(LossType::smooth_sum_absolute(1e-3).is_smooth());
assert!(LossType::smooth_sum_absolute_region_error(1e-3).is_smooth());
assert!(LossType::smooth_max_absolute(1e-3).is_smooth());
assert!(LossType::smooth_max_squared(1e-3).is_smooth());
assert!(LossType::smooth_diag_error(1e-3).is_smooth());
}
}