entrenar_inspect/
validate.rs

1//! Model validation (Andon principle - surface problems immediately).
2
3use entrenar_common::Result;
4use std::path::Path;
5
6/// Result of model validation.
7#[derive(Debug, Clone)]
8pub struct ValidationResult {
9    /// Whether the model is valid
10    pub valid: bool,
11    /// List of issues found
12    pub issues: Vec<ValidationIssue>,
13    /// List of warnings
14    pub warnings: Vec<String>,
15    /// Validation checks performed
16    pub checks: Vec<ValidationCheck>,
17}
18
19impl ValidationResult {
20    /// Check if there are any errors (not just warnings).
21    pub fn has_errors(&self) -> bool {
22        self.issues.iter().any(|i| i.severity == Severity::Error)
23    }
24
25    /// Format as human-readable report.
26    pub fn to_report(&self) -> String {
27        let mut report = String::new();
28
29        report.push_str(&format!(
30            "Validation Result: {}\n\n",
31            if self.valid { "PASS" } else { "FAIL" }
32        ));
33
34        if !self.issues.is_empty() {
35            report.push_str("Issues:\n");
36            for issue in &self.issues {
37                let prefix = match issue.severity {
38                    Severity::Error => "✗",
39                    Severity::Warning => "⚠",
40                    Severity::Info => "ℹ",
41                };
42                report.push_str(&format!("  {} {}: {}\n", prefix, issue.code, issue.message));
43                if let Some(suggestion) = &issue.suggestion {
44                    report.push_str(&format!("    → {suggestion}\n"));
45                }
46            }
47            report.push('\n');
48        }
49
50        report.push_str("Checks Performed:\n");
51        for check in &self.checks {
52            let status = if check.passed { "✓" } else { "✗" };
53            report.push_str(&format!("  {} {}\n", status, check.name));
54        }
55
56        report
57    }
58}
59
60/// A validation issue.
61#[derive(Debug, Clone)]
62pub struct ValidationIssue {
63    /// Issue code for programmatic handling
64    pub code: String,
65    /// Human-readable message
66    pub message: String,
67    /// Severity level
68    pub severity: Severity,
69    /// Actionable suggestion
70    pub suggestion: Option<String>,
71    /// Affected tensor name (if applicable)
72    pub tensor: Option<String>,
73}
74
75/// Severity level.
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum Severity {
78    /// Critical error - model unusable
79    Error,
80    /// Warning - model usable but may have issues
81    Warning,
82    /// Informational - not a problem
83    Info,
84}
85
86/// A validation check that was performed.
87#[derive(Debug, Clone)]
88pub struct ValidationCheck {
89    /// Check name
90    pub name: String,
91    /// Whether check passed
92    pub passed: bool,
93    /// Duration in milliseconds
94    pub duration_ms: u64,
95}
96
97/// Model integrity checker.
98pub struct IntegrityChecker {
99    strict: bool,
100}
101
102impl Default for IntegrityChecker {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl IntegrityChecker {
109    /// Create a new integrity checker.
110    pub fn new() -> Self {
111        Self { strict: false }
112    }
113
114    /// Enable strict mode (treat warnings as errors).
115    pub fn strict(mut self) -> Self {
116        self.strict = true;
117        self
118    }
119
120    /// Validate a model file.
121    pub fn validate(&self, path: &Path) -> Result<ValidationResult> {
122        let mut issues = Vec::new();
123        let mut warnings = Vec::new();
124        let mut checks = Vec::new();
125
126        // Check file exists
127        let file_check = self.check_file_exists(path);
128        checks.push(file_check.clone());
129        if !file_check.passed {
130            issues.push(ValidationIssue {
131                code: "V001".to_string(),
132                message: format!("File not found: {}", path.display()),
133                severity: Severity::Error,
134                suggestion: Some("Check the file path".to_string()),
135                tensor: None,
136            });
137            return Ok(ValidationResult {
138                valid: false,
139                issues,
140                warnings,
141                checks,
142            });
143        }
144
145        // Check format
146        let format_check = self.check_format(path);
147        checks.push(format_check.clone());
148        if !format_check.passed {
149            issues.push(ValidationIssue {
150                code: "V002".to_string(),
151                message: "Unsupported or potentially unsafe format".to_string(),
152                severity: if self.strict {
153                    Severity::Error
154                } else {
155                    Severity::Warning
156                },
157                suggestion: Some("Use SafeTensors format for security".to_string()),
158                tensor: None,
159            });
160        }
161
162        // Check file size
163        let size_check = self.check_file_size(path);
164        checks.push(size_check.clone());
165        if !size_check.passed {
166            warnings.push("File size is unusually small - may be corrupted".to_string());
167        }
168
169        // In real implementation, would also check:
170        // - Tensor shapes consistency
171        // - NaN/Inf values
172        // - Data type consistency
173        // - Architecture constraints
174
175        let valid = !issues.iter().any(|i| i.severity == Severity::Error);
176
177        Ok(ValidationResult {
178            valid,
179            issues,
180            warnings,
181            checks,
182        })
183    }
184
185    fn check_file_exists(&self, path: &Path) -> ValidationCheck {
186        let start = std::time::Instant::now();
187        let passed = path.exists();
188        ValidationCheck {
189            name: "File exists".to_string(),
190            passed,
191            duration_ms: start.elapsed().as_millis() as u64,
192        }
193    }
194
195    fn check_format(&self, path: &Path) -> ValidationCheck {
196        let start = std::time::Instant::now();
197        let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
198        let passed = matches!(
199            extension.to_lowercase().as_str(),
200            "safetensors" | "gguf" | "apr"
201        );
202        ValidationCheck {
203            name: "Safe format".to_string(),
204            passed,
205            duration_ms: start.elapsed().as_millis() as u64,
206        }
207    }
208
209    fn check_file_size(&self, path: &Path) -> ValidationCheck {
210        let start = std::time::Instant::now();
211        let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
212        let passed = size > 1000; // At least 1KB
213        ValidationCheck {
214            name: "Valid file size".to_string(),
215            passed,
216            duration_ms: start.elapsed().as_millis() as u64,
217        }
218    }
219}
220
221/// Validate a model file.
222pub fn validate_model(path: impl AsRef<Path>) -> Result<ValidationResult> {
223    IntegrityChecker::new().validate(path.as_ref())
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::io::Write;
230    use tempfile::NamedTempFile;
231
232    #[test]
233    fn test_validation_missing_file() {
234        let result = validate_model("/nonexistent/model.safetensors").unwrap();
235        assert!(!result.valid);
236        assert!(result.issues.iter().any(|i| i.code == "V001"));
237    }
238
239    #[test]
240    fn test_validation_safe_format() {
241        let mut file = NamedTempFile::with_suffix(".safetensors").unwrap();
242        file.write_all(&[0u8; 2000]).unwrap();
243
244        let result = validate_model(file.path()).unwrap();
245        assert!(result
246            .checks
247            .iter()
248            .any(|c| c.name == "Safe format" && c.passed));
249    }
250
251    #[test]
252    fn test_validation_unsafe_format() {
253        let mut file = NamedTempFile::with_suffix(".pt").unwrap();
254        file.write_all(&[0u8; 2000]).unwrap();
255
256        let result = validate_model(file.path()).unwrap();
257        // In non-strict mode, unsafe format is a warning not error
258        let format_check = result
259            .checks
260            .iter()
261            .find(|c| c.name == "Safe format")
262            .unwrap();
263        assert!(!format_check.passed);
264    }
265
266    #[test]
267    fn test_strict_mode() {
268        let mut file = NamedTempFile::with_suffix(".pt").unwrap();
269        file.write_all(&[0u8; 2000]).unwrap();
270
271        let result = IntegrityChecker::new()
272            .strict()
273            .validate(file.path())
274            .unwrap();
275        assert!(!result.valid); // Unsafe format is error in strict mode
276    }
277
278    #[test]
279    fn test_validation_report() {
280        let result = ValidationResult {
281            valid: false,
282            issues: vec![ValidationIssue {
283                code: "V001".to_string(),
284                message: "Test error".to_string(),
285                severity: Severity::Error,
286                suggestion: Some("Fix it".to_string()),
287                tensor: None,
288            }],
289            warnings: vec![],
290            checks: vec![ValidationCheck {
291                name: "Test check".to_string(),
292                passed: false,
293                duration_ms: 1,
294            }],
295        };
296
297        let report = result.to_report();
298        assert!(report.contains("FAIL"));
299        assert!(report.contains("V001"));
300        assert!(report.contains("Fix it"));
301    }
302
303    #[test]
304    fn test_validation_report_pass() {
305        let result = ValidationResult {
306            valid: true,
307            issues: vec![],
308            warnings: vec![],
309            checks: vec![ValidationCheck {
310                name: "Test check".to_string(),
311                passed: true,
312                duration_ms: 1,
313            }],
314        };
315
316        let report = result.to_report();
317        assert!(report.contains("PASS"));
318        assert!(report.contains("✓"));
319    }
320
321    #[test]
322    fn test_has_errors_with_error() {
323        let result = ValidationResult {
324            valid: false,
325            issues: vec![ValidationIssue {
326                code: "V001".to_string(),
327                message: "Error".to_string(),
328                severity: Severity::Error,
329                suggestion: None,
330                tensor: None,
331            }],
332            warnings: vec![],
333            checks: vec![],
334        };
335        assert!(result.has_errors());
336    }
337
338    #[test]
339    fn test_has_errors_with_warning_only() {
340        let result = ValidationResult {
341            valid: true,
342            issues: vec![ValidationIssue {
343                code: "V002".to_string(),
344                message: "Warning".to_string(),
345                severity: Severity::Warning,
346                suggestion: None,
347                tensor: None,
348            }],
349            warnings: vec![],
350            checks: vec![],
351        };
352        assert!(!result.has_errors());
353    }
354
355    #[test]
356    fn test_has_errors_with_info_only() {
357        let result = ValidationResult {
358            valid: true,
359            issues: vec![ValidationIssue {
360                code: "V003".to_string(),
361                message: "Info".to_string(),
362                severity: Severity::Info,
363                suggestion: None,
364                tensor: None,
365            }],
366            warnings: vec![],
367            checks: vec![],
368        };
369        assert!(!result.has_errors());
370    }
371
372    #[test]
373    fn test_integrity_checker_default() {
374        let checker = IntegrityChecker::default();
375        // Should be non-strict by default
376        let mut file = NamedTempFile::with_suffix(".pt").unwrap();
377        file.write_all(&[0u8; 2000]).unwrap();
378
379        let result = checker.validate(file.path()).unwrap();
380        // Non-strict mode: unsafe format is warning, not error
381        assert!(result.valid);
382    }
383
384    #[test]
385    fn test_severity_equality() {
386        assert_eq!(Severity::Error, Severity::Error);
387        assert_ne!(Severity::Error, Severity::Warning);
388        assert_ne!(Severity::Warning, Severity::Info);
389    }
390
391    #[test]
392    fn test_report_warning_symbol() {
393        let result = ValidationResult {
394            valid: true,
395            issues: vec![ValidationIssue {
396                code: "V002".to_string(),
397                message: "Warning".to_string(),
398                severity: Severity::Warning,
399                suggestion: None,
400                tensor: None,
401            }],
402            warnings: vec![],
403            checks: vec![],
404        };
405
406        let report = result.to_report();
407        assert!(report.contains("⚠"));
408    }
409
410    #[test]
411    fn test_report_info_symbol() {
412        let result = ValidationResult {
413            valid: true,
414            issues: vec![ValidationIssue {
415                code: "V003".to_string(),
416                message: "Info".to_string(),
417                severity: Severity::Info,
418                suggestion: None,
419                tensor: None,
420            }],
421            warnings: vec![],
422            checks: vec![],
423        };
424
425        let report = result.to_report();
426        assert!(report.contains("ℹ"));
427    }
428
429    #[test]
430    fn test_validation_small_file_warning() {
431        let mut file = NamedTempFile::with_suffix(".safetensors").unwrap();
432        file.write_all(&[0u8; 100]).unwrap(); // Very small file
433
434        let result = validate_model(file.path()).unwrap();
435        // Small file should generate a warning
436        assert!(!result.warnings.is_empty() || !result.checks.iter().all(|c| c.passed));
437    }
438
439    #[test]
440    fn test_validation_gguf_format() {
441        let mut file = NamedTempFile::with_suffix(".gguf").unwrap();
442        file.write_all(&[0u8; 2000]).unwrap();
443
444        let result = validate_model(file.path()).unwrap();
445        let format_check = result.checks.iter().find(|c| c.name == "Safe format");
446        assert!(format_check.is_some());
447        assert!(format_check.unwrap().passed);
448    }
449
450    #[test]
451    fn test_validation_apr_format() {
452        let mut file = NamedTempFile::with_suffix(".apr").unwrap();
453        file.write_all(&[0u8; 2000]).unwrap();
454
455        let result = validate_model(file.path()).unwrap();
456        let format_check = result.checks.iter().find(|c| c.name == "Safe format");
457        assert!(format_check.is_some());
458        assert!(format_check.unwrap().passed);
459    }
460}