fast_distances/distances/
ll_dirichlet.rs1use std::{f64::consts::PI, iter::Sum};
2
3use num::Float;
4
5fn log_single_beta<T: Float>(x: T) -> T {
6 T::ln(T::from(2.0).unwrap()) * (-T::from(2.0).unwrap() * x + T::from(0.5).unwrap())
7 + T::from(0.5).unwrap() * (T::from(2.0).unwrap() * T::from(PI).unwrap() / x).ln()
8 + T::from(0.125).unwrap() / x
9}
10
11fn log_beta<T: Float>(x: T, y: T) -> T
12where
13 T: Float,
14{
15 let a = x.min(y);
16 let b = x.max(y);
17
18 if b < T::from(5.0).unwrap() {
19 let mut value = -T::ln(b);
20 for i in 1..a.to_i64().unwrap() {
21 let ii = T::from(i).unwrap();
22 value = value + T::ln(ii) - T::ln(b + ii);
23 }
24 value
25 } else {
26 log_single_beta(x) + log_single_beta(y) - log_single_beta(x + y)
27 }
28}
29
30pub fn ll_dirichlet<T: Float>(data1: &[T], data2: &[T]) -> T
62where
63 T: Float + Sum,
64{
65 let n1: T = data1.iter().copied().sum();
66 let n2: T = data2.iter().copied().sum();
67
68 let mut log_b = T::from(0.0).unwrap();
69 let mut self_denom1 = T::from(0.0).unwrap();
70 let mut self_denom2 = T::from(0.0).unwrap();
71
72 for i in 0..data1.len() {
73 if data1[i] * data2[i] > T::from(0.9).unwrap() {
74 log_b = log_b + log_beta(data1[i], data2[i]);
75 self_denom1 = self_denom1 + log_single_beta(data1[i]);
76 self_denom2 = self_denom2 + log_single_beta(data2[i]);
77 } else {
78 if data1[i] > T::from(0.9).unwrap() {
79 self_denom1 = self_denom1 + log_single_beta(data1[i]);
80 }
81
82 if data2[i] > T::from(0.9).unwrap() {
83 self_denom2 = self_denom2 + log_single_beta(data2[i]);
84 }
85 }
86 }
87
88 T::sqrt(
89 T::from(1.0).unwrap() / n2
90 * (log_b - log_beta(n1, n2) - (self_denom2 - log_single_beta(n2)))
91 + T::from(1.0).unwrap() / n1
92 * (log_b - log_beta(n2, n1) - (self_denom1 - log_single_beta(n1))),
93 )
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_ll_dirichlet_f32() {
102 let data1: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
103 let data2: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
104
105 let result = ll_dirichlet(&data1, &data2);
106 assert_eq!(result, 0.36789307, "ll_dirichlet with f32");
107 }
108
109 #[test]
110 fn test_ll_dirichlet_f64() {
111 let data1: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
112 let data2: Vec<f64> = vec![5.0, 6.0, 7.0, 8.0];
113
114 let result = ll_dirichlet(&data1, &data2);
115 assert_eq!(result, 0.36789301898248805, "ll_dirichlet with f64");
116 }
117}