use ndarray::ArrayView2;
use crate::error::Result;
use crate::stats::{mmd_rbf, sliced_wasserstein, MmdEstimator};
use crate::types::{DriftDimension, DriftScore};
#[derive(Debug, Clone, Copy)]
pub struct EmbeddingDriftConfig {
pub threshold: f64,
pub estimator: MmdEstimator,
pub n_projections: usize,
pub seed: u64,
pub sliced_weight: f64,
}
impl Default for EmbeddingDriftConfig {
fn default() -> Self {
Self {
threshold: 0.1,
estimator: MmdEstimator::Unbiased,
n_projections: 64,
seed: 0,
sliced_weight: 0.5,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct EmbeddingDriftDetector {
config: EmbeddingDriftConfig,
}
impl EmbeddingDriftDetector {
pub fn new(config: EmbeddingDriftConfig) -> Self {
Self { config }
}
pub fn detect(
&self,
baseline: &ArrayView2<'_, f32>,
current: &ArrayView2<'_, f32>,
) -> Result<DriftScore> {
let mmd = mmd_rbf(baseline, current, self.config.estimator, self.config.seed)?;
let mmd = mmd.max(0.0);
let sw = sliced_wasserstein(
baseline,
current,
self.config.n_projections,
self.config.seed,
)?;
let combined = mmd + self.config.sliced_weight * sw;
Ok(DriftScore::new(
DriftDimension::Embedding,
combined,
self.config.threshold,
"mmd+sw",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
use proptest::prelude::*;
#[test]
fn identical_embeddings_score_zero() {
let a = Array2::<f32>::random((128, 16), StandardNormal);
let detector = EmbeddingDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
assert!(s.score.abs() < 1e-6, "score={}", s.score);
assert!(!s.exceeded);
assert_eq!(s.method, "mmd+sw");
}
#[test]
fn shifted_embeddings_score_exceeds_threshold() {
let a = Array2::<f32>::random((128, 16), StandardNormal);
let mut b = a.clone();
b.mapv_inplace(|v| v + 3.0);
let detector = EmbeddingDriftDetector::default();
let s = detector.detect(&a.view(), &b.view()).unwrap();
assert!(s.exceeded, "expected drift, got {}", s.score);
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 16, ..ProptestConfig::default()
})]
#[test]
fn property_identical_inputs_no_drift(seed in 0u64..1000) {
let mut rng = <rand::rngs::StdRng as rand::SeedableRng>::seed_from_u64(seed);
let _ = &mut rng;
let a = Array2::<f32>::random((64, 8), StandardNormal);
let detector = EmbeddingDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
prop_assert!(s.score.abs() < 1e-6);
prop_assert!(!s.exceeded);
}
}
}