Skip to main content

ragdrift_core/detectors/
data.rs

1//! Tabular data drift via per-feature KS + PSI, reduced to a worst-feature score.
2
3use crate::stats::{ks_two_sample, psi};
4use crate::types::{DriftDimension, DriftScore};
5use crate::Result;
6use ndarray::{ArrayView2, Axis};
7
8/// Detects feature-wise distribution drift on tabular data.
9///
10/// For each column, computes the KS statistic D and PSI, then takes the
11/// maximum across features. KS dominates when the shape of the distribution
12/// changes; PSI dominates when bin masses redistribute.
13pub struct DataDriftDetector {
14    threshold: f64,
15    n_bins: usize,
16    eps: f64,
17}
18
19impl DataDriftDetector {
20    /// Create a detector with the given threshold, 10 PSI bins, and epsilon = 1e-4.
21    pub fn new(threshold: f64) -> Self {
22        Self {
23            threshold,
24            n_bins: 10,
25            eps: 1e-4,
26        }
27    }
28
29    /// Override the number of PSI bins.
30    pub fn with_bins(mut self, n_bins: usize) -> Self {
31        self.n_bins = n_bins;
32        self
33    }
34
35    /// Run the detector. Inputs are `(n_samples, n_features)` and must share
36    /// `n_features`.
37    pub fn detect(
38        &self,
39        baseline: ArrayView2<f64>,
40        current: ArrayView2<f64>,
41    ) -> Result<DriftScore> {
42        if baseline.ncols() != current.ncols() {
43            return Err(crate::error::RagDriftError::DimensionMismatch {
44                expected: baseline.ncols(),
45                actual: current.ncols(),
46                context: "DataDriftDetector::detect",
47            });
48        }
49        let mut worst = 0.0_f64;
50        for ((base_col, curr_col), _idx) in baseline
51            .axis_iter(Axis(1))
52            .zip(current.axis_iter(Axis(1)))
53            .zip(0..)
54        {
55            let b: Vec<f64> = base_col.iter().copied().collect();
56            let c: Vec<f64> = curr_col.iter().copied().collect();
57            let ks = ks_two_sample(&b, &c)?.d;
58            let p = if b.len() >= self.n_bins {
59                psi(&b, &c, self.n_bins, self.eps).unwrap_or(0.0)
60            } else {
61                0.0
62            };
63            // Map PSI (unbounded but typically <=1) and D ([0,1]) onto a comparable scale.
64            // We use the raw maximum: PSI > D in the regime that matters for alerts.
65            let combined = ks.max(p);
66            if combined > worst {
67                worst = combined;
68            }
69        }
70        Ok(DriftScore::new(
71            DriftDimension::Data,
72            worst,
73            self.threshold,
74            "ks+psi",
75        ))
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use ndarray::Array2;
83
84    #[test]
85    fn identical_features_score_low() {
86        let mut x: Array2<f64> = Array2::zeros((200, 3));
87        for i in 0..200 {
88            for j in 0..3 {
89                x[[i, j]] = (i as f64) + (j as f64);
90            }
91        }
92        let det = DataDriftDetector::new(0.1);
93        let s = det.detect(x.view(), x.view()).unwrap();
94        assert!(s.score < 1e-3, "score was {}", s.score);
95        assert!(!s.exceeded);
96    }
97
98    #[test]
99    fn shifted_feature_flagged() {
100        let mut x: Array2<f64> = Array2::zeros((200, 3));
101        let mut y: Array2<f64> = Array2::zeros((200, 3));
102        for i in 0..200 {
103            x[[i, 0]] = i as f64;
104            x[[i, 1]] = i as f64;
105            x[[i, 2]] = i as f64;
106            y[[i, 0]] = i as f64;
107            y[[i, 1]] = i as f64;
108            // shift only column 2
109            y[[i, 2]] = i as f64 + 100.0;
110        }
111        let det = DataDriftDetector::new(0.1);
112        let s = det.detect(x.view(), y.view()).unwrap();
113        assert!(s.exceeded, "score was {}", s.score);
114    }
115
116    #[test]
117    fn dimension_mismatch_errors() {
118        let x: Array2<f64> = Array2::zeros((10, 3));
119        let y: Array2<f64> = Array2::zeros((10, 4));
120        let det = DataDriftDetector::new(0.1);
121        assert!(det.detect(x.view(), y.view()).is_err());
122    }
123}