1use entrenar_common::Result;
4use std::path::Path;
5
6#[derive(Debug, Clone)]
8pub struct ValidationResult {
9 pub valid: bool,
11 pub issues: Vec<ValidationIssue>,
13 pub warnings: Vec<String>,
15 pub checks: Vec<ValidationCheck>,
17}
18
19impl ValidationResult {
20 pub fn has_errors(&self) -> bool {
22 self.issues.iter().any(|i| i.severity == Severity::Error)
23 }
24
25 pub fn to_report(&self) -> String {
27 let mut report = String::new();
28
29 report.push_str(&format!(
30 "Validation Result: {}\n\n",
31 if self.valid { "PASS" } else { "FAIL" }
32 ));
33
34 if !self.issues.is_empty() {
35 report.push_str("Issues:\n");
36 for issue in &self.issues {
37 let prefix = match issue.severity {
38 Severity::Error => "✗",
39 Severity::Warning => "⚠",
40 Severity::Info => "ℹ",
41 };
42 report.push_str(&format!(" {} {}: {}\n", prefix, issue.code, issue.message));
43 if let Some(suggestion) = &issue.suggestion {
44 report.push_str(&format!(" → {suggestion}\n"));
45 }
46 }
47 report.push('\n');
48 }
49
50 report.push_str("Checks Performed:\n");
51 for check in &self.checks {
52 let status = if check.passed { "✓" } else { "✗" };
53 report.push_str(&format!(" {} {}\n", status, check.name));
54 }
55
56 report
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct ValidationIssue {
63 pub code: String,
65 pub message: String,
67 pub severity: Severity,
69 pub suggestion: Option<String>,
71 pub tensor: Option<String>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum Severity {
78 Error,
80 Warning,
82 Info,
84}
85
86#[derive(Debug, Clone)]
88pub struct ValidationCheck {
89 pub name: String,
91 pub passed: bool,
93 pub duration_ms: u64,
95}
96
97pub struct IntegrityChecker {
99 strict: bool,
100}
101
102impl Default for IntegrityChecker {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl IntegrityChecker {
109 pub fn new() -> Self {
111 Self { strict: false }
112 }
113
114 pub fn strict(mut self) -> Self {
116 self.strict = true;
117 self
118 }
119
120 pub fn validate(&self, path: &Path) -> Result<ValidationResult> {
122 let mut issues = Vec::new();
123 let mut warnings = Vec::new();
124 let mut checks = Vec::new();
125
126 let file_check = self.check_file_exists(path);
128 checks.push(file_check.clone());
129 if !file_check.passed {
130 issues.push(ValidationIssue {
131 code: "V001".to_string(),
132 message: format!("File not found: {}", path.display()),
133 severity: Severity::Error,
134 suggestion: Some("Check the file path".to_string()),
135 tensor: None,
136 });
137 return Ok(ValidationResult {
138 valid: false,
139 issues,
140 warnings,
141 checks,
142 });
143 }
144
145 let format_check = self.check_format(path);
147 checks.push(format_check.clone());
148 if !format_check.passed {
149 issues.push(ValidationIssue {
150 code: "V002".to_string(),
151 message: "Unsupported or potentially unsafe format".to_string(),
152 severity: if self.strict {
153 Severity::Error
154 } else {
155 Severity::Warning
156 },
157 suggestion: Some("Use SafeTensors format for security".to_string()),
158 tensor: None,
159 });
160 }
161
162 let size_check = self.check_file_size(path);
164 checks.push(size_check.clone());
165 if !size_check.passed {
166 warnings.push("File size is unusually small - may be corrupted".to_string());
167 }
168
169 let valid = !issues.iter().any(|i| i.severity == Severity::Error);
176
177 Ok(ValidationResult {
178 valid,
179 issues,
180 warnings,
181 checks,
182 })
183 }
184
185 fn check_file_exists(&self, path: &Path) -> ValidationCheck {
186 let start = std::time::Instant::now();
187 let passed = path.exists();
188 ValidationCheck {
189 name: "File exists".to_string(),
190 passed,
191 duration_ms: start.elapsed().as_millis() as u64,
192 }
193 }
194
195 fn check_format(&self, path: &Path) -> ValidationCheck {
196 let start = std::time::Instant::now();
197 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
198 let passed = matches!(
199 extension.to_lowercase().as_str(),
200 "safetensors" | "gguf" | "apr"
201 );
202 ValidationCheck {
203 name: "Safe format".to_string(),
204 passed,
205 duration_ms: start.elapsed().as_millis() as u64,
206 }
207 }
208
209 fn check_file_size(&self, path: &Path) -> ValidationCheck {
210 let start = std::time::Instant::now();
211 let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
212 let passed = size > 1000; ValidationCheck {
214 name: "Valid file size".to_string(),
215 passed,
216 duration_ms: start.elapsed().as_millis() as u64,
217 }
218 }
219}
220
221pub fn validate_model(path: impl AsRef<Path>) -> Result<ValidationResult> {
223 IntegrityChecker::new().validate(path.as_ref())
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::io::Write;
230 use tempfile::NamedTempFile;
231
232 #[test]
233 fn test_validation_missing_file() {
234 let result = validate_model("/nonexistent/model.safetensors").unwrap();
235 assert!(!result.valid);
236 assert!(result.issues.iter().any(|i| i.code == "V001"));
237 }
238
239 #[test]
240 fn test_validation_safe_format() {
241 let mut file = NamedTempFile::with_suffix(".safetensors").unwrap();
242 file.write_all(&[0u8; 2000]).unwrap();
243
244 let result = validate_model(file.path()).unwrap();
245 assert!(result
246 .checks
247 .iter()
248 .any(|c| c.name == "Safe format" && c.passed));
249 }
250
251 #[test]
252 fn test_validation_unsafe_format() {
253 let mut file = NamedTempFile::with_suffix(".pt").unwrap();
254 file.write_all(&[0u8; 2000]).unwrap();
255
256 let result = validate_model(file.path()).unwrap();
257 let format_check = result
259 .checks
260 .iter()
261 .find(|c| c.name == "Safe format")
262 .unwrap();
263 assert!(!format_check.passed);
264 }
265
266 #[test]
267 fn test_strict_mode() {
268 let mut file = NamedTempFile::with_suffix(".pt").unwrap();
269 file.write_all(&[0u8; 2000]).unwrap();
270
271 let result = IntegrityChecker::new()
272 .strict()
273 .validate(file.path())
274 .unwrap();
275 assert!(!result.valid); }
277
278 #[test]
279 fn test_validation_report() {
280 let result = ValidationResult {
281 valid: false,
282 issues: vec![ValidationIssue {
283 code: "V001".to_string(),
284 message: "Test error".to_string(),
285 severity: Severity::Error,
286 suggestion: Some("Fix it".to_string()),
287 tensor: None,
288 }],
289 warnings: vec![],
290 checks: vec![ValidationCheck {
291 name: "Test check".to_string(),
292 passed: false,
293 duration_ms: 1,
294 }],
295 };
296
297 let report = result.to_report();
298 assert!(report.contains("FAIL"));
299 assert!(report.contains("V001"));
300 assert!(report.contains("Fix it"));
301 }
302
303 #[test]
304 fn test_validation_report_pass() {
305 let result = ValidationResult {
306 valid: true,
307 issues: vec![],
308 warnings: vec![],
309 checks: vec![ValidationCheck {
310 name: "Test check".to_string(),
311 passed: true,
312 duration_ms: 1,
313 }],
314 };
315
316 let report = result.to_report();
317 assert!(report.contains("PASS"));
318 assert!(report.contains("✓"));
319 }
320
321 #[test]
322 fn test_has_errors_with_error() {
323 let result = ValidationResult {
324 valid: false,
325 issues: vec![ValidationIssue {
326 code: "V001".to_string(),
327 message: "Error".to_string(),
328 severity: Severity::Error,
329 suggestion: None,
330 tensor: None,
331 }],
332 warnings: vec![],
333 checks: vec![],
334 };
335 assert!(result.has_errors());
336 }
337
338 #[test]
339 fn test_has_errors_with_warning_only() {
340 let result = ValidationResult {
341 valid: true,
342 issues: vec![ValidationIssue {
343 code: "V002".to_string(),
344 message: "Warning".to_string(),
345 severity: Severity::Warning,
346 suggestion: None,
347 tensor: None,
348 }],
349 warnings: vec![],
350 checks: vec![],
351 };
352 assert!(!result.has_errors());
353 }
354
355 #[test]
356 fn test_has_errors_with_info_only() {
357 let result = ValidationResult {
358 valid: true,
359 issues: vec![ValidationIssue {
360 code: "V003".to_string(),
361 message: "Info".to_string(),
362 severity: Severity::Info,
363 suggestion: None,
364 tensor: None,
365 }],
366 warnings: vec![],
367 checks: vec![],
368 };
369 assert!(!result.has_errors());
370 }
371
372 #[test]
373 fn test_integrity_checker_default() {
374 let checker = IntegrityChecker::default();
375 let mut file = NamedTempFile::with_suffix(".pt").unwrap();
377 file.write_all(&[0u8; 2000]).unwrap();
378
379 let result = checker.validate(file.path()).unwrap();
380 assert!(result.valid);
382 }
383
384 #[test]
385 fn test_severity_equality() {
386 assert_eq!(Severity::Error, Severity::Error);
387 assert_ne!(Severity::Error, Severity::Warning);
388 assert_ne!(Severity::Warning, Severity::Info);
389 }
390
391 #[test]
392 fn test_report_warning_symbol() {
393 let result = ValidationResult {
394 valid: true,
395 issues: vec![ValidationIssue {
396 code: "V002".to_string(),
397 message: "Warning".to_string(),
398 severity: Severity::Warning,
399 suggestion: None,
400 tensor: None,
401 }],
402 warnings: vec![],
403 checks: vec![],
404 };
405
406 let report = result.to_report();
407 assert!(report.contains("⚠"));
408 }
409
410 #[test]
411 fn test_report_info_symbol() {
412 let result = ValidationResult {
413 valid: true,
414 issues: vec![ValidationIssue {
415 code: "V003".to_string(),
416 message: "Info".to_string(),
417 severity: Severity::Info,
418 suggestion: None,
419 tensor: None,
420 }],
421 warnings: vec![],
422 checks: vec![],
423 };
424
425 let report = result.to_report();
426 assert!(report.contains("ℹ"));
427 }
428
429 #[test]
430 fn test_validation_small_file_warning() {
431 let mut file = NamedTempFile::with_suffix(".safetensors").unwrap();
432 file.write_all(&[0u8; 100]).unwrap(); let result = validate_model(file.path()).unwrap();
435 assert!(!result.warnings.is_empty() || !result.checks.iter().all(|c| c.passed));
437 }
438
439 #[test]
440 fn test_validation_gguf_format() {
441 let mut file = NamedTempFile::with_suffix(".gguf").unwrap();
442 file.write_all(&[0u8; 2000]).unwrap();
443
444 let result = validate_model(file.path()).unwrap();
445 let format_check = result.checks.iter().find(|c| c.name == "Safe format");
446 assert!(format_check.is_some());
447 assert!(format_check.unwrap().passed);
448 }
449
450 #[test]
451 fn test_validation_apr_format() {
452 let mut file = NamedTempFile::with_suffix(".apr").unwrap();
453 file.write_all(&[0u8; 2000]).unwrap();
454
455 let result = validate_model(file.path()).unwrap();
456 let format_check = result.checks.iter().find(|c| c.name == "Safe format");
457 assert!(format_check.is_some());
458 assert!(format_check.unwrap().passed);
459 }
460}