entrenar/storage/preflight/
validator.rs1use super::{
4 CheckMetadata, CheckResult, PreflightCheck, PreflightContext, PreflightError, PreflightResults,
5};
6
7#[derive(Debug, Default)]
11pub struct Preflight {
12 checks: Vec<PreflightCheck>,
14 context: PreflightContext,
16}
17
18impl Preflight {
19 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn standard() -> Self {
26 Self::new()
27 .add_check(PreflightCheck::no_nan_values())
28 .add_check(PreflightCheck::no_inf_values())
29 .add_check(PreflightCheck::consistent_dimensions())
30 .add_check(PreflightCheck::no_constant_features())
31 }
32
33 pub fn comprehensive() -> Self {
35 Self::standard()
36 .add_check(PreflightCheck::min_samples(10))
37 .add_check(PreflightCheck::min_features(1))
38 .add_check(PreflightCheck::disk_space_mb(100))
39 .add_check(PreflightCheck::memory_mb(256))
40 .add_check(PreflightCheck::gpu_available())
41 }
42
43 pub fn add_check(mut self, check: PreflightCheck) -> Self {
45 self.checks.push(check);
46 self
47 }
48
49 pub fn with_context(mut self, context: PreflightContext) -> Self {
51 self.context = context;
52 self
53 }
54
55 pub fn check_count(&self) -> usize {
57 self.checks.len()
58 }
59
60 pub fn run(&self, data: &[Vec<f64>]) -> PreflightResults {
62 let mut results = Vec::new();
63 let mut passed_count = 0;
64 let mut failed_count = 0;
65 let mut warning_count = 0;
66 let mut skipped_count = 0;
67 let mut all_required_passed = true;
68
69 for check in &self.checks {
70 let result = check.run(data, &self.context);
71
72 match &result {
73 CheckResult::Passed { .. } => passed_count += 1,
74 CheckResult::Failed { .. } => {
75 failed_count += 1;
76 if check.required {
77 all_required_passed = false;
78 }
79 }
80 CheckResult::Warning { .. } => warning_count += 1,
81 CheckResult::Skipped { .. } => skipped_count += 1,
82 }
83
84 results.push((CheckMetadata::from(check), result));
85 }
86
87 PreflightResults::new(
88 results,
89 all_required_passed,
90 passed_count,
91 failed_count,
92 warning_count,
93 skipped_count,
94 )
95 }
96
97 pub fn validate(&self, data: &[Vec<f64>]) -> Result<PreflightResults, PreflightError> {
99 let results = self.run(data);
100
101 if results.all_passed() {
102 Ok(results)
103 } else {
104 Err(PreflightError::ValidationFailed {
105 checks_failed: results.failed_count(),
106 total_checks: self.checks.len(),
107 })
108 }
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
121 fn test_preflight_new() {
122 let preflight = Preflight::new();
123 assert_eq!(preflight.check_count(), 0);
124 }
125
126 #[test]
127 fn test_preflight_add_check() {
128 let preflight = Preflight::new()
129 .add_check(PreflightCheck::no_nan_values())
130 .add_check(PreflightCheck::no_inf_values());
131 assert_eq!(preflight.check_count(), 2);
132 }
133
134 #[test]
135 fn test_preflight_standard() {
136 let preflight = Preflight::standard();
137 assert!(preflight.check_count() >= 3);
138 }
139
140 #[test]
141 fn test_preflight_comprehensive() {
142 let preflight = Preflight::comprehensive();
143 assert!(preflight.check_count() >= 5);
144 }
145
146 #[test]
147 fn test_preflight_run_all_pass() {
148 let preflight = Preflight::new()
149 .add_check(PreflightCheck::no_nan_values())
150 .add_check(PreflightCheck::no_inf_values())
151 .add_check(PreflightCheck::min_samples(2));
152
153 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
154 let results = preflight.run(&data);
155
156 assert!(results.all_passed());
157 assert_eq!(results.passed_count(), 3);
158 assert_eq!(results.failed_count(), 0);
159 }
160
161 #[test]
162 fn test_preflight_run_with_failure() {
163 let preflight = Preflight::new()
164 .add_check(PreflightCheck::no_nan_values())
165 .add_check(PreflightCheck::min_samples(10));
166
167 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
168 let results = preflight.run(&data);
169
170 assert!(!results.all_passed());
171 assert_eq!(results.passed_count(), 1);
172 assert_eq!(results.failed_count(), 1);
173 }
174
175 #[test]
176 fn test_preflight_optional_check_doesnt_fail() {
177 let preflight = Preflight::new()
178 .add_check(PreflightCheck::no_nan_values())
179 .add_check(PreflightCheck::no_constant_features()); let data = vec![vec![1.0, 2.0], vec![1.0, 4.0]]; let results = preflight.run(&data);
183
184 assert!(results.all_passed());
186 assert_eq!(results.warning_count(), 1);
187 }
188
189 #[test]
190 fn test_preflight_validate_success() {
191 let preflight = Preflight::new().add_check(PreflightCheck::no_nan_values());
192 let data = vec![vec![1.0, 2.0]];
193 let result = preflight.validate(&data);
194 assert!(result.is_ok());
195 }
196
197 #[test]
198 fn test_preflight_validate_failure() {
199 let preflight = Preflight::new().add_check(PreflightCheck::min_samples(100));
200 let data = vec![vec![1.0, 2.0]];
201 let result = preflight.validate(&data);
202 assert!(result.is_err());
203 }
204
205 #[test]
206 fn test_preflight_results_failed_checks() {
207 let preflight = Preflight::new()
208 .add_check(PreflightCheck::no_nan_values())
209 .add_check(PreflightCheck::min_samples(10));
210
211 let data = vec![vec![1.0]];
212 let results = preflight.run(&data);
213
214 let failed = results.failed_checks();
215 assert_eq!(failed.len(), 1);
216 assert_eq!(failed[0].0.name, "min_samples");
217 }
218
219 #[test]
220 fn test_preflight_results_warnings() {
221 let preflight = Preflight::new().add_check(PreflightCheck::no_constant_features());
222
223 let data = vec![vec![1.0, 2.0], vec![1.0, 3.0]];
224 let results = preflight.run(&data);
225
226 let warnings = results.warnings();
227 assert_eq!(warnings.len(), 1);
228 }
229
230 #[test]
231 fn test_preflight_results_report() {
232 let preflight = Preflight::new().add_check(PreflightCheck::no_nan_values());
233
234 let data = vec![vec![1.0, 2.0]];
235 let results = preflight.run(&data);
236 let report = results.report();
237
238 assert!(report.contains("Preflight Check Results"));
239 assert!(report.contains("PASSED"));
240 assert!(report.contains("no_nan_values"));
241 }
242
243 #[test]
244 fn test_preflight_with_context() {
245 let ctx = PreflightContext::new().with_min_samples(5);
246 let preflight =
247 Preflight::new().add_check(PreflightCheck::min_samples(1)).with_context(ctx);
248
249 let data = vec![vec![1.0], vec![2.0], vec![3.0]];
250 let results = preflight.run(&data);
251
252 assert!(!results.all_passed());
254 }
255
256 use proptest::prelude::*;
261
262 proptest! {
263 #![proptest_config(ProptestConfig::with_cases(200))]
264
265 #[test]
266 fn prop_preflight_results_counts_consistent(
267 n_checks in 1usize..10
268 ) {
269 let mut preflight = Preflight::new();
270 for _ in 0..n_checks {
271 preflight = preflight.add_check(PreflightCheck::no_nan_values());
272 }
273
274 let data = vec![vec![1.0, 2.0]];
275 let results = preflight.run(&data);
276
277 let total = results.passed_count()
278 + results.failed_count()
279 + results.warning_count()
280 + results.skipped_count();
281
282 prop_assert_eq!(total, n_checks);
283 }
284 }
285}