use parking_lot::RwLock;
use roxmltree::{Document as XmlDocument, Node, ParsingOptions};
use std::collections::HashMap;
use std::fmt;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationError {
SchemaParseError {
message: String,
},
DocumentParseError {
message: String,
line: Option<usize>,
column: Option<usize>,
},
ElementValidationError {
element: String,
expected: String,
found: String,
line: Option<usize>,
},
AttributeValidationError {
element: String,
attribute: String,
message: String,
line: Option<usize>,
},
TypeValidationError {
name: String,
expected_type: String,
value: String,
line: Option<usize>,
},
CardinalityError {
element: String,
min: usize,
max: Option<usize>,
actual: usize,
line: Option<usize>,
},
RequiredAttributeMissing {
element: String,
attribute: String,
line: Option<usize>,
},
UnknownElement {
element: String,
line: Option<usize>,
},
SchemaNotFound {
path: PathBuf,
},
IoError {
message: String,
},
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ValidationError::SchemaParseError { message } => {
write!(f, "Schema parse error: {}", message)
}
ValidationError::DocumentParseError {
message,
line,
column,
} => {
write!(f, "Document parse error: {}", message)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
if let Some(c) = column {
write!(f, ", column {}", c)?;
}
}
Ok(())
}
ValidationError::ElementValidationError {
element,
expected,
found,
line,
} => {
write!(
f,
"Element validation failed for '{}': expected {}, found '{}'",
element, expected, found
)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::AttributeValidationError {
element,
attribute,
message,
line,
} => {
write!(
f,
"Attribute validation failed for '{}.{}': {}",
element, attribute, message
)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::TypeValidationError {
name,
expected_type,
value,
line,
} => {
write!(
f,
"Type validation failed for '{}': expected {}, found '{}'",
name, expected_type, value
)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::CardinalityError {
element,
min,
max,
actual,
line,
} => {
write!(
f,
"Cardinality error for '{}': expected {}..{}, found {}",
element,
min,
max.map_or("unbounded".to_string(), |m| m.to_string()),
actual
)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::RequiredAttributeMissing {
element,
attribute,
line,
} => {
write!(
f,
"Required attribute '{}' missing from element '{}'",
attribute, element
)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::UnknownElement { element, line } => {
write!(f, "Unknown element '{}' not defined in schema", element)?;
if let Some(l) = line {
write!(f, " at line {}", l)?;
}
Ok(())
}
ValidationError::SchemaNotFound { path } => {
write!(f, "Schema file not found: {}", path.display())
}
ValidationError::IoError { message } => {
write!(f, "I/O error: {}", message)
}
}
}
}
impl std::error::Error for ValidationError {}
#[derive(Debug, Clone)]
struct Schema {
elements: HashMap<String, ElementDef>,
}
#[derive(Debug, Clone)]
struct ElementDef {
name: String,
type_name: Option<String>,
complex_type: Option<ComplexType>,
min_occurs: usize,
max_occurs: Option<usize>,
}
#[derive(Debug, Clone)]
struct ComplexType {
sequence: Vec<ElementDef>,
attributes: Vec<AttributeDef>,
}
#[derive(Debug, Clone)]
struct AttributeDef {
name: String,
type_name: String,
required: bool,
}
#[derive(Debug, Clone)]
pub struct SchemaValidator {
schema: Schema,
}
impl SchemaValidator {
pub fn from_xsd(xsd: &str) -> Result<Self, ValidationError> {
let schema = Self::parse_xsd(xsd)?;
Ok(Self { schema })
}
fn parse_xsd(xsd: &str) -> Result<Schema, ValidationError> {
let options = ParsingOptions {
allow_dtd: false, ..Default::default()
};
let doc = XmlDocument::parse_with_options(xsd, options).map_err(|e| {
ValidationError::SchemaParseError {
message: e.to_string(),
}
})?;
let root = doc.root_element();
if root.tag_name().name() != "schema" {
return Err(ValidationError::SchemaParseError {
message: "Root element must be <xs:schema>".to_string(),
});
}
let mut elements = HashMap::new();
for child in root.children().filter(|n| n.is_element()) {
if child.tag_name().name() == "element" {
let elem_def = Self::parse_element(&child)?;
elements.insert(elem_def.name.clone(), elem_def);
}
}
Ok(Schema { elements })
}
fn parse_element(node: &Node<'_, '_>) -> Result<ElementDef, ValidationError> {
let name = node
.attribute("name")
.ok_or_else(|| ValidationError::SchemaParseError {
message: "Element must have 'name' attribute".to_string(),
})?
.to_string();
let type_name = node.attribute("type").map(|s| s.to_string());
let min_occurs = node
.attribute("minOccurs")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(1);
let max_occurs = node.attribute("maxOccurs").and_then(|s| {
if s == "unbounded" {
None
} else {
s.parse::<usize>().ok()
}
});
let mut complex_type = None;
for child in node.children().filter(|n| n.is_element()) {
if child.tag_name().name() == "complexType" {
complex_type = Some(Self::parse_complex_type(&child)?);
break;
}
}
Ok(ElementDef {
name,
type_name,
complex_type,
min_occurs,
max_occurs,
})
}
fn parse_complex_type(node: &Node<'_, '_>) -> Result<ComplexType, ValidationError> {
let mut sequence = Vec::new();
let mut attributes = Vec::new();
for child in node.children().filter(|n| n.is_element()) {
match child.tag_name().name() {
"sequence" => {
for elem_node in child.children().filter(|n| n.is_element()) {
if elem_node.tag_name().name() == "element" {
sequence.push(Self::parse_element(&elem_node)?);
}
}
}
"attribute" => {
attributes.push(Self::parse_attribute(&child)?);
}
_ => {}
}
}
Ok(ComplexType {
sequence,
attributes,
})
}
fn parse_attribute(node: &Node<'_, '_>) -> Result<AttributeDef, ValidationError> {
let name = node
.attribute("name")
.ok_or_else(|| ValidationError::SchemaParseError {
message: "Attribute must have 'name' attribute".to_string(),
})?
.to_string();
let type_name = node.attribute("type").unwrap_or("xs:string").to_string();
let required = node.attribute("use") == Some("required");
Ok(AttributeDef {
name,
type_name,
required,
})
}
pub fn from_file(path: &Path) -> Result<Self, ValidationError> {
if !path.exists() {
return Err(ValidationError::SchemaNotFound {
path: path.to_path_buf(),
});
}
let content = fs::read_to_string(path).map_err(|e| ValidationError::IoError {
message: e.to_string(),
})?;
Self::from_xsd(&content)
}
pub fn validate(&self, xml: &str) -> Result<(), ValidationError> {
let options = ParsingOptions {
allow_dtd: false, ..Default::default()
};
let doc = XmlDocument::parse_with_options(xml, options).map_err(|e| {
ValidationError::DocumentParseError {
message: e.to_string(),
line: None,
column: None,
}
})?;
let root = doc.root_element();
let root_name = root.tag_name().name();
let schema_elem =
self.schema
.elements
.get(root_name)
.ok_or_else(|| ValidationError::UnknownElement {
element: root_name.to_string(),
line: Some(doc.text_pos_at(root.range().start).row as usize),
})?;
self.validate_element(&root, schema_elem)?;
Ok(())
}
fn validate_element(
&self,
node: &Node<'_, '_>,
schema_elem: &ElementDef,
) -> Result<(), ValidationError> {
let line = node.document().text_pos_at(node.range().start).row as usize;
if let Some(ref type_name) = schema_elem.type_name {
self.validate_type(node, type_name, line)?;
}
if let Some(ref complex_type) = schema_elem.complex_type {
self.validate_attributes_complex(node, complex_type, line)?;
self.validate_children_complex(node, complex_type, line)?;
}
Ok(())
}
fn validate_type(
&self,
node: &Node<'_, '_>,
type_ref: &str,
line: usize,
) -> Result<(), ValidationError> {
let text = node.text().unwrap_or("");
match type_ref {
"xs:string" | "string" => {
}
"xs:integer" | "integer" => {
if text.parse::<i64>().is_err() {
return Err(ValidationError::TypeValidationError {
name: node.tag_name().name().to_string(),
expected_type: "xs:integer".to_string(),
value: text.to_string(),
line: Some(line),
});
}
}
"xs:decimal" | "decimal" => {
if text.parse::<f64>().is_err() {
return Err(ValidationError::TypeValidationError {
name: node.tag_name().name().to_string(),
expected_type: "xs:decimal".to_string(),
value: text.to_string(),
line: Some(line),
});
}
}
"xs:boolean" | "boolean" => {
if !["true", "false", "1", "0"].contains(&text) {
return Err(ValidationError::TypeValidationError {
name: node.tag_name().name().to_string(),
expected_type: "xs:boolean".to_string(),
value: text.to_string(),
line: Some(line),
});
}
}
_ => {
}
}
Ok(())
}
fn validate_attributes_complex(
&self,
node: &Node<'_, '_>,
complex_type: &ComplexType,
line: usize,
) -> Result<(), ValidationError> {
let element_name = node.tag_name().name();
for attr_def in &complex_type.attributes {
if attr_def.required && node.attribute(attr_def.name.as_str()).is_none() {
return Err(ValidationError::RequiredAttributeMissing {
element: element_name.to_string(),
attribute: attr_def.name.clone(),
line: Some(line),
});
}
if let Some(value) = node.attribute(attr_def.name.as_str()) {
self.validate_simple_type(value, &attr_def.type_name)
.map_err(|_| ValidationError::AttributeValidationError {
element: element_name.to_string(),
attribute: attr_def.name.clone(),
message: format!("Expected type {}, found '{}'", attr_def.type_name, value),
line: Some(line),
})?;
}
}
Ok(())
}
fn validate_children_complex(
&self,
node: &Node<'_, '_>,
complex_type: &ComplexType,
line: usize,
) -> Result<(), ValidationError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
for child in &children {
let child_name = child.tag_name().name();
let schema_elem = complex_type
.sequence
.iter()
.find(|e| e.name == child_name)
.ok_or_else(|| ValidationError::UnknownElement {
element: child_name.to_string(),
line: Some(child.document().text_pos_at(child.range().start).row as usize),
})?;
self.validate_element(child, schema_elem)?;
}
for elem_def in &complex_type.sequence {
let count = children
.iter()
.filter(|n| n.tag_name().name() == elem_def.name)
.count();
if count < elem_def.min_occurs {
return Err(ValidationError::CardinalityError {
element: elem_def.name.clone(),
min: elem_def.min_occurs,
max: elem_def.max_occurs,
actual: count,
line: Some(line),
});
}
if let Some(max) = elem_def.max_occurs {
if count > max {
return Err(ValidationError::CardinalityError {
element: elem_def.name.clone(),
min: elem_def.min_occurs,
max: elem_def.max_occurs,
actual: count,
line: Some(line),
});
}
}
}
Ok(())
}
fn validate_simple_type(&self, value: &str, type_name: &str) -> Result<(), ()> {
match type_name {
"xs:string" | "string" => Ok(()),
"xs:integer" | "integer" => value.parse::<i64>().map(|_| ()).map_err(|_| ()),
"xs:decimal" | "decimal" => value.parse::<f64>().map(|_| ()).map_err(|_| ()),
"xs:boolean" | "boolean" => {
if ["true", "false", "1", "0"].contains(&value) {
Ok(())
} else {
Err(())
}
}
_ => Ok(()), }
}
}
pub struct SchemaCache {
cache: Arc<RwLock<HashMap<PathBuf, Arc<SchemaValidator>>>>,
max_size: usize,
}
impl SchemaCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
max_size,
}
}
pub fn get_or_load(&self, path: &Path) -> Result<Arc<SchemaValidator>, ValidationError> {
{
let cache = self.cache.read();
if let Some(validator) = cache.get(path) {
return Ok(Arc::clone(validator));
}
}
let mut cache = self.cache.write();
if let Some(validator) = cache.get(path) {
return Ok(Arc::clone(validator));
}
let validator = Arc::new(SchemaValidator::from_file(path)?);
if cache.len() >= self.max_size {
if let Some(oldest_key) = cache.keys().next().cloned() {
cache.remove(&oldest_key);
}
}
cache.insert(path.to_path_buf(), Arc::clone(&validator));
Ok(validator)
}
pub fn clear(&self) {
self.cache.write().clear();
}
pub fn size(&self) -> usize {
self.cache.read().len()
}
}
impl Default for SchemaCache {
fn default() -> Self {
Self::new(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
const SIMPLE_SCHEMA: &str = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="person">
<xs:complexType>
<xs:sequence>
<xs:element name="name" type="xs:string"/>
<xs:element name="age" type="xs:integer"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>"#;
#[test]
fn test_schema_validator_creation() {
let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA);
assert!(validator.is_ok());
}
#[test]
fn test_valid_document() {
let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
let xml = r#"<?xml version="1.0"?>
<person>
<name>Alice</name>
<age>30</age>
</person>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_invalid_type() {
let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
let xml = r#"<?xml version="1.0"?>
<person>
<name>Alice</name>
<age>thirty</age>
</person>"#;
let result = validator.validate(xml);
assert!(result.is_err());
if let Err(ValidationError::TypeValidationError {
name,
expected_type,
value,
..
}) = result
{
assert_eq!(name, "age");
assert_eq!(expected_type, "xs:integer");
assert_eq!(value, "thirty");
} else {
panic!("Expected TypeValidationError");
}
}
#[test]
fn test_unknown_element() {
let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
let xml = r#"<?xml version="1.0"?>
<person>
<name>Alice</name>
<age>30</age>
<email>alice@example.com</email>
</person>"#;
let result = validator.validate(xml);
assert!(result.is_err());
if let Err(ValidationError::UnknownElement { element, .. }) = result {
assert_eq!(element, "email");
} else {
panic!("Expected UnknownElement error");
}
}
#[test]
fn test_malformed_xml() {
let validator = SchemaValidator::from_xsd(SIMPLE_SCHEMA).unwrap();
let xml = r#"<?xml version="1.0"?>
<person>
<name>Alice
<age>30</age>
</person>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert!(matches!(
result,
Err(ValidationError::DocumentParseError { .. })
));
}
#[test]
fn test_schema_cache() {
use std::io::Write;
use tempfile::NamedTempFile;
let cache = SchemaCache::new(5);
assert_eq!(cache.size(), 0);
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
let path = temp_file.path();
let validator1 = cache.get_or_load(path).unwrap();
assert_eq!(cache.size(), 1);
let validator2 = cache.get_or_load(path).unwrap();
assert_eq!(cache.size(), 1);
assert!(Arc::ptr_eq(&validator1, &validator2));
cache.clear();
assert_eq!(cache.size(), 0);
}
#[test]
fn test_cache_eviction() {
use std::io::Write;
use tempfile::NamedTempFile;
let cache = SchemaCache::new(2);
let mut files = vec![];
for _ in 0..3 {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
files.push(temp_file);
}
cache.get_or_load(files[0].path()).unwrap();
cache.get_or_load(files[1].path()).unwrap();
assert_eq!(cache.size(), 2);
cache.get_or_load(files[2].path()).unwrap();
assert_eq!(cache.size(), 2);
}
#[test]
fn test_error_display() {
let err = ValidationError::TypeValidationError {
name: "age".to_string(),
expected_type: "xs:integer".to_string(),
value: "thirty".to_string(),
line: Some(5),
};
let display = err.to_string();
assert!(display.contains("age"));
assert!(display.contains("xs:integer"));
assert!(display.contains("thirty"));
assert!(display.contains("line 5"));
}
#[test]
fn test_schema_not_found() {
let result = SchemaValidator::from_file(Path::new("/nonexistent/schema.xsd"));
assert!(result.is_err());
assert!(matches!(
result,
Err(ValidationError::SchemaNotFound { .. })
));
}
#[test]
fn test_invalid_schema() {
let invalid_schema = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="broken" type="nonexistent:type"/>
</xs:schema>"#;
let _result = SchemaValidator::from_xsd(invalid_schema);
}
#[test]
fn test_boolean_type_validation() {
let schema = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="flag" type="xs:boolean"/>
</xs:schema>"#;
let validator = SchemaValidator::from_xsd(schema).unwrap();
for val in &["true", "false", "1", "0"] {
let xml = format!(r#"<?xml version="1.0"?><flag>{}</flag>"#, val);
assert!(validator.validate(&xml).is_ok());
}
let xml = r#"<?xml version="1.0"?><flag>yes</flag>"#;
assert!(validator.validate(xml).is_err());
}
#[test]
fn test_decimal_type_validation() {
let schema = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="price" type="xs:decimal"/>
</xs:schema>"#;
let validator = SchemaValidator::from_xsd(schema).unwrap();
let xml = r#"<?xml version="1.0"?><price>19.99</price>"#;
assert!(validator.validate(xml).is_ok());
let xml = r#"<?xml version="1.0"?><price>not a number</price>"#;
assert!(validator.validate(xml).is_err());
}
#[test]
fn test_concurrent_cache_access() {
use std::io::Write;
use std::sync::Arc;
use std::thread;
use tempfile::NamedTempFile;
let cache = Arc::new(SchemaCache::new(10));
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(SIMPLE_SCHEMA.as_bytes()).unwrap();
let path = temp_file.path().to_path_buf();
let mut handles = vec![];
for _ in 0..10 {
let cache_clone = Arc::clone(&cache);
let path_clone = path.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
let _validator = cache_clone.get_or_load(&path_clone).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cache.size(), 1);
}
}