ragdrift_core/detectors/
embedding.rs1use ndarray::ArrayView2;
4
5use crate::error::Result;
6use crate::stats::{mmd_rbf, sliced_wasserstein, MmdEstimator};
7use crate::types::{DriftDimension, DriftScore};
8
9#[derive(Debug, Clone, Copy)]
11pub struct EmbeddingDriftConfig {
12 pub threshold: f64,
14 pub estimator: MmdEstimator,
16 pub n_projections: usize,
18 pub seed: u64,
20 pub sliced_weight: f64,
25}
26
27impl Default for EmbeddingDriftConfig {
28 fn default() -> Self {
29 Self {
30 threshold: 0.1,
31 estimator: MmdEstimator::Unbiased,
32 n_projections: 64,
33 seed: 0,
34 sliced_weight: 0.5,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, Default)]
41pub struct EmbeddingDriftDetector {
42 config: EmbeddingDriftConfig,
43}
44
45impl EmbeddingDriftDetector {
46 pub fn new(config: EmbeddingDriftConfig) -> Self {
48 Self { config }
49 }
50
51 pub fn detect(
53 &self,
54 baseline: &ArrayView2<'_, f32>,
55 current: &ArrayView2<'_, f32>,
56 ) -> Result<DriftScore> {
57 let mmd = mmd_rbf(baseline, current, self.config.estimator, self.config.seed)?;
58 let mmd = mmd.max(0.0);
59 let sw = sliced_wasserstein(
60 baseline,
61 current,
62 self.config.n_projections,
63 self.config.seed,
64 )?;
65 let combined = mmd + self.config.sliced_weight * sw;
66 Ok(DriftScore::new(
67 DriftDimension::Embedding,
68 combined,
69 self.config.threshold,
70 "mmd+sw",
71 ))
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use ndarray::Array2;
79 use ndarray_rand::rand_distr::StandardNormal;
80 use ndarray_rand::RandomExt;
81 use proptest::prelude::*;
82
83 #[test]
84 fn identical_embeddings_score_zero() {
85 let a = Array2::<f32>::random((128, 16), StandardNormal);
86 let detector = EmbeddingDriftDetector::default();
87 let s = detector.detect(&a.view(), &a.view()).unwrap();
88 assert!(s.score.abs() < 1e-6, "score={}", s.score);
89 assert!(!s.exceeded);
90 assert_eq!(s.method, "mmd+sw");
91 }
92
93 #[test]
94 fn shifted_embeddings_score_exceeds_threshold() {
95 let a = Array2::<f32>::random((128, 16), StandardNormal);
96 let mut b = a.clone();
97 b.mapv_inplace(|v| v + 3.0);
98 let detector = EmbeddingDriftDetector::default();
99 let s = detector.detect(&a.view(), &b.view()).unwrap();
100 assert!(s.exceeded, "expected drift, got {}", s.score);
101 }
102
103 proptest! {
104 #![proptest_config(ProptestConfig {
105 cases: 16, ..ProptestConfig::default()
106 })]
107 #[test]
108 fn property_identical_inputs_no_drift(seed in 0u64..1000) {
109 let mut rng = <rand::rngs::StdRng as rand::SeedableRng>::seed_from_u64(seed);
110 let _ = &mut rng;
112 let a = Array2::<f32>::random((64, 8), StandardNormal);
113 let detector = EmbeddingDriftDetector::default();
114 let s = detector.detect(&a.view(), &a.view()).unwrap();
115 prop_assert!(s.score.abs() < 1e-6);
116 prop_assert!(!s.exceeded);
117 }
118 }
119}