1use std::collections::HashMap;
4
5use super::statistical::{bin_counts, chi_square_p_value, ks_p_value};
6use super::types::{
7 CategoricalBaseline, DriftCallback, DriftResult, DriftSummary, DriftTest, Severity,
8};
9
10pub struct DriftDetector {
12 tests: Vec<DriftTest>,
13 baseline: Option<Vec<Vec<f64>>>,
14 baseline_categorical: Option<CategoricalBaseline>,
15 warning_multiplier: f64,
16 callbacks: Vec<DriftCallback>,
17}
18
19impl DriftDetector {
20 pub fn new(tests: Vec<DriftTest>) -> Self {
22 Self {
23 tests,
24 baseline: None,
25 baseline_categorical: None,
26 warning_multiplier: 0.8, callbacks: Vec::new(),
28 }
29 }
30
31 pub fn on_drift<F>(&mut self, callback: F)
35 where
36 F: Fn(&[DriftResult]) + Send + Sync + 'static,
37 {
38 self.callbacks.push(Box::new(callback));
39 }
40
41 pub fn check_and_trigger(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
46 let results = self.check(current);
47
48 let has_drift = results.iter().any(|r| r.drifted);
50
51 if has_drift {
52 for callback in &self.callbacks {
53 callback(&results);
54 }
55 }
56
57 results
58 }
59
60 pub fn check_categorical_and_trigger(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
62 let results = self.check_categorical(current);
63
64 let has_drift = results.iter().any(|r| r.drifted);
65
66 if has_drift {
67 for callback in &self.callbacks {
68 callback(&results);
69 }
70 }
71
72 results
73 }
74
75 pub fn set_baseline(&mut self, data: &[Vec<f64>]) {
78 if data.is_empty() {
79 return;
80 }
81 let n_features = data[0].len();
83 let mut columns = vec![Vec::new(); n_features];
84 for row in data {
85 for (i, &val) in row.iter().enumerate() {
86 if i < n_features {
87 columns[i].push(val);
88 }
89 }
90 }
91 self.baseline = Some(columns);
92 }
93
94 pub fn set_baseline_categorical(&mut self, data: &[Vec<usize>]) {
96 if data.is_empty() {
97 return;
98 }
99 let n_features = data[0].len();
100 let mut histograms = vec![HashMap::new(); n_features];
101 for row in data {
102 for (i, &val) in row.iter().enumerate() {
103 if i < n_features {
104 *histograms[i].entry(val).or_insert(0) += 1;
105 }
106 }
107 }
108 self.baseline_categorical = Some(histograms);
109 }
110
111 pub fn check(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
113 let mut results = Vec::new();
114
115 let baseline = match &self.baseline {
116 Some(b) => b,
117 None => return results,
118 };
119
120 if current.is_empty() {
121 return results;
122 }
123
124 let n_features = current[0].len().min(baseline.len());
126 let mut current_columns = vec![Vec::new(); n_features];
127 for row in current {
128 for (i, &val) in row.iter().enumerate() {
129 if i < n_features {
130 current_columns[i].push(val);
131 }
132 }
133 }
134
135 for (feature_idx, (baseline_col, current_col)) in
137 baseline.iter().zip(current_columns.iter()).enumerate()
138 {
139 for test in &self.tests {
140 let result = match test {
141 DriftTest::KS { threshold } => {
142 self.ks_test(feature_idx, baseline_col, current_col, *threshold)
143 }
144 DriftTest::PSI { threshold } => {
145 self.psi_test(feature_idx, baseline_col, current_col, *threshold)
146 }
147 DriftTest::ChiSquare { .. } => continue, };
149 results.push(result);
150 }
151 }
152
153 results
154 }
155
156 pub fn check_categorical(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
158 let mut results = Vec::new();
159
160 let baseline = match &self.baseline_categorical {
161 Some(b) => b,
162 None => return results,
163 };
164
165 if current.is_empty() {
166 return results;
167 }
168
169 let n_features = current[0].len().min(baseline.len());
171 let mut current_histograms = vec![HashMap::new(); n_features];
172 for row in current {
173 for (i, &val) in row.iter().enumerate() {
174 if i < n_features {
175 *current_histograms[i].entry(val).or_insert(0) += 1;
176 }
177 }
178 }
179
180 for (feature_idx, (baseline_hist, current_hist)) in
182 baseline.iter().zip(current_histograms.iter()).enumerate()
183 {
184 for test in &self.tests {
185 if let DriftTest::ChiSquare { threshold } = test {
186 let result =
187 self.chi_square_test(feature_idx, baseline_hist, current_hist, *threshold);
188 results.push(result);
189 }
190 }
191 }
192
193 results
194 }
195
196 fn ks_test(
198 &self,
199 feature_idx: usize,
200 baseline: &[f64],
201 current: &[f64],
202 threshold: f64,
203 ) -> DriftResult {
204 let mut sorted_baseline: Vec<f64> = baseline.to_vec();
206 let mut sorted_current: Vec<f64> = current.to_vec();
207 sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208 sorted_current.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209
210 let n1 = sorted_baseline.len() as f64;
212 let n2 = sorted_current.len() as f64;
213
214 let mut d_max = 0.0f64;
215 let mut i = 0usize;
216 let mut j = 0usize;
217
218 while i < sorted_baseline.len() && j < sorted_current.len() {
219 let cdf1 = (i + 1) as f64 / n1;
220 let cdf2 = (j + 1) as f64 / n2;
221
222 let diff = (cdf1 - cdf2).abs();
223 d_max = d_max.max(diff);
224
225 if sorted_baseline[i] <= sorted_current[j] {
226 i += 1;
227 } else {
228 j += 1;
229 }
230 }
231
232 let n_eff = (n1 * n2) / (n1 + n2);
234 let lambda = d_max * n_eff.sqrt();
235 let p_value = ks_p_value(lambda);
236
237 let (drifted, severity) = self.classify_result(p_value, threshold);
238
239 DriftResult {
240 feature: format!("feature_{feature_idx}"),
241 test: DriftTest::KS { threshold },
242 statistic: d_max,
243 p_value,
244 drifted,
245 severity,
246 }
247 }
248
249 fn psi_test(
251 &self,
252 feature_idx: usize,
253 baseline: &[f64],
254 current: &[f64],
255 threshold: f64,
256 ) -> DriftResult {
257 if baseline.is_empty() || current.is_empty() {
258 return DriftResult {
259 feature: format!("feature_{feature_idx}"),
260 test: DriftTest::PSI { threshold },
261 statistic: 0.0,
262 p_value: 0.0,
263 drifted: false,
264 severity: Severity::None,
265 };
266 }
267
268 let n_bins = 10;
270 let mut sorted_baseline: Vec<f64> = baseline.to_vec();
271 sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273 let mut edges = Vec::with_capacity(n_bins + 1);
275 edges.push(f64::NEG_INFINITY);
276 for i in 1..n_bins {
277 let idx = (sorted_baseline.len() * i / n_bins.max(1)).min(sorted_baseline.len() - 1);
278 edges.push(sorted_baseline[idx]);
279 }
280 edges.push(f64::INFINITY);
281
282 let baseline_counts = bin_counts(baseline, &edges);
284 let current_counts = bin_counts(current, &edges);
285
286 let total_baseline = baseline.len() as f64;
288 let total_current = current.len() as f64;
289
290 let mut psi = 0.0;
291 for (b_count, c_count) in baseline_counts.iter().zip(current_counts.iter()) {
292 let b_pct = (*b_count as f64 + 0.0001) / (total_baseline + 0.001);
293 let c_pct = (*c_count as f64 + 0.0001) / (total_current + 0.001);
294 psi += (c_pct - b_pct) * (c_pct / b_pct).max(f64::MIN_POSITIVE).ln();
295 }
296
297 let (drifted, severity) = if psi >= threshold {
298 (true, Severity::Critical)
299 } else if psi >= threshold * self.warning_multiplier {
300 (true, Severity::Warning)
301 } else {
302 (false, Severity::None)
303 };
304
305 DriftResult {
306 feature: format!("feature_{feature_idx}"),
307 test: DriftTest::PSI { threshold },
308 statistic: psi,
309 p_value: psi, drifted,
311 severity,
312 }
313 }
314
315 fn chi_square_test(
317 &self,
318 feature_idx: usize,
319 baseline: &HashMap<usize, usize>,
320 current: &HashMap<usize, usize>,
321 threshold: f64,
322 ) -> DriftResult {
323 let mut categories: Vec<usize> = baseline.keys().chain(current.keys()).copied().collect();
325 categories.sort_unstable();
326 categories.dedup();
327
328 let total_baseline: f64 = baseline.values().sum::<usize>() as f64;
329 let total_current: f64 = current.values().sum::<usize>() as f64;
330
331 if total_baseline == 0.0 || total_current == 0.0 {
332 return DriftResult {
333 feature: format!("feature_{feature_idx}"),
334 test: DriftTest::ChiSquare { threshold },
335 statistic: 0.0,
336 p_value: 1.0,
337 drifted: false,
338 severity: Severity::None,
339 };
340 }
341
342 let mut chi_sq = 0.0;
344 let mut df: usize = 0;
345
346 for &cat in &categories {
347 let observed = *current.get(&cat).unwrap_or(&0) as f64;
348 let baseline_pct = *baseline.get(&cat).unwrap_or(&0) as f64 / total_baseline;
349 let expected = baseline_pct * total_current;
350
351 if expected > 0.0 {
352 chi_sq += (observed - expected).powi(2) / expected;
353 df += 1;
354 }
355 }
356
357 df = df.saturating_sub(1); let p_value = chi_square_p_value(chi_sq, df);
359
360 let (drifted, severity) = self.classify_result(p_value, threshold);
361
362 DriftResult {
363 feature: format!("feature_{feature_idx}"),
364 test: DriftTest::ChiSquare { threshold },
365 statistic: chi_sq,
366 p_value,
367 drifted,
368 severity,
369 }
370 }
371
372 fn classify_result(&self, p_value: f64, threshold: f64) -> (bool, Severity) {
374 if p_value < threshold {
375 (true, Severity::Critical)
376 } else if p_value < threshold / self.warning_multiplier {
377 (true, Severity::Warning)
378 } else {
379 (false, Severity::None)
380 }
381 }
382
383 pub fn summary(results: &[DriftResult]) -> DriftSummary {
385 let total = results.len();
386 let drifted = results.iter().filter(|r| r.drifted).count();
387 let warnings = results.iter().filter(|r| r.severity == Severity::Warning).count();
388 let critical = results.iter().filter(|r| r.severity == Severity::Critical).count();
389
390 DriftSummary { total_features: total, drifted_features: drifted, warnings, critical }
391 }
392}