1use nalgebra::{DVector, DMatrix};
4use serde::{Serialize, Deserialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DiscreteDistribution {
8 pub support: Vec<f64>,
9 pub probabilities: Vec<f64>,
10}
11
12impl DiscreteDistribution {
13 pub fn new(support: Vec<f64>, probabilities: Vec<f64>) -> Result<Self, String> {
14 let sum: f64 = probabilities.iter().sum();
15 if (sum - 1.0).abs() > 0.01 {
16 return Err(format!("Probabilities must sum to 1.0, got {}", sum));
17 }
18 Ok(Self { support, probabilities })
19 }
20
21 pub fn normalized(support: Vec<f64>, mut probabilities: Vec<f64>) -> Self {
22 let sum: f64 = probabilities.iter().sum();
23 if sum > 0.0 { for p in &mut probabilities { *p /= sum; } }
24 Self { support, probabilities }
25 }
26
27 pub fn mean(&self) -> f64 {
28 self.support.iter().zip(self.probabilities.iter()).map(|(x, p)| x * p).sum()
29 }
30
31 pub fn variance(&self) -> f64 {
32 let mu = self.mean();
33 self.support.iter().zip(self.probabilities.iter()).map(|(x, p)| p * (x - mu).powi(2)).sum()
34 }
35
36 pub fn entropy(&self) -> f64 {
37 -self.probabilities.iter().filter(|&&p| p > 1e-15).map(|&p| p * p.ln()).sum::<f64>()
38 }
39}
40
41pub fn cost_matrix(a: &DiscreteDistribution, b: &DiscreteDistribution, p: f64) -> DMatrix<f64> {
42 let n = a.support.len();
43 let m = b.support.len();
44 let mut c = DMatrix::zeros(n, m);
45 for i in 0..n {
46 for j in 0..m {
47 c[(i, j)] = (a.support[i] - b.support[j]).abs().powf(p);
48 }
49 }
50 c
51}
52
53pub fn wasserstein_1d(a: &DiscreteDistribution, b: &DiscreteDistribution) -> f64 {
54 let mut a_pairs: Vec<_> = a.support.iter().zip(a.probabilities.iter()).collect();
55 let mut b_pairs: Vec<_> = b.support.iter().zip(b.probabilities.iter()).collect();
56 a_pairs.sort_by(|x, y| x.0.partial_cmp(y.0).unwrap());
57 b_pairs.sort_by(|x, y| x.0.partial_cmp(y.0).unwrap());
58
59 let mut all_points: Vec<f64> = a.support.iter().chain(b.support.iter()).copied().collect();
60 all_points.sort_by(|a, b| a.partial_cmp(b).unwrap());
61 all_points.dedup();
62
63 let mut w1 = 0.0;
64 for window in all_points.windows(2) {
65 let dx = window[1] - window[0];
66 let mid = (window[0] + window[1]) / 2.0;
67 let fa = cdf_at(&a_pairs, mid);
68 let fb = cdf_at(&b_pairs, mid);
69 w1 += (fa - fb).abs() * dx;
70 }
71 w1
72}
73
74fn cdf_at(sorted_pairs: &[(&f64, &f64)], x: f64) -> f64 {
75 sorted_pairs.iter().filter(|(xi, _)| **xi <= x).map(|(_, p)| **p).sum()
76}
77
78pub fn sinkhorn(a: &DVector<f64>, b: &DVector<f64>, cost: &DMatrix<f64>, reg: f64, max_iter: usize, tol: f64) -> (DMatrix<f64>, DVector<f64>, DVector<f64>) {
79 let n = a.nrows();
80 let m = b.nrows();
81 let mut u = DVector::from_element(n, 1.0 / n as f64);
82 let mut v = DVector::from_element(m, 1.0 / m as f64);
83 let k = cost.map(|c| (-c / reg).exp());
84
85 for _ in 0..max_iter {
86 let kv = &k * &v;
87 for i in 0..n { if kv[i].abs() > 1e-15 { u[i] = a[i] / kv[i]; } }
88 let ktu = &k.transpose() * &u;
89 for j in 0..m { if ktu[j].abs() > 1e-15 { v[j] = b[j] / ktu[j]; } }
90 let ktu_new = &k.transpose() * &u;
91 let err: f64 = (&ktu_new - b).iter().map(|x| x.abs()).sum();
92 if err < tol { break; }
93 }
94
95 let mut plan = DMatrix::zeros(n, m);
96 for i in 0..n { for j in 0..m { plan[(i, j)] = u[i] * k[(i, j)] * v[j]; } }
97 (plan, u, v)
98}
99
100pub fn kl_divergence(p: &DVector<f64>, q: &DVector<f64>) -> f64 {
101 p.iter().zip(q.iter())
102 .filter(|(&pi, _)| pi > 1e-15)
103 .map(|(&pi, &qi)| pi * (pi / qi.max(1e-15)).ln())
104 .sum()
105}
106
107pub fn js_divergence(p: &DVector<f64>, q: &DVector<f64>) -> f64 {
108 let m = (p + q).scale(0.5);
109 0.5 * kl_divergence(p, &m) + 0.5 * kl_divergence(q, &m)
110}
111
112pub fn wasserstein_gradient_step(positions: &mut [f64], target: &DiscreteDistribution, step_size: f64) {
113 let n = positions.len();
114 positions.sort_by(|a, b| a.partial_cmp(b).unwrap());
115 for i in 0..n {
116 let quantile = (i as f64 + 0.5) / n as f64;
117 let target_val = quantile_function(target, quantile);
118 positions[i] += step_size * (target_val - positions[i]);
119 }
120}
121
122pub fn quantile_function(dist: &DiscreteDistribution, q: f64) -> f64 {
123 let mut pairs: Vec<_> = dist.support.iter().zip(dist.probabilities.iter()).collect();
124 pairs.sort_by(|a, b| a.0.partial_cmp(b.0).unwrap());
125 let mut cumsum = 0.0;
126 for (&x, &p) in &pairs {
127 cumsum += p;
128 if cumsum >= q { return x; }
129 }
130 *pairs.last().unwrap().0
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use approx::assert_relative_eq;
137
138 #[test]
139 fn test_discrete_distribution_normalized() {
140 let d = DiscreteDistribution::normalized(vec![0.0, 1.0, 2.0], vec![1.0, 2.0, 3.0]);
141 let sum: f64 = d.probabilities.iter().sum();
142 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
143 }
144
145 #[test]
146 fn test_discrete_distribution_mean() {
147 let d = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
148 assert_relative_eq!(d.mean(), 0.5, epsilon = 1e-10);
149 }
150
151 #[test]
152 fn test_discrete_distribution_variance() {
153 let d = DiscreteDistribution::normalized(vec![0.0, 2.0], vec![1.0, 1.0]);
154 assert_relative_eq!(d.variance(), 1.0, epsilon = 1e-10);
155 }
156
157 #[test]
158 fn test_entropy() {
159 let d = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
160 assert_relative_eq!(d.entropy(), 2.0_f64.ln(), epsilon = 1e-10);
161 }
162
163 #[test]
164 fn test_wasserstein_1d_positive() {
165 let a = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
166 let b = DiscreteDistribution::normalized(vec![1.0, 2.0], vec![1.0, 1.0]);
167 let w1 = wasserstein_1d(&a, &b);
168 assert!(w1 > 0.0);
169 }
170
171 #[test]
172 fn test_wasserstein_same() {
173 let a = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
174 let w1 = wasserstein_1d(&a, &a);
175 assert_relative_eq!(w1, 0.0, epsilon = 0.1);
176 }
177
178 #[test]
179 fn test_sinkhorn() {
180 let a = DVector::from_vec(vec![0.5, 0.5]);
181 let b = DVector::from_vec(vec![0.5, 0.5]);
182 let cost = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 0.0]);
183 let (plan, _, _) = sinkhorn(&a, &b, &cost, 0.1, 100, 1e-6);
184 assert_eq!(plan.nrows(), 2);
185 }
186
187 #[test]
188 fn test_kl_divergence_same() {
189 let p = DVector::from_vec(vec![0.5, 0.5]);
190 let kl = kl_divergence(&p, &p);
191 assert_relative_eq!(kl, 0.0, epsilon = 1e-10);
192 }
193
194 #[test]
195 fn test_kl_divergence_positive() {
196 let p = DVector::from_vec(vec![1.0, 0.0]);
197 let q = DVector::from_vec(vec![0.5, 0.5]);
198 let kl = kl_divergence(&p, &q);
199 assert!(kl >= 0.0);
200 }
201
202 #[test]
203 fn test_js_symmetry() {
204 let p = DVector::from_vec(vec![0.8, 0.2]);
205 let q = DVector::from_vec(vec![0.3, 0.7]);
206 let js_pq = js_divergence(&p, &q);
207 let js_qp = js_divergence(&q, &p);
208 assert_relative_eq!(js_pq, js_qp, epsilon = 1e-10);
209 }
210
211 #[test]
212 fn test_quantile_function() {
213 let d = DiscreteDistribution::normalized(vec![0.0, 1.0, 2.0], vec![1.0, 1.0, 1.0]);
214 let q25 = quantile_function(&d, 0.25);
215 assert!(q25 >= 0.0);
216 }
217
218 #[test]
219 fn test_wasserstein_gradient_step() {
220 let mut positions = vec![0.0, 0.1, 0.2];
221 let target = DiscreteDistribution::normalized(vec![5.0, 6.0, 7.0], vec![1.0, 1.0, 1.0]);
222 wasserstein_gradient_step(&mut positions, &target, 0.5);
223 assert!(positions[0] > 0.0);
224 }
225}