Skip to main content

ragdrift_core/detectors/
embedding.rs

1//! Embedding drift detector: composes MMD^2 (RBF) and sliced Wasserstein-1.
2
3use ndarray::ArrayView2;
4
5use crate::error::Result;
6use crate::stats::{mmd_rbf, sliced_wasserstein, MmdEstimator};
7use crate::types::{DriftDimension, DriftScore};
8
9/// Configuration for [`EmbeddingDriftDetector`].
10#[derive(Debug, Clone, Copy)]
11pub struct EmbeddingDriftConfig {
12    /// Threshold above which the combined score is flagged.
13    pub threshold: f64,
14    /// MMD^2 estimator. Default: unbiased.
15    pub estimator: MmdEstimator,
16    /// Number of random projections for sliced Wasserstein. Default: 64.
17    pub n_projections: usize,
18    /// RNG seed for projection sampling and bandwidth subsample.
19    pub seed: u64,
20    /// Weight applied to the sliced Wasserstein term in the combined score.
21    /// MMD^2 and sliced W1 have different units, so they need scaling. The
22    /// default of 0.5 is a reasonable starting point; tune from observed
23    /// production scores.
24    pub sliced_weight: f64,
25}
26
27impl Default for EmbeddingDriftConfig {
28    fn default() -> Self {
29        Self {
30            threshold: 0.1,
31            estimator: MmdEstimator::Unbiased,
32            n_projections: 64,
33            seed: 0,
34            sliced_weight: 0.5,
35        }
36    }
37}
38
39/// Detects drift between two embedding matrices.
40#[derive(Debug, Clone, Copy, Default)]
41pub struct EmbeddingDriftDetector {
42    config: EmbeddingDriftConfig,
43}
44
45impl EmbeddingDriftDetector {
46    /// Construct a detector from a custom config.
47    pub fn new(config: EmbeddingDriftConfig) -> Self {
48        Self { config }
49    }
50
51    /// Compute drift between two `(n_samples, dim)` embedding matrices.
52    pub fn detect(
53        &self,
54        baseline: &ArrayView2<'_, f32>,
55        current: &ArrayView2<'_, f32>,
56    ) -> Result<DriftScore> {
57        let mmd = mmd_rbf(baseline, current, self.config.estimator, self.config.seed)?;
58        let mmd = mmd.max(0.0);
59        let sw = sliced_wasserstein(
60            baseline,
61            current,
62            self.config.n_projections,
63            self.config.seed,
64        )?;
65        let combined = mmd + self.config.sliced_weight * sw;
66        Ok(DriftScore::new(
67            DriftDimension::Embedding,
68            combined,
69            self.config.threshold,
70            "mmd+sw",
71        ))
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use ndarray::Array2;
79    use ndarray_rand::rand_distr::StandardNormal;
80    use ndarray_rand::RandomExt;
81    use proptest::prelude::*;
82
83    #[test]
84    fn identical_embeddings_score_zero() {
85        let a = Array2::<f32>::random((128, 16), StandardNormal);
86        let detector = EmbeddingDriftDetector::default();
87        let s = detector.detect(&a.view(), &a.view()).unwrap();
88        assert!(s.score.abs() < 1e-6, "score={}", s.score);
89        assert!(!s.exceeded);
90        assert_eq!(s.method, "mmd+sw");
91    }
92
93    #[test]
94    fn shifted_embeddings_score_exceeds_threshold() {
95        let a = Array2::<f32>::random((128, 16), StandardNormal);
96        let mut b = a.clone();
97        b.mapv_inplace(|v| v + 3.0);
98        let detector = EmbeddingDriftDetector::default();
99        let s = detector.detect(&a.view(), &b.view()).unwrap();
100        assert!(s.exceeded, "expected drift, got {}", s.score);
101    }
102
103    proptest! {
104        #![proptest_config(ProptestConfig {
105            cases: 16, ..ProptestConfig::default()
106        })]
107        #[test]
108        fn property_identical_inputs_no_drift(seed in 0u64..1000) {
109            let mut rng = <rand::rngs::StdRng as rand::SeedableRng>::seed_from_u64(seed);
110            // Use the rand distribution backing ndarray-rand to keep parity.
111            let _ = &mut rng;
112            let a = Array2::<f32>::random((64, 8), StandardNormal);
113            let detector = EmbeddingDriftDetector::default();
114            let s = detector.detect(&a.view(), &a.view()).unwrap();
115            prop_assert!(s.score.abs() < 1e-6);
116            prop_assert!(!s.exceeded);
117        }
118    }
119}