ragdrift-core 0.1.4

Pure-Rust core for ragdrift: 5-dimensional drift detection for RAG systems.
Documentation
//! Embedding drift detector: composes MMD^2 (RBF) and sliced Wasserstein-1.

use ndarray::ArrayView2;

use crate::error::Result;
use crate::stats::{mmd_rbf, sliced_wasserstein, MmdEstimator};
use crate::types::{DriftDimension, DriftScore};

/// Configuration for [`EmbeddingDriftDetector`].
#[derive(Debug, Clone, Copy)]
pub struct EmbeddingDriftConfig {
    /// Threshold above which the combined score is flagged.
    pub threshold: f64,
    /// MMD^2 estimator. Default: unbiased.
    pub estimator: MmdEstimator,
    /// Number of random projections for sliced Wasserstein. Default: 64.
    pub n_projections: usize,
    /// RNG seed for projection sampling and bandwidth subsample.
    pub seed: u64,
    /// Weight applied to the sliced Wasserstein term in the combined score.
    /// MMD^2 and sliced W1 have different units, so they need scaling. The
    /// default of 0.5 is a reasonable starting point; tune from observed
    /// production scores.
    pub sliced_weight: f64,
}

impl Default for EmbeddingDriftConfig {
    fn default() -> Self {
        Self {
            threshold: 0.1,
            estimator: MmdEstimator::Unbiased,
            n_projections: 64,
            seed: 0,
            sliced_weight: 0.5,
        }
    }
}

/// Detects drift between two embedding matrices.
#[derive(Debug, Clone, Copy, Default)]
pub struct EmbeddingDriftDetector {
    config: EmbeddingDriftConfig,
}

impl EmbeddingDriftDetector {
    /// Construct a detector from a custom config.
    pub fn new(config: EmbeddingDriftConfig) -> Self {
        Self { config }
    }

    /// Compute drift between two `(n_samples, dim)` embedding matrices.
    pub fn detect(
        &self,
        baseline: &ArrayView2<'_, f32>,
        current: &ArrayView2<'_, f32>,
    ) -> Result<DriftScore> {
        let mmd = mmd_rbf(baseline, current, self.config.estimator, self.config.seed)?;
        let mmd = mmd.max(0.0);
        let sw = sliced_wasserstein(
            baseline,
            current,
            self.config.n_projections,
            self.config.seed,
        )?;
        let combined = mmd + self.config.sliced_weight * sw;
        Ok(DriftScore::new(
            DriftDimension::Embedding,
            combined,
            self.config.threshold,
            "mmd+sw",
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array2;
    use ndarray_rand::rand_distr::StandardNormal;
    use ndarray_rand::RandomExt;
    use proptest::prelude::*;

    #[test]
    fn identical_embeddings_score_zero() {
        let a = Array2::<f32>::random((128, 16), StandardNormal);
        let detector = EmbeddingDriftDetector::default();
        let s = detector.detect(&a.view(), &a.view()).unwrap();
        assert!(s.score.abs() < 1e-6, "score={}", s.score);
        assert!(!s.exceeded);
        assert_eq!(s.method, "mmd+sw");
    }

    #[test]
    fn shifted_embeddings_score_exceeds_threshold() {
        let a = Array2::<f32>::random((128, 16), StandardNormal);
        let mut b = a.clone();
        b.mapv_inplace(|v| v + 3.0);
        let detector = EmbeddingDriftDetector::default();
        let s = detector.detect(&a.view(), &b.view()).unwrap();
        assert!(s.exceeded, "expected drift, got {}", s.score);
    }

    proptest! {
        #![proptest_config(ProptestConfig {
            cases: 16, ..ProptestConfig::default()
        })]
        #[test]
        fn property_identical_inputs_no_drift(seed in 0u64..1000) {
            let mut rng = <rand::rngs::StdRng as rand::SeedableRng>::seed_from_u64(seed);
            // Use the rand distribution backing ndarray-rand to keep parity.
            let _ = &mut rng;
            let a = Array2::<f32>::random((64, 8), StandardNormal);
            let detector = EmbeddingDriftDetector::default();
            let s = detector.detect(&a.view(), &a.view()).unwrap();
            prop_assert!(s.score.abs() < 1e-6);
            prop_assert!(!s.exceeded);
        }
    }
}