use crate::stats::{mmd_rbf, sliced_wasserstein, MmdEstimator};
use crate::types::{DriftDimension, DriftScore};
use crate::Result;
use ndarray::ArrayView2;
pub struct EmbeddingDriftDetector {
threshold: f64,
n_projections: usize,
seed: u64,
}
impl EmbeddingDriftDetector {
pub fn new(threshold: f64) -> Self {
Self {
threshold,
n_projections: 64,
seed: 0,
}
}
pub fn with_projections(mut self, n: usize) -> Self {
self.n_projections = n;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn detect(
&self,
baseline: ArrayView2<f32>,
current: ArrayView2<f32>,
) -> Result<DriftScore> {
let mmd2 = mmd_rbf(baseline, current, None, MmdEstimator::Unbiased)?;
let sw = sliced_wasserstein(baseline, current, self.n_projections, self.seed)?;
let score = mmd2.max(0.0) + sw;
Ok(DriftScore::new(
DriftDimension::Embedding,
score,
self.threshold,
"mmd+sw",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn identical_embeddings_score_low() {
let x: Array2<f32> = Array2::ones((50, 16));
let det = EmbeddingDriftDetector::new(0.05);
let s = det.detect(x.view(), x.view()).unwrap();
assert!(!s.exceeded);
assert!(s.score < 1e-3, "score was {}", s.score);
}
#[test]
fn shifted_embeddings_score_high() {
let mut x: Array2<f32> = Array2::zeros((100, 8));
let mut y: Array2<f32> = Array2::zeros((100, 8));
for i in 0..100 {
for j in 0..8 {
x[[i, j]] = (i + j) as f32 * 0.01;
y[[i, j]] = (i + j) as f32 * 0.01 + 5.0;
}
}
let det = EmbeddingDriftDetector::new(0.1);
let s = det.detect(x.view(), y.view()).unwrap();
assert!(s.exceeded, "score was {}", s.score);
}
}