use ndarray::{Array1, Array2};
use crate::{Error, Result};
fn wfr_cost_from_sq_distance(cost_sq: &Array2<f32>, rho: f32) -> Array2<f32> {
let half_pi = std::f32::consts::FRAC_PI_2;
cost_sq.mapv(|c_sq| {
let d = c_sq.max(0.0).sqrt();
let angle = (d / (2.0 * rho)).min(half_pi);
let cos_val = angle.cos();
if cos_val <= 1e-12 {
30.0
} else {
-(cos_val * cos_val).ln()
}
})
}
pub fn wfr_distance(
a: &Array1<f32>,
b: &Array1<f32>,
cost: &Array2<f32>,
rho: f32,
reg: f32,
max_iter: usize,
tol: f32,
) -> Result<f32> {
let n = a.len();
if b.len() != n {
return Err(Error::LengthMismatch(n, b.len()));
}
if cost.nrows() != n || cost.ncols() != n {
return Err(Error::CostShapeMismatch(n, n, cost.nrows(), cost.ncols()));
}
if reg <= 0.0 || !reg.is_finite() {
return Err(Error::InvalidRegularization(reg));
}
if rho <= 0.0 || !rho.is_finite() {
return Err(Error::InvalidMassPenalty(rho));
}
if a.iter().any(|&x| x < 0.0) || b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("WFR requires non-negative masses"));
}
if a.sum() <= 0.0 || b.sum() <= 0.0 {
return Err(Error::Domain("WFR requires positive total mass"));
}
let wfr_cost = wfr_cost_from_sq_distance(cost, rho);
let div = crate::unbalanced_sinkhorn_divergence_same_support(
a, b, &wfr_cost, reg, rho, max_iter, tol,
)?;
Ok(div.max(0.0).sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use proptest::prelude::*;
fn sq_cost_1d(positions: &[f32]) -> Array2<f32> {
let n = positions.len();
let mut c = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let d = positions[i] - positions[j];
c[[i, j]] = d * d;
}
}
c
}
#[test]
fn wfr_self_distance_is_zero() {
let a = array![0.3, 0.5, 0.2];
let cost = sq_cost_1d(&[0.0, 1.0, 2.0]);
let d = wfr_distance(&a, &a, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
assert!(d < 0.1, "self-distance should be near zero: d={d}");
}
#[test]
fn wfr_symmetry() {
let a = array![0.5, 0.3, 0.2];
let b = array![0.2, 0.4, 0.4];
let cost = sq_cost_1d(&[0.0, 1.0, 3.0]);
let ab = wfr_distance(&a, &b, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
let ba = wfr_distance(&b, &a, &cost.t().to_owned(), 1.0, 0.1, 500, 1e-4).unwrap();
assert!(
(ab - ba).abs() < 0.1,
"WFR should be symmetric: ab={ab} ba={ba}"
);
}
#[test]
fn wfr_different_total_mass() {
let a = array![1.0, 1.0, 1.0]; let b = array![0.1, 0.1, 0.1]; let cost = sq_cost_1d(&[0.0, 1.0, 2.0]);
let d = wfr_distance(&a, &b, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
assert!(
d > 0.0,
"different-mass measures should have positive distance: d={d}"
);
}
#[test]
fn wfr_identical_measures_different_rho() {
let a = array![0.5, 0.5];
let cost = sq_cost_1d(&[0.0, 1.0]);
for &rho in &[0.1, 1.0, 10.0] {
let d = wfr_distance(&a, &a, &cost, rho, 0.1, 500, 1e-4).unwrap();
assert!(
d < 0.15,
"self-distance should be near zero for rho={rho}: d={d}"
);
}
}
#[test]
fn wfr_positive_for_different_measures() {
let a = array![1.0, 0.0];
let b = array![0.0, 1.0];
let cost = sq_cost_1d(&[0.0, 2.0]);
let d = wfr_distance(&a, &b, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
assert!(
d > 0.01,
"different measures should have positive distance: d={d}"
);
}
#[test]
fn wfr_rejects_negative_mass() {
let a = array![-0.5, 0.5];
let b = array![0.5, 0.5];
let cost = sq_cost_1d(&[0.0, 1.0]);
assert!(wfr_distance(&a, &b, &cost, 1.0, 0.1, 100, 1e-4).is_err());
}
#[test]
fn wfr_rejects_invalid_rho() {
let a = array![0.5, 0.5];
let b = array![0.5, 0.5];
let cost = sq_cost_1d(&[0.0, 1.0]);
assert!(wfr_distance(&a, &b, &cost, -1.0, 0.1, 100, 1e-4).is_err());
assert!(wfr_distance(&a, &b, &cost, 0.0, 0.1, 100, 1e-4).is_err());
}
#[test]
fn wfr_rejects_invalid_reg() {
let a = array![0.5, 0.5];
let b = array![0.5, 0.5];
let cost = sq_cost_1d(&[0.0, 1.0]);
assert!(wfr_distance(&a, &b, &cost, 1.0, 0.0, 100, 1e-4).is_err());
assert!(wfr_distance(&a, &b, &cost, 1.0, -0.1, 100, 1e-4).is_err());
}
#[test]
fn wfr_rejects_shape_mismatch() {
let a = array![0.5, 0.5];
let b = array![0.5, 0.5, 0.0];
let cost = Array2::zeros((2, 2));
assert!(wfr_distance(&a, &b, &cost, 1.0, 0.1, 100, 1e-4).is_err());
}
proptest! {
#[test]
fn prop_wfr_non_negative(
n in 2usize..5,
) {
let a_vec: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) * 0.3).collect();
let b_vec: Vec<f32> = (0..n).map(|i| ((n - i) as f32) * 0.2).collect();
let positions: Vec<f32> = (0..n).map(|i| i as f32).collect();
let a = Array1::from_vec(a_vec);
let b = Array1::from_vec(b_vec);
let cost = sq_cost_1d(&positions);
let d = wfr_distance(&a, &b, &cost, 1.0, 0.1, 500, 1e-3).unwrap();
prop_assert!(d >= -1e-4, "d={d}");
}
#[test]
fn prop_wfr_self_distance_near_zero(
n in 2usize..5,
) {
let a_vec: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) * 0.3).collect();
let positions: Vec<f32> = (0..n).map(|i| i as f32).collect();
let a = Array1::from_vec(a_vec);
let cost = sq_cost_1d(&positions);
let d = wfr_distance(&a, &a, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
prop_assert!(d < 0.15, "self-distance too large: d={d}");
}
#[test]
fn prop_wfr_symmetry(
n in 2usize..4,
) {
let a_vec: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) * 0.3).collect();
let b_vec: Vec<f32> = (0..n).map(|i| ((n - i) as f32) * 0.4).collect();
let positions: Vec<f32> = (0..n).map(|i| i as f32).collect();
let a = Array1::from_vec(a_vec);
let b = Array1::from_vec(b_vec);
let cost = sq_cost_1d(&positions);
let ab = wfr_distance(&a, &b, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
let ba = wfr_distance(&b, &a, &cost, 1.0, 0.1, 500, 1e-4).unwrap();
prop_assert!((ab - ba).abs() < 0.15, "ab={ab} ba={ba}");
}
}
}