use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
pub const LOG_ZERO_SENTINEL: f64 = -1.0e300;
const LOG_ZERO_SATURATION_THRESHOLD: f64 = LOG_ZERO_SENTINEL * 0.5;
pub const MIN_EPS: f64 = 1.0e-12;
fn logsumexp_axis0(log_x: ArrayView2<'_, f64>) -> Array1<f64> {
let (m_rows, m_cols) = log_x.dim();
let mut out = Array1::<f64>::from_elem(m_cols, LOG_ZERO_SENTINEL);
if m_rows == 0 {
return out;
}
for j in 0..m_cols {
let mut col_max = f64::NEG_INFINITY;
for i in 0..m_rows {
let value = log_x[[i, j]];
if value > col_max {
col_max = value;
}
}
if !col_max.is_finite() || col_max <= LOG_ZERO_SATURATION_THRESHOLD {
out[j] = LOG_ZERO_SENTINEL;
continue;
}
let mut acc = 0.0_f64;
for i in 0..m_rows {
acc += (log_x[[i, j]] - col_max).exp();
}
out[j] = if acc > 0.0 {
col_max + acc.ln()
} else {
LOG_ZERO_SENTINEL
};
}
out
}
fn logsumexp_axis1(log_x: ArrayView2<'_, f64>) -> Array1<f64> {
let (m_rows, m_cols) = log_x.dim();
let mut out = Array1::<f64>::from_elem(m_rows, LOG_ZERO_SENTINEL);
if m_cols == 0 {
return out;
}
for i in 0..m_rows {
let mut row_max = f64::NEG_INFINITY;
for j in 0..m_cols {
let value = log_x[[i, j]];
if value > row_max {
row_max = value;
}
}
if !row_max.is_finite() || row_max <= LOG_ZERO_SATURATION_THRESHOLD {
out[i] = LOG_ZERO_SENTINEL;
continue;
}
let mut acc = 0.0_f64;
for j in 0..m_cols {
acc += (log_x[[i, j]] - row_max).exp();
}
out[i] = if acc > 0.0 {
row_max + acc.ln()
} else {
LOG_ZERO_SENTINEL
};
}
out
}
fn log_vector_is_sentinel_saturated(log_x: ArrayView1<'_, f64>) -> bool {
let mut max = f64::NEG_INFINITY;
for &v in log_x.iter() {
if v > max {
max = v;
}
}
!max.is_finite() || max <= LOG_ZERO_SATURATION_THRESHOLD
}
fn softmax_1d(log_x: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
let m = log_x.len();
if m == 0 {
return Ok(Array1::zeros(0));
}
let mut max = f64::NEG_INFINITY;
for &v in log_x.iter() {
if v > max {
max = v;
}
}
if !max.is_finite() || max <= LOG_ZERO_SATURATION_THRESHOLD {
return Err(
"sinkhorn barycenter degenerated: all log_a saturated to sentinel -- try larger eps or check cost matrix"
.to_string(),
);
}
let mut out = Array1::<f64>::zeros(m);
let mut total = 0.0_f64;
for (i, &v) in log_x.iter().enumerate() {
let e = (v - max).exp();
out[i] = e;
total += e;
}
if total <= 0.0 {
return Err(
"sinkhorn barycenter degenerated: softmax mass underflowed -- try larger eps or check cost matrix"
.to_string(),
);
}
for v in out.iter_mut() {
*v /= total;
}
Ok(out)
}
fn safe_log_simplex(row: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(row.len());
for (i, &v) in row.iter().enumerate() {
out[i] = if v <= 0.0 { LOG_ZERO_SENTINEL } else { v.ln() };
}
out
}
fn validate_inputs(
atoms: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
cost: ArrayView2<'_, f64>,
eps: f64,
n_iter: usize,
) -> Result<(), String> {
let (k, m) = atoms.dim();
if k == 0 || m == 0 {
return Err("atoms must have at least one row and one column".to_string());
}
if weights.len() != k {
return Err(format!(
"weights length {} does not match atoms row count {}",
weights.len(),
k
));
}
let (cm_r, cm_c) = cost.dim();
if cm_r != m || cm_c != m {
return Err(format!(
"cost matrix must be ({}, {}), got ({}, {})",
m, m, cm_r, cm_c
));
}
if !(eps.is_finite() && eps >= MIN_EPS) {
return Err(format!("eps must be finite and >= {MIN_EPS:e}, got {eps}"));
}
if n_iter == 0 {
return Err("n_iter must be at least 1".to_string());
}
for ((row, col), value) in atoms.indexed_iter() {
if !value.is_finite() || *value < 0.0 {
return Err(format!(
"atoms must be finite and non-negative; got {value} at ({row}, {col})"
));
}
}
let mut w_total = 0.0_f64;
for &w in weights.iter() {
if !w.is_finite() || w < 0.0 {
return Err("weights must be finite and non-negative".to_string());
}
w_total += w;
}
if w_total <= 0.0 {
return Err("weights must have positive total mass".to_string());
}
for ((i, j), value) in cost.indexed_iter() {
if !value.is_finite() || *value < 0.0 {
return Err(format!(
"cost must be finite and non-negative; got {value} at ({i}, {j})"
));
}
}
Ok(())
}
fn normalize_atoms(atoms: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let (k, m) = atoms.dim();
let mut out = Array2::<f64>::zeros((k, m));
for ki in 0..k {
let mut total = 0.0_f64;
for j in 0..m {
total += atoms[[ki, j]];
}
if !(total > 0.0) {
return Err(format!(
"atoms row {ki} has non-positive total mass {total}"
));
}
for j in 0..m {
out[[ki, j]] = atoms[[ki, j]] / total;
}
}
Ok(out)
}
fn normalize_weights(weights: ArrayView1<'_, f64>) -> Vec<f64> {
let total: f64 = weights.iter().sum();
weights.iter().map(|w| w / total).collect()
}
pub struct SinkhornState {
pub log_u: Array2<f64>,
pub log_v: Array2<f64>,
pub log_a: Array1<f64>,
pub log_kernel: Array2<f64>,
pub log_atoms: Array2<f64>,
pub weights: Vec<f64>,
}
pub fn sinkhorn_barycenter_forward_state(
atoms: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
cost: ArrayView2<'_, f64>,
eps: f64,
n_iter: usize,
) -> Result<SinkhornState, String> {
validate_inputs(atoms, weights, cost, eps, n_iter)?;
let atoms_norm = normalize_atoms(atoms)?;
let weights_norm = normalize_weights(weights);
let (k, m) = atoms_norm.dim();
let mut log_kernel = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
log_kernel[[i, j]] = -cost[[i, j]] / eps;
}
}
let mut log_atoms = Array2::<f64>::zeros((k, m));
for ki in 0..k {
let row = safe_log_simplex(atoms_norm.row(ki));
for j in 0..m {
log_atoms[[ki, j]] = row[j];
}
}
let mut log_u = Array2::<f64>::zeros((k, m));
let mut log_v = Array2::<f64>::zeros((k, m));
let inv_m = (1.0_f64 / m as f64).ln();
let mut log_a = Array1::<f64>::from_elem(m, inv_m);
let mut scratch = Array2::<f64>::zeros((m, m));
for _ in 0..n_iter {
for ki in 0..k {
for i in 0..m {
let off = log_u[[ki, i]];
for j in 0..m {
scratch[[i, j]] = log_kernel[[i, j]] + off;
}
}
let lse = logsumexp_axis0(scratch.view());
for j in 0..m {
log_v[[ki, j]] = log_a[j] - lse[j];
}
}
for ki in 0..k {
for j in 0..m {
let off = log_v[[ki, j]];
for i in 0..m {
scratch[[i, j]] = log_kernel[[i, j]] + off;
}
}
let lse = logsumexp_axis1(scratch.view());
for i in 0..m {
log_u[[ki, i]] = log_atoms[[ki, i]] - lse[i];
}
}
let mut next_log_a = Array1::<f64>::zeros(m);
for ki in 0..k {
for i_prime in 0..m {
let off = log_u[[ki, i_prime]];
for i in 0..m {
scratch[[i_prime, i]] = log_kernel[[i_prime, i]] + off;
}
}
let lse = logsumexp_axis0(scratch.view());
for i in 0..m {
next_log_a[i] += weights_norm[ki] * lse[i];
}
}
log_a = next_log_a;
}
Ok(SinkhornState {
log_u,
log_v,
log_a,
log_kernel,
log_atoms,
weights: weights_norm,
})
}
pub fn sinkhorn_barycenter(
atoms: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
cost: ArrayView2<'_, f64>,
eps: f64,
n_iter: usize,
) -> Result<Array1<f64>, String> {
let state = sinkhorn_barycenter_forward_state(atoms, weights, cost, eps, n_iter)?;
if log_vector_is_sentinel_saturated(state.log_a.view()) {
return Err(
"sinkhorn barycenter degenerated: all log_a saturated to sentinel -- try larger eps or check cost matrix"
.to_string(),
);
}
softmax_1d(state.log_a.view())
}
pub struct SinkhornVjp {
pub d_atoms: Array2<f64>,
pub d_weights: Array1<f64>,
}
pub fn sinkhorn_barycenter_vjp(
atoms: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
cost: ArrayView2<'_, f64>,
eps: f64,
n_iter: usize,
cotangent: ArrayView1<'_, f64>,
) -> Result<SinkhornVjp, String> {
let state = sinkhorn_barycenter_forward_state(atoms, weights, cost, eps, n_iter)?;
let (k, m) = state.log_u.dim();
if cotangent.len() != m {
return Err(format!(
"cotangent length {} does not match barycenter size {}",
cotangent.len(),
m
));
}
let bary = softmax_1d(state.log_a.view())?;
let mut g_log_a = Array1::<f64>::zeros(m);
let mut weighted = 0.0_f64;
for i in 0..m {
weighted += cotangent[i] * bary[i];
}
for i in 0..m {
g_log_a[i] = bary[i] * (cotangent[i] - weighted);
}
let mut g_log_u = Array2::<f64>::zeros((k, m));
let mut g_log_v = Array2::<f64>::zeros((k, m));
let mut g_log_atoms = Array2::<f64>::zeros((k, m));
let mut g_weights = Array1::<f64>::zeros(k);
let mut p_couplings: Vec<Array2<f64>> = Vec::with_capacity(k);
let mut q_couplings: Vec<Array2<f64>> = Vec::with_capacity(k);
for ki in 0..k {
let mut p_mat = Array2::<f64>::zeros((m, m));
let mut q_mat = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let logp = state.log_kernel[[i, j]] + state.log_u[[ki, i]] + state.log_v[[ki, j]];
let p_val = (logp - state.log_a[j]).exp();
let q_val = (logp - state.log_atoms[[ki, i]]).exp();
p_mat[[i, j]] = p_val;
q_mat[[i, j]] = q_val;
}
}
p_couplings.push(p_mat);
q_couplings.push(q_mat);
}
let mut s_per_k: Vec<Array1<f64>> = Vec::with_capacity(k);
{
let mut scratch = Array2::<f64>::zeros((m, m));
for ki in 0..k {
for i_prime in 0..m {
let off = state.log_u[[ki, i_prime]];
for j in 0..m {
scratch[[i_prime, j]] = state.log_kernel[[i_prime, j]] + off;
}
}
s_per_k.push(logsumexp_axis0(scratch.view()));
}
}
for ki in 0..k {
for i in 0..m {
let mut acc = 0.0_f64;
for j in 0..m {
acc += g_log_a[j] * state.weights[ki] * p_couplings[ki][[i, j]];
}
g_log_u[[ki, i]] += acc;
}
let mut acc_w = 0.0_f64;
for j in 0..m {
acc_w += g_log_a[j] * s_per_k[ki][j];
}
g_weights[ki] += acc_w;
}
let mut g_log_a_in = Array1::<f64>::zeros(m);
for _ in 0..n_iter {
for ki in 0..k {
for i in 0..m {
g_log_atoms[[ki, i]] += g_log_u[[ki, i]];
}
for i in 0..m {
let gu = g_log_u[[ki, i]];
if gu == 0.0 {
continue;
}
for j in 0..m {
g_log_v[[ki, j]] -= gu * q_couplings[ki][[i, j]];
}
}
for i in 0..m {
g_log_u[[ki, i]] = 0.0;
}
}
for ki in 0..k {
for j in 0..m {
g_log_a_in[j] += g_log_v[[ki, j]];
}
for j in 0..m {
let gv = g_log_v[[ki, j]];
if gv == 0.0 {
continue;
}
for i in 0..m {
g_log_u[[ki, i]] -= gv * p_couplings[ki][[i, j]];
}
}
for j in 0..m {
g_log_v[[ki, j]] = 0.0;
}
}
for ki in 0..k {
for i in 0..m {
let mut acc = 0.0_f64;
for j in 0..m {
acc += g_log_a_in[j] * state.weights[ki] * p_couplings[ki][[i, j]];
}
g_log_u[[ki, i]] += acc;
}
let mut acc_w = 0.0_f64;
for j in 0..m {
acc_w += g_log_a_in[j] * s_per_k[ki][j];
}
g_weights[ki] += acc_w;
}
for j in 0..m {
g_log_a_in[j] = 0.0;
}
}
let mut d_atoms = Array2::<f64>::zeros((k, m));
for ki in 0..k {
let mut z = 0.0_f64;
for j in 0..m {
z += atoms[[ki, j]];
}
let mut sum_g = 0.0_f64;
for i in 0..m {
sum_g += g_log_atoms[[ki, i]];
}
for l in 0..m {
let raw = atoms[[ki, l]];
let first = if raw > 0.0 {
g_log_atoms[[ki, l]] / raw
} else {
0.0
};
d_atoms[[ki, l]] = first - sum_g / z;
}
}
let mut d_weights = Array1::<f64>::zeros(k);
let w_total: f64 = weights.iter().sum();
if w_total > 0.0 {
let mut sum_norm_g = 0.0_f64;
for ki in 0..k {
sum_norm_g += g_weights[ki] * state.weights[ki];
}
for ki in 0..k {
d_weights[ki] = (g_weights[ki] - sum_norm_g) / w_total;
}
}
Ok(SinkhornVjp { d_atoms, d_weights })
}
pub fn circular_cost(m: usize) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((m, m));
if m == 0 {
return out;
}
for i in 0..m {
for j in 0..m {
let diff = if i >= j { i - j } else { j - i };
let d = diff.min(m - diff);
let dd = d as f64;
out[[i, j]] = dd * dd;
}
}
out
}
pub fn euclidean_cost(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let (m, d) = points.dim();
if m == 0 || d == 0 {
return Err("euclidean_cost requires at least one point and one dimension".to_string());
}
for ((row, col), value) in points.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"euclidean_cost points must be finite; got {value} at ({row}, {col})"
));
}
}
let mut out = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let mut acc = 0.0_f64;
for k in 0..d {
let diff = points[[i, k]] - points[[j, k]];
acc += diff * diff;
}
out[[i, j]] = acc;
}
}
Ok(out)
}
pub fn geodesic_sphere_cost(directions: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let (m, d) = directions.dim();
if d != 3 {
return Err(format!(
"geodesic_sphere_cost requires direction vectors of dimension 3, got {d}"
));
}
for i in 0..m {
let mut norm = 0.0_f64;
for k in 0..3 {
let v = directions[[i, k]];
if !v.is_finite() {
return Err(format!(
"geodesic_sphere_cost directions must be finite; got {v} at ({i}, {k})"
));
}
norm += v * v;
}
if (norm.sqrt() - 1.0).abs() > 1.0e-6 {
return Err(format!(
"geodesic_sphere_cost row {i} must be unit-norm; got |x| = {}",
norm.sqrt()
));
}
}
let mut out = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let mut dot = 0.0_f64;
for k in 0..3 {
dot += directions[[i, k]] * directions[[j, k]];
}
let dot_clamped = dot.clamp(-1.0, 1.0);
let theta = dot_clamped.acos();
out[[i, j]] = theta * theta;
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2, array};
fn approx_simplex_eq(actual: &Array1<f64>, expected: &Array1<f64>, tol: f64) {
assert_eq!(actual.len(), expected.len());
let sum_actual: f64 = actual.iter().sum();
let sum_expected: f64 = expected.iter().sum();
assert!(
(sum_actual - 1.0).abs() < 1.0e-8,
"actual does not sum to 1: {sum_actual}"
);
assert!(
(sum_expected - 1.0).abs() < 1.0e-8,
"expected does not sum to 1: {sum_expected}"
);
for (a, e) in actual.iter().zip(expected.iter()) {
assert!(
(a - e).abs() < tol,
"barycenter entry mismatch: {a} vs {e} (tol {tol})"
);
}
}
#[test]
fn k_eq_1_recovers_the_atom() {
let m = 8;
let atom = array![0.05, 0.1, 0.2, 0.3, 0.2, 0.1, 0.04, 0.01];
let mut atoms = Array2::<f64>::zeros((1, m));
for j in 0..m {
atoms[[0, j]] = atom[j];
}
let weights = array![1.0];
let cost = circular_cost(m);
let bary =
sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 0.05, 60).unwrap();
approx_simplex_eq(&bary, &atom, 5.0e-3);
assert_eq!(bary.len(), atom.len(), "barycenter length mismatch");
}
#[test]
fn k_eq_2_mean_is_between() {
let m = 32;
let points: Array2<f64> = Array2::from_shape_fn((m, 1), |(i, _)| i as f64 / (m - 1) as f64);
let mut atom_a = Array1::<f64>::zeros(m);
let mut atom_b = Array1::<f64>::zeros(m);
let mut sa = 0.0;
let mut sb = 0.0;
for j in 0..m {
let x = j as f64 / (m - 1) as f64;
let va = (-((x - 0.2) * (x - 0.2)) / 0.005).exp();
let vb = (-((x - 0.8) * (x - 0.8)) / 0.005).exp();
atom_a[j] = va;
atom_b[j] = vb;
sa += va;
sb += vb;
}
for j in 0..m {
atom_a[j] /= sa;
atom_b[j] /= sb;
}
let mut atoms = Array2::<f64>::zeros((2, m));
for j in 0..m {
atoms[[0, j]] = atom_a[j];
atoms[[1, j]] = atom_b[j];
}
let weights = array![0.5, 0.5];
let cost = euclidean_cost(points.view()).unwrap();
let bary =
sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 0.005, 200).unwrap();
let mean_a: f64 = (0..m)
.map(|j| (j as f64 / (m - 1) as f64) * atom_a[j])
.sum();
let mean_b: f64 = (0..m)
.map(|j| (j as f64 / (m - 1) as f64) * atom_b[j])
.sum();
let mean_bary: f64 = (0..m).map(|j| (j as f64 / (m - 1) as f64) * bary[j]).sum();
let expected_mean = 0.5 * (mean_a + mean_b);
assert!(
(mean_bary - expected_mean).abs() < 0.05,
"bary mean {mean_bary} should sit at midpoint {expected_mean}"
);
assert!(
mean_bary > mean_a && mean_bary < mean_b,
"bary mean {mean_bary} should be between atom means ({mean_a}, {mean_b})"
);
}
#[test]
fn cyclic_midpoint_recovers_mccann_interp() {
let m = 32;
let mut atoms = Array2::<f64>::zeros((2, m));
for j in 0..m {
let d_a = (j as i64 - 8).rem_euclid(m as i64);
let d_a = d_a.min(m as i64 - d_a);
let d_b = (j as i64 - 24).rem_euclid(m as i64);
let d_b = d_b.min(m as i64 - d_b);
atoms[[0, j]] = (-(d_a as f64).powi(2) / 1.5).exp();
atoms[[1, j]] = (-(d_b as f64).powi(2) / 1.5).exp();
}
let weights = array![0.5, 0.5];
let cost = circular_cost(m);
let bary =
sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 0.5, 200).unwrap();
let mode = bary
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
assert!(
mode == 16 || mode == 0,
"barycenter mode {mode} should be midway between atom modes"
);
}
#[test]
fn small_eps_does_not_nan() {
let m = 16;
let atoms = Array2::<f64>::from_shape_fn((2, m), |(k, j)| {
let centre = if k == 0 { 3.0 } else { 11.0 };
(-((j as f64 - centre).powi(2)) / 4.0).exp()
});
let weights = array![0.5, 0.5];
let cost = circular_cost(m);
let bary =
sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 1.0e-3, 50).unwrap();
for v in bary.iter() {
assert!(v.is_finite(), "barycenter entry {v} is not finite");
assert!(*v >= 0.0, "barycenter entry {v} is negative");
}
let s: f64 = bary.iter().sum();
assert!((s - 1.0).abs() < 1.0e-8, "barycenter sum {s} != 1");
}
#[test]
fn rejects_small_eps() {
let m = 4;
let atoms = Array2::<f64>::from_elem((2, m), 1.0 / m as f64);
let weights = array![0.5, 0.5];
let cost = circular_cost(m);
let err = sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 1.0e-15, 10);
assert!(err.is_err());
}
#[test]
fn sentinel_saturated_log_a_returns_error() {
let m = 3;
let atoms = Array2::<f64>::from_elem((1, m), 1.0 / m as f64);
let weights = array![1.0];
let cost = Array2::<f64>::from_elem((m, m), 1.0e288);
let err =
sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), MIN_EPS, 1).unwrap_err();
assert!(
err.contains("all log_a saturated to sentinel"),
"unexpected error: {err}"
);
}
#[test]
fn batch_kbig_runs_quick() {
let m = 64;
let k = 128;
let atoms = Array2::<f64>::from_shape_fn((k, m), |(ki, j)| {
let centre = (ki as f64) * (m as f64) / (k as f64);
(-((j as f64 - centre).powi(2)) / 8.0).exp()
});
let weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
let cost = circular_cost(m);
let t0 = std::time::Instant::now();
let bary = sinkhorn_barycenter(atoms.view(), weights.view(), cost.view(), 0.1, 20).unwrap();
let dt = t0.elapsed();
assert!(
dt.as_secs_f64() < 5.0,
"batch sinkhorn took too long: {dt:?}"
);
let s: f64 = bary.iter().sum();
assert!((s - 1.0).abs() < 1.0e-8);
}
#[test]
fn cost_helpers_shape_and_symmetry() {
let m = 5;
let cc = circular_cost(m);
for i in 0..m {
assert_eq!(cc[[i, i]], 0.0);
for j in 0..m {
assert!((cc[[i, j]] - cc[[j, i]]).abs() < 1.0e-12);
assert!(cc[[i, j]] >= 0.0);
}
}
let pts = Array2::<f64>::from_shape_fn((m, 2), |(i, k)| (i + k) as f64);
let ec = euclidean_cost(pts.view()).unwrap();
for i in 0..m {
assert_eq!(ec[[i, i]], 0.0);
for j in 0..m {
assert!((ec[[i, j]] - ec[[j, i]]).abs() < 1.0e-12);
}
}
let dirs = Array2::<f64>::from_shape_fn((3, 3), |(i, k)| if i == k { 1.0 } else { 0.0 });
let gc = geodesic_sphere_cost(dirs.view()).unwrap();
for i in 0..3 {
assert!((gc[[i, i]]).abs() < 1.0e-12);
for j in 0..3 {
assert!((gc[[i, j]] - gc[[j, i]]).abs() < 1.0e-12);
}
}
}
#[test]
fn vjp_matches_finite_differences_small() {
let m = 6;
let k = 2;
let atoms = Array2::<f64>::from_shape_fn((k, m), |(ki, j)| {
let centre = if ki == 0 { 1.5 } else { 4.0 };
(-((j as f64 - centre).powi(2)) / 2.0).exp()
});
let mut atoms_norm = atoms.clone();
for ki in 0..k {
let s: f64 = atoms_norm.row(ki).iter().sum();
for j in 0..m {
atoms_norm[[ki, j]] /= s;
}
}
let weights = array![0.5, 0.5];
let cost = circular_cost(m);
let eps = 0.3;
let n_iter = 100;
let r = Array1::<f64>::from_shape_fn(m, |j| j as f64 - (m as f64 - 1.0) / 2.0);
let vjp = sinkhorn_barycenter_vjp(
atoms_norm.view(),
weights.view(),
cost.view(),
eps,
n_iter,
r.view(),
)
.unwrap();
let h = 1.0e-5;
let (ki, j) = (0usize, 2usize);
let mut atoms_plus = atoms_norm.clone();
let mut atoms_minus = atoms_norm.clone();
atoms_plus[[ki, j]] += h;
atoms_minus[[ki, j]] -= h;
let b_plus =
sinkhorn_barycenter(atoms_plus.view(), weights.view(), cost.view(), eps, n_iter)
.unwrap();
let b_minus =
sinkhorn_barycenter(atoms_minus.view(), weights.view(), cost.view(), eps, n_iter)
.unwrap();
let mut fd = 0.0_f64;
for i in 0..m {
fd += r[i] * (b_plus[i] - b_minus[i]) / (2.0 * h);
}
let analytic = vjp.d_atoms[[ki, j]];
let denom = analytic.abs().max(fd.abs()).max(1.0e-6);
let rel = (analytic - fd).abs() / denom;
assert!(
rel < 0.05,
"VJP/FD mismatch: analytic={analytic}, fd={fd}, rel={rel}"
);
}
}