1use parking_lot::RwLock;
114use roxmltree::{Document as XmlDocument, Node};
115use std::collections::HashMap;
116use std::fmt;
117use std::fs;
118use std::path::{Path, PathBuf};
119use std::sync::Arc;
120
121#[derive(Debug, Clone, PartialEq)]
123pub enum ValidationError {
124 SchemaParseError {
126 message: String,
128 },
129
130 DocumentParseError {
132 message: String,
134 line: Option<usize>,
136 column: Option<usize>,
138 },
139
140 ElementValidationError {
142 element: String,
144 expected: String,
146 found: String,
148 line: Option<usize>,
150 },
151
152 AttributeValidationError {
154 element: String,
156 attribute: String,
158 message: String,
160 line: Option<usize>,
162 },
163
164 TypeValidationError {
166 name: String,
168 expected_type: String,
170 value: String,
172 line: Option<usize>,
174 },
175
176 CardinalityError {
178 element: String,
180 min: usize,
182 max: Option<usize>,
184 actual: usize,
186 line: Option<usize>,
188 },
189
190 RequiredAttributeMissing {
192 element: String,
194 attribute: String,
196 line: Option<usize>,
198 },
199
200 UnknownElement {
202 element: String,
204 line: Option<usize>,
206 },
207
208 SchemaNotFound {
210 path: PathBuf,
212 },
213
214 IoError {
216 message: String,
218 },
219}
220
221impl fmt::Display for ValidationError {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 match self {
224 ValidationError::SchemaParseError { message } => {
225 write!(f, "Schema parse error: {}", message)
226 }
227 ValidationError::DocumentParseError {
228 message,
229 line,
230 column,
231 } => {
232 write!(f, "Document parse error: {}", message)?;
233 if let Some(l) = line {
234 write!(f, " at line {}", l)?;
235 if let Some(c) = column {
236 write!(f, ", column {}", c)?;
237 }
238 }
239 Ok(())
240 }
241 ValidationError::ElementValidationError {
242 element,
243 expected,
244 found,
245 line,
246 } => {
247 write!(
248 f,
249 "Element validation failed for '{}': expected {}, found '{}'",
250 element, expected, found
251 )?;
252 if let Some(l) = line {
253 write!(f, " at line {}", l)?;
254 }
255 Ok(())
256 }
257 ValidationError::AttributeValidationError {
258 element,
259 attribute,
260 message,
261 line,
262 } => {
263 write!(
264 f,
265 "Attribute validation failed for '{}.{}': {}",
266 element, attribute, message
267 )?;
268 if let Some(l) = line {
269 write!(f, " at line {}", l)?;
270 }
271 Ok(())
272 }
273 ValidationError::TypeValidationError {
274 name,
275 expected_type,
276 value,
277 line,
278 } => {
279 write!(
280 f,
281 "Type validation failed for '{}': expected {}, found '{}'",
282 name, expected_type, value
283 )?;
284 if let Some(l) = line {
285 write!(f, " at line {}", l)?;
286 }
287 Ok(())
288 }
289 ValidationError::CardinalityError {
290 element,
291 min,
292 max,
293 actual,
294 line,
295 } => {
296 write!(
297 f,
298 "Cardinality error for '{}': expected {}..{}, found {}",
299 element,
300 min,
301 max.map_or("unbounded".to_string(), |m| m.to_string()),
302 actual
303 )?;
304 if let Some(l) = line {
305 write!(f, " at line {}", l)?;
306 }
307 Ok(())
308 }
309 ValidationError::RequiredAttributeMissing {
310 element,
311 attribute,
312 line,
313 } => {
314 write!(
315 f,
316 "Required attribute '{}' missing from element '{}'",
317 attribute, element
318 )?;
319 if let Some(l) = line {
320 write!(f, " at line {}", l)?;
321 }
322 Ok(())
323 }
324 ValidationError::UnknownElement { element, line } => {
325 write!(f, "Unknown element '{}' not defined in schema", element)?;
326 if let Some(l) = line {
327 write!(f, " at line {}", l)?;
328 }
329 Ok(())
330 }
331 ValidationError::SchemaNotFound { path } => {
332 write!(f, "Schema file not found: {}", path.display())
333 }
334 ValidationError::IoError { message } => {
335 write!(f, "I/O error: {}", message)
336 }
337 }
338 }
339}
340
341impl std::error::Error for ValidationError {}
342
343#[derive(Debug, Clone)]
345struct Schema {
346 elements: HashMap<String, ElementDef>,
347 #[allow(dead_code)]
348 target_namespace: Option<String>,
349}
350
351#[derive(Debug, Clone)]
353struct ElementDef {
354 name: String,
355 type_name: Option<String>,
356 complex_type: Option<ComplexType>,
357 min_occurs: usize,
358 max_occurs: Option<usize>,
359}
360
361#[derive(Debug, Clone)]
363struct ComplexType {
364 sequence: Vec<ElementDef>,
365 attributes: Vec<AttributeDef>,
366}
367
368#[derive(Debug, Clone)]
370struct AttributeDef {
371 name: String,
372 type_name: String,
373 required: bool,
374}
375
376#[derive(Debug, Clone)]
380pub struct SchemaValidator {
381 schema: Schema,
382}
383
384impl SchemaValidator {
385 pub fn from_xsd(xsd: &str) -> Result<Self, ValidationError> {
409 let schema = Self::parse_xsd(xsd)?;
410 Ok(Self { schema })
411 }
412
413 fn parse_xsd(xsd: &str) -> Result<Schema, ValidationError> {
415 let doc = XmlDocument::parse(xsd).map_err(|e| ValidationError::SchemaParseError {
416 message: e.to_string(),
417 })?;
418
419 let root = doc.root_element();
420
421 if root.tag_name().name() != "schema" {
423 return Err(ValidationError::SchemaParseError {
424 message: "Root element must be <xs:schema>".to_string(),
425 });
426 }
427
428 let target_namespace = root.attribute("targetNamespace").map(|s| s.to_string());
429 let mut elements = HashMap::new();
430
431 for child in root.children().filter(|n| n.is_element()) {
433 if child.tag_name().name() == "element" {
434 let elem_def = Self::parse_element(&child)?;
435 elements.insert(elem_def.name.clone(), elem_def);
436 }
437 }
438
439 Ok(Schema {
440 elements,
441 target_namespace,
442 })
443 }
444
445 fn parse_element(node: &Node) -> Result<ElementDef, ValidationError> {
447 let name = node
448 .attribute("name")
449 .ok_or_else(|| ValidationError::SchemaParseError {
450 message: "Element must have 'name' attribute".to_string(),
451 })?
452 .to_string();
453
454 let type_name = node.attribute("type").map(|s| s.to_string());
455 let min_occurs = node
456 .attribute("minOccurs")
457 .and_then(|s| s.parse::<usize>().ok())
458 .unwrap_or(1);
459 let max_occurs = node.attribute("maxOccurs").and_then(|s| {
460 if s == "unbounded" {
461 None
462 } else {
463 s.parse::<usize>().ok()
464 }
465 });
466
467 let mut complex_type = None;
469 for child in node.children().filter(|n| n.is_element()) {
470 if child.tag_name().name() == "complexType" {
471 complex_type = Some(Self::parse_complex_type(&child)?);
472 break;
473 }
474 }
475
476 Ok(ElementDef {
477 name,
478 type_name,
479 complex_type,
480 min_occurs,
481 max_occurs,
482 })
483 }
484
485 fn parse_complex_type(node: &Node) -> Result<ComplexType, ValidationError> {
487 let mut sequence = Vec::new();
488 let mut attributes = Vec::new();
489
490 for child in node.children().filter(|n| n.is_element()) {
491 match child.tag_name().name() {
492 "sequence" => {
493 for elem_node in child.children().filter(|n| n.is_element()) {
494 if elem_node.tag_name().name() == "element" {
495 sequence.push(Self::parse_element(&elem_node)?);
496 }
497 }
498 }
499 "attribute" => {
500 attributes.push(Self::parse_attribute(&child)?);
501 }
502 _ => {}
503 }
504 }
505
506 Ok(ComplexType {
507 sequence,
508 attributes,
509 })
510 }
511
512 fn parse_attribute(node: &Node) -> Result<AttributeDef, ValidationError> {
514 let name = node
515 .attribute("name")
516 .ok_or_else(|| ValidationError::SchemaParseError {
517 message: "Attribute must have 'name' attribute".to_string(),
518 })?
519 .to_string();
520
521 let type_name = node
522 .attribute("type")
523 .unwrap_or("xs:string")
524 .to_string();
525
526 let required = node.attribute("use") == Some("required");
527
528 Ok(AttributeDef {
529 name,
530 type_name,
531 required,
532 })
533 }
534
535 pub fn from_file(path: &Path) -> Result<Self, ValidationError> {
557 if !path.exists() {
558 return Err(ValidationError::SchemaNotFound {
559 path: path.to_path_buf(),
560 });
561 }
562
563 let content = fs::read_to_string(path).map_err(|e| ValidationError::IoError {
564 message: e.to_string(),
565 })?;
566
567 Self::from_xsd(&content)
568 }
569
570 pub fn validate(&self, xml: &str) -> Result<(), ValidationError> {
597 let doc = XmlDocument::parse(xml).map_err(|e| ValidationError::DocumentParseError {
598 message: e.to_string(),
599 line: None,
600 column: None,
601 })?;
602
603 let root = doc.root_element();
604 let root_name = root.tag_name().name();
605
606 let schema_elem = self
608 .schema
609 .elements
610 .get(root_name)
611 .ok_or_else(|| ValidationError::UnknownElement {
612 element: root_name.to_string(),
613 line: Some(doc.text_pos_at(root.range().start).row as usize),
614 })?;
615
616 self.validate_element(&root, schema_elem)?;
617
618 Ok(())
619 }
620
621 fn validate_element(
623 &self,
624 node: &Node,
625 schema_elem: &ElementDef,
626 ) -> Result<(), ValidationError> {
627 let line = node.document().text_pos_at(node.range().start).row as usize;
628
629 if let Some(ref type_name) = schema_elem.type_name {
631 self.validate_type(node, type_name, line)?;
632 }
633
634 if let Some(ref complex_type) = schema_elem.complex_type {
636 self.validate_attributes_complex(node, complex_type, line)?;
638
639 self.validate_children_complex(node, complex_type, line)?;
641 }
642
643 Ok(())
644 }
645
646 fn validate_type(
648 &self,
649 node: &Node,
650 type_ref: &str,
651 line: usize,
652 ) -> Result<(), ValidationError> {
653 let text = node.text().unwrap_or("");
654
655 match type_ref {
657 "xs:string" | "string" => {
658 }
660 "xs:integer" | "integer" => {
661 if text.parse::<i64>().is_err() {
662 return Err(ValidationError::TypeValidationError {
663 name: node.tag_name().name().to_string(),
664 expected_type: "xs:integer".to_string(),
665 value: text.to_string(),
666 line: Some(line),
667 });
668 }
669 }
670 "xs:decimal" | "decimal" => {
671 if text.parse::<f64>().is_err() {
672 return Err(ValidationError::TypeValidationError {
673 name: node.tag_name().name().to_string(),
674 expected_type: "xs:decimal".to_string(),
675 value: text.to_string(),
676 line: Some(line),
677 });
678 }
679 }
680 "xs:boolean" | "boolean" => {
681 if !["true", "false", "1", "0"].contains(&text) {
682 return Err(ValidationError::TypeValidationError {
683 name: node.tag_name().name().to_string(),
684 expected_type: "xs:boolean".to_string(),
685 value: text.to_string(),
686 line: Some(line),
687 });
688 }
689 }
690 _ => {
691 }
693 }
694
695 Ok(())
696 }
697
698 fn validate_attributes_complex(
700 &self,
701 node: &Node,
702 complex_type: &ComplexType,
703 line: usize,
704 ) -> Result<(), ValidationError> {
705 let element_name = node.tag_name().name();
706
707 for attr_def in &complex_type.attributes {
709 if attr_def.required && node.attribute(attr_def.name.as_str()).is_none() {
710 return Err(ValidationError::RequiredAttributeMissing {
711 element: element_name.to_string(),
712 attribute: attr_def.name.clone(),
713 line: Some(line),
714 });
715 }
716
717 if let Some(value) = node.attribute(attr_def.name.as_str()) {
719 self.validate_simple_type(value, &attr_def.type_name).map_err(|_| {
720 ValidationError::AttributeValidationError {
721 element: element_name.to_string(),
722 attribute: attr_def.name.clone(),
723 message: format!(
724 "Expected type {}, found '{}'",
725 attr_def.type_name, value
726 ),
727 line: Some(line),
728 }
729 })?;
730 }
731 }
732
733 Ok(())
734 }
735
736 fn validate_children_complex(
738 &self,
739 node: &Node,
740 complex_type: &ComplexType,
741 line: usize,
742 ) -> Result<(), ValidationError> {
743 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
744
745 for child in &children {
747 let child_name = child.tag_name().name();
748
749 let schema_elem = complex_type
751 .sequence
752 .iter()
753 .find(|e| e.name == child_name)
754 .ok_or_else(|| ValidationError::UnknownElement {
755 element: child_name.to_string(),
756 line: Some(child.document().text_pos_at(child.range().start).row as usize),
757 })?;
758
759 self.validate_element(child, schema_elem)?;
760 }
761
762 for elem_def in &complex_type.sequence {
764 let count = children
765 .iter()
766 .filter(|n| n.tag_name().name() == elem_def.name)
767 .count();
768
769 if count < elem_def.min_occurs {
770 return Err(ValidationError::CardinalityError {
771 element: elem_def.name.clone(),
772 min: elem_def.min_occurs,
773 max: elem_def.max_occurs,
774 actual: count,
775 line: Some(line),
776 });
777 }
778
779 if let Some(max) = elem_def.max_occurs {
780 if count > max {
781 return Err(ValidationError::CardinalityError {
782 element: elem_def.name.clone(),
783 min: elem_def.min_occurs,
784 max: elem_def.max_occurs,
785 actual: count,
786 line: Some(line),
787 });
788 }
789 }
790 }
791
792 Ok(())
793 }
794
795 fn validate_simple_type(&self, value: &str, type_name: &str) -> Result<(), ()> {
797 match type_name {
798 "xs:string" | "string" => Ok(()),
799 "xs:integer" | "integer" => value.parse::<i64>().map(|_| ()).map_err(|_| ()),
800 "xs:decimal" | "decimal" => value.parse::<f64>().map(|_| ()).map_err(|_| ()),
801 "xs:boolean" | "boolean" => {
802 if ["true", "false", "1", "0"].contains(&value) {
803 Ok(())
804 } else {
805 Err(())
806 }
807 }
808 _ => Ok(()), }
810 }
811}
812
813pub struct SchemaCache {
833 cache: Arc<RwLock<HashMap<PathBuf, Arc<SchemaValidator>>>>,
834 max_size: usize,
835}
836
837impl SchemaCache {
838 pub fn new(max_size: usize) -> Self {
852 Self {
853 cache: Arc::new(RwLock::new(HashMap::new())),
854 max_size,
855 }
856 }
857
858 pub fn get_or_load(&self, path: &Path) -> Result<Arc<SchemaValidator>, ValidationError> {
882 {
884 let cache = self.cache.read();
885 if let Some(validator) = cache.get(path) {
886 return Ok(Arc::clone(validator));
887 }
888 }
889
890 let mut cache = self.cache.write();
892
893 if let Some(validator) = cache.get(path) {
895 return Ok(Arc::clone(validator));
896 }
897
898 let validator = Arc::new(SchemaValidator::from_file(path)?);
900
901 if cache.len() >= self.max_size {
903 if let Some(oldest_key) = cache.keys().next().cloned() {
904 cache.remove(&oldest_key);
905 }
906 }
907
908 cache.insert(path.to_path_buf(), Arc::clone(&validator));
909
910 Ok(validator)
911 }
912
913 pub fn clear(&self) {
924 self.cache.write().clear();
925 }
926
927 pub fn size(&self) -> usize {
938 self.cache.read().len()
939 }
940}
941
942impl Default for SchemaCache {
943 fn default() -> Self {
945 Self::new(100)
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952
953 const SIMPLE_SCHEMA: &str = r#"<?xml version="1.0"?>
954<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
955 <xs:element name="person">
956 <xs:complexType>
957 <xs:sequence>
958 <xs:element name="name" type="xs:string"/>
959 <xs:element name="age" type="xs:integer"/>
960 </xs:sequence>
961 </xs:complexType>
962 </xs:element>
963</xs:schema>"#;
964
965 #[test]
966 fn test_schema_validator_creation() {
967 let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA);
968 assert!(validator.is_ok());
969 }
970
971 #[test]
972 fn test_valid_document() {
973 let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
974
975 let xml = r#"<?xml version="1.0"?>
976<person>
977 <name>Alice</name>
978 <age>30</age>
979</person>"#;
980
981 assert!(validator.validate(xml).is_ok());
982 }
983
984 #[test]
985 fn test_invalid_type() {
986 let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
987
988 let xml = r#"<?xml version="1.0"?>
989<person>
990 <name>Alice</name>
991 <age>thirty</age>
992</person>"#;
993
994 let result = validator.validate(xml);
995 assert!(result.is_err());
996
997 if let Err(ValidationError::TypeValidationError {
998 name,
999 expected_type,
1000 value,
1001 ..
1002 }) = result
1003 {
1004 assert_eq!(name, "age");
1005 assert_eq!(expected_type, "xs:integer");
1006 assert_eq!(value, "thirty");
1007 } else {
1008 panic!("Expected TypeValidationError");
1009 }
1010 }
1011
1012 #[test]
1013 fn test_unknown_element() {
1014 let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
1015
1016 let xml = r#"<?xml version="1.0"?>
1017<person>
1018 <name>Alice</name>
1019 <age>30</age>
1020 <email>alice@example.com</email>
1021</person>"#;
1022
1023 let result = validator.validate(xml);
1024 assert!(result.is_err());
1025
1026 if let Err(ValidationError::UnknownElement { element, .. }) = result {
1027 assert_eq!(element, "email");
1028 } else {
1029 panic!("Expected UnknownElement error");
1030 }
1031 }
1032
1033 #[test]
1034 fn test_malformed_xml() {
1035 let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
1036
1037 let xml = r#"<?xml version="1.0"?>
1038<person>
1039 <name>Alice
1040 <age>30</age>
1041</person>"#;
1042
1043 let result = validator.validate(xml);
1044 assert!(result.is_err());
1045 assert!(matches!(
1046 result,
1047 Err(ValidationError::DocumentParseError { .. })
1048 ));
1049 }
1050
1051 #[test]
1052 fn test_schema_cache() {
1053 use std::io::Write;
1054 use tempfile::NamedTempFile;
1055
1056 let cache = SchemaCache::new(5);
1057 assert_eq!(cache.size(), 0);
1058
1059 let mut temp_file = NamedTempFile::new().unwrap();
1061 temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
1062 let path = temp_file.path();
1063
1064 let validator1 = cache.get_or_load(path).unwrap();
1066 assert_eq!(cache.size(), 1);
1067
1068 let validator2 = cache.get_or_load(path).unwrap();
1070 assert_eq!(cache.size(), 1);
1071
1072 assert!(Arc::ptr_eq(&validator1, &validator2));
1074
1075 cache.clear();
1077 assert_eq!(cache.size(), 0);
1078 }
1079
1080 #[test]
1081 fn test_cache_eviction() {
1082 use std::io::Write;
1083 use tempfile::NamedTempFile;
1084
1085 let cache = SchemaCache::new(2);
1086
1087 let mut files = vec![];
1089 for _ in 0..3 {
1090 let mut temp_file = NamedTempFile::new().unwrap();
1091 temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
1092 files.push(temp_file);
1093 }
1094
1095 cache.get_or_load(files[0].path()).unwrap();
1097 cache.get_or_load(files[1].path()).unwrap();
1098 assert_eq!(cache.size(), 2);
1099
1100 cache.get_or_load(files[2].path()).unwrap();
1102 assert_eq!(cache.size(), 2);
1103 }
1104
1105 #[test]
1106 fn test_error_display() {
1107 let err = ValidationError::TypeValidationError {
1108 name: "age".to_string(),
1109 expected_type: "xs:integer".to_string(),
1110 value: "thirty".to_string(),
1111 line: Some(5),
1112 };
1113
1114 let display = err.to_string();
1115 assert!(display.contains("age"));
1116 assert!(display.contains("xs:integer"));
1117 assert!(display.contains("thirty"));
1118 assert!(display.contains("line 5"));
1119 }
1120
1121 #[test]
1122 fn test_schema_not_found() {
1123 let result = SchemaValidator::from_file(Path::new("/nonexistent/schema.xsd"));
1124 assert!(result.is_err());
1125 assert!(matches!(
1126 result,
1127 Err(ValidationError::SchemaNotFound { .. })
1128 ));
1129 }
1130
1131 #[test]
1132 fn test_invalid_schema() {
1133 let invalid_schema = r#"<?xml version="1.0"?>
1134<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1135 <xs:element name="broken" type="nonexistent:type"/>
1136</xs:schema>"#;
1137
1138 let _result = SchemaValidator::from_xsd(invalid_schema);
1139 }
1142
1143 #[test]
1144 fn test_boolean_type_validation() {
1145 let schema = r#"<?xml version="1.0"?>
1146<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1147 <xs:element name="flag" type="xs:boolean"/>
1148</xs:schema>"#;
1149
1150 let validator = SchemaValidator::from_xsd(schema).unwrap();
1151
1152 for val in &["true", "false", "1", "0"] {
1154 let xml = format!(r#"<?xml version="1.0"?><flag>{}</flag>"#, val);
1155 assert!(validator.validate(&xml).is_ok());
1156 }
1157
1158 let xml = r#"<?xml version="1.0"?><flag>yes</flag>"#;
1160 assert!(validator.validate(xml).is_err());
1161 }
1162
1163 #[test]
1164 fn test_decimal_type_validation() {
1165 let schema = r#"<?xml version="1.0"?>
1166<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1167 <xs:element name="price" type="xs:decimal"/>
1168</xs:schema>"#;
1169
1170 let validator = SchemaValidator::from_xsd(schema).unwrap();
1171
1172 let xml = r#"<?xml version="1.0"?><price>19.99</price>"#;
1174 assert!(validator.validate(xml).is_ok());
1175
1176 let xml = r#"<?xml version="1.0"?><price>not a number</price>"#;
1178 assert!(validator.validate(xml).is_err());
1179 }
1180
1181 #[test]
1182 fn test_concurrent_cache_access() {
1183 use std::io::Write;
1184 use std::sync::Arc;
1185 use std::thread;
1186 use tempfile::NamedTempFile;
1187
1188 let cache = Arc::new(SchemaCache::new(10));
1189
1190 let mut temp_file = NamedTempFile::new().unwrap();
1192 temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
1193 let path = temp_file.path().to_path_buf();
1194
1195 let mut handles = vec![];
1197 for _ in 0..10 {
1198 let cache_clone = Arc::clone(&cache);
1199 let path_clone = path.clone();
1200 let handle = thread::spawn(move || {
1201 for _ in 0..100 {
1202 let _validator = cache_clone.get_or_load(&path_clone).unwrap();
1203 }
1204 });
1205 handles.push(handle);
1206 }
1207
1208 for handle in handles {
1210 handle.join().unwrap();
1211 }
1212
1213 assert_eq!(cache.size(), 1);
1215 }
1216}