Skip to main content

lau_diffusion_agents/
transport_diffusion.rs

1//! Wasserstein gradient flow, Sinkhorn divergence.
2
3use nalgebra::{DVector, DMatrix};
4use serde::{Serialize, Deserialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DiscreteDistribution {
8    pub support: Vec<f64>,
9    pub probabilities: Vec<f64>,
10}
11
12impl DiscreteDistribution {
13    pub fn new(support: Vec<f64>, probabilities: Vec<f64>) -> Result<Self, String> {
14        let sum: f64 = probabilities.iter().sum();
15        if (sum - 1.0).abs() > 0.01 {
16            return Err(format!("Probabilities must sum to 1.0, got {}", sum));
17        }
18        Ok(Self { support, probabilities })
19    }
20
21    pub fn normalized(support: Vec<f64>, mut probabilities: Vec<f64>) -> Self {
22        let sum: f64 = probabilities.iter().sum();
23        if sum > 0.0 { for p in &mut probabilities { *p /= sum; } }
24        Self { support, probabilities }
25    }
26
27    pub fn mean(&self) -> f64 {
28        self.support.iter().zip(self.probabilities.iter()).map(|(x, p)| x * p).sum()
29    }
30
31    pub fn variance(&self) -> f64 {
32        let mu = self.mean();
33        self.support.iter().zip(self.probabilities.iter()).map(|(x, p)| p * (x - mu).powi(2)).sum()
34    }
35
36    pub fn entropy(&self) -> f64 {
37        -self.probabilities.iter().filter(|&&p| p > 1e-15).map(|&p| p * p.ln()).sum::<f64>()
38    }
39}
40
41pub fn cost_matrix(a: &DiscreteDistribution, b: &DiscreteDistribution, p: f64) -> DMatrix<f64> {
42    let n = a.support.len();
43    let m = b.support.len();
44    let mut c = DMatrix::zeros(n, m);
45    for i in 0..n {
46        for j in 0..m {
47            c[(i, j)] = (a.support[i] - b.support[j]).abs().powf(p);
48        }
49    }
50    c
51}
52
53pub fn wasserstein_1d(a: &DiscreteDistribution, b: &DiscreteDistribution) -> f64 {
54    let mut a_pairs: Vec<_> = a.support.iter().zip(a.probabilities.iter()).collect();
55    let mut b_pairs: Vec<_> = b.support.iter().zip(b.probabilities.iter()).collect();
56    a_pairs.sort_by(|x, y| x.0.partial_cmp(y.0).unwrap());
57    b_pairs.sort_by(|x, y| x.0.partial_cmp(y.0).unwrap());
58
59    let mut all_points: Vec<f64> = a.support.iter().chain(b.support.iter()).copied().collect();
60    all_points.sort_by(|a, b| a.partial_cmp(b).unwrap());
61    all_points.dedup();
62
63    let mut w1 = 0.0;
64    for window in all_points.windows(2) {
65        let dx = window[1] - window[0];
66        let mid = (window[0] + window[1]) / 2.0;
67        let fa = cdf_at(&a_pairs, mid);
68        let fb = cdf_at(&b_pairs, mid);
69        w1 += (fa - fb).abs() * dx;
70    }
71    w1
72}
73
74fn cdf_at(sorted_pairs: &[(&f64, &f64)], x: f64) -> f64 {
75    sorted_pairs.iter().filter(|(xi, _)| **xi <= x).map(|(_, p)| **p).sum()
76}
77
78pub fn sinkhorn(a: &DVector<f64>, b: &DVector<f64>, cost: &DMatrix<f64>, reg: f64, max_iter: usize, tol: f64) -> (DMatrix<f64>, DVector<f64>, DVector<f64>) {
79    let n = a.nrows();
80    let m = b.nrows();
81    let mut u = DVector::from_element(n, 1.0 / n as f64);
82    let mut v = DVector::from_element(m, 1.0 / m as f64);
83    let k = cost.map(|c| (-c / reg).exp());
84
85    for _ in 0..max_iter {
86        let kv = &k * &v;
87        for i in 0..n { if kv[i].abs() > 1e-15 { u[i] = a[i] / kv[i]; } }
88        let ktu = &k.transpose() * &u;
89        for j in 0..m { if ktu[j].abs() > 1e-15 { v[j] = b[j] / ktu[j]; } }
90        let ktu_new = &k.transpose() * &u;
91        let err: f64 = (&ktu_new - b).iter().map(|x| x.abs()).sum();
92        if err < tol { break; }
93    }
94
95    let mut plan = DMatrix::zeros(n, m);
96    for i in 0..n { for j in 0..m { plan[(i, j)] = u[i] * k[(i, j)] * v[j]; } }
97    (plan, u, v)
98}
99
100pub fn kl_divergence(p: &DVector<f64>, q: &DVector<f64>) -> f64 {
101    p.iter().zip(q.iter())
102        .filter(|(&pi, _)| pi > 1e-15)
103        .map(|(&pi, &qi)| pi * (pi / qi.max(1e-15)).ln())
104        .sum()
105}
106
107pub fn js_divergence(p: &DVector<f64>, q: &DVector<f64>) -> f64 {
108    let m = (p + q).scale(0.5);
109    0.5 * kl_divergence(p, &m) + 0.5 * kl_divergence(q, &m)
110}
111
112pub fn wasserstein_gradient_step(positions: &mut [f64], target: &DiscreteDistribution, step_size: f64) {
113    let n = positions.len();
114    positions.sort_by(|a, b| a.partial_cmp(b).unwrap());
115    for i in 0..n {
116        let quantile = (i as f64 + 0.5) / n as f64;
117        let target_val = quantile_function(target, quantile);
118        positions[i] += step_size * (target_val - positions[i]);
119    }
120}
121
122pub fn quantile_function(dist: &DiscreteDistribution, q: f64) -> f64 {
123    let mut pairs: Vec<_> = dist.support.iter().zip(dist.probabilities.iter()).collect();
124    pairs.sort_by(|a, b| a.0.partial_cmp(b.0).unwrap());
125    let mut cumsum = 0.0;
126    for (&x, &p) in &pairs {
127        cumsum += p;
128        if cumsum >= q { return x; }
129    }
130    *pairs.last().unwrap().0
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use approx::assert_relative_eq;
137
138    #[test]
139    fn test_discrete_distribution_normalized() {
140        let d = DiscreteDistribution::normalized(vec![0.0, 1.0, 2.0], vec![1.0, 2.0, 3.0]);
141        let sum: f64 = d.probabilities.iter().sum();
142        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
143    }
144
145    #[test]
146    fn test_discrete_distribution_mean() {
147        let d = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
148        assert_relative_eq!(d.mean(), 0.5, epsilon = 1e-10);
149    }
150
151    #[test]
152    fn test_discrete_distribution_variance() {
153        let d = DiscreteDistribution::normalized(vec![0.0, 2.0], vec![1.0, 1.0]);
154        assert_relative_eq!(d.variance(), 1.0, epsilon = 1e-10);
155    }
156
157    #[test]
158    fn test_entropy() {
159        let d = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
160        assert_relative_eq!(d.entropy(), 2.0_f64.ln(), epsilon = 1e-10);
161    }
162
163    #[test]
164    fn test_wasserstein_1d_positive() {
165        let a = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
166        let b = DiscreteDistribution::normalized(vec![1.0, 2.0], vec![1.0, 1.0]);
167        let w1 = wasserstein_1d(&a, &b);
168        assert!(w1 > 0.0);
169    }
170
171    #[test]
172    fn test_wasserstein_same() {
173        let a = DiscreteDistribution::normalized(vec![0.0, 1.0], vec![1.0, 1.0]);
174        let w1 = wasserstein_1d(&a, &a);
175        assert_relative_eq!(w1, 0.0, epsilon = 0.1);
176    }
177
178    #[test]
179    fn test_sinkhorn() {
180        let a = DVector::from_vec(vec![0.5, 0.5]);
181        let b = DVector::from_vec(vec![0.5, 0.5]);
182        let cost = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 0.0]);
183        let (plan, _, _) = sinkhorn(&a, &b, &cost, 0.1, 100, 1e-6);
184        assert_eq!(plan.nrows(), 2);
185    }
186
187    #[test]
188    fn test_kl_divergence_same() {
189        let p = DVector::from_vec(vec![0.5, 0.5]);
190        let kl = kl_divergence(&p, &p);
191        assert_relative_eq!(kl, 0.0, epsilon = 1e-10);
192    }
193
194    #[test]
195    fn test_kl_divergence_positive() {
196        let p = DVector::from_vec(vec![1.0, 0.0]);
197        let q = DVector::from_vec(vec![0.5, 0.5]);
198        let kl = kl_divergence(&p, &q);
199        assert!(kl >= 0.0);
200    }
201
202    #[test]
203    fn test_js_symmetry() {
204        let p = DVector::from_vec(vec![0.8, 0.2]);
205        let q = DVector::from_vec(vec![0.3, 0.7]);
206        let js_pq = js_divergence(&p, &q);
207        let js_qp = js_divergence(&q, &p);
208        assert_relative_eq!(js_pq, js_qp, epsilon = 1e-10);
209    }
210
211    #[test]
212    fn test_quantile_function() {
213        let d = DiscreteDistribution::normalized(vec![0.0, 1.0, 2.0], vec![1.0, 1.0, 1.0]);
214        let q25 = quantile_function(&d, 0.25);
215        assert!(q25 >= 0.0);
216    }
217
218    #[test]
219    fn test_wasserstein_gradient_step() {
220        let mut positions = vec![0.0, 0.1, 0.2];
221        let target = DiscreteDistribution::normalized(vec![5.0, 6.0, 7.0], vec![1.0, 1.0, 1.0]);
222        wasserstein_gradient_step(&mut positions, &target, 0.5);
223        assert!(positions[0] > 0.0);
224    }
225}