1use crate::{Vector, VectorPrecision};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
13pub enum ValidationSeverity {
14 Info,
16 Warning,
18 Error,
20 Critical,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ValidationViolation {
27 pub severity: ValidationSeverity,
29 pub rule: String,
31 pub message: String,
33 pub context: Option<String>,
35}
36
37impl ValidationViolation {
38 pub fn new(
39 severity: ValidationSeverity,
40 rule: impl Into<String>,
41 message: impl Into<String>,
42 ) -> Self {
43 Self {
44 severity,
45 rule: rule.into(),
46 message: message.into(),
47 context: None,
48 }
49 }
50
51 pub fn with_context(mut self, context: impl Into<String>) -> Self {
52 self.context = Some(context.into());
53 self
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ValidationResult {
60 pub passed: bool,
62 pub violations: Vec<ValidationViolation>,
64 pub timestamp: u64,
66}
67
68impl ValidationResult {
69 pub fn success() -> Self {
70 Self {
71 passed: true,
72 violations: Vec::new(),
73 timestamp: std::time::SystemTime::now()
74 .duration_since(std::time::UNIX_EPOCH)
75 .unwrap()
76 .as_secs(),
77 }
78 }
79
80 pub fn with_violations(violations: Vec<ValidationViolation>) -> Self {
81 let passed = !violations.iter().any(|v| {
82 matches!(
83 v.severity,
84 ValidationSeverity::Error | ValidationSeverity::Critical
85 )
86 });
87
88 Self {
89 passed,
90 violations,
91 timestamp: std::time::SystemTime::now()
92 .duration_since(std::time::UNIX_EPOCH)
93 .unwrap()
94 .as_secs(),
95 }
96 }
97
98 pub fn has_errors(&self) -> bool {
99 self.violations.iter().any(|v| {
100 matches!(
101 v.severity,
102 ValidationSeverity::Error | ValidationSeverity::Critical
103 )
104 })
105 }
106
107 pub fn has_warnings(&self) -> bool {
108 self.violations
109 .iter()
110 .any(|v| v.severity == ValidationSeverity::Warning)
111 }
112
113 pub fn error_count(&self) -> usize {
114 self.violations
115 .iter()
116 .filter(|v| {
117 matches!(
118 v.severity,
119 ValidationSeverity::Error | ValidationSeverity::Critical
120 )
121 })
122 .count()
123 }
124
125 pub fn warning_count(&self) -> usize {
126 self.violations
127 .iter()
128 .filter(|v| v.severity == ValidationSeverity::Warning)
129 .count()
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct VectorValidationRules {
136 pub min_dimensions: Option<usize>,
138 pub max_dimensions: Option<usize>,
140 pub require_normalized: bool,
142 pub normalization_tolerance: f32,
144 pub check_for_invalid_values: bool,
146 pub disallow_zero_vectors: bool,
148 pub expected_precision: Option<VectorPrecision>,
150 pub min_non_zero: Option<usize>,
152 pub max_magnitude: Option<f32>,
154}
155
156impl Default for VectorValidationRules {
157 fn default() -> Self {
158 Self {
159 min_dimensions: Some(1),
160 max_dimensions: None,
161 require_normalized: false,
162 normalization_tolerance: 1e-6,
163 check_for_invalid_values: true,
164 disallow_zero_vectors: false,
165 expected_precision: None,
166 min_non_zero: None,
167 max_magnitude: None,
168 }
169 }
170}
171
172pub struct VectorValidator {
174 rules: VectorValidationRules,
175}
176
177impl VectorValidator {
178 pub fn new(rules: VectorValidationRules) -> Self {
179 Self { rules }
180 }
181
182 pub fn with_default_rules() -> Self {
183 Self::new(VectorValidationRules::default())
184 }
185
186 pub fn validate(&self, vector: &Vector) -> ValidationResult {
188 let mut violations = Vec::new();
189
190 if let Some(min_dim) = self.rules.min_dimensions {
192 if vector.dimensions < min_dim {
193 violations.push(ValidationViolation::new(
194 ValidationSeverity::Error,
195 "min_dimensions",
196 format!(
197 "Vector has {} dimensions, minimum is {}",
198 vector.dimensions, min_dim
199 ),
200 ));
201 }
202 }
203
204 if let Some(max_dim) = self.rules.max_dimensions {
205 if vector.dimensions > max_dim {
206 violations.push(ValidationViolation::new(
207 ValidationSeverity::Error,
208 "max_dimensions",
209 format!(
210 "Vector has {} dimensions, maximum is {}",
211 vector.dimensions, max_dim
212 ),
213 ));
214 }
215 }
216
217 if self.rules.check_for_invalid_values {
219 let values = vector.as_f32();
220 let has_nan = values.iter().any(|v| v.is_nan());
221 let has_inf = values.iter().any(|v| v.is_infinite());
222
223 if has_nan {
224 violations.push(ValidationViolation::new(
225 ValidationSeverity::Critical,
226 "invalid_values",
227 "Vector contains NaN values",
228 ));
229 }
230
231 if has_inf {
232 violations.push(ValidationViolation::new(
233 ValidationSeverity::Critical,
234 "invalid_values",
235 "Vector contains infinite values",
236 ));
237 }
238 }
239
240 if self.rules.disallow_zero_vectors {
242 let magnitude = vector.magnitude();
243 if magnitude < 1e-10 {
244 violations.push(ValidationViolation::new(
245 ValidationSeverity::Error,
246 "zero_vector",
247 "Vector is approximately zero",
248 ));
249 }
250 }
251
252 if self.rules.require_normalized {
254 let magnitude = vector.magnitude();
255 if (magnitude - 1.0).abs() > self.rules.normalization_tolerance {
256 violations.push(ValidationViolation::new(
257 ValidationSeverity::Warning,
258 "normalization",
259 format!("Vector is not normalized (magnitude: {:.6})", magnitude),
260 ));
261 }
262 }
263
264 if let Some(expected_precision) = self.rules.expected_precision {
266 if vector.precision != expected_precision {
267 violations.push(ValidationViolation::new(
268 ValidationSeverity::Warning,
269 "precision",
270 format!(
271 "Vector precision {:?} does not match expected {:?}",
272 vector.precision, expected_precision
273 ),
274 ));
275 }
276 }
277
278 if let Some(min_non_zero) = self.rules.min_non_zero {
280 let values = vector.as_f32();
281 let non_zero_count = values.iter().filter(|&&v| v.abs() > 1e-10).count();
282
283 if non_zero_count < min_non_zero {
284 violations.push(ValidationViolation::new(
285 ValidationSeverity::Warning,
286 "sparsity",
287 format!(
288 "Vector has {} non-zero values, minimum is {}",
289 non_zero_count, min_non_zero
290 ),
291 ));
292 }
293 }
294
295 if let Some(max_mag) = self.rules.max_magnitude {
297 let magnitude = vector.magnitude();
298 if magnitude > max_mag {
299 violations.push(ValidationViolation::new(
300 ValidationSeverity::Error,
301 "magnitude",
302 format!(
303 "Vector magnitude {:.6} exceeds maximum {:.6}",
304 magnitude, max_mag
305 ),
306 ));
307 }
308 }
309
310 ValidationResult::with_violations(violations)
311 }
312
313 pub fn validate_batch(
315 &self,
316 vectors: &[(String, Vector)],
317 ) -> HashMap<String, ValidationResult> {
318 vectors
319 .iter()
320 .map(|(id, vector)| (id.clone(), self.validate(vector)))
321 .collect()
322 }
323
324 pub fn find_invalid(&self, vectors: &[(String, Vector)]) -> Vec<(String, ValidationResult)> {
326 vectors
327 .iter()
328 .map(|(id, vector)| (id.clone(), self.validate(vector)))
329 .filter(|(_, result)| !result.passed)
330 .collect()
331 }
332}
333
334pub struct DimensionValidator {
336 expected_dimension: Option<usize>,
337}
338
339impl DimensionValidator {
340 pub fn new() -> Self {
341 Self {
342 expected_dimension: None,
343 }
344 }
345
346 pub fn with_expected_dimension(dimension: usize) -> Self {
347 Self {
348 expected_dimension: Some(dimension),
349 }
350 }
351
352 pub fn validate_consistency(&mut self, vectors: &[(String, Vector)]) -> ValidationResult {
354 let mut violations = Vec::new();
355
356 if vectors.is_empty() {
357 return ValidationResult::success();
358 }
359
360 let expected = if let Some(dim) = self.expected_dimension {
362 dim
363 } else {
364 let first_dim = vectors[0].1.dimensions;
365 self.expected_dimension = Some(first_dim);
366 first_dim
367 };
368
369 for (id, vector) in vectors {
371 if vector.dimensions != expected {
372 violations.push(
373 ValidationViolation::new(
374 ValidationSeverity::Error,
375 "dimension_mismatch",
376 format!(
377 "Vector '{}' has {} dimensions, expected {}",
378 id, vector.dimensions, expected
379 ),
380 )
381 .with_context(format!(
382 "expected={}, actual={}",
383 expected, vector.dimensions
384 )),
385 );
386 }
387 }
388
389 ValidationResult::with_violations(violations)
390 }
391
392 pub fn established_dimension(&self) -> Option<usize> {
394 self.expected_dimension
395 }
396}
397
398pub struct MetadataValidator {
400 required_fields: Vec<String>,
401 field_patterns: HashMap<String, regex::Regex>,
402}
403
404impl MetadataValidator {
405 pub fn new() -> Self {
406 Self {
407 required_fields: Vec::new(),
408 field_patterns: HashMap::new(),
409 }
410 }
411
412 pub fn require_field(&mut self, field: impl Into<String>) -> &mut Self {
413 self.required_fields.push(field.into());
414 self
415 }
416
417 pub fn require_pattern(
418 &mut self,
419 field: impl Into<String>,
420 pattern: &str,
421 ) -> Result<&mut Self> {
422 let regex = regex::Regex::new(pattern)?;
423 self.field_patterns.insert(field.into(), regex);
424 Ok(self)
425 }
426
427 pub fn validate(&self, metadata: &HashMap<String, String>) -> ValidationResult {
429 let mut violations = Vec::new();
430
431 for field in &self.required_fields {
433 if !metadata.contains_key(field) {
434 violations.push(ValidationViolation::new(
435 ValidationSeverity::Error,
436 "missing_field",
437 format!("Required field '{}' is missing", field),
438 ));
439 }
440 }
441
442 for (field, pattern) in &self.field_patterns {
444 if let Some(value) = metadata.get(field) {
445 if !pattern.is_match(value) {
446 violations.push(ValidationViolation::new(
447 ValidationSeverity::Error,
448 "pattern_mismatch",
449 format!(
450 "Field '{}' value '{}' does not match required pattern",
451 field, value
452 ),
453 ));
454 }
455 }
456 }
457
458 ValidationResult::with_violations(violations)
459 }
460}
461
462impl Default for MetadataValidator {
463 fn default() -> Self {
464 Self::new()
465 }
466}
467
468impl Default for DimensionValidator {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474pub struct ComprehensiveValidator {
476 vector_validator: VectorValidator,
477 dimension_validator: DimensionValidator,
478 metadata_validator: Option<MetadataValidator>,
479}
480
481impl ComprehensiveValidator {
482 pub fn new(vector_rules: VectorValidationRules, expected_dimension: Option<usize>) -> Self {
483 Self {
484 vector_validator: VectorValidator::new(vector_rules),
485 dimension_validator: if let Some(dim) = expected_dimension {
486 DimensionValidator::with_expected_dimension(dim)
487 } else {
488 DimensionValidator::new()
489 },
490 metadata_validator: None,
491 }
492 }
493
494 pub fn with_metadata_validator(mut self, validator: MetadataValidator) -> Self {
495 self.metadata_validator = Some(validator);
496 self
497 }
498
499 pub fn validate_vector(
501 &self,
502 id: &str,
503 vector: &Vector,
504 metadata: Option<&HashMap<String, String>>,
505 ) -> ValidationResult {
506 let mut all_violations = Vec::new();
507
508 let vector_result = self.vector_validator.validate(vector);
510 all_violations.extend(vector_result.violations);
511
512 if let Some(expected_dim) = self.dimension_validator.established_dimension() {
514 if vector.dimensions != expected_dim {
515 all_violations.push(ValidationViolation::new(
516 ValidationSeverity::Error,
517 "dimension_mismatch",
518 format!(
519 "Vector '{}' has {} dimensions, expected {}",
520 id, vector.dimensions, expected_dim
521 ),
522 ));
523 }
524 }
525
526 if let (Some(validator), Some(meta)) = (&self.metadata_validator, metadata) {
528 let meta_result = validator.validate(meta);
529 all_violations.extend(meta_result.violations);
530 }
531
532 ValidationResult::with_violations(all_violations)
533 }
534
535 #[allow(clippy::type_complexity)]
537 pub fn validate_batch(
538 &mut self,
539 vectors: &[(String, Vector, Option<HashMap<String, String>>)],
540 ) -> HashMap<String, ValidationResult> {
541 let mut results = HashMap::new();
542
543 let vectors_only: Vec<(String, Vector)> = vectors
545 .iter()
546 .map(|(id, vec, _)| (id.clone(), vec.clone()))
547 .collect();
548
549 let dim_result = self.dimension_validator.validate_consistency(&vectors_only);
550 if dim_result.has_errors() {
551 for (id, _, _) in vectors {
553 results.insert(id.clone(), dim_result.clone());
554 }
555 return results;
556 }
557
558 for (id, vector, metadata) in vectors {
560 let result = self.validate_vector(id, vector, metadata.as_ref());
561 results.insert(id.clone(), result);
562 }
563
564 results
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_valid_vector() {
574 let rules = VectorValidationRules::default();
575 let validator = VectorValidator::new(rules);
576
577 let vector = Vector::new(vec![1.0, 2.0, 3.0]);
578 let result = validator.validate(&vector);
579
580 assert!(result.passed);
581 assert_eq!(result.violations.len(), 0);
582 }
583
584 #[test]
585 fn test_invalid_dimensions() {
586 let rules = VectorValidationRules {
587 min_dimensions: Some(5),
588 ..Default::default()
589 };
590 let validator = VectorValidator::new(rules);
591
592 let vector = Vector::new(vec![1.0, 2.0]);
593 let result = validator.validate(&vector);
594
595 assert!(!result.passed);
596 assert!(result.has_errors());
597 }
598
599 #[test]
600 fn test_normalized_vector() {
601 let rules = VectorValidationRules {
602 require_normalized: true,
603 ..Default::default()
604 };
605 let validator = VectorValidator::new(rules);
606
607 let vector1 = Vector::new(vec![1.0, 2.0, 3.0]);
609 let result1 = validator.validate(&vector1);
610 assert!(result1.has_warnings());
611
612 let vector2 = Vector::new(vec![1.0, 0.0, 0.0]);
614 let result2 = validator.validate(&vector2);
615 assert!(result2.passed);
616 }
617
618 #[test]
619 fn test_invalid_values() {
620 let rules = VectorValidationRules {
621 check_for_invalid_values: true,
622 ..Default::default()
623 };
624 let validator = VectorValidator::new(rules);
625
626 let vector = Vector::new(vec![1.0, f32::NAN, 3.0]);
627 let result = validator.validate(&vector);
628
629 assert!(!result.passed);
630 assert_eq!(result.error_count(), 1);
631 }
632
633 #[test]
634 fn test_dimension_consistency() {
635 let mut validator = DimensionValidator::new();
636
637 let vectors = vec![
638 ("vec1".to_string(), Vector::new(vec![1.0, 2.0, 3.0])),
639 ("vec2".to_string(), Vector::new(vec![4.0, 5.0, 6.0])),
640 ("vec3".to_string(), Vector::new(vec![7.0, 8.0])), ];
642
643 let result = validator.validate_consistency(&vectors);
644
645 assert!(!result.passed);
646 assert_eq!(result.error_count(), 1);
647 }
648
649 #[test]
650 fn test_metadata_validation() {
651 let mut validator = MetadataValidator::new();
652 validator.require_field("category");
653 validator
654 .require_pattern("status", r"^(active|inactive)$")
655 .unwrap();
656
657 let mut valid_metadata = HashMap::new();
658 valid_metadata.insert("category".to_string(), "news".to_string());
659 valid_metadata.insert("status".to_string(), "active".to_string());
660
661 let result1 = validator.validate(&valid_metadata);
662 assert!(result1.passed);
663
664 let mut invalid_metadata = HashMap::new();
665 invalid_metadata.insert("status".to_string(), "pending".to_string()); let result2 = validator.validate(&invalid_metadata);
668 assert!(!result2.passed);
669 assert_eq!(result2.error_count(), 2);
670 }
671
672 #[test]
673 fn test_comprehensive_validator() {
674 let rules = VectorValidationRules::default();
675 let mut validator = ComprehensiveValidator::new(rules, None); let vectors = vec![
678 ("vec1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]), None),
679 ("vec2".to_string(), Vector::new(vec![4.0, 5.0]), None), ];
681
682 let results = validator.validate_batch(&vectors);
683
684 assert!(!results["vec1"].passed); assert!(!results["vec2"].passed);
688 }
689}