use ndarray::{ArrayView1, ArrayView2};
use crate::detectors::embedding::EmbeddingDriftDetector;
use crate::error::Result;
use crate::stats::ks_two_sample;
use crate::types::{DriftDimension, DriftScore};
#[derive(Debug, Clone, Copy)]
pub struct ResponseDriftConfig {
pub threshold: f64,
pub semantic_weight: f64,
}
impl Default for ResponseDriftConfig {
fn default() -> Self {
Self {
threshold: 0.2,
semantic_weight: 1.0,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ResponseDriftDetector {
config: ResponseDriftConfig,
}
impl ResponseDriftDetector {
pub fn new(config: ResponseDriftConfig) -> Self {
Self { config }
}
pub fn detect(
&self,
baseline_lengths: &ArrayView1<'_, f64>,
current_lengths: &ArrayView1<'_, f64>,
) -> Result<DriftScore> {
let length_d = ks_two_sample(baseline_lengths, current_lengths)?.statistic;
Ok(DriftScore::new(
DriftDimension::Response,
length_d,
self.config.threshold,
"length-ks",
))
}
pub fn detect_with_entropy(
&self,
baseline_lengths: &ArrayView1<'_, f64>,
current_lengths: &ArrayView1<'_, f64>,
baseline_entropy: &ArrayView1<'_, f64>,
current_entropy: &ArrayView1<'_, f64>,
) -> Result<DriftScore> {
let length_d = ks_two_sample(baseline_lengths, current_lengths)?.statistic;
let ent_d = ks_two_sample(baseline_entropy, current_entropy)?.statistic;
let combined = length_d.max(ent_d);
Ok(DriftScore::new(
DriftDimension::Response,
combined,
self.config.threshold,
"length-ks+entropy-ks",
))
}
pub fn detect_full(
&self,
baseline_lengths: &ArrayView1<'_, f64>,
current_lengths: &ArrayView1<'_, f64>,
baseline_entropy: &ArrayView1<'_, f64>,
current_entropy: &ArrayView1<'_, f64>,
baseline_embeddings: &ArrayView2<'_, f32>,
current_embeddings: &ArrayView2<'_, f32>,
) -> Result<DriftScore> {
let length_d = ks_two_sample(baseline_lengths, current_lengths)?.statistic;
let ent_d = ks_two_sample(baseline_entropy, current_entropy)?.statistic;
let emb =
EmbeddingDriftDetector::default().detect(baseline_embeddings, current_embeddings)?;
let combined = length_d.max(ent_d) + self.config.semantic_weight * emb.score;
Ok(DriftScore::new(
DriftDimension::Response,
combined,
self.config.threshold,
"length-ks+entropy-ks+mmd+sw",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn identical_lengths_score_zero() {
let a = Array1::from((1..=100).map(|x| x as f64).collect::<Vec<_>>());
let detector = ResponseDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
assert_eq!(s.score, 0.0);
assert!(!s.exceeded);
}
#[test]
fn shifted_lengths_flag_drift() {
let a = Array1::from(vec![10.0; 50]);
let b = Array1::from(vec![100.0; 50]);
let detector = ResponseDriftDetector::default();
let s = detector.detect(&a.view(), &b.view()).unwrap();
assert!(s.exceeded);
}
#[test]
fn entropy_drift_caught() {
let lens_a = Array1::from(vec![10.0; 50]);
let lens_b = lens_a.clone();
let ent_a = Array1::from(vec![1.0; 50]);
let ent_b = Array1::from(vec![5.0; 50]);
let detector = ResponseDriftDetector::default();
let s = detector
.detect_with_entropy(&lens_a.view(), &lens_b.view(), &ent_a.view(), &ent_b.view())
.unwrap();
assert!(s.exceeded);
}
}