ragdrift-core 0.1.4

Pure-Rust core for ragdrift: 5-dimensional drift detection for RAG systems.
Documentation
//! Wasserstein-1 distance.
//!
//! - [`wasserstein_1d`] computes the exact 1D Wasserstein-1 between two
//!   empirical distributions of arbitrary (possibly different) sizes.
//! - [`sliced_wasserstein`] approximates Wasserstein-1 in higher dimensions
//!   by averaging 1D Wasserstein-1 over random projections (Bonneel et al.,
//!   "Sliced and Radon Wasserstein Barycenters of Measures", JMCG 2015).

use ndarray::{Array1, ArrayView1, ArrayView2, Axis};
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};

use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols};

/// Exact 1D Wasserstein-1 distance between two empirical distributions.
///
/// Computes `W1(P, Q) = integral |F_P(x) - F_Q(x)| dx` in
/// `O((n + m) log(n + m))` time. Sample sizes need not match.
///
/// ```rust
/// use ndarray::Array1;
/// use ragdrift_core::stats::wasserstein_1d;
///
/// let a = Array1::from(vec![0.0, 0.5, 1.0]);
/// let b = Array1::from(vec![1.0, 1.5, 2.0]);
/// let w = wasserstein_1d(&a.view(), &b.view()).unwrap();
/// assert!((w - 1.0).abs() < 1e-9);
/// ```
pub fn wasserstein_1d(a: &ArrayView1<'_, f64>, b: &ArrayView1<'_, f64>) -> Result<f64> {
    check_min_samples(a.len(), 1)?;
    check_min_samples(b.len(), 1)?;
    if a.iter().any(|x| !x.is_finite()) || b.iter().any(|x| !x.is_finite()) {
        return Err(RagDriftError::NumericalInstability {
            step: "wasserstein_1d".into(),
            reason: "non-finite input".into(),
        });
    }

    let mut sa: Vec<f64> = a.iter().copied().collect();
    let mut sb: Vec<f64> = b.iter().copied().collect();
    sa.sort_by(|x, y| x.partial_cmp(y).unwrap());
    sb.sort_by(|x, y| x.partial_cmp(y).unwrap());

    // Walk merged sorted points; integrate |F_a - F_b| over each interval.
    let n = sa.len() as f64;
    let m = sb.len() as f64;
    let mut i = 0usize;
    let mut j = 0usize;
    let mut prev = sa[0].min(sb[0]);
    let mut total = 0.0_f64;
    while i < sa.len() || j < sb.len() {
        let next = match (sa.get(i), sb.get(j)) {
            (Some(x), Some(y)) => x.min(*y),
            (Some(x), None) => *x,
            (None, Some(y)) => *y,
            (None, None) => break,
        };
        let fa = (i as f64) / n;
        let fb = (j as f64) / m;
        total += (fa - fb).abs() * (next - prev);
        // Advance whichever pointers match the next breakpoint.
        while i < sa.len() && sa[i] <= next {
            i += 1;
        }
        while j < sb.len() && sb[j] <= next {
            j += 1;
        }
        prev = next;
    }
    Ok(total)
}

/// Sliced Wasserstein-1 over `n_projections` random unit directions.
///
/// `seed` controls the projections so the metric is reproducible.
///
/// ```rust
/// use ndarray::Array2;
/// use ragdrift_core::stats::sliced_wasserstein;
///
/// let a = Array2::<f32>::zeros((32, 16));
/// let b = Array2::<f32>::zeros((32, 16));
/// let w = sliced_wasserstein(&a.view(), &b.view(), 32, 0).unwrap();
/// assert!(w.abs() < 1e-9);
/// ```
pub fn sliced_wasserstein(
    baseline: &ArrayView2<'_, f32>,
    current: &ArrayView2<'_, f32>,
    n_projections: usize,
    seed: u64,
) -> Result<f64> {
    check_same_cols(baseline, current)?;
    check_min_samples(baseline.nrows(), 1)?;
    check_min_samples(current.nrows(), 1)?;
    if n_projections == 0 {
        return Err(RagDriftError::InvalidConfig(
            "n_projections must be > 0".into(),
        ));
    }

    let dim = baseline.ncols();
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
    let normal = StandardNormal;

    let mut total = 0.0_f64;
    for _ in 0..n_projections {
        // Sample a Gaussian vector and L2-normalize for a uniform unit-sphere
        // direction.
        let mut u = Array1::<f64>::zeros(dim);
        let mut norm2 = 0.0_f64;
        for k in 0..dim {
            let v: f64 = normal.sample(&mut rng);
            u[k] = v;
            norm2 += v * v;
        }
        let norm = norm2.sqrt().max(1e-12);
        for k in 0..dim {
            u[k] /= norm;
        }

        let pa = project(baseline, &u);
        let pb = project(current, &u);
        total += wasserstein_1d(&pa.view(), &pb.view())?;
    }
    Ok(total / n_projections as f64)
}

fn project(matrix: &ArrayView2<'_, f32>, u: &Array1<f64>) -> Array1<f64> {
    let n = matrix.nrows();
    let mut out = Array1::<f64>::zeros(n);
    for (i, row) in matrix.axis_iter(Axis(0)).enumerate() {
        let mut dot = 0.0_f64;
        for (a, b) in row.iter().zip(u.iter()) {
            dot += (*a as f64) * b;
        }
        out[i] = dot;
    }
    out
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;
    use ndarray::{Array1, Array2};

    #[test]
    fn analytic_case_uniform_shift() {
        // U[0,1] vs U[1,2]: W1 = 1 exactly.
        let n = 1024;
        let a = Array1::from((0..n).map(|i| i as f64 / n as f64).collect::<Vec<_>>());
        let b = a.mapv(|x| x + 1.0);
        let w = wasserstein_1d(&a.view(), &b.view()).unwrap();
        assert_abs_diff_eq!(w, 1.0, epsilon = 1e-3);
    }

    #[test]
    fn identical_samples_zero() {
        let a = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
        let w = wasserstein_1d(&a.view(), &a.view()).unwrap();
        assert_eq!(w, 0.0);
    }

    #[test]
    fn unequal_size_two_point_masses() {
        // Single point at 0 vs single point at 5: W1 = 5.
        let a = Array1::from(vec![0.0]);
        let b = Array1::from(vec![5.0]);
        let w = wasserstein_1d(&a.view(), &b.view()).unwrap();
        assert_abs_diff_eq!(w, 5.0, epsilon = 1e-12);
    }

    #[test]
    fn sliced_zero_for_identical_matrix() {
        let a = Array2::<f32>::zeros((32, 8));
        let b = a.clone();
        let w = sliced_wasserstein(&a.view(), &b.view(), 16, 0).unwrap();
        assert_abs_diff_eq!(w, 0.0, epsilon = 1e-9);
    }

    #[test]
    fn sliced_increases_with_shift() {
        use ndarray_rand::rand_distr::StandardNormal;
        use ndarray_rand::RandomExt;
        let a = Array2::<f32>::random((128, 16), StandardNormal);
        let mut b = a.clone();
        b.mapv_inplace(|v| v + 2.0);
        let w0 = sliced_wasserstein(&a.view(), &a.view(), 32, 0).unwrap();
        let w1 = sliced_wasserstein(&a.view(), &b.view(), 32, 0).unwrap();
        assert!(w1 > w0);
        assert!(w1 > 0.5);
    }

    #[test]
    fn rejects_zero_projections() {
        let a = Array2::<f32>::zeros((4, 4));
        let b = Array2::<f32>::zeros((4, 4));
        assert!(sliced_wasserstein(&a.view(), &b.view(), 0, 0).is_err());
    }
}