ragdrift-core 0.1.4

Pure-Rust core for ragdrift: 5-dimensional drift detection for RAG systems.
Documentation
//! Maximum Mean Discrepancy with an RBF (Gaussian) kernel.
//!
//! MMD compares two empirical distributions by their mean embeddings in a
//! reproducing kernel Hilbert space. The Gaussian kernel,
//! `k(x, y) = exp(-||x - y||^2 / (2 * h^2))`, is the standard choice; the
//! bandwidth `h` defaults to the median pairwise distance over a random
//! subsample of the combined data (the median heuristic).
//!
//! Reference: Gretton et al., "A Kernel Two-Sample Test", JMLR 2012.

use ndarray::{ArrayView2, Axis};
use rand::seq::IndexedRandom;
use rand::SeedableRng;

use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols};

/// Choice of MMD^2 estimator.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MmdEstimator {
    /// Biased estimator (includes diagonal terms). Slightly faster, never negative.
    Biased,
    /// Unbiased estimator. Can produce small negative values under the null.
    #[default]
    Unbiased,
}

/// Compute MMD^2 with an RBF kernel and the median-heuristic bandwidth.
///
/// `seed` controls subsampling for the bandwidth estimate, so repeat runs are
/// deterministic. Both samples must have at least 2 rows and matching column
/// count.
///
/// ```rust
/// use ndarray::Array2;
/// use ragdrift_core::stats::{mmd_rbf, MmdEstimator};
///
/// let a = Array2::<f32>::zeros((32, 8));
/// let b = Array2::<f32>::zeros((32, 8));
/// let v = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).unwrap();
/// assert!(v.abs() < 1e-6);
/// ```
pub fn mmd_rbf(
    baseline: &ArrayView2<'_, f32>,
    current: &ArrayView2<'_, f32>,
    estimator: MmdEstimator,
    seed: u64,
) -> Result<f64> {
    check_same_cols(baseline, current)?;
    check_min_samples(baseline.nrows(), 2)?;
    check_min_samples(current.nrows(), 2)?;

    let bandwidth = median_pairwise_bandwidth(baseline, current, seed)?;
    if bandwidth <= 0.0 {
        return Ok(0.0);
    }
    let inv_two_h2 = 1.0 / (2.0 * bandwidth * bandwidth);

    let n = baseline.nrows();
    let m = current.nrows();

    let kxx = sum_kernel(
        baseline,
        baseline,
        inv_two_h2,
        estimator == MmdEstimator::Unbiased,
    );
    let kyy = sum_kernel(
        current,
        current,
        inv_two_h2,
        estimator == MmdEstimator::Unbiased,
    );
    let kxy = sum_kernel(baseline, current, inv_two_h2, false);

    let mmd2 = match estimator {
        MmdEstimator::Biased => {
            kxx / (n * n) as f64 + kyy / (m * m) as f64 - 2.0 * kxy / (n * m) as f64
        }
        MmdEstimator::Unbiased => {
            kxx / (n * (n - 1)) as f64 + kyy / (m * (m - 1)) as f64 - 2.0 * kxy / (n * m) as f64
        }
    };

    if !mmd2.is_finite() {
        return Err(RagDriftError::NumericalInstability {
            step: "mmd".into(),
            reason: "non-finite mmd^2".into(),
        });
    }
    Ok(mmd2)
}

/// Sum of `k(x_i, y_j) = exp(-||x_i - y_j||^2 / (2 h^2))` over all (i, j).
/// If `exclude_diagonal` is true, skips i == j (only valid when xs == ys).
#[cfg(feature = "parallel")]
fn sum_kernel(
    xs: &ArrayView2<'_, f32>,
    ys: &ArrayView2<'_, f32>,
    inv_two_h2: f64,
    exclude_diagonal: bool,
) -> f64 {
    use rayon::prelude::*;
    xs.axis_iter(Axis(0))
        .into_par_iter()
        .enumerate()
        .map(|(i, x)| {
            let mut row_sum = 0.0_f64;
            for (j, y) in ys.axis_iter(Axis(0)).enumerate() {
                if exclude_diagonal && i == j {
                    continue;
                }
                let mut d = 0.0_f64;
                for (a, b) in x.iter().zip(y.iter()) {
                    let diff = (*a as f64) - (*b as f64);
                    d += diff * diff;
                }
                row_sum += (-d * inv_two_h2).exp();
            }
            row_sum
        })
        .sum()
}

#[cfg(not(feature = "parallel"))]
fn sum_kernel(
    xs: &ArrayView2<'_, f32>,
    ys: &ArrayView2<'_, f32>,
    inv_two_h2: f64,
    exclude_diagonal: bool,
) -> f64 {
    let mut total = 0.0_f64;
    for (i, x) in xs.axis_iter(Axis(0)).enumerate() {
        for (j, y) in ys.axis_iter(Axis(0)).enumerate() {
            if exclude_diagonal && i == j {
                continue;
            }
            let mut d = 0.0_f64;
            for (a, b) in x.iter().zip(y.iter()) {
                let diff = (*a as f64) - (*b as f64);
                d += diff * diff;
            }
            total += (-d * inv_two_h2).exp();
        }
    }
    total
}

/// Median pairwise Euclidean distance over a sub-sample of the combined data.
fn median_pairwise_bandwidth(
    a: &ArrayView2<'_, f32>,
    b: &ArrayView2<'_, f32>,
    seed: u64,
) -> Result<f64> {
    // Cap subsample so the heuristic stays cheap on 10k+ rows.
    const MAX_PAIRS: usize = 500;
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);

    let combined: Vec<ndarray::ArrayView1<'_, f32>> =
        a.axis_iter(Axis(0)).chain(b.axis_iter(Axis(0))).collect();
    let total = combined.len();
    if total < 2 {
        return Err(RagDriftError::InsufficientSamples {
            needed: 2,
            got: total,
        });
    }

    let n_pick = total.min(64);
    let picked: Vec<&ndarray::ArrayView1<'_, f32>> =
        combined.choose_multiple(&mut rng, n_pick).collect();

    let mut dists: Vec<f64> = Vec::with_capacity((n_pick * (n_pick - 1)) / 2);
    for i in 0..picked.len() {
        for j in (i + 1)..picked.len() {
            if dists.len() >= MAX_PAIRS {
                break;
            }
            let mut d = 0.0_f64;
            for (x, y) in picked[i].iter().zip(picked[j].iter()) {
                let diff = (*x as f64) - (*y as f64);
                d += diff * diff;
            }
            dists.push(d.sqrt());
        }
        if dists.len() >= MAX_PAIRS {
            break;
        }
    }
    if dists.is_empty() {
        return Ok(0.0);
    }
    dists.sort_by(|x, y| x.partial_cmp(y).unwrap());
    let mid = dists.len() / 2;
    Ok(dists[mid].max(1e-12))
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;
    use ndarray::Array2;
    use ndarray_rand::rand_distr::StandardNormal;
    use ndarray_rand::RandomExt;

    #[test]
    fn identical_arrays_zero_mmd_biased() {
        let a = Array2::<f32>::zeros((16, 4));
        let b = a.clone();
        let v = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).unwrap();
        assert_abs_diff_eq!(v, 0.0, epsilon = 1e-9);
    }

    #[test]
    fn same_distribution_small_mmd() {
        // Both drawn from N(0, I_8); MMD^2 should be small relative to a
        // shifted distribution.
        let a = Array2::<f32>::random((128, 8), StandardNormal);
        let b = Array2::<f32>::random((128, 8), StandardNormal);
        let v = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Unbiased, 1).unwrap();
        assert!(v.abs() < 0.05, "expected small MMD^2 under H0, got {v}");
    }

    #[test]
    fn shifted_distribution_larger_mmd() {
        let a = Array2::<f32>::random((128, 8), StandardNormal);
        let mut b = Array2::<f32>::random((128, 8), StandardNormal);
        b.mapv_inplace(|v| v + 2.0);
        let v0 = mmd_rbf(&a.view(), &a.view(), MmdEstimator::Biased, 1).unwrap();
        let v1 = mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 1).unwrap();
        assert!(v1 > v0, "shifted MMD should exceed identical, {v1} vs {v0}");
        assert!(v1 > 0.05);
    }

    #[test]
    fn rejects_dim_mismatch() {
        let a = Array2::<f32>::zeros((4, 4));
        let b = Array2::<f32>::zeros((4, 8));
        assert!(mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).is_err());
    }

    #[test]
    fn rejects_too_few_samples() {
        let a = Array2::<f32>::zeros((1, 4));
        let b = Array2::<f32>::zeros((4, 4));
        assert!(mmd_rbf(&a.view(), &b.view(), MmdEstimator::Biased, 0).is_err());
    }
}