1use indexmap::IndexSet;
17use once_cell::sync::Lazy;
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, VecDeque};
21use std::time::Instant;
22use tracing::{debug, warn};
23
24const MAX_RECURSIVE_DEPTH: usize = 3;
26
27const MAX_EXPANSION_RATIO: f64 = 10.0;
29
30const MAX_EXPANDED_SIZE: usize = 1_000_000; const MAX_ENTITY_CHAIN_LENGTH: usize = 50;
35
36static BUILTIN_ENTITIES: Lazy<IndexSet<&str>> = Lazy::new(|| {
38 let mut set = IndexSet::new();
39 set.insert("lt");
40 set.insert("gt");
41 set.insert("amp");
42 set.insert("quot");
43 set.insert("apos");
44 set
45});
46
47static MALICIOUS_PATTERNS: Lazy<Regex> = Lazy::new(|| {
49 Regex::new(
50 r"(?i)(lol|lol[2-9]|billion|bomb|evil|attack|exploit|payload|xxe|external|system|public)",
51 )
52 .unwrap()
53});
54
55static EXTERNAL_PATTERNS: Lazy<Regex> =
57 Lazy::new(|| Regex::new(r#"(?i)(SYSTEM|PUBLIC)\s+['"][^'"]*['"]"#).unwrap());
58
59static NETWORK_URL_PATTERNS: Lazy<Regex> =
61 Lazy::new(|| Regex::new(r"(?i)(https?://|ftp://|file://|ftps://|smb://|\\\\)").unwrap());
62
63static RECURSIVE_PATTERNS: Lazy<Regex> =
65 Lazy::new(|| Regex::new(r"&[a-zA-Z_][a-zA-Z0-9._-]*;").unwrap());
66
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub enum EntityClass {
70 SafeBuiltin,
72 SafeDdex,
74 CustomLocal,
76 Suspicious {
78 reason: String,
80 confidence: f64,
82 },
83 Malicious {
85 attack_type: AttackType,
87 reason: String,
89 },
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub enum AttackType {
95 ExternalEntity,
97 ExponentialExpansion,
99 RecursiveEntity,
101 NetworkRequest,
103 FileAccess,
105 ParameterEntity,
107 EntityBomb,
109}
110
111#[derive(Debug)]
113pub enum ClassificationResult {
114 Safe {
116 reason: String,
118 confidence: f64,
120 },
121 Malicious {
123 attack_type: AttackType,
125 reason: String,
127 },
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct Entity {
133 pub name: String,
135 pub value: String,
137 pub is_parameter: bool,
139 pub system_id: Option<String>,
141 pub public_id: Option<String>,
143 pub depth: usize,
145 pub size: usize,
147}
148
149#[derive(Debug, Clone)]
151pub struct ValidationResult {
152 pub is_safe: bool,
154 pub classification: EntityClass,
156 pub metrics: EntityMetrics,
158 pub warnings: Vec<String>,
160 pub errors: Vec<String>,
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct EntityMetrics {
167 pub entity_count: usize,
169 pub max_depth: usize,
171 pub total_expanded_size: usize,
173 pub expansion_ratio: f64,
175 pub external_references: usize,
177 pub network_urls: usize,
179 pub processing_time_ms: u64,
181}
182
183#[derive(Debug, Clone)]
185pub struct ClassifierConfig {
186 pub max_depth: usize,
188 pub max_expansion_ratio: f64,
190 pub max_expanded_size: usize,
192 pub allow_external_entities: bool,
194 pub allow_parameter_entities: bool,
196 pub custom_safe_entities: IndexSet<String>,
198 pub collect_metrics: bool,
200}
201
202impl Default for ClassifierConfig {
203 fn default() -> Self {
204 Self {
205 max_depth: MAX_RECURSIVE_DEPTH,
206 max_expansion_ratio: MAX_EXPANSION_RATIO,
207 max_expanded_size: MAX_EXPANDED_SIZE,
208 allow_external_entities: false,
209 allow_parameter_entities: false,
210 custom_safe_entities: IndexSet::new(),
211 collect_metrics: true,
212 }
213 }
214}
215
216pub struct EntityClassifier {
218 config: ClassifierConfig,
219 ddex_whitelist: IndexSet<String>,
220 entity_cache: HashMap<String, EntityClass>,
221 metrics_history: VecDeque<EntityMetrics>,
222}
223
224impl EntityClassifier {
225 pub fn new() -> Self {
227 Self::with_config(ClassifierConfig::default())
228 }
229
230 pub fn with_config(config: ClassifierConfig) -> Self {
232 let ddex_whitelist = Self::load_ddex_whitelist();
233
234 Self {
235 config,
236 ddex_whitelist,
237 entity_cache: HashMap::new(),
238 metrics_history: VecDeque::with_capacity(100), }
240 }
241
242 pub fn classify_entity(&mut self, name: &str, value: &str) -> EntityClass {
244 let cache_key = format!("{}:{}", name, value);
245
246 if let Some(cached) = self.entity_cache.get(&cache_key) {
248 return cached.clone();
249 }
250
251 let classification = self.classify_entity_internal(name, value);
252
253 self.entity_cache.insert(cache_key, classification.clone());
255
256 classification
257 }
258
259 fn classify_entity_internal(&self, name: &str, value: &str) -> EntityClass {
261 if BUILTIN_ENTITIES.contains(name) {
263 return EntityClass::SafeBuiltin;
264 }
265
266 if self.ddex_whitelist.contains(name) {
268 return EntityClass::SafeDdex;
269 }
270
271 if self.config.custom_safe_entities.contains(name) {
273 return EntityClass::SafeDdex; }
275
276 if EXTERNAL_PATTERNS.is_match(value) {
278 return EntityClass::Malicious {
279 attack_type: AttackType::ExternalEntity,
280 reason: "Entity contains SYSTEM or PUBLIC external reference".to_string(),
281 };
282 }
283
284 if NETWORK_URL_PATTERNS.is_match(value) {
286 return EntityClass::Malicious {
287 attack_type: AttackType::NetworkRequest,
288 reason: "Entity contains network URL".to_string(),
289 };
290 }
291
292 if MALICIOUS_PATTERNS.is_match(name) {
294 return EntityClass::Malicious {
295 attack_type: AttackType::EntityBomb,
296 reason: format!("Entity name '{}' matches known attack patterns", name),
297 };
298 }
299
300 let entity_refs = RECURSIVE_PATTERNS.find_iter(value).count();
302 if entity_refs > 5 {
303 return EntityClass::Suspicious {
304 reason: format!("Entity contains {} recursive references", entity_refs),
305 confidence: (entity_refs as f64 / 10.0).min(1.0),
306 };
307 }
308
309 if value.len() > 10000 {
311 return EntityClass::Suspicious {
312 reason: format!("Entity value is very large ({} bytes)", value.len()),
313 confidence: 0.7,
314 };
315 }
316
317 if self.has_repetitive_pattern(value) {
319 return EntityClass::Suspicious {
320 reason: "Entity contains repetitive patterns".to_string(),
321 confidence: 0.6,
322 };
323 }
324
325 EntityClass::CustomLocal
327 }
328
329 pub fn is_safe_entity(&mut self, entity: &Entity) -> bool {
331 let classification = self.classify_entity(&entity.name, &entity.value);
332
333 match classification {
334 EntityClass::SafeBuiltin | EntityClass::SafeDdex => true,
335 EntityClass::CustomLocal => {
336 entity.depth <= self.config.max_depth
338 && entity.size <= self.config.max_expanded_size
339 && !entity.is_parameter }
341 EntityClass::Suspicious { confidence, .. } => {
342 confidence < 0.5
344 }
345 EntityClass::Malicious { .. } => false,
346 }
347 }
348
349 pub fn validate_entity_chain(&mut self, entities: &[Entity]) -> ValidationResult {
351 let start_time = Instant::now();
352 let mut metrics = EntityMetrics::default();
353 let mut warnings = Vec::new();
354 let mut errors = Vec::new();
355 let mut most_dangerous = EntityClass::SafeBuiltin;
356 let mut is_safe = true;
357
358 if entities.len() > MAX_ENTITY_CHAIN_LENGTH {
360 errors.push(format!(
361 "Entity chain too long: {} entities (max: {})",
362 entities.len(),
363 MAX_ENTITY_CHAIN_LENGTH
364 ));
365 is_safe = false;
366 }
367
368 let mut total_input_size = 0;
370 let mut total_output_size = 0;
371 let mut max_depth = 0;
372 let mut external_refs = 0;
373 let mut network_urls = 0;
374
375 for entity in entities {
377 let classification = self.classify_entity(&entity.name, &entity.value);
378
379 total_input_size += entity.name.len() + 2; total_output_size += entity.size;
382 max_depth = max_depth.max(entity.depth);
383
384 if entity.system_id.is_some() || entity.public_id.is_some() {
385 external_refs += 1;
386 }
387
388 if NETWORK_URL_PATTERNS.is_match(&entity.value) {
389 network_urls += 1;
390 }
391
392 match &classification {
394 EntityClass::SafeBuiltin | EntityClass::SafeDdex => {
395 }
397 EntityClass::CustomLocal => {
398 if entity.depth > self.config.max_depth {
399 errors.push(format!(
400 "Entity '{}' exceeds maximum depth: {} > {}",
401 entity.name, entity.depth, self.config.max_depth
402 ));
403 is_safe = false;
404 }
405
406 if entity.is_parameter && !self.config.allow_parameter_entities {
407 errors.push(format!("Parameter entity '{}' not allowed", entity.name));
408 is_safe = false;
409 }
410 }
411 EntityClass::Suspicious { reason, confidence } => {
412 warnings.push(format!(
413 "Suspicious entity '{}': {} (confidence: {:.2})",
414 entity.name, reason, confidence
415 ));
416
417 if *confidence > 0.7 {
418 is_safe = false;
419 most_dangerous = classification.clone();
420 }
421 }
422 EntityClass::Malicious {
423 attack_type,
424 reason,
425 } => {
426 errors.push(format!(
427 "Malicious entity '{}' ({:?}): {}",
428 entity.name, attack_type, reason
429 ));
430 is_safe = false;
431 most_dangerous = classification.clone();
432 }
433 }
434 }
435
436 let expansion_ratio = if total_input_size > 0 {
438 total_output_size as f64 / total_input_size as f64
439 } else {
440 1.0
441 };
442
443 if expansion_ratio > self.config.max_expansion_ratio {
445 errors.push(format!(
446 "Expansion ratio too high: {:.2} > {}",
447 expansion_ratio, self.config.max_expansion_ratio
448 ));
449 is_safe = false;
450 }
451
452 if total_output_size > self.config.max_expanded_size {
453 errors.push(format!(
454 "Total expanded size too large: {} > {}",
455 total_output_size, self.config.max_expanded_size
456 ));
457 is_safe = false;
458 }
459
460 if external_refs > 0 && !self.config.allow_external_entities {
461 errors.push(format!(
462 "External entities not allowed ({} found)",
463 external_refs
464 ));
465 is_safe = false;
466 }
467
468 metrics.entity_count = entities.len();
470 metrics.max_depth = max_depth;
471 metrics.total_expanded_size = total_output_size;
472 metrics.expansion_ratio = expansion_ratio;
473 metrics.external_references = external_refs;
474 metrics.network_urls = network_urls;
475 metrics.processing_time_ms = start_time.elapsed().as_millis() as u64;
476
477 if self.config.collect_metrics {
479 self.metrics_history.push_back(metrics.clone());
480 if self.metrics_history.len() > 100 {
481 self.metrics_history.pop_front();
482 }
483 }
484
485 if !is_safe {
487 warn!(
488 "Entity chain validation failed: {} errors, {} warnings",
489 errors.len(),
490 warnings.len()
491 );
492 } else if !warnings.is_empty() {
493 debug!(
494 "Entity chain validation passed with {} warnings",
495 warnings.len()
496 );
497 }
498
499 ValidationResult {
500 is_safe,
501 classification: most_dangerous,
502 metrics,
503 warnings,
504 errors,
505 }
506 }
507
508 pub fn get_metrics_history(&self) -> &VecDeque<EntityMetrics> {
510 &self.metrics_history
511 }
512
513 pub fn clear_cache(&mut self) {
515 self.entity_cache.clear();
516 }
517
518 fn load_ddex_whitelist() -> IndexSet<String> {
520 let mut whitelist = IndexSet::new();
521
522 whitelist.insert("ddex".to_string());
525 whitelist.insert("ern".to_string());
526 whitelist.insert("avs".to_string());
527 whitelist.insert("iso".to_string());
528 whitelist.insert("musicbrainz".to_string());
529 whitelist.insert("isrc".to_string());
530 whitelist.insert("iswc".to_string());
531 whitelist.insert("isni".to_string());
532 whitelist.insert("dpid".to_string());
533 whitelist.insert("grid".to_string());
534 whitelist.insert("mwli".to_string());
535 whitelist.insert("spar".to_string());
536
537 whitelist.insert("NewReleaseMessage".to_string());
539 whitelist.insert("MessageHeader".to_string());
540 whitelist.insert("MessageId".to_string());
541 whitelist.insert("MessageSender".to_string());
542 whitelist.insert("SentOnBehalfOf".to_string());
543 whitelist.insert("MessageRecipient".to_string());
544 whitelist.insert("MessageCreatedDateTime".to_string());
545 whitelist.insert("MessageAuditTrail".to_string());
546
547 whitelist.insert("ReleaseList".to_string());
549 whitelist.insert("Release".to_string());
550 whitelist.insert("ReleaseId".to_string());
551 whitelist.insert("ReleaseReference".to_string());
552 whitelist.insert("ReferenceTitle".to_string());
553 whitelist.insert("ReleaseDetailsByTerritory".to_string());
554
555 whitelist.insert("ResourceList".to_string());
557 whitelist.insert("SoundRecording".to_string());
558 whitelist.insert("MusicalWork".to_string());
559 whitelist.insert("Image".to_string());
560 whitelist.insert("Text".to_string());
561 whitelist.insert("Video".to_string());
562
563 whitelist.insert("DealList".to_string());
565 whitelist.insert("ReleaseDeal".to_string());
566 whitelist.insert("Deal".to_string());
567 whitelist.insert("DealTerms".to_string());
568 whitelist.insert("CommercialModelType".to_string());
569 whitelist.insert("Usage".to_string());
570 whitelist.insert("Territory".to_string());
571
572 debug!("Loaded {} DDEX entities to whitelist", whitelist.len());
573
574 whitelist
575 }
576
577 fn has_repetitive_pattern(&self, value: &str) -> bool {
579 if value.len() < 20 {
580 return false;
581 }
582
583 let chars: Vec<char> = value.chars().collect();
585 let len = chars.len();
586
587 for pattern_len in 2..=10.min(len / 4) {
589 let mut matches = 0;
590 let pattern = &chars[0..pattern_len];
591
592 for i in (0..len).step_by(pattern_len) {
593 if i + pattern_len <= len && &chars[i..i + pattern_len] == pattern {
594 matches += 1;
595 }
596 }
597
598 if matches * pattern_len > len / 2 {
600 return true;
601 }
602 }
603
604 false
605 }
606}
607
608impl Default for EntityClassifier {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614pub fn create_entity(name: &str, value: &str) -> Entity {
616 Entity {
617 name: name.to_string(),
618 value: value.to_string(),
619 is_parameter: false,
620 system_id: None,
621 public_id: None,
622 depth: 0,
623 size: value.len(),
624 }
625}
626
627pub fn create_parameter_entity(name: &str, value: &str) -> Entity {
629 Entity {
630 name: name.to_string(),
631 value: value.to_string(),
632 is_parameter: true,
633 system_id: None,
634 public_id: None,
635 depth: 0,
636 size: value.len(),
637 }
638}
639
640pub fn create_external_entity(name: &str, system_id: &str) -> Entity {
642 Entity {
643 name: name.to_string(),
644 value: String::new(),
645 is_parameter: false,
646 system_id: Some(system_id.to_string()),
647 public_id: None,
648 depth: 0,
649 size: 0,
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_builtin_entity_classification() {
659 let mut classifier = EntityClassifier::new();
660
661 assert_eq!(
662 classifier.classify_entity("lt", "<"),
663 EntityClass::SafeBuiltin
664 );
665
666 assert_eq!(
667 classifier.classify_entity("amp", "&"),
668 EntityClass::SafeBuiltin
669 );
670 }
671
672 #[test]
673 fn test_ddex_entity_classification() {
674 let mut classifier = EntityClassifier::new();
675
676 assert_eq!(
677 classifier.classify_entity("ddex", "http://ddex.net/xml/ern/43"),
678 EntityClass::SafeDdex
679 );
680 }
681
682 #[test]
683 fn test_malicious_entity_detection() {
684 let mut classifier = EntityClassifier::new();
685
686 let result =
688 classifier.classify_entity("xxe", "<!ENTITY xxe SYSTEM \"file:///etc/passwd\">");
689
690 match result {
691 EntityClass::Malicious {
692 attack_type: AttackType::ExternalEntity,
693 ..
694 } => {}
695 _ => panic!("Should detect external entity attack"),
696 }
697
698 let result = classifier.classify_entity("evil", "http://attacker.com/evil.xml");
700
701 match result {
702 EntityClass::Malicious {
703 attack_type: AttackType::NetworkRequest,
704 ..
705 } => {}
706 _ => panic!("Should detect network request attack"),
707 }
708 }
709
710 #[test]
711 fn test_entity_chain_validation() {
712 let mut classifier = EntityClassifier::new();
713
714 let entities = vec![
715 create_entity("safe", "content"),
716 create_entity("lol", "&lol2;&lol2;&lol2;"),
717 create_entity("lol2", "&lol3;&lol3;&lol3;"),
718 create_entity("lol3", "haha"),
719 ];
720
721 let result = classifier.validate_entity_chain(&entities);
722 assert!(!result.is_safe);
723 assert!(!result.errors.is_empty());
724 }
725
726 #[test]
727 fn test_safe_entity_chain() {
728 let mut classifier = EntityClassifier::new();
729
730 let entities = vec![
731 create_entity("title", "My Song"),
732 create_entity("artist", "My Artist"),
733 ];
734
735 let result = classifier.validate_entity_chain(&entities);
736 assert!(result.is_safe);
737 assert!(result.errors.is_empty());
738 }
739
740 #[test]
741 fn test_expansion_ratio_detection() {
742 let mut classifier = EntityClassifier::new();
743
744 let entities = vec![Entity {
746 name: "bomb".to_string(),
747 value: "A".repeat(1000),
748 is_parameter: false,
749 system_id: None,
750 public_id: None,
751 depth: 0,
752 size: 1000,
753 }];
754
755 let result = classifier.validate_entity_chain(&entities);
756
757 assert!(result.metrics.expansion_ratio > 50.0);
759 }
760}