Skip to main content

entrenar/eval/drift/
detector.rs

1//! Drift detector implementation.
2
3use std::collections::HashMap;
4
5use super::statistical::{bin_counts, chi_square_p_value, ks_p_value};
6use super::types::{
7    CategoricalBaseline, DriftCallback, DriftResult, DriftSummary, DriftTest, Severity,
8};
9
10/// Drift detector with statistical tests and callbacks
11pub struct DriftDetector {
12    tests: Vec<DriftTest>,
13    baseline: Option<Vec<Vec<f64>>>,
14    baseline_categorical: Option<CategoricalBaseline>,
15    warning_multiplier: f64,
16    callbacks: Vec<DriftCallback>,
17}
18
19impl DriftDetector {
20    /// Create a new drift detector with specified tests
21    pub fn new(tests: Vec<DriftTest>) -> Self {
22        Self {
23            tests,
24            baseline: None,
25            baseline_categorical: None,
26            warning_multiplier: 0.8, // Warning at 80% of threshold
27            callbacks: Vec::new(),
28        }
29    }
30
31    /// Register callback for drift events (Andon Cord)
32    ///
33    /// Callbacks are invoked when drift is detected via `check_and_trigger`.
34    pub fn on_drift<F>(&mut self, callback: F)
35    where
36        F: Fn(&[DriftResult]) + Send + Sync + 'static,
37    {
38        self.callbacks.push(Box::new(callback));
39    }
40
41    /// Check for drift and trigger callbacks if drift detected
42    ///
43    /// Returns the drift results and invokes all registered callbacks
44    /// if any feature shows drift.
45    pub fn check_and_trigger(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
46        let results = self.check(current);
47
48        // Check if any drift detected
49        let has_drift = results.iter().any(|r| r.drifted);
50
51        if has_drift {
52            for callback in &self.callbacks {
53                callback(&results);
54            }
55        }
56
57        results
58    }
59
60    /// Check categorical features for drift and trigger callbacks
61    pub fn check_categorical_and_trigger(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
62        let results = self.check_categorical(current);
63
64        let has_drift = results.iter().any(|r| r.drifted);
65
66        if has_drift {
67            for callback in &self.callbacks {
68                callback(&results);
69            }
70        }
71
72        results
73    }
74
75    /// Set baseline distribution for continuous features
76    /// Each row is a sample, each column is a feature
77    pub fn set_baseline(&mut self, data: &[Vec<f64>]) {
78        if data.is_empty() {
79            return;
80        }
81        // Transpose: store column-wise for easier feature comparison
82        let n_features = data[0].len();
83        let mut columns = vec![Vec::new(); n_features];
84        for row in data {
85            for (i, &val) in row.iter().enumerate() {
86                if i < n_features {
87                    columns[i].push(val);
88                }
89            }
90        }
91        self.baseline = Some(columns);
92    }
93
94    /// Set baseline distribution for categorical features
95    pub fn set_baseline_categorical(&mut self, data: &[Vec<usize>]) {
96        if data.is_empty() {
97            return;
98        }
99        let n_features = data[0].len();
100        let mut histograms = vec![HashMap::new(); n_features];
101        for row in data {
102            for (i, &val) in row.iter().enumerate() {
103                if i < n_features {
104                    *histograms[i].entry(val).or_insert(0) += 1;
105                }
106            }
107        }
108        self.baseline_categorical = Some(histograms);
109    }
110
111    /// Check new data for drift against baseline
112    pub fn check(&self, current: &[Vec<f64>]) -> Vec<DriftResult> {
113        let mut results = Vec::new();
114
115        let baseline = match &self.baseline {
116            Some(b) => b,
117            None => return results,
118        };
119
120        if current.is_empty() {
121            return results;
122        }
123
124        // Transpose current data to column-wise
125        let n_features = current[0].len().min(baseline.len());
126        let mut current_columns = vec![Vec::new(); n_features];
127        for row in current {
128            for (i, &val) in row.iter().enumerate() {
129                if i < n_features {
130                    current_columns[i].push(val);
131                }
132            }
133        }
134
135        // Run tests on each feature
136        for (feature_idx, (baseline_col, current_col)) in
137            baseline.iter().zip(current_columns.iter()).enumerate()
138        {
139            for test in &self.tests {
140                let result = match test {
141                    DriftTest::KS { threshold } => {
142                        self.ks_test(feature_idx, baseline_col, current_col, *threshold)
143                    }
144                    DriftTest::PSI { threshold } => {
145                        self.psi_test(feature_idx, baseline_col, current_col, *threshold)
146                    }
147                    DriftTest::ChiSquare { .. } => continue, // Skip for continuous
148                };
149                results.push(result);
150            }
151        }
152
153        results
154    }
155
156    /// Check categorical features for drift
157    pub fn check_categorical(&self, current: &[Vec<usize>]) -> Vec<DriftResult> {
158        let mut results = Vec::new();
159
160        let baseline = match &self.baseline_categorical {
161            Some(b) => b,
162            None => return results,
163        };
164
165        if current.is_empty() {
166            return results;
167        }
168
169        // Build current histograms
170        let n_features = current[0].len().min(baseline.len());
171        let mut current_histograms = vec![HashMap::new(); n_features];
172        for row in current {
173            for (i, &val) in row.iter().enumerate() {
174                if i < n_features {
175                    *current_histograms[i].entry(val).or_insert(0) += 1;
176                }
177            }
178        }
179
180        // Run chi-square test on each feature
181        for (feature_idx, (baseline_hist, current_hist)) in
182            baseline.iter().zip(current_histograms.iter()).enumerate()
183        {
184            for test in &self.tests {
185                if let DriftTest::ChiSquare { threshold } = test {
186                    let result =
187                        self.chi_square_test(feature_idx, baseline_hist, current_hist, *threshold);
188                    results.push(result);
189                }
190            }
191        }
192
193        results
194    }
195
196    /// Kolmogorov-Smirnov test for continuous features
197    fn ks_test(
198        &self,
199        feature_idx: usize,
200        baseline: &[f64],
201        current: &[f64],
202        threshold: f64,
203    ) -> DriftResult {
204        // Sort both distributions
205        let mut sorted_baseline: Vec<f64> = baseline.to_vec();
206        let mut sorted_current: Vec<f64> = current.to_vec();
207        sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208        sorted_current.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209
210        // Compute empirical CDFs and find maximum difference
211        let n1 = sorted_baseline.len() as f64;
212        let n2 = sorted_current.len() as f64;
213
214        let mut d_max = 0.0f64;
215        let mut i = 0usize;
216        let mut j = 0usize;
217
218        while i < sorted_baseline.len() && j < sorted_current.len() {
219            let cdf1 = (i + 1) as f64 / n1;
220            let cdf2 = (j + 1) as f64 / n2;
221
222            let diff = (cdf1 - cdf2).abs();
223            d_max = d_max.max(diff);
224
225            if sorted_baseline[i] <= sorted_current[j] {
226                i += 1;
227            } else {
228                j += 1;
229            }
230        }
231
232        // Approximate p-value using asymptotic formula
233        let n_eff = (n1 * n2) / (n1 + n2);
234        let lambda = d_max * n_eff.sqrt();
235        let p_value = ks_p_value(lambda);
236
237        let (drifted, severity) = self.classify_result(p_value, threshold);
238
239        DriftResult {
240            feature: format!("feature_{feature_idx}"),
241            test: DriftTest::KS { threshold },
242            statistic: d_max,
243            p_value,
244            drifted,
245            severity,
246        }
247    }
248
249    /// Population Stability Index (PSI) test
250    fn psi_test(
251        &self,
252        feature_idx: usize,
253        baseline: &[f64],
254        current: &[f64],
255        threshold: f64,
256    ) -> DriftResult {
257        if baseline.is_empty() || current.is_empty() {
258            return DriftResult {
259                feature: format!("feature_{feature_idx}"),
260                test: DriftTest::PSI { threshold },
261                statistic: 0.0,
262                p_value: 0.0,
263                drifted: false,
264                severity: Severity::None,
265            };
266        }
267
268        // Create 10 bins based on baseline deciles
269        let n_bins = 10;
270        let mut sorted_baseline: Vec<f64> = baseline.to_vec();
271        sorted_baseline.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273        // Calculate bin edges (deciles)
274        let mut edges = Vec::with_capacity(n_bins + 1);
275        edges.push(f64::NEG_INFINITY);
276        for i in 1..n_bins {
277            let idx = (sorted_baseline.len() * i / n_bins.max(1)).min(sorted_baseline.len() - 1);
278            edges.push(sorted_baseline[idx]);
279        }
280        edges.push(f64::INFINITY);
281
282        // Count samples in each bin
283        let baseline_counts = bin_counts(baseline, &edges);
284        let current_counts = bin_counts(current, &edges);
285
286        // Calculate PSI
287        let total_baseline = baseline.len() as f64;
288        let total_current = current.len() as f64;
289
290        let mut psi = 0.0;
291        for (b_count, c_count) in baseline_counts.iter().zip(current_counts.iter()) {
292            let b_pct = (*b_count as f64 + 0.0001) / (total_baseline + 0.001);
293            let c_pct = (*c_count as f64 + 0.0001) / (total_current + 0.001);
294            psi += (c_pct - b_pct) * (c_pct / b_pct).max(f64::MIN_POSITIVE).ln();
295        }
296
297        let (drifted, severity) = if psi >= threshold {
298            (true, Severity::Critical)
299        } else if psi >= threshold * self.warning_multiplier {
300            (true, Severity::Warning)
301        } else {
302            (false, Severity::None)
303        };
304
305        DriftResult {
306            feature: format!("feature_{feature_idx}"),
307            test: DriftTest::PSI { threshold },
308            statistic: psi,
309            p_value: psi, // PSI doesn't use p-value, store the PSI value
310            drifted,
311            severity,
312        }
313    }
314
315    /// Chi-square test for categorical features
316    fn chi_square_test(
317        &self,
318        feature_idx: usize,
319        baseline: &HashMap<usize, usize>,
320        current: &HashMap<usize, usize>,
321        threshold: f64,
322    ) -> DriftResult {
323        // Get all categories
324        let mut categories: Vec<usize> = baseline.keys().chain(current.keys()).copied().collect();
325        categories.sort_unstable();
326        categories.dedup();
327
328        let total_baseline: f64 = baseline.values().sum::<usize>() as f64;
329        let total_current: f64 = current.values().sum::<usize>() as f64;
330
331        if total_baseline == 0.0 || total_current == 0.0 {
332            return DriftResult {
333                feature: format!("feature_{feature_idx}"),
334                test: DriftTest::ChiSquare { threshold },
335                statistic: 0.0,
336                p_value: 1.0,
337                drifted: false,
338                severity: Severity::None,
339            };
340        }
341
342        // Calculate chi-square statistic
343        let mut chi_sq = 0.0;
344        let mut df: usize = 0;
345
346        for &cat in &categories {
347            let observed = *current.get(&cat).unwrap_or(&0) as f64;
348            let baseline_pct = *baseline.get(&cat).unwrap_or(&0) as f64 / total_baseline;
349            let expected = baseline_pct * total_current;
350
351            if expected > 0.0 {
352                chi_sq += (observed - expected).powi(2) / expected;
353                df += 1;
354            }
355        }
356
357        df = df.saturating_sub(1); // degrees of freedom = categories - 1
358        let p_value = chi_square_p_value(chi_sq, df);
359
360        let (drifted, severity) = self.classify_result(p_value, threshold);
361
362        DriftResult {
363            feature: format!("feature_{feature_idx}"),
364            test: DriftTest::ChiSquare { threshold },
365            statistic: chi_sq,
366            p_value,
367            drifted,
368            severity,
369        }
370    }
371
372    /// Classify result based on p-value and threshold
373    fn classify_result(&self, p_value: f64, threshold: f64) -> (bool, Severity) {
374        if p_value < threshold {
375            (true, Severity::Critical)
376        } else if p_value < threshold / self.warning_multiplier {
377            (true, Severity::Warning)
378        } else {
379            (false, Severity::None)
380        }
381    }
382
383    /// Get summary of drift results
384    pub fn summary(results: &[DriftResult]) -> DriftSummary {
385        let total = results.len();
386        let drifted = results.iter().filter(|r| r.drifted).count();
387        let warnings = results.iter().filter(|r| r.severity == Severity::Warning).count();
388        let critical = results.iter().filter(|r| r.severity == Severity::Critical).count();
389
390        DriftSummary { total_features: total, drifted_features: drifted, warnings, critical }
391    }
392}