Skip to main content

oxirs_embed/
validation.rs

1//! Model validation and integrity checks for embedding models
2//!
3//! This module provides comprehensive validation capabilities including:
4//! - Checksum validation (SHA-256, BLAKE3) for model integrity
5//! - Dimension consistency checks across model layers
6//! - Model signature and format verification
7//! - Model metadata validation
8//!
9//! All operations use proper error handling with no unwrap() calls.
10
11use anyhow::{anyhow, Context, Result};
12use blake3::Hasher as Blake3Hasher;
13use scirs2_core::ndarray_ext::{ArrayView, Ix2};
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, Read};
19use std::path::{Path, PathBuf};
20use tokio::fs::File as AsyncFile;
21use tokio::io::AsyncReadExt;
22use tracing::{debug, error, info};
23
24/// Supported checksum algorithms for model validation
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ChecksumAlgorithm {
27    /// SHA-256 cryptographic hash
28    Sha256,
29    /// BLAKE3 high-performance hash
30    Blake3,
31}
32
33/// Model format types supported by validation
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum ModelFormat {
36    /// ONNX format (magic bytes: "ONNX")
37    Onnx,
38    /// SafeTensors format
39    SafeTensors,
40    /// Custom OxiRS embedding format
41    OxirsEmbed,
42    /// PyTorch format
43    PyTorch,
44    /// TensorFlow SavedModel
45    TensorFlow,
46    /// Unknown format
47    Unknown,
48}
49
50/// Model validation configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ValidationConfig {
53    /// Checksum algorithm to use
54    pub checksum_algorithm: ChecksumAlgorithm,
55    /// Whether to validate checksums
56    pub validate_checksum: bool,
57    /// Whether to validate dimensions
58    pub validate_dimensions: bool,
59    /// Whether to validate model signature
60    pub validate_signature: bool,
61    /// Whether to validate metadata
62    pub validate_metadata: bool,
63    /// Expected model format (None = auto-detect)
64    pub expected_format: Option<ModelFormat>,
65    /// Required metadata fields
66    pub required_metadata_fields: Vec<String>,
67}
68
69impl Default for ValidationConfig {
70    fn default() -> Self {
71        Self {
72            checksum_algorithm: ChecksumAlgorithm::Blake3,
73            validate_checksum: true,
74            validate_dimensions: true,
75            validate_signature: true,
76            validate_metadata: true,
77            expected_format: None,
78            required_metadata_fields: vec![
79                "model_name".to_string(),
80                "embedding_dim".to_string(),
81                "version".to_string(),
82            ],
83        }
84    }
85}
86
87/// Model metadata for validation
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModelMetadata {
90    /// Model name
91    pub model_name: String,
92    /// Model version
93    pub version: String,
94    /// Embedding dimension
95    pub embedding_dim: usize,
96    /// Input dimension
97    pub input_dim: Option<usize>,
98    /// Output dimension
99    pub output_dim: Option<usize>,
100    /// Model format
101    pub format: ModelFormat,
102    /// Expected checksum
103    pub checksum: Option<String>,
104    /// Checksum algorithm used
105    pub checksum_algorithm: Option<ChecksumAlgorithm>,
106    /// Additional metadata
107    pub extra: HashMap<String, serde_json::Value>,
108}
109
110/// Validation result status
111#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub enum ValidationStatus {
113    /// Validation passed
114    Valid,
115    /// Validation failed
116    Invalid,
117    /// Validation skipped
118    Skipped,
119    /// Validation warning (non-critical issues)
120    Warning,
121}
122
123/// Validation result for a specific check
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ValidationResult {
126    /// Type of validation performed
127    pub validation_type: String,
128    /// Result status
129    pub status: ValidationStatus,
130    /// Detailed message
131    pub message: String,
132    /// Additional details
133    pub details: Option<serde_json::Value>,
134}
135
136/// Comprehensive validation report
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ValidationReport {
139    /// Model path validated
140    pub model_path: PathBuf,
141    /// Overall validation status
142    pub overall_status: ValidationStatus,
143    /// Individual validation results
144    pub results: Vec<ValidationResult>,
145    /// Validation timestamp
146    pub timestamp: chrono::DateTime<chrono::Utc>,
147}
148
149impl ValidationReport {
150    /// Check if validation passed
151    pub fn is_valid(&self) -> bool {
152        self.overall_status == ValidationStatus::Valid
153            || self.overall_status == ValidationStatus::Warning
154    }
155
156    /// Get failed validations
157    pub fn failed_validations(&self) -> Vec<&ValidationResult> {
158        self.results
159            .iter()
160            .filter(|r| r.status == ValidationStatus::Invalid)
161            .collect()
162    }
163}
164
165/// Model validator for embedding models
166pub struct ModelValidator {
167    config: ValidationConfig,
168}
169
170impl ModelValidator {
171    /// Create a new model validator with default configuration
172    pub fn new() -> Self {
173        Self {
174            config: ValidationConfig::default(),
175        }
176    }
177
178    /// Create a model validator with custom configuration
179    pub fn with_config(config: ValidationConfig) -> Self {
180        Self { config }
181    }
182
183    /// Validate a model file synchronously
184    pub fn validate(
185        &self,
186        model_path: &Path,
187        metadata: &ModelMetadata,
188    ) -> Result<ValidationReport> {
189        info!("Starting validation for model: {}", model_path.display());
190
191        let mut results = Vec::new();
192
193        // Validate checksum
194        if self.config.validate_checksum {
195            match self.validate_checksum_sync(model_path, metadata) {
196                Ok(result) => results.push(result),
197                Err(e) => {
198                    error!("Checksum validation failed: {}", e);
199                    results.push(ValidationResult {
200                        validation_type: "checksum".to_string(),
201                        status: ValidationStatus::Invalid,
202                        message: format!("Checksum validation error: {}", e),
203                        details: None,
204                    });
205                }
206            }
207        }
208
209        // Validate signature
210        if self.config.validate_signature {
211            match self.validate_signature(model_path, metadata) {
212                Ok(result) => results.push(result),
213                Err(e) => {
214                    error!("Signature validation failed: {}", e);
215                    results.push(ValidationResult {
216                        validation_type: "signature".to_string(),
217                        status: ValidationStatus::Invalid,
218                        message: format!("Signature validation error: {}", e),
219                        details: None,
220                    });
221                }
222            }
223        }
224
225        // Validate metadata
226        if self.config.validate_metadata {
227            match self.validate_metadata(metadata) {
228                Ok(result) => results.push(result),
229                Err(e) => {
230                    error!("Metadata validation failed: {}", e);
231                    results.push(ValidationResult {
232                        validation_type: "metadata".to_string(),
233                        status: ValidationStatus::Invalid,
234                        message: format!("Metadata validation error: {}", e),
235                        details: None,
236                    });
237                }
238            }
239        }
240
241        // Determine overall status
242        let overall_status = if results
243            .iter()
244            .any(|r| r.status == ValidationStatus::Invalid)
245        {
246            ValidationStatus::Invalid
247        } else if results
248            .iter()
249            .any(|r| r.status == ValidationStatus::Warning)
250        {
251            ValidationStatus::Warning
252        } else {
253            ValidationStatus::Valid
254        };
255
256        Ok(ValidationReport {
257            model_path: model_path.to_path_buf(),
258            overall_status,
259            results,
260            timestamp: chrono::Utc::now(),
261        })
262    }
263
264    /// Validate a model file asynchronously
265    pub async fn validate_async(
266        &self,
267        model_path: &Path,
268        metadata: &ModelMetadata,
269    ) -> Result<ValidationReport> {
270        info!(
271            "Starting async validation for model: {}",
272            model_path.display()
273        );
274
275        let mut results = Vec::new();
276
277        // Validate checksum
278        if self.config.validate_checksum {
279            match self.validate_checksum_async(model_path, metadata).await {
280                Ok(result) => results.push(result),
281                Err(e) => {
282                    error!("Checksum validation failed: {}", e);
283                    results.push(ValidationResult {
284                        validation_type: "checksum".to_string(),
285                        status: ValidationStatus::Invalid,
286                        message: format!("Checksum validation error: {}", e),
287                        details: None,
288                    });
289                }
290            }
291        }
292
293        // Validate signature (async)
294        if self.config.validate_signature {
295            match self.validate_signature_async(model_path, metadata).await {
296                Ok(result) => results.push(result),
297                Err(e) => {
298                    error!("Signature validation failed: {}", e);
299                    results.push(ValidationResult {
300                        validation_type: "signature".to_string(),
301                        status: ValidationStatus::Invalid,
302                        message: format!("Signature validation error: {}", e),
303                        details: None,
304                    });
305                }
306            }
307        }
308
309        // Validate metadata
310        if self.config.validate_metadata {
311            match self.validate_metadata(metadata) {
312                Ok(result) => results.push(result),
313                Err(e) => {
314                    error!("Metadata validation failed: {}", e);
315                    results.push(ValidationResult {
316                        validation_type: "metadata".to_string(),
317                        status: ValidationStatus::Invalid,
318                        message: format!("Metadata validation error: {}", e),
319                        details: None,
320                    });
321                }
322            }
323        }
324
325        // Determine overall status
326        let overall_status = if results
327            .iter()
328            .any(|r| r.status == ValidationStatus::Invalid)
329        {
330            ValidationStatus::Invalid
331        } else if results
332            .iter()
333            .any(|r| r.status == ValidationStatus::Warning)
334        {
335            ValidationStatus::Warning
336        } else {
337            ValidationStatus::Valid
338        };
339
340        Ok(ValidationReport {
341            model_path: model_path.to_path_buf(),
342            overall_status,
343            results,
344            timestamp: chrono::Utc::now(),
345        })
346    }
347
348    /// Validate dimension consistency
349    pub fn validate_dimensions(
350        &self,
351        embedding_dim: usize,
352        input_tensors: &[ArrayView<f32, Ix2>],
353        output_tensors: &[ArrayView<f32, Ix2>],
354    ) -> Result<ValidationResult> {
355        debug!(
356            "Validating dimensions: embedding_dim={}, input_tensors={}, output_tensors={}",
357            embedding_dim,
358            input_tensors.len(),
359            output_tensors.len()
360        );
361
362        // Check input tensor dimensions
363        for (i, tensor) in input_tensors.iter().enumerate() {
364            let shape = tensor.shape();
365            if shape.len() != 2 {
366                return Ok(ValidationResult {
367                    validation_type: "dimension".to_string(),
368                    status: ValidationStatus::Invalid,
369                    message: format!(
370                        "Input tensor {} has invalid rank: expected 2, got {}",
371                        i,
372                        shape.len()
373                    ),
374                    details: Some(serde_json::json!({ "tensor_index": i, "shape": shape })),
375                });
376            }
377
378            // Check if embedding dimension matches
379            if shape[1] != embedding_dim {
380                return Ok(ValidationResult {
381                    validation_type: "dimension".to_string(),
382                    status: ValidationStatus::Invalid,
383                    message: format!(
384                        "Input tensor {} dimension mismatch: expected {}, got {}",
385                        i, embedding_dim, shape[1]
386                    ),
387                    details: Some(
388                        serde_json::json!({ "tensor_index": i, "expected": embedding_dim, "actual": shape[1] }),
389                    ),
390                });
391            }
392        }
393
394        // Check output tensor dimensions
395        for (i, tensor) in output_tensors.iter().enumerate() {
396            let shape = tensor.shape();
397            if shape.len() != 2 {
398                return Ok(ValidationResult {
399                    validation_type: "dimension".to_string(),
400                    status: ValidationStatus::Invalid,
401                    message: format!(
402                        "Output tensor {} has invalid rank: expected 2, got {}",
403                        i,
404                        shape.len()
405                    ),
406                    details: Some(serde_json::json!({ "tensor_index": i, "shape": shape })),
407                });
408            }
409
410            if shape[1] != embedding_dim {
411                return Ok(ValidationResult {
412                    validation_type: "dimension".to_string(),
413                    status: ValidationStatus::Invalid,
414                    message: format!(
415                        "Output tensor {} dimension mismatch: expected {}, got {}",
416                        i, embedding_dim, shape[1]
417                    ),
418                    details: Some(
419                        serde_json::json!({ "tensor_index": i, "expected": embedding_dim, "actual": shape[1] }),
420                    ),
421                });
422            }
423        }
424
425        Ok(ValidationResult {
426            validation_type: "dimension".to_string(),
427            status: ValidationStatus::Valid,
428            message: "All dimension checks passed".to_string(),
429            details: Some(serde_json::json!({
430                "embedding_dim": embedding_dim,
431                "input_tensors_validated": input_tensors.len(),
432                "output_tensors_validated": output_tensors.len(),
433            })),
434        })
435    }
436
437    /// Validate checksum synchronously
438    fn validate_checksum_sync(
439        &self,
440        model_path: &Path,
441        metadata: &ModelMetadata,
442    ) -> Result<ValidationResult> {
443        let expected_checksum = metadata
444            .checksum
445            .as_ref()
446            .ok_or_else(|| anyhow!("No checksum provided in metadata"))?;
447
448        let checksum_algo = metadata
449            .checksum_algorithm
450            .unwrap_or(self.config.checksum_algorithm);
451
452        let computed_checksum = self.compute_checksum_sync(model_path, checksum_algo)?;
453
454        if computed_checksum == *expected_checksum {
455            Ok(ValidationResult {
456                validation_type: "checksum".to_string(),
457                status: ValidationStatus::Valid,
458                message: format!("Checksum validation passed ({:?})", checksum_algo),
459                details: Some(serde_json::json!({
460                    "algorithm": checksum_algo,
461                    "checksum": computed_checksum,
462                })),
463            })
464        } else {
465            Ok(ValidationResult {
466                validation_type: "checksum".to_string(),
467                status: ValidationStatus::Invalid,
468                message: format!("Checksum mismatch ({:?})", checksum_algo),
469                details: Some(serde_json::json!({
470                    "algorithm": checksum_algo,
471                    "expected": expected_checksum,
472                    "actual": computed_checksum,
473                })),
474            })
475        }
476    }
477
478    /// Validate checksum asynchronously
479    async fn validate_checksum_async(
480        &self,
481        model_path: &Path,
482        metadata: &ModelMetadata,
483    ) -> Result<ValidationResult> {
484        let expected_checksum = metadata
485            .checksum
486            .as_ref()
487            .ok_or_else(|| anyhow!("No checksum provided in metadata"))?;
488
489        let checksum_algo = metadata
490            .checksum_algorithm
491            .unwrap_or(self.config.checksum_algorithm);
492
493        let computed_checksum = self
494            .compute_checksum_async(model_path, checksum_algo)
495            .await?;
496
497        if computed_checksum == *expected_checksum {
498            Ok(ValidationResult {
499                validation_type: "checksum".to_string(),
500                status: ValidationStatus::Valid,
501                message: format!("Checksum validation passed ({:?})", checksum_algo),
502                details: Some(serde_json::json!({
503                    "algorithm": checksum_algo,
504                    "checksum": computed_checksum,
505                })),
506            })
507        } else {
508            Ok(ValidationResult {
509                validation_type: "checksum".to_string(),
510                status: ValidationStatus::Invalid,
511                message: format!("Checksum mismatch ({:?})", checksum_algo),
512                details: Some(serde_json::json!({
513                    "algorithm": checksum_algo,
514                    "expected": expected_checksum,
515                    "actual": computed_checksum,
516                })),
517            })
518        }
519    }
520
521    /// Compute checksum synchronously
522    fn compute_checksum_sync(&self, path: &Path, algorithm: ChecksumAlgorithm) -> Result<String> {
523        let file = File::open(path).context("Failed to open model file")?;
524        let mut reader = BufReader::new(file);
525
526        match algorithm {
527            ChecksumAlgorithm::Sha256 => {
528                let mut hasher = Sha256::new();
529                let mut buffer = [0u8; 8192];
530                loop {
531                    let count = reader.read(&mut buffer).context("Failed to read file")?;
532                    if count == 0 {
533                        break;
534                    }
535                    hasher.update(&buffer[..count]);
536                }
537                Ok(format!("{:x}", hasher.finalize()))
538            }
539            ChecksumAlgorithm::Blake3 => {
540                let mut hasher = Blake3Hasher::new();
541                let mut buffer = [0u8; 8192];
542                loop {
543                    let count = reader.read(&mut buffer).context("Failed to read file")?;
544                    if count == 0 {
545                        break;
546                    }
547                    hasher.update(&buffer[..count]);
548                }
549                Ok(hasher.finalize().to_hex().to_string())
550            }
551        }
552    }
553
554    /// Compute checksum asynchronously
555    async fn compute_checksum_async(
556        &self,
557        path: &Path,
558        algorithm: ChecksumAlgorithm,
559    ) -> Result<String> {
560        let mut file = AsyncFile::open(path)
561            .await
562            .context("Failed to open model file")?;
563
564        match algorithm {
565            ChecksumAlgorithm::Sha256 => {
566                let mut hasher = Sha256::new();
567                let mut buffer = vec![0u8; 8192];
568                loop {
569                    let count = file
570                        .read(&mut buffer)
571                        .await
572                        .context("Failed to read file")?;
573                    if count == 0 {
574                        break;
575                    }
576                    hasher.update(&buffer[..count]);
577                }
578                Ok(format!("{:x}", hasher.finalize()))
579            }
580            ChecksumAlgorithm::Blake3 => {
581                let mut hasher = Blake3Hasher::new();
582                let mut buffer = vec![0u8; 8192];
583                loop {
584                    let count = file
585                        .read(&mut buffer)
586                        .await
587                        .context("Failed to read file")?;
588                    if count == 0 {
589                        break;
590                    }
591                    hasher.update(&buffer[..count]);
592                }
593                Ok(hasher.finalize().to_hex().to_string())
594            }
595        }
596    }
597
598    /// Validate model signature and format
599    fn validate_signature(
600        &self,
601        model_path: &Path,
602        metadata: &ModelMetadata,
603    ) -> Result<ValidationResult> {
604        let file = File::open(model_path).context("Failed to open model file")?;
605        let mut reader = BufReader::new(file);
606        let mut magic_bytes = [0u8; 8];
607
608        reader
609            .read_exact(&mut magic_bytes)
610            .context("Failed to read magic bytes")?;
611
612        let detected_format = Self::detect_format(&magic_bytes);
613
614        // Check if format matches expected
615        if let Some(expected) = self.config.expected_format {
616            if detected_format != expected && detected_format != ModelFormat::Unknown {
617                return Ok(ValidationResult {
618                    validation_type: "signature".to_string(),
619                    status: ValidationStatus::Invalid,
620                    message: format!(
621                        "Format mismatch: expected {:?}, got {:?}",
622                        expected, detected_format
623                    ),
624                    details: Some(serde_json::json!({
625                        "expected_format": expected,
626                        "detected_format": detected_format,
627                        "magic_bytes": magic_bytes,
628                    })),
629                });
630            }
631        }
632
633        // Check if format matches metadata
634        if detected_format != metadata.format && detected_format != ModelFormat::Unknown {
635            return Ok(ValidationResult {
636                validation_type: "signature".to_string(),
637                status: ValidationStatus::Warning,
638                message: format!(
639                    "Format mismatch with metadata: metadata says {:?}, detected {:?}",
640                    metadata.format, detected_format
641                ),
642                details: Some(serde_json::json!({
643                    "metadata_format": metadata.format,
644                    "detected_format": detected_format,
645                })),
646            });
647        }
648
649        Ok(ValidationResult {
650            validation_type: "signature".to_string(),
651            status: ValidationStatus::Valid,
652            message: format!("Model signature valid: {:?}", detected_format),
653            details: Some(serde_json::json!({
654                "format": detected_format,
655            })),
656        })
657    }
658
659    /// Validate model signature asynchronously
660    async fn validate_signature_async(
661        &self,
662        model_path: &Path,
663        metadata: &ModelMetadata,
664    ) -> Result<ValidationResult> {
665        let mut file = AsyncFile::open(model_path)
666            .await
667            .context("Failed to open model file")?;
668        let mut magic_bytes = [0u8; 8];
669
670        file.read_exact(&mut magic_bytes)
671            .await
672            .context("Failed to read magic bytes")?;
673
674        let detected_format = Self::detect_format(&magic_bytes);
675
676        // Check if format matches expected
677        if let Some(expected) = self.config.expected_format {
678            if detected_format != expected && detected_format != ModelFormat::Unknown {
679                return Ok(ValidationResult {
680                    validation_type: "signature".to_string(),
681                    status: ValidationStatus::Invalid,
682                    message: format!(
683                        "Format mismatch: expected {:?}, got {:?}",
684                        expected, detected_format
685                    ),
686                    details: Some(serde_json::json!({
687                        "expected_format": expected,
688                        "detected_format": detected_format,
689                        "magic_bytes": magic_bytes,
690                    })),
691                });
692            }
693        }
694
695        // Check if format matches metadata
696        if detected_format != metadata.format && detected_format != ModelFormat::Unknown {
697            return Ok(ValidationResult {
698                validation_type: "signature".to_string(),
699                status: ValidationStatus::Warning,
700                message: format!(
701                    "Format mismatch with metadata: metadata says {:?}, detected {:?}",
702                    metadata.format, detected_format
703                ),
704                details: Some(serde_json::json!({
705                    "metadata_format": metadata.format,
706                    "detected_format": detected_format,
707                })),
708            });
709        }
710
711        Ok(ValidationResult {
712            validation_type: "signature".to_string(),
713            status: ValidationStatus::Valid,
714            message: format!("Model signature valid: {:?}", detected_format),
715            details: Some(serde_json::json!({
716                "format": detected_format,
717            })),
718        })
719    }
720
721    /// Detect model format from magic bytes
722    fn detect_format(magic_bytes: &[u8]) -> ModelFormat {
723        // ONNX: starts with "08 03" or has ONNX protobuf header
724        if magic_bytes.starts_with(&[0x08, 0x03]) {
725            return ModelFormat::Onnx;
726        }
727
728        // SafeTensors: JSON header
729        if magic_bytes.starts_with(b"{") {
730            return ModelFormat::SafeTensors;
731        }
732
733        // PyTorch: ZIP archive (PK header)
734        if magic_bytes.starts_with(&[0x50, 0x4B, 0x03, 0x04]) {
735            return ModelFormat::PyTorch;
736        }
737
738        // TensorFlow SavedModel: protobuf
739        if magic_bytes.starts_with(&[0x0A]) {
740            return ModelFormat::TensorFlow;
741        }
742
743        // OxiRS custom format: "OXIRS\0\0\0"
744        if magic_bytes.starts_with(b"OXIRS") {
745            return ModelFormat::OxirsEmbed;
746        }
747
748        ModelFormat::Unknown
749    }
750
751    /// Validate model metadata
752    fn validate_metadata(&self, metadata: &ModelMetadata) -> Result<ValidationResult> {
753        let mut missing_fields = Vec::new();
754
755        // Check required fields
756        for field in &self.config.required_metadata_fields {
757            match field.as_str() {
758                "model_name" if metadata.model_name.is_empty() => {
759                    missing_fields.push("model_name".to_string());
760                }
761                "version" if metadata.version.is_empty() => {
762                    missing_fields.push("version".to_string());
763                }
764                "embedding_dim" if metadata.embedding_dim == 0 => {
765                    missing_fields.push("embedding_dim".to_string());
766                }
767                _ => {}
768            }
769        }
770
771        if !missing_fields.is_empty() {
772            return Ok(ValidationResult {
773                validation_type: "metadata".to_string(),
774                status: ValidationStatus::Invalid,
775                message: format!("Missing required metadata fields: {:?}", missing_fields),
776                details: Some(serde_json::json!({
777                    "missing_fields": missing_fields,
778                })),
779            });
780        }
781
782        // Validate dimension values
783        if metadata.embedding_dim == 0 {
784            return Ok(ValidationResult {
785                validation_type: "metadata".to_string(),
786                status: ValidationStatus::Invalid,
787                message: "Invalid embedding dimension: must be > 0".to_string(),
788                details: None,
789            });
790        }
791
792        Ok(ValidationResult {
793            validation_type: "metadata".to_string(),
794            status: ValidationStatus::Valid,
795            message: "Metadata validation passed".to_string(),
796            details: Some(serde_json::json!({
797                "model_name": metadata.model_name,
798                "version": metadata.version,
799                "embedding_dim": metadata.embedding_dim,
800            })),
801        })
802    }
803}
804
805impl Default for ModelValidator {
806    fn default() -> Self {
807        Self::new()
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use super::*;
814    use scirs2_core::ndarray_ext::Array;
815    use std::io::Write;
816    use tempfile::NamedTempFile;
817
818    #[test]
819    fn test_checksum_sha256() {
820        let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
821        temp_file
822            .write_all(b"test data")
823            .expect("Failed to write data");
824
825        let validator = ModelValidator::new();
826        let checksum = validator
827            .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Sha256)
828            .expect("Failed to compute checksum");
829
830        // Expected SHA-256 of "test data"
831        assert_eq!(
832            checksum,
833            "916f0027a575074ce72a331777c3478d6513f786a591bd892da1a577bf2335f9"
834        );
835    }
836
837    #[test]
838    fn test_checksum_blake3() {
839        let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
840        temp_file
841            .write_all(b"test data")
842            .expect("Failed to write data");
843
844        let validator = ModelValidator::new();
845        let checksum = validator
846            .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Blake3)
847            .expect("Failed to compute checksum");
848
849        // BLAKE3 should produce consistent output
850        assert!(!checksum.is_empty());
851        assert_eq!(checksum.len(), 64); // BLAKE3 produces 32-byte (64 hex chars) hash
852    }
853
854    #[test]
855    fn test_dimension_validation_valid() {
856        let validator = ModelValidator::new();
857        let embedding_dim = 128;
858
859        let input1 = Array::zeros((10, 128));
860        let input2 = Array::zeros((20, 128));
861        let output1 = Array::zeros((10, 128));
862
863        let input_views = vec![input1.view(), input2.view()];
864        let output_views = vec![output1.view()];
865
866        let result = validator
867            .validate_dimensions(embedding_dim, &input_views, &output_views)
868            .expect("Validation failed");
869
870        assert_eq!(result.status, ValidationStatus::Valid);
871    }
872
873    #[test]
874    fn test_dimension_validation_invalid() {
875        let validator = ModelValidator::new();
876        let embedding_dim = 128;
877
878        let input1 = Array::zeros((10, 128));
879        let input2 = Array::zeros((20, 64)); // Wrong dimension
880        let output1 = Array::zeros((10, 128));
881
882        let input_views = vec![input1.view(), input2.view()];
883        let output_views = vec![output1.view()];
884
885        let result = validator
886            .validate_dimensions(embedding_dim, &input_views, &output_views)
887            .expect("Validation failed");
888
889        assert_eq!(result.status, ValidationStatus::Invalid);
890        assert!(result.message.contains("dimension mismatch"));
891    }
892
893    #[test]
894    fn test_format_detection_onnx() {
895        let magic = [0x08, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
896        let format = ModelValidator::detect_format(&magic);
897        assert_eq!(format, ModelFormat::Onnx);
898    }
899
900    #[test]
901    fn test_format_detection_pytorch() {
902        let magic = [0x50, 0x4B, 0x03, 0x04, 0x00, 0x00, 0x00, 0x00];
903        let format = ModelValidator::detect_format(&magic);
904        assert_eq!(format, ModelFormat::PyTorch);
905    }
906
907    #[test]
908    fn test_format_detection_oxirs() {
909        let magic = b"OXIRS\0\0\0";
910        let format = ModelValidator::detect_format(magic);
911        assert_eq!(format, ModelFormat::OxirsEmbed);
912    }
913
914    #[test]
915    fn test_metadata_validation_valid() {
916        let validator = ModelValidator::new();
917        let metadata = ModelMetadata {
918            model_name: "test_model".to_string(),
919            version: "1.0.0".to_string(),
920            embedding_dim: 128,
921            input_dim: Some(128),
922            output_dim: Some(128),
923            format: ModelFormat::OxirsEmbed,
924            checksum: None,
925            checksum_algorithm: None,
926            extra: HashMap::new(),
927        };
928
929        let result = validator
930            .validate_metadata(&metadata)
931            .expect("Validation failed");
932        assert_eq!(result.status, ValidationStatus::Valid);
933    }
934
935    #[test]
936    fn test_metadata_validation_missing_fields() {
937        let validator = ModelValidator::new();
938        let metadata = ModelMetadata {
939            model_name: "".to_string(), // Missing
940            version: "1.0.0".to_string(),
941            embedding_dim: 128,
942            input_dim: Some(128),
943            output_dim: Some(128),
944            format: ModelFormat::OxirsEmbed,
945            checksum: None,
946            checksum_algorithm: None,
947            extra: HashMap::new(),
948        };
949
950        let result = validator
951            .validate_metadata(&metadata)
952            .expect("Validation failed");
953        assert_eq!(result.status, ValidationStatus::Invalid);
954        assert!(result.message.contains("Missing required"));
955    }
956
957    #[test]
958    fn test_metadata_validation_invalid_dimension() {
959        let validator = ModelValidator::new();
960        let metadata = ModelMetadata {
961            model_name: "test".to_string(),
962            version: "1.0.0".to_string(),
963            embedding_dim: 0, // Invalid
964            input_dim: Some(128),
965            output_dim: Some(128),
966            format: ModelFormat::OxirsEmbed,
967            checksum: None,
968            checksum_algorithm: None,
969            extra: HashMap::new(),
970        };
971
972        let result = validator
973            .validate_metadata(&metadata)
974            .expect("Validation failed");
975        assert_eq!(result.status, ValidationStatus::Invalid);
976    }
977
978    #[tokio::test]
979    async fn test_async_checksum() {
980        let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
981        temp_file
982            .write_all(b"async test data")
983            .expect("Failed to write data");
984
985        let validator = ModelValidator::new();
986        let checksum = validator
987            .compute_checksum_async(temp_file.path(), ChecksumAlgorithm::Blake3)
988            .await
989            .expect("Failed to compute checksum");
990
991        assert!(!checksum.is_empty());
992    }
993
994    #[test]
995    fn test_validation_report_is_valid() {
996        let report = ValidationReport {
997            model_path: PathBuf::from("/test/model"),
998            overall_status: ValidationStatus::Valid,
999            results: vec![],
1000            timestamp: chrono::Utc::now(),
1001        };
1002
1003        assert!(report.is_valid());
1004    }
1005
1006    #[test]
1007    fn test_validation_report_failed_validations() {
1008        let report = ValidationReport {
1009            model_path: PathBuf::from("/test/model"),
1010            overall_status: ValidationStatus::Invalid,
1011            results: vec![
1012                ValidationResult {
1013                    validation_type: "checksum".to_string(),
1014                    status: ValidationStatus::Invalid,
1015                    message: "Checksum mismatch".to_string(),
1016                    details: None,
1017                },
1018                ValidationResult {
1019                    validation_type: "dimension".to_string(),
1020                    status: ValidationStatus::Valid,
1021                    message: "OK".to_string(),
1022                    details: None,
1023                },
1024            ],
1025            timestamp: chrono::Utc::now(),
1026        };
1027
1028        let failed = report.failed_validations();
1029        assert_eq!(failed.len(), 1);
1030        assert_eq!(failed[0].validation_type, "checksum");
1031    }
1032
1033    #[test]
1034    fn test_comprehensive_validation() {
1035        let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
1036        // Write ONNX magic bytes
1037        temp_file
1038            .write_all(&[0x08, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
1039            .expect("Failed to write");
1040        temp_file
1041            .write_all(b"model data here")
1042            .expect("Failed to write");
1043
1044        let validator = ModelValidator::new();
1045
1046        // Compute actual checksum
1047        let checksum = validator
1048            .compute_checksum_sync(temp_file.path(), ChecksumAlgorithm::Blake3)
1049            .expect("Failed to compute checksum");
1050
1051        let metadata = ModelMetadata {
1052            model_name: "test_model".to_string(),
1053            version: "1.0.0".to_string(),
1054            embedding_dim: 128,
1055            input_dim: Some(128),
1056            output_dim: Some(128),
1057            format: ModelFormat::Onnx,
1058            checksum: Some(checksum),
1059            checksum_algorithm: Some(ChecksumAlgorithm::Blake3),
1060            extra: HashMap::new(),
1061        };
1062
1063        let report = validator
1064            .validate(temp_file.path(), &metadata)
1065            .expect("Validation failed");
1066
1067        assert!(report.is_valid());
1068        assert!(report.results.len() >= 2); // At least checksum and signature
1069    }
1070}