#![warn(missing_docs)]
use ndarray::{Array1, Array2};
use thiserror::Error;
pub mod flow;
pub mod gaussian;
pub mod gromov;
pub mod semidiscrete;
pub mod sparse;
pub mod wfr;
pub use flow::{flow_drift, VectorField};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("distributions have different lengths: {0} vs {1}")]
LengthMismatch(usize, usize),
#[error("cost matrix shape mismatch: expected ({0}, {1}), got ({2}, {3})")]
CostShapeMismatch(usize, usize, usize, usize),
#[error("distribution does not sum to 1.0 (sum = {0})")]
NotNormalized(f32),
#[error("Sinkhorn did not converge in {0} iterations")]
SinkhornNotConverged(usize),
#[error("regularization parameter must be positive and finite, got {0}")]
InvalidRegularization(f32),
#[error("mass penalty parameter must be positive and finite, got {0}")]
InvalidMassPenalty(f32),
#[error("rank must be >= 1 and <= min(n, m), got rank={0} for n={1}, m={2}")]
InvalidRank(usize, usize, usize),
#[error("branching must be >= 2 and <= min(n, m), got branching={0} for n={1}, m={2}")]
InvalidBranching(usize, usize, usize),
#[error("{0}")]
Domain(&'static str),
}
pub type Result<T> = std::result::Result<T, Error>;
const EPSILON: f32 = 1e-7;
#[inline]
fn logsumexp_by(len: usize, mut f: impl FnMut(usize) -> f32) -> f32 {
if len == 0 {
return f32::NEG_INFINITY;
}
let mut max_val = f32::NEG_INFINITY;
for i in 0..len {
max_val = max_val.max(f(i));
}
if !max_val.is_finite() {
return max_val;
}
let mut sum_exp = 0.0;
for i in 0..len {
sum_exp += (f(i) - max_val).exp();
}
max_val + sum_exp.ln()
}
pub fn wasserstein_1d(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "distributions must have same length");
let n = a.len();
if n == 0 {
return 0.0;
}
let mut cdf_a = vec![0.0; n];
let mut cdf_b = vec![0.0; n];
cdf_a[0] = a[0];
cdf_b[0] = b[0];
for i in 1..n {
cdf_a[i] = cdf_a[i - 1] + a[i];
cdf_b[i] = cdf_b[i - 1] + b[i];
}
cdf_a
.iter()
.zip(cdf_b.iter())
.map(|(&ca, &cb)| (ca - cb).abs())
.sum()
}
pub fn sinkhorn(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
) -> (Array2<f32>, f32) {
let m = a.len();
let n = b.len();
assert_eq!(cost.shape(), &[m, n], "cost matrix shape mismatch");
let k: Array2<f32> = cost.mapv(|c| (-c / reg).exp());
let mut u = Array1::ones(m);
let mut v = Array1::ones(n);
for _ in 0..max_iter {
let kv = k.dot(&v);
for i in 0..m {
u[i] = a[i] / (kv[i] + EPSILON);
}
let ktu = k.t().dot(&u);
for j in 0..n {
v[j] = b[j] / (ktu[j] + EPSILON);
}
}
let mut plan = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
plan[[i, j]] = u[i] * k[[i, j]] * v[j];
}
}
let distance: f32 = cost.iter().zip(plan.iter()).map(|(&c, &p)| c * p).sum();
(plan, distance)
}
pub fn sinkhorn_with_convergence(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<(Array2<f32>, f32, usize)> {
let m = a.len();
let n = b.len();
assert_eq!(cost.shape(), &[m, n], "cost matrix shape mismatch");
let k: Array2<f32> = cost.mapv(|c| (-c / reg).exp());
let mut u = Array1::ones(m);
let mut v = Array1::ones(n);
for iter in 0..max_iter {
let kv = k.dot(&v);
for i in 0..m {
u[i] = a[i] / (kv[i] + EPSILON);
}
let ktu = k.t().dot(&u);
for j in 0..n {
v[j] = b[j] / (ktu[j] + EPSILON);
}
let kv2 = k.dot(&v);
let mut max_err = 0.0f32;
for i in 0..m {
let row_sum = u[i] * kv2[i];
max_err = max_err.max((row_sum - a[i]).abs());
}
let ktu2 = k.t().dot(&u);
for j in 0..n {
let col_sum = v[j] * ktu2[j];
max_err = max_err.max((col_sum - b[j]).abs());
}
if max_err < tol {
let mut plan = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
plan[[i, j]] = u[i] * k[[i, j]] * v[j];
}
}
let distance: f32 = cost.iter().zip(plan.iter()).map(|(&c, &p)| c * p).sum();
return Ok((plan, distance, iter + 1));
}
}
Err(Error::SinkhornNotConverged(max_iter))
}
pub fn earth_mover_distance(a: &Array1<f32>, b: &Array1<f32>, cost: &Array2<f32>) -> f32 {
let reg = 0.01; let (_, distance) = sinkhorn_log(a, b, cost, reg, 200);
distance
}
#[deprecated(
note = "Use sinkhorn_divergence_same_support or sinkhorn_divergence_general for a true divergence"
)]
pub fn sinkhorn_divergence(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
) -> f32 {
let (_, ot_pq) = sinkhorn_log(a, b, cost, reg, max_iter);
let m = a.len();
let n = b.len();
if m == n {
let (_, ot_pp) = sinkhorn_log(a, a, cost, reg, max_iter);
let (_, ot_qq) = sinkhorn_log(b, b, cost, reg, max_iter);
(ot_pq - 0.5 * (ot_pp + ot_qq)).max(0.0)
} else {
ot_pq
}
}
pub fn sinkhorn_divergence_same_support(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<f32> {
let n = a.len();
if b.len() != n {
return Err(Error::LengthMismatch(n, b.len()));
}
if cost.nrows() != n || cost.ncols() != n {
return Err(Error::CostShapeMismatch(n, n, cost.nrows(), cost.ncols()));
}
let (_p_pq, ot_pq, _iters_pq) = sinkhorn_log_with_convergence(a, b, cost, reg, max_iter, tol)?;
let (_p_pp, ot_pp, _iters_pp) = sinkhorn_log_with_convergence(a, a, cost, reg, max_iter, tol)?;
let (_p_qq, ot_qq, _iters_qq) = sinkhorn_log_with_convergence(b, b, cost, reg, max_iter, tol)?;
Ok((ot_pq - 0.5 * (ot_pp + ot_qq)).max(0.0))
}
pub fn sinkhorn_divergence_general(
a: &Array1<f32>,
b: &Array1<f32>,
cost_ab: &Array2<f32>,
cost_aa: &Array2<f32>,
cost_bb: &Array2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<f32> {
let m = a.len();
let n = b.len();
if cost_ab.nrows() != m || cost_ab.ncols() != n {
return Err(Error::CostShapeMismatch(
m,
n,
cost_ab.nrows(),
cost_ab.ncols(),
));
}
if cost_aa.nrows() != m || cost_aa.ncols() != m {
return Err(Error::CostShapeMismatch(
m,
m,
cost_aa.nrows(),
cost_aa.ncols(),
));
}
if cost_bb.nrows() != n || cost_bb.ncols() != n {
return Err(Error::CostShapeMismatch(
n,
n,
cost_bb.nrows(),
cost_bb.ncols(),
));
}
let (_p_pq, ot_pq, _iters_pq) =
sinkhorn_log_with_convergence(a, b, cost_ab, reg, max_iter, tol)?;
let (_p_pp, ot_pp, _iters_pp) =
sinkhorn_log_with_convergence(a, a, cost_aa, reg, max_iter, tol)?;
let (_p_qq, ot_qq, _iters_qq) =
sinkhorn_log_with_convergence(b, b, cost_bb, reg, max_iter, tol)?;
Ok((ot_pq - 0.5 * (ot_pp + ot_qq)).max(0.0))
}
pub fn euclidean_cost_matrix(x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
let m = x.nrows();
let n = y.nrows();
let d = x.ncols();
assert_eq!(y.ncols(), d, "point dimensions must match");
let mut cost = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut dist_sq = 0.0;
for k in 0..d {
let diff = x[[i, k]] - y[[j, k]];
dist_sq += diff * diff;
}
cost[[i, j]] = dist_sq.sqrt();
}
}
cost
}
pub fn sq_euclidean_cost_matrix(x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
let m = x.nrows();
let n = y.nrows();
let d = x.ncols();
assert_eq!(y.ncols(), d, "point dimensions must match");
let mut cost = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut dist_sq = 0.0;
for k in 0..d {
let diff = x[[i, k]] - y[[j, k]];
dist_sq += diff * diff;
}
cost[[i, j]] = dist_sq;
}
}
cost
}
pub fn sinkhorn_log(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
) -> (Array2<f32>, f32) {
let m = a.len();
let n = b.len();
let a_sum = a.sum();
let b_sum = b.sum();
let a = a / (a_sum + EPSILON);
let b = b / (b_sum + EPSILON);
let log_a = a.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let log_b = b.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let mut f: Array1<f32> = Array1::zeros(m);
let mut g: Array1<f32> = Array1::zeros(n);
for _ in 0..max_iter {
for i in 0..m {
let lse = logsumexp_by(n, |j| (g[j] - cost[[i, j]]) / reg);
f[i] = reg * (log_a[i] - lse);
}
for j in 0..n {
let lse = logsumexp_by(m, |i| (f[i] - cost[[i, j]]) / reg);
g[j] = reg * (log_b[j] - lse);
}
}
let mut plan = Array2::zeros((m, n));
let mut distance = 0.0;
for i in 0..m {
for j in 0..n {
let log_p = (f[i] + g[j] - cost[[i, j]]) / reg;
plan[[i, j]] = log_p.exp();
distance += plan[[i, j]] * cost[[i, j]];
}
}
(plan, distance)
}
pub fn sinkhorn_log_with_convergence(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<(Array2<f32>, f32, usize)> {
let m = a.len();
let n = b.len();
if cost.nrows() != m || cost.ncols() != n {
return Err(Error::CostShapeMismatch(m, n, cost.nrows(), cost.ncols()));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("sinkhorn requires nonnegative masses"));
}
let a_sum = a.sum();
let b_sum = b.sum();
if a_sum <= 0.0 || b_sum <= 0.0 {
return Err(Error::Domain("sinkhorn requires positive total mass"));
}
let a = a / (a_sum + EPSILON);
let b = b / (b_sum + EPSILON);
let log_a = a.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let log_b = b.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let mut f: Array1<f32> = Array1::zeros(m);
let mut g: Array1<f32> = Array1::zeros(n);
let check_every = 10usize.max(1);
for iter in 0..max_iter {
for i in 0..m {
let lse = logsumexp_by(n, |j| (g[j] - cost[[i, j]]) / reg);
f[i] = reg * (log_a[i] - lse);
}
for j in 0..n {
let lse = logsumexp_by(m, |i| (f[i] - cost[[i, j]]) / reg);
g[j] = reg * (log_b[j] - lse);
}
if (iter + 1) % check_every == 0 || iter + 1 == max_iter {
let mut max_err = 0.0f32;
for i in 0..m {
let lse = logsumexp_by(n, |j| (g[j] - cost[[i, j]]) / reg);
let row_sum = (f[i] / reg).exp() * lse.exp();
max_err = max_err.max((row_sum - a[i]).abs());
}
for j in 0..n {
let lse = logsumexp_by(m, |i| (f[i] - cost[[i, j]]) / reg);
let col_sum = (g[j] / reg).exp() * lse.exp();
max_err = max_err.max((col_sum - b[j]).abs());
}
if max_err < tol {
let mut plan = Array2::zeros((m, n));
let mut distance = 0.0;
for i in 0..m {
for j in 0..n {
let log_p = (f[i] + g[j] - cost[[i, j]]) / reg;
let pij = log_p.exp();
plan[[i, j]] = pij;
distance += pij * cost[[i, j]];
}
}
return Ok((plan, distance, iter + 1));
}
}
}
Err(Error::SinkhornNotConverged(max_iter))
}
pub fn unbalanced_sinkhorn_log_with_convergence(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
rho: f32,
max_iter: usize,
tol: f32,
) -> Result<(Array2<f32>, f32, usize)> {
let m = a.len();
let n = b.len();
if cost.nrows() != m || cost.ncols() != n {
return Err(Error::CostShapeMismatch(m, n, cost.nrows(), cost.ncols()));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if rho <= 0.0 || !rho.is_finite() {
return Err(Error::InvalidMassPenalty(rho));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("unbalanced OT requires nonnegative masses"));
}
let a_sum = a.sum();
let b_sum = b.sum();
if a_sum <= 0.0 || b_sum <= 0.0 {
return Err(Error::Domain("unbalanced OT requires positive total mass"));
}
let alpha = rho / (rho + reg);
let log_a = a.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let log_b = b.mapv(|x| if x <= 0.0 { f32::NEG_INFINITY } else { x.ln() });
let mut log_v: Array1<f32> = Array1::zeros(n);
let check_every = 10usize.max(1);
for iter in 0..max_iter {
let mut log_u_new = Array1::zeros(m);
for i in 0..m {
if log_a[i] == f32::NEG_INFINITY {
log_u_new[i] = f32::NEG_INFINITY;
continue;
}
let lkv = logsumexp_by(n, |j| log_v[j] - (cost[[i, j]] / reg));
if lkv == f32::NEG_INFINITY {
log_u_new[i] = f32::NEG_INFINITY;
} else {
log_u_new[i] = alpha * (log_a[i] - lkv);
}
}
let mut log_v_new = Array1::zeros(n);
for j in 0..n {
if log_b[j] == f32::NEG_INFINITY {
log_v_new[j] = f32::NEG_INFINITY;
continue;
}
let lktu = logsumexp_by(m, |i| log_u_new[i] - (cost[[i, j]] / reg));
if lktu == f32::NEG_INFINITY {
log_v_new[j] = f32::NEG_INFINITY;
} else {
log_v_new[j] = alpha * (log_b[j] - lktu);
}
}
let log_u = log_u_new;
log_v = log_v_new;
if (iter + 1) % check_every == 0 || iter + 1 == max_iter {
let mut max_diff = 0.0f32;
for i in 0..m {
if log_a[i] == f32::NEG_INFINITY {
continue;
}
let lkv = logsumexp_by(n, |j| log_v[j] - (cost[[i, j]] / reg));
if lkv != f32::NEG_INFINITY {
let val = alpha * (log_a[i] - lkv);
max_diff = max_diff.max((val - log_u[i]).abs());
}
}
for j in 0..n {
if log_b[j] == f32::NEG_INFINITY {
continue;
}
let lktu = logsumexp_by(m, |i| log_u[i] - (cost[[i, j]] / reg));
if lktu != f32::NEG_INFINITY {
let val = alpha * (log_b[j] - lktu);
max_diff = max_diff.max((val - log_v[j]).abs());
}
}
if max_diff < tol {
let mut plan = Array2::zeros((m, n));
let mut transport_cost = 0.0;
for i in 0..m {
for j in 0..n {
if log_u[i] == f32::NEG_INFINITY || log_v[j] == f32::NEG_INFINITY {
continue;
}
let log_p = log_u[i] + log_v[j] - (cost[[i, j]] / reg);
let pij = log_p.exp();
plan[[i, j]] = pij;
transport_cost += pij * cost[[i, j]];
}
}
fn kl_mass(p: &Array1<f32>, q: &Array1<f32>) -> f32 {
let mut s: f64 = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if pi <= 1e-12 {
s += qi as f64;
continue;
}
if qi <= 0.0 {
return f32::INFINITY;
}
let pi64 = pi as f64;
let qi64 = qi as f64;
s += pi64 * (pi64 / qi64).ln() - pi64 + qi64;
}
s as f32
}
let row = plan.sum_axis(ndarray::Axis(1));
let col = plan.sum_axis(ndarray::Axis(0));
let kl_row = kl_mass(&row, a);
let kl_col = kl_mass(&col, b);
let mut kl_plan: f64 = 0.0;
let mut sum_k: f64 = 0.0;
for i in 0..m {
for j in 0..n {
let cij = cost[[i, j]] as f64;
sum_k += (-cij / (reg as f64)).exp();
let pij = plan[[i, j]] as f64;
if pij <= 1e-12 {
continue;
}
let log_k = -cij / (reg as f64);
kl_plan += pij * (pij.ln() - log_k) - pij;
}
}
kl_plan += sum_k;
let obj = transport_cost + reg * (kl_plan as f32) + rho * (kl_row + kl_col);
return Ok((plan, obj, iter + 1));
}
}
}
Err(Error::SinkhornNotConverged(max_iter))
}
pub fn unbalanced_sinkhorn_divergence_same_support(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
reg: f32,
rho: f32,
max_iter: usize,
tol: f32,
) -> Result<f32> {
let n = a.len();
if b.len() != n {
return Err(Error::LengthMismatch(n, b.len()));
}
if cost.nrows() != n || cost.ncols() != n {
return Err(Error::CostShapeMismatch(n, n, cost.nrows(), cost.ncols()));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if rho <= 0.0 || !rho.is_finite() {
return Err(Error::InvalidMassPenalty(rho));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("unbalanced OT requires nonnegative masses"));
}
if a.sum() <= 0.0 || b.sum() <= 0.0 {
return Err(Error::Domain("unbalanced OT requires positive total mass"));
}
fn log_weights(w: &Array1<f32>) -> Array1<f32> {
w.mapv(|x| if x <= 0.0 { -100000.0 } else { x.ln() })
}
fn softmin_xy(eps: f32, cost: &Array2<f32>, h_y: &Array1<f32>) -> Array1<f32> {
let m = cost.nrows();
let n = cost.ncols();
debug_assert_eq!(h_y.len(), n);
let mut out = Array1::zeros(m);
for i in 0..m {
let lse = logsumexp_by(n, |j| h_y[j] - cost[[i, j]] / eps);
out[i] = -eps * lse;
}
out
}
fn softmin_yx(eps: f32, cost: &Array2<f32>, h_x: &Array1<f32>) -> Array1<f32> {
let m = cost.nrows();
let n = cost.ncols();
debug_assert_eq!(h_x.len(), m);
let mut out = Array1::zeros(n);
for j in 0..n {
let lse = logsumexp_by(m, |i| h_x[i] - cost[[i, j]] / eps);
out[j] = -eps * lse;
}
out
}
let eps = reg;
let damping = rho / (rho + eps);
let a_log = log_weights(a);
let b_log = log_weights(b);
let mut g_ab = damping * softmin_yx(eps, cost, &a_log);
let mut f_ba = damping * softmin_xy(eps, cost, &b_log);
let mut f_aa = damping * softmin_xy(eps, cost, &a_log);
let mut g_bb = damping * softmin_xy(eps, cost, &b_log);
let check_every = 10usize.max(1);
for iter in 0..max_iter {
let h_b = &b_log + &(g_ab.mapv(|x| x / eps));
let h_a = &a_log + &(f_ba.mapv(|x| x / eps));
let ft_ba = damping * softmin_xy(eps, cost, &h_b);
let gt_ab = damping * softmin_yx(eps, cost, &h_a);
let h_aa = &a_log + &(f_aa.mapv(|x| x / eps));
let h_bb = &b_log + &(g_bb.mapv(|x| x / eps));
let ft_aa = damping * softmin_xy(eps, cost, &h_aa);
let gt_bb = damping * softmin_xy(eps, cost, &h_bb);
let f_ba_new = 0.5 * (&f_ba + &ft_ba);
let g_ab_new = 0.5 * (&g_ab + >_ab);
let f_aa_new = 0.5 * (&f_aa + &ft_aa);
let g_bb_new = 0.5 * (&g_bb + >_bb);
f_ba = f_ba_new;
g_ab = g_ab_new;
f_aa = f_aa_new;
g_bb = g_bb_new;
if (iter + 1) % check_every == 0 || iter + 1 == max_iter {
let mut max_diff = 0.0f32;
for i in 0..n {
max_diff = max_diff.max((f_ba[i] - ft_ba[i]).abs());
max_diff = max_diff.max((f_aa[i] - ft_aa[i]).abs());
}
for j in 0..n {
max_diff = max_diff.max((g_ab[j] - gt_ab[j]).abs());
max_diff = max_diff.max((g_bb[j] - gt_bb[j]).abs());
}
if max_diff < tol {
break;
}
}
}
let scale = rho + eps / 2.0;
let mut term_a: f64 = 0.0;
for i in 0..n {
let ai = a[i] as f64;
if ai == 0.0 {
continue;
}
let x = (-f_aa[i] / rho).exp() - (-f_ba[i] / rho).exp();
term_a += ai * (scale as f64) * (x as f64);
}
let mut term_b: f64 = 0.0;
for j in 0..n {
let bj = b[j] as f64;
if bj == 0.0 {
continue;
}
let x = (-g_bb[j] / rho).exp() - (-g_ab[j] / rho).exp();
term_b += bj * (scale as f64) * (x as f64);
}
let mass_corr = 0.5 * eps * (a.sum() - b.sum()).powi(2);
Ok((term_a + term_b) as f32 + mass_corr)
}
pub fn unbalanced_sinkhorn_divergence_general(
a: &Array1<f32>,
b: &Array1<f32>,
cost_ab: &Array2<f32>,
cost_aa: &Array2<f32>,
cost_bb: &Array2<f32>,
reg: f32,
rho: f32,
max_iter: usize,
tol: f32,
) -> Result<f32> {
let m = a.len();
let n = b.len();
if cost_ab.nrows() != m || cost_ab.ncols() != n {
return Err(Error::CostShapeMismatch(
m,
n,
cost_ab.nrows(),
cost_ab.ncols(),
));
}
if cost_aa.nrows() != m || cost_aa.ncols() != m {
return Err(Error::CostShapeMismatch(
m,
m,
cost_aa.nrows(),
cost_aa.ncols(),
));
}
if cost_bb.nrows() != n || cost_bb.ncols() != n {
return Err(Error::CostShapeMismatch(
n,
n,
cost_bb.nrows(),
cost_bb.ncols(),
));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if rho <= 0.0 || !rho.is_finite() {
return Err(Error::InvalidMassPenalty(rho));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("unbalanced OT requires nonnegative masses"));
}
if a.sum() <= 0.0 || b.sum() <= 0.0 {
return Err(Error::Domain("unbalanced OT requires positive total mass"));
}
fn log_weights(w: &Array1<f32>) -> Array1<f32> {
w.mapv(|x| if x <= 0.0 { -100000.0 } else { x.ln() })
}
fn softmin_rows(eps: f32, cost: &Array2<f32>, h: &Array1<f32>) -> Array1<f32> {
let m = cost.nrows();
let n = cost.ncols();
debug_assert_eq!(h.len(), n);
let mut out = Array1::zeros(m);
for i in 0..m {
let lse = logsumexp_by(n, |j| h[j] - cost[[i, j]] / eps);
out[i] = -eps * lse;
}
out
}
fn softmin_cols(eps: f32, cost: &Array2<f32>, h: &Array1<f32>) -> Array1<f32> {
let m = cost.nrows();
let n = cost.ncols();
debug_assert_eq!(h.len(), m);
let mut out = Array1::zeros(n);
for j in 0..n {
let lse = logsumexp_by(m, |i| h[i] - cost[[i, j]] / eps);
out[j] = -eps * lse;
}
out
}
let eps = reg;
let damping = rho / (rho + eps);
let a_log = log_weights(a);
let b_log = log_weights(b);
let mut g_ab = damping * softmin_cols(eps, cost_ab, &a_log);
let mut f_ba = damping * softmin_rows(eps, cost_ab, &b_log);
let mut f_aa = damping * softmin_rows(eps, cost_aa, &a_log);
let mut g_bb = damping * softmin_cols(eps, cost_bb, &b_log);
let check_every = 10usize.max(1);
for iter in 0..max_iter {
let h_b = &b_log + &(g_ab.mapv(|x| x / eps));
let h_a = &a_log + &(f_ba.mapv(|x| x / eps));
let ft_ba = damping * softmin_rows(eps, cost_ab, &h_b);
let gt_ab = damping * softmin_cols(eps, cost_ab, &h_a);
let h_aa = &a_log + &(f_aa.mapv(|x| x / eps));
let ft_aa = damping * softmin_rows(eps, cost_aa, &h_aa);
let h_bb = &b_log + &(g_bb.mapv(|x| x / eps));
let gt_bb = damping * softmin_cols(eps, cost_bb, &h_bb);
let f_ba_new = 0.5 * (&f_ba + &ft_ba);
let g_ab_new = 0.5 * (&g_ab + >_ab);
let f_aa_new = 0.5 * (&f_aa + &ft_aa);
let g_bb_new = 0.5 * (&g_bb + >_bb);
f_ba = f_ba_new;
g_ab = g_ab_new;
f_aa = f_aa_new;
g_bb = g_bb_new;
if (iter + 1) % check_every == 0 || iter + 1 == max_iter {
let mut max_diff = 0.0f32;
for i in 0..m {
max_diff = max_diff.max((f_ba[i] - ft_ba[i]).abs());
max_diff = max_diff.max((f_aa[i] - ft_aa[i]).abs());
}
for j in 0..n {
max_diff = max_diff.max((g_ab[j] - gt_ab[j]).abs());
max_diff = max_diff.max((g_bb[j] - gt_bb[j]).abs());
}
if max_diff < tol {
break;
}
}
}
let scale = rho + eps / 2.0;
let mut term_a: f64 = 0.0;
for i in 0..m {
let ai = a[i] as f64;
if ai == 0.0 {
continue;
}
let x = (-f_aa[i] / rho).exp() - (-f_ba[i] / rho).exp();
term_a += ai * (scale as f64) * (x as f64);
}
let mut term_b: f64 = 0.0;
for j in 0..n {
let bj = b[j] as f64;
if bj == 0.0 {
continue;
}
let x = (-g_bb[j] / rho).exp() - (-g_ab[j] / rho).exp();
term_b += bj * (scale as f64) * (x as f64);
}
let mass_corr = 0.5 * eps * (a.sum() - b.sum()).powi(2);
Ok((term_a + term_b) as f32 + mass_corr)
}
#[derive(Debug, Clone)]
pub struct LowRankCoupling {
pub q: Vec<f32>,
pub g: Vec<f32>,
pub r: Vec<f32>,
pub cost: f32,
pub iterations: usize,
n: usize,
m: usize,
rank: usize,
}
impl LowRankCoupling {
pub fn to_dense(&self) -> Vec<f32> {
let mut p = vec![0.0f32; self.n * self.m];
for i in 0..self.n {
for j in 0..self.m {
let mut val = 0.0f32;
for k in 0..self.rank {
val += self.q[i * self.rank + k] * self.g[k] * self.r[j * self.rank + k];
}
p[i * self.m + j] = val;
}
}
p
}
pub fn apply(&self, v: &[f32]) -> Vec<f32> {
assert_eq!(v.len(), self.m, "v must have length m={}", self.m);
let mut t = vec![0.0f32; self.rank];
for j in 0..self.m {
for k in 0..self.rank {
t[k] += self.r[j * self.rank + k] * v[j];
}
}
for k in 0..self.rank {
t[k] *= self.g[k];
}
let mut result = vec![0.0f32; self.n];
for i in 0..self.n {
for k in 0..self.rank {
result[i] += self.q[i * self.rank + k] * t[k];
}
}
result
}
pub fn row_marginals(&self) -> Vec<f32> {
let ones = vec![1.0f32; self.m];
self.apply(&ones)
}
pub fn col_marginals(&self) -> Vec<f32> {
let mut s = vec![0.0f32; self.rank];
for i in 0..self.n {
for k in 0..self.rank {
s[k] += self.q[i * self.rank + k];
}
}
for k in 0..self.rank {
s[k] *= self.g[k];
}
let mut result = vec![0.0f32; self.m];
for j in 0..self.m {
for k in 0..self.rank {
result[j] += self.r[j * self.rank + k] * s[k];
}
}
result
}
}
pub fn sinkhorn_low_rank(
a: &[f32],
b: &[f32],
cost: &[f32],
reg: f32,
rank: usize,
max_iter: usize,
tol: f32,
) -> Result<LowRankCoupling> {
let n = a.len();
let m = b.len();
if cost.len() != n * m {
return Err(Error::CostShapeMismatch(n, m, n, cost.len() / n.max(1)));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if rank < 1 || rank > n.min(m) {
return Err(Error::InvalidRank(rank, n, m));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("sinkhorn requires nonnegative masses"));
}
let a_sum: f32 = a.iter().sum();
let b_sum: f32 = b.iter().sum();
if a_sum <= 0.0 || b_sum <= 0.0 {
return Err(Error::Domain("sinkhorn requires positive total mass"));
}
let a_norm: Vec<f32> = a.iter().map(|&x| x / (a_sum + EPSILON)).collect();
let b_norm: Vec<f32> = b.iter().map(|&x| x / (b_sum + EPSILON)).collect();
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Uniform};
let mut rng = ChaCha8Rng::seed_from_u64(0xCAFE);
let log_k: Vec<f32> = cost.iter().map(|&c| -c / reg).collect();
let col_indices: Vec<usize> = if rank >= m {
(0..rank.min(m)).collect()
} else {
let unif_m = Uniform::new(0usize, m).unwrap();
(0..rank).map(|_| unif_m.sample(&mut rng)).collect()
};
let row_indices: Vec<usize> = if rank >= n {
(0..rank.min(n)).collect()
} else {
let unif_n = Uniform::new(0usize, n).unwrap();
(0..rank).map(|_| unif_n.sample(&mut rng)).collect()
};
let mut q = vec![0.0f32; n * rank];
for k in 0..rank {
let jk = col_indices[k];
for i in 0..n {
q[i * rank + k] = log_k[i * m + jk].exp().max(EPSILON) * a_norm[i].max(EPSILON);
}
}
let mut r = vec![0.0f32; m * rank];
for k in 0..rank {
let ik = row_indices[k];
for j in 0..m {
r[j * rank + k] = log_k[ik * m + j].exp().max(EPSILON) * b_norm[j].max(EPSILON);
}
}
let g = vec![1.0f32; rank];
let mut iterations = max_iter;
for iter in 0..max_iter {
let mut v = vec![0.0f32; rank];
for k in 0..rank {
let mut s = 0.0f32;
for j in 0..m {
s += r[j * rank + k];
}
v[k] = g[k] * s;
}
let mut max_row_err = 0.0f32;
for i in 0..n {
let mut row_sum = 0.0f32;
for k in 0..rank {
row_sum += q[i * rank + k] * v[k];
}
if a_norm[i] > 0.0 && row_sum > EPSILON {
let scale = a_norm[i] / row_sum;
max_row_err = max_row_err.max((1.0 - scale).abs() * a_norm[i]);
for k in 0..rank {
q[i * rank + k] *= scale;
}
}
}
let mut u = vec![0.0f32; rank];
for k in 0..rank {
let mut s = 0.0f32;
for i in 0..n {
s += q[i * rank + k];
}
u[k] = g[k] * s;
}
let mut max_col_err = 0.0f32;
for j in 0..m {
let mut col_sum = 0.0f32;
for k in 0..rank {
col_sum += r[j * rank + k] * u[k];
}
if b_norm[j] > 0.0 && col_sum > EPSILON {
let scale = b_norm[j] / col_sum;
max_col_err = max_col_err.max((1.0 - scale).abs() * b_norm[j]);
for k in 0..rank {
r[j * rank + k] *= scale;
}
}
}
let max_err = max_row_err.max(max_col_err);
if max_err < tol {
iterations = iter + 1;
break;
}
}
let mut transport_cost = 0.0f32;
for k in 0..rank {
let mut s = 0.0f64;
for i in 0..n {
let qik = q[i * rank + k] as f64;
if qik < 1e-12 {
continue;
}
let mut cr = 0.0f64;
for j in 0..m {
cr += cost[i * m + j] as f64 * r[j * rank + k] as f64;
}
s += qik * cr;
}
transport_cost += (g[k] as f64 * s) as f32;
}
Ok(LowRankCoupling {
q,
g,
r,
cost: transport_cost,
iterations,
n,
m,
rank,
})
}
pub fn wasserstein_1d_samples(a: &[f32], b: &[f32], p: f32) -> f32 {
assert_eq!(a.len(), b.len(), "samples must have same length");
assert!(p >= 1.0, "p must be >= 1.0, got {}", p);
let n = a.len();
if n == 0 {
return 0.0;
}
let mut sa: Vec<f32> = a.to_vec();
let mut sb: Vec<f32> = b.to_vec();
sa.sort_by(|x, y| x.total_cmp(y));
sb.sort_by(|x, y| x.total_cmp(y));
if (p - 1.0).abs() < 1e-7 {
let sum: f32 = sa.iter().zip(sb.iter()).map(|(x, y)| (x - y).abs()).sum();
sum / n as f32
} else {
let sum: f32 = sa
.iter()
.zip(sb.iter())
.map(|(x, y)| (x - y).abs().powf(p))
.sum();
(sum / n as f32).powf(1.0 / p)
}
}
fn project_and_sort(
points: &Array2<f32>,
direction: &Array1<f32>,
#[allow(unused_variables)] n_points: usize,
) -> Vec<f32> {
#[cfg(feature = "simd")]
let mut proj = {
let mut v = Vec::with_capacity(n_points);
for i in 0..n_points {
v.push(innr::dense::dot(
points.row(i).as_slice().unwrap(),
direction.as_slice().unwrap(),
));
}
v
};
#[cfg(not(feature = "simd"))]
let mut proj = points.dot(direction).to_vec();
proj.sort_by(|a, b| a.total_cmp(b));
proj
}
fn random_unit_direction(d: usize, rng: &mut impl rand::Rng) -> Array1<f32> {
use rand_distr::{Distribution, StandardNormal};
let mut direction: Array1<f32> = Array1::zeros(d);
for i in 0..d {
direction[i] = StandardNormal.sample(rng);
}
#[cfg(feature = "simd")]
let norm = innr::dense::norm(direction.as_slice().unwrap());
#[cfg(not(feature = "simd"))]
let norm = direction.dot(&direction).sqrt();
direction /= norm.max(EPSILON);
direction
}
fn w_p_sorted(proj_x: &[f32], proj_y: &[f32], p: f32) -> f32 {
let n = proj_x.len().min(proj_y.len());
if n == 0 {
return 0.0;
}
if (p - 1.0).abs() < 1e-7 {
let sum: f32 = (0..n).map(|i| (proj_x[i] - proj_y[i]).abs()).sum();
sum / n as f32
} else {
let sum: f32 = (0..n).map(|i| (proj_x[i] - proj_y[i]).abs().powf(p)).sum();
(sum / n as f32).powf(1.0 / p)
}
}
pub fn sliced_wasserstein(
x: &Array2<f32>,
y: &Array2<f32>,
n_projections: usize,
seed: u64,
p: f32,
) -> f32 {
let d = x.ncols();
assert_eq!(y.ncols(), d, "point dimensions must match");
assert!(p >= 1.0, "p must be >= 1.0, got {}", p);
let m = x.nrows();
let n = y.nrows();
if m == 0 || n == 0 || n_projections == 0 {
return 0.0;
}
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut total = 0.0f64;
for _ in 0..n_projections {
let direction = random_unit_direction(d, &mut rng);
let proj_x = project_and_sort(x, &direction, m);
let proj_y = project_and_sort(y, &direction, n);
let wp = w_p_sorted(&proj_x, &proj_y, p);
total += (wp as f64).powi(p as i32);
}
(total / n_projections as f64).powf(1.0 / p as f64) as f32
}
pub fn max_sliced_wasserstein(
x: &Array2<f32>,
y: &Array2<f32>,
max_iter: usize,
seed: u64,
p: f32,
) -> f32 {
let d = x.ncols();
assert_eq!(y.ncols(), d, "point dimensions must match");
assert!(p >= 1.0, "p must be >= 1.0, got {}", p);
let m = x.nrows();
let n = y.nrows();
if m == 0 || n == 0 || max_iter == 0 {
return 0.0;
}
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut best = 0.0f32;
for _ in 0..max_iter {
let direction = random_unit_direction(d, &mut rng);
let proj_x = project_and_sort(x, &direction, m);
let proj_y = project_and_sort(y, &direction, n);
let wp = w_p_sorted(&proj_x, &proj_y, p);
best = best.max(wp);
}
best
}
pub fn sinkhorn_hierarchical(
a: &[f32],
b: &[f32],
cost: &[f32],
n: usize,
m: usize,
reg: f32,
branching: usize,
max_depth: usize,
max_iter: usize,
tol: f32,
) -> Result<(f32, Vec<f32>)> {
if cost.len() != n * m {
return Err(Error::CostShapeMismatch(n, m, n, cost.len() / n.max(1)));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if a.len() != n || b.len() != m {
return Err(Error::CostShapeMismatch(n, m, a.len(), b.len()));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("sinkhorn requires nonnegative masses"));
}
let a_sum: f32 = a.iter().sum();
let b_sum: f32 = b.iter().sum();
if a_sum <= 0.0 || b_sum <= 0.0 {
return Err(Error::Domain("sinkhorn requires positive total mass"));
}
if branching < 2 || branching > n.min(m) {
return Err(Error::InvalidBranching(branching, n, m));
}
let a_norm: Vec<f32> = a.iter().map(|&x| x / a_sum).collect();
let b_norm: Vec<f32> = b.iter().map(|&x| x / b_sum).collect();
let mut coupling = vec![0.0f32; n * m];
hierarchical_recurse(
&a_norm,
&b_norm,
cost,
n,
m,
&(0..n).collect::<Vec<_>>(),
&(0..m).collect::<Vec<_>>(),
reg,
branching,
max_depth,
max_iter,
tol,
1.0, &mut coupling,
)?;
let transport_cost: f32 = coupling.iter().zip(cost.iter()).map(|(&p, &c)| p * c).sum();
Ok((transport_cost, coupling))
}
fn partition_by_cost_projection(
cost_row: impl Fn(usize) -> f32,
indices: &[usize],
k: usize,
) -> Vec<Vec<usize>> {
let mut indexed: Vec<(usize, f32)> = indices.iter().map(|&i| (i, cost_row(i))).collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let total = indexed.len();
let base_size = total / k;
let remainder = total % k;
let mut groups = Vec::with_capacity(k);
let mut offset = 0;
for g in 0..k {
let size = base_size + if g < remainder { 1 } else { 0 };
let group: Vec<usize> = indexed[offset..offset + size]
.iter()
.map(|&(i, _)| i)
.collect();
if !group.is_empty() {
groups.push(group);
}
offset += size;
}
groups
}
fn group_cost(src_group: &[usize], tgt_group: &[usize], cost: &[f32], m_cols: usize) -> f32 {
if src_group.is_empty() || tgt_group.is_empty() {
return 0.0;
}
let mut total = 0.0f32;
for &i in src_group {
for &j in tgt_group {
total += cost[i * m_cols + j];
}
}
total / (src_group.len() * tgt_group.len()) as f32
}
fn solve_local_subproblem(
a: &[f32],
b: &[f32],
cost: &[f32],
full_m: usize,
src_idx: &[usize],
tgt_idx: &[usize],
reg: f32,
max_iter: usize,
tol: f32,
mass: f32,
coupling: &mut [f32],
) {
let local_n = src_idx.len();
let local_m = tgt_idx.len();
let local_a: Vec<f32> = src_idx.iter().map(|&i| a[i]).collect();
let local_b: Vec<f32> = tgt_idx.iter().map(|&j| b[j]).collect();
let a_sum: f32 = local_a.iter().sum();
let b_sum: f32 = local_b.iter().sum();
if a_sum <= 0.0 || b_sum <= 0.0 {
return;
}
let a_local: Array1<f32> = Array1::from_vec(local_a.iter().map(|&x| x / a_sum).collect());
let b_local: Array1<f32> = Array1::from_vec(local_b.iter().map(|&x| x / b_sum).collect());
let mut local_cost = Array2::zeros((local_n, local_m));
for (li, &gi) in src_idx.iter().enumerate() {
for (lj, &gj) in tgt_idx.iter().enumerate() {
local_cost[[li, lj]] = cost[gi * full_m + gj];
}
}
let (plan, _, _) =
sinkhorn_log_with_convergence(&a_local, &b_local, &local_cost, reg, max_iter, tol)
.unwrap_or_else(|_| {
let (plan, dist) = sinkhorn_log(&a_local, &b_local, &local_cost, reg, max_iter);
(plan, dist, max_iter)
});
for (li, &gi) in src_idx.iter().enumerate() {
for (lj, &gj) in tgt_idx.iter().enumerate() {
coupling[gi * full_m + gj] += plan[[li, lj]] * mass;
}
}
}
fn hierarchical_recurse(
a: &[f32],
b: &[f32],
cost: &[f32],
full_n: usize,
full_m: usize,
src_idx: &[usize],
tgt_idx: &[usize],
reg: f32,
branching: usize,
depth: usize,
max_iter: usize,
tol: f32,
mass: f32,
coupling: &mut [f32],
) -> Result<()> {
let local_n = src_idx.len();
let local_m = tgt_idx.len();
if depth == 0 || local_n <= branching || local_m <= branching {
solve_local_subproblem(
a, b, cost, full_m, src_idx, tgt_idx, reg, max_iter, tol, mass, coupling,
);
return Ok(());
}
let src_groups = partition_by_cost_projection(
|i| {
let mut s = 0.0f32;
for &j in tgt_idx {
s += cost[i * full_m + j];
}
s / local_m as f32
},
src_idx,
branching.min(local_n),
);
let tgt_groups = partition_by_cost_projection(
|j| {
let mut s = 0.0f32;
for &i in src_idx {
s += cost[i * full_m + j];
}
s / local_n as f32
},
tgt_idx,
branching.min(local_m),
);
let k_src = src_groups.len();
let k_tgt = tgt_groups.len();
let coarse_a: Vec<f32> = src_groups
.iter()
.map(|g| g.iter().map(|&i| a[i]).sum::<f32>())
.collect();
let coarse_b: Vec<f32> = tgt_groups
.iter()
.map(|g| g.iter().map(|&j| b[j]).sum::<f32>())
.collect();
let coarse_a_sum: f32 = coarse_a.iter().sum();
let coarse_b_sum: f32 = coarse_b.iter().sum();
if coarse_a_sum <= 0.0 || coarse_b_sum <= 0.0 {
return Ok(());
}
let coarse_a_norm: Array1<f32> =
Array1::from_vec(coarse_a.iter().map(|&x| x / coarse_a_sum).collect());
let coarse_b_norm: Array1<f32> =
Array1::from_vec(coarse_b.iter().map(|&x| x / coarse_b_sum).collect());
let mut coarse_cost = Array2::zeros((k_src, k_tgt));
for (si, sg) in src_groups.iter().enumerate() {
for (ti, tg) in tgt_groups.iter().enumerate() {
coarse_cost[[si, ti]] = group_cost(sg, tg, cost, full_m);
}
}
let (coarse_plan, _, _) = sinkhorn_log_with_convergence(
&coarse_a_norm,
&coarse_b_norm,
&coarse_cost,
reg,
max_iter,
tol,
)
.unwrap_or_else(|_| {
let (plan, dist) =
sinkhorn_log(&coarse_a_norm, &coarse_b_norm, &coarse_cost, reg, max_iter);
(plan, dist, max_iter)
});
let coarse_threshold = 1e-8;
for (si, sg) in src_groups.iter().enumerate() {
for (ti, tg) in tgt_groups.iter().enumerate() {
let w = coarse_plan[[si, ti]];
if w < coarse_threshold {
continue;
}
hierarchical_recurse(
a,
b,
cost,
full_n,
full_m,
sg,
tg,
reg,
branching,
depth - 1,
max_iter,
tol,
mass * w,
coupling,
)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use proptest::prelude::*;
#[test]
fn test_wasserstein_1d_same() {
let a = [0.25, 0.25, 0.25, 0.25];
let w = wasserstein_1d(&a, &a);
assert!(w < 1e-7, "same distribution should have 0 distance");
}
#[test]
fn test_wasserstein_1d_shift() {
let a = [1.0, 0.0, 0.0, 0.0];
let b = [0.0, 0.0, 0.0, 1.0];
let w = wasserstein_1d(&a, &b);
assert!(
(w - 3.0).abs() < 0.01,
"point mass shift of 3 should have distance ~3"
);
}
#[test]
fn test_sinkhorn_basic() {
let a = array![0.5, 0.5];
let b = array![0.5, 0.5];
let cost = array![[0.0, 1.0], [1.0, 0.0]];
let (plan, distance) = sinkhorn(&a, &b, &cost, 0.1, 100);
let plan_sum: f32 = plan.iter().sum();
assert!((plan_sum - 1.0).abs() < 0.01, "plan should sum to 1");
let row_sums: Vec<f32> = (0..2).map(|i| plan.row(i).sum()).collect();
assert!((row_sums[0] - 0.5).abs() < 0.1);
assert!((row_sums[1] - 0.5).abs() < 0.1);
assert!((0.0..1.0).contains(&distance));
}
#[test]
fn test_sinkhorn_identical() {
let a = array![0.5, 0.5];
let b = array![0.5, 0.5];
let cost = array![[0.0, 1.0], [1.0, 0.0]];
let (_, distance) = sinkhorn(&a, &b, &cost, 0.1, 100);
assert!(
distance < 0.5,
"identical distributions should have low OT cost"
);
}
#[test]
fn test_euclidean_cost_matrix() {
let x = array![[0.0, 0.0], [1.0, 0.0]];
let y = array![[0.0, 0.0], [0.0, 1.0]];
let cost = euclidean_cost_matrix(&x, &y);
assert!((cost[[0, 0]] - 0.0).abs() < 1e-7);
assert!((cost[[0, 1]] - 1.0).abs() < 1e-7);
assert!((cost[[1, 0]] - 1.0).abs() < 1e-7);
assert!((cost[[1, 1]] - 2.0_f32.sqrt()).abs() < 1e-7);
}
#[test]
fn test_sq_euclidean_cost_matrix() {
let x = array![[0.0, 0.0], [1.0, 0.0]];
let y = array![[0.0, 0.0], [0.0, 1.0]];
let cost = sq_euclidean_cost_matrix(&x, &y);
assert!((cost[[0, 0]] - 0.0).abs() < 1e-7);
assert!((cost[[0, 1]] - 1.0).abs() < 1e-7); assert!((cost[[1, 0]] - 1.0).abs() < 1e-7); assert!((cost[[1, 1]] - 2.0).abs() < 1e-7); }
#[test]
fn test_sq_vs_euclidean_cost_matrix() {
let x = array![[0.0, 0.0], [1.0, 0.5], [0.3, -0.7]];
let y = array![[0.5, 0.5], [-1.0, 0.0], [0.0, 1.0]];
let l2 = euclidean_cost_matrix(&x, &y);
let sq = sq_euclidean_cost_matrix(&x, &y);
for i in 0..3 {
for j in 0..3 {
let expected = l2[[i, j]] * l2[[i, j]];
assert!(
(sq[[i, j]] - expected).abs() < 1e-5,
"mismatch at ({i},{j}): sq={} expected={}",
sq[[i, j]],
expected
);
}
}
}
#[test]
fn w1d_samples_shift() {
let w = wasserstein_1d_samples(&[0.0, 1.0], &[1.0, 2.0], 1.0);
assert!((w - 1.0).abs() < 1e-6, "W1 shift: {}", w);
}
#[test]
fn w1d_samples_w2_shift() {
let w = wasserstein_1d_samples(&[0.0, 1.0], &[1.0, 2.0], 2.0);
assert!((w - 1.0).abs() < 1e-6, "W2 shift: {}", w);
}
#[test]
fn w1d_samples_self_distance() {
let a = [0.5, 1.3, -0.2, 4.0];
let w = wasserstein_1d_samples(&a, &a, 1.0);
assert!(w < 1e-7, "self-distance: {}", w);
}
#[test]
fn w1d_samples_unsorted_input() {
let a = [3.0, 1.0, 2.0];
let b = [6.0, 4.0, 5.0];
let w = wasserstein_1d_samples(&a, &b, 1.0);
assert!((w - 3.0).abs() < 1e-6, "unsorted W1: {}", w);
}
#[test]
fn test_sliced_wasserstein_self_distance() {
let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 3.0]];
let sw = sliced_wasserstein(&x, &x, 50, 42, 1.0);
assert!(sw < 1e-5, "self-distance should be ~0: {}", sw);
}
#[test]
fn test_sliced_wasserstein_symmetric() {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let y = array![[5.0, 5.0], [6.0, 6.0]];
let sw_xy = sliced_wasserstein(&x, &y, 100, 42, 1.0);
let sw_yx = sliced_wasserstein(&y, &x, 100, 42, 1.0);
assert!(
(sw_xy - sw_yx).abs() < 1e-5,
"symmetry: {} vs {}",
sw_xy,
sw_yx
);
}
#[test]
fn test_sliced_wasserstein_nonneg() {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let y = array![[10.0, 10.0], [11.0, 11.0]];
let sw = sliced_wasserstein(&x, &y, 50, 42, 1.0);
assert!(sw >= 0.0, "non-negative: {}", sw);
}
#[test]
fn test_sliced_wasserstein_separation() {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let y_close = array![[0.1, 0.1], [1.1, 1.1]];
let y_far = array![[10.0, 10.0], [11.0, 11.0]];
let sw_close = sliced_wasserstein(&x, &y_close, 100, 42, 1.0);
let sw_far = sliced_wasserstein(&x, &y_far, 100, 42, 1.0);
assert!(
sw_far > sw_close,
"distant > close: {} vs {}",
sw_far,
sw_close
);
}
#[test]
fn test_sliced_wasserstein_w2() {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let y = array![[10.0, 10.0], [11.0, 11.0]];
let sw1 = sliced_wasserstein(&x, &y, 100, 42, 1.0);
let sw2 = sliced_wasserstein(&x, &y, 100, 42, 2.0);
assert!(sw1 > 5.0, "SW1 should be large: {}", sw1);
assert!(sw2 > 5.0, "SW2 should be large: {}", sw2);
}
#[test]
fn test_max_sliced_self_distance() {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let msw = max_sliced_wasserstein(&x, &x, 50, 42, 1.0);
assert!(msw < 1e-5, "max-sliced self-distance: {}", msw);
}
#[test]
fn test_max_sliced_ge_sliced() {
let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let y = array![[5.0, 5.0], [6.0, 5.0], [5.0, 6.0]];
let n_proj = 200;
let sw = sliced_wasserstein(&x, &y, n_proj, 42, 1.0);
let msw = max_sliced_wasserstein(&x, &y, n_proj, 42, 1.0);
assert!(
msw >= sw - 1e-5,
"max-sliced ({}) should >= sliced ({})",
msw,
sw
);
}
#[test]
fn test_max_sliced_axis_aligned() {
let x = array![[0.0, 0.0], [1.0, 0.0]];
let y = array![[10.0, 0.0], [11.0, 0.0]];
let msw = max_sliced_wasserstein(&x, &y, 200, 42, 1.0);
assert!(msw > 9.0, "axis-aligned max-sliced: {}", msw);
}
proptest! {
#[test]
fn prop_sinkhorn_divergence_non_negative(
(a, b) in (2usize..8).prop_flat_map(|n| {
(
prop::collection::vec(0.0f32..1.0, n),
prop::collection::vec(0.0f32..1.0, n),
)
}),
) {
let n = a.len();
let mut a_dist = Array1::from_vec(a);
let mut b_dist = Array1::from_vec(b);
let sa = a_dist.sum();
let sb = b_dist.sum();
if sa > 0.0 { a_dist /= sa; } else { a_dist[0] = 1.0; }
if sb > 0.0 { b_dist /= sb; } else { b_dist[0] = 1.0; }
let mut cost = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
cost[[i, j]] = (i as f32 - j as f32).abs();
}
}
let div = sinkhorn_divergence_same_support(&a_dist, &b_dist, &cost, 0.1, 2000, 1e-2).unwrap();
prop_assert!(div >= -1e-6);
}
#[test]
fn logsumexp_translation_invariant(
xs in prop::collection::vec(-50.0f32..50.0, 1..64),
shift in -10.0f32..10.0
) {
let l1 = logsumexp_by(xs.len(), |i| xs[i]);
let l2 = logsumexp_by(xs.len(), |i| xs[i] + shift);
prop_assert!((l2 - (l1 + shift)).abs() < 1e-5);
}
#[test]
fn logsumexp_matches_naive_on_safe_range(
xs in prop::collection::vec(-20.0f32..20.0, 1..64),
) {
let naive = xs.iter().map(|&x| x.exp()).sum::<f32>().ln();
let stable = logsumexp_by(xs.len(), |i| xs[i]);
prop_assert!((stable - naive).abs() < 1e-5);
}
#[test]
fn logsumexp_bounds_by_max(
xs in prop::collection::vec(-50.0f32..50.0, 1..64),
) {
let max = xs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let lse = logsumexp_by(xs.len(), |i| xs[i]);
prop_assert!(lse >= max - 1e-5);
prop_assert!(lse <= max + (xs.len() as f32).ln() + 1e-5);
}
}
proptest! {
#[test]
fn prop_w1d_samples_nonneg(
a in prop::collection::vec(-100.0f32..100.0, 2..32),
b in prop::collection::vec(-100.0f32..100.0, 2..32),
) {
let n = a.len().min(b.len());
let w = wasserstein_1d_samples(&a[..n], &b[..n], 1.0);
prop_assert!(w >= -1e-7, "non-negative: {}", w);
}
#[test]
fn prop_w1d_samples_symmetric(
a in prop::collection::vec(-100.0f32..100.0, 2..32),
b in prop::collection::vec(-100.0f32..100.0, 2..32),
) {
let n = a.len().min(b.len());
let ab = wasserstein_1d_samples(&a[..n], &b[..n], 1.0);
let ba = wasserstein_1d_samples(&b[..n], &a[..n], 1.0);
prop_assert!((ab - ba).abs() < 1e-5, "symmetric: {} vs {}", ab, ba);
}
#[test]
fn prop_w1d_samples_self_zero(
a in prop::collection::vec(-100.0f32..100.0, 2..32),
) {
let w = wasserstein_1d_samples(&a, &a, 1.0);
prop_assert!(w < 1e-6, "self-distance: {}", w);
}
#[test]
fn prop_sliced_wasserstein_nonneg(
seed in 0u64..1000,
) {
let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 0.5]];
let y = array![[3.0, 3.0], [4.0, 4.0], [5.0, 3.5]];
let sw = sliced_wasserstein(&x, &y, 20, seed, 1.0);
prop_assert!(sw >= -1e-7, "non-negative: {}", sw);
}
#[test]
fn prop_sliced_wasserstein_self_zero(
seed in 0u64..1000,
) {
let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 0.5]];
let sw = sliced_wasserstein(&x, &x, 20, seed, 1.0);
prop_assert!(sw < 1e-5, "self-distance: {}", sw);
}
#[test]
fn prop_sliced_wasserstein_symmetric(
seed in 0u64..1000,
) {
let x = array![[0.0, 0.0], [1.0, 1.0]];
let y = array![[3.0, 3.0], [4.0, 4.0]];
let sw_xy = sliced_wasserstein(&x, &y, 20, seed, 1.0);
let sw_yx = sliced_wasserstein(&y, &x, 20, seed, 1.0);
prop_assert!((sw_xy - sw_yx).abs() < 1e-5, "symmetric: {} vs {}", sw_xy, sw_yx);
}
}
#[test]
fn earth_mover_distance_identical_is_zero() {
let a = array![0.5, 0.5];
let cost = array![[0.0, 1.0], [1.0, 0.0]];
let emd = earth_mover_distance(&a, &a, &cost);
assert!(emd < 0.05, "identical distributions: emd={}", emd);
}
#[test]
fn earth_mover_distance_shifted_distributions() {
let a = array![0.7, 0.3];
let b = array![0.3, 0.7];
let cost = array![[0.0, 1.0], [1.0, 0.0]];
let emd = earth_mover_distance(&a, &b, &cost);
assert!(
emd > 0.2,
"shifted distributions should have positive cost: emd={}",
emd
);
assert!(
emd < 0.6,
"cost bounded by total mass * max cost: emd={}",
emd
);
}
#[test]
fn earth_mover_distance_point_mass_shift() {
let a = array![1.0, 0.0];
let b = array![0.0, 1.0];
let cost = array![[0.0, 3.0], [3.0, 0.0]];
let emd = earth_mover_distance(&a, &b, &cost);
assert!(
(emd - 3.0).abs() < 0.2,
"point mass shift of 3: emd={}",
emd
);
}
#[test]
fn sinkhorn_log_plan_has_valid_marginals() {
let a = array![0.3, 0.5, 0.2];
let b = array![0.4, 0.4, 0.2];
let cost = array![[0.0, 1.0, 2.0], [1.0, 0.0, 1.5], [2.0, 1.5, 0.0]];
let (plan, _, _) = sinkhorn_log_with_convergence(&a, &b, &cost, 0.1, 2000, 1e-4).unwrap();
for i in 0..3 {
let row_sum: f32 = plan.row(i).sum();
assert!(
(row_sum - a[i]).abs() < 0.02,
"row {} sum={}, expected={}",
i,
row_sum,
a[i]
);
}
for j in 0..3 {
let col_sum: f32 = plan.column(j).sum();
assert!(
(col_sum - b[j]).abs() < 0.02,
"col {} sum={}, expected={}",
j,
col_sum,
b[j]
);
}
}
#[test]
fn sinkhorn_log_plan_is_nonneg() {
let a = array![0.5, 0.5];
let b = array![0.3, 0.7];
let cost = array![[0.0, 2.0], [2.0, 0.0]];
let (plan, _, _) = sinkhorn_log_with_convergence(&a, &b, &cost, 0.05, 500, 1e-6).unwrap();
assert!(
plan.iter().all(|&p| p >= -1e-7),
"plan has negative entries"
);
}
#[test]
fn wasserstein_1d_triangle_inequality() {
let a = [1.0, 0.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0, 0.0];
let c = [0.0, 0.0, 0.0, 1.0];
let ab = wasserstein_1d(&a, &b);
let bc = wasserstein_1d(&b, &c);
let ac = wasserstein_1d(&a, &c);
assert!(
ac <= ab + bc + 1e-6,
"triangle inequality: {ac} > {ab} + {bc}"
);
}
#[test]
fn sinkhorn_divergence_zero_on_diagonal_same_support() {
let a = array![0.2, 0.3, 0.5];
let cost = array![[0.0, 1.0, 2.0], [1.0, 0.0, 1.0], [2.0, 1.0, 0.0]];
let div = sinkhorn_divergence_same_support(&a, &a, &cost, 0.1, 500, 1e-4).unwrap();
assert!(div.abs() < 1e-5, "div={}", div);
}
#[test]
fn sinkhorn_divergence_is_symmetric_same_support() {
let a = array![0.2, 0.3, 0.5];
let b = array![0.5, 0.4, 0.1];
let cost = array![[0.0, 1.0, 2.0], [1.0, 0.0, 1.0], [2.0, 1.0, 0.0]];
let ab = sinkhorn_divergence_same_support(&a, &b, &cost, 0.1, 500, 1e-4).unwrap();
let ba = sinkhorn_divergence_same_support(&b, &a, &cost, 0.1, 500, 1e-4).unwrap();
assert!((ab - ba).abs() < 1e-5, "ab={} ba={}", ab, ba);
}
fn flat_cost(n: usize, m: usize) -> Vec<f32> {
let mut c = vec![0.0f32; n * m];
for i in 0..n {
for j in 0..m {
let d = i as f32 - j as f32;
c[i * m + j] = d * d;
}
}
c
}
#[test]
fn low_rank_transport_cost_nonneg() {
let n = 5;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / n as f32; n];
let cost = flat_cost(n, n);
let lr = sinkhorn_low_rank(&a, &b, &cost, 0.1, 3, 200, 1e-5).unwrap();
assert!(
lr.cost >= -1e-6,
"transport cost should be non-negative, got {}",
lr.cost
);
}
#[test]
fn low_rank_row_marginals_match() {
let n = 6;
let m = 4;
let a: Vec<f32> = {
let raw = vec![1.0, 2.0, 3.0, 2.0, 1.0, 1.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let b: Vec<f32> = {
let raw = vec![2.0, 1.0, 1.0, 2.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let cost = flat_cost(n, m);
let lr = sinkhorn_low_rank(&a, &b, &cost, 0.5, 3, 500, 1e-6).unwrap();
let row_marg = lr.row_marginals();
assert_eq!(row_marg.len(), n);
for i in 0..n {
assert!(
(row_marg[i] - a[i]).abs() < 0.05,
"row marginal[{}]: got {}, expected {}",
i,
row_marg[i],
a[i]
);
}
}
#[test]
fn low_rank_col_marginals_match() {
let n = 6;
let m = 4;
let a: Vec<f32> = {
let raw = vec![1.0, 2.0, 3.0, 2.0, 1.0, 1.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let b: Vec<f32> = {
let raw = vec![2.0, 1.0, 1.0, 2.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let cost = flat_cost(n, m);
let lr = sinkhorn_low_rank(&a, &b, &cost, 0.5, 3, 500, 1e-6).unwrap();
let col_marg = lr.col_marginals();
assert_eq!(col_marg.len(), m);
for j in 0..m {
assert!(
(col_marg[j] - b[j]).abs() < 0.05,
"col marginal[{}]: got {}, expected {}",
j,
col_marg[j],
b[j]
);
}
}
#[test]
fn low_rank_full_rank_approximates_sinkhorn() {
let n = 4;
let a_arr = array![0.25, 0.25, 0.25, 0.25];
let b_arr = array![0.1, 0.3, 0.4, 0.2];
let cost_arr = array![
[0.0, 1.0, 4.0, 9.0],
[1.0, 0.0, 1.0, 4.0],
[4.0, 1.0, 0.0, 1.0],
[9.0, 4.0, 1.0, 0.0]
];
let (_, full_cost) = sinkhorn_log(&a_arr, &b_arr, &cost_arr, 0.5, 200);
let a_flat: Vec<f32> = a_arr.to_vec();
let b_flat: Vec<f32> = b_arr.to_vec();
let cost_flat: Vec<f32> = cost_arr.iter().copied().collect();
let lr = sinkhorn_low_rank(&a_flat, &b_flat, &cost_flat, 0.5, n, 500, 1e-6).unwrap();
let ratio = lr.cost / full_cost;
assert!(
(0.5..2.0).contains(&ratio),
"full-rank low-rank cost ({}) should approximate sinkhorn_log cost ({}), ratio={}",
lr.cost,
full_cost,
ratio
);
}
#[test]
fn low_rank_memory_scales_linearly() {
let n = 100;
let m = 80;
let rank = 5;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / m as f32; m];
let cost = flat_cost(n, m);
let lr = sinkhorn_low_rank(&a, &b, &cost, 1.0, rank, 100, 1e-4).unwrap();
let factor_size = lr.q.len() + lr.r.len() + lr.g.len();
let dense_size = n * m;
assert_eq!(lr.q.len(), n * rank);
assert_eq!(lr.r.len(), m * rank);
assert_eq!(lr.g.len(), rank);
assert!(
factor_size < dense_size,
"factor storage ({}) should be less than dense ({})",
factor_size,
dense_size
);
}
#[test]
fn low_rank_apply_matches_dense() {
let n = 5;
let m = 4;
let a: Vec<f32> = {
let raw = vec![1.0, 2.0, 1.0, 2.0, 1.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let b: Vec<f32> = {
let raw = vec![1.0, 1.0, 1.0, 1.0];
let s: f32 = raw.iter().sum();
raw.iter().map(|&x| x / s).collect()
};
let cost = flat_cost(n, m);
let lr = sinkhorn_low_rank(&a, &b, &cost, 0.5, 3, 300, 1e-5).unwrap();
let v = vec![1.0, 0.0, 0.0, 0.0];
let result_apply = lr.apply(&v);
let dense = lr.to_dense();
let mut result_dense = vec![0.0f32; n];
for i in 0..n {
for j in 0..m {
result_dense[i] += dense[i * m + j] * v[j];
}
}
for i in 0..n {
assert!(
(result_apply[i] - result_dense[i]).abs() < 1e-5,
"apply[{}]={} vs dense[{}]={}",
i,
result_apply[i],
i,
result_dense[i]
);
}
}
#[test]
fn low_rank_invalid_inputs() {
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let cost = vec![0.0, 1.0, 1.0, 0.0];
assert!(sinkhorn_low_rank(&a, &b, &cost, 0.1, 0, 100, 1e-5).is_err());
assert!(sinkhorn_low_rank(&a, &b, &cost, 0.1, 3, 100, 1e-5).is_err());
assert!(sinkhorn_low_rank(&a, &b, &cost, -0.1, 1, 100, 1e-5).is_err());
assert!(sinkhorn_low_rank(&a, &b, &[0.0, 1.0], 0.1, 1, 100, 1e-5).is_err());
}
#[test]
fn low_rank_to_dense_nonneg() {
let n = 5;
let a = vec![0.2; n];
let b = vec![0.2; n];
let cost = flat_cost(n, n);
let lr = sinkhorn_low_rank(&a, &b, &cost, 0.5, 3, 200, 1e-5).unwrap();
let dense = lr.to_dense();
for (idx, &val) in dense.iter().enumerate() {
assert!(val >= -1e-7, "dense[{}] = {} is negative", idx, val);
}
}
#[test]
fn hierarchical_cost_nonnegative() {
let n = 8;
let m = 8;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / m as f32; m];
let cost = flat_cost(n, m);
let (tc, coupling) =
sinkhorn_hierarchical(&a, &b, &cost, n, m, 0.5, 4, 1, 200, 1e-5).unwrap();
assert!(tc >= 0.0, "transport cost should be nonneg, got {}", tc);
for (idx, &val) in coupling.iter().enumerate() {
assert!(val >= -1e-7, "coupling[{}] = {} is negative", idx, val);
}
}
#[test]
fn hierarchical_marginals_approx() {
let n = 6;
let m = 6;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / m as f32; m];
let cost = flat_cost(n, m);
let (_, coupling) =
sinkhorn_hierarchical(&a, &b, &cost, n, m, 0.5, 3, 1, 200, 1e-5).unwrap();
for i in 0..n {
let row_sum: f32 = (0..m).map(|j| coupling[i * m + j]).sum();
assert!(
(row_sum - a[i]).abs() < 0.15,
"row {} sum = {}, expected ~{}",
i,
row_sum,
a[i]
);
}
for j in 0..m {
let col_sum: f32 = (0..n).map(|i| coupling[i * m + j]).sum();
assert!(
(col_sum - b[j]).abs() < 0.15,
"col {} sum = {}, expected ~{}",
j,
col_sum,
b[j]
);
}
}
#[test]
fn hierarchical_approx_quality() {
let n = 16;
let m = 16;
let mut a_flat = vec![0.0f32; n];
let mut b_flat = vec![0.0f32; m];
for i in 0..4 {
a_flat[i] = 0.25;
}
for i in 12..16 {
b_flat[i] = 0.25;
}
let cost_flat = flat_cost(n, m);
let a_arr = Array1::from_vec(a_flat.clone());
let b_arr = Array1::from_vec(b_flat.clone());
let cost_arr = Array2::from_shape_vec((n, m), cost_flat.clone()).unwrap();
let (_, regular_cost) = sinkhorn_log(&a_arr, &b_arr, &cost_arr, 1.0, 200);
let (hier_cost, _) =
sinkhorn_hierarchical(&a_flat, &b_flat, &cost_flat, n, m, 1.0, 4, 1, 200, 1e-5)
.unwrap();
assert!(
hier_cost > 0.0,
"hierarchical cost should be positive, got {}",
hier_cost
);
assert!(
regular_cost > 10.0,
"regular cost should be large for separated clusters, got {}",
regular_cost
);
let ratio = hier_cost / regular_cost;
assert!(
(0.1..20.0).contains(&ratio),
"hierarchical cost ({}) should be same order of magnitude as regular ({}), ratio={}",
hier_cost,
regular_cost,
ratio
);
}
#[test]
fn hierarchical_deeper_recursion() {
let n = 16;
let m = 16;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / m as f32; m];
let cost = flat_cost(n, m);
let (tc, coupling) =
sinkhorn_hierarchical(&a, &b, &cost, n, m, 0.5, 4, 2, 200, 1e-5).unwrap();
assert!(tc >= 0.0, "transport cost nonneg, got {}", tc);
let total_mass: f32 = coupling.iter().sum();
assert!(
total_mass > 0.0,
"total coupling mass should be positive, got {}",
total_mass
);
}
#[test]
fn hierarchical_invalid_inputs() {
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let cost = vec![0.0, 1.0, 1.0, 0.0];
assert!(sinkhorn_hierarchical(&a, &b, &cost, 2, 2, 0.1, 1, 1, 100, 1e-5).is_err());
assert!(sinkhorn_hierarchical(&a, &b, &cost, 2, 2, 0.1, 3, 1, 100, 1e-5).is_err());
assert!(sinkhorn_hierarchical(&a, &b, &cost, 2, 2, -0.1, 2, 1, 100, 1e-5).is_err());
assert!(sinkhorn_hierarchical(&a, &b, &[0.0, 1.0], 2, 2, 0.1, 2, 1, 100, 1e-5).is_err());
}
#[test]
fn hierarchical_timing_large() {
let n = 64;
let m = 64;
let a = vec![1.0 / n as f32; n];
let b = vec![1.0 / m as f32; m];
let cost = flat_cost(n, m);
let start = std::time::Instant::now();
let (tc, _) = sinkhorn_hierarchical(&a, &b, &cost, n, m, 1.0, 8, 2, 100, 1e-4).unwrap();
let hier_elapsed = start.elapsed();
assert!(tc >= 0.0);
eprintln!(
"hierarchical 64x64 branching=8 depth=2: cost={:.4}, time={:?}",
tc, hier_elapsed
);
}
}