Skip to main content

ragdrift_core/detectors/
data.rs

1//! Data drift detector: feature-wise KS + PSI on tabular features.
2
3use ndarray::{ArrayView2, Axis};
4
5use crate::error::Result;
6use crate::stats::{ks_two_sample, psi, PsiBinning};
7use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
8
9/// Configuration for [`DataDriftDetector`].
10#[derive(Debug, Clone, Copy)]
11pub struct DataDriftConfig {
12    /// Threshold on the combined per-feature score.
13    pub threshold: f64,
14    /// Binning strategy passed to PSI for each feature.
15    pub psi_binning: PsiBinning,
16}
17
18impl Default for DataDriftConfig {
19    fn default() -> Self {
20        Self {
21            threshold: 0.25,
22            psi_binning: PsiBinning::Quantile(10),
23        }
24    }
25}
26
27/// Detects drift on tabular feature matrices.
28///
29/// Computes both KS and PSI per feature column. The reported score is the
30/// max over features of `max(KS_D, PSI / 0.25)` so a single threshold of
31/// 0.25 lines up with the standard PSI table.
32#[derive(Debug, Clone, Copy, Default)]
33pub struct DataDriftDetector {
34    config: DataDriftConfig,
35}
36
37impl DataDriftDetector {
38    /// Construct a detector from a custom config.
39    pub fn new(config: DataDriftConfig) -> Self {
40        Self { config }
41    }
42
43    /// Compute drift between two `(n_samples, n_features)` matrices.
44    pub fn detect(
45        &self,
46        baseline: &ArrayView2<'_, f64>,
47        current: &ArrayView2<'_, f64>,
48    ) -> Result<DriftScore> {
49        check_same_cols(baseline, current)?;
50        check_min_samples(baseline.nrows(), 2)?;
51        check_min_samples(current.nrows(), 2)?;
52
53        let mut max_score = 0.0_f64;
54        for col in 0..baseline.ncols() {
55            let b_col = baseline.index_axis(Axis(1), col);
56            let c_col = current.index_axis(Axis(1), col);
57            let ks = ks_two_sample(&b_col, &c_col)?.statistic;
58            // PSI may fail if a single feature has too few unique values.
59            // Treat that feature as zero-drift rather than failing the whole
60            // detection.
61            let p = psi(&b_col, &c_col, self.config.psi_binning).unwrap_or(0.0);
62            let combined = ks.max(p / 0.25);
63            if combined > max_score {
64                max_score = combined;
65            }
66        }
67        Ok(DriftScore::new(
68            DriftDimension::Data,
69            max_score,
70            self.config.threshold,
71            "ks+psi",
72        ))
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use ndarray::Array2;
80
81    #[test]
82    fn identical_matrices_score_zero() {
83        let a = Array2::from_shape_fn((100, 4), |(i, j)| (i + j) as f64);
84        let detector = DataDriftDetector::default();
85        let s = detector.detect(&a.view(), &a.view()).unwrap();
86        assert_eq!(s.score, 0.0);
87        assert!(!s.exceeded);
88    }
89
90    #[test]
91    fn one_drifted_feature_flags_overall() {
92        let baseline = Array2::from_shape_fn((200, 3), |(i, _)| i as f64);
93        let mut current = baseline.clone();
94        // Shift only column 1.
95        for i in 0..current.nrows() {
96            current[[i, 1]] += 100.0;
97        }
98        let detector = DataDriftDetector::default();
99        let s = detector.detect(&baseline.view(), &current.view()).unwrap();
100        assert!(s.exceeded, "expected drift, score={}", s.score);
101    }
102
103    #[test]
104    fn rejects_dim_mismatch() {
105        let a = Array2::<f64>::zeros((10, 3));
106        let b = Array2::<f64>::zeros((10, 4));
107        let detector = DataDriftDetector::default();
108        assert!(detector.detect(&a.view(), &b.view()).is_err());
109    }
110}