Skip to main content

entrenar/eval/drift/
types.rs

1//! Type definitions for drift detection.
2
3use std::collections::HashMap;
4
5/// Statistical test for drift detection
6#[derive(Clone, Copy, Debug, PartialEq)]
7pub enum DriftTest {
8    /// Kolmogorov-Smirnov test (continuous features)
9    KS { threshold: f64 },
10    /// Chi-square test (categorical features)
11    ChiSquare { threshold: f64 },
12    /// Population Stability Index (standard industry metric)
13    PSI { threshold: f64 },
14}
15
16impl DriftTest {
17    /// Get the name of this test
18    pub fn name(&self) -> &'static str {
19        match self {
20            DriftTest::KS { .. } => "Kolmogorov-Smirnov",
21            DriftTest::ChiSquare { .. } => "Chi-Square",
22            DriftTest::PSI { .. } => "PSI",
23        }
24    }
25
26    /// Get the threshold for this test
27    pub fn threshold(&self) -> f64 {
28        match self {
29            DriftTest::KS { threshold }
30            | DriftTest::ChiSquare { threshold }
31            | DriftTest::PSI { threshold } => *threshold,
32        }
33    }
34}
35
36/// Severity levels for drift
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum Severity {
39    /// No drift detected
40    None,
41    /// Warning: potential drift, log warning and continue
42    Warning,
43    /// Critical: stop inference or trigger retrain
44    Critical,
45}
46
47/// Drift detection result
48#[derive(Clone, Debug)]
49pub struct DriftResult {
50    /// Feature name or index
51    pub feature: String,
52    /// Test used for detection
53    pub test: DriftTest,
54    /// Test statistic value
55    pub statistic: f64,
56    /// P-value (for KS and Chi-sq) or PSI value
57    pub p_value: f64,
58    /// Whether drift was detected
59    pub drifted: bool,
60    /// Severity of the drift
61    pub severity: Severity,
62}
63
64/// Summary of drift detection results
65#[derive(Debug, Clone)]
66pub struct DriftSummary {
67    /// Total number of features checked
68    pub total_features: usize,
69    /// Number of features with detected drift
70    pub drifted_features: usize,
71    /// Number of warning-level drifts
72    pub warnings: usize,
73    /// Number of critical-level drifts
74    pub critical: usize,
75}
76
77impl DriftSummary {
78    /// Whether any critical drift was detected
79    pub fn has_critical(&self) -> bool {
80        self.critical > 0
81    }
82
83    /// Whether any drift was detected (warning or critical)
84    pub fn has_drift(&self) -> bool {
85        self.drifted_features > 0
86    }
87
88    /// Percentage of features that drifted
89    pub fn drift_percentage(&self) -> f64 {
90        if self.total_features == 0 {
91            0.0
92        } else {
93            100.0 * self.drifted_features as f64 / self.total_features as f64
94        }
95    }
96}
97
98/// Callback type for drift events (Andon Cord)
99pub type DriftCallback = Box<dyn Fn(&[DriftResult]) + Send + Sync>;
100
101/// Baseline data for categorical features (histogram per feature)
102pub type CategoricalBaseline = Vec<HashMap<usize, usize>>;
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_drift_test_ks_name() {
110        let test = DriftTest::KS { threshold: 0.05 };
111        assert_eq!(test.name(), "Kolmogorov-Smirnov");
112    }
113
114    #[test]
115    fn test_drift_test_chi_square_name() {
116        let test = DriftTest::ChiSquare { threshold: 0.05 };
117        assert_eq!(test.name(), "Chi-Square");
118    }
119
120    #[test]
121    fn test_drift_test_psi_name() {
122        let test = DriftTest::PSI { threshold: 0.1 };
123        assert_eq!(test.name(), "PSI");
124    }
125
126    #[test]
127    fn test_drift_test_ks_threshold() {
128        let test = DriftTest::KS { threshold: 0.05 };
129        assert!((test.threshold() - 0.05).abs() < 1e-9);
130    }
131
132    #[test]
133    fn test_drift_test_chi_square_threshold() {
134        let test = DriftTest::ChiSquare { threshold: 0.01 };
135        assert!((test.threshold() - 0.01).abs() < 1e-9);
136    }
137
138    #[test]
139    fn test_drift_test_psi_threshold() {
140        let test = DriftTest::PSI { threshold: 0.25 };
141        assert!((test.threshold() - 0.25).abs() < 1e-9);
142    }
143
144    #[test]
145    fn test_drift_test_clone() {
146        let test = DriftTest::KS { threshold: 0.05 };
147        let cloned = test;
148        assert_eq!(test, cloned);
149    }
150
151    #[test]
152    fn test_drift_test_debug() {
153        let test = DriftTest::KS { threshold: 0.05 };
154        let debug_str = format!("{test:?}");
155        assert!(debug_str.contains("KS"));
156        assert!(debug_str.contains("threshold"));
157    }
158
159    #[test]
160    fn test_severity_none() {
161        let sev = Severity::None;
162        assert_eq!(sev, Severity::None);
163    }
164
165    #[test]
166    fn test_severity_warning() {
167        let sev = Severity::Warning;
168        assert_eq!(sev, Severity::Warning);
169    }
170
171    #[test]
172    fn test_severity_critical() {
173        let sev = Severity::Critical;
174        assert_eq!(sev, Severity::Critical);
175    }
176
177    #[test]
178    fn test_severity_clone() {
179        let sev = Severity::Warning;
180        let cloned = sev;
181        assert_eq!(sev, cloned);
182    }
183
184    #[test]
185    fn test_severity_debug() {
186        assert_eq!(format!("{:?}", Severity::None), "None");
187        assert_eq!(format!("{:?}", Severity::Warning), "Warning");
188        assert_eq!(format!("{:?}", Severity::Critical), "Critical");
189    }
190
191    #[test]
192    fn test_drift_summary_has_critical() {
193        let summary =
194            DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
195        assert!(summary.has_critical());
196
197        let no_critical =
198            DriftSummary { total_features: 10, drifted_features: 2, warnings: 2, critical: 0 };
199        assert!(!no_critical.has_critical());
200    }
201
202    #[test]
203    fn test_drift_summary_has_drift() {
204        let summary =
205            DriftSummary { total_features: 10, drifted_features: 3, warnings: 3, critical: 0 };
206        assert!(summary.has_drift());
207
208        let no_drift =
209            DriftSummary { total_features: 10, drifted_features: 0, warnings: 0, critical: 0 };
210        assert!(!no_drift.has_drift());
211    }
212
213    #[test]
214    fn test_drift_summary_drift_percentage() {
215        let summary =
216            DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
217        assert!((summary.drift_percentage() - 30.0).abs() < 1e-9);
218    }
219
220    #[test]
221    fn test_drift_summary_drift_percentage_zero_features() {
222        let summary =
223            DriftSummary { total_features: 0, drifted_features: 0, warnings: 0, critical: 0 };
224        assert!((summary.drift_percentage() - 0.0).abs() < 1e-9);
225    }
226
227    #[test]
228    fn test_drift_result_clone() {
229        let result = DriftResult {
230            feature: "age".to_string(),
231            test: DriftTest::KS { threshold: 0.05 },
232            statistic: 0.15,
233            p_value: 0.02,
234            drifted: true,
235            severity: Severity::Warning,
236        };
237        let cloned = result.clone();
238        assert_eq!(result.feature, cloned.feature);
239        assert_eq!(result.drifted, cloned.drifted);
240    }
241
242    #[test]
243    fn test_drift_summary_clone() {
244        let summary =
245            DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
246        let cloned = summary.clone();
247        assert_eq!(summary.total_features, cloned.total_features);
248    }
249
250    #[test]
251    fn test_drift_summary_debug() {
252        let summary =
253            DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
254        let debug_str = format!("{summary:?}");
255        assert!(debug_str.contains("DriftSummary"));
256        assert!(debug_str.contains("total_features"));
257    }
258}