Skip to main content

entrenar/monitor/
drift.rs

1//! Drift Detection Module (ENT-044)
2//!
3//! Detects training anomalies using statistical methods.
4//! Based on renacer's sliding window baseline patterns.
5
6/// Drift detection status
7#[derive(Debug, Clone, PartialEq)]
8pub enum DriftStatus {
9    /// No drift detected
10    NoDrift,
11    /// Warning: potential drift (p-value)
12    Warning(f64),
13    /// Drift confirmed (p-value)
14    Drift(f64),
15}
16
17/// Anomaly severity levels (from renacer)
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AnomalySeverity {
20    /// 3-4 standard deviations
21    Low,
22    /// 4-5 standard deviations
23    Medium,
24    /// >5 standard deviations
25    High,
26}
27
28/// Sliding window baseline for anomaly detection
29#[derive(Debug, Clone)]
30pub struct SlidingWindowBaseline {
31    window_size: usize,
32    values: Vec<f64>,
33    mean: f64,
34    m2: f64, // For Welford's algorithm
35    count: usize,
36}
37
38impl SlidingWindowBaseline {
39    /// Create a new baseline with given window size
40    pub fn new(window_size: usize) -> Self {
41        Self { window_size, values: Vec::with_capacity(window_size), mean: 0.0, m2: 0.0, count: 0 }
42    }
43
44    /// Update baseline with new value
45    pub fn update(&mut self, value: f64) {
46        contract_pre_update!();
47        if value.is_nan() || value.is_infinite() {
48            return;
49        }
50
51        // Add to window
52        if self.values.len() >= self.window_size {
53            self.values.remove(0);
54        }
55        self.values.push(value);
56
57        // Recalculate stats (simplified - could be optimized)
58        self.count = self.values.len();
59        if self.count > 0 {
60            self.mean = self.values.iter().sum::<f64>() / self.count as f64;
61            if self.count > 1 {
62                self.m2 = self.values.iter().map(|v| (v - self.mean).powi(2)).sum::<f64>();
63            }
64        }
65    }
66
67    /// Get current standard deviation
68    pub fn std(&self) -> f64 {
69        if self.count < 2 {
70            return 0.0;
71        }
72        (self.m2 / (self.count - 1) as f64).sqrt()
73    }
74
75    /// Calculate z-score for a value
76    pub fn z_score(&self, value: f64) -> f64 {
77        let std = self.std();
78        if std == 0.0 {
79            return 0.0;
80        }
81        (value - self.mean) / std
82    }
83
84    /// Detect anomaly with threshold (standard deviations)
85    pub fn detect_anomaly(&self, value: f64, threshold: f64) -> Option<Anomaly> {
86        if self.count < 10 {
87            return None; // Not enough data
88        }
89
90        let z = self.z_score(value).abs();
91        if z < threshold {
92            return None;
93        }
94
95        let severity = if z >= 5.0 {
96            AnomalySeverity::High
97        } else if z >= 4.0 {
98            AnomalySeverity::Medium
99        } else {
100            AnomalySeverity::Low
101        };
102
103        Some(Anomaly {
104            value,
105            z_score: z,
106            severity,
107            baseline_mean: self.mean,
108            baseline_std: self.std(),
109        })
110    }
111
112    /// Get current mean
113    pub fn mean(&self) -> f64 {
114        self.mean
115    }
116
117    /// Get sample count
118    pub fn count(&self) -> usize {
119        self.count
120    }
121}
122
123/// Detected anomaly
124#[derive(Debug, Clone)]
125pub struct Anomaly {
126    /// The anomalous value
127    pub value: f64,
128    /// Z-score (number of standard deviations from mean)
129    pub z_score: f64,
130    /// Severity classification
131    pub severity: AnomalySeverity,
132    /// Baseline mean when anomaly was detected
133    pub baseline_mean: f64,
134    /// Baseline standard deviation
135    pub baseline_std: f64,
136}
137
138/// Drift detector using statistical tests
139#[derive(Debug)]
140pub struct DriftDetector {
141    baseline: SlidingWindowBaseline,
142    threshold: f64,
143    warning_threshold: f64,
144}
145
146impl DriftDetector {
147    /// Create a new drift detector
148    pub fn new(window_size: usize) -> Self {
149        Self {
150            baseline: SlidingWindowBaseline::new(window_size),
151            threshold: 0.05,        // p < 0.05 for drift
152            warning_threshold: 0.1, // p < 0.1 for warning
153        }
154    }
155
156    /// Set detection thresholds
157    pub fn with_thresholds(mut self, warning: f64, drift: f64) -> Self {
158        self.warning_threshold = warning;
159        self.threshold = drift;
160        self
161    }
162
163    /// Update baseline and check for drift
164    pub fn check(&mut self, value: f64) -> DriftStatus {
165        // Get z-score before updating
166        let z = self.baseline.z_score(value).abs();
167
168        // Update baseline
169        self.baseline.update(value);
170
171        if self.baseline.count() < 10 {
172            return DriftStatus::NoDrift;
173        }
174
175        // Convert z-score to approximate p-value
176        let p = z_to_p(z);
177
178        if p < self.threshold {
179            DriftStatus::Drift(p)
180        } else if p < self.warning_threshold {
181            DriftStatus::Warning(p)
182        } else {
183            DriftStatus::NoDrift
184        }
185    }
186}
187
188/// Approximate z-score to p-value (two-tailed)
189fn z_to_p(z: f64) -> f64 {
190    // Approximation using error function
191    let t = 1.0 / (1.0 + 0.2316419 * z.abs());
192    let d = 0.3989423 * (-z * z / 2.0).exp();
193    let p =
194        d * t * (0.3193815 + t * (-0.3565638 + t * (1.781478 + t * (-1.821256 + t * 1.330274))));
195
196    2.0 * p // Two-tailed
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_sliding_window_new() {
205        let baseline = SlidingWindowBaseline::new(100);
206        assert_eq!(baseline.count(), 0);
207    }
208
209    #[test]
210    fn test_sliding_window_update() {
211        let mut baseline = SlidingWindowBaseline::new(100);
212        for i in 0..10 {
213            baseline.update(f64::from(i));
214        }
215        assert_eq!(baseline.count(), 10);
216        assert!((baseline.mean() - 4.5).abs() < 1e-6);
217    }
218
219    #[test]
220    fn test_sliding_window_rolls() {
221        let mut baseline = SlidingWindowBaseline::new(5);
222        for i in 0..10 {
223            baseline.update(f64::from(i));
224        }
225        // Window should contain [5, 6, 7, 8, 9]
226        assert_eq!(baseline.count(), 5);
227        assert!((baseline.mean() - 7.0).abs() < 1e-6);
228    }
229
230    #[test]
231    fn test_z_score() {
232        let mut baseline = SlidingWindowBaseline::new(100);
233        // Add 100 values with mean=50, std≈29
234        for i in 0..100 {
235            baseline.update(f64::from(i));
236        }
237
238        // Value at mean should have z≈0
239        let z_mean = baseline.z_score(50.0);
240        assert!(z_mean.abs() < 0.5);
241    }
242
243    #[test]
244    fn test_detect_anomaly_none() {
245        let mut baseline = SlidingWindowBaseline::new(100);
246        for i in 0..100 {
247            baseline.update(50.0 + f64::from(i % 5));
248        }
249
250        // Normal value
251        let anomaly = baseline.detect_anomaly(52.0, 3.0);
252        assert!(anomaly.is_none());
253    }
254
255    #[test]
256    fn test_detect_anomaly_high() {
257        let mut baseline = SlidingWindowBaseline::new(100);
258        // Add values with some variance so std > 0
259        for i in 0..100 {
260            baseline.update(50.0 + f64::from(i % 10));
261        }
262
263        // Extreme outlier (far from mean ~54.5, std ~2.87)
264        let anomaly = baseline.detect_anomaly(100.0, 3.0);
265        assert!(anomaly.is_some());
266        let a = anomaly.expect("operation should succeed");
267        assert!(a.z_score > 5.0); // Should be high severity
268    }
269
270    #[test]
271    fn test_drift_detector_no_drift() {
272        let mut detector = DriftDetector::new(100);
273
274        // Stable values
275        for _ in 0..50 {
276            let status = detector.check(50.0);
277            assert_eq!(status, DriftStatus::NoDrift);
278        }
279    }
280
281    #[test]
282    fn test_drift_detector_detects_drift() {
283        let mut detector = DriftDetector::new(100);
284
285        // Establish baseline with some variance
286        for i in 0..100 {
287            detector.check(50.0 + f64::from(i % 10));
288        }
289
290        // Sudden large change (mean ~54.5, value 200 is ~50 std devs away)
291        let status = detector.check(200.0);
292        // With variance, this should trigger drift or warning
293        assert!(
294            matches!(status, DriftStatus::Drift(_) | DriftStatus::Warning(_)),
295            "Expected drift or warning, got {status:?}"
296        );
297    }
298
299    #[test]
300    fn test_anomaly_severity_low() {
301        let mut baseline = SlidingWindowBaseline::new(100);
302        for _ in 0..100 {
303            baseline.update(50.0);
304        }
305        // Force a specific std for predictable testing
306        // With constant values, std=0, so any deviation is huge
307        // Instead test with some variance
308    }
309
310    // =========================================================================
311    // Additional Coverage Tests
312    // =========================================================================
313
314    #[test]
315    fn test_update_with_nan() {
316        let mut baseline = SlidingWindowBaseline::new(100);
317        baseline.update(1.0);
318        baseline.update(f64::NAN);
319        baseline.update(2.0);
320        // NaN should be ignored
321        assert_eq!(baseline.count(), 2);
322    }
323
324    #[test]
325    fn test_update_with_infinity() {
326        let mut baseline = SlidingWindowBaseline::new(100);
327        baseline.update(1.0);
328        baseline.update(f64::INFINITY);
329        baseline.update(f64::NEG_INFINITY);
330        baseline.update(2.0);
331        // Infinities should be ignored
332        assert_eq!(baseline.count(), 2);
333    }
334
335    #[test]
336    fn test_std_with_single_value() {
337        let mut baseline = SlidingWindowBaseline::new(100);
338        baseline.update(42.0);
339        // With single value, std should be 0
340        assert_eq!(baseline.std(), 0.0);
341    }
342
343    #[test]
344    fn test_z_score_zero_std() {
345        let mut baseline = SlidingWindowBaseline::new(100);
346        baseline.update(5.0);
347        baseline.update(5.0);
348        // With constant values, std=0, z_score should be 0
349        assert_eq!(baseline.z_score(10.0), 0.0);
350    }
351
352    #[test]
353    fn test_detect_anomaly_not_enough_data() {
354        let mut baseline = SlidingWindowBaseline::new(100);
355        for i in 0..5 {
356            baseline.update(f64::from(i));
357        }
358        // Less than 10 values, should return None
359        let anomaly = baseline.detect_anomaly(100.0, 3.0);
360        assert!(anomaly.is_none());
361    }
362
363    #[test]
364    fn test_anomaly_severity_medium() {
365        let mut baseline = SlidingWindowBaseline::new(100);
366        // Add values with controlled variance
367        for i in 0..100 {
368            // Values between 48-52, mean=50, std≈1.41
369            baseline.update(50.0 + f64::from(i % 5 - 2));
370        }
371        // Value 4 std devs away: 50 + 4*1.41 ≈ 55.64
372        // But actually need ~4 std devs to get Medium
373        // With our distribution mean≈50, std≈1.41, z=4 at ~55.64
374        let anomaly = baseline.detect_anomaly(56.0, 3.0);
375        if let Some(a) = anomaly {
376            // z should be around 4 for Medium severity
377            println!("z_score: {}", a.z_score);
378        }
379    }
380
381    #[test]
382    fn test_drift_detector_with_thresholds() {
383        let detector = DriftDetector::new(100).with_thresholds(0.15, 0.08);
384        // Just verify the builder works
385        assert_eq!(detector.threshold, 0.08);
386        assert_eq!(detector.warning_threshold, 0.15);
387    }
388
389    #[test]
390    fn test_drift_status_warning() {
391        let mut detector = DriftDetector::new(50).with_thresholds(0.3, 0.05); // More lenient warning
392
393        // Establish baseline
394        for i in 0..50 {
395            detector.check(50.0 + f64::from(i % 10));
396        }
397
398        // Moderate deviation might trigger warning
399        // This tests the warning branch
400        let _status = detector.check(75.0);
401        // Result depends on statistics
402    }
403
404    #[test]
405    fn test_z_to_p_approximation() {
406        // Test the z_to_p function indirectly through DriftDetector
407        let mut detector = DriftDetector::new(100);
408        for i in 0..100 {
409            detector.check(50.0 + f64::from(i % 5));
410        }
411        // Any check will exercise the z_to_p function
412        let _status = detector.check(60.0);
413    }
414
415    #[test]
416    fn test_drift_status_eq() {
417        assert_eq!(DriftStatus::NoDrift, DriftStatus::NoDrift);
418        assert_ne!(DriftStatus::NoDrift, DriftStatus::Drift(0.01));
419        assert_ne!(DriftStatus::Warning(0.08), DriftStatus::Drift(0.08));
420    }
421
422    #[test]
423    fn test_anomaly_clone() {
424        let anomaly = Anomaly {
425            value: 100.0,
426            z_score: 5.0,
427            severity: AnomalySeverity::High,
428            baseline_mean: 50.0,
429            baseline_std: 10.0,
430        };
431        let cloned = anomaly.clone();
432        assert_eq!(anomaly.value, cloned.value);
433        assert_eq!(anomaly.severity, cloned.severity);
434    }
435
436    #[test]
437    fn test_sliding_window_baseline_clone() {
438        let mut baseline = SlidingWindowBaseline::new(50);
439        baseline.update(1.0);
440        baseline.update(2.0);
441        let cloned = baseline.clone();
442        assert_eq!(baseline.count(), cloned.count());
443        assert_eq!(baseline.mean(), cloned.mean());
444    }
445}