use crate::{Error, Result};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
const DEGENERATE_PLAN_TOL: f32 = 1e-8;
#[inline]
fn l2_squared(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut s = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
s += d * d;
}
s
}
#[cfg(test)]
#[inline]
fn l2(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> f32 {
l2_squared(a, b).sqrt()
}
fn sq_euclidean_cost_matrix_from_views(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
) -> Result<Array2<f32>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
let mut cost = Array2::<f32>::zeros((n, n));
for i in 0..n {
let xi = x.row(i);
for j in 0..n {
let yj = y.row(j);
cost[[i, j]] = l2_squared(&xi, &yj);
}
}
Ok(cost)
}
pub(crate) fn greedy_bipartite_match_from_weights(w: &ArrayView2<f32>) -> Result<Vec<usize>> {
let n = w.nrows();
if w.ncols() != n {
return Err(Error::Shape("weight matrix must be square"));
}
if n == 0 {
return Ok(Vec::new());
}
if w.iter().any(|&x| x.is_nan()) {
return Err(Error::Domain("weight matrix contains NaN"));
}
if w.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("weight matrix must be nonnegative"));
}
let mut edges: Vec<(usize, usize, f32)> = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
edges.push((i, j, w[[i, j]]));
}
}
edges.sort_by(|a, b| b.2.total_cmp(&a.2));
let mut matched_row = vec![false; n];
let mut matched_col = vec![false; n];
let mut perm = vec![usize::MAX; n];
let mut remaining = n;
for (i, j, _wij) in edges {
if matched_row[i] || matched_col[j] {
continue;
}
matched_row[i] = true;
matched_col[j] = true;
perm[i] = j;
remaining -= 1;
if remaining == 0 {
break;
}
}
if perm.contains(&usize::MAX) {
return Err(Error::Domain("failed to construct a full matching"));
}
Ok(perm)
}
pub fn minibatch_ot_greedy_pairing(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if n == 0 {
return Ok(Vec::new());
}
if !reg.is_finite() || reg <= 0.0 {
return Err(Error::Domain("reg must be positive and finite"));
}
if max_iter == 0 {
return Err(Error::Domain("max_iter must be >= 1"));
}
if !tol.is_finite() || tol <= 0.0 {
return Err(Error::Domain("tol must be positive and finite"));
}
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let cost = sq_euclidean_cost_matrix_from_views(x, y)?;
let a = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let b = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let (plan, _dist, _iters) =
wass::sinkhorn_log_with_convergence(&a, &b, &cost, reg, max_iter, tol)
.map_err(|_| Error::Domain("sinkhorn coupling did not converge"))?;
let plan = if plan.sum().abs() < DEGENERATE_PLAN_TOL {
Array2::<f32>::from_elem((n, n), 1.0 / (n * n) as f32)
} else {
plan
};
greedy_bipartite_match_from_weights(&plan.view())
}
pub fn minibatch_ot_selective_pairing(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
reg: f32,
max_iter: usize,
tol: f32,
keep_frac: f32,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if n == 0 {
return Ok(Vec::new());
}
if !reg.is_finite() || reg <= 0.0 {
return Err(Error::Domain("reg must be positive and finite"));
}
if max_iter == 0 {
return Err(Error::Domain("max_iter must be >= 1"));
}
if !tol.is_finite() || tol <= 0.0 {
return Err(Error::Domain("tol must be positive and finite"));
}
if !keep_frac.is_finite() || keep_frac <= 0.0 {
return Err(Error::Domain("keep_frac must be positive and finite"));
}
let keep_frac = keep_frac.min(1.0);
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let cost = sq_euclidean_cost_matrix_from_views(x, y)?;
let a = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let b = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let (plan, _dist, _iters) =
wass::sinkhorn_log_with_convergence(&a, &b, &cost, reg, max_iter, tol)
.map_err(|_| Error::Domain("sinkhorn coupling did not converge"))?;
let plan = if plan.sum().abs() < DEGENERATE_PLAN_TOL {
Array2::<f32>::from_elem((n, n), 1.0 / (n * n) as f32)
} else {
plan
};
let mut row_exp_cost = vec![0.0f32; n];
let mut row_nn = vec![0usize; n];
for i in 0..n {
let mut e = 0.0f32;
let mut best_j = 0usize;
let mut best_c = f32::INFINITY;
for j in 0..n {
e += plan[[i, j]] * cost[[i, j]];
let c = cost[[i, j]];
if c < best_c {
best_c = c;
best_j = j;
}
}
row_exp_cost[i] = e;
row_nn[i] = best_j;
}
let keep = ((keep_frac * n as f32).round() as usize).clamp(1, n);
let mut rows: Vec<usize> = (0..n).collect();
rows.sort_by(|&i, &j| row_exp_cost[i].total_cmp(&row_exp_cost[j])); let selected = &rows[..keep];
let mut used_col = vec![false; n];
let mut perm = vec![usize::MAX; n];
for &i in selected {
let mut best_j = usize::MAX;
let mut best_w = -1.0f32;
for j in 0..n {
if used_col[j] {
continue;
}
let w = plan[[i, j]];
if w > best_w {
best_w = w;
best_j = j;
}
}
if best_j == usize::MAX {
return Err(Error::Domain("failed to assign columns for selected rows"));
}
used_col[best_j] = true;
perm[i] = best_j;
}
for i in 0..n {
if perm[i] != usize::MAX {
continue;
}
perm[i] = row_nn[i];
}
debug_assert!(perm.iter().all(|&j| j < n));
Ok(perm)
}
pub fn minibatch_rowwise_nearest_pairing(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if n == 0 {
return Ok(Vec::new());
}
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let mut used = vec![false; n];
let mut perm = vec![usize::MAX; n];
for (i, perm_i) in perm.iter_mut().enumerate() {
let xi = x.row(i);
let mut best_j = usize::MAX;
let mut best = f32::INFINITY;
for (j, used_j) in used.iter().enumerate() {
if *used_j {
continue;
}
let yj = y.row(j);
let c = l2_squared(&xi, &yj);
if c < best {
best = c;
best_j = j;
}
}
if best_j == usize::MAX {
return Err(Error::Domain("failed to construct a full matching"));
}
used[best_j] = true;
*perm_i = best_j;
}
Ok(perm)
}
pub fn minibatch_partial_rowwise_pairing(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
keep_frac: f32,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if n == 0 {
return Ok(Vec::new());
}
if !keep_frac.is_finite() || keep_frac <= 0.0 {
return Err(Error::Domain("keep_frac must be positive and finite"));
}
let keep_frac = keep_frac.min(1.0);
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let mut best: Vec<(usize, f32)> = Vec::with_capacity(n);
for i in 0..n {
let xi = x.row(i);
let mut best_j = 0usize;
let mut best_c = f32::INFINITY;
for j in 0..n {
let yj = y.row(j);
let c = l2_squared(&xi, &yj);
if c < best_c {
best_c = c;
best_j = j;
}
}
best.push((best_j, best_c));
}
let keep = ((keep_frac * n as f32).round() as usize).clamp(1, n);
let mut rows: Vec<usize> = (0..n).collect();
rows.sort_by(|&i, &j| best[i].1.total_cmp(&best[j].1)); let selected = &rows[..keep];
let mut used_col = vec![false; n];
let mut perm = vec![usize::MAX; n];
for &i in selected {
let xi = x.row(i);
let mut best_j = usize::MAX;
let mut best_c = f32::INFINITY;
for (j, used_j) in used_col.iter().enumerate() {
if *used_j {
continue;
}
let yj = y.row(j);
let c = l2_squared(&xi, &yj);
if c < best_c {
best_c = c;
best_j = j;
}
}
if best_j == usize::MAX {
return Err(Error::Domain("failed to assign columns for selected rows"));
}
used_col[best_j] = true;
perm[i] = best_j;
}
for i in 0..n {
if perm[i] != usize::MAX {
continue;
}
perm[i] = best[i].0;
}
debug_assert!(perm.iter().all(|&j| j < n));
Ok(perm)
}
pub fn minibatch_exp_greedy_pairing(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
temp: f32,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if n == 0 {
return Ok(Vec::new());
}
if !temp.is_finite() || temp <= 0.0 {
return Err(Error::Domain("temp must be positive and finite"));
}
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let d = x.ncols();
let mut edges: Vec<(usize, usize, f32)> = Vec::with_capacity(n * n);
for i in 0..n {
let xi = x.row(i);
for j in 0..n {
let yj = y.row(j);
let mut dist_sq = 0.0f32;
for k in 0..d {
let dd = xi[k] - yj[k];
dist_sq += dd * dd;
}
let dist = dist_sq.sqrt();
edges.push((i, j, -dist / temp));
}
}
edges.sort_by(|a, b| b.2.total_cmp(&a.2));
let mut matched_row = vec![false; n];
let mut matched_col = vec![false; n];
let mut perm = vec![usize::MAX; n];
let mut remaining = n;
for (i, j, _wij) in edges {
if matched_row[i] || matched_col[j] {
continue;
}
matched_row[i] = true;
matched_col[j] = true;
perm[i] = j;
remaining -= 1;
if remaining == 0 {
break;
}
}
if perm.contains(&usize::MAX) {
return Err(Error::Domain("failed to construct a full matching"));
}
Ok(perm)
}
pub fn apply_pairing(
pairing: &crate::sd_fm::RfmMinibatchPairing,
x0s: &ArrayView2<f32>,
ys: &ArrayView2<f32>,
cfg: &crate::sd_fm::RfmMinibatchOtConfig,
) -> Result<Vec<usize>> {
use crate::sd_fm::RfmMinibatchPairing;
match *pairing {
RfmMinibatchPairing::SinkhornGreedy => {
validate_sinkhorn_fields(cfg)?;
minibatch_ot_greedy_pairing(x0s, ys, cfg.reg, cfg.max_iter, cfg.tol)
}
RfmMinibatchPairing::SinkhornSelective { keep_frac } => {
validate_sinkhorn_fields(cfg)?;
if !keep_frac.is_finite() || keep_frac <= 0.0 {
return Err(Error::Domain(
"rfm_cfg.keep_frac must be positive and finite",
));
}
minibatch_ot_selective_pairing(x0s, ys, cfg.reg, cfg.max_iter, cfg.tol, keep_frac)
}
RfmMinibatchPairing::RowwiseNearest => minibatch_rowwise_nearest_pairing(x0s, ys),
RfmMinibatchPairing::ExpGreedy { temp } => {
if !temp.is_finite() || temp <= 0.0 {
return Err(Error::Domain("rfm_cfg.temp must be positive and finite"));
}
minibatch_exp_greedy_pairing(x0s, ys, temp)
}
RfmMinibatchPairing::PartialRowwise { keep_frac } => {
if !keep_frac.is_finite() || keep_frac <= 0.0 {
return Err(Error::Domain(
"rfm_cfg.keep_frac must be positive and finite",
));
}
minibatch_partial_rowwise_pairing(x0s, ys, keep_frac)
}
}
}
fn validate_sinkhorn_fields(cfg: &crate::sd_fm::RfmMinibatchOtConfig) -> Result<()> {
if !cfg.reg.is_finite() || cfg.reg <= 0.0 {
return Err(Error::Domain("rfm_cfg.reg must be positive and finite"));
}
if cfg.max_iter == 0 {
return Err(Error::Domain("rfm_cfg.max_iter must be >= 1"));
}
if !cfg.tol.is_finite() || cfg.tol <= 0.0 {
return Err(Error::Domain("rfm_cfg.tol must be positive and finite"));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use proptest::prelude::*;
#[test]
fn greedy_matching_is_a_permutation() {
let w = array![[0.9, 0.1, 0.0], [0.2, 0.8, 0.1], [0.0, 0.1, 0.7]];
let p = greedy_bipartite_match_from_weights(&w.view()).unwrap();
assert_eq!(p.len(), 3);
let mut seen = [false; 3];
for &j in &p {
assert!(j < 3);
assert!(!seen[j]);
seen[j] = true;
}
}
#[test]
fn minibatch_ot_pairing_is_deterministic_and_a_permutation() {
let x = array![[0.0f32, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let y = array![[0.0f32, 0.0], [1.0, 0.1], [0.1, 1.0], [1.0, 1.0]];
let p1 = minibatch_ot_greedy_pairing(&x.view(), &y.view(), 1.0, 5000, 2e-3).unwrap();
let p2 = minibatch_ot_greedy_pairing(&x.view(), &y.view(), 1.0, 5000, 2e-3).unwrap();
assert_eq!(p1, p2);
let mut seen = vec![false; p1.len()];
for &j in &p1 {
assert!(j < p1.len());
assert!(!seen[j]);
seen[j] = true;
}
}
fn is_permutation(p: &[usize]) -> bool {
let n = p.len();
let mut seen = vec![false; n];
for &j in p {
if j >= n || seen[j] {
return false;
}
seen[j] = true;
}
true
}
fn is_in_range(p: &[usize]) -> bool {
let n = p.len();
p.iter().all(|&j| j < n)
}
fn selected_rows_by_nn_cost_sq(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
keep_frac: f32,
) -> Vec<usize> {
let n = x.nrows();
let keep = ((keep_frac.min(1.0) * n as f32).round() as usize).clamp(1, n);
let mut best_cost: Vec<(usize, f32)> = Vec::with_capacity(n);
for i in 0..n {
let xi = x.row(i);
let mut best = f32::INFINITY;
for j in 0..n {
let yj = y.row(j);
let c = l2_squared(&xi, &yj);
if c < best {
best = c;
}
}
best_cost.push((i, best));
}
best_cost.sort_by(|a, b| a.1.total_cmp(&b.1)); best_cost.into_iter().take(keep).map(|(i, _)| i).collect()
}
fn rowwise_nearest_pairing_sqrt_reference(
x: &ArrayView2<f32>,
y: &ArrayView2<f32>,
) -> Result<Vec<usize>> {
let n = x.nrows();
if y.nrows() != n {
return Err(Error::Shape("x and y must have same number of rows"));
}
if x.ncols() != y.ncols() {
return Err(Error::Shape("x and y must have same dimension"));
}
if x.iter().any(|&v| !v.is_finite()) || y.iter().any(|&v| !v.is_finite()) {
return Err(Error::Domain("x/y contain NaN/Inf"));
}
let mut used = vec![false; n];
let mut perm = vec![usize::MAX; n];
for (i, perm_i) in perm.iter_mut().enumerate() {
let xi = x.row(i);
let mut best_j = usize::MAX;
let mut best = f32::INFINITY;
for (j, used_j) in used.iter().enumerate() {
if *used_j {
continue;
}
let yj = y.row(j);
let c = l2(&xi, &yj);
if c < best {
best = c;
best_j = j;
}
}
if best_j == usize::MAX {
return Err(Error::Domain("failed to construct a full matching"));
}
used[best_j] = true;
*perm_i = best_j;
}
Ok(perm)
}
proptest! {
#[test]
fn prop_sq_cost_matrix_is_wass_squared(
n in 1usize..8,
d in 1usize..8,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let cost_wass_l2 = wass::euclidean_cost_matrix(&x, &y);
let cost_ours_sq = sq_euclidean_cost_matrix_from_views(&x.view(), &y.view()).unwrap();
prop_assert_eq!(cost_wass_l2.shape(), cost_ours_sq.shape());
for i in 0..n {
for j in 0..n {
let wass_sq = cost_wass_l2[[i, j]] * cost_wass_l2[[i, j]];
let ours = cost_ours_sq[[i, j]];
prop_assert!((wass_sq - ours).abs() <= 1e-5,
"mismatch at ({i},{j}): wass^2={wass_sq} ours={ours}");
}
}
}
}
proptest! {
#[test]
fn prop_sq_cost_matrix_matches_wass_sq_cost(
n in 1usize..8,
d in 1usize..8,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let cost_wass_sq = wass::sq_euclidean_cost_matrix(&x, &y);
let cost_ours_sq = sq_euclidean_cost_matrix_from_views(&x.view(), &y.view()).unwrap();
prop_assert_eq!(cost_wass_sq.shape(), cost_ours_sq.shape());
for i in 0..n {
for j in 0..n {
let w = cost_wass_sq[[i, j]];
let o = cost_ours_sq[[i, j]];
prop_assert!((w - o).abs() <= 1e-5,
"mismatch at ({i},{j}): wass_sq={w} ours={o}");
}
}
}
}
proptest! {
#[test]
fn prop_argmin_same_for_sq_vs_sqrt(
n in 1usize..16,
d in 1usize..16,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut xs: Vec<Vec<f32>> = Vec::with_capacity(n);
let mut ys: Vec<Vec<f32>> = Vec::with_capacity(n);
for _ in 0..n {
let mut v = vec![0.0f32; d];
let mut w = vec![0.0f32; d];
for k in 0..d {
v[k] = StandardNormal.sample(&mut rng);
w[k] = StandardNormal.sample(&mut rng);
}
xs.push(v);
ys.push(w);
}
let x0 = ndarray::ArrayView1::from(&xs[0]);
let mut best_sq = f32::INFINITY;
let mut best_sq_j = 0usize;
let mut best_eu = f32::INFINITY;
let mut best_eu_j = 0usize;
for (j, yj_vec) in ys.iter().enumerate() {
let yj = ndarray::ArrayView1::from(yj_vec);
let dsq = l2_squared(&x0, &yj);
let deu = dsq.sqrt();
if dsq < best_sq {
best_sq = dsq;
best_sq_j = j;
}
if deu < best_eu {
best_eu = deu;
best_eu_j = j;
}
}
prop_assert_eq!(best_sq_j, best_eu_j);
}
}
proptest! {
#[test]
fn prop_exp_greedy_matches_weight_matrix_definition(
n in 1usize..7,
d in 1usize..8,
temp in 0.05f32..1.0f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let perm_fast = minibatch_exp_greedy_pairing(&x.view(), &y.view(), temp).unwrap();
let cost = wass::euclidean_cost_matrix(&x, &y);
let mut w = cost.clone();
for i in 0..n {
for j in 0..n {
w[[i, j]] = (-cost[[i, j]] / temp).exp();
}
}
let perm_def = greedy_bipartite_match_from_weights(&w.view()).unwrap();
prop_assert!(is_permutation(&perm_fast));
prop_assert_eq!(perm_fast, perm_def);
}
}
#[test]
fn exp_greedy_works_with_extreme_underflow_regime() {
let n = 8usize;
let d = 64usize;
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = (i as f32) * 10.0;
y[[i, k]] = ((n - 1 - i) as f32) * 10.0;
}
}
let temp = 0.01; let perm = minibatch_exp_greedy_pairing(&x.view(), &y.view(), temp).unwrap();
assert!(is_permutation(&perm));
let identity: Vec<usize> = (0..n).collect();
assert_ne!(
perm, identity,
"exp-greedy collapsed to identity (underflow)"
);
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 32,
.. ProptestConfig::default()
})]
#[test]
fn prop_partial_rowwise_in_range_and_deterministic(
n in 1usize..32,
d in 1usize..16,
keep_frac in 0.05f32..1.0f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let p1 = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), keep_frac).unwrap();
let p2 = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), keep_frac).unwrap();
prop_assert_eq!(&p1, &p2);
prop_assert_eq!(p1.len(), n);
prop_assert!(is_in_range(&p1));
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 16,
.. ProptestConfig::default()
})]
#[test]
fn prop_sinkhorn_selective_in_range_and_deterministic(
n in 1usize..10,
d in 1usize..10,
keep_frac in 0.1f32..1.0f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let reg = 1.0;
let max_iter = 5_000;
let tol = 2e-3;
let p1 = minibatch_ot_selective_pairing(&x.view(), &y.view(), reg, max_iter, tol, keep_frac).unwrap();
let p2 = minibatch_ot_selective_pairing(&x.view(), &y.view(), reg, max_iter, tol, keep_frac).unwrap();
prop_assert_eq!(&p1, &p2);
prop_assert_eq!(p1.len(), n);
prop_assert!(is_in_range(&p1));
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 32,
.. ProptestConfig::default()
})]
#[test]
fn prop_partial_rowwise_keep_frac_enforces_unique_cols_on_selected_rows(
n in 2usize..32,
d in 1usize..16,
keep_frac in 0.05f32..1.0f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let perm = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), keep_frac).unwrap();
prop_assert_eq!(perm.len(), n);
prop_assert!(is_in_range(&perm));
let selected = selected_rows_by_nn_cost_sq(&x.view(), &y.view(), keep_frac);
let mut seen = std::collections::HashSet::<usize>::new();
for &i in &selected {
let j = perm[i];
prop_assert!(seen.insert(j), "expected unique columns for selected rows; duplicate col {j}");
}
}
}
#[test]
fn keep_frac_monotone_outlier_usage_partial_rowwise_and_sinkhorn_selective() {
let n = 32usize;
let d = 8usize;
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = (i as f32) * 0.01 + (k as f32) * 0.001;
y[[i, k]] = (i as f32) * 0.01 + (k as f32) * 0.001;
}
}
for k in 0..d {
y[[n - 1, k]] = 1_000.0;
}
let count_outlier = |perm: &[usize]| perm.iter().filter(|&&j| j == n - 1).count();
let p_full = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), 1.0).unwrap();
let p_half = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), 0.5).unwrap();
assert_eq!(count_outlier(&p_full), 1);
assert_eq!(count_outlier(&p_half), 0);
let s_full =
minibatch_ot_selective_pairing(&x.view(), &y.view(), 1.0, 5_000, 2e-3, 1.0).unwrap();
let s_half =
minibatch_ot_selective_pairing(&x.view(), &y.view(), 1.0, 5_000, 2e-3, 0.5).unwrap();
assert_eq!(count_outlier(&s_full), 1);
assert_eq!(count_outlier(&s_half), 0);
}
proptest! {
#[test]
fn prop_sinkhorn_pairing_ok_implies_permutation(
n in 2usize..10,
d in 1usize..8,
reg in 0.05f32..0.5f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
if let Ok(p) = minibatch_ot_greedy_pairing(&x.view(), &y.view(), reg, 500, 2e-2) {
prop_assert_eq!(p.len(), n);
prop_assert!(is_permutation(&p));
}
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 32,
.. ProptestConfig::default()
})]
#[test]
fn prop_sinkhorn_plan_approximately_invariant_to_rowcol_shifts(
n in 2usize..7,
d in 1usize..8,
reg in 0.10f32..0.5f32,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let cost = wass::euclidean_cost_matrix(&x, &y);
let mut row = vec![0.0f32; n];
let mut col = vec![0.0f32; n];
for i in 0..n {
let r: f32 = StandardNormal.sample(&mut rng);
let c: f32 = StandardNormal.sample(&mut rng);
row[i] = r * 0.1;
col[i] = c * 0.1;
}
let mut cost2 = cost.clone();
for i in 0..n {
for j in 0..n {
cost2[[i, j]] = cost[[i, j]] + row[i] + col[j];
}
}
let a = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let b = Array1::<f32>::from_elem(n, 1.0 / n as f32);
let (p1, _d1, _it1) = wass::sinkhorn_log_with_convergence(&a, &b, &cost, reg, 20_000, 1e-4).unwrap();
let (p2, _d2, _it2) = wass::sinkhorn_log_with_convergence(&a, &b, &cost2, reg, 20_000, 1e-4).unwrap();
let mut max_abs = 0.0f32;
for i in 0..n {
for j in 0..n {
let d = (p1[[i, j]] - p2[[i, j]]).abs();
if d > max_abs {
max_abs = d;
}
}
}
prop_assert!(max_abs <= 5e-3, "expected near-invariant plan; max_abs={max_abs}");
}
}
proptest! {
#[test]
fn prop_rowwise_pairing_invariant_to_sqrt(
n in 2usize..32,
d in 1usize..16,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let p_sq = minibatch_rowwise_nearest_pairing(&x.view(), &y.view()).unwrap();
let p_sqrt = rowwise_nearest_pairing_sqrt_reference(&x.view(), &y.view()).unwrap();
prop_assert!(is_permutation(&p_sq));
prop_assert_eq!(p_sq, p_sqrt);
}
}
#[test]
fn partial_rowwise_avoids_forcing_outlier_column() {
let n = 16usize;
let d = 4usize;
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = i as f32 * 0.01;
y[[i, k]] = i as f32 * 0.01;
}
}
for k in 0..d {
y[[n - 1, k]] = 1_000.0;
}
let full = minibatch_rowwise_nearest_pairing(&x.view(), &y.view()).unwrap();
let partial = minibatch_partial_rowwise_pairing(&x.view(), &y.view(), 0.8).unwrap();
assert_eq!(full.iter().filter(|&&j| j == n - 1).count(), 1);
assert_eq!(partial.iter().filter(|&&j| j == n - 1).count(), 0);
}
#[test]
fn sinkhorn_selective_avoids_forcing_outlier_column() {
let n = 16usize;
let d = 4usize;
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = i as f32 * 0.01;
y[[i, k]] = i as f32 * 0.01;
}
}
for k in 0..d {
y[[n - 1, k]] = 1_000.0;
}
let full = minibatch_ot_greedy_pairing(&x.view(), &y.view(), 1.0, 2_000, 1e-4).unwrap();
let sel =
minibatch_ot_selective_pairing(&x.view(), &y.view(), 1.0, 2_000, 1e-4, 0.8).unwrap();
assert_eq!(full.iter().filter(|&&j| j == n - 1).count(), 1);
assert_eq!(sel.iter().filter(|&&j| j == n - 1).count(), 0);
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 32,
.. ProptestConfig::default()
})]
#[test]
fn prop_pairing_apis_error_on_shape_mismatch_or_nan(
n in 1usize..16,
d in 1usize..16,
seed in any::<u64>(),
) {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut x = Array2::<f32>::zeros((n, d));
let mut y = Array2::<f32>::zeros((n, d));
for i in 0..n {
for k in 0..d {
x[[i, k]] = StandardNormal.sample(&mut rng);
y[[i, k]] = StandardNormal.sample(&mut rng);
}
}
let y_bad = Array2::<f32>::zeros((n, d + 1));
prop_assert!(minibatch_rowwise_nearest_pairing(&x.view(), &y_bad.view()).is_err());
prop_assert!(minibatch_partial_rowwise_pairing(&x.view(), &y_bad.view(), 0.8).is_err());
prop_assert!(minibatch_exp_greedy_pairing(&x.view(), &y_bad.view(), 0.2).is_err());
prop_assert!(minibatch_ot_greedy_pairing(&x.view(), &y_bad.view(), 1.0, 100, 1e-2).is_err());
prop_assert!(minibatch_ot_selective_pairing(&x.view(), &y_bad.view(), 1.0, 100, 1e-2, 0.8).is_err());
x[[0, 0]] = f32::NAN;
prop_assert!(minibatch_rowwise_nearest_pairing(&x.view(), &y.view()).is_err());
prop_assert!(minibatch_partial_rowwise_pairing(&x.view(), &y.view(), 0.8).is_err());
prop_assert!(minibatch_exp_greedy_pairing(&x.view(), &y.view(), 0.2).is_err());
prop_assert!(minibatch_ot_greedy_pairing(&x.view(), &y.view(), 1.0, 200, 1e-2).is_err());
prop_assert!(minibatch_ot_selective_pairing(&x.view(), &y.view(), 1.0, 200, 1e-2, 0.8).is_err());
}
}
}