use super::GroundTruth;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Metric {
MeanSquaredError,
CosineSimilarity,
KLDivergence,
WassersteinDistance,
L2Norm,
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_field_names)]
pub struct Delta {
mean_delta: f32,
std_delta: f32,
percent: f32,
sign_flipped: bool,
cosine: Option<f32>,
kl_div: Option<f32>,
}
impl Delta {
#[must_use]
pub fn compute(our: &GroundTruth, gt: &GroundTruth) -> Self {
let mean_delta = (our.mean() - gt.mean()).abs();
let std_delta = (our.std() - gt.std()).abs();
let ref_val = gt.std().abs().max(0.001);
let percent = ((mean_delta + std_delta) / ref_val) * 100.0;
let sign_flipped = our.mean().signum() != gt.mean().signum()
&& our.mean().abs() > 0.01
&& gt.mean().abs() > 0.01;
let cosine = match (our.data(), gt.data()) {
(Some(a), Some(b)) if !a.is_empty() && !b.is_empty() => {
Some(Self::cosine_similarity(a, b))
}
_ => None,
};
Self {
mean_delta,
std_delta,
percent,
sign_flipped,
cosine,
kl_div: None,
}
}
#[must_use]
pub fn from_percent(percent: f32) -> Self {
Self {
mean_delta: 0.0,
std_delta: 0.0,
percent,
sign_flipped: false,
cosine: None,
kl_div: None,
}
}
#[must_use]
pub fn from_stats(mean_delta: f32, std_delta: f32) -> Self {
let percent = (mean_delta + std_delta) * 100.0;
Self {
mean_delta,
std_delta,
percent,
sign_flipped: false,
cosine: None,
kl_div: None,
}
}
#[must_use]
pub fn mean_delta(&self) -> f32 {
self.mean_delta
}
#[must_use]
pub fn std_delta(&self) -> f32 {
self.std_delta
}
#[must_use]
pub fn percent(&self) -> f32 {
self.percent
}
#[must_use]
pub fn is_sign_flipped(&self) -> bool {
self.sign_flipped
}
#[must_use]
pub fn cosine(&self) -> Option<f32> {
self.cosine
}
#[must_use]
pub fn kl_divergence_value(&self) -> Option<f32> {
self.kl_div
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
crate::nn::functional::cosine_similarity_slice(a, b)
}
#[must_use]
pub fn kl_divergence(p: &[f32], q: &[f32]) -> f32 {
if p.len() != q.len() || p.is_empty() {
return f32::INFINITY;
}
let epsilon = 1e-10;
let mut kl = 0.0;
for (pi, qi) in p.iter().zip(q.iter()) {
if *pi > epsilon {
let qi_safe = qi.max(epsilon);
kl += pi * (pi / qi_safe).ln();
}
}
kl
}
#[must_use]
pub fn wasserstein_1d(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let mut a_sorted: Vec<f32> = a.to_vec();
let mut b_sorted: Vec<f32> = b.to_vec();
a_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
b_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
let n = a_sorted.len().max(b_sorted.len());
let a_resampled = Self::resample(&a_sorted, n);
let b_resampled = Self::resample(&b_sorted, n);
a_resampled
.iter()
.zip(b_resampled.iter())
.map(|(x, y)| (x - y).abs())
.sum::<f32>()
/ n as f32
}
fn resample(data: &[f32], target_len: usize) -> Vec<f32> {
if data.len() == target_len {
return data.to_vec();
}
let mut result = Vec::with_capacity(target_len);
for i in 0..target_len {
let t = i as f32 / (target_len - 1) as f32;
let idx = t * (data.len() - 1) as f32;
let low = idx.floor() as usize;
let high = (low + 1).min(data.len() - 1);
let frac = idx - low as f32;
result.push(data[low] * (1.0 - frac) + data[high] * frac);
}
result
}
#[must_use]
pub fn compute_all(our: &GroundTruth, gt: &GroundTruth) -> Self {
let mut delta = Self::compute(our, gt);
if let (Some(our_data), Some(gt_data)) = (our.data(), gt.data()) {
let our_sum: f32 = our_data.iter().map(|x| x.abs()).sum();
let gt_sum: f32 = gt_data.iter().map(|x| x.abs()).sum();
if our_sum > 1e-10 && gt_sum > 1e-10 {
let p: Vec<f32> = our_data.iter().map(|x| x.abs() / our_sum).collect();
let q: Vec<f32> = gt_data.iter().map(|x| x.abs() / gt_sum).collect();
delta.kl_div = Some(Self::kl_divergence(&p, &q));
}
}
delta
}
}
impl Default for Delta {
fn default() -> Self {
Self::from_percent(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let cos = Delta::cosine_similarity(&a, &b);
assert!((cos - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_opposite() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![-1.0, -2.0, -3.0];
let cos = Delta::cosine_similarity(&a, &b);
assert!((cos - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_cosine_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let cos = Delta::cosine_similarity(&a, &b);
assert!(cos.abs() < 1e-6);
}
#[test]
fn test_kl_identical() {
let p = vec![0.25, 0.25, 0.25, 0.25];
let q = vec![0.25, 0.25, 0.25, 0.25];
let kl = Delta::kl_divergence(&p, &q);
assert!(kl.abs() < 1e-6);
}
#[test]
fn test_kl_different() {
let p = vec![0.9, 0.1];
let q = vec![0.5, 0.5];
let kl = Delta::kl_divergence(&p, &q);
assert!(kl > 0.0);
}
#[test]
fn test_wasserstein_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let w = Delta::wasserstein_1d(&a, &b);
assert!(w < 1e-6);
}
#[test]
fn test_wasserstein_shifted() {
let a = vec![0.0, 1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let w = Delta::wasserstein_1d(&a, &b);
assert!((w - 1.0).abs() < 1e-6);
}
#[test]
fn test_delta_compute_with_data() {
let our = GroundTruth::from_slice(&[1.0, 2.0, 3.0]);
let gt = GroundTruth::from_slice(&[1.0, 2.0, 3.0]);
let delta = Delta::compute(&our, >);
assert!(delta.cosine().is_some());
assert!((delta.cosine().unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn test_delta_compute_sign_flip() {
let our = GroundTruth::from_stats(0.5, 1.0);
let gt = GroundTruth::from_stats(-0.5, 1.0);
let delta = Delta::compute(&our, >);
assert!(delta.is_sign_flipped());
}
#[test]
fn test_delta_compute_no_sign_flip() {
let our = GroundTruth::from_stats(0.5, 1.0);
let gt = GroundTruth::from_stats(0.6, 1.0);
let delta = Delta::compute(&our, >);
assert!(!delta.is_sign_flipped());
}
#[test]
fn test_delta_from_percent() {
let delta = Delta::from_percent(42.5);
assert!((delta.percent() - 42.5).abs() < 1e-6);
assert!((delta.mean_delta() - 0.0).abs() < 1e-6);
assert!((delta.std_delta() - 0.0).abs() < 1e-6);
}
#[test]
fn test_delta_from_stats() {
let delta = Delta::from_stats(0.1, 0.2);
assert!((delta.mean_delta() - 0.1).abs() < 1e-5);
assert!((delta.std_delta() - 0.2).abs() < 1e-5);
assert!((delta.percent() - 30.0).abs() < 1e-4);
}
#[test]
fn test_delta_default() {
let delta = Delta::default();
assert!((delta.percent() - 0.0).abs() < 1e-6);
}
#[test]
fn test_delta_kl_divergence_value() {
let delta = Delta::from_percent(5.0);
assert!(delta.kl_divergence_value().is_none());
}
#[test]
fn test_delta_compute_all_with_kl() {
let our = GroundTruth::from_slice(&[0.1, 0.2, 0.3, 0.4]);
let gt = GroundTruth::from_slice(&[0.1, 0.2, 0.3, 0.4]);
let delta = Delta::compute_all(&our, >);
assert!(delta.kl_divergence_value().is_some());
assert!(delta.kl_divergence_value().unwrap().abs() < 1e-6);
}
#[test]
fn test_cosine_empty() {
let cos = Delta::cosine_similarity(&[], &[]);
assert!((cos - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let cos = Delta::cosine_similarity(&a, &b);
assert!((cos - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_zero_norm() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 3.0];
let cos = Delta::cosine_similarity(&a, &b);
assert!((cos - 0.0).abs() < 1e-6);
}
#[test]
fn test_kl_empty() {
let kl = Delta::kl_divergence(&[], &[]);
assert!(kl.is_infinite());
}
#[test]
fn test_kl_different_lengths() {
let p = vec![0.5, 0.5];
let q = vec![0.33, 0.33, 0.34];
let kl = Delta::kl_divergence(&p, &q);
assert!(kl.is_infinite());
}
#[test]
fn test_wasserstein_empty() {
let w = Delta::wasserstein_1d(&[], &[1.0]);
assert!((w - 0.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let w = Delta::wasserstein_1d(&a, &b);
assert!(w.is_finite());
}
}