ddex_builder/security/
entity_classifier.rs

1//! Entity Classification System for DDEX Builder
2//!
3//! This module provides a comprehensive multi-layer classification system to distinguish
4//! between legitimate DDEX entities and malicious attempts. It implements defense against
5//! XXE attacks, entity expansion attacks, and other XML-based security threats.
6//!
7//! ## Features
8//!
9//! - Multi-layer entity classification (SafeBuiltin, SafeDdex, CustomLocal, Suspicious, Malicious)
10//! - Recursive depth tracking and expansion ratio calculation
11//! - DDEX-specific entity whitelist from official schemas
12//! - Pattern matching for known attack vectors
13//! - Metrics collection for security monitoring
14//! - Performance-optimized caching system
15
16use indexmap::IndexSet;
17use once_cell::sync::Lazy;
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, VecDeque};
21use std::time::Instant;
22use tracing::{debug, warn};
23
24/// Maximum allowed recursive depth for entity expansion
25const MAX_RECURSIVE_DEPTH: usize = 3;
26
27/// Maximum allowed expansion ratio (output size / input size)
28const MAX_EXPANSION_RATIO: f64 = 10.0;
29
30/// Maximum total expanded size in bytes
31const MAX_EXPANDED_SIZE: usize = 1_000_000; // 1MB
32
33/// Maximum number of entities in a chain
34const MAX_ENTITY_CHAIN_LENGTH: usize = 50;
35
36/// Standard XML built-in entity patterns
37static BUILTIN_ENTITIES: Lazy<IndexSet<&str>> = Lazy::new(|| {
38    let mut set = IndexSet::new();
39    set.insert("lt");
40    set.insert("gt");
41    set.insert("amp");
42    set.insert("quot");
43    set.insert("apos");
44    set
45});
46
47/// Known malicious entity patterns
48static MALICIOUS_PATTERNS: Lazy<Regex> = Lazy::new(|| {
49    Regex::new(
50        r"(?i)(lol|lol[2-9]|billion|bomb|evil|attack|exploit|payload|xxe|external|system|public)",
51    )
52    .unwrap()
53});
54
55/// External reference patterns
56static EXTERNAL_PATTERNS: Lazy<Regex> =
57    Lazy::new(|| Regex::new(r#"(?i)(SYSTEM|PUBLIC)\s+['"][^'"]*['"]"#).unwrap());
58
59/// Network URL patterns
60static NETWORK_URL_PATTERNS: Lazy<Regex> =
61    Lazy::new(|| Regex::new(r"(?i)(https?://|ftp://|file://|ftps://|smb://|\\\\)").unwrap());
62
63/// Recursive entity reference patterns
64static RECURSIVE_PATTERNS: Lazy<Regex> =
65    Lazy::new(|| Regex::new(r"&[a-zA-Z_][a-zA-Z0-9._-]*;").unwrap());
66
67/// Entity classification levels
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub enum EntityClass {
70    /// Standard XML built-in entities (&lt;, &gt;, &amp;, &quot;, &apos;)
71    SafeBuiltin,
72    /// DDEX-specific entities from official schemas
73    SafeDdex,
74    /// User-defined entities that need validation
75    CustomLocal,
76    /// Entities that match suspicious patterns but aren't confirmed malicious
77    Suspicious {
78        /// Reason for suspicious classification
79        reason: String,
80        /// Confidence level (0.0-1.0)
81        confidence: f64,
82    },
83    /// Confirmed malicious entities
84    Malicious {
85        /// Type of attack detected
86        attack_type: AttackType,
87        /// Reason for malicious classification
88        reason: String,
89    },
90}
91
92/// Types of XML entity attacks
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub enum AttackType {
95    /// XXE (XML External Entity) attack
96    ExternalEntity,
97    /// Billion laughs / exponential expansion attack
98    ExponentialExpansion,
99    /// Recursive entity definition
100    RecursiveEntity,
101    /// Network request attempt
102    NetworkRequest,
103    /// File access attempt
104    FileAccess,
105    /// Parameter entity attack
106    ParameterEntity,
107    /// Generic entity bomb
108    EntityBomb,
109}
110
111/// Classification result
112#[derive(Debug)]
113pub enum ClassificationResult {
114    /// Entity is safe
115    Safe {
116        /// Reason for safe classification
117        reason: String,
118        /// Confidence level (0.0-1.0)
119        confidence: f64,
120    },
121    /// Entity is potentially malicious
122    Malicious {
123        /// Type of attack detected
124        attack_type: AttackType,
125        /// Reason for classification
126        reason: String,
127    },
128}
129
130/// Entity definition for analysis
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct Entity {
133    /// Entity name (without & and ;)
134    pub name: String,
135    /// Entity value/definition
136    pub value: String,
137    /// Whether this is a parameter entity
138    pub is_parameter: bool,
139    /// External system ID if present
140    pub system_id: Option<String>,
141    /// Public ID if present
142    pub public_id: Option<String>,
143    /// Nesting level in entity chain
144    pub depth: usize,
145    /// Size of the entity value in bytes
146    pub size: usize,
147}
148
149/// Result of entity chain validation
150#[derive(Debug, Clone)]
151pub struct ValidationResult {
152    /// Whether the entity chain is safe
153    pub is_safe: bool,
154    /// Classification of the most dangerous entity
155    pub classification: EntityClass,
156    /// Detailed analysis metrics
157    pub metrics: EntityMetrics,
158    /// Warning messages for suspicious but allowed entities
159    pub warnings: Vec<String>,
160    /// Errors for blocked entities
161    pub errors: Vec<String>,
162}
163
164/// Security metrics for entity analysis
165#[derive(Debug, Clone, Default)]
166pub struct EntityMetrics {
167    /// Total number of entities analyzed
168    pub entity_count: usize,
169    /// Maximum recursive depth found
170    pub max_depth: usize,
171    /// Total expansion size
172    pub total_expanded_size: usize,
173    /// Expansion ratio (output/input)
174    pub expansion_ratio: f64,
175    /// Number of external references
176    pub external_references: usize,
177    /// Number of network URLs found
178    pub network_urls: usize,
179    /// Processing time in milliseconds
180    pub processing_time_ms: u64,
181}
182
183/// Configuration for entity classification
184#[derive(Debug, Clone)]
185pub struct ClassifierConfig {
186    /// Maximum allowed recursive depth
187    pub max_depth: usize,
188    /// Maximum expansion ratio
189    pub max_expansion_ratio: f64,
190    /// Maximum total expanded size
191    pub max_expanded_size: usize,
192    /// Whether to allow external entities
193    pub allow_external_entities: bool,
194    /// Whether to allow parameter entities
195    pub allow_parameter_entities: bool,
196    /// Custom safe entities (in addition to DDEX whitelist)
197    pub custom_safe_entities: IndexSet<String>,
198    /// Whether to collect detailed metrics
199    pub collect_metrics: bool,
200}
201
202impl Default for ClassifierConfig {
203    fn default() -> Self {
204        Self {
205            max_depth: MAX_RECURSIVE_DEPTH,
206            max_expansion_ratio: MAX_EXPANSION_RATIO,
207            max_expanded_size: MAX_EXPANDED_SIZE,
208            allow_external_entities: false,
209            allow_parameter_entities: false,
210            custom_safe_entities: IndexSet::new(),
211            collect_metrics: true,
212        }
213    }
214}
215
216/// DDEX Entity Classifier
217pub struct EntityClassifier {
218    config: ClassifierConfig,
219    ddex_whitelist: IndexSet<String>,
220    entity_cache: HashMap<String, EntityClass>,
221    metrics_history: VecDeque<EntityMetrics>,
222}
223
224impl EntityClassifier {
225    /// Create a new entity classifier with default configuration
226    pub fn new() -> Self {
227        Self::with_config(ClassifierConfig::default())
228    }
229
230    /// Create a new entity classifier with custom configuration
231    pub fn with_config(config: ClassifierConfig) -> Self {
232        let ddex_whitelist = Self::load_ddex_whitelist();
233
234        Self {
235            config,
236            ddex_whitelist,
237            entity_cache: HashMap::new(),
238            metrics_history: VecDeque::with_capacity(100), // Keep last 100 analyses
239        }
240    }
241
242    /// Classify a single entity by name and value
243    pub fn classify_entity(&mut self, name: &str, value: &str) -> EntityClass {
244        let cache_key = format!("{}:{}", name, value);
245
246        // Check cache first
247        if let Some(cached) = self.entity_cache.get(&cache_key) {
248            return cached.clone();
249        }
250
251        let classification = self.classify_entity_internal(name, value);
252
253        // Cache the result
254        self.entity_cache.insert(cache_key, classification.clone());
255
256        classification
257    }
258
259    /// Internal classification logic
260    fn classify_entity_internal(&self, name: &str, value: &str) -> EntityClass {
261        // 1. Check if it's a standard XML built-in entity
262        if BUILTIN_ENTITIES.contains(name) {
263            return EntityClass::SafeBuiltin;
264        }
265
266        // 2. Check if it's in the DDEX whitelist
267        if self.ddex_whitelist.contains(name) {
268            return EntityClass::SafeDdex;
269        }
270
271        // 3. Check if it's in custom safe entities
272        if self.config.custom_safe_entities.contains(name) {
273            return EntityClass::SafeDdex; // Treat custom safe as DDEX-level
274        }
275
276        // 4. Check for external references in value (highest priority)
277        if EXTERNAL_PATTERNS.is_match(value) {
278            return EntityClass::Malicious {
279                attack_type: AttackType::ExternalEntity,
280                reason: "Entity contains SYSTEM or PUBLIC external reference".to_string(),
281            };
282        }
283
284        // 5. Check for network URLs
285        if NETWORK_URL_PATTERNS.is_match(value) {
286            return EntityClass::Malicious {
287                attack_type: AttackType::NetworkRequest,
288                reason: "Entity contains network URL".to_string(),
289            };
290        }
291
292        // 6. Check for malicious patterns in name (lower priority)
293        if MALICIOUS_PATTERNS.is_match(name) {
294            return EntityClass::Malicious {
295                attack_type: AttackType::EntityBomb,
296                reason: format!("Entity name '{}' matches known attack patterns", name),
297            };
298        }
299
300        // 7. Check for recursive references
301        let entity_refs = RECURSIVE_PATTERNS.find_iter(value).count();
302        if entity_refs > 5 {
303            return EntityClass::Suspicious {
304                reason: format!("Entity contains {} recursive references", entity_refs),
305                confidence: (entity_refs as f64 / 10.0).min(1.0),
306            };
307        }
308
309        // 8. Check value size
310        if value.len() > 10000 {
311            return EntityClass::Suspicious {
312                reason: format!("Entity value is very large ({} bytes)", value.len()),
313                confidence: 0.7,
314            };
315        }
316
317        // 9. Check for repetitive patterns (possible expansion bomb)
318        if self.has_repetitive_pattern(value) {
319            return EntityClass::Suspicious {
320                reason: "Entity contains repetitive patterns".to_string(),
321                confidence: 0.6,
322            };
323        }
324
325        // Default to custom local (needs validation)
326        EntityClass::CustomLocal
327    }
328
329    /// Check if an entity is safe for use
330    pub fn is_safe_entity(&mut self, entity: &Entity) -> bool {
331        let classification = self.classify_entity(&entity.name, &entity.value);
332
333        match classification {
334            EntityClass::SafeBuiltin | EntityClass::SafeDdex => true,
335            EntityClass::CustomLocal => {
336                // Additional validation for custom entities
337                entity.depth <= self.config.max_depth
338                    && entity.size <= self.config.max_expanded_size
339                    && !entity.is_parameter // Be strict about parameter entities
340            }
341            EntityClass::Suspicious { confidence, .. } => {
342                // Allow suspicious entities with low confidence
343                confidence < 0.5
344            }
345            EntityClass::Malicious { .. } => false,
346        }
347    }
348
349    /// Validate a complete entity chain
350    pub fn validate_entity_chain(&mut self, entities: &[Entity]) -> ValidationResult {
351        let start_time = Instant::now();
352        let mut metrics = EntityMetrics::default();
353        let mut warnings = Vec::new();
354        let mut errors = Vec::new();
355        let mut most_dangerous = EntityClass::SafeBuiltin;
356        let mut is_safe = true;
357
358        // Basic chain validation
359        if entities.len() > MAX_ENTITY_CHAIN_LENGTH {
360            errors.push(format!(
361                "Entity chain too long: {} entities (max: {})",
362                entities.len(),
363                MAX_ENTITY_CHAIN_LENGTH
364            ));
365            is_safe = false;
366        }
367
368        // Track entity expansion and depth
369        let mut total_input_size = 0;
370        let mut total_output_size = 0;
371        let mut max_depth = 0;
372        let mut external_refs = 0;
373        let mut network_urls = 0;
374
375        // Analyze each entity
376        for entity in entities {
377            let classification = self.classify_entity(&entity.name, &entity.value);
378
379            // Update metrics
380            total_input_size += entity.name.len() + 2; // &name;
381            total_output_size += entity.size;
382            max_depth = max_depth.max(entity.depth);
383
384            if entity.system_id.is_some() || entity.public_id.is_some() {
385                external_refs += 1;
386            }
387
388            if NETWORK_URL_PATTERNS.is_match(&entity.value) {
389                network_urls += 1;
390            }
391
392            // Check individual entity safety
393            match &classification {
394                EntityClass::SafeBuiltin | EntityClass::SafeDdex => {
395                    // These are always safe
396                }
397                EntityClass::CustomLocal => {
398                    if entity.depth > self.config.max_depth {
399                        errors.push(format!(
400                            "Entity '{}' exceeds maximum depth: {} > {}",
401                            entity.name, entity.depth, self.config.max_depth
402                        ));
403                        is_safe = false;
404                    }
405
406                    if entity.is_parameter && !self.config.allow_parameter_entities {
407                        errors.push(format!("Parameter entity '{}' not allowed", entity.name));
408                        is_safe = false;
409                    }
410                }
411                EntityClass::Suspicious { reason, confidence } => {
412                    warnings.push(format!(
413                        "Suspicious entity '{}': {} (confidence: {:.2})",
414                        entity.name, reason, confidence
415                    ));
416
417                    if *confidence > 0.7 {
418                        is_safe = false;
419                        most_dangerous = classification.clone();
420                    }
421                }
422                EntityClass::Malicious {
423                    attack_type,
424                    reason,
425                } => {
426                    errors.push(format!(
427                        "Malicious entity '{}' ({:?}): {}",
428                        entity.name, attack_type, reason
429                    ));
430                    is_safe = false;
431                    most_dangerous = classification.clone();
432                }
433            }
434        }
435
436        // Calculate expansion ratio
437        let expansion_ratio = if total_input_size > 0 {
438            total_output_size as f64 / total_input_size as f64
439        } else {
440            1.0
441        };
442
443        // Check overall limits
444        if expansion_ratio > self.config.max_expansion_ratio {
445            errors.push(format!(
446                "Expansion ratio too high: {:.2} > {}",
447                expansion_ratio, self.config.max_expansion_ratio
448            ));
449            is_safe = false;
450        }
451
452        if total_output_size > self.config.max_expanded_size {
453            errors.push(format!(
454                "Total expanded size too large: {} > {}",
455                total_output_size, self.config.max_expanded_size
456            ));
457            is_safe = false;
458        }
459
460        if external_refs > 0 && !self.config.allow_external_entities {
461            errors.push(format!(
462                "External entities not allowed ({} found)",
463                external_refs
464            ));
465            is_safe = false;
466        }
467
468        // Populate metrics
469        metrics.entity_count = entities.len();
470        metrics.max_depth = max_depth;
471        metrics.total_expanded_size = total_output_size;
472        metrics.expansion_ratio = expansion_ratio;
473        metrics.external_references = external_refs;
474        metrics.network_urls = network_urls;
475        metrics.processing_time_ms = start_time.elapsed().as_millis() as u64;
476
477        // Store metrics for analysis
478        if self.config.collect_metrics {
479            self.metrics_history.push_back(metrics.clone());
480            if self.metrics_history.len() > 100 {
481                self.metrics_history.pop_front();
482            }
483        }
484
485        // Log security events
486        if !is_safe {
487            warn!(
488                "Entity chain validation failed: {} errors, {} warnings",
489                errors.len(),
490                warnings.len()
491            );
492        } else if !warnings.is_empty() {
493            debug!(
494                "Entity chain validation passed with {} warnings",
495                warnings.len()
496            );
497        }
498
499        ValidationResult {
500            is_safe,
501            classification: most_dangerous,
502            metrics,
503            warnings,
504            errors,
505        }
506    }
507
508    /// Get recent security metrics for analysis
509    pub fn get_metrics_history(&self) -> &VecDeque<EntityMetrics> {
510        &self.metrics_history
511    }
512
513    /// Clear the entity classification cache
514    pub fn clear_cache(&mut self) {
515        self.entity_cache.clear();
516    }
517
518    /// Load DDEX entity whitelist from official schemas
519    fn load_ddex_whitelist() -> IndexSet<String> {
520        let mut whitelist = IndexSet::new();
521
522        // Standard DDEX entities that are commonly used and safe
523        // These would typically be loaded from DDEX schema files
524        whitelist.insert("ddex".to_string());
525        whitelist.insert("ern".to_string());
526        whitelist.insert("avs".to_string());
527        whitelist.insert("iso".to_string());
528        whitelist.insert("musicbrainz".to_string());
529        whitelist.insert("isrc".to_string());
530        whitelist.insert("iswc".to_string());
531        whitelist.insert("isni".to_string());
532        whitelist.insert("dpid".to_string());
533        whitelist.insert("grid".to_string());
534        whitelist.insert("mwli".to_string());
535        whitelist.insert("spar".to_string());
536
537        // Common DDEX namespace prefixes
538        whitelist.insert("NewReleaseMessage".to_string());
539        whitelist.insert("MessageHeader".to_string());
540        whitelist.insert("MessageId".to_string());
541        whitelist.insert("MessageSender".to_string());
542        whitelist.insert("SentOnBehalfOf".to_string());
543        whitelist.insert("MessageRecipient".to_string());
544        whitelist.insert("MessageCreatedDateTime".to_string());
545        whitelist.insert("MessageAuditTrail".to_string());
546
547        // Release-specific entities
548        whitelist.insert("ReleaseList".to_string());
549        whitelist.insert("Release".to_string());
550        whitelist.insert("ReleaseId".to_string());
551        whitelist.insert("ReleaseReference".to_string());
552        whitelist.insert("ReferenceTitle".to_string());
553        whitelist.insert("ReleaseDetailsByTerritory".to_string());
554
555        // Resource entities
556        whitelist.insert("ResourceList".to_string());
557        whitelist.insert("SoundRecording".to_string());
558        whitelist.insert("MusicalWork".to_string());
559        whitelist.insert("Image".to_string());
560        whitelist.insert("Text".to_string());
561        whitelist.insert("Video".to_string());
562
563        // Deal/Commercial entities
564        whitelist.insert("DealList".to_string());
565        whitelist.insert("ReleaseDeal".to_string());
566        whitelist.insert("Deal".to_string());
567        whitelist.insert("DealTerms".to_string());
568        whitelist.insert("CommercialModelType".to_string());
569        whitelist.insert("Usage".to_string());
570        whitelist.insert("Territory".to_string());
571
572        debug!("Loaded {} DDEX entities to whitelist", whitelist.len());
573
574        whitelist
575    }
576
577    /// Check if a value has repetitive patterns that might indicate an expansion bomb
578    fn has_repetitive_pattern(&self, value: &str) -> bool {
579        if value.len() < 20 {
580            return false;
581        }
582
583        // Look for repeated substrings
584        let chars: Vec<char> = value.chars().collect();
585        let len = chars.len();
586
587        // Check for patterns of length 2-10
588        for pattern_len in 2..=10.min(len / 4) {
589            let mut matches = 0;
590            let pattern = &chars[0..pattern_len];
591
592            for i in (0..len).step_by(pattern_len) {
593                if i + pattern_len <= len && &chars[i..i + pattern_len] == pattern {
594                    matches += 1;
595                }
596            }
597
598            // If more than 50% of the string is the same pattern, it's suspicious
599            if matches * pattern_len > len / 2 {
600                return true;
601            }
602        }
603
604        false
605    }
606}
607
608impl Default for EntityClassifier {
609    fn default() -> Self {
610        Self::new()
611    }
612}
613
614/// Helper function to create an Entity from name and value
615pub fn create_entity(name: &str, value: &str) -> Entity {
616    Entity {
617        name: name.to_string(),
618        value: value.to_string(),
619        is_parameter: false,
620        system_id: None,
621        public_id: None,
622        depth: 0,
623        size: value.len(),
624    }
625}
626
627/// Helper function to create a parameter entity
628pub fn create_parameter_entity(name: &str, value: &str) -> Entity {
629    Entity {
630        name: name.to_string(),
631        value: value.to_string(),
632        is_parameter: true,
633        system_id: None,
634        public_id: None,
635        depth: 0,
636        size: value.len(),
637    }
638}
639
640/// Helper function to create an external entity
641pub fn create_external_entity(name: &str, system_id: &str) -> Entity {
642    Entity {
643        name: name.to_string(),
644        value: String::new(),
645        is_parameter: false,
646        system_id: Some(system_id.to_string()),
647        public_id: None,
648        depth: 0,
649        size: 0,
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_builtin_entity_classification() {
659        let mut classifier = EntityClassifier::new();
660
661        assert_eq!(
662            classifier.classify_entity("lt", "<"),
663            EntityClass::SafeBuiltin
664        );
665
666        assert_eq!(
667            classifier.classify_entity("amp", "&"),
668            EntityClass::SafeBuiltin
669        );
670    }
671
672    #[test]
673    fn test_ddex_entity_classification() {
674        let mut classifier = EntityClassifier::new();
675
676        assert_eq!(
677            classifier.classify_entity("ddex", "http://ddex.net/xml/ern/43"),
678            EntityClass::SafeDdex
679        );
680    }
681
682    #[test]
683    fn test_malicious_entity_detection() {
684        let mut classifier = EntityClassifier::new();
685
686        // Test external entity
687        let result =
688            classifier.classify_entity("xxe", "<!ENTITY xxe SYSTEM \"file:///etc/passwd\">");
689
690        match result {
691            EntityClass::Malicious {
692                attack_type: AttackType::ExternalEntity,
693                ..
694            } => {}
695            _ => panic!("Should detect external entity attack"),
696        }
697
698        // Test network URL
699        let result = classifier.classify_entity("evil", "http://attacker.com/evil.xml");
700
701        match result {
702            EntityClass::Malicious {
703                attack_type: AttackType::NetworkRequest,
704                ..
705            } => {}
706            _ => panic!("Should detect network request attack"),
707        }
708    }
709
710    #[test]
711    fn test_entity_chain_validation() {
712        let mut classifier = EntityClassifier::new();
713
714        let entities = vec![
715            create_entity("safe", "content"),
716            create_entity("lol", "&lol2;&lol2;&lol2;"),
717            create_entity("lol2", "&lol3;&lol3;&lol3;"),
718            create_entity("lol3", "haha"),
719        ];
720
721        let result = classifier.validate_entity_chain(&entities);
722        assert!(!result.is_safe);
723        assert!(!result.errors.is_empty());
724    }
725
726    #[test]
727    fn test_safe_entity_chain() {
728        let mut classifier = EntityClassifier::new();
729
730        let entities = vec![
731            create_entity("title", "My Song"),
732            create_entity("artist", "My Artist"),
733        ];
734
735        let result = classifier.validate_entity_chain(&entities);
736        assert!(result.is_safe);
737        assert!(result.errors.is_empty());
738    }
739
740    #[test]
741    fn test_expansion_ratio_detection() {
742        let mut classifier = EntityClassifier::new();
743
744        // Create entities that expand significantly
745        let entities = vec![Entity {
746            name: "bomb".to_string(),
747            value: "A".repeat(1000),
748            is_parameter: false,
749            system_id: None,
750            public_id: None,
751            depth: 0,
752            size: 1000,
753        }];
754
755        let result = classifier.validate_entity_chain(&entities);
756
757        // Should trigger expansion ratio warning
758        assert!(result.metrics.expansion_ratio > 50.0);
759    }
760}