hipparchus_metrics/
distribution.rs

1use hipparchus_mean::Fp;
2use crate::metrics::Metrics;
3
4/// Metrics for distributions
5#[repr(i32)]
6#[derive(Clone,PartialEq,Debug)]
7pub enum DistributionMetrics
8{
9    /// Dot product
10    CrossEntropy = 1,
11
12    /// KL divergence
13    KullbackLeiblerDivergence = 2,
14
15    /// JS divergence
16    JensenShannonDivergence = 3,
17
18    /// Hellinger distance
19    Hellinger = 4,
20}
21
22impl<T:Fp> Metrics<&[T], T> for DistributionMetrics
23{
24    fn measure(self, x:&[T], y:&[T]) -> T
25    {
26        let it = x.iter().zip(y.iter());
27        match self
28        {
29            DistributionMetrics::CrossEntropy => it.fold(T::zero(), | agg, (&p, &q)|
30            {
31                agg - p.mul(q.ln())
32            }),
33            DistributionMetrics::KullbackLeiblerDivergence => it.fold(T::zero(), | agg, (&p, &q)|
34            {
35                agg + p.mul(p.ln()-q.ln())
36            }),
37            DistributionMetrics::JensenShannonDivergence => 
38            {
39                let half = T::from(0.5f64).unwrap();
40                let v = it.map(|(p, &q)| p.add(q).mul(half) ).collect::<Vec<T>>();
41                let m = v.as_slice().try_into().unwrap();
42                let klxm = DistributionMetrics::KullbackLeiblerDivergence.measure(x, m);
43                let klym = DistributionMetrics::KullbackLeiblerDivergence.measure(y, m);
44                (klxm+klym) * half
45            }
46            DistributionMetrics::Hellinger => it.fold(T::zero(), | agg, (&p, &q)|
47            {
48                agg + (p.sqrt() - q.sqrt()).powi(2)
49            }).div(T::from(2).unwrap()).sqrt(),
50        }
51    }
52}
53
54#[cfg(test)]
55mod tests 
56{
57    use super::*;
58    use rstest::*;
59    use float_cmp::assert_approx_eq;
60
61    #[rstest]
62    #[case(vec![0.5, 0.5], vec![0.5, 0.5], DistributionMetrics::CrossEntropy, 0.693147)]
63    #[case(vec![0.00001, 0.99999], vec![0.99999, 0.00001], DistributionMetrics::KullbackLeiblerDivergence, 11.512684)]
64    #[case(vec![0.00001, 0.99999], vec![0.99999, 0.00001], DistributionMetrics::JensenShannonDivergence, 0.6930221)]
65    #[case(vec![0.0, 1.0], vec![1.0, 0.0], DistributionMetrics::Hellinger, 1.0)]
66    fn test_distribution(#[case] d1: Vec<f32>, #[case] d2: Vec<f32>, #[case] m: DistributionMetrics, #[case] expected :f32)
67    {
68        let actual = m.measure(&d1, &d2);
69        assert_approx_eq!(f32, expected, actual);
70    }
71
72    #[rstest]
73    #[case(vec![0.5, 0.5], DistributionMetrics::KullbackLeiblerDivergence)]
74    #[case(vec![0.5, 0.5], DistributionMetrics::JensenShannonDivergence)]
75    #[case(vec![0.5, 0.5], DistributionMetrics::Hellinger)]
76    fn test_distribution_zero(#[case] d: Vec<f32>, #[case] m: DistributionMetrics)
77    {
78        let actual = m.measure(&d, &d);
79        assert_approx_eq!(f32, 0.0, actual);
80    }
81}