wass
Optimal transport primitives for geometry-aware distribution comparison. Implements the Sinkhorn algorithm for entropy-regularized OT, including unbalanced transport for robust partial matching.
Contract
-
Invariants (must never change):
- Numerics: All probability inputs are treated as
f32. - Divergence:
sinkhorn_divergence_same_supportis the canonical de-biased divergence; it guarantees non-negativity (modulo float noise) and symmetry. - Convergence:
*_with_convergencemethods returnErrif the residual does not fall belowtolwithinmax_iter; they do not return a partial result silently.
- Numerics: All probability inputs are treated as
-
Support / Dependencies:
- Input: Expects
ndarray::Array1(masses) andArray2(costs). - Geometry: The cost matrix $C$ encodes the ground geometry.
- Unbalanced: $\rho$ (rho) controls the penalty for mass destruction; $\rho \to \infty$ recovers balanced OT.
- Input: Expects
-
Exports:
sinkhorn: Standard dense OT.sinkhorn_divergence_same_support: MMD-like geometric divergence.wasserstein_1d: $O(n)$ exact solver for 1D.
Dual-licensed under MIT or Apache-2.0.
use ;
use array;
// 1D Wasserstein (fast, closed-form)
let a = ;
let b = ;
let w1 = wasserstein_1d;
// General transport with Sinkhorn
let cost = array!;
let a = array!;
let b = array!;
let = sinkhorn;
// Sinkhorn Divergence (debiased; requires same-support cost matrix)
let div = sinkhorn_divergence_same_support.unwrap;
Key Features
- Balanced OT: Standard Sinkhorn for probability distributions.
- Unbalanced OT: Robust transport for partial matches, outliers, and unnormalized measures (e.g. document alignment).
- Sparse OT: L2-regularized transport for interpretable, sparse alignments (via
sparsemodule). - Log-domain stabilization: Numerically stable implementations for small epsilon / large costs.
- Divergences: Proper debiased Sinkhorn divergences (positive, definite) for metric use.
Examples
Run these to see OT in action:
-
Robust Document Alignment: Shows how unbalanced OT aligns core topics while ignoring outliers (headers/footers/typos).
-
Mass Mismatch: Shows how divergence scales with the unbalanced penalty parameter.
-
Balanced Divergence:
Functions
| Function | Use Case | Complexity |
|---|---|---|
wasserstein_1d |
1D distributions | O(n) |
sinkhorn |
General transport (dense) | O(n^2 x iter) |
sinkhorn_with_convergence |
With early stopping | O(n^2 x iter) |
sinkhorn_divergence_same_support |
Debiased divergence (same support) | O(n^2 x iter) |
sinkhorn_divergence_general |
Debiased divergence (different supports) | O(mn x iter) |
unbalanced_sinkhorn_divergence_general |
Robust comparison (different supports) | O(mn x iter) |
sparse::solve_semidual_l2 |
Sparse transport (L2) | O(n^2 x iter) |
sliced_wasserstein |
High-dim approx | O(n_proj x n log n) |
Note: sinkhorn_divergence is deprecated; it only computes a true divergence when the
cost is square. Use the explicit *_same_support / *_general variants instead.
Why Optimal Transport?
- No support issues: Unlike KL divergence, OT compares distributions with disjoint supports.
- Geometry-aware: Respects the underlying metric space (e.g. word embedding distance).
- Robustness: Unbalanced OT handles outliers and noise ("pizza" vs "sushi") without breaking the alignment of the signal ("AI" vs "ML").