ragdrift_core/detectors/
data.rs1use ndarray::{ArrayView2, Axis};
4
5use crate::error::Result;
6use crate::stats::{ks_two_sample, psi, PsiBinning};
7use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
8
9#[derive(Debug, Clone, Copy)]
11pub struct DataDriftConfig {
12 pub threshold: f64,
14 pub psi_binning: PsiBinning,
16}
17
18impl Default for DataDriftConfig {
19 fn default() -> Self {
20 Self {
21 threshold: 0.25,
22 psi_binning: PsiBinning::Quantile(10),
23 }
24 }
25}
26
27#[derive(Debug, Clone, Copy, Default)]
33pub struct DataDriftDetector {
34 config: DataDriftConfig,
35}
36
37impl DataDriftDetector {
38 pub fn new(config: DataDriftConfig) -> Self {
40 Self { config }
41 }
42
43 pub fn detect(
45 &self,
46 baseline: &ArrayView2<'_, f64>,
47 current: &ArrayView2<'_, f64>,
48 ) -> Result<DriftScore> {
49 check_same_cols(baseline, current)?;
50 check_min_samples(baseline.nrows(), 2)?;
51 check_min_samples(current.nrows(), 2)?;
52
53 let mut max_score = 0.0_f64;
54 for col in 0..baseline.ncols() {
55 let b_col = baseline.index_axis(Axis(1), col);
56 let c_col = current.index_axis(Axis(1), col);
57 let ks = ks_two_sample(&b_col, &c_col)?.statistic;
58 let p = psi(&b_col, &c_col, self.config.psi_binning).unwrap_or(0.0);
62 let combined = ks.max(p / 0.25);
63 if combined > max_score {
64 max_score = combined;
65 }
66 }
67 Ok(DriftScore::new(
68 DriftDimension::Data,
69 max_score,
70 self.config.threshold,
71 "ks+psi",
72 ))
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use ndarray::Array2;
80
81 #[test]
82 fn identical_matrices_score_zero() {
83 let a = Array2::from_shape_fn((100, 4), |(i, j)| (i + j) as f64);
84 let detector = DataDriftDetector::default();
85 let s = detector.detect(&a.view(), &a.view()).unwrap();
86 assert_eq!(s.score, 0.0);
87 assert!(!s.exceeded);
88 }
89
90 #[test]
91 fn one_drifted_feature_flags_overall() {
92 let baseline = Array2::from_shape_fn((200, 3), |(i, _)| i as f64);
93 let mut current = baseline.clone();
94 for i in 0..current.nrows() {
96 current[[i, 1]] += 100.0;
97 }
98 let detector = DataDriftDetector::default();
99 let s = detector.detect(&baseline.view(), ¤t.view()).unwrap();
100 assert!(s.exceeded, "expected drift, score={}", s.score);
101 }
102
103 #[test]
104 fn rejects_dim_mismatch() {
105 let a = Array2::<f64>::zeros((10, 3));
106 let b = Array2::<f64>::zeros((10, 4));
107 let detector = DataDriftDetector::default();
108 assert!(detector.detect(&a.view(), &b.view()).is_err());
109 }
110}