fast_distances/distances/
ll_dirichlet.rs

1use std::{f64::consts::PI, iter::Sum};
2
3use num::Float;
4
5fn log_single_beta<T: Float>(x: T) -> T {
6    T::ln(T::from(2.0).unwrap()) * (-T::from(2.0).unwrap() * x + T::from(0.5).unwrap())
7        + T::from(0.5).unwrap() * (T::from(2.0).unwrap() * T::from(PI).unwrap() / x).ln()
8        + T::from(0.125).unwrap() / x
9}
10
11fn log_beta<T: Float>(x: T, y: T) -> T
12where
13    T: Float,
14{
15    let a = x.min(y);
16    let b = x.max(y);
17
18    if b < T::from(5.0).unwrap() {
19        let mut value = -T::ln(b);
20        for i in 1..a.to_i64().unwrap() {
21            let ii = T::from(i).unwrap();
22            value = value + T::ln(ii) - T::ln(b + ii);
23        }
24        value
25    } else {
26        log_single_beta(x) + log_single_beta(y) - log_single_beta(x + y)
27    }
28}
29
30/// Calculates the symmetric relative log likelihood (log Dirichlet likelihood) of rolling
31/// `data2` versus `data1` in `n2` trials on a die that rolled `data1` in `n1` trials.
32///
33/// The formula used is based on the Dirichlet-Multinomial model, and it computes the difference
34/// in likelihood between the two sets of data under a Dirichlet distribution assumption. This
35/// measure is useful for comparing the distribution of counts between two categorical datasets,
36/// typically for hypothesis testing or evaluating model performance when categorical data is involved.
37///
38/// The equation is as follows:
39///
40/// ..math::
41///     D(data1, data2) = \sqrt{ \frac{1}{n2} \left( \log \beta(data1, data2) - \log \beta(n1, n2) - ( \text{self\_denom2} - \log \text{single\_beta}(n2) ) \right) + \frac{1}{n1} \left( \log \beta(data2, data1) - \log \beta(n2, n1) - ( \text{self\_denom1} - \log \text{single\_beta}(n1) ) \right) }
42///
43/// # Arguments
44///
45/// * `data1` - A slice of `T` values representing the first data set (e.g., counts from one die roll).
46/// * `data2` - A slice of `T` values representing the second data set (e.g., counts from another die roll).
47///
48/// # Returns
49///
50/// Returns a `T` value representing the log likelihood of `data2` relative to `data1`. A higher value indicates that `data2` is more likely given `data1`.
51///
52/// # Examples
53///
54/// ```
55/// use fast_distances::*;
56/// let data1: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
57/// let data2: Vec<f64> = vec![5.0, 6.0, 7.0, 8.0];
58/// let result = ll_dirichlet(&data1, &data2);
59/// println!("Log Dirichlet likelihood: {}", result);
60/// ```
61pub fn ll_dirichlet<T: Float>(data1: &[T], data2: &[T]) -> T
62where
63    T: Float + Sum,
64{
65    let n1: T = data1.iter().copied().sum();
66    let n2: T = data2.iter().copied().sum();
67
68    let mut log_b = T::from(0.0).unwrap();
69    let mut self_denom1 = T::from(0.0).unwrap();
70    let mut self_denom2 = T::from(0.0).unwrap();
71
72    for i in 0..data1.len() {
73        if data1[i] * data2[i] > T::from(0.9).unwrap() {
74            log_b = log_b + log_beta(data1[i], data2[i]);
75            self_denom1 = self_denom1 + log_single_beta(data1[i]);
76            self_denom2 = self_denom2 + log_single_beta(data2[i]);
77        } else {
78            if data1[i] > T::from(0.9).unwrap() {
79                self_denom1 = self_denom1 + log_single_beta(data1[i]);
80            }
81
82            if data2[i] > T::from(0.9).unwrap() {
83                self_denom2 = self_denom2 + log_single_beta(data2[i]);
84            }
85        }
86    }
87
88    T::sqrt(
89        T::from(1.0).unwrap() / n2
90            * (log_b - log_beta(n1, n2) - (self_denom2 - log_single_beta(n2)))
91            + T::from(1.0).unwrap() / n1
92                * (log_b - log_beta(n2, n1) - (self_denom1 - log_single_beta(n1))),
93    )
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_ll_dirichlet_f32() {
102        let data1: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
103        let data2: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
104
105        let result = ll_dirichlet(&data1, &data2);
106        assert_eq!(result, 0.36789307, "ll_dirichlet with f32");
107    }
108
109    #[test]
110    fn test_ll_dirichlet_f64() {
111        let data1: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
112        let data2: Vec<f64> = vec![5.0, 6.0, 7.0, 8.0];
113
114        let result = ll_dirichlet(&data1, &data2);
115        assert_eq!(result, 0.36789301898248805, "ll_dirichlet with f64");
116    }
117}