Skip to main content

mir_extractor/rules/
advanced_input.rs

1//! Advanced input validation rules.
2//!
3//! Deep dataflow analysis for:
4//! - ADV003/RUSTCOLA201: Insecure binary deserialization (bincode, postcard)
5//! - ADV004/RUSTCOLA202: Regex catastrophic backtracking
6//! - ADV008/RUSTCOLA203: Uncontrolled allocation size
7//! - ADV009/RUSTCOLA204: Integer overflow on untrusted input
8
9use std::collections::{HashMap, HashSet};
10
11use crate::{
12    interprocedural::InterProceduralAnalysis, AttackComplexity, AttackVector, Confidence,
13    Exploitability, Finding, MirFunction, MirPackage, PrivilegesRequired, Rule, RuleMetadata,
14    RuleOrigin, Severity, UserInteraction,
15};
16
17use super::advanced_utils::{
18    detect_assignment, detect_const_string_assignment, detect_len_call, detect_len_comparison,
19    detect_var_alias, extract_call_args, extract_const_literals, is_untrusted_source,
20    pattern_is_high_risk, unescape_rust_literal, TaintTracker,
21};
22
23// ============================================================================
24// RUSTCOLA201: Insecure Binary Deserialization (was ADV003)
25// ============================================================================
26
27/// Detects binary deserialization (bincode, postcard) on untrusted input without size checks.
28pub struct InsecureBinaryDeserializationRule {
29    metadata: RuleMetadata,
30}
31
32impl Default for InsecureBinaryDeserializationRule {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl InsecureBinaryDeserializationRule {
39    pub fn new() -> Self {
40        Self {
41            metadata: RuleMetadata {
42                id: "RUSTCOLA201".to_string(),
43                name: "insecure-binary-deserialization".to_string(),
44                short_description: "Detects binary deserialization on untrusted input".to_string(),
45                full_description: "Binary deserialization libraries like bincode and postcard \
46                    can deserialize arbitrary data structures. When processing untrusted input \
47                    without size validation, attackers can craft payloads that cause excessive \
48                    memory allocation or trigger other vulnerabilities."
49                    .to_string(),
50                help_uri: None,
51                default_severity: Severity::High,
52                origin: RuleOrigin::BuiltIn,
53                cwe_ids: vec!["502".to_string()], // CWE-502: Deserialization of Untrusted Data
54                fix_suggestion: Some(
55                    "Validate input size before deserialization. Use deserialize_with_limit \
56                    or check buffer length against a maximum before calling deserialize."
57                        .to_string(),
58                ),
59                exploitability: Exploitability {
60                    attack_vector: AttackVector::Network,
61                    attack_complexity: AttackComplexity::Low,
62                    privileges_required: PrivilegesRequired::None,
63                    user_interaction: UserInteraction::None,
64                },
65            },
66        }
67    }
68
69    const SINK_PATTERNS: &'static [&'static str] = &[
70        "bincode::deserialize",
71        "bincode::deserialize_from",
72        "bincode::config::deserialize",
73        "bincode::config::deserialize_from",
74        "postcard::from_bytes",
75        "postcard::from_bytes_cobs",
76        "postcard::take_from_bytes",
77        "postcard::take_from_bytes_cobs",
78    ];
79}
80
81impl Rule for InsecureBinaryDeserializationRule {
82    fn metadata(&self) -> &RuleMetadata {
83        &self.metadata
84    }
85
86    fn evaluate(
87        &self,
88        package: &MirPackage,
89        _inter_analysis: Option<&InterProceduralAnalysis>,
90    ) -> Vec<Finding> {
91        let mut findings = Vec::new();
92
93        for func in &package.functions {
94            let mir_text = func.body.join("\n");
95            let mut tracker = TaintTracker::default();
96            let mut pending_len_checks: HashMap<String, String> = HashMap::new();
97
98            for line in mir_text.lines() {
99                let trimmed = line.trim();
100                if trimmed.is_empty() {
101                    continue;
102                }
103
104                // Track taint sources
105                if let Some(dest) = detect_assignment(trimmed) {
106                    if is_untrusted_source(trimmed) {
107                        tracker.mark_source(&dest, trimmed);
108                    } else if let Some(source) = tracker.find_tainted_in_line(trimmed) {
109                        tracker.mark_alias(&dest, &source);
110                    }
111                }
112
113                // Track length checks as sanitization
114                if let Some((len_var, src_var)) = detect_len_call(trimmed) {
115                    if let Some(root) = tracker.taint_roots.get(&src_var).cloned() {
116                        pending_len_checks.insert(len_var, root);
117                    }
118                }
119
120                if let Some(len_var) = detect_len_comparison(trimmed) {
121                    if let Some(root) = pending_len_checks.remove(&len_var) {
122                        tracker.sanitize_root(&root);
123                    }
124                }
125
126                // Check sinks
127                if let Some(sink_name) = Self::SINK_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
128                    let args = extract_call_args(trimmed);
129                    for arg in args {
130                        if let Some(root) = tracker.taint_roots.get(&arg).cloned() {
131                            if tracker.sanitized_roots.contains(&root) {
132                                continue;
133                            }
134
135                            let mut message = format!(
136                                "Insecure binary deserialization: untrusted data flows into `{}`",
137                                sink_name
138                            );
139                            if let Some(origin) = tracker.sources.get(&root) {
140                                message.push_str(&format!("\n  source: `{}`", origin));
141                            }
142
143                            findings.push(Finding {
144                                rule_id: self.metadata.id.clone(),
145                                rule_name: self.metadata.name.clone(),
146                                severity: self.metadata.default_severity,
147                                confidence: Confidence::High,
148                                message,
149                                function: func.name.clone(),
150                                function_signature: func.signature.clone(),
151                                evidence: vec![trimmed.to_string()],
152                                span: func.span.clone(),
153                                exploitability: self.metadata.exploitability.clone(),
154                                exploitability_score: self.metadata.exploitability.score(),
155                                ..Default::default()
156                            });
157                            break;
158                        }
159                    }
160                }
161            }
162        }
163
164        findings
165    }
166}
167
168// ============================================================================
169// RUSTCOLA202: Regex Catastrophic Backtracking (was ADV004)
170// ============================================================================
171
172/// Detects regex patterns with nested quantifiers that trigger catastrophic backtracking.
173pub struct RegexBacktrackingDosRule {
174    metadata: RuleMetadata,
175}
176
177impl Default for RegexBacktrackingDosRule {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl RegexBacktrackingDosRule {
184    pub fn new() -> Self {
185        Self {
186            metadata: RuleMetadata {
187                id: "RUSTCOLA202".to_string(),
188                name: "regex-backtracking-dos".to_string(),
189                short_description: "Detects regex patterns vulnerable to catastrophic backtracking"
190                    .to_string(),
191                full_description: "Regex patterns with nested quantifiers like (a+)+ can cause \
192                    exponential backtracking on certain inputs. This can be exploited for \
193                    denial-of-service attacks (ReDoS) by sending specially crafted input."
194                    .to_string(),
195                help_uri: None,
196                default_severity: Severity::Medium,
197                origin: RuleOrigin::BuiltIn,
198                cwe_ids: vec!["1333".to_string()], // CWE-1333: Inefficient Regular Expression
199                fix_suggestion: Some(
200                    "Avoid nested quantifiers. Use atomic groups or possessive quantifiers. \
201                    Consider using regex crate's built-in protections or set match limits."
202                        .to_string(),
203                ),
204                exploitability: Exploitability {
205                    attack_vector: AttackVector::Network,
206                    attack_complexity: AttackComplexity::Low,
207                    privileges_required: PrivilegesRequired::None,
208                    user_interaction: UserInteraction::None,
209                },
210            },
211        }
212    }
213
214    const SINK_PATTERNS: &'static [&'static str] = &[
215        "regex::Regex::new",
216        "regex::RegexSet::new",
217        "regex::builders::RegexBuilder::new",
218        "regex::RegexBuilder::new",
219    ];
220}
221
222impl Rule for RegexBacktrackingDosRule {
223    fn metadata(&self) -> &RuleMetadata {
224        &self.metadata
225    }
226
227    fn evaluate(
228        &self,
229        package: &MirPackage,
230        _inter_analysis: Option<&InterProceduralAnalysis>,
231    ) -> Vec<Finding> {
232        let mut findings = Vec::new();
233
234        for func in &package.functions {
235            let mir_text = func.body.join("\n");
236            let mut const_strings: HashMap<String, String> = HashMap::new();
237            let mut reported_lines: HashSet<String> = HashSet::new();
238
239            for line in mir_text.lines() {
240                let trimmed = line.trim();
241                if trimmed.is_empty() {
242                    continue;
243                }
244
245                // Track constant string assignments
246                if let Some((var, literal)) = detect_const_string_assignment(trimmed) {
247                    const_strings.insert(var, unescape_rust_literal(&literal));
248                    continue;
249                }
250
251                // Track variable aliases
252                if let Some((dest, src)) = detect_var_alias(trimmed) {
253                    if let Some(value) = const_strings.get(&src).cloned() {
254                        const_strings.insert(dest, value);
255                    }
256                }
257
258                // Check for regex compilation
259                if let Some(sink) = Self::SINK_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
260                    // Check inline literals
261                    for literal in extract_const_literals(trimmed) {
262                        let unescaped = unescape_rust_literal(&literal);
263                        if pattern_is_high_risk(&unescaped) {
264                            let key = format!("{}::{}", sink, trimmed.trim());
265                            if reported_lines.insert(key) {
266                                findings.push(self.create_finding(func, sink, trimmed, &unescaped));
267                            }
268                        }
269                    }
270
271                    // Check tracked variables
272                    let args = extract_call_args(trimmed);
273                    for arg in args {
274                        if let Some(pattern) = const_strings.get(&arg).cloned() {
275                            if pattern_is_high_risk(&pattern) {
276                                let key = format!("{}::{}", sink, trimmed.trim());
277                                if reported_lines.insert(key) {
278                                    findings
279                                        .push(self.create_finding(func, sink, trimmed, &pattern));
280                                }
281                            }
282                        }
283                    }
284                }
285            }
286        }
287
288        findings
289    }
290}
291
292impl RegexBacktrackingDosRule {
293    fn create_finding(&self, func: &MirFunction, sink: &str, line: &str, pattern: &str) -> Finding {
294        let display = if pattern.len() > 60 {
295            format!("{}...", &pattern[..57])
296        } else {
297            pattern.to_string()
298        };
299
300        Finding {
301            rule_id: self.metadata.id.clone(),
302            rule_name: self.metadata.name.clone(),
303            severity: self.metadata.default_severity,
304            confidence: Confidence::Medium,
305            message: format!(
306                "Potential regex DoS: pattern `{}` compiled via `{}` may trigger catastrophic backtracking",
307                display, sink
308            ),
309            function: func.name.clone(),
310            function_signature: func.signature.clone(),
311            evidence: vec![line.trim().to_string()],
312            span: func.span.clone(),
313            exploitability: self.metadata.exploitability.clone(),
314            exploitability_score: self.metadata.exploitability.score(),
315            ..Default::default()
316        }
317    }
318}
319
320// ============================================================================
321// RUSTCOLA203: Uncontrolled Allocation Size (was ADV008)
322// ============================================================================
323
324/// Detects allocations sized from untrusted sources without upper bound validation.
325pub struct UncontrolledAllocationSizeRule {
326    metadata: RuleMetadata,
327}
328
329impl Default for UncontrolledAllocationSizeRule {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335impl UncontrolledAllocationSizeRule {
336    pub fn new() -> Self {
337        Self {
338            metadata: RuleMetadata {
339                id: "RUSTCOLA203".to_string(),
340                name: "uncontrolled-allocation-size".to_string(),
341                short_description: "Detects allocations sized from untrusted sources".to_string(),
342                full_description: "Using untrusted input to control allocation size without \
343                    validation can lead to denial-of-service through memory exhaustion. \
344                    Attackers can send large values to trigger excessive memory allocation."
345                    .to_string(),
346                help_uri: None,
347                default_severity: Severity::High,
348                origin: RuleOrigin::BuiltIn,
349                cwe_ids: vec!["789".to_string()], // CWE-789: Memory Allocation with Excessive Size
350                fix_suggestion: Some(
351                    "Validate allocation size against a reasonable maximum before allocating. \
352                    Use min() or clamp() to enforce upper bounds."
353                        .to_string(),
354                ),
355                exploitability: Exploitability {
356                    attack_vector: AttackVector::Network,
357                    attack_complexity: AttackComplexity::Low,
358                    privileges_required: PrivilegesRequired::None,
359                    user_interaction: UserInteraction::None,
360                },
361            },
362        }
363    }
364
365    const ALLOC_PATTERNS: &'static [&'static str] = &[
366        "Vec::with_capacity",
367        "vec::with_capacity",
368        "String::with_capacity",
369        "string::with_capacity",
370        "HashMap::with_capacity",
371        "hashmap::with_capacity",
372        "HashSet::with_capacity",
373        "VecDeque::with_capacity",
374        "::reserve",
375        "::reserve_exact",
376        "alloc::alloc",
377        "alloc::alloc_zeroed",
378        "alloc::realloc",
379        "Box::new_uninit_slice",
380        "vec![",
381    ];
382}
383
384impl Rule for UncontrolledAllocationSizeRule {
385    fn metadata(&self) -> &RuleMetadata {
386        &self.metadata
387    }
388
389    fn evaluate(
390        &self,
391        package: &MirPackage,
392        _inter_analysis: Option<&InterProceduralAnalysis>,
393    ) -> Vec<Finding> {
394        let mut findings = Vec::new();
395
396        for func in &package.functions {
397            let mir_text = func.body.join("\n");
398            let mut tracker = TaintTracker::default();
399            let mut checked_vars: HashSet<String> = HashSet::new();
400
401            for line in mir_text.lines() {
402                let trimmed = line.trim();
403                if trimmed.is_empty() {
404                    continue;
405                }
406
407                // Track taint sources
408                if let Some(dest) = detect_assignment(trimmed) {
409                    if is_untrusted_source(trimmed) {
410                        tracker.mark_source(&dest, trimmed);
411                    } else if let Some(source) = tracker.find_tainted_in_line(trimmed) {
412                        tracker.mark_alias(&dest, &source);
413                    }
414                }
415
416                // Detect bounds checks (min, clamp, comparisons)
417                if trimmed.contains("::min(") || trimmed.contains("::clamp(") {
418                    let args = extract_call_args(trimmed);
419                    for arg in &args {
420                        checked_vars.insert(arg.clone());
421                        if let Some(root) = tracker.taint_roots.get(arg).cloned() {
422                            tracker.sanitize_root(&root);
423                        }
424                    }
425                }
426
427                // Check allocation sinks
428                if let Some(sink) = Self::ALLOC_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
429                    let args = extract_call_args(trimmed);
430                    for arg in args {
431                        if checked_vars.contains(&arg) {
432                            continue;
433                        }
434
435                        if let Some(root) = tracker.taint_roots.get(&arg).cloned() {
436                            if tracker.sanitized_roots.contains(&root) {
437                                continue;
438                            }
439
440                            let mut message = format!(
441                                "Uncontrolled allocation size: untrusted value flows into `{}`",
442                                sink
443                            );
444                            if let Some(origin) = tracker.sources.get(&root) {
445                                message.push_str(&format!("\n  source: `{}`", origin));
446                            }
447
448                            findings.push(Finding {
449                                rule_id: self.metadata.id.clone(),
450                                rule_name: self.metadata.name.clone(),
451                                severity: self.metadata.default_severity,
452                                confidence: Confidence::High,
453                                message,
454                                function: func.name.clone(),
455                                function_signature: func.signature.clone(),
456                                evidence: vec![trimmed.to_string()],
457                                span: func.span.clone(),
458                                exploitability: self.metadata.exploitability.clone(),
459                                exploitability_score: self.metadata.exploitability.score(),
460                                ..Default::default()
461                            });
462                            break;
463                        }
464                    }
465                }
466            }
467        }
468
469        findings
470    }
471}
472
473// ============================================================================
474// RUSTCOLA204: Integer Overflow on Untrusted Input (was ADV009)
475// ============================================================================
476
477/// Detects arithmetic operations on untrusted input without overflow protection.
478pub struct IntegerOverflowRule {
479    metadata: RuleMetadata,
480}
481
482impl Default for IntegerOverflowRule {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488impl IntegerOverflowRule {
489    pub fn new() -> Self {
490        Self {
491            metadata: RuleMetadata {
492                id: "RUSTCOLA204".to_string(),
493                name: "integer-overflow-untrusted".to_string(),
494                short_description:
495                    "Detects arithmetic on untrusted input without overflow protection".to_string(),
496                full_description: "Arithmetic operations on values derived from untrusted sources \
497                    can overflow in release builds. This can lead to incorrect calculations, \
498                    buffer overflows, or denial of service."
499                    .to_string(),
500                help_uri: None,
501                default_severity: Severity::Medium,
502                origin: RuleOrigin::BuiltIn,
503                cwe_ids: vec!["190".to_string()], // CWE-190: Integer Overflow
504                fix_suggestion: Some(
505                    "Use checked_*, saturating_*, or wrapping_* methods for arithmetic on \
506                    untrusted input. Validate input ranges before arithmetic operations."
507                        .to_string(),
508                ),
509                exploitability: Exploitability {
510                    attack_vector: AttackVector::Network,
511                    attack_complexity: AttackComplexity::High,
512                    privileges_required: PrivilegesRequired::None,
513                    user_interaction: UserInteraction::None,
514                },
515            },
516        }
517    }
518
519    const UNSAFE_OPS: &'static [(&'static str, &'static str)] = &[
520        ("Add(", "addition"),
521        ("Sub(", "subtraction"),
522        ("Mul(", "multiplication"),
523    ];
524
525    const SAFE_METHODS: &'static [&'static str] = &[
526        "checked_add",
527        "checked_sub",
528        "checked_mul",
529        "checked_div",
530        "saturating_add",
531        "saturating_sub",
532        "saturating_mul",
533        "wrapping_add",
534        "wrapping_sub",
535        "wrapping_mul",
536        "overflowing_add",
537        "overflowing_sub",
538        "overflowing_mul",
539    ];
540}
541
542impl Rule for IntegerOverflowRule {
543    fn metadata(&self) -> &RuleMetadata {
544        &self.metadata
545    }
546
547    fn evaluate(
548        &self,
549        package: &MirPackage,
550        _inter_analysis: Option<&InterProceduralAnalysis>,
551    ) -> Vec<Finding> {
552        let mut findings = Vec::new();
553
554        for func in &package.functions {
555            let mir_text = func.body.join("\n");
556            let mut tracker = TaintTracker::default();
557            let mut safe_vars: HashSet<String> = HashSet::new();
558
559            for line in mir_text.lines() {
560                let trimmed = line.trim();
561                if trimmed.is_empty() {
562                    continue;
563                }
564
565                // Track taint sources
566                if let Some(dest) = detect_assignment(trimmed) {
567                    if is_untrusted_source(trimmed) {
568                        tracker.mark_source(&dest, trimmed);
569                    } else if let Some(source) = tracker.find_tainted_in_line(trimmed) {
570                        tracker.mark_alias(&dest, &source);
571                    }
572
573                    // Track safe arithmetic results
574                    if Self::SAFE_METHODS.iter().any(|m| trimmed.contains(m)) {
575                        safe_vars.insert(dest);
576                    }
577                }
578
579                // Check for unsafe arithmetic operations
580                for (op_pattern, op_name) in Self::UNSAFE_OPS {
581                    if trimmed.contains(op_pattern) {
582                        // Skip if result of safe method
583                        if Self::SAFE_METHODS.iter().any(|m| trimmed.contains(m)) {
584                            continue;
585                        }
586
587                        let args = extract_call_args(trimmed);
588                        for arg in args {
589                            if safe_vars.contains(&arg) {
590                                continue;
591                            }
592
593                            if let Some(root) = tracker.taint_roots.get(&arg).cloned() {
594                                if tracker.sanitized_roots.contains(&root) {
595                                    continue;
596                                }
597
598                                let mut message = format!(
599                                    "Potential integer overflow: untrusted value in {} without overflow protection",
600                                    op_name
601                                );
602                                if let Some(origin) = tracker.sources.get(&root) {
603                                    message.push_str(&format!("\n  source: `{}`", origin));
604                                }
605
606                                findings.push(Finding {
607                                    rule_id: self.metadata.id.clone(),
608                                    rule_name: self.metadata.name.clone(),
609                                    severity: self.metadata.default_severity,
610                                    confidence: Confidence::Medium,
611                                    message,
612                                    function: func.name.clone(),
613                                    function_signature: func.signature.clone(),
614                                    evidence: vec![trimmed.to_string()],
615                                    span: func.span.clone(),
616                                    exploitability: self.metadata.exploitability.clone(),
617                                    exploitability_score: self.metadata.exploitability.score(),
618                                    ..Default::default()
619                                });
620                                break;
621                            }
622                        }
623                    }
624                }
625            }
626        }
627
628        findings
629    }
630}
631
632// ============================================================================
633// Registration
634// ============================================================================
635
636/// Register all advanced input rules with the rule engine.
637pub fn register_advanced_input_rules(engine: &mut crate::RuleEngine) {
638    engine.register_rule(Box::new(InsecureBinaryDeserializationRule::new()));
639    engine.register_rule(Box::new(RegexBacktrackingDosRule::new()));
640    engine.register_rule(Box::new(UncontrolledAllocationSizeRule::new()));
641    engine.register_rule(Box::new(IntegerOverflowRule::new()));
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_binary_deser_metadata() {
650        let rule = InsecureBinaryDeserializationRule::new();
651        assert_eq!(rule.metadata().id, "RUSTCOLA201");
652    }
653
654    #[test]
655    fn test_regex_dos_metadata() {
656        let rule = RegexBacktrackingDosRule::new();
657        assert_eq!(rule.metadata().id, "RUSTCOLA202");
658    }
659
660    #[test]
661    fn test_allocation_size_metadata() {
662        let rule = UncontrolledAllocationSizeRule::new();
663        assert_eq!(rule.metadata().id, "RUSTCOLA203");
664    }
665
666    #[test]
667    fn test_integer_overflow_metadata() {
668        let rule = IntegerOverflowRule::new();
669        assert_eq!(rule.metadata().id, "RUSTCOLA204");
670    }
671}