use crate::error::{MathError, Result};
use crate::utils::{log_sum_exp, EPS, LOG_MIN};
#[derive(Debug, Clone)]
pub struct TransportPlan {
pub plan: Vec<Vec<f64>>,
pub cost: f64,
pub iterations: usize,
pub marginal_error: f64,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct SinkhornSolver {
regularization: f64,
max_iterations: usize,
threshold: f64,
}
impl SinkhornSolver {
pub fn new(regularization: f64, max_iterations: usize) -> Self {
Self {
regularization: regularization.max(1e-6),
max_iterations: max_iterations.max(1),
threshold: 1e-6,
}
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold.max(1e-12);
self
}
#[inline]
pub fn compute_cost_matrix(source: &[Vec<f64>], target: &[Vec<f64>]) -> Vec<Vec<f64>> {
source
.iter()
.map(|s| {
target
.iter()
.map(|t| Self::squared_euclidean(s, t))
.collect()
})
.collect()
}
#[inline(always)]
fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f64;
let mut sum1 = 0.0f64;
let mut sum2 = 0.0f64;
let mut sum3 = 0.0f64;
for i in 0..chunks {
let base = i * 4;
let d0 = a[base] - b[base];
let d1 = a[base + 1] - b[base + 1];
let d2 = a[base + 2] - b[base + 2];
let d3 = a[base + 3] - b[base + 3];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let base = chunks * 4;
for i in 0..remainder {
let d = a[base + i] - b[base + i];
sum0 += d * d;
}
sum0 + sum1 + sum2 + sum3
}
pub fn solve(
&self,
cost_matrix: &[Vec<f64>],
source_weights: &[f64],
target_weights: &[f64],
) -> Result<TransportPlan> {
let n = source_weights.len();
let m = target_weights.len();
if n == 0 || m == 0 {
return Err(MathError::empty_input("weights"));
}
if cost_matrix.len() != n || cost_matrix.iter().any(|row| row.len() != m) {
return Err(MathError::dimension_mismatch(n, cost_matrix.len()));
}
let sum_a: f64 = source_weights.iter().sum();
let sum_b: f64 = target_weights.iter().sum();
let a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
let b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
let log_k: Vec<Vec<f64>> = cost_matrix
.iter()
.map(|row| row.iter().map(|&c| -c / self.regularization).collect())
.collect();
let mut log_u = vec![0.0; n];
let mut log_v = vec![0.0; m];
let log_a: Vec<f64> = a.iter().map(|&ai| ai.ln().max(LOG_MIN)).collect();
let log_b: Vec<f64> = b.iter().map(|&bi| bi.ln().max(LOG_MIN)).collect();
let mut converged = false;
let mut iterations = 0;
let mut marginal_error = f64::INFINITY;
let mut log_terms_row = vec![0.0; m];
let mut log_terms_col = vec![0.0; n];
for iter in 0..self.max_iterations {
iterations = iter + 1;
let mut max_u_change: f64 = 0.0;
for i in 0..n {
let old_log_u = log_u[i];
for j in 0..m {
log_terms_row[j] = log_v[j] + log_k[i][j];
}
let lse = log_sum_exp(&log_terms_row);
log_u[i] = log_a[i] - lse;
max_u_change = max_u_change.max((log_u[i] - old_log_u).abs());
}
let mut max_v_change: f64 = 0.0;
for j in 0..m {
let old_log_v = log_v[j];
for i in 0..n {
log_terms_col[i] = log_u[i] + log_k[i][j];
}
let lse = log_sum_exp(&log_terms_col);
log_v[j] = log_b[j] - lse;
max_v_change = max_v_change.max((log_v[j] - old_log_v).abs());
}
let max_change = max_u_change.max(max_v_change);
if iter % 10 == 0 || max_change < self.threshold {
marginal_error = self.compute_marginal_error(&log_u, &log_v, &log_k, &a, &b);
if max_change < self.threshold && marginal_error < self.threshold * 10.0 {
converged = true;
break;
}
}
}
let plan: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..m)
.map(|j| {
let log_gamma = log_u[i] + log_k[i][j] + log_v[j];
log_gamma.exp().max(0.0)
})
.collect()
})
.collect();
let cost = plan
.iter()
.zip(cost_matrix.iter())
.map(|(gamma_row, cost_row)| {
gamma_row
.iter()
.zip(cost_row.iter())
.map(|(&g, &c)| g * c)
.sum::<f64>()
})
.sum();
Ok(TransportPlan {
plan,
cost,
iterations,
marginal_error,
converged,
})
}
fn compute_marginal_error(
&self,
log_u: &[f64],
log_v: &[f64],
log_k: &[Vec<f64>],
a: &[f64],
b: &[f64],
) -> f64 {
let n = log_u.len();
let m = log_v.len();
let mut row_error = 0.0;
for i in 0..n {
let log_row_sum = log_sum_exp(
&(0..m)
.map(|j| log_u[i] + log_k[i][j] + log_v[j])
.collect::<Vec<_>>(),
);
row_error += (log_row_sum.exp() - a[i]).abs();
}
let mut col_error = 0.0;
for j in 0..m {
let log_col_sum = log_sum_exp(
&(0..n)
.map(|i| log_u[i] + log_k[i][j] + log_v[j])
.collect::<Vec<_>>(),
);
col_error += (log_col_sum.exp() - b[j]).abs();
}
row_error + col_error
}
pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
let cost_matrix = Self::compute_cost_matrix(source, target);
let n = source.len();
let m = target.len();
let source_weights = vec![1.0 / n as f64; n];
let target_weights = vec![1.0 / m as f64; m];
let result = self.solve(&cost_matrix, &source_weights, &target_weights)?;
Ok(result.cost)
}
pub fn barycenter(
&self,
distributions: &[&[Vec<f64>]],
weights: Option<&[f64]>,
support_size: usize,
dim: usize,
) -> Result<Vec<Vec<f64>>> {
if distributions.is_empty() {
return Err(MathError::empty_input("distributions"));
}
let k = distributions.len();
let barycenter_weights = match weights {
Some(w) => {
let sum: f64 = w.iter().sum();
w.iter().map(|&wi| wi / sum).collect()
}
None => vec![1.0 / k as f64; k],
};
let mut barycenter: Vec<Vec<f64>> = (0..support_size)
.map(|i| {
let t = i as f64 / (support_size - 1).max(1) as f64;
vec![t; dim]
})
.collect();
for _outer in 0..20 {
let mut displacements = vec![vec![0.0; dim]; support_size];
for (dist_idx, &distribution) in distributions.iter().enumerate() {
let cost_matrix = Self::compute_cost_matrix(distribution, &barycenter);
let n = distribution.len();
let source_w = vec![1.0 / n as f64; n];
let target_w = vec![1.0 / support_size as f64; support_size];
if let Ok(plan) = self.solve(&cost_matrix, &source_w, &target_w) {
for j in 0..support_size {
for i in 0..n {
let weight = plan.plan[i][j] * support_size as f64;
for d in 0..dim {
displacements[j][d] += barycenter_weights[dist_idx]
* weight
* (distribution[i][d] - barycenter[j][d]);
}
}
}
}
}
let mut max_update: f64 = 0.0;
for j in 0..support_size {
for d in 0..dim {
let delta = displacements[j][d] * 0.5; barycenter[j][d] += delta;
max_update = max_update.max(delta.abs());
}
}
if max_update < EPS {
break;
}
}
Ok(barycenter)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sinkhorn_identity() {
let solver = SinkhornSolver::new(0.1, 100);
let source = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let target = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let cost = solver.distance(&source, &target).unwrap();
assert!(cost < 0.1, "Identity should have near-zero cost: {}", cost);
}
#[test]
fn test_sinkhorn_translation() {
let solver = SinkhornSolver::new(0.05, 200);
let source = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1]]).collect();
let cost = solver.distance(&source, &target).unwrap();
assert!(
cost > 0.5 && cost < 2.0,
"Translation cost should be ~1.0: {}",
cost
);
}
#[test]
fn test_sinkhorn_convergence() {
let solver = SinkhornSolver::new(0.1, 100).with_threshold(1e-6);
let cost_matrix = vec![
vec![0.0, 1.0, 2.0],
vec![1.0, 0.0, 1.0],
vec![2.0, 1.0, 0.0],
];
let a = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let b = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
assert!(result.converged, "Should converge");
assert!(
result.marginal_error < 0.01,
"Marginal error too high: {}",
result.marginal_error
);
}
#[test]
fn test_transport_plan_marginals() {
let solver = SinkhornSolver::new(0.1, 100);
let cost_matrix = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let a = vec![0.3, 0.7];
let b = vec![0.6, 0.4];
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
for (i, &ai) in a.iter().enumerate() {
let row_sum: f64 = result.plan[i].iter().sum();
assert!(
(row_sum - ai).abs() < 0.05,
"Row {} sum {} != {}",
i,
row_sum,
ai
);
}
for (j, &bj) in b.iter().enumerate() {
let col_sum: f64 = result.plan.iter().map(|row| row[j]).sum();
assert!(
(col_sum - bj).abs() < 0.05,
"Col {} sum {} != {}",
j,
col_sum,
bj
);
}
}
}