1use crate::error::{SchemaError, SchemaResult};
40use crate::types::{
41 SchemaType, ValidationLevel, ValidationReport, ValidationResult, ValidationRule,
42 ValidationRuleType,
43};
44use parking_lot::Mutex;
45use serde::Deserialize;
46use std::collections::HashMap;
47use tracing::debug;
48
49#[derive(Debug, Clone)]
51pub struct ValidationEngineConfig {
52 pub fail_fast: bool,
54 pub warnings_as_errors: bool,
56 pub max_rules_per_schema: usize,
58}
59
60impl Default for ValidationEngineConfig {
61 fn default() -> Self {
62 Self {
63 fail_fast: false,
64 warnings_as_errors: false,
65 max_rules_per_schema: 100,
66 }
67 }
68}
69
70impl ValidationEngineConfig {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_fail_fast(mut self, fail_fast: bool) -> Self {
76 self.fail_fast = fail_fast;
77 self
78 }
79
80 pub fn with_warnings_as_errors(mut self, warnings_as_errors: bool) -> Self {
81 self.warnings_as_errors = warnings_as_errors;
82 self
83 }
84}
85
86pub struct ValidationEngine {
88 config: ValidationEngineConfig,
90 global_rules: Vec<ValidationRule>,
92 subject_rules: HashMap<String, Vec<ValidationRule>>,
94 regex_cache: Mutex<HashMap<String, regex::Regex>>,
96}
97
98impl ValidationEngine {
99 pub fn new(config: ValidationEngineConfig) -> Self {
101 Self {
102 config,
103 global_rules: Vec::new(),
104 subject_rules: HashMap::new(),
105 regex_cache: Mutex::new(HashMap::new()),
106 }
107 }
108
109 pub fn add_rule(&mut self, rule: ValidationRule) {
111 self.global_rules.push(rule);
112 }
113
114 pub fn add_rules(&mut self, rules: impl IntoIterator<Item = ValidationRule>) {
116 self.global_rules.extend(rules);
117 }
118
119 pub fn add_subject_rule(&mut self, subject: &str, rule: ValidationRule) {
121 self.subject_rules
122 .entry(subject.to_string())
123 .or_default()
124 .push(rule);
125 }
126
127 pub fn remove_rule(&mut self, name: &str) -> bool {
129 let before = self.global_rules.len();
130 self.global_rules.retain(|r| r.name != name);
131
132 for rules in self.subject_rules.values_mut() {
133 rules.retain(|r| r.name != name);
134 }
135
136 self.global_rules.len() != before
137 }
138
139 pub fn rules(&self) -> &[ValidationRule] {
141 &self.global_rules
142 }
143
144 pub fn list_rules(&self) -> Vec<ValidationRule> {
146 self.global_rules.clone()
147 }
148
149 pub fn subject_rules(&self, subject: &str) -> Option<&[ValidationRule]> {
151 self.subject_rules.get(subject).map(|v| v.as_slice())
152 }
153
154 pub fn clear(&mut self) {
156 self.global_rules.clear();
157 self.subject_rules.clear();
158 self.regex_cache.lock().clear();
159 }
160
161 fn get_or_compile_regex(&self, pattern: &str) -> SchemaResult<regex::Regex> {
166 let mut cache = self.regex_cache.lock();
167 if let Some(re) = cache.get(pattern) {
168 return Ok(re.clone());
169 }
170 let re = regex::RegexBuilder::new(pattern)
171 .size_limit(1_000_000)
172 .build()
173 .map_err(|e| SchemaError::Validation(format!("Invalid regex pattern: {}", e)))?;
174 const MAX_REGEX_CACHE: usize = 1_000;
176 if cache.len() >= MAX_REGEX_CACHE {
177 cache.clear();
178 }
179 cache.insert(pattern.to_string(), re.clone());
180 Ok(re)
181 }
182
183 pub fn validate(
185 &self,
186 schema_type: SchemaType,
187 subject: &str,
188 schema: &str,
189 ) -> SchemaResult<ValidationReport> {
190 let mut report = ValidationReport::new();
191 let mut rules_evaluated = 0;
192
193 let applicable_rules: Vec<&ValidationRule> = self
195 .global_rules
196 .iter()
197 .chain(self.subject_rules.get(subject).into_iter().flatten())
198 .filter(|r| r.applies(schema_type, subject))
199 .take(self.config.max_rules_per_schema)
200 .collect();
201
202 debug!(
203 "Validating schema for subject {} with {} applicable rules",
204 subject,
205 applicable_rules.len()
206 );
207
208 for rule in applicable_rules {
209 let result = self.execute_rule(rule, schema_type, schema)?;
210
211 if self.config.fail_fast && !result.passed && result.level == ValidationLevel::Error {
213 report.add_result(result);
214 return Ok(report);
215 }
216
217 report.add_result(result);
218 rules_evaluated += 1;
219 }
220
221 debug!(
222 "Validation complete: {} rules evaluated, {} errors, {} warnings",
223 rules_evaluated, report.summary.errors, report.summary.warnings
224 );
225
226 Ok(report)
227 }
228
229 fn execute_rule(
231 &self,
232 rule: &ValidationRule,
233 schema_type: SchemaType,
234 schema: &str,
235 ) -> SchemaResult<ValidationResult> {
236 match rule.rule_type {
237 ValidationRuleType::MaxSize => self.validate_max_size(rule, schema),
238 ValidationRuleType::NamingConvention => {
239 self.validate_naming_convention(rule, schema_type, schema)
240 }
241 ValidationRuleType::FieldRequired => {
242 self.validate_field_required(rule, schema_type, schema)
243 }
244 ValidationRuleType::FieldType => self.validate_field_type(rule, schema_type, schema),
245 ValidationRuleType::Regex => self.validate_regex(rule, schema),
246 ValidationRuleType::JsonSchema => self.validate_json_schema(rule, schema),
247 }
248 }
249
250 fn validate_max_size(
252 &self,
253 rule: &ValidationRule,
254 schema: &str,
255 ) -> SchemaResult<ValidationResult> {
256 #[derive(Deserialize)]
257 struct Config {
258 max_bytes: usize,
259 }
260
261 let config: Config = serde_json::from_str(&rule.config)
262 .map_err(|e| SchemaError::Validation(format!("Invalid max_size config: {}", e)))?;
263
264 let size = schema.len();
265 if size > config.max_bytes {
266 Ok(ValidationResult::fail(
267 &rule.name,
268 rule.level,
269 format!(
270 "Schema size {} bytes exceeds maximum {} bytes",
271 size, config.max_bytes
272 ),
273 ))
274 } else {
275 Ok(ValidationResult::pass(&rule.name))
276 }
277 }
278
279 fn validate_naming_convention(
281 &self,
282 rule: &ValidationRule,
283 schema_type: SchemaType,
284 schema: &str,
285 ) -> SchemaResult<ValidationResult> {
286 #[derive(Deserialize)]
287 struct Config {
288 pattern: String,
289 #[serde(default = "default_name_field")]
290 field: String,
291 }
292
293 fn default_name_field() -> String {
294 "name".to_string()
295 }
296
297 let config: Config = serde_json::from_str(&rule.config).map_err(|e| {
298 SchemaError::Validation(format!("Invalid naming_convention config: {}", e))
299 })?;
300
301 let regex = self.get_or_compile_regex(&config.pattern)?;
303
304 let name = match schema_type {
306 SchemaType::Avro | SchemaType::Json => {
307 let parsed: serde_json::Value = serde_json::from_str(schema)
308 .map_err(|e| SchemaError::Validation(format!("Invalid JSON schema: {}", e)))?;
309 parsed
310 .get(&config.field)
311 .and_then(|v| v.as_str())
312 .map(|s| s.to_string())
313 }
314 SchemaType::Protobuf => {
315 extract_protobuf_name(schema)
317 }
318 };
319
320 match name {
321 Some(n) if regex.is_match(&n) => Ok(ValidationResult::pass(&rule.name)),
322 Some(n) => Ok(ValidationResult::fail(
323 &rule.name,
324 rule.level,
325 format!("Name '{}' does not match pattern '{}'", n, config.pattern),
326 )),
327 None => Ok(ValidationResult::fail(
328 &rule.name,
329 rule.level,
330 format!("Could not extract '{}' field from schema", config.field),
331 )),
332 }
333 }
334
335 fn validate_field_required(
337 &self,
338 rule: &ValidationRule,
339 schema_type: SchemaType,
340 schema: &str,
341 ) -> SchemaResult<ValidationResult> {
342 #[derive(Deserialize)]
343 struct Config {
344 field: String,
345 }
346
347 let config: Config = serde_json::from_str(&rule.config).map_err(|e| {
348 SchemaError::Validation(format!("Invalid field_required config: {}", e))
349 })?;
350
351 match schema_type {
352 SchemaType::Avro | SchemaType::Json => {
353 let parsed: serde_json::Value = serde_json::from_str(schema)
354 .map_err(|e| SchemaError::Validation(format!("Invalid JSON schema: {}", e)))?;
355
356 if has_field_recursive(&parsed, &config.field) {
357 Ok(ValidationResult::pass(&rule.name))
358 } else {
359 Ok(ValidationResult::fail(
360 &rule.name,
361 rule.level,
362 format!("Required field '{}' not found in schema", config.field),
363 ))
364 }
365 }
366 SchemaType::Protobuf => {
367 let field_pattern = format!(r"\b{}\b", regex::escape(&config.field));
369 let field_regex = regex::Regex::new(&field_pattern).map_err(|e| {
370 SchemaError::Validation(format!("Invalid field pattern: {}", e))
371 })?;
372 if field_regex.is_match(schema) {
373 Ok(ValidationResult::pass(&rule.name))
374 } else {
375 Ok(ValidationResult::fail(
376 &rule.name,
377 rule.level,
378 format!("Required field '{}' not found in schema", config.field),
379 ))
380 }
381 }
382 }
383 }
384
385 fn validate_field_type(
387 &self,
388 rule: &ValidationRule,
389 schema_type: SchemaType,
390 schema: &str,
391 ) -> SchemaResult<ValidationResult> {
392 #[derive(Deserialize)]
393 struct Config {
394 field: String,
395 #[serde(rename = "type")]
396 expected_type: String,
397 }
398
399 let config: Config = serde_json::from_str(&rule.config)
400 .map_err(|e| SchemaError::Validation(format!("Invalid field_type config: {}", e)))?;
401
402 match schema_type {
403 SchemaType::Avro => {
404 let parsed: serde_json::Value = serde_json::from_str(schema)
405 .map_err(|e| SchemaError::Validation(format!("Invalid Avro schema: {}", e)))?;
406
407 if let Some(field_type) = find_avro_field_type(&parsed, &config.field) {
408 if field_type == config.expected_type {
409 Ok(ValidationResult::pass(&rule.name))
410 } else {
411 Ok(ValidationResult::fail(
412 &rule.name,
413 rule.level,
414 format!(
415 "Field '{}' has type '{}', expected '{}'",
416 config.field, field_type, config.expected_type
417 ),
418 ))
419 }
420 } else {
421 Ok(ValidationResult::fail(
422 &rule.name,
423 rule.level,
424 format!("Field '{}' not found in schema", config.field),
425 ))
426 }
427 }
428 _ => {
429 Ok(ValidationResult::pass(&rule.name))
431 }
432 }
433 }
434
435 fn validate_regex(
437 &self,
438 rule: &ValidationRule,
439 schema: &str,
440 ) -> SchemaResult<ValidationResult> {
441 #[derive(Deserialize)]
442 struct Config {
443 pattern: String,
444 #[serde(default)]
445 must_match: bool,
446 }
447
448 let config: Config = serde_json::from_str(&rule.config)
449 .map_err(|e| SchemaError::Validation(format!("Invalid regex config: {}", e)))?;
450
451 let regex = self.get_or_compile_regex(&config.pattern)?;
453
454 let matches = regex.is_match(schema);
455 let expected = config.must_match;
456
457 if matches == expected {
458 Ok(ValidationResult::pass(&rule.name))
459 } else if expected {
460 Ok(ValidationResult::fail(
461 &rule.name,
462 rule.level,
463 format!(
464 "Schema does not match required pattern '{}'",
465 config.pattern
466 ),
467 ))
468 } else {
469 Ok(ValidationResult::fail(
470 &rule.name,
471 rule.level,
472 format!("Schema matches forbidden pattern '{}'", config.pattern),
473 ))
474 }
475 }
476
477 fn validate_json_schema(
479 &self,
480 rule: &ValidationRule,
481 schema: &str,
482 ) -> SchemaResult<ValidationResult> {
483 #[cfg(feature = "json-schema")]
484 {
485 #[derive(Deserialize)]
486 struct Config {
487 schema: serde_json::Value,
488 }
489
490 let config: Config = serde_json::from_str(&rule.config).map_err(|e| {
491 SchemaError::Validation(format!("Invalid json_schema config: {}", e))
492 })?;
493
494 let instance: serde_json::Value = serde_json::from_str(schema)
496 .map_err(|e| SchemaError::Validation(format!("Invalid JSON in schema: {}", e)))?;
497
498 let validator = jsonschema::JSONSchema::compile(&config.schema).map_err(|e| {
500 SchemaError::Validation(format!("Invalid JSON Schema validator: {}", e))
501 })?;
502
503 if validator.is_valid(&instance) {
504 Ok(ValidationResult::pass(&rule.name))
505 } else {
506 let errors: Vec<String> = validator
507 .validate(&instance)
508 .err()
509 .into_iter()
510 .flatten()
511 .map(|e| e.to_string())
512 .take(3)
513 .collect();
514
515 Ok(ValidationResult::fail(
516 &rule.name,
517 rule.level,
518 format!("JSON Schema validation failed: {}", errors.join("; ")),
519 ))
520 }
521 }
522
523 #[cfg(not(feature = "json-schema"))]
524 {
525 let _ = (rule, schema); tracing::warn!("JSON Schema validation requires the 'json-schema' feature");
527 Ok(ValidationResult::fail(
528 &rule.name,
529 ValidationLevel::Warning,
530 "JSON Schema validation skipped: 'json-schema' feature not enabled",
531 ))
532 }
533 }
534}
535
536fn extract_protobuf_name(schema: &str) -> Option<String> {
538 for line in schema.lines() {
539 let trimmed = line.trim();
540 if trimmed.starts_with("message ") {
541 let name = trimmed
542 .strip_prefix("message ")?
543 .split_whitespace()
544 .next()?;
545 return Some(name.to_string());
546 }
547 }
548 None
549}
550
551fn has_field_recursive(value: &serde_json::Value, field: &str) -> bool {
553 match value {
554 serde_json::Value::Object(map) => {
555 if map.contains_key(field) {
556 return true;
557 }
558 for v in map.values() {
559 if has_field_recursive(v, field) {
560 return true;
561 }
562 }
563 false
564 }
565 serde_json::Value::Array(arr) => arr.iter().any(|v| has_field_recursive(v, field)),
566 _ => false,
567 }
568}
569
570fn find_avro_field_type(schema: &serde_json::Value, field_name: &str) -> Option<String> {
572 if let Some(fields) = schema.get("fields").and_then(|f| f.as_array()) {
573 for field in fields {
574 if field.get("name").and_then(|n| n.as_str()) == Some(field_name) {
575 return field.get("type").map(|t| match t {
576 serde_json::Value::String(s) => s.clone(),
577 serde_json::Value::Object(o) => o
578 .get("type")
579 .and_then(|t| t.as_str())
580 .unwrap_or("complex")
581 .to_string(),
582 serde_json::Value::Array(_) => "union".to_string(),
583 _ => "unknown".to_string(),
584 });
585 }
586 }
587 }
588 None
589}
590
591pub mod presets {
593 use super::*;
594
595 pub fn max_size(max_bytes: usize) -> ValidationRule {
597 ValidationRule::new(
598 "max-schema-size",
599 ValidationRuleType::MaxSize,
600 format!(r#"{{"max_bytes": {}}}"#, max_bytes),
601 )
602 .with_description(format!("Schema must be smaller than {} bytes", max_bytes))
603 }
604
605 pub fn require_doc() -> ValidationRule {
607 ValidationRule::new(
608 "require-doc",
609 ValidationRuleType::FieldRequired,
610 r#"{"field": "doc"}"#,
611 )
612 .with_description("Schema must have a 'doc' field for documentation")
613 .with_schema_types(vec![SchemaType::Avro])
614 }
615
616 pub fn require_namespace() -> ValidationRule {
618 ValidationRule::new(
619 "require-namespace",
620 ValidationRuleType::FieldRequired,
621 r#"{"field": "namespace"}"#,
622 )
623 .with_description("Avro schema must have a namespace")
624 .with_schema_types(vec![SchemaType::Avro])
625 }
626
627 pub fn pascal_case_name() -> ValidationRule {
629 ValidationRule::new(
630 "pascal-case-name",
631 ValidationRuleType::NamingConvention,
632 r#"{"pattern": "^[A-Z][a-zA-Z0-9]*$", "field": "name"}"#,
633 )
634 .with_description("Schema name must be PascalCase")
635 .with_level(ValidationLevel::Warning)
636 }
637
638 pub fn forbid_pattern(name: &str, pattern: &str, description: &str) -> ValidationRule {
640 ValidationRule::new(
641 name,
642 ValidationRuleType::Regex,
643 format!(r#"{{"pattern": "{}", "must_match": false}}"#, pattern),
644 )
645 .with_description(description)
646 }
647
648 pub fn production_ruleset() -> Vec<ValidationRule> {
650 vec![
651 max_size(100 * 1024), require_doc(), require_namespace(), pascal_case_name(), ]
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 #[test]
664 fn test_validation_engine_creation() {
665 let engine = ValidationEngine::new(ValidationEngineConfig::default());
666 assert!(engine.rules().is_empty());
667 }
668
669 #[test]
670 fn test_add_and_remove_rule() {
671 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
672
673 engine.add_rule(presets::max_size(1024));
674 assert_eq!(engine.rules().len(), 1);
675
676 engine.remove_rule("max-schema-size");
677 assert!(engine.rules().is_empty());
678 }
679
680 #[test]
681 fn test_max_size_validation() {
682 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
683 engine.add_rule(presets::max_size(100));
684
685 let small_schema = r#"{"type":"string"}"#;
687 let report = engine
688 .validate(SchemaType::Avro, "test", small_schema)
689 .unwrap();
690 assert!(report.is_valid());
691
692 let large_schema = "x".repeat(200);
694 let report = engine
695 .validate(SchemaType::Avro, "test", &large_schema)
696 .unwrap();
697 assert!(!report.is_valid());
698 }
699
700 #[test]
701 fn test_field_required_validation() {
702 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
703 engine.add_rule(presets::require_doc());
704
705 let with_doc = r#"{"type":"record","name":"User","doc":"A user","fields":[]}"#;
707 let report = engine.validate(SchemaType::Avro, "test", with_doc).unwrap();
708 assert!(report.is_valid());
709
710 let without_doc = r#"{"type":"record","name":"User","fields":[]}"#;
712 let report = engine
713 .validate(SchemaType::Avro, "test", without_doc)
714 .unwrap();
715 assert!(!report.is_valid());
716 }
717
718 #[test]
719 fn test_naming_convention_validation() {
720 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
721 engine.add_rule(
722 ValidationRule::new(
723 "pascal-case",
724 ValidationRuleType::NamingConvention,
725 r#"{"pattern": "^[A-Z][a-zA-Z0-9]*$"}"#,
726 )
727 .with_level(ValidationLevel::Error),
728 );
729
730 let pascal = r#"{"name":"UserEvent"}"#;
732 let report = engine.validate(SchemaType::Avro, "test", pascal).unwrap();
733 assert!(report.is_valid());
734
735 let camel = r#"{"name":"userEvent"}"#;
737 let report = engine.validate(SchemaType::Avro, "test", camel).unwrap();
738 assert!(!report.is_valid());
739 }
740
741 #[test]
742 fn test_regex_validation() {
743 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
744 engine.add_rule(presets::forbid_pattern(
745 "no-ssn",
746 r"ssn|social.?security",
747 "Schema must not contain SSN fields",
748 ));
749
750 let clean = r#"{"type":"record","name":"User","fields":[{"name":"id","type":"long"}]}"#;
752 let report = engine.validate(SchemaType::Avro, "test", clean).unwrap();
753 assert!(report.is_valid());
754
755 let with_ssn =
757 r#"{"type":"record","name":"User","fields":[{"name":"ssn","type":"string"}]}"#;
758 let report = engine.validate(SchemaType::Avro, "test", with_ssn).unwrap();
759 assert!(!report.is_valid());
760 }
761
762 #[test]
763 fn test_field_type_validation() {
764 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
765 engine.add_rule(ValidationRule::new(
766 "id-must-be-long",
767 ValidationRuleType::FieldType,
768 r#"{"field": "id", "type": "long"}"#,
769 ));
770
771 let correct = r#"{"type":"record","name":"User","fields":[{"name":"id","type":"long"}]}"#;
773 let report = engine.validate(SchemaType::Avro, "test", correct).unwrap();
774 assert!(report.is_valid());
775
776 let wrong = r#"{"type":"record","name":"User","fields":[{"name":"id","type":"int"}]}"#;
778 let report = engine.validate(SchemaType::Avro, "test", wrong).unwrap();
779 assert!(!report.is_valid());
780 }
781
782 #[test]
783 fn test_subject_specific_rules() {
784 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
785
786 engine.add_subject_rule(
788 "users-value",
789 ValidationRule::new(
790 "users-rule",
791 ValidationRuleType::MaxSize,
792 r#"{"max_bytes": 50}"#,
793 ),
794 );
795
796 let schema = r#"{"type":"string"}"#;
797
798 let report = engine
800 .validate(SchemaType::Avro, "orders-value", schema)
801 .unwrap();
802 assert!(report.is_valid());
803
804 let report = engine
807 .validate(SchemaType::Avro, "users-value", schema)
808 .unwrap();
809 assert!(report.is_valid());
810 }
811
812 #[test]
813 fn test_fail_fast() {
814 let config = ValidationEngineConfig::default().with_fail_fast(true);
815 let mut engine = ValidationEngine::new(config);
816
817 engine.add_rule(ValidationRule::new(
818 "rule1",
819 ValidationRuleType::MaxSize,
820 r#"{"max_bytes": 1}"#,
821 ));
822 engine.add_rule(ValidationRule::new(
823 "rule2",
824 ValidationRuleType::MaxSize,
825 r#"{"max_bytes": 2}"#,
826 ));
827
828 let schema = "xxx"; let report = engine.validate(SchemaType::Avro, "test", schema).unwrap();
830
831 assert_eq!(report.results.len(), 1);
833 }
834
835 #[test]
836 fn test_production_ruleset() {
837 let mut engine = ValidationEngine::new(ValidationEngineConfig::default());
838 engine.add_rules(presets::production_ruleset());
839
840 let schema = r#"{
842 "type": "record",
843 "name": "UserCreated",
844 "namespace": "com.example.events",
845 "doc": "Event emitted when a new user is created",
846 "fields": [
847 {"name": "userId", "type": "long", "doc": "Unique user ID"}
848 ]
849 }"#;
850
851 let report = engine
852 .validate(SchemaType::Avro, "users-value", schema)
853 .unwrap();
854 assert!(report.is_valid(), "Errors: {:?}", report.error_messages());
855 }
856
857 #[test]
858 fn test_protobuf_name_extraction() {
859 let proto = r#"
860 syntax = "proto3";
861 message UserEvent {
862 int64 id = 1;
863 }
864 "#;
865
866 let name = extract_protobuf_name(proto);
867 assert_eq!(name, Some("UserEvent".to_string()));
868 }
869}