Skip to main content

entrenar/storage/preflight/
validator.rs

1//! Main Preflight validation system.
2
3use super::{
4    CheckMetadata, CheckResult, PreflightCheck, PreflightContext, PreflightError, PreflightResults,
5};
6
7/// Preflight validation system
8///
9/// Runs a series of checks before training to catch common issues early.
10#[derive(Debug, Default)]
11pub struct Preflight {
12    /// List of checks to run
13    checks: Vec<PreflightCheck>,
14    /// Context for checks
15    context: PreflightContext,
16}
17
18impl Preflight {
19    /// Create a new preflight validator
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    /// Create with standard data integrity checks
25    pub fn standard() -> Self {
26        Self::new()
27            .add_check(PreflightCheck::no_nan_values())
28            .add_check(PreflightCheck::no_inf_values())
29            .add_check(PreflightCheck::consistent_dimensions())
30            .add_check(PreflightCheck::no_constant_features())
31    }
32
33    /// Create with all checks (data + environment)
34    pub fn comprehensive() -> Self {
35        Self::standard()
36            .add_check(PreflightCheck::min_samples(10))
37            .add_check(PreflightCheck::min_features(1))
38            .add_check(PreflightCheck::disk_space_mb(100))
39            .add_check(PreflightCheck::memory_mb(256))
40            .add_check(PreflightCheck::gpu_available())
41    }
42
43    /// Add a check
44    pub fn add_check(mut self, check: PreflightCheck) -> Self {
45        self.checks.push(check);
46        self
47    }
48
49    /// Set context
50    pub fn with_context(mut self, context: PreflightContext) -> Self {
51        self.context = context;
52        self
53    }
54
55    /// Get the number of checks
56    pub fn check_count(&self) -> usize {
57        self.checks.len()
58    }
59
60    /// Run all checks
61    pub fn run(&self, data: &[Vec<f64>]) -> PreflightResults {
62        let mut results = Vec::new();
63        let mut passed_count = 0;
64        let mut failed_count = 0;
65        let mut warning_count = 0;
66        let mut skipped_count = 0;
67        let mut all_required_passed = true;
68
69        for check in &self.checks {
70            let result = check.run(data, &self.context);
71
72            match &result {
73                CheckResult::Passed { .. } => passed_count += 1,
74                CheckResult::Failed { .. } => {
75                    failed_count += 1;
76                    if check.required {
77                        all_required_passed = false;
78                    }
79                }
80                CheckResult::Warning { .. } => warning_count += 1,
81                CheckResult::Skipped { .. } => skipped_count += 1,
82            }
83
84            results.push((CheckMetadata::from(check), result));
85        }
86
87        PreflightResults::new(
88            results,
89            all_required_passed,
90            passed_count,
91            failed_count,
92            warning_count,
93            skipped_count,
94        )
95    }
96
97    /// Run checks and return error if any required check fails
98    pub fn validate(&self, data: &[Vec<f64>]) -> Result<PreflightResults, PreflightError> {
99        let results = self.run(data);
100
101        if results.all_passed() {
102            Ok(results)
103        } else {
104            Err(PreflightError::ValidationFailed {
105                checks_failed: results.failed_count(),
106                total_checks: self.checks.len(),
107            })
108        }
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    // =========================================================================
117    // Preflight Tests
118    // =========================================================================
119
120    #[test]
121    fn test_preflight_new() {
122        let preflight = Preflight::new();
123        assert_eq!(preflight.check_count(), 0);
124    }
125
126    #[test]
127    fn test_preflight_add_check() {
128        let preflight = Preflight::new()
129            .add_check(PreflightCheck::no_nan_values())
130            .add_check(PreflightCheck::no_inf_values());
131        assert_eq!(preflight.check_count(), 2);
132    }
133
134    #[test]
135    fn test_preflight_standard() {
136        let preflight = Preflight::standard();
137        assert!(preflight.check_count() >= 3);
138    }
139
140    #[test]
141    fn test_preflight_comprehensive() {
142        let preflight = Preflight::comprehensive();
143        assert!(preflight.check_count() >= 5);
144    }
145
146    #[test]
147    fn test_preflight_run_all_pass() {
148        let preflight = Preflight::new()
149            .add_check(PreflightCheck::no_nan_values())
150            .add_check(PreflightCheck::no_inf_values())
151            .add_check(PreflightCheck::min_samples(2));
152
153        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
154        let results = preflight.run(&data);
155
156        assert!(results.all_passed());
157        assert_eq!(results.passed_count(), 3);
158        assert_eq!(results.failed_count(), 0);
159    }
160
161    #[test]
162    fn test_preflight_run_with_failure() {
163        let preflight = Preflight::new()
164            .add_check(PreflightCheck::no_nan_values())
165            .add_check(PreflightCheck::min_samples(10));
166
167        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
168        let results = preflight.run(&data);
169
170        assert!(!results.all_passed());
171        assert_eq!(results.passed_count(), 1);
172        assert_eq!(results.failed_count(), 1);
173    }
174
175    #[test]
176    fn test_preflight_optional_check_doesnt_fail() {
177        let preflight = Preflight::new()
178            .add_check(PreflightCheck::no_nan_values())
179            .add_check(PreflightCheck::no_constant_features()); // Optional
180
181        let data = vec![vec![1.0, 2.0], vec![1.0, 4.0]]; // First column constant
182        let results = preflight.run(&data);
183
184        // Should pass because constant features check is optional
185        assert!(results.all_passed());
186        assert_eq!(results.warning_count(), 1);
187    }
188
189    #[test]
190    fn test_preflight_validate_success() {
191        let preflight = Preflight::new().add_check(PreflightCheck::no_nan_values());
192        let data = vec![vec![1.0, 2.0]];
193        let result = preflight.validate(&data);
194        assert!(result.is_ok());
195    }
196
197    #[test]
198    fn test_preflight_validate_failure() {
199        let preflight = Preflight::new().add_check(PreflightCheck::min_samples(100));
200        let data = vec![vec![1.0, 2.0]];
201        let result = preflight.validate(&data);
202        assert!(result.is_err());
203    }
204
205    #[test]
206    fn test_preflight_results_failed_checks() {
207        let preflight = Preflight::new()
208            .add_check(PreflightCheck::no_nan_values())
209            .add_check(PreflightCheck::min_samples(10));
210
211        let data = vec![vec![1.0]];
212        let results = preflight.run(&data);
213
214        let failed = results.failed_checks();
215        assert_eq!(failed.len(), 1);
216        assert_eq!(failed[0].0.name, "min_samples");
217    }
218
219    #[test]
220    fn test_preflight_results_warnings() {
221        let preflight = Preflight::new().add_check(PreflightCheck::no_constant_features());
222
223        let data = vec![vec![1.0, 2.0], vec![1.0, 3.0]];
224        let results = preflight.run(&data);
225
226        let warnings = results.warnings();
227        assert_eq!(warnings.len(), 1);
228    }
229
230    #[test]
231    fn test_preflight_results_report() {
232        let preflight = Preflight::new().add_check(PreflightCheck::no_nan_values());
233
234        let data = vec![vec![1.0, 2.0]];
235        let results = preflight.run(&data);
236        let report = results.report();
237
238        assert!(report.contains("Preflight Check Results"));
239        assert!(report.contains("PASSED"));
240        assert!(report.contains("no_nan_values"));
241    }
242
243    #[test]
244    fn test_preflight_with_context() {
245        let ctx = PreflightContext::new().with_min_samples(5);
246        let preflight =
247            Preflight::new().add_check(PreflightCheck::min_samples(1)).with_context(ctx);
248
249        let data = vec![vec![1.0], vec![2.0], vec![3.0]];
250        let results = preflight.run(&data);
251
252        // Context min_samples=5 should override check's default of 1
253        assert!(!results.all_passed());
254    }
255
256    // =========================================================================
257    // Property Tests
258    // =========================================================================
259
260    use proptest::prelude::*;
261
262    proptest! {
263        #![proptest_config(ProptestConfig::with_cases(200))]
264
265        #[test]
266        fn prop_preflight_results_counts_consistent(
267            n_checks in 1usize..10
268        ) {
269            let mut preflight = Preflight::new();
270            for _ in 0..n_checks {
271                preflight = preflight.add_check(PreflightCheck::no_nan_values());
272            }
273
274            let data = vec![vec![1.0, 2.0]];
275            let results = preflight.run(&data);
276
277            let total = results.passed_count()
278                + results.failed_count()
279                + results.warning_count()
280                + results.skipped_count();
281
282            prop_assert_eq!(total, n_checks);
283        }
284    }
285}