use ndarray::{Array1, ArrayView1, ArrayView2, Axis};
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols};
pub fn wasserstein_1d(a: &ArrayView1<'_, f64>, b: &ArrayView1<'_, f64>) -> Result<f64> {
check_min_samples(a.len(), 1)?;
check_min_samples(b.len(), 1)?;
if a.iter().any(|x| !x.is_finite()) || b.iter().any(|x| !x.is_finite()) {
return Err(RagDriftError::NumericalInstability {
step: "wasserstein_1d".into(),
reason: "non-finite input".into(),
});
}
let mut sa: Vec<f64> = a.iter().copied().collect();
let mut sb: Vec<f64> = b.iter().copied().collect();
sa.sort_by(|x, y| x.partial_cmp(y).unwrap());
sb.sort_by(|x, y| x.partial_cmp(y).unwrap());
let n = sa.len() as f64;
let m = sb.len() as f64;
let mut i = 0usize;
let mut j = 0usize;
let mut prev = sa[0].min(sb[0]);
let mut total = 0.0_f64;
while i < sa.len() || j < sb.len() {
let next = match (sa.get(i), sb.get(j)) {
(Some(x), Some(y)) => x.min(*y),
(Some(x), None) => *x,
(None, Some(y)) => *y,
(None, None) => break,
};
let fa = (i as f64) / n;
let fb = (j as f64) / m;
total += (fa - fb).abs() * (next - prev);
while i < sa.len() && sa[i] <= next {
i += 1;
}
while j < sb.len() && sb[j] <= next {
j += 1;
}
prev = next;
}
Ok(total)
}
pub fn sliced_wasserstein(
baseline: &ArrayView2<'_, f32>,
current: &ArrayView2<'_, f32>,
n_projections: usize,
seed: u64,
) -> Result<f64> {
check_same_cols(baseline, current)?;
check_min_samples(baseline.nrows(), 1)?;
check_min_samples(current.nrows(), 1)?;
if n_projections == 0 {
return Err(RagDriftError::InvalidConfig(
"n_projections must be > 0".into(),
));
}
let dim = baseline.ncols();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = StandardNormal;
let mut total = 0.0_f64;
for _ in 0..n_projections {
let mut u = Array1::<f64>::zeros(dim);
let mut norm2 = 0.0_f64;
for k in 0..dim {
let v: f64 = normal.sample(&mut rng);
u[k] = v;
norm2 += v * v;
}
let norm = norm2.sqrt().max(1e-12);
for k in 0..dim {
u[k] /= norm;
}
let pa = project(baseline, &u);
let pb = project(current, &u);
total += wasserstein_1d(&pa.view(), &pb.view())?;
}
Ok(total / n_projections as f64)
}
fn project(matrix: &ArrayView2<'_, f32>, u: &Array1<f64>) -> Array1<f64> {
let n = matrix.nrows();
let mut out = Array1::<f64>::zeros(n);
for (i, row) in matrix.axis_iter(Axis(0)).enumerate() {
let mut dot = 0.0_f64;
for (a, b) in row.iter().zip(u.iter()) {
dot += (*a as f64) * b;
}
out[i] = dot;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::{Array1, Array2};
#[test]
fn analytic_case_uniform_shift() {
let n = 1024;
let a = Array1::from((0..n).map(|i| i as f64 / n as f64).collect::<Vec<_>>());
let b = a.mapv(|x| x + 1.0);
let w = wasserstein_1d(&a.view(), &b.view()).unwrap();
assert_abs_diff_eq!(w, 1.0, epsilon = 1e-3);
}
#[test]
fn identical_samples_zero() {
let a = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let w = wasserstein_1d(&a.view(), &a.view()).unwrap();
assert_eq!(w, 0.0);
}
#[test]
fn unequal_size_two_point_masses() {
let a = Array1::from(vec![0.0]);
let b = Array1::from(vec![5.0]);
let w = wasserstein_1d(&a.view(), &b.view()).unwrap();
assert_abs_diff_eq!(w, 5.0, epsilon = 1e-12);
}
#[test]
fn sliced_zero_for_identical_matrix() {
let a = Array2::<f32>::zeros((32, 8));
let b = a.clone();
let w = sliced_wasserstein(&a.view(), &b.view(), 16, 0).unwrap();
assert_abs_diff_eq!(w, 0.0, epsilon = 1e-9);
}
#[test]
fn sliced_increases_with_shift() {
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
let a = Array2::<f32>::random((128, 16), StandardNormal);
let mut b = a.clone();
b.mapv_inplace(|v| v + 2.0);
let w0 = sliced_wasserstein(&a.view(), &a.view(), 32, 0).unwrap();
let w1 = sliced_wasserstein(&a.view(), &b.view(), 32, 0).unwrap();
assert!(w1 > w0);
assert!(w1 > 0.5);
}
#[test]
fn rejects_zero_projections() {
let a = Array2::<f32>::zeros((4, 4));
let b = Array2::<f32>::zeros((4, 4));
assert!(sliced_wasserstein(&a.view(), &b.view(), 0, 0).is_err());
}
}