Skip to main content

mir_extractor/rules/
advanced_async.rs

1//! Advanced async and web rules.
2//!
3//! Deep dataflow analysis for:
4//! - ADV005/RUSTCOLA205: Template/XSS injection
5//! - ADV006/RUSTCOLA206: Non-Send types across async boundaries
6//! - ADV007/RUSTCOLA207: Tracing span guards held across await
7
8use std::collections::{HashMap, HashSet};
9
10use crate::{
11    interprocedural::InterProceduralAnalysis, AttackComplexity, AttackVector, Confidence,
12    Exploitability, Finding, MirPackage, PrivilegesRequired, Rule, RuleMetadata, RuleOrigin,
13    Severity, UserInteraction,
14};
15
16use super::advanced_utils::{
17    contains_var, detect_assignment, detect_drop_calls, detect_storage_dead_vars, detect_var_alias,
18    extract_call_args, TaintTracker,
19};
20
21// ============================================================================
22// RUSTCOLA205: Template/XSS Injection (was ADV005)
23// ============================================================================
24
25/// Detects template/HTML responses built from unescaped user input.
26pub struct TemplateInjectionRule {
27    metadata: RuleMetadata,
28}
29
30impl Default for TemplateInjectionRule {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl TemplateInjectionRule {
37    pub fn new() -> Self {
38        Self {
39            metadata: RuleMetadata {
40                id: "RUSTCOLA205".to_string(),
41                name: "template-injection".to_string(),
42                short_description: "Detects unescaped user input in HTML responses".to_string(),
43                full_description: "User input included in HTML responses without proper escaping \
44                    can lead to Cross-Site Scripting (XSS) attacks. Attackers can inject \
45                    malicious JavaScript that executes in victims' browsers."
46                    .to_string(),
47                help_uri: None,
48                default_severity: Severity::High,
49                origin: RuleOrigin::BuiltIn,
50                cwe_ids: vec!["79".to_string()], // CWE-79: XSS
51                fix_suggestion: Some(
52                    "Use HTML escaping functions like html_escape::encode_safe() or template \
53                    engines with auto-escaping. Consider using ammonia for HTML sanitization."
54                        .to_string(),
55                ),
56                exploitability: Exploitability {
57                    attack_vector: AttackVector::Network,
58                    attack_complexity: AttackComplexity::Low,
59                    privileges_required: PrivilegesRequired::None,
60                    user_interaction: UserInteraction::Required, // User must visit page
61                },
62            },
63        }
64    }
65
66    const UNTRUSTED_PATTERNS: &'static [&'static str] = &[
67        "env::var",
68        "env::var_os",
69        "env::args",
70        "std::env::var",
71        "std::env::args",
72        "stdin",
73        "TcpStream",
74        "read_to_string",
75        "read_to_end",
76        "axum::extract",
77        "warp::filters::path::param",
78        "warp::filters::query::query",
79        "Request::uri",
80        "Request::body",
81        "hyper::body::to_bytes",
82    ];
83
84    const SANITIZER_PATTERNS: &'static [&'static str] = &[
85        "html_escape::encode_safe",
86        "html_escape::encode_double_quoted_attribute",
87        "html_escape::encode_single_quoted_attribute",
88        "ammonia::clean",
89        "maud::Escaped",
90    ];
91
92    const SINK_PATTERNS: &'static [&'static str] = &[
93        "warp::reply::html",
94        "axum::response::Html::from",
95        "axum::response::Html::new",
96        "rocket::response::content::Html::new",
97        "rocket::response::content::Html::from",
98        "HttpResponse::body",
99        "HttpResponse::Ok",
100    ];
101}
102
103impl Rule for TemplateInjectionRule {
104    fn metadata(&self) -> &RuleMetadata {
105        &self.metadata
106    }
107
108    fn evaluate(
109        &self,
110        package: &MirPackage,
111        _inter_analysis: Option<&InterProceduralAnalysis>,
112    ) -> Vec<Finding> {
113        let mut findings = Vec::new();
114
115        for func in &package.functions {
116            let mir_text = func.body.join("\n");
117            let mut tracker = TaintTracker::default();
118
119            for line in mir_text.lines() {
120                let trimmed = line.trim();
121                if trimmed.is_empty() {
122                    continue;
123                }
124
125                if let Some(dest) = detect_assignment(trimmed) {
126                    if Self::UNTRUSTED_PATTERNS.iter().any(|p| trimmed.contains(p)) {
127                        tracker.mark_source(&dest, trimmed);
128                        continue;
129                    }
130
131                    // Check for sanitization
132                    if Self::SANITIZER_PATTERNS.iter().any(|p| trimmed.contains(p)) {
133                        if let Some(source) = tracker.find_tainted_in_line(trimmed) {
134                            if let Some(root) = tracker.taint_roots.get(&source).cloned() {
135                                tracker.sanitize_root(&root);
136                            }
137                        }
138                        continue;
139                    }
140
141                    if let Some(source) = tracker.find_tainted_in_line(trimmed) {
142                        tracker.mark_alias(&dest, &source);
143                    }
144                }
145
146                // Check sinks
147                if let Some(sink) = Self::SINK_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
148                    let args = extract_call_args(trimmed);
149                    for arg in args {
150                        if let Some(root) = tracker.taint_roots.get(&arg).cloned() {
151                            if tracker.sanitized_roots.contains(&root) {
152                                continue;
153                            }
154
155                            let mut message = format!(
156                                "Possible template injection: unescaped input flows into `{}`",
157                                sink
158                            );
159                            if let Some(origin) = tracker.sources.get(&root) {
160                                message.push_str(&format!("\n  source: `{}`", origin));
161                            }
162
163                            findings.push(Finding {
164                                rule_id: self.metadata.id.clone(),
165                                rule_name: self.metadata.name.clone(),
166                                severity: self.metadata.default_severity,
167                                confidence: Confidence::High,
168                                message,
169                                function: func.name.clone(),
170                                function_signature: func.signature.clone(),
171                                evidence: vec![trimmed.to_string()],
172                                span: func.span.clone(),
173                                exploitability: self.metadata.exploitability.clone(),
174                                exploitability_score: self.metadata.exploitability.score(),
175                                ..Default::default()
176                            });
177                            break;
178                        }
179                    }
180                }
181            }
182        }
183
184        findings
185    }
186}
187
188// ============================================================================
189// RUSTCOLA206: Non-Send Types Across Async Boundaries (was ADV006)
190// ============================================================================
191
192/// Detects non-Send types captured by multi-threaded async executors.
193pub struct UnsafeSendAcrossAsyncBoundaryRule {
194    metadata: RuleMetadata,
195}
196
197impl Default for UnsafeSendAcrossAsyncBoundaryRule {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl UnsafeSendAcrossAsyncBoundaryRule {
204    pub fn new() -> Self {
205        Self {
206            metadata: RuleMetadata {
207                id: "RUSTCOLA206".to_string(),
208                name: "unsafe-send-across-async".to_string(),
209                short_description: "Detects non-Send types captured in spawned tasks".to_string(),
210                full_description:
211                    "Types like Rc<T> and RefCell<T> are not thread-safe (non-Send). \
212                    When captured by tokio::spawn or similar multi-threaded executors, they can \
213                    cause undefined behavior if accessed from multiple threads."
214                        .to_string(),
215                help_uri: None,
216                default_severity: Severity::High,
217                origin: RuleOrigin::BuiltIn,
218                cwe_ids: vec!["362".to_string(), "366".to_string()], // Race Condition, Race Condition Within Thread
219                fix_suggestion: Some(
220                    "Use Arc<T> instead of Rc<T> and Arc<Mutex<T>> instead of RefCell<T>. \
221                    Alternatively, use spawn_local() for single-threaded execution."
222                        .to_string(),
223                ),
224                exploitability: Exploitability {
225                    attack_vector: AttackVector::Local,
226                    attack_complexity: AttackComplexity::High,
227                    privileges_required: PrivilegesRequired::None,
228                    user_interaction: UserInteraction::None,
229                },
230            },
231        }
232    }
233
234    const NON_SEND_PATTERNS: &'static [&'static str] = &[
235        "alloc::rc::Rc",
236        "std::rc::Rc",
237        "core::cell::RefCell",
238        "std::cell::RefCell",
239        "alloc::rc::Weak",
240        "std::rc::Weak",
241    ];
242
243    const SAFE_PATTERNS: &'static [&'static str] = &[
244        "std::sync::Arc",
245        "alloc::sync::Arc",
246        "std::sync::Mutex",
247        "tokio::sync::Mutex",
248    ];
249
250    const SPAWN_PATTERNS: &'static [&'static str] = &[
251        "tokio::spawn",
252        "tokio::task::spawn",
253        "tokio::task::spawn_blocking",
254        "async_std::task::spawn",
255        "async_std::task::spawn_blocking",
256        "smol::spawn",
257        "futures::executor::ThreadPool::spawn",
258        "std::thread::spawn",
259    ];
260
261    const SPAWN_LOCAL_PATTERNS: &'static [&'static str] = &[
262        "tokio::task::spawn_local",
263        "async_std::task::spawn_local",
264        "smol::spawn_local",
265        "futures::task::SpawnExt::spawn_local",
266    ];
267}
268
269impl Rule for UnsafeSendAcrossAsyncBoundaryRule {
270    fn metadata(&self) -> &RuleMetadata {
271        &self.metadata
272    }
273
274    fn evaluate(
275        &self,
276        package: &MirPackage,
277        _inter_analysis: Option<&InterProceduralAnalysis>,
278    ) -> Vec<Finding> {
279        let mut findings = Vec::new();
280
281        for func in &package.functions {
282            let mir_text = func.body.join("\n");
283            let mut roots: HashMap<String, String> = HashMap::new();
284            let mut unsafe_roots: HashSet<String> = HashSet::new();
285            let mut safe_roots: HashSet<String> = HashSet::new();
286            let mut sources: HashMap<String, String> = HashMap::new();
287
288            for line in mir_text.lines() {
289                let trimmed = line.trim();
290                if trimmed.is_empty() {
291                    continue;
292                }
293
294                if let Some(dest) = detect_assignment(trimmed) {
295                    roots.remove(&dest);
296
297                    if Self::NON_SEND_PATTERNS.iter().any(|p| trimmed.contains(p)) {
298                        roots.insert(dest.clone(), dest.clone());
299                        unsafe_roots.insert(dest.clone());
300                        safe_roots.remove(&dest);
301                        sources.entry(dest).or_insert_with(|| trimmed.to_string());
302                    } else if Self::SAFE_PATTERNS.iter().any(|p| trimmed.contains(p)) {
303                        roots.insert(dest.clone(), dest.clone());
304                        safe_roots.insert(dest.clone());
305                        unsafe_roots.remove(&dest);
306                    } else if let Some(source) =
307                        roots.keys().find(|v| contains_var(trimmed, v)).cloned()
308                    {
309                        if let Some(root) = roots.get(&source).cloned() {
310                            roots.insert(dest, root);
311                        }
312                    }
313                }
314
315                // Check spawn calls
316                if let Some(spawn) = Self::SPAWN_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
317                    // Skip spawn_local
318                    if Self::SPAWN_LOCAL_PATTERNS
319                        .iter()
320                        .any(|p| trimmed.contains(p))
321                    {
322                        continue;
323                    }
324
325                    let args = extract_call_args(trimmed);
326                    for arg in args {
327                        if let Some(root) = roots.get(&arg).cloned() {
328                            if safe_roots.contains(&root) {
329                                continue;
330                            }
331                            if !unsafe_roots.contains(&root) {
332                                continue;
333                            }
334
335                            let mut message = format!(
336                                "Non-Send type captured in `{}` may cross thread boundary",
337                                spawn
338                            );
339                            if let Some(origin) = sources.get(&root) {
340                                message.push_str(&format!("\n  source: `{}`", origin));
341                            }
342
343                            findings.push(Finding {
344                                rule_id: self.metadata.id.clone(),
345                                rule_name: self.metadata.name.clone(),
346                                severity: self.metadata.default_severity,
347                                confidence: Confidence::High,
348                                message,
349                                function: func.name.clone(),
350                                function_signature: func.signature.clone(),
351                                evidence: vec![trimmed.to_string()],
352                                span: func.span.clone(),
353                                exploitability: self.metadata.exploitability.clone(),
354                                exploitability_score: self.metadata.exploitability.score(),
355                                ..Default::default()
356                            });
357                            break;
358                        }
359                    }
360                }
361            }
362        }
363
364        findings
365    }
366}
367
368// ============================================================================
369// RUSTCOLA207: Span Guard Held Across Await (was ADV007)
370// ============================================================================
371
372/// Detects tracing span guards held across await points.
373pub struct AwaitSpanGuardRule {
374    metadata: RuleMetadata,
375}
376
377impl Default for AwaitSpanGuardRule {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383impl AwaitSpanGuardRule {
384    pub fn new() -> Self {
385        Self {
386            metadata: RuleMetadata {
387                id: "RUSTCOLA207".to_string(),
388                name: "span-guard-across-await".to_string(),
389                short_description: "Detects tracing span guards held across await points"
390                    .to_string(),
391                full_description: "Tracing span guards (from span.enter()) should not be held \
392                    across await points. When a task is suspended, the span context may be \
393                    incorrect when resumed, leading to confusing or incorrect trace data."
394                    .to_string(),
395                help_uri: None,
396                default_severity: Severity::Low,
397                origin: RuleOrigin::BuiltIn,
398                cwe_ids: vec!["664".to_string()], // CWE-664: Improper Control of a Resource
399                fix_suggestion: Some(
400                    "Use span.in_scope(|| async { ... }) instead of let _guard = span.enter(). \
401                    Or use tracing::Instrument trait: future.instrument(span).await"
402                        .to_string(),
403                ),
404                exploitability: Exploitability {
405                    attack_vector: AttackVector::Local,
406                    attack_complexity: AttackComplexity::High,
407                    privileges_required: PrivilegesRequired::Low,
408                    user_interaction: UserInteraction::None,
409                },
410            },
411        }
412    }
413
414    const GUARD_PATTERNS: &'static [&'static str] = &[
415        "tracing::Span::enter",
416        "tracing::span::Span::enter",
417        "tracing::dispatcher::DefaultGuard::new",
418    ];
419}
420
421#[derive(Clone)]
422struct GuardState {
423    origin: String,
424    count: usize,
425}
426
427impl Rule for AwaitSpanGuardRule {
428    fn metadata(&self) -> &RuleMetadata {
429        &self.metadata
430    }
431
432    fn evaluate(
433        &self,
434        package: &MirPackage,
435        _inter_analysis: Option<&InterProceduralAnalysis>,
436    ) -> Vec<Finding> {
437        let mut findings = Vec::new();
438
439        for func in &package.functions {
440            let mir_text = func.body.join("\n");
441            let mut var_to_root: HashMap<String, String> = HashMap::new();
442            let mut guard_states: HashMap<String, GuardState> = HashMap::new();
443            let mut reported: HashSet<String> = HashSet::new();
444
445            for line in mir_text.lines() {
446                let trimmed = line.trim();
447                if trimmed.is_empty() {
448                    continue;
449                }
450
451                // Handle aliases
452                if let Some((dest, source)) = detect_var_alias(trimmed) {
453                    if let Some(root) = var_to_root.remove(&dest) {
454                        if let Some(state) = guard_states.get_mut(&root) {
455                            if state.count > 1 {
456                                state.count -= 1;
457                            } else {
458                                guard_states.remove(&root);
459                            }
460                        }
461                    }
462                    if let Some(root) = var_to_root.get(&source).cloned() {
463                        var_to_root.insert(dest, root.clone());
464                        if let Some(state) = guard_states.get_mut(&root) {
465                            state.count += 1;
466                        }
467                    }
468                    continue;
469                }
470
471                // Detect guard creation
472                if let Some(dest) = detect_assignment(trimmed) {
473                    if Self::GUARD_PATTERNS.iter().any(|p| trimmed.contains(p)) {
474                        var_to_root.insert(dest.clone(), dest.clone());
475                        guard_states.insert(
476                            dest,
477                            GuardState {
478                                origin: trimmed.to_string(),
479                                count: 1,
480                            },
481                        );
482                    }
483                }
484
485                // Handle drops
486                for var in detect_drop_calls(trimmed) {
487                    if let Some(root) = var_to_root.remove(&var) {
488                        if let Some(state) = guard_states.get_mut(&root) {
489                            if state.count > 1 {
490                                state.count -= 1;
491                            } else {
492                                guard_states.remove(&root);
493                            }
494                        }
495                    }
496                }
497
498                for var in detect_storage_dead_vars(trimmed) {
499                    if let Some(root) = var_to_root.remove(&var) {
500                        if let Some(state) = guard_states.get_mut(&root) {
501                            if state.count > 1 {
502                                state.count -= 1;
503                            } else {
504                                guard_states.remove(&root);
505                            }
506                        }
507                    }
508                }
509
510                // Check for await with active guards
511                if trimmed.contains(".await")
512                    || trimmed.contains("Await")
513                    || trimmed.contains("await ")
514                {
515                    for (root, state) in guard_states.iter() {
516                        let key = format!("{}::{}", root, trimmed.trim());
517                        if !reported.insert(key) {
518                            continue;
519                        }
520
521                        findings.push(Finding {
522                            rule_id: self.metadata.id.clone(),
523                            rule_name: self.metadata.name.clone(),
524                            severity: self.metadata.default_severity,
525                            confidence: Confidence::Medium,
526                            message: format!(
527                                "Span guard `{}` held across await point\n  created: `{}`",
528                                root, state.origin
529                            ),
530                            function: func.name.clone(),
531                            function_signature: func.signature.clone(),
532                            evidence: vec![trimmed.to_string()],
533                            span: func.span.clone(),
534                            exploitability: self.metadata.exploitability.clone(),
535                            exploitability_score: self.metadata.exploitability.score(),
536                            ..Default::default()
537                        });
538                    }
539                }
540            }
541        }
542
543        findings
544    }
545}
546
547// ============================================================================
548// Registration
549// ============================================================================
550
551/// Register all advanced async/web rules with the rule engine.
552pub fn register_advanced_async_rules(engine: &mut crate::RuleEngine) {
553    engine.register_rule(Box::new(TemplateInjectionRule::new()));
554    engine.register_rule(Box::new(UnsafeSendAcrossAsyncBoundaryRule::new()));
555    engine.register_rule(Box::new(AwaitSpanGuardRule::new()));
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_template_injection_metadata() {
564        let rule = TemplateInjectionRule::new();
565        assert_eq!(rule.metadata().id, "RUSTCOLA205");
566    }
567
568    #[test]
569    fn test_unsafe_send_metadata() {
570        let rule = UnsafeSendAcrossAsyncBoundaryRule::new();
571        assert_eq!(rule.metadata().id, "RUSTCOLA206");
572    }
573
574    #[test]
575    fn test_span_guard_metadata() {
576        let rule = AwaitSpanGuardRule::new();
577        assert_eq!(rule.metadata().id, "RUSTCOLA207");
578    }
579}