oxirs_vec/
validation.rs

1//! Comprehensive data validation for vector operations
2//!
3//! This module provides validation utilities for vectors, indices, and search operations
4//! to ensure data integrity and catch errors early.
5
6use crate::{Vector, VectorPrecision};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Validation severity level
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
13pub enum ValidationSeverity {
14    /// Informational message
15    Info,
16    /// Warning that should be addressed
17    Warning,
18    /// Error that must be fixed
19    Error,
20    /// Critical error that prevents operation
21    Critical,
22}
23
24/// Validation rule violation
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ValidationViolation {
27    /// Severity level
28    pub severity: ValidationSeverity,
29    /// Rule that was violated
30    pub rule: String,
31    /// Detailed message
32    pub message: String,
33    /// Optional context
34    pub context: Option<String>,
35}
36
37impl ValidationViolation {
38    pub fn new(
39        severity: ValidationSeverity,
40        rule: impl Into<String>,
41        message: impl Into<String>,
42    ) -> Self {
43        Self {
44            severity,
45            rule: rule.into(),
46            message: message.into(),
47            context: None,
48        }
49    }
50
51    pub fn with_context(mut self, context: impl Into<String>) -> Self {
52        self.context = Some(context.into());
53        self
54    }
55}
56
57/// Validation result
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ValidationResult {
60    /// Whether validation passed (no errors or critical issues)
61    pub passed: bool,
62    /// List of violations
63    pub violations: Vec<ValidationViolation>,
64    /// Validation timestamp
65    pub timestamp: u64,
66}
67
68impl ValidationResult {
69    pub fn success() -> Self {
70        Self {
71            passed: true,
72            violations: Vec::new(),
73            timestamp: std::time::SystemTime::now()
74                .duration_since(std::time::UNIX_EPOCH)
75                .unwrap()
76                .as_secs(),
77        }
78    }
79
80    pub fn with_violations(violations: Vec<ValidationViolation>) -> Self {
81        let passed = !violations.iter().any(|v| {
82            matches!(
83                v.severity,
84                ValidationSeverity::Error | ValidationSeverity::Critical
85            )
86        });
87
88        Self {
89            passed,
90            violations,
91            timestamp: std::time::SystemTime::now()
92                .duration_since(std::time::UNIX_EPOCH)
93                .unwrap()
94                .as_secs(),
95        }
96    }
97
98    pub fn has_errors(&self) -> bool {
99        self.violations.iter().any(|v| {
100            matches!(
101                v.severity,
102                ValidationSeverity::Error | ValidationSeverity::Critical
103            )
104        })
105    }
106
107    pub fn has_warnings(&self) -> bool {
108        self.violations
109            .iter()
110            .any(|v| v.severity == ValidationSeverity::Warning)
111    }
112
113    pub fn error_count(&self) -> usize {
114        self.violations
115            .iter()
116            .filter(|v| {
117                matches!(
118                    v.severity,
119                    ValidationSeverity::Error | ValidationSeverity::Critical
120                )
121            })
122            .count()
123    }
124
125    pub fn warning_count(&self) -> usize {
126        self.violations
127            .iter()
128            .filter(|v| v.severity == ValidationSeverity::Warning)
129            .count()
130    }
131}
132
133/// Vector validation rules
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct VectorValidationRules {
136    /// Minimum allowed dimensions
137    pub min_dimensions: Option<usize>,
138    /// Maximum allowed dimensions
139    pub max_dimensions: Option<usize>,
140    /// Require normalized vectors (L2 norm = 1)
141    pub require_normalized: bool,
142    /// Tolerance for normalization check
143    pub normalization_tolerance: f32,
144    /// Check for NaN or infinite values
145    pub check_for_invalid_values: bool,
146    /// Check for zero vectors
147    pub disallow_zero_vectors: bool,
148    /// Expected precision (if any)
149    pub expected_precision: Option<VectorPrecision>,
150    /// Minimum non-zero values (for sparse vectors)
151    pub min_non_zero: Option<usize>,
152    /// Maximum magnitude
153    pub max_magnitude: Option<f32>,
154}
155
156impl Default for VectorValidationRules {
157    fn default() -> Self {
158        Self {
159            min_dimensions: Some(1),
160            max_dimensions: None,
161            require_normalized: false,
162            normalization_tolerance: 1e-6,
163            check_for_invalid_values: true,
164            disallow_zero_vectors: false,
165            expected_precision: None,
166            min_non_zero: None,
167            max_magnitude: None,
168        }
169    }
170}
171
172/// Vector validator
173pub struct VectorValidator {
174    rules: VectorValidationRules,
175}
176
177impl VectorValidator {
178    pub fn new(rules: VectorValidationRules) -> Self {
179        Self { rules }
180    }
181
182    pub fn with_default_rules() -> Self {
183        Self::new(VectorValidationRules::default())
184    }
185
186    /// Validate a single vector
187    pub fn validate(&self, vector: &Vector) -> ValidationResult {
188        let mut violations = Vec::new();
189
190        // Check dimensions
191        if let Some(min_dim) = self.rules.min_dimensions {
192            if vector.dimensions < min_dim {
193                violations.push(ValidationViolation::new(
194                    ValidationSeverity::Error,
195                    "min_dimensions",
196                    format!(
197                        "Vector has {} dimensions, minimum is {}",
198                        vector.dimensions, min_dim
199                    ),
200                ));
201            }
202        }
203
204        if let Some(max_dim) = self.rules.max_dimensions {
205            if vector.dimensions > max_dim {
206                violations.push(ValidationViolation::new(
207                    ValidationSeverity::Error,
208                    "max_dimensions",
209                    format!(
210                        "Vector has {} dimensions, maximum is {}",
211                        vector.dimensions, max_dim
212                    ),
213                ));
214            }
215        }
216
217        // Check for invalid values
218        if self.rules.check_for_invalid_values {
219            let values = vector.as_f32();
220            let has_nan = values.iter().any(|v| v.is_nan());
221            let has_inf = values.iter().any(|v| v.is_infinite());
222
223            if has_nan {
224                violations.push(ValidationViolation::new(
225                    ValidationSeverity::Critical,
226                    "invalid_values",
227                    "Vector contains NaN values",
228                ));
229            }
230
231            if has_inf {
232                violations.push(ValidationViolation::new(
233                    ValidationSeverity::Critical,
234                    "invalid_values",
235                    "Vector contains infinite values",
236                ));
237            }
238        }
239
240        // Check for zero vector
241        if self.rules.disallow_zero_vectors {
242            let magnitude = vector.magnitude();
243            if magnitude < 1e-10 {
244                violations.push(ValidationViolation::new(
245                    ValidationSeverity::Error,
246                    "zero_vector",
247                    "Vector is approximately zero",
248                ));
249            }
250        }
251
252        // Check normalization
253        if self.rules.require_normalized {
254            let magnitude = vector.magnitude();
255            if (magnitude - 1.0).abs() > self.rules.normalization_tolerance {
256                violations.push(ValidationViolation::new(
257                    ValidationSeverity::Warning,
258                    "normalization",
259                    format!("Vector is not normalized (magnitude: {:.6})", magnitude),
260                ));
261            }
262        }
263
264        // Check precision
265        if let Some(expected_precision) = self.rules.expected_precision {
266            if vector.precision != expected_precision {
267                violations.push(ValidationViolation::new(
268                    ValidationSeverity::Warning,
269                    "precision",
270                    format!(
271                        "Vector precision {:?} does not match expected {:?}",
272                        vector.precision, expected_precision
273                    ),
274                ));
275            }
276        }
277
278        // Check sparsity
279        if let Some(min_non_zero) = self.rules.min_non_zero {
280            let values = vector.as_f32();
281            let non_zero_count = values.iter().filter(|&&v| v.abs() > 1e-10).count();
282
283            if non_zero_count < min_non_zero {
284                violations.push(ValidationViolation::new(
285                    ValidationSeverity::Warning,
286                    "sparsity",
287                    format!(
288                        "Vector has {} non-zero values, minimum is {}",
289                        non_zero_count, min_non_zero
290                    ),
291                ));
292            }
293        }
294
295        // Check maximum magnitude
296        if let Some(max_mag) = self.rules.max_magnitude {
297            let magnitude = vector.magnitude();
298            if magnitude > max_mag {
299                violations.push(ValidationViolation::new(
300                    ValidationSeverity::Error,
301                    "magnitude",
302                    format!(
303                        "Vector magnitude {:.6} exceeds maximum {:.6}",
304                        magnitude, max_mag
305                    ),
306                ));
307            }
308        }
309
310        ValidationResult::with_violations(violations)
311    }
312
313    /// Validate multiple vectors
314    pub fn validate_batch(
315        &self,
316        vectors: &[(String, Vector)],
317    ) -> HashMap<String, ValidationResult> {
318        vectors
319            .iter()
320            .map(|(id, vector)| (id.clone(), self.validate(vector)))
321            .collect()
322    }
323
324    /// Validate and return only invalid vectors
325    pub fn find_invalid(&self, vectors: &[(String, Vector)]) -> Vec<(String, ValidationResult)> {
326        vectors
327            .iter()
328            .map(|(id, vector)| (id.clone(), self.validate(vector)))
329            .filter(|(_, result)| !result.passed)
330            .collect()
331    }
332}
333
334/// Dimension consistency validator
335pub struct DimensionValidator {
336    expected_dimension: Option<usize>,
337}
338
339impl DimensionValidator {
340    pub fn new() -> Self {
341        Self {
342            expected_dimension: None,
343        }
344    }
345
346    pub fn with_expected_dimension(dimension: usize) -> Self {
347        Self {
348            expected_dimension: Some(dimension),
349        }
350    }
351
352    /// Validate dimension consistency across multiple vectors
353    pub fn validate_consistency(&mut self, vectors: &[(String, Vector)]) -> ValidationResult {
354        let mut violations = Vec::new();
355
356        if vectors.is_empty() {
357            return ValidationResult::success();
358        }
359
360        // Check if expected dimension is set, if not, use first vector's dimension
361        let expected = if let Some(dim) = self.expected_dimension {
362            dim
363        } else {
364            let first_dim = vectors[0].1.dimensions;
365            self.expected_dimension = Some(first_dim);
366            first_dim
367        };
368
369        // Check all vectors have consistent dimensions
370        for (id, vector) in vectors {
371            if vector.dimensions != expected {
372                violations.push(
373                    ValidationViolation::new(
374                        ValidationSeverity::Error,
375                        "dimension_mismatch",
376                        format!(
377                            "Vector '{}' has {} dimensions, expected {}",
378                            id, vector.dimensions, expected
379                        ),
380                    )
381                    .with_context(format!(
382                        "expected={}, actual={}",
383                        expected, vector.dimensions
384                    )),
385                );
386            }
387        }
388
389        ValidationResult::with_violations(violations)
390    }
391
392    /// Get the established dimension
393    pub fn established_dimension(&self) -> Option<usize> {
394        self.expected_dimension
395    }
396}
397
398/// Metadata validator
399pub struct MetadataValidator {
400    required_fields: Vec<String>,
401    field_patterns: HashMap<String, regex::Regex>,
402}
403
404impl MetadataValidator {
405    pub fn new() -> Self {
406        Self {
407            required_fields: Vec::new(),
408            field_patterns: HashMap::new(),
409        }
410    }
411
412    pub fn require_field(&mut self, field: impl Into<String>) -> &mut Self {
413        self.required_fields.push(field.into());
414        self
415    }
416
417    pub fn require_pattern(
418        &mut self,
419        field: impl Into<String>,
420        pattern: &str,
421    ) -> Result<&mut Self> {
422        let regex = regex::Regex::new(pattern)?;
423        self.field_patterns.insert(field.into(), regex);
424        Ok(self)
425    }
426
427    /// Validate metadata
428    pub fn validate(&self, metadata: &HashMap<String, String>) -> ValidationResult {
429        let mut violations = Vec::new();
430
431        // Check required fields
432        for field in &self.required_fields {
433            if !metadata.contains_key(field) {
434                violations.push(ValidationViolation::new(
435                    ValidationSeverity::Error,
436                    "missing_field",
437                    format!("Required field '{}' is missing", field),
438                ));
439            }
440        }
441
442        // Check patterns
443        for (field, pattern) in &self.field_patterns {
444            if let Some(value) = metadata.get(field) {
445                if !pattern.is_match(value) {
446                    violations.push(ValidationViolation::new(
447                        ValidationSeverity::Error,
448                        "pattern_mismatch",
449                        format!(
450                            "Field '{}' value '{}' does not match required pattern",
451                            field, value
452                        ),
453                    ));
454                }
455            }
456        }
457
458        ValidationResult::with_violations(violations)
459    }
460}
461
462impl Default for MetadataValidator {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468impl Default for DimensionValidator {
469    fn default() -> Self {
470        Self::new()
471    }
472}
473
474/// Comprehensive validator for all operations
475pub struct ComprehensiveValidator {
476    vector_validator: VectorValidator,
477    dimension_validator: DimensionValidator,
478    metadata_validator: Option<MetadataValidator>,
479}
480
481impl ComprehensiveValidator {
482    pub fn new(vector_rules: VectorValidationRules, expected_dimension: Option<usize>) -> Self {
483        Self {
484            vector_validator: VectorValidator::new(vector_rules),
485            dimension_validator: if let Some(dim) = expected_dimension {
486                DimensionValidator::with_expected_dimension(dim)
487            } else {
488                DimensionValidator::new()
489            },
490            metadata_validator: None,
491        }
492    }
493
494    pub fn with_metadata_validator(mut self, validator: MetadataValidator) -> Self {
495        self.metadata_validator = Some(validator);
496        self
497    }
498
499    /// Validate vector with all rules
500    pub fn validate_vector(
501        &self,
502        id: &str,
503        vector: &Vector,
504        metadata: Option<&HashMap<String, String>>,
505    ) -> ValidationResult {
506        let mut all_violations = Vec::new();
507
508        // Vector validation
509        let vector_result = self.vector_validator.validate(vector);
510        all_violations.extend(vector_result.violations);
511
512        // Dimension validation (single vector check)
513        if let Some(expected_dim) = self.dimension_validator.established_dimension() {
514            if vector.dimensions != expected_dim {
515                all_violations.push(ValidationViolation::new(
516                    ValidationSeverity::Error,
517                    "dimension_mismatch",
518                    format!(
519                        "Vector '{}' has {} dimensions, expected {}",
520                        id, vector.dimensions, expected_dim
521                    ),
522                ));
523            }
524        }
525
526        // Metadata validation
527        if let (Some(validator), Some(meta)) = (&self.metadata_validator, metadata) {
528            let meta_result = validator.validate(meta);
529            all_violations.extend(meta_result.violations);
530        }
531
532        ValidationResult::with_violations(all_violations)
533    }
534
535    /// Validate batch of vectors
536    #[allow(clippy::type_complexity)]
537    pub fn validate_batch(
538        &mut self,
539        vectors: &[(String, Vector, Option<HashMap<String, String>>)],
540    ) -> HashMap<String, ValidationResult> {
541        let mut results = HashMap::new();
542
543        // First pass: dimension consistency
544        let vectors_only: Vec<(String, Vector)> = vectors
545            .iter()
546            .map(|(id, vec, _)| (id.clone(), vec.clone()))
547            .collect();
548
549        let dim_result = self.dimension_validator.validate_consistency(&vectors_only);
550        if dim_result.has_errors() {
551            // If dimension consistency fails, report it for all vectors
552            for (id, _, _) in vectors {
553                results.insert(id.clone(), dim_result.clone());
554            }
555            return results;
556        }
557
558        // Second pass: individual validation
559        for (id, vector, metadata) in vectors {
560            let result = self.validate_vector(id, vector, metadata.as_ref());
561            results.insert(id.clone(), result);
562        }
563
564        results
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_valid_vector() {
574        let rules = VectorValidationRules::default();
575        let validator = VectorValidator::new(rules);
576
577        let vector = Vector::new(vec![1.0, 2.0, 3.0]);
578        let result = validator.validate(&vector);
579
580        assert!(result.passed);
581        assert_eq!(result.violations.len(), 0);
582    }
583
584    #[test]
585    fn test_invalid_dimensions() {
586        let rules = VectorValidationRules {
587            min_dimensions: Some(5),
588            ..Default::default()
589        };
590        let validator = VectorValidator::new(rules);
591
592        let vector = Vector::new(vec![1.0, 2.0]);
593        let result = validator.validate(&vector);
594
595        assert!(!result.passed);
596        assert!(result.has_errors());
597    }
598
599    #[test]
600    fn test_normalized_vector() {
601        let rules = VectorValidationRules {
602            require_normalized: true,
603            ..Default::default()
604        };
605        let validator = VectorValidator::new(rules);
606
607        // Not normalized
608        let vector1 = Vector::new(vec![1.0, 2.0, 3.0]);
609        let result1 = validator.validate(&vector1);
610        assert!(result1.has_warnings());
611
612        // Normalized
613        let vector2 = Vector::new(vec![1.0, 0.0, 0.0]);
614        let result2 = validator.validate(&vector2);
615        assert!(result2.passed);
616    }
617
618    #[test]
619    fn test_invalid_values() {
620        let rules = VectorValidationRules {
621            check_for_invalid_values: true,
622            ..Default::default()
623        };
624        let validator = VectorValidator::new(rules);
625
626        let vector = Vector::new(vec![1.0, f32::NAN, 3.0]);
627        let result = validator.validate(&vector);
628
629        assert!(!result.passed);
630        assert_eq!(result.error_count(), 1);
631    }
632
633    #[test]
634    fn test_dimension_consistency() {
635        let mut validator = DimensionValidator::new();
636
637        let vectors = vec![
638            ("vec1".to_string(), Vector::new(vec![1.0, 2.0, 3.0])),
639            ("vec2".to_string(), Vector::new(vec![4.0, 5.0, 6.0])),
640            ("vec3".to_string(), Vector::new(vec![7.0, 8.0])), // Wrong dimension
641        ];
642
643        let result = validator.validate_consistency(&vectors);
644
645        assert!(!result.passed);
646        assert_eq!(result.error_count(), 1);
647    }
648
649    #[test]
650    fn test_metadata_validation() {
651        let mut validator = MetadataValidator::new();
652        validator.require_field("category");
653        validator
654            .require_pattern("status", r"^(active|inactive)$")
655            .unwrap();
656
657        let mut valid_metadata = HashMap::new();
658        valid_metadata.insert("category".to_string(), "news".to_string());
659        valid_metadata.insert("status".to_string(), "active".to_string());
660
661        let result1 = validator.validate(&valid_metadata);
662        assert!(result1.passed);
663
664        let mut invalid_metadata = HashMap::new();
665        invalid_metadata.insert("status".to_string(), "pending".to_string()); // Wrong pattern, missing category
666
667        let result2 = validator.validate(&invalid_metadata);
668        assert!(!result2.passed);
669        assert_eq!(result2.error_count(), 2);
670    }
671
672    #[test]
673    fn test_comprehensive_validator() {
674        let rules = VectorValidationRules::default();
675        let mut validator = ComprehensiveValidator::new(rules, None); // Don't set expected dimension upfront
676
677        let vectors = vec![
678            ("vec1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]), None),
679            ("vec2".to_string(), Vector::new(vec![4.0, 5.0]), None), // Wrong dimension
680        ];
681
682        let results = validator.validate_batch(&vectors);
683
684        // First vector should fail because dimension consistency check fails
685        // (vec1 has 3 dims, vec2 has 2 dims - they're inconsistent)
686        assert!(!results["vec1"].passed); // Dimension inconsistency reported for all
687        assert!(!results["vec2"].passed);
688    }
689}