use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
use scirs2_core::random::{seeded_rng, Distribution, Normal, SeedableRng, Uniform};
use crate::error::{Result, TransformError};
pub fn wasserstein_1d(u: &[f64], v: &[f64]) -> Result<f64> {
if u.is_empty() {
return Err(TransformError::InvalidInput(
"First distribution is empty".to_string(),
));
}
if v.is_empty() {
return Err(TransformError::InvalidInput(
"Second distribution is empty".to_string(),
));
}
let mut u_sorted = u.to_vec();
let mut v_sorted = v.to_vec();
u_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
v_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = u_sorted.len();
let m = v_sorted.len();
let mut all_values: Vec<f64> = u_sorted.iter().chain(v_sorted.iter()).copied().collect();
all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
all_values.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON * a.abs().max(1.0));
let mut distance = 0.0;
let mut i_u = 0usize; let mut i_v = 0usize;
for window in all_values.windows(2) {
let x_lo = window[0];
let x_hi = window[1];
let dx = x_hi - x_lo;
while i_u < n && u_sorted[i_u] <= x_lo {
i_u += 1;
}
let cdf_u = i_u as f64 / n as f64;
while i_v < m && v_sorted[i_v] <= x_lo {
i_v += 1;
}
let cdf_v = i_v as f64 / m as f64;
distance += (cdf_u - cdf_v).abs() * dx;
}
Ok(distance)
}
pub fn earth_mover_distance(
hist1: &[f64],
hist2: &[f64],
cost_matrix: &Array2<f64>,
) -> Result<f64> {
let n = hist1.len();
let m = hist2.len();
if n == 0 || m == 0 {
return Err(TransformError::InvalidInput(
"Histograms must be non-empty".to_string(),
));
}
if cost_matrix.nrows() != n || cost_matrix.ncols() != m {
return Err(TransformError::InvalidInput(format!(
"cost_matrix shape ({}, {}) must match histogram lengths ({}, {})",
cost_matrix.nrows(),
cost_matrix.ncols(),
n,
m
)));
}
for (i, &v) in hist1.iter().enumerate() {
if v < 0.0 {
return Err(TransformError::InvalidInput(format!(
"hist1[{}] = {} is negative",
i, v
)));
}
}
for (j, &v) in hist2.iter().enumerate() {
if v < 0.0 {
return Err(TransformError::InvalidInput(format!(
"hist2[{}] = {} is negative",
j, v
)));
}
}
let sum1: f64 = hist1.iter().sum();
let sum2: f64 = hist2.iter().sum();
if sum1 < 1e-15 || sum2 < 1e-15 {
return Err(TransformError::InvalidInput(
"Histograms must have positive total mass".to_string(),
));
}
let mut a: Vec<f64> = hist1.iter().map(|&x| x / sum1).collect();
let mut b: Vec<f64> = hist2.iter().map(|&x| x / sum2).collect();
let mut cost_pairs: Vec<(f64, usize, usize)> = Vec::with_capacity(n * m);
for i in 0..n {
for j in 0..m {
cost_pairs.push((cost_matrix[[i, j]], i, j));
}
}
cost_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut transport = vec![0.0f64; n * m];
let mut emd = 0.0f64;
for (cost, i, j) in &cost_pairs {
let flow = a[*i].min(b[*j]);
if flow > 1e-15 {
transport[i * m + j] = flow;
a[*i] -= flow;
b[*j] -= flow;
emd += flow * cost;
}
}
let scale = sum1.min(sum2);
Ok(emd * scale)
}
pub fn sinkhorn(
a: &scirs2_core::ndarray::ArrayView1<f64>,
b: &scirs2_core::ndarray::ArrayView1<f64>,
cost_matrix: &Array2<f64>,
reg: f64,
max_iter: usize,
) -> Result<Array2<f64>> {
let n = a.len();
let m = b.len();
if n == 0 || m == 0 {
return Err(TransformError::InvalidInput(
"Distributions must be non-empty".to_string(),
));
}
if cost_matrix.nrows() != n || cost_matrix.ncols() != m {
return Err(TransformError::InvalidInput(format!(
"cost_matrix shape ({}, {}) must match distribution lengths ({}, {})",
cost_matrix.nrows(),
cost_matrix.ncols(),
n,
m
)));
}
if reg <= 0.0 {
return Err(TransformError::InvalidInput(
"Regularization parameter reg must be positive".to_string(),
));
}
let sum_a: f64 = a.iter().sum();
let sum_b: f64 = b.iter().sum();
if sum_a < 1e-15 {
return Err(TransformError::InvalidInput(
"Source distribution a has zero mass".to_string(),
));
}
if sum_b < 1e-15 {
return Err(TransformError::InvalidInput(
"Target distribution b has zero mass".to_string(),
));
}
let a_norm: Vec<f64> = a.iter().map(|&v| v / sum_a).collect();
let b_norm: Vec<f64> = b.iter().map(|&v| v / sum_b).collect();
let mut log_k = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
log_k[[i, j]] = -cost_matrix[[i, j]] / reg;
}
}
let mut log_u = vec![0.0f64; n]; let mut log_v = vec![0.0f64; m];
let tol = 1e-9;
for _iter in 0..max_iter {
let log_u_prev = log_u.clone();
for i in 0..n {
let lse = logsumexp_row_plus(&log_k, i, &log_v);
log_u[i] = a_norm[i].ln() - lse;
}
for j in 0..m {
let lse = logsumexp_col_plus(&log_k, j, &log_u);
log_v[j] = b_norm[j].ln() - lse;
}
let delta: f64 = log_u
.iter()
.zip(log_u_prev.iter())
.map(|(new, old)| (new - old).abs())
.fold(0.0_f64, f64::max);
if delta < tol {
break;
}
}
let mut plan = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
let log_t = log_u[i] + log_k[[i, j]] + log_v[j];
plan[[i, j]] = log_t.exp().max(0.0);
}
}
Ok(plan)
}
pub fn sinkhorn_distance(
a: &scirs2_core::ndarray::ArrayView1<f64>,
b: &scirs2_core::ndarray::ArrayView1<f64>,
cost_matrix: &Array2<f64>,
reg: f64,
) -> Result<f64> {
let plan = sinkhorn(a, b, cost_matrix, reg, 1000)?;
let n = plan.nrows();
let m = plan.ncols();
let mut distance = 0.0f64;
for i in 0..n {
for j in 0..m {
distance += plan[[i, j]] * cost_matrix[[i, j]];
}
}
Ok(distance.max(0.0))
}
pub fn sliced_wasserstein(
x: &ArrayView2<f64>,
y: &ArrayView2<f64>,
n_projections: usize,
seed: u64,
) -> Result<f64> {
let n = x.nrows();
let m = y.nrows();
let d = x.ncols();
if n == 0 || m == 0 {
return Err(TransformError::InvalidInput(
"Point clouds must be non-empty".to_string(),
));
}
if d == 0 {
return Err(TransformError::InvalidInput(
"Point clouds must have at least one feature dimension".to_string(),
));
}
if y.ncols() != d {
return Err(TransformError::InvalidInput(format!(
"Point cloud dimension mismatch: x has {} features, y has {}",
d,
y.ncols()
)));
}
if n_projections == 0 {
return Err(TransformError::InvalidInput(
"n_projections must be positive".to_string(),
));
}
let mut rng = seeded_rng(seed);
let normal = Normal::new(0.0f64, 1.0f64)
.map_err(|e| TransformError::ComputationError(e.to_string()))?;
let mut total_distance = 0.0f64;
for _ in 0..n_projections {
let raw: Vec<f64> = (0..d).map(|_| normal.sample(&mut rng)).collect();
let norm: f64 = raw.iter().map(|&v| v * v).sum::<f64>().sqrt();
if norm < 1e-15 {
continue; }
let direction: Vec<f64> = raw.iter().map(|&v| v / norm).collect();
let proj_x: Vec<f64> = (0..n)
.map(|i| {
direction
.iter()
.enumerate()
.map(|(k, &dk)| dk * x[[i, k]])
.sum::<f64>()
})
.collect();
let proj_y: Vec<f64> = (0..m)
.map(|i| {
direction
.iter()
.enumerate()
.map(|(k, &dk)| dk * y[[i, k]])
.sum::<f64>()
})
.collect();
let w1 = wasserstein_1d(&proj_x, &proj_y)?;
total_distance += w1;
}
Ok(total_distance / n_projections as f64)
}
pub fn wasserstein_barycenter(
distributions: &[scirs2_core::ndarray::ArrayView1<f64>],
weights: &[f64],
reg: f64,
max_iter: usize,
) -> Result<Array1<f64>> {
let k = distributions.len();
if k == 0 {
return Err(TransformError::InvalidInput(
"At least one distribution is required".to_string(),
));
}
if weights.len() != k {
return Err(TransformError::InvalidInput(format!(
"weights length {} must match number of distributions {}",
weights.len(),
k
)));
}
let n = distributions[0].len();
if n == 0 {
return Err(TransformError::InvalidInput(
"Distributions must be non-empty".to_string(),
));
}
for (idx, d) in distributions.iter().enumerate() {
if d.len() != n {
return Err(TransformError::InvalidInput(format!(
"Distribution {} has length {}, expected {}",
idx,
d.len(),
n
)));
}
}
let weight_sum: f64 = weights.iter().sum();
if weight_sum < 1e-15 {
return Err(TransformError::InvalidInput(
"Weights must have positive total".to_string(),
));
}
let weights_norm: Vec<f64> = weights.iter().map(|&w| w / weight_sum).collect();
let dists_norm: Vec<Vec<f64>> = distributions
.iter()
.map(|d| {
let s: f64 = d.iter().sum();
if s > 1e-15 {
d.iter().map(|&v| v / s).collect()
} else {
vec![1.0 / n as f64; n]
}
})
.collect();
let cost_matrix = build_grid_cost_matrix(n);
let mut log_k = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
log_k[[i, j]] = -cost_matrix[[i, j]] / reg;
}
}
let mut p = vec![1.0 / n as f64; n];
let mut log_v_all: Vec<Vec<f64>> = vec![vec![0.0f64; n]; k];
let tol = 1e-7;
for _outer in 0..max_iter {
let p_prev = p.clone();
let mut log_kv: Vec<Vec<f64>> = Vec::with_capacity(k);
for (idx, log_v_k) in log_v_all.iter().enumerate() {
let mut kv = vec![0.0f64; n];
for i in 0..n {
kv[i] = logsumexp_row_plus(&log_k, i, log_v_k);
}
log_kv.push(kv);
let _ = idx;
}
let mut log_p = vec![0.0f64; n];
for i in 0..n {
let mut val = 0.0f64;
for idx in 0..k {
val += weights_norm[idx] * log_kv[idx][i];
}
log_p[i] = val;
}
let log_z = logsumexp_slice(&log_p);
for i in 0..n {
p[i] = (log_p[i] - log_z).exp();
}
for (idx, dist_k) in dists_norm.iter().enumerate() {
let log_u_k: Vec<f64> = (0..n)
.map(|i| {
let lp = if p[i] > 1e-300 { p[i].ln() } else { -700.0 };
lp - log_kv[idx][i]
})
.collect();
let mut new_log_v_k = vec![0.0f64; n];
for j in 0..n {
let lse = logsumexp_col_plus(&log_k, j, &log_u_k);
let lmu = if dist_k[j] > 1e-300 {
dist_k[j].ln()
} else {
-700.0
};
new_log_v_k[j] = lmu - lse;
}
log_v_all[idx] = new_log_v_k;
}
let delta: f64 = p
.iter()
.zip(p_prev.iter())
.map(|(new, old)| (new - old).abs())
.fold(0.0f64, f64::max);
if delta < tol {
break;
}
}
Ok(Array1::from_vec(p))
}
pub fn ot_plan_to_transport_map(
plan: &Array2<f64>,
target_points: &Array2<f64>,
) -> Result<impl Fn(usize) -> Vec<f64>> {
let n = plan.nrows();
let m = plan.ncols();
if m != target_points.nrows() {
return Err(TransformError::InvalidInput(format!(
"plan has {} columns but target_points has {} rows",
m,
target_points.nrows()
)));
}
if n == 0 || m == 0 {
return Err(TransformError::InvalidInput(
"Plan must be non-empty".to_string(),
));
}
let d = target_points.ncols();
if d == 0 {
return Err(TransformError::InvalidInput(
"Target points must have at least one dimension".to_string(),
));
}
let mut mapped_points: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..n {
let row_sum: f64 = (0..m).map(|j| plan[[i, j]]).sum();
let mut pt = vec![0.0f64; d];
if row_sum > 1e-15 {
for j in 0..m {
let w = plan[[i, j]] / row_sum;
for fd in 0..d {
pt[fd] += w * target_points[[j, fd]];
}
}
}
mapped_points.push(pt);
}
Ok(move |i: usize| -> Vec<f64> {
if i < mapped_points.len() {
mapped_points[i].clone()
} else {
vec![0.0f64; d]
}
})
}
fn build_grid_cost_matrix(n: usize) -> Array2<f64> {
let mut c = Array2::zeros((n, n));
let scale = if n > 1 {
1.0 / ((n - 1) as f64 * (n - 1) as f64)
} else {
1.0
};
for i in 0..n {
for j in 0..n {
let diff = i as f64 - j as f64;
c[[i, j]] = diff * diff * scale;
}
}
c
}
pub fn pairwise_cost_matrix(x: &ArrayView2<f64>, y: &ArrayView2<f64>, p: f64) -> Result<Array2<f64>> {
let n = x.nrows();
let m = y.nrows();
let d = x.ncols();
if y.ncols() != d {
return Err(TransformError::InvalidInput(format!(
"Dimension mismatch: x has {} features, y has {}",
d,
y.ncols()
)));
}
let mut cost = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
let mut dist = 0.0f64;
for k in 0..d {
dist += (x[[i, k]] - y[[j, k]]).abs().powf(p);
}
cost[[i, j]] = dist;
}
}
Ok(cost)
}
#[inline]
fn logsumexp_row_plus(log_k: &Array2<f64>, i: usize, log_v: &[f64]) -> f64 {
let m = log_v.len();
let mut max_val = f64::NEG_INFINITY;
for j in 0..m {
let v = log_k[[i, j]] + log_v[j];
if v > max_val {
max_val = v;
}
}
if max_val.is_infinite() {
return max_val;
}
let mut sum = 0.0f64;
for j in 0..m {
sum += (log_k[[i, j]] + log_v[j] - max_val).exp();
}
max_val + sum.ln()
}
#[inline]
fn logsumexp_col_plus(log_k: &Array2<f64>, j: usize, log_u: &[f64]) -> f64 {
let n = log_u.len();
let mut max_val = f64::NEG_INFINITY;
for i in 0..n {
let v = log_k[[i, j]] + log_u[i];
if v > max_val {
max_val = v;
}
}
if max_val.is_infinite() {
return max_val;
}
let mut sum = 0.0f64;
for i in 0..n {
sum += (log_k[[i, j]] + log_u[i] - max_val).exp();
}
max_val + sum.ln()
}
#[inline]
fn logsumexp_slice(log_p: &[f64]) -> f64 {
let max_val = log_p
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
if max_val.is_infinite() {
return max_val;
}
let sum: f64 = log_p.iter().map(|&v| (v - max_val).exp()).sum();
max_val + sum.ln()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_wasserstein_1d_identical() {
let u = vec![1.0, 2.0, 3.0];
let v = vec![1.0, 2.0, 3.0];
let dist = wasserstein_1d(&u, &v).expect("wasserstein_1d failed");
assert!(dist.abs() < 1e-10, "Identical distributions: W=0, got {}", dist);
}
#[test]
fn test_wasserstein_1d_shift() {
let u: Vec<f64> = (0..10).map(|i| i as f64).collect();
let v: Vec<f64> = (0..10).map(|i| i as f64 + 1.0).collect();
let dist = wasserstein_1d(&u, &v).expect("wasserstein_1d failed");
assert!(
(dist - 1.0).abs() < 1e-6,
"Shifted by 1: W1 should be ~1, got {}",
dist
);
}
#[test]
fn test_wasserstein_1d_empty_error() {
assert!(wasserstein_1d(&[], &[1.0]).is_err());
assert!(wasserstein_1d(&[1.0], &[]).is_err());
}
#[test]
fn test_wasserstein_1d_non_negative() {
let u = vec![0.5, 1.5, 2.5];
let v = vec![3.0, 4.0, 5.0];
let dist = wasserstein_1d(&u, &v).expect("wasserstein_1d failed");
assert!(dist >= 0.0);
}
#[test]
fn test_earth_mover_distance_identical() {
let hist = vec![0.25, 0.25, 0.25, 0.25];
let cost = Array2::from_shape_fn((4, 4), |(i, j)| {
let diff = i as f64 - j as f64;
diff * diff
});
let emd = earth_mover_distance(&hist, &hist, &cost).expect("EMD failed");
assert!(emd.abs() < 1e-10, "Identical histograms: EMD=0, got {}", emd);
}
#[test]
fn test_earth_mover_distance_adjacent() {
let h1 = vec![1.0, 0.0, 0.0, 0.0];
let h2 = vec![0.0, 1.0, 0.0, 0.0];
let cost = Array2::from_shape_fn((4, 4), |(i, j)| (i as f64 - j as f64).abs());
let emd = earth_mover_distance(&h1, &h2, &cost).expect("EMD failed");
assert!(emd > 0.0);
assert!(emd <= 1.0 + 1e-10, "Expected EMD <= 1, got {}", emd);
}
#[test]
fn test_earth_mover_dimension_mismatch() {
let h1 = vec![0.5, 0.5];
let h2 = vec![0.3, 0.3, 0.4];
let cost = Array2::zeros((2, 2));
assert!(earth_mover_distance(&h1, &h2, &cost).is_err());
}
#[test]
fn test_sinkhorn_shape() {
let a = Array1::from_vec(vec![0.5, 0.5]);
let b = Array1::from_vec(vec![0.3, 0.7]);
let cost = Array2::from_shape_fn((2, 2), |(i, j)| (i as f64 - j as f64).powi(2));
let plan = sinkhorn(&a.view(), &b.view(), &cost, 0.1, 100).expect("sinkhorn failed");
assert_eq!(plan.shape(), &[2, 2]);
}
#[test]
fn test_sinkhorn_marginals() {
let a = Array1::from_vec(vec![0.4, 0.6]);
let b = Array1::from_vec(vec![0.3, 0.7]);
let cost = Array2::from_shape_fn((2, 2), |(i, j)| (i as f64 - j as f64).powi(2));
let plan = sinkhorn(&a.view(), &b.view(), &cost, 0.01, 500).expect("sinkhorn failed");
let row_sum: Vec<f64> = (0..2).map(|i| plan.row(i).sum()).collect();
let col_sum: Vec<f64> = (0..2).map(|j| plan.column(j).sum()).collect();
for i in 0..2 {
assert!(
(row_sum[i] - a[i]).abs() < 0.1,
"row_sum[{}] = {} should be ~{}",
i,
row_sum[i],
a[i]
);
assert!(
(col_sum[i] - b[i]).abs() < 0.1,
"col_sum[{}] = {} should be ~{}",
i,
col_sum[i],
b[i]
);
}
}
#[test]
fn test_sinkhorn_invalid_reg() {
let a = Array1::from_vec(vec![0.5, 0.5]);
let b = Array1::from_vec(vec![0.5, 0.5]);
let cost = Array2::zeros((2, 2));
assert!(sinkhorn(&a.view(), &b.view(), &cost, 0.0, 10).is_err());
assert!(sinkhorn(&a.view(), &b.view(), &cost, -1.0, 10).is_err());
}
#[test]
fn test_sinkhorn_distance_non_negative() {
let a = Array1::from_vec(vec![0.2, 0.3, 0.5]);
let b = Array1::from_vec(vec![0.1, 0.5, 0.4]);
let cost = Array2::from_shape_fn((3, 3), |(i, j)| (i as f64 - j as f64).powi(2));
let dist = sinkhorn_distance(&a.view(), &b.view(), &cost, 0.05).expect("sinkhorn_distance failed");
assert!(dist >= 0.0, "Distance must be non-negative, got {}", dist);
assert!(dist.is_finite(), "Distance must be finite");
}
#[test]
fn test_sinkhorn_distance_identical_zero() {
let a = Array1::from_vec(vec![0.5, 0.5]);
let b = Array1::from_vec(vec![0.5, 0.5]);
let cost = Array2::from_shape_fn((2, 2), |(i, j)| (i as f64 - j as f64).powi(2));
let dist = sinkhorn_distance(&a.view(), &b.view(), &cost, 0.01).expect("sinkhorn_distance failed");
assert!(dist < 0.01, "Identical distributions: small distance, got {}", dist);
}
#[test]
fn test_sliced_wasserstein_basic() {
let x = Array2::from_shape_vec(
(4, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
)
.expect("Failed");
let y = Array2::from_shape_vec(
(4, 2),
vec![2.0, 2.0, 3.0, 2.0, 2.0, 3.0, 3.0, 3.0],
)
.expect("Failed");
let dist = sliced_wasserstein(&x.view(), &y.view(), 50, 42).expect("sliced_wasserstein failed");
assert!(dist > 0.0, "Non-identical clouds should have positive SW distance");
assert!(dist.is_finite());
}
#[test]
fn test_sliced_wasserstein_identical_zero() {
let x = Array2::from_shape_vec(
(3, 2),
vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
)
.expect("Failed");
let dist = sliced_wasserstein(&x.view(), &x.view(), 50, 0).expect("sliced_wasserstein failed");
assert!(dist.abs() < 1e-10, "Identical clouds: SW=0, got {}", dist);
}
#[test]
fn test_sliced_wasserstein_dim_mismatch() {
let x = Array2::zeros((3, 2));
let y = Array2::zeros((3, 3));
assert!(sliced_wasserstein(&x.view(), &y.view(), 10, 0).is_err());
}
#[test]
fn test_wasserstein_barycenter_midpoint() {
let d1 = Array1::from_vec(vec![0.0, 1.0, 0.0, 0.0]);
let d2 = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0]);
let dists = vec![d1.view(), d2.view()];
let weights = vec![0.5, 0.5];
let bary =
wasserstein_barycenter(&dists, &weights, 0.02, 200).expect("barycenter failed");
assert_eq!(bary.len(), 4);
let sum: f64 = bary.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "Barycenter must sum to 1, got {}", sum);
for &v in bary.iter() {
assert!(v >= -1e-10, "Barycenter must be non-negative, got {}", v);
}
}
#[test]
fn test_wasserstein_barycenter_single() {
let d1 = Array1::from_vec(vec![0.2, 0.5, 0.3]);
let dists = vec![d1.view()];
let weights = vec![1.0];
let bary = wasserstein_barycenter(&dists, &weights, 0.01, 100).expect("barycenter failed");
assert_eq!(bary.len(), 3);
let sum: f64 = bary.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_wasserstein_barycenter_weight_mismatch() {
let d1 = Array1::from_vec(vec![0.5, 0.5]);
let dists = vec![d1.view()];
let weights = vec![0.5, 0.5]; assert!(wasserstein_barycenter(&dists, &weights, 0.1, 10).is_err());
}
#[test]
fn test_ot_plan_to_transport_map_basic() {
let plan = Array2::from_shape_vec((2, 3), vec![0.5, 0.3, 0.2, 0.1, 0.6, 0.3]).expect("Failed");
let targets = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).expect("Failed");
let map_fn = ot_plan_to_transport_map(&plan, &targets).expect("transport map failed");
let img0 = map_fn(0);
let img1 = map_fn(1);
assert_eq!(img0.len(), 2);
assert_eq!(img1.len(), 2);
for &v in img0.iter().chain(img1.iter()) {
assert!(v.is_finite());
}
}
#[test]
fn test_ot_plan_transport_map_dimension_error() {
let plan = Array2::zeros((2, 3));
let targets = Array2::zeros((4, 2)); assert!(ot_plan_to_transport_map(&plan, &targets).is_err());
}
#[test]
fn test_pairwise_cost_matrix() {
let x = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 0.0]).expect("Failed");
let y = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 0.0]).expect("Failed");
let cost = pairwise_cost_matrix(&x.view(), &y.view(), 2.0).expect("cost failed");
assert!((cost[[0, 0]] - 1.0).abs() < 1e-10);
assert!((cost[[1, 0]] - 0.0).abs() < 1e-10);
}
}