use ndarray::{ArrayView2, Axis};
use crate::error::Result;
use crate::stats::{ks_two_sample, psi, PsiBinning};
use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
#[derive(Debug, Clone, Copy)]
pub struct DataDriftConfig {
pub threshold: f64,
pub psi_binning: PsiBinning,
}
impl Default for DataDriftConfig {
fn default() -> Self {
Self {
threshold: 0.25,
psi_binning: PsiBinning::Quantile(10),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DataDriftDetector {
config: DataDriftConfig,
}
impl DataDriftDetector {
pub fn new(config: DataDriftConfig) -> Self {
Self { config }
}
pub fn detect(
&self,
baseline: &ArrayView2<'_, f64>,
current: &ArrayView2<'_, f64>,
) -> Result<DriftScore> {
check_same_cols(baseline, current)?;
check_min_samples(baseline.nrows(), 2)?;
check_min_samples(current.nrows(), 2)?;
let mut max_score = 0.0_f64;
for col in 0..baseline.ncols() {
let b_col = baseline.index_axis(Axis(1), col);
let c_col = current.index_axis(Axis(1), col);
let ks = ks_two_sample(&b_col, &c_col)?.statistic;
let p = psi(&b_col, &c_col, self.config.psi_binning).unwrap_or(0.0);
let combined = ks.max(p / 0.25);
if combined > max_score {
max_score = combined;
}
}
Ok(DriftScore::new(
DriftDimension::Data,
max_score,
self.config.threshold,
"ks+psi",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn identical_matrices_score_zero() {
let a = Array2::from_shape_fn((100, 4), |(i, j)| (i + j) as f64);
let detector = DataDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
assert_eq!(s.score, 0.0);
assert!(!s.exceeded);
}
#[test]
fn one_drifted_feature_flags_overall() {
let baseline = Array2::from_shape_fn((200, 3), |(i, _)| i as f64);
let mut current = baseline.clone();
for i in 0..current.nrows() {
current[[i, 1]] += 100.0;
}
let detector = DataDriftDetector::default();
let s = detector.detect(&baseline.view(), ¤t.view()).unwrap();
assert!(s.exceeded, "expected drift, score={}", s.score);
}
#[test]
fn rejects_dim_mismatch() {
let a = Array2::<f64>::zeros((10, 3));
let b = Array2::<f64>::zeros((10, 4));
let detector = DataDriftDetector::default();
assert!(detector.detect(&a.view(), &b.view()).is_err());
}
}