hipparchus_metrics/
distribution.rs1use hipparchus_mean::Fp;
2use crate::metrics::Metrics;
3
4#[repr(i32)]
6#[derive(Clone,PartialEq,Debug)]
7pub enum DistributionMetrics
8{
9 CrossEntropy = 1,
11
12 KullbackLeiblerDivergence = 2,
14
15 JensenShannonDivergence = 3,
17
18 Hellinger = 4,
20}
21
22impl<T:Fp> Metrics<&[T], T> for DistributionMetrics
23{
24 fn measure(self, x:&[T], y:&[T]) -> T
25 {
26 let it = x.iter().zip(y.iter());
27 match self
28 {
29 DistributionMetrics::CrossEntropy => it.fold(T::zero(), | agg, (&p, &q)|
30 {
31 agg - p.mul(q.ln())
32 }),
33 DistributionMetrics::KullbackLeiblerDivergence => it.fold(T::zero(), | agg, (&p, &q)|
34 {
35 agg + p.mul(p.ln()-q.ln())
36 }),
37 DistributionMetrics::JensenShannonDivergence =>
38 {
39 let half = T::from(0.5f64).unwrap();
40 let v = it.map(|(p, &q)| p.add(q).mul(half) ).collect::<Vec<T>>();
41 let m = v.as_slice().try_into().unwrap();
42 let klxm = DistributionMetrics::KullbackLeiblerDivergence.measure(x, m);
43 let klym = DistributionMetrics::KullbackLeiblerDivergence.measure(y, m);
44 (klxm+klym) * half
45 }
46 DistributionMetrics::Hellinger => it.fold(T::zero(), | agg, (&p, &q)|
47 {
48 agg + (p.sqrt() - q.sqrt()).powi(2)
49 }).div(T::from(2).unwrap()).sqrt(),
50 }
51 }
52}
53
54#[cfg(test)]
55mod tests
56{
57 use super::*;
58 use rstest::*;
59 use float_cmp::assert_approx_eq;
60
61 #[rstest]
62 #[case(vec![0.5, 0.5], vec![0.5, 0.5], DistributionMetrics::CrossEntropy, 0.693147)]
63 #[case(vec![0.00001, 0.99999], vec![0.99999, 0.00001], DistributionMetrics::KullbackLeiblerDivergence, 11.512684)]
64 #[case(vec![0.00001, 0.99999], vec![0.99999, 0.00001], DistributionMetrics::JensenShannonDivergence, 0.6930221)]
65 #[case(vec![0.0, 1.0], vec![1.0, 0.0], DistributionMetrics::Hellinger, 1.0)]
66 fn test_distribution(#[case] d1: Vec<f32>, #[case] d2: Vec<f32>, #[case] m: DistributionMetrics, #[case] expected :f32)
67 {
68 let actual = m.measure(&d1, &d2);
69 assert_approx_eq!(f32, expected, actual);
70 }
71
72 #[rstest]
73 #[case(vec![0.5, 0.5], DistributionMetrics::KullbackLeiblerDivergence)]
74 #[case(vec![0.5, 0.5], DistributionMetrics::JensenShannonDivergence)]
75 #[case(vec![0.5, 0.5], DistributionMetrics::Hellinger)]
76 fn test_distribution_zero(#[case] d: Vec<f32>, #[case] m: DistributionMetrics)
77 {
78 let actual = m.measure(&d, &d);
79 assert_approx_eq!(f32, 0.0, actual);
80 }
81}