1use anyhow::{anyhow, Context, Result};
12use blake3::Hasher as Blake3Hasher;
13use scirs2_core::ndarray_ext::{ArrayView, Ix2};
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, Read};
19use std::path::{Path, PathBuf};
20use tokio::fs::File as AsyncFile;
21use tokio::io::AsyncReadExt;
22use tracing::{debug, error, info};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ChecksumAlgorithm {
27 Sha256,
29 Blake3,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum ModelFormat {
36 Onnx,
38 SafeTensors,
40 OxirsEmbed,
42 PyTorch,
44 TensorFlow,
46 Unknown,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ValidationConfig {
53 pub checksum_algorithm: ChecksumAlgorithm,
55 pub validate_checksum: bool,
57 pub validate_dimensions: bool,
59 pub validate_signature: bool,
61 pub validate_metadata: bool,
63 pub expected_format: Option<ModelFormat>,
65 pub required_metadata_fields: Vec<String>,
67}
68
69impl Default for ValidationConfig {
70 fn default() -> Self {
71 Self {
72 checksum_algorithm: ChecksumAlgorithm::Blake3,
73 validate_checksum: true,
74 validate_dimensions: true,
75 validate_signature: true,
76 validate_metadata: true,
77 expected_format: None,
78 required_metadata_fields: vec![
79 "model_name".to_string(),
80 "embedding_dim".to_string(),
81 "version".to_string(),
82 ],
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModelMetadata {
90 pub model_name: String,
92 pub version: String,
94 pub embedding_dim: usize,
96 pub input_dim: Option<usize>,
98 pub output_dim: Option<usize>,
100 pub format: ModelFormat,
102 pub checksum: Option<String>,
104 pub checksum_algorithm: Option<ChecksumAlgorithm>,
106 pub extra: HashMap<String, serde_json::Value>,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub enum ValidationStatus {
113 Valid,
115 Invalid,
117 Skipped,
119 Warning,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ValidationResult {
126 pub validation_type: String,
128 pub status: ValidationStatus,
130 pub message: String,
132 pub details: Option<serde_json::Value>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ValidationReport {
139 pub model_path: PathBuf,
141 pub overall_status: ValidationStatus,
143 pub results: Vec<ValidationResult>,
145 pub timestamp: chrono::DateTime<chrono::Utc>,
147}
148
149impl ValidationReport {
150 pub fn is_valid(&self) -> bool {
152 self.overall_status == ValidationStatus::Valid
153 || self.overall_status == ValidationStatus::Warning
154 }
155
156 pub fn failed_validations(&self) -> Vec<&ValidationResult> {
158 self.results
159 .iter()
160 .filter(|r| r.status == ValidationStatus::Invalid)
161 .collect()
162 }
163}
164
165pub struct ModelValidator {
167 config: ValidationConfig,
168}
169
170impl ModelValidator {
171 pub fn new() -> Self {
173 Self {
174 config: ValidationConfig::default(),
175 }
176 }
177
178 pub fn with_config(config: ValidationConfig) -> Self {
180 Self { config }
181 }
182
183 pub fn validate(
185 &self,
186 model_path: &Path,
187 metadata: &ModelMetadata,
188 ) -> Result<ValidationReport> {
189 info!("Starting validation for model: {}", model_path.display());
190
191 let mut results = Vec::new();
192
193 if self.config.validate_checksum {
195 match self.validate_checksum_sync(model_path, metadata) {
196 Ok(result) => results.push(result),
197 Err(e) => {
198 error!("Checksum validation failed: {}", e);
199 results.push(ValidationResult {
200 validation_type: "checksum".to_string(),
201 status: ValidationStatus::Invalid,
202 message: format!("Checksum validation error: {}", e),
203 details: None,
204 });
205 }
206 }
207 }
208
209 if self.config.validate_signature {
211 match self.validate_signature(model_path, metadata) {
212 Ok(result) => results.push(result),
213 Err(e) => {
214 error!("Signature validation failed: {}", e);
215 results.push(ValidationResult {
216 validation_type: "signature".to_string(),
217 status: ValidationStatus::Invalid,
218 message: format!("Signature validation error: {}", e),
219 details: None,
220 });
221 }
222 }
223 }
224
225 if self.config.validate_metadata {
227 match self.validate_metadata(metadata) {
228 Ok(result) => results.push(result),
229 Err(e) => {
230 error!("Metadata validation failed: {}", e);
231 results.push(ValidationResult {
232 validation_type: "metadata".to_string(),
233 status: ValidationStatus::Invalid,
234 message: format!("Metadata validation error: {}", e),
235 details: None,
236 });
237 }
238 }
239 }
240
241 let overall_status = if results
243 .iter()
244 .any(|r| r.status == ValidationStatus::Invalid)
245 {
246 ValidationStatus::Invalid
247 } else if results
248 .iter()
249 .any(|r| r.status == ValidationStatus::Warning)
250 {
251 ValidationStatus::Warning
252 } else {
253 ValidationStatus::Valid
254 };
255
256 Ok(ValidationReport {
257 model_path: model_path.to_path_buf(),
258 overall_status,
259 results,
260 timestamp: chrono::Utc::now(),
261 })
262 }
263
264 pub async fn validate_async(
266 &self,
267 model_path: &Path,
268 metadata: &ModelMetadata,
269 ) -> Result<ValidationReport> {
270 info!(
271 "Starting async validation for model: {}",
272 model_path.display()
273 );
274
275 let mut results = Vec::new();
276
277 if self.config.validate_checksum {
279 match self.validate_checksum_async(model_path, metadata).await {
280 Ok(result) => results.push(result),
281 Err(e) => {
282 error!("Checksum validation failed: {}", e);
283 results.push(ValidationResult {
284 validation_type: "checksum".to_string(),
285 status: ValidationStatus::Invalid,
286 message: format!("Checksum validation error: {}", e),
287 details: None,
288 });
289 }
290 }
291 }
292
293 if self.config.validate_signature {
295 match self.validate_signature_async(model_path, metadata).await {
296 Ok(result) => results.push(result),
297 Err(e) => {
298 error!("Signature validation failed: {}", e);
299 results.push(ValidationResult {
300 validation_type: "signature".to_string(),
301 status: ValidationStatus::Invalid,
302 message: format!("Signature validation error: {}", e),
303 details: None,
304 });
305 }
306 }
307 }
308
309 if self.config.validate_metadata {
311 match self.validate_metadata(metadata) {
312 Ok(result) => results.push(result),
313 Err(e) => {
314 error!("Metadata validation failed: {}", e);
315 results.push(ValidationResult {
316 validation_type: "metadata".to_string(),
317 status: ValidationStatus::Invalid,
318 message: format!("Metadata validation error: {}", e),
319 details: None,
320 });
321 }
322 }
323 }
324
325 let overall_status = if results
327 .iter()
328 .any(|r| r.status == ValidationStatus::Invalid)
329 {
330 ValidationStatus::Invalid
331 } else if results
332 .iter()
333 .any(|r| r.status == ValidationStatus::Warning)
334 {
335 ValidationStatus::Warning
336 } else {
337 ValidationStatus::Valid
338 };
339
340 Ok(ValidationReport {
341 model_path: model_path.to_path_buf(),
342 overall_status,
343 results,
344 timestamp: chrono::Utc::now(),
345 })
346 }
347
348 pub fn validate_dimensions(
350 &self,
351 embedding_dim: usize,
352 input_tensors: &[ArrayView<f32, Ix2>],
353 output_tensors: &[ArrayView<f32, Ix2>],
354 ) -> Result<ValidationResult> {
355 debug!(
356 "Validating dimensions: embedding_dim={}, input_tensors={}, output_tensors={}",
357 embedding_dim,
358 input_tensors.len(),
359 output_tensors.len()
360 );
361
362 for (i, tensor) in input_tensors.iter().enumerate() {
364 let shape = tensor.shape();
365 if shape.len() != 2 {
366 return Ok(ValidationResult {
367 validation_type: "dimension".to_string(),
368 status: ValidationStatus::Invalid,
369 message: format!(
370 "Input tensor {} has invalid rank: expected 2, got {}",
371 i,
372 shape.len()
373 ),
374 details: Some(serde_json::json!({ "tensor_index": i, "shape": shape })),
375 });
376 }
377
378 if shape[1] != embedding_dim {
380 return Ok(ValidationResult {
381 validation_type: "dimension".to_string(),
382 status: ValidationStatus::Invalid,
383 message: format!(
384 "Input tensor {} dimension mismatch: expected {}, got {}",
385 i, embedding_dim, shape[1]
386 ),
387 details: Some(
388 serde_json::json!({ "tensor_index": i, "expected": embedding_dim, "actual": shape[1] }),
389 ),
390 });
391 }
392 }
393
394 for (i, tensor) in output_tensors.iter().enumerate() {
396 let shape = tensor.shape();
397 if shape.len() != 2 {
398 return Ok(ValidationResult {
399 validation_type: "dimension".to_string(),
400 status: ValidationStatus::Invalid,
401 message: format!(
402 "Output tensor {} has invalid rank: expected 2, got {}",
403 i,
404 shape.len()
405 ),
406 details: Some(serde_json::json!({ "tensor_index": i, "shape": shape })),
407 });
408 }
409
410 if shape[1] != embedding_dim {
411 return Ok(ValidationResult {
412 validation_type: "dimension".to_string(),
413 status: ValidationStatus::Invalid,
414 message: format!(
415 "Output tensor {} dimension mismatch: expected {}, got {}",
416 i, embedding_dim, shape[1]
417 ),
418 details: Some(
419 serde_json::json!({ "tensor_index": i, "expected": embedding_dim, "actual": shape[1] }),
420 ),
421 });
422 }
423 }
424
425 Ok(ValidationResult {
426 validation_type: "dimension".to_string(),
427 status: ValidationStatus::Valid,
428 message: "All dimension checks passed".to_string(),
429 details: Some(serde_json::json!({
430 "embedding_dim": embedding_dim,
431 "input_tensors_validated": input_tensors.len(),
432 "output_tensors_validated": output_tensors.len(),
433 })),
434 })
435 }
436
437 fn validate_checksum_sync(
439 &self,
440 model_path: &Path,
441 metadata: &ModelMetadata,
442 ) -> Result<ValidationResult> {
443 let expected_checksum = metadata
444 .checksum
445 .as_ref()
446 .ok_or_else(|| anyhow!("No checksum provided in metadata"))?;
447
448 let checksum_algo = metadata
449 .checksum_algorithm
450 .unwrap_or(self.config.checksum_algorithm);
451
452 let computed_checksum = self.compute_checksum_sync(model_path, checksum_algo)?;
453
454 if computed_checksum == *expected_checksum {
455 Ok(ValidationResult {
456 validation_type: "checksum".to_string(),
457 status: ValidationStatus::Valid,
458 message: format!("Checksum validation passed ({:?})", checksum_algo),
459 details: Some(serde_json::json!({
460 "algorithm": checksum_algo,
461 "checksum": computed_checksum,
462 })),
463 })
464 } else {
465 Ok(ValidationResult {
466 validation_type: "checksum".to_string(),
467 status: ValidationStatus::Invalid,
468 message: format!("Checksum mismatch ({:?})", checksum_algo),
469 details: Some(serde_json::json!({
470 "algorithm": checksum_algo,
471 "expected": expected_checksum,
472 "actual": computed_checksum,
473 })),
474 })
475 }
476 }
477
478 async fn validate_checksum_async(
480 &self,
481 model_path: &Path,
482 metadata: &ModelMetadata,
483 ) -> Result<ValidationResult> {
484 let expected_checksum = metadata
485 .checksum
486 .as_ref()
487 .ok_or_else(|| anyhow!("No checksum provided in metadata"))?;
488
489 let checksum_algo = metadata
490 .checksum_algorithm
491 .unwrap_or(self.config.checksum_algorithm);
492
493 let computed_checksum = self
494 .compute_checksum_async(model_path, checksum_algo)
495 .await?;
496
497 if computed_checksum == *expected_checksum {
498 Ok(ValidationResult {
499 validation_type: "checksum".to_string(),
500 status: ValidationStatus::Valid,
501 message: format!("Checksum validation passed ({:?})", checksum_algo),
502 details: Some(serde_json::json!({
503 "algorithm": checksum_algo,
504 "checksum": computed_checksum,
505 })),
506 })
507 } else {
508 Ok(ValidationResult {
509 validation_type: "checksum".to_string(),
510 status: ValidationStatus::Invalid,
511 message: format!("Checksum mismatch ({:?})", checksum_algo),
512 details: Some(serde_json::json!({
513 "algorithm": checksum_algo,
514 "expected": expected_checksum,
515 "actual": computed_checksum,
516 })),
517 })
518 }
519 }
520
521 fn compute_checksum_sync(&self, path: &Path, algorithm: ChecksumAlgorithm) -> Result<String> {
523 let file = File::open(path).context("Failed to open model file")?;
524 let mut reader = BufReader::new(file);
525
526 match algorithm {
527 ChecksumAlgorithm::Sha256 => {
528 let mut hasher = Sha256::new();
529 let mut buffer = [0u8; 8192];
530 loop {
531 let count = reader.read(&mut buffer).context("Failed to read file")?;
532 if count == 0 {
533 break;
534 }
535 hasher.update(&buffer[..count]);
536 }
537 Ok(format!("{:x}", hasher.finalize()))
538 }
539 ChecksumAlgorithm::Blake3 => {
540 let mut hasher = Blake3Hasher::new();
541 let mut buffer = [0u8; 8192];
542 loop {
543 let count = reader.read(&mut buffer).context("Failed to read file")?;
544 if count == 0 {
545 break;
546 }
547 hasher.update(&buffer[..count]);
548 }
549 Ok(hasher.finalize().to_hex().to_string())
550 }
551 }
552 }
553
554 async fn compute_checksum_async(
556 &self,
557 path: &Path,
558 algorithm: ChecksumAlgorithm,
559 ) -> Result<String> {
560 let mut file = AsyncFile::open(path)
561 .await
562 .context("Failed to open model file")?;
563
564 match algorithm {
565 ChecksumAlgorithm::Sha256 => {
566 let mut hasher = Sha256::new();
567 let mut buffer = vec![0u8; 8192];
568 loop {
569 let count = file
570 .read(&mut buffer)
571 .await
572 .context("Failed to read file")?;
573 if count == 0 {
574 break;
575 }
576 hasher.update(&buffer[..count]);
577 }
578 Ok(format!("{:x}", hasher.finalize()))
579 }
580 ChecksumAlgorithm::Blake3 => {
581 let mut hasher = Blake3Hasher::new();
582 let mut buffer = vec![0u8; 8192];
583 loop {
584 let count = file
585 .read(&mut buffer)
586 .await
587 .context("Failed to read file")?;
588 if count == 0 {
589 break;
590 }
591 hasher.update(&buffer[..count]);
592 }
593 Ok(hasher.finalize().to_hex().to_string())
594 }
595 }
596 }
597
598 fn validate_signature(
600 &self,
601 model_path: &Path,
602 metadata: &ModelMetadata,
603 ) -> Result<ValidationResult> {
604 let file = File::open(model_path).context("Failed to open model file")?;
605 let mut reader = BufReader::new(file);
606 let mut magic_bytes = [0u8; 8];
607
608 reader
609 .read_exact(&mut magic_bytes)
610 .context("Failed to read magic bytes")?;
611
612 let detected_format = Self::detect_format(&magic_bytes);
613
614 if let Some(expected) = self.config.expected_format {
616 if detected_format != expected && detected_format != ModelFormat::Unknown {
617 return Ok(ValidationResult {
618 validation_type: "signature".to_string(),
619 status: ValidationStatus::Invalid,
620 message: format!(
621 "Format mismatch: expected {:?}, got {:?}",
622 expected, detected_format
623 ),
624 details: Some(serde_json::json!({
625 "expected_format": expected,
626 "detected_format": detected_format,
627 "magic_bytes": magic_bytes,
628 })),
629 });
630 }
631 }
632
633 if detected_format != metadata.format && detected_format != ModelFormat::Unknown {
635 return Ok(ValidationResult {
636 validation_type: "signature".to_string(),
637 status: ValidationStatus::Warning,
638 message: format!(
639 "Format mismatch with metadata: metadata says {:?}, detected {:?}",
640 metadata.format, detected_format
641 ),
642 details: Some(serde_json::json!({
643 "metadata_format": metadata.format,
644 "detected_format": detected_format,
645 })),
646 });
647 }
648
649 Ok(ValidationResult {
650 validation_type: "signature".to_string(),
651 status: ValidationStatus::Valid,
652 message: format!("Model signature valid: {:?}", detected_format),
653 details: Some(serde_json::json!({
654 "format": detected_format,
655 })),
656 })
657 }
658
659 async fn validate_signature_async(
661 &self,
662 model_path: &Path,
663 metadata: &ModelMetadata,
664 ) -> Result<ValidationResult> {
665 let mut file = AsyncFile::open(model_path)
666 .await
667 .context("Failed to open model file")?;
668 let mut magic_bytes = [0u8; 8];
669
670 file.read_exact(&mut magic_bytes)
671 .await
672 .context("Failed to read magic bytes")?;
673
674 let detected_format = Self::detect_format(&magic_bytes);
675
676 if let Some(expected) = self.config.expected_format {
678 if detected_format != expected && detected_format != ModelFormat::Unknown {
679 return Ok(ValidationResult {
680 validation_type: "signature".to_string(),
681 status: ValidationStatus::Invalid,
682 message: format!(
683 "Format mismatch: expected {:?}, got {:?}",
684 expected, detected_format
685 ),
686 details: Some(serde_json::json!({
687 "expected_format": expected,
688 "detected_format": detected_format,
689 "magic_bytes": magic_bytes,
690 })),
691 });
692 }
693 }
694
695 if detected_format != metadata.format && detected_format != ModelFormat::Unknown {
697 return Ok(ValidationResult {
698 validation_type: "signature".to_string(),
699 status: ValidationStatus::Warning,
700 message: format!(
701 "Format mismatch with metadata: metadata says {:?}, detected {:?}",
702 metadata.format, detected_format
703 ),
704 details: Some(serde_json::json!({
705 "metadata_format": metadata.format,
706 "detected_format": detected_format,
707 })),
708 });
709 }
710
711 Ok(ValidationResult {
712 validation_type: "signature".to_string(),
713 status: ValidationStatus::Valid,
714 message: format!("Model signature valid: {:?}", detected_format),
715 details: Some(serde_json::json!({
716 "format": detected_format,
717 })),
718 })
719 }
720
721 fn detect_format(magic_bytes: &[u8]) -> ModelFormat {
723 if magic_bytes.starts_with(&[0x08, 0x03]) {
725 return ModelFormat::Onnx;
726 }
727
728 if magic_bytes.starts_with(b"{") {
730 return ModelFormat::SafeTensors;
731 }
732
733 if magic_bytes.starts_with(&[0x50, 0x4B, 0x03, 0x04]) {
735 return ModelFormat::PyTorch;
736 }
737
738 if magic_bytes.starts_with(&[0x0A]) {
740 return ModelFormat::TensorFlow;
741 }
742
743 if magic_bytes.starts_with(b"OXIRS") {
745 return ModelFormat::OxirsEmbed;
746 }
747
748 ModelFormat::Unknown
749 }
750
751 fn validate_metadata(&self, metadata: &ModelMetadata) -> Result<ValidationResult> {
753 let mut missing_fields = Vec::new();
754
755 for field in &self.config.required_metadata_fields {
757 match field.as_str() {
758 "model_name" if metadata.model_name.is_empty() => {
759 missing_fields.push("model_name".to_string());
760 }
761 "version" if metadata.version.is_empty() => {
762 missing_fields.push("version".to_string());
763 }
764 "embedding_dim" if metadata.embedding_dim == 0 => {
765 missing_fields.push("embedding_dim".to_string());
766 }
767 _ => {}
768 }
769 }
770
771 if !missing_fields.is_empty() {
772 return Ok(ValidationResult {
773 validation_type: "metadata".to_string(),
774 status: ValidationStatus::Invalid,
775 message: format!("Missing required metadata fields: {:?}", missing_fields),
776 details: Some(serde_json::json!({
777 "missing_fields": missing_fields,
778 })),
779 });
780 }
781
782 if metadata.embedding_dim == 0 {
784 return Ok(ValidationResult {
785 validation_type: "metadata".to_string(),
786 status: ValidationStatus::Invalid,
787 message: "Invalid embedding dimension: must be > 0".to_string(),
788 details: None,
789 });
790 }
791
792 Ok(ValidationResult {
793 validation_type: "metadata".to_string(),
794 status: ValidationStatus::Valid,
795 message: "Metadata validation passed".to_string(),
796 details: Some(serde_json::json!({
797 "model_name": metadata.model_name,
798 "version": metadata.version,
799 "embedding_dim": metadata.embedding_dim,
800 })),
801 })
802 }
803}
804
805impl Default for ModelValidator {
806 fn default() -> Self {
807 Self::new()
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use super::*;
814 use scirs2_core::ndarray_ext::Array;
815 use std::io::Write;
816 use tempfile::NamedTempFile;
817
818 #[test]
819 fn test_checksum_sha256() {
820 let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
821 temp_file
822 .write_all(b"test data")
823 .expect("Failed to write data");
824
825 let validator = ModelValidator::new();
826 let checksum = validator
827 .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Sha256)
828 .expect("Failed to compute checksum");
829
830 assert_eq!(
832 checksum,
833 "916f0027a575074ce72a331777c3478d6513f786a591bd892da1a577bf2335f9"
834 );
835 }
836
837 #[test]
838 fn test_checksum_blake3() {
839 let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
840 temp_file
841 .write_all(b"test data")
842 .expect("Failed to write data");
843
844 let validator = ModelValidator::new();
845 let checksum = validator
846 .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Blake3)
847 .expect("Failed to compute checksum");
848
849 assert!(!checksum.is_empty());
851 assert_eq!(checksum.len(), 64); }
853
854 #[test]
855 fn test_dimension_validation_valid() {
856 let validator = ModelValidator::new();
857 let embedding_dim = 128;
858
859 let input1 = Array::zeros((10, 128));
860 let input2 = Array::zeros((20, 128));
861 let output1 = Array::zeros((10, 128));
862
863 let input_views = vec![input1.view(), input2.view()];
864 let output_views = vec![output1.view()];
865
866 let result = validator
867 .validate_dimensions(embedding_dim, &input_views, &output_views)
868 .expect("Validation failed");
869
870 assert_eq!(result.status, ValidationStatus::Valid);
871 }
872
873 #[test]
874 fn test_dimension_validation_invalid() {
875 let validator = ModelValidator::new();
876 let embedding_dim = 128;
877
878 let input1 = Array::zeros((10, 128));
879 let input2 = Array::zeros((20, 64)); let output1 = Array::zeros((10, 128));
881
882 let input_views = vec![input1.view(), input2.view()];
883 let output_views = vec![output1.view()];
884
885 let result = validator
886 .validate_dimensions(embedding_dim, &input_views, &output_views)
887 .expect("Validation failed");
888
889 assert_eq!(result.status, ValidationStatus::Invalid);
890 assert!(result.message.contains("dimension mismatch"));
891 }
892
893 #[test]
894 fn test_format_detection_onnx() {
895 let magic = [0x08, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
896 let format = ModelValidator::detect_format(&magic);
897 assert_eq!(format, ModelFormat::Onnx);
898 }
899
900 #[test]
901 fn test_format_detection_pytorch() {
902 let magic = [0x50, 0x4B, 0x03, 0x04, 0x00, 0x00, 0x00, 0x00];
903 let format = ModelValidator::detect_format(&magic);
904 assert_eq!(format, ModelFormat::PyTorch);
905 }
906
907 #[test]
908 fn test_format_detection_oxirs() {
909 let magic = b"OXIRS\0\0\0";
910 let format = ModelValidator::detect_format(magic);
911 assert_eq!(format, ModelFormat::OxirsEmbed);
912 }
913
914 #[test]
915 fn test_metadata_validation_valid() {
916 let validator = ModelValidator::new();
917 let metadata = ModelMetadata {
918 model_name: "test_model".to_string(),
919 version: "1.0.0".to_string(),
920 embedding_dim: 128,
921 input_dim: Some(128),
922 output_dim: Some(128),
923 format: ModelFormat::OxirsEmbed,
924 checksum: None,
925 checksum_algorithm: None,
926 extra: HashMap::new(),
927 };
928
929 let result = validator
930 .validate_metadata(&metadata)
931 .expect("Validation failed");
932 assert_eq!(result.status, ValidationStatus::Valid);
933 }
934
935 #[test]
936 fn test_metadata_validation_missing_fields() {
937 let validator = ModelValidator::new();
938 let metadata = ModelMetadata {
939 model_name: "".to_string(), version: "1.0.0".to_string(),
941 embedding_dim: 128,
942 input_dim: Some(128),
943 output_dim: Some(128),
944 format: ModelFormat::OxirsEmbed,
945 checksum: None,
946 checksum_algorithm: None,
947 extra: HashMap::new(),
948 };
949
950 let result = validator
951 .validate_metadata(&metadata)
952 .expect("Validation failed");
953 assert_eq!(result.status, ValidationStatus::Invalid);
954 assert!(result.message.contains("Missing required"));
955 }
956
957 #[test]
958 fn test_metadata_validation_invalid_dimension() {
959 let validator = ModelValidator::new();
960 let metadata = ModelMetadata {
961 model_name: "test".to_string(),
962 version: "1.0.0".to_string(),
963 embedding_dim: 0, input_dim: Some(128),
965 output_dim: Some(128),
966 format: ModelFormat::OxirsEmbed,
967 checksum: None,
968 checksum_algorithm: None,
969 extra: HashMap::new(),
970 };
971
972 let result = validator
973 .validate_metadata(&metadata)
974 .expect("Validation failed");
975 assert_eq!(result.status, ValidationStatus::Invalid);
976 }
977
978 #[tokio::test]
979 async fn test_async_checksum() {
980 let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
981 temp_file
982 .write_all(b"async test data")
983 .expect("Failed to write data");
984
985 let validator = ModelValidator::new();
986 let checksum = validator
987 .compute_checksum_async(temp_file.path(), ChecksumAlgorithm::Blake3)
988 .await
989 .expect("Failed to compute checksum");
990
991 assert!(!checksum.is_empty());
992 }
993
994 #[test]
995 fn test_validation_report_is_valid() {
996 let report = ValidationReport {
997 model_path: PathBuf::from("/test/model"),
998 overall_status: ValidationStatus::Valid,
999 results: vec![],
1000 timestamp: chrono::Utc::now(),
1001 };
1002
1003 assert!(report.is_valid());
1004 }
1005
1006 #[test]
1007 fn test_validation_report_failed_validations() {
1008 let report = ValidationReport {
1009 model_path: PathBuf::from("/test/model"),
1010 overall_status: ValidationStatus::Invalid,
1011 results: vec![
1012 ValidationResult {
1013 validation_type: "checksum".to_string(),
1014 status: ValidationStatus::Invalid,
1015 message: "Checksum mismatch".to_string(),
1016 details: None,
1017 },
1018 ValidationResult {
1019 validation_type: "dimension".to_string(),
1020 status: ValidationStatus::Valid,
1021 message: "OK".to_string(),
1022 details: None,
1023 },
1024 ],
1025 timestamp: chrono::Utc::now(),
1026 };
1027
1028 let failed = report.failed_validations();
1029 assert_eq!(failed.len(), 1);
1030 assert_eq!(failed[0].validation_type, "checksum");
1031 }
1032
1033 #[test]
1034 fn test_comprehensive_validation() {
1035 let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
1036 temp_file
1038 .write_all(&[0x08, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
1039 .expect("Failed to write");
1040 temp_file
1041 .write_all(b"model data here")
1042 .expect("Failed to write");
1043
1044 let validator = ModelValidator::new();
1045
1046 let checksum = validator
1048 .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Blake3)
1049 .expect("Failed to compute checksum");
1050
1051 let metadata = ModelMetadata {
1052 model_name: "test_model".to_string(),
1053 version: "1.0.0".to_string(),
1054 embedding_dim: 128,
1055 input_dim: Some(128),
1056 output_dim: Some(128),
1057 format: ModelFormat::Onnx,
1058 checksum: Some(checksum),
1059 checksum_algorithm: Some(ChecksumAlgorithm::Blake3),
1060 extra: HashMap::new(),
1061 };
1062
1063 let report = validator
1064 .validate(temp_file.path(), &metadata)
1065 .expect("Validation failed");
1066
1067 assert!(report.is_valid());
1068 assert!(report.results.len() >= 2); }
1070}