ragdrift_core/detectors/
data.rs1use crate::stats::{ks_two_sample, psi};
4use crate::types::{DriftDimension, DriftScore};
5use crate::Result;
6use ndarray::{ArrayView2, Axis};
7
8pub struct DataDriftDetector {
14 threshold: f64,
15 n_bins: usize,
16 eps: f64,
17}
18
19impl DataDriftDetector {
20 pub fn new(threshold: f64) -> Self {
22 Self {
23 threshold,
24 n_bins: 10,
25 eps: 1e-4,
26 }
27 }
28
29 pub fn with_bins(mut self, n_bins: usize) -> Self {
31 self.n_bins = n_bins;
32 self
33 }
34
35 pub fn detect(
38 &self,
39 baseline: ArrayView2<f64>,
40 current: ArrayView2<f64>,
41 ) -> Result<DriftScore> {
42 if baseline.ncols() != current.ncols() {
43 return Err(crate::error::RagDriftError::DimensionMismatch {
44 expected: baseline.ncols(),
45 actual: current.ncols(),
46 context: "DataDriftDetector::detect",
47 });
48 }
49 let mut worst = 0.0_f64;
50 for ((base_col, curr_col), _idx) in baseline
51 .axis_iter(Axis(1))
52 .zip(current.axis_iter(Axis(1)))
53 .zip(0..)
54 {
55 let b: Vec<f64> = base_col.iter().copied().collect();
56 let c: Vec<f64> = curr_col.iter().copied().collect();
57 let ks = ks_two_sample(&b, &c)?.d;
58 let p = if b.len() >= self.n_bins {
59 psi(&b, &c, self.n_bins, self.eps).unwrap_or(0.0)
60 } else {
61 0.0
62 };
63 let combined = ks.max(p);
66 if combined > worst {
67 worst = combined;
68 }
69 }
70 Ok(DriftScore::new(
71 DriftDimension::Data,
72 worst,
73 self.threshold,
74 "ks+psi",
75 ))
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use ndarray::Array2;
83
84 #[test]
85 fn identical_features_score_low() {
86 let mut x: Array2<f64> = Array2::zeros((200, 3));
87 for i in 0..200 {
88 for j in 0..3 {
89 x[[i, j]] = (i as f64) + (j as f64);
90 }
91 }
92 let det = DataDriftDetector::new(0.1);
93 let s = det.detect(x.view(), x.view()).unwrap();
94 assert!(s.score < 1e-3, "score was {}", s.score);
95 assert!(!s.exceeded);
96 }
97
98 #[test]
99 fn shifted_feature_flagged() {
100 let mut x: Array2<f64> = Array2::zeros((200, 3));
101 let mut y: Array2<f64> = Array2::zeros((200, 3));
102 for i in 0..200 {
103 x[[i, 0]] = i as f64;
104 x[[i, 1]] = i as f64;
105 x[[i, 2]] = i as f64;
106 y[[i, 0]] = i as f64;
107 y[[i, 1]] = i as f64;
108 y[[i, 2]] = i as f64 + 100.0;
110 }
111 let det = DataDriftDetector::new(0.1);
112 let s = det.detect(x.view(), y.view()).unwrap();
113 assert!(s.exceeded, "score was {}", s.score);
114 }
115
116 #[test]
117 fn dimension_mismatch_errors() {
118 let x: Array2<f64> = Array2::zeros((10, 3));
119 let y: Array2<f64> = Array2::zeros((10, 4));
120 let det = DataDriftDetector::new(0.1);
121 assert!(det.detect(x.view(), y.view()).is_err());
122 }
123}