use crate::{Error, Result};
fn standardize(values: &[f64]) -> Vec<f64> {
let n = values.len();
if n == 0 {
return vec![];
}
let mean = values.iter().sum::<f64>() / n as f64;
let mut var = 0.0;
for &v in values {
var += (v - mean) * (v - mean);
}
var /= n as f64;
let std = var.sqrt();
let denom = (std + 1e-12).max(1e-12);
values.iter().map(|&v| (v - mean) / denom).collect()
}
fn target_grid(n: usize) -> Vec<f64> {
if n == 0 {
return vec![];
}
if n == 1 {
return vec![0.0];
}
(0..n)
.map(|j| -1.0 + 2.0 * (j as f64) / ((n - 1) as f64))
.collect()
}
#[derive(Debug, Clone)]
pub struct SinkhornConfig {
pub epsilon: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for SinkhornConfig {
fn default() -> Self {
Self {
epsilon: 0.1,
max_iter: 100,
tol: 1e-6,
}
}
}
pub fn sinkhorn_permutation(values: &[f64], config: &SinkhornConfig) -> Result<Vec<f64>> {
let n = values.len();
if n == 0 {
return Err(Error::EmptyInput);
}
if config.epsilon <= 0.0 {
return Err(Error::InvalidTemperature(config.epsilon));
}
if n == 1 {
return Ok(vec![1.0]);
}
let z = standardize(values);
let t = target_grid(n);
let mut cost = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let diff = z[i] - t[j];
cost[i * n + j] = diff * diff;
}
}
fn log_sum_exp(xs: &[f64]) -> f64 {
let mut m = f64::NEG_INFINITY;
for &x in xs {
if x > m {
m = x;
}
}
if !m.is_finite() {
return f64::NEG_INFINITY;
}
let mut s = 0.0;
for &x in xs {
s += (x - m).exp();
}
m + s.ln()
}
let inv_eps = 1.0 / config.epsilon;
let min_cost = cost.iter().copied().fold(f64::INFINITY, |a, b| a.min(b));
let mut log_k = vec![0.0; n * n];
for idx in 0..n * n {
log_k[idx] = -((cost[idx] - min_cost) * inv_eps);
}
let mut log_u = vec![0.0; n];
let mut log_v = vec![0.0; n];
let mut scratch = vec![0.0; n];
for _ in 0..config.max_iter {
let mut log_u_new = vec![0.0; n];
for i in 0..n {
for j in 0..n {
scratch[j] = log_k[i * n + j] + log_v[j];
}
let lse = log_sum_exp(&scratch);
log_u_new[i] = -lse;
}
let mut log_v_new = vec![0.0; n];
for j in 0..n {
for i in 0..n {
scratch[i] = log_k[i * n + j] + log_u_new[i];
}
let lse = log_sum_exp(&scratch);
log_v_new[j] = -lse;
}
let mut max_err: f64 = 0.0;
for i in 0..n {
for j in 0..n {
scratch[j] = log_u_new[i] + log_k[i * n + j] + log_v_new[j];
}
let log_row_sum = log_sum_exp(&scratch);
let row_sum = log_row_sum.exp();
max_err = max_err.max((row_sum - 1.0).abs());
}
log_u = log_u_new;
log_v = log_v_new;
if max_err < config.tol {
break;
}
}
let mut p = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
p[i * n + j] = (log_u[i] + log_k[i * n + j] + log_v[j]).exp();
}
}
Ok(p)
}
pub fn sinkhorn_rank(values: &[f64], epsilon: f64) -> Result<Vec<f64>> {
let n = values.len();
if n == 0 {
return Err(Error::EmptyInput);
}
if n == 1 {
return Ok(vec![1.0]);
}
let config = SinkhornConfig {
epsilon,
max_iter: if epsilon < 0.05 { 2000 } else { 200 },
..Default::default()
};
let p = sinkhorn_permutation(values, &config)?;
let mut ranks = vec![0.0; n];
for i in 0..n {
let mut rank = 0.0;
for j in 0..n {
rank += p[i * n + j] * (j as f64 + 1.0);
}
ranks[i] = rank;
}
Ok(ranks)
}
pub fn sinkhorn_sort(values: &[f64], epsilon: f64) -> Result<Vec<f64>> {
let n = values.len();
if n == 0 {
return Err(Error::EmptyInput);
}
if n == 1 {
return Ok(values.to_vec());
}
let config = SinkhornConfig {
epsilon,
max_iter: if epsilon < 0.05 { 2000 } else { 200 },
..Default::default()
};
let p = sinkhorn_permutation(values, &config)?;
let mut sorted = vec![0.0; n];
for j in 0..n {
let mut val = 0.0;
for i in 0..n {
val += p[i * n + j] * values[i];
}
sorted[j] = val;
}
Ok(sorted)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sinkhorn_coupling_invariant_to_row_col_cost_shifts() {
fn log_sum_exp(xs: &[f64]) -> f64 {
let mut m = f64::NEG_INFINITY;
for &x in xs {
if x > m {
m = x;
}
}
if !m.is_finite() {
return f64::NEG_INFINITY;
}
let mut s = 0.0;
for &x in xs {
s += (x - m).exp();
}
m + s.ln()
}
fn sinkhorn_from_log_k(log_k: &[f64], n: usize, config: &SinkhornConfig) -> Vec<f64> {
let mut log_u = vec![0.0; n];
let mut log_v = vec![0.0; n];
let mut scratch = vec![0.0; n];
for _ in 0..config.max_iter {
let mut log_u_new = vec![0.0; n];
for i in 0..n {
for j in 0..n {
scratch[j] = log_k[i * n + j] + log_v[j];
}
log_u_new[i] = -log_sum_exp(&scratch);
}
let mut log_v_new = vec![0.0; n];
for j in 0..n {
for i in 0..n {
scratch[i] = log_k[i * n + j] + log_u_new[i];
}
log_v_new[j] = -log_sum_exp(&scratch);
}
let mut max_err = 0.0f64;
for i in 0..n {
for j in 0..n {
scratch[j] = log_u_new[i] + log_k[i * n + j] + log_v_new[j];
}
let row_sum = log_sum_exp(&scratch).exp();
max_err = max_err.max((row_sum - 1.0).abs());
}
log_u = log_u_new;
log_v = log_v_new;
if max_err < config.tol {
break;
}
}
let mut p = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
p[i * n + j] = (log_u[i] + log_k[i * n + j] + log_v[j]).exp();
}
}
p
}
let values = vec![3.0, 1.0, 2.0, 4.0, 0.5];
let n = values.len();
let config = SinkhornConfig {
epsilon: 0.2,
max_iter: 600,
tol: 1e-10,
};
let z = standardize(&values);
let t = target_grid(n);
let mut cost = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let diff = z[i] - t[j];
cost[i * n + j] = diff * diff;
}
}
let inv_eps = 1.0 / config.epsilon;
let min_cost = cost.iter().copied().fold(f64::INFINITY, |a, b| a.min(b));
let mut log_k = vec![0.0; n * n];
for idx in 0..n * n {
log_k[idx] = -((cost[idx] - min_cost) * inv_eps);
}
let row = [0.3, -0.2, 0.1, 0.0, 0.15];
let col = [-0.25, 0.05, 0.2, -0.1, 0.0];
let mut log_k_shift = log_k.clone();
for i in 0..n {
for j in 0..n {
log_k_shift[i * n + j] += row[i] + col[j];
}
}
let p1 = sinkhorn_from_log_k(&log_k, n, &config);
let p2 = sinkhorn_from_log_k(&log_k_shift, n, &config);
let mut max_abs = 0.0f64;
for (a, b) in p1.iter().zip(p2.iter()) {
max_abs = max_abs.max((a - b).abs());
}
assert!(
max_abs < 1e-7,
"expected coupling invariant to separable shifts: max_abs={max_abs}"
);
}
#[test]
fn test_sinkhorn_rank_basic() {
let values = vec![3.0, 1.0, 2.0];
let ranks = sinkhorn_rank(&values, 0.1).unwrap();
assert!(
ranks[0] > 2.5,
"largest should have rank ~3, got {}",
ranks[0]
);
assert!(
ranks[1] < 1.5,
"smallest should have rank ~1, got {}",
ranks[1]
);
assert!(
ranks[2] > 1.5 && ranks[2] < 2.5,
"middle should have rank ~2, got {}",
ranks[2]
);
}
#[test]
fn test_sinkhorn_sort_basic() {
let values = vec![3.0, 1.0, 2.0];
let sorted = sinkhorn_sort(&values, 0.1).unwrap();
assert!(sorted[0] < sorted[1] + 0.5);
assert!(sorted[1] < sorted[2] + 0.5);
}
#[test]
fn test_sinkhorn_sort_permutation_invariant() {
let a = vec![5.0, 1.0, 2.0, 4.0, 3.0];
let b = vec![1.0, 3.0, 5.0, 2.0, 4.0]; let sa = sinkhorn_sort(&a, 0.2).unwrap();
let sb = sinkhorn_sort(&b, 0.2).unwrap();
assert_eq!(sa.len(), sb.len());
for i in 0..sa.len() {
assert!(
(sa[i] - sb[i]).abs() < 1e-6,
"i={} sa={} sb={}",
i,
sa[i],
sb[i]
);
}
}
#[test]
fn test_sinkhorn_sort_preserves_sum() {
let values = vec![3.0, 1.0, 2.0, 4.0];
let sorted = sinkhorn_sort(&values, 0.2).unwrap();
let s_in: f64 = values.iter().sum();
let s_out: f64 = sorted.iter().sum();
assert!((s_in - s_out).abs() < 1e-6, "in={} out={}", s_in, s_out);
}
#[test]
fn test_sinkhorn_rank_range() {
let values = vec![3.0, 1.0, 2.0, 4.0];
let ranks = sinkhorn_rank(&values, 0.2).unwrap();
for &r in &ranks {
assert!(r >= 1.0 - 1e-6);
assert!(r <= values.len() as f64 + 1e-6);
}
}
#[test]
fn test_doubly_stochastic() {
let values = vec![3.0, 1.0, 2.0, 4.0];
let config = SinkhornConfig::default();
let p = sinkhorn_permutation(&values, &config).unwrap();
let n = values.len();
for i in 0..n {
let row_sum: f64 = (0..n).map(|j| p[i * n + j]).sum();
assert!(
(row_sum - 1.0).abs() < 0.01,
"row {} sum = {}, expected 1.0",
i,
row_sum
);
}
for j in 0..n {
let col_sum: f64 = (0..n).map(|i| p[i * n + j]).sum();
assert!(
(col_sum - 1.0).abs() < 0.01,
"col {} sum = {}, expected 1.0",
j,
col_sum
);
}
}
#[test]
fn test_epsilon_effect() {
let values = vec![3.0, 1.0, 2.0];
let ranks_sharp = sinkhorn_rank(&values, 0.01).unwrap();
let ranks_smooth = sinkhorn_rank(&values, 1.0).unwrap();
let sharp_var: f64 = ranks_sharp
.iter()
.map(|&r| (r - r.round()).powi(2))
.sum::<f64>()
/ ranks_sharp.len() as f64;
let smooth_var: f64 = ranks_smooth
.iter()
.map(|&r| (r - r.round()).powi(2))
.sum::<f64>()
/ ranks_smooth.len() as f64;
assert!(
sharp_var < smooth_var,
"sharp_var={} should be < smooth_var={}",
sharp_var,
smooth_var
);
}
#[test]
fn test_empty_input() {
assert!(sinkhorn_rank(&[], 0.1).is_err());
assert!(sinkhorn_sort(&[], 0.1).is_err());
}
#[test]
fn test_single_element() {
let ranks = sinkhorn_rank(&[42.0], 0.1).unwrap();
assert_eq!(ranks, vec![1.0]);
let sorted = sinkhorn_sort(&[42.0], 0.1).unwrap();
assert_eq!(sorted, vec![42.0]);
}
}