use crate::error::RagDriftError;
use crate::Result;
use ndarray::{Array2, ArrayView2, Axis};
use rand::distributions::Distribution;
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::Normal;
pub fn wasserstein_1d(a: &[f64], b: &[f64]) -> Result<f64> {
if a.is_empty() || b.is_empty() {
return Err(RagDriftError::InsufficientSamples {
required: 1,
got: a.len().min(b.len()),
context: "wasserstein_1d",
});
}
let mut a_sorted: Vec<f64> = a.to_vec();
let mut b_sorted: Vec<f64> = b.to_vec();
a_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
b_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
if a_sorted.len() == b_sorted.len() {
let n = a_sorted.len() as f64;
let s: f64 = a_sorted
.iter()
.zip(b_sorted.iter())
.map(|(x, y)| (x - y).abs())
.sum();
return Ok(s / n);
}
let n = a_sorted.len();
let m = b_sorted.len();
let n_f = n as f64;
let m_f = m as f64;
let mut all: Vec<f64> = Vec::with_capacity(n + m);
all.extend_from_slice(&a_sorted);
all.extend_from_slice(&b_sorted);
all.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
let mut i = 0usize;
let mut j = 0usize;
let mut prev = all[0];
let mut total = 0.0_f64;
for &x in all.iter().skip(1) {
let fa = i as f64 / n_f;
let fb = j as f64 / m_f;
total += (fa - fb).abs() * (x - prev);
prev = x;
while i < n && a_sorted[i] <= x {
i += 1;
}
while j < m && b_sorted[j] <= x {
j += 1;
}
}
Ok(total)
}
pub fn sliced_wasserstein(
x: ArrayView2<f32>,
y: ArrayView2<f32>,
n_projections: usize,
seed: u64,
) -> Result<f64> {
if x.ncols() != y.ncols() {
return Err(RagDriftError::DimensionMismatch {
expected: x.ncols(),
actual: y.ncols(),
context: "sliced_wasserstein",
});
}
if x.nrows() == 0 || y.nrows() == 0 {
return Err(RagDriftError::InsufficientSamples {
required: 1,
got: x.nrows().min(y.nrows()),
context: "sliced_wasserstein",
});
}
if n_projections == 0 {
return Err(RagDriftError::InvalidConfig(
"sliced_wasserstein: n_projections must be > 0".into(),
));
}
let dim = x.ncols();
let projections = sample_unit_directions(dim, n_projections, seed);
let mut total = 0.0_f64;
for col in projections.axis_iter(Axis(0)) {
let xp: Vec<f64> = project(&x, col.as_slice().unwrap());
let yp: Vec<f64> = project(&y, col.as_slice().unwrap());
total += wasserstein_1d(&xp, &yp)?;
}
Ok(total / n_projections as f64)
}
fn project(x: &ArrayView2<f32>, dir: &[f32]) -> Vec<f64> {
x.axis_iter(Axis(0))
.map(|row| {
row.iter()
.zip(dir.iter())
.map(|(a, b)| (*a as f64) * (*b as f64))
.sum::<f64>()
})
.collect()
}
fn sample_unit_directions(dim: usize, n: usize, seed: u64) -> Array2<f32> {
let mut rng = StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0_f32, 1.0_f32).unwrap();
let mut a = Array2::<f32>::zeros((n, dim));
for mut row in a.axis_iter_mut(Axis(0)) {
let mut sumsq = 0.0_f32;
for v in row.iter_mut() {
let s = normal.sample(&mut rng);
*v = s;
sumsq += s * s;
}
let norm = sumsq.sqrt().max(f32::EPSILON);
for v in row.iter_mut() {
*v /= norm;
}
}
a
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn identical_samples_have_w_zero() {
let a: Vec<f64> = (0..100).map(|i| i as f64).collect();
let w = wasserstein_1d(&a, &a).unwrap();
assert_eq!(w, 0.0);
}
#[test]
fn unit_uniform_shift_has_w_one() {
let a: Vec<f64> = (0..1000).map(|i| i as f64 / 1000.0).collect();
let b: Vec<f64> = a.iter().map(|x| x + 1.0).collect();
let w = wasserstein_1d(&a, &b).unwrap();
assert!((w - 1.0).abs() < 1e-3, "w was {}", w);
}
#[test]
fn unequal_lengths_handled() {
let a: Vec<f64> = (0..50).map(|i| i as f64 / 50.0).collect();
let b: Vec<f64> = (0..200).map(|i| i as f64 / 200.0 + 1.0).collect();
let w = wasserstein_1d(&a, &b).unwrap();
assert!((w - 1.0).abs() < 0.05, "w was {}", w);
}
#[test]
fn sliced_w_identical_near_zero() {
let x: Array2<f32> = Array2::ones((50, 8));
let s = sliced_wasserstein(x.view(), x.view(), 32, 0).unwrap();
assert!(s.abs() < 1e-6, "s was {}", s);
}
#[test]
fn sliced_w_shift_positive() {
let x: Array2<f32> = Array2::zeros((100, 8));
let y: Array2<f32> = Array2::ones((100, 8)) * 2.0;
let s = sliced_wasserstein(x.view(), y.view(), 64, 7).unwrap();
assert!(s > 0.5, "s was {}", s);
}
}