entrenar/eval/drift/
types.rs1use std::collections::HashMap;
4
5#[derive(Clone, Copy, Debug, PartialEq)]
7pub enum DriftTest {
8 KS { threshold: f64 },
10 ChiSquare { threshold: f64 },
12 PSI { threshold: f64 },
14}
15
16impl DriftTest {
17 pub fn name(&self) -> &'static str {
19 match self {
20 DriftTest::KS { .. } => "Kolmogorov-Smirnov",
21 DriftTest::ChiSquare { .. } => "Chi-Square",
22 DriftTest::PSI { .. } => "PSI",
23 }
24 }
25
26 pub fn threshold(&self) -> f64 {
28 match self {
29 DriftTest::KS { threshold }
30 | DriftTest::ChiSquare { threshold }
31 | DriftTest::PSI { threshold } => *threshold,
32 }
33 }
34}
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum Severity {
39 None,
41 Warning,
43 Critical,
45}
46
47#[derive(Clone, Debug)]
49pub struct DriftResult {
50 pub feature: String,
52 pub test: DriftTest,
54 pub statistic: f64,
56 pub p_value: f64,
58 pub drifted: bool,
60 pub severity: Severity,
62}
63
64#[derive(Debug, Clone)]
66pub struct DriftSummary {
67 pub total_features: usize,
69 pub drifted_features: usize,
71 pub warnings: usize,
73 pub critical: usize,
75}
76
77impl DriftSummary {
78 pub fn has_critical(&self) -> bool {
80 self.critical > 0
81 }
82
83 pub fn has_drift(&self) -> bool {
85 self.drifted_features > 0
86 }
87
88 pub fn drift_percentage(&self) -> f64 {
90 if self.total_features == 0 {
91 0.0
92 } else {
93 100.0 * self.drifted_features as f64 / self.total_features as f64
94 }
95 }
96}
97
98pub type DriftCallback = Box<dyn Fn(&[DriftResult]) + Send + Sync>;
100
101pub type CategoricalBaseline = Vec<HashMap<usize, usize>>;
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_drift_test_ks_name() {
110 let test = DriftTest::KS { threshold: 0.05 };
111 assert_eq!(test.name(), "Kolmogorov-Smirnov");
112 }
113
114 #[test]
115 fn test_drift_test_chi_square_name() {
116 let test = DriftTest::ChiSquare { threshold: 0.05 };
117 assert_eq!(test.name(), "Chi-Square");
118 }
119
120 #[test]
121 fn test_drift_test_psi_name() {
122 let test = DriftTest::PSI { threshold: 0.1 };
123 assert_eq!(test.name(), "PSI");
124 }
125
126 #[test]
127 fn test_drift_test_ks_threshold() {
128 let test = DriftTest::KS { threshold: 0.05 };
129 assert!((test.threshold() - 0.05).abs() < 1e-9);
130 }
131
132 #[test]
133 fn test_drift_test_chi_square_threshold() {
134 let test = DriftTest::ChiSquare { threshold: 0.01 };
135 assert!((test.threshold() - 0.01).abs() < 1e-9);
136 }
137
138 #[test]
139 fn test_drift_test_psi_threshold() {
140 let test = DriftTest::PSI { threshold: 0.25 };
141 assert!((test.threshold() - 0.25).abs() < 1e-9);
142 }
143
144 #[test]
145 fn test_drift_test_clone() {
146 let test = DriftTest::KS { threshold: 0.05 };
147 let cloned = test;
148 assert_eq!(test, cloned);
149 }
150
151 #[test]
152 fn test_drift_test_debug() {
153 let test = DriftTest::KS { threshold: 0.05 };
154 let debug_str = format!("{test:?}");
155 assert!(debug_str.contains("KS"));
156 assert!(debug_str.contains("threshold"));
157 }
158
159 #[test]
160 fn test_severity_none() {
161 let sev = Severity::None;
162 assert_eq!(sev, Severity::None);
163 }
164
165 #[test]
166 fn test_severity_warning() {
167 let sev = Severity::Warning;
168 assert_eq!(sev, Severity::Warning);
169 }
170
171 #[test]
172 fn test_severity_critical() {
173 let sev = Severity::Critical;
174 assert_eq!(sev, Severity::Critical);
175 }
176
177 #[test]
178 fn test_severity_clone() {
179 let sev = Severity::Warning;
180 let cloned = sev;
181 assert_eq!(sev, cloned);
182 }
183
184 #[test]
185 fn test_severity_debug() {
186 assert_eq!(format!("{:?}", Severity::None), "None");
187 assert_eq!(format!("{:?}", Severity::Warning), "Warning");
188 assert_eq!(format!("{:?}", Severity::Critical), "Critical");
189 }
190
191 #[test]
192 fn test_drift_summary_has_critical() {
193 let summary =
194 DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
195 assert!(summary.has_critical());
196
197 let no_critical =
198 DriftSummary { total_features: 10, drifted_features: 2, warnings: 2, critical: 0 };
199 assert!(!no_critical.has_critical());
200 }
201
202 #[test]
203 fn test_drift_summary_has_drift() {
204 let summary =
205 DriftSummary { total_features: 10, drifted_features: 3, warnings: 3, critical: 0 };
206 assert!(summary.has_drift());
207
208 let no_drift =
209 DriftSummary { total_features: 10, drifted_features: 0, warnings: 0, critical: 0 };
210 assert!(!no_drift.has_drift());
211 }
212
213 #[test]
214 fn test_drift_summary_drift_percentage() {
215 let summary =
216 DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
217 assert!((summary.drift_percentage() - 30.0).abs() < 1e-9);
218 }
219
220 #[test]
221 fn test_drift_summary_drift_percentage_zero_features() {
222 let summary =
223 DriftSummary { total_features: 0, drifted_features: 0, warnings: 0, critical: 0 };
224 assert!((summary.drift_percentage() - 0.0).abs() < 1e-9);
225 }
226
227 #[test]
228 fn test_drift_result_clone() {
229 let result = DriftResult {
230 feature: "age".to_string(),
231 test: DriftTest::KS { threshold: 0.05 },
232 statistic: 0.15,
233 p_value: 0.02,
234 drifted: true,
235 severity: Severity::Warning,
236 };
237 let cloned = result.clone();
238 assert_eq!(result.feature, cloned.feature);
239 assert_eq!(result.drifted, cloned.drifted);
240 }
241
242 #[test]
243 fn test_drift_summary_clone() {
244 let summary =
245 DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
246 let cloned = summary.clone();
247 assert_eq!(summary.total_features, cloned.total_features);
248 }
249
250 #[test]
251 fn test_drift_summary_debug() {
252 let summary =
253 DriftSummary { total_features: 10, drifted_features: 3, warnings: 2, critical: 1 };
254 let debug_str = format!("{summary:?}");
255 assert!(debug_str.contains("DriftSummary"));
256 assert!(debug_str.contains("total_features"));
257 }
258}