1use std::collections::{HashMap, HashSet};
9
10use crate::{
11 interprocedural::InterProceduralAnalysis, AttackComplexity, AttackVector, Confidence,
12 Exploitability, Finding, MirPackage, PrivilegesRequired, Rule, RuleMetadata, RuleOrigin,
13 Severity, UserInteraction,
14};
15
16use super::advanced_utils::{
17 contains_var, detect_assignment, detect_drop_calls, detect_storage_dead_vars, detect_var_alias,
18 extract_call_args, TaintTracker,
19};
20
21pub struct TemplateInjectionRule {
27 metadata: RuleMetadata,
28}
29
30impl Default for TemplateInjectionRule {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl TemplateInjectionRule {
37 pub fn new() -> Self {
38 Self {
39 metadata: RuleMetadata {
40 id: "RUSTCOLA205".to_string(),
41 name: "template-injection".to_string(),
42 short_description: "Detects unescaped user input in HTML responses".to_string(),
43 full_description: "User input included in HTML responses without proper escaping \
44 can lead to Cross-Site Scripting (XSS) attacks. Attackers can inject \
45 malicious JavaScript that executes in victims' browsers."
46 .to_string(),
47 help_uri: None,
48 default_severity: Severity::High,
49 origin: RuleOrigin::BuiltIn,
50 cwe_ids: vec!["79".to_string()], fix_suggestion: Some(
52 "Use HTML escaping functions like html_escape::encode_safe() or template \
53 engines with auto-escaping. Consider using ammonia for HTML sanitization."
54 .to_string(),
55 ),
56 exploitability: Exploitability {
57 attack_vector: AttackVector::Network,
58 attack_complexity: AttackComplexity::Low,
59 privileges_required: PrivilegesRequired::None,
60 user_interaction: UserInteraction::Required, },
62 },
63 }
64 }
65
66 const UNTRUSTED_PATTERNS: &'static [&'static str] = &[
67 "env::var",
68 "env::var_os",
69 "env::args",
70 "std::env::var",
71 "std::env::args",
72 "stdin",
73 "TcpStream",
74 "read_to_string",
75 "read_to_end",
76 "axum::extract",
77 "warp::filters::path::param",
78 "warp::filters::query::query",
79 "Request::uri",
80 "Request::body",
81 "hyper::body::to_bytes",
82 ];
83
84 const SANITIZER_PATTERNS: &'static [&'static str] = &[
85 "html_escape::encode_safe",
86 "html_escape::encode_double_quoted_attribute",
87 "html_escape::encode_single_quoted_attribute",
88 "ammonia::clean",
89 "maud::Escaped",
90 ];
91
92 const SINK_PATTERNS: &'static [&'static str] = &[
93 "warp::reply::html",
94 "axum::response::Html::from",
95 "axum::response::Html::new",
96 "rocket::response::content::Html::new",
97 "rocket::response::content::Html::from",
98 "HttpResponse::body",
99 "HttpResponse::Ok",
100 ];
101}
102
103impl Rule for TemplateInjectionRule {
104 fn metadata(&self) -> &RuleMetadata {
105 &self.metadata
106 }
107
108 fn evaluate(
109 &self,
110 package: &MirPackage,
111 _inter_analysis: Option<&InterProceduralAnalysis>,
112 ) -> Vec<Finding> {
113 let mut findings = Vec::new();
114
115 for func in &package.functions {
116 let mir_text = func.body.join("\n");
117 let mut tracker = TaintTracker::default();
118
119 for line in mir_text.lines() {
120 let trimmed = line.trim();
121 if trimmed.is_empty() {
122 continue;
123 }
124
125 if let Some(dest) = detect_assignment(trimmed) {
126 if Self::UNTRUSTED_PATTERNS.iter().any(|p| trimmed.contains(p)) {
127 tracker.mark_source(&dest, trimmed);
128 continue;
129 }
130
131 if Self::SANITIZER_PATTERNS.iter().any(|p| trimmed.contains(p)) {
133 if let Some(source) = tracker.find_tainted_in_line(trimmed) {
134 if let Some(root) = tracker.taint_roots.get(&source).cloned() {
135 tracker.sanitize_root(&root);
136 }
137 }
138 continue;
139 }
140
141 if let Some(source) = tracker.find_tainted_in_line(trimmed) {
142 tracker.mark_alias(&dest, &source);
143 }
144 }
145
146 if let Some(sink) = Self::SINK_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
148 let args = extract_call_args(trimmed);
149 for arg in args {
150 if let Some(root) = tracker.taint_roots.get(&arg).cloned() {
151 if tracker.sanitized_roots.contains(&root) {
152 continue;
153 }
154
155 let mut message = format!(
156 "Possible template injection: unescaped input flows into `{}`",
157 sink
158 );
159 if let Some(origin) = tracker.sources.get(&root) {
160 message.push_str(&format!("\n source: `{}`", origin));
161 }
162
163 findings.push(Finding {
164 rule_id: self.metadata.id.clone(),
165 rule_name: self.metadata.name.clone(),
166 severity: self.metadata.default_severity,
167 confidence: Confidence::High,
168 message,
169 function: func.name.clone(),
170 function_signature: func.signature.clone(),
171 evidence: vec![trimmed.to_string()],
172 span: func.span.clone(),
173 exploitability: self.metadata.exploitability.clone(),
174 exploitability_score: self.metadata.exploitability.score(),
175 ..Default::default()
176 });
177 break;
178 }
179 }
180 }
181 }
182 }
183
184 findings
185 }
186}
187
188pub struct UnsafeSendAcrossAsyncBoundaryRule {
194 metadata: RuleMetadata,
195}
196
197impl Default for UnsafeSendAcrossAsyncBoundaryRule {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl UnsafeSendAcrossAsyncBoundaryRule {
204 pub fn new() -> Self {
205 Self {
206 metadata: RuleMetadata {
207 id: "RUSTCOLA206".to_string(),
208 name: "unsafe-send-across-async".to_string(),
209 short_description: "Detects non-Send types captured in spawned tasks".to_string(),
210 full_description:
211 "Types like Rc<T> and RefCell<T> are not thread-safe (non-Send). \
212 When captured by tokio::spawn or similar multi-threaded executors, they can \
213 cause undefined behavior if accessed from multiple threads."
214 .to_string(),
215 help_uri: None,
216 default_severity: Severity::High,
217 origin: RuleOrigin::BuiltIn,
218 cwe_ids: vec!["362".to_string(), "366".to_string()], fix_suggestion: Some(
220 "Use Arc<T> instead of Rc<T> and Arc<Mutex<T>> instead of RefCell<T>. \
221 Alternatively, use spawn_local() for single-threaded execution."
222 .to_string(),
223 ),
224 exploitability: Exploitability {
225 attack_vector: AttackVector::Local,
226 attack_complexity: AttackComplexity::High,
227 privileges_required: PrivilegesRequired::None,
228 user_interaction: UserInteraction::None,
229 },
230 },
231 }
232 }
233
234 const NON_SEND_PATTERNS: &'static [&'static str] = &[
235 "alloc::rc::Rc",
236 "std::rc::Rc",
237 "core::cell::RefCell",
238 "std::cell::RefCell",
239 "alloc::rc::Weak",
240 "std::rc::Weak",
241 ];
242
243 const SAFE_PATTERNS: &'static [&'static str] = &[
244 "std::sync::Arc",
245 "alloc::sync::Arc",
246 "std::sync::Mutex",
247 "tokio::sync::Mutex",
248 ];
249
250 const SPAWN_PATTERNS: &'static [&'static str] = &[
251 "tokio::spawn",
252 "tokio::task::spawn",
253 "tokio::task::spawn_blocking",
254 "async_std::task::spawn",
255 "async_std::task::spawn_blocking",
256 "smol::spawn",
257 "futures::executor::ThreadPool::spawn",
258 "std::thread::spawn",
259 ];
260
261 const SPAWN_LOCAL_PATTERNS: &'static [&'static str] = &[
262 "tokio::task::spawn_local",
263 "async_std::task::spawn_local",
264 "smol::spawn_local",
265 "futures::task::SpawnExt::spawn_local",
266 ];
267}
268
269impl Rule for UnsafeSendAcrossAsyncBoundaryRule {
270 fn metadata(&self) -> &RuleMetadata {
271 &self.metadata
272 }
273
274 fn evaluate(
275 &self,
276 package: &MirPackage,
277 _inter_analysis: Option<&InterProceduralAnalysis>,
278 ) -> Vec<Finding> {
279 let mut findings = Vec::new();
280
281 for func in &package.functions {
282 let mir_text = func.body.join("\n");
283 let mut roots: HashMap<String, String> = HashMap::new();
284 let mut unsafe_roots: HashSet<String> = HashSet::new();
285 let mut safe_roots: HashSet<String> = HashSet::new();
286 let mut sources: HashMap<String, String> = HashMap::new();
287
288 for line in mir_text.lines() {
289 let trimmed = line.trim();
290 if trimmed.is_empty() {
291 continue;
292 }
293
294 if let Some(dest) = detect_assignment(trimmed) {
295 roots.remove(&dest);
296
297 if Self::NON_SEND_PATTERNS.iter().any(|p| trimmed.contains(p)) {
298 roots.insert(dest.clone(), dest.clone());
299 unsafe_roots.insert(dest.clone());
300 safe_roots.remove(&dest);
301 sources.entry(dest).or_insert_with(|| trimmed.to_string());
302 } else if Self::SAFE_PATTERNS.iter().any(|p| trimmed.contains(p)) {
303 roots.insert(dest.clone(), dest.clone());
304 safe_roots.insert(dest.clone());
305 unsafe_roots.remove(&dest);
306 } else if let Some(source) =
307 roots.keys().find(|v| contains_var(trimmed, v)).cloned()
308 {
309 if let Some(root) = roots.get(&source).cloned() {
310 roots.insert(dest, root);
311 }
312 }
313 }
314
315 if let Some(spawn) = Self::SPAWN_PATTERNS.iter().find(|p| trimmed.contains(*p)) {
317 if Self::SPAWN_LOCAL_PATTERNS
319 .iter()
320 .any(|p| trimmed.contains(p))
321 {
322 continue;
323 }
324
325 let args = extract_call_args(trimmed);
326 for arg in args {
327 if let Some(root) = roots.get(&arg).cloned() {
328 if safe_roots.contains(&root) {
329 continue;
330 }
331 if !unsafe_roots.contains(&root) {
332 continue;
333 }
334
335 let mut message = format!(
336 "Non-Send type captured in `{}` may cross thread boundary",
337 spawn
338 );
339 if let Some(origin) = sources.get(&root) {
340 message.push_str(&format!("\n source: `{}`", origin));
341 }
342
343 findings.push(Finding {
344 rule_id: self.metadata.id.clone(),
345 rule_name: self.metadata.name.clone(),
346 severity: self.metadata.default_severity,
347 confidence: Confidence::High,
348 message,
349 function: func.name.clone(),
350 function_signature: func.signature.clone(),
351 evidence: vec![trimmed.to_string()],
352 span: func.span.clone(),
353 exploitability: self.metadata.exploitability.clone(),
354 exploitability_score: self.metadata.exploitability.score(),
355 ..Default::default()
356 });
357 break;
358 }
359 }
360 }
361 }
362 }
363
364 findings
365 }
366}
367
368pub struct AwaitSpanGuardRule {
374 metadata: RuleMetadata,
375}
376
377impl Default for AwaitSpanGuardRule {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl AwaitSpanGuardRule {
384 pub fn new() -> Self {
385 Self {
386 metadata: RuleMetadata {
387 id: "RUSTCOLA207".to_string(),
388 name: "span-guard-across-await".to_string(),
389 short_description: "Detects tracing span guards held across await points"
390 .to_string(),
391 full_description: "Tracing span guards (from span.enter()) should not be held \
392 across await points. When a task is suspended, the span context may be \
393 incorrect when resumed, leading to confusing or incorrect trace data."
394 .to_string(),
395 help_uri: None,
396 default_severity: Severity::Low,
397 origin: RuleOrigin::BuiltIn,
398 cwe_ids: vec!["664".to_string()], fix_suggestion: Some(
400 "Use span.in_scope(|| async { ... }) instead of let _guard = span.enter(). \
401 Or use tracing::Instrument trait: future.instrument(span).await"
402 .to_string(),
403 ),
404 exploitability: Exploitability {
405 attack_vector: AttackVector::Local,
406 attack_complexity: AttackComplexity::High,
407 privileges_required: PrivilegesRequired::Low,
408 user_interaction: UserInteraction::None,
409 },
410 },
411 }
412 }
413
414 const GUARD_PATTERNS: &'static [&'static str] = &[
415 "tracing::Span::enter",
416 "tracing::span::Span::enter",
417 "tracing::dispatcher::DefaultGuard::new",
418 ];
419}
420
421#[derive(Clone)]
422struct GuardState {
423 origin: String,
424 count: usize,
425}
426
427impl Rule for AwaitSpanGuardRule {
428 fn metadata(&self) -> &RuleMetadata {
429 &self.metadata
430 }
431
432 fn evaluate(
433 &self,
434 package: &MirPackage,
435 _inter_analysis: Option<&InterProceduralAnalysis>,
436 ) -> Vec<Finding> {
437 let mut findings = Vec::new();
438
439 for func in &package.functions {
440 let mir_text = func.body.join("\n");
441 let mut var_to_root: HashMap<String, String> = HashMap::new();
442 let mut guard_states: HashMap<String, GuardState> = HashMap::new();
443 let mut reported: HashSet<String> = HashSet::new();
444
445 for line in mir_text.lines() {
446 let trimmed = line.trim();
447 if trimmed.is_empty() {
448 continue;
449 }
450
451 if let Some((dest, source)) = detect_var_alias(trimmed) {
453 if let Some(root) = var_to_root.remove(&dest) {
454 if let Some(state) = guard_states.get_mut(&root) {
455 if state.count > 1 {
456 state.count -= 1;
457 } else {
458 guard_states.remove(&root);
459 }
460 }
461 }
462 if let Some(root) = var_to_root.get(&source).cloned() {
463 var_to_root.insert(dest, root.clone());
464 if let Some(state) = guard_states.get_mut(&root) {
465 state.count += 1;
466 }
467 }
468 continue;
469 }
470
471 if let Some(dest) = detect_assignment(trimmed) {
473 if Self::GUARD_PATTERNS.iter().any(|p| trimmed.contains(p)) {
474 var_to_root.insert(dest.clone(), dest.clone());
475 guard_states.insert(
476 dest,
477 GuardState {
478 origin: trimmed.to_string(),
479 count: 1,
480 },
481 );
482 }
483 }
484
485 for var in detect_drop_calls(trimmed) {
487 if let Some(root) = var_to_root.remove(&var) {
488 if let Some(state) = guard_states.get_mut(&root) {
489 if state.count > 1 {
490 state.count -= 1;
491 } else {
492 guard_states.remove(&root);
493 }
494 }
495 }
496 }
497
498 for var in detect_storage_dead_vars(trimmed) {
499 if let Some(root) = var_to_root.remove(&var) {
500 if let Some(state) = guard_states.get_mut(&root) {
501 if state.count > 1 {
502 state.count -= 1;
503 } else {
504 guard_states.remove(&root);
505 }
506 }
507 }
508 }
509
510 if trimmed.contains(".await")
512 || trimmed.contains("Await")
513 || trimmed.contains("await ")
514 {
515 for (root, state) in guard_states.iter() {
516 let key = format!("{}::{}", root, trimmed.trim());
517 if !reported.insert(key) {
518 continue;
519 }
520
521 findings.push(Finding {
522 rule_id: self.metadata.id.clone(),
523 rule_name: self.metadata.name.clone(),
524 severity: self.metadata.default_severity,
525 confidence: Confidence::Medium,
526 message: format!(
527 "Span guard `{}` held across await point\n created: `{}`",
528 root, state.origin
529 ),
530 function: func.name.clone(),
531 function_signature: func.signature.clone(),
532 evidence: vec![trimmed.to_string()],
533 span: func.span.clone(),
534 exploitability: self.metadata.exploitability.clone(),
535 exploitability_score: self.metadata.exploitability.score(),
536 ..Default::default()
537 });
538 }
539 }
540 }
541 }
542
543 findings
544 }
545}
546
547pub fn register_advanced_async_rules(engine: &mut crate::RuleEngine) {
553 engine.register_rule(Box::new(TemplateInjectionRule::new()));
554 engine.register_rule(Box::new(UnsafeSendAcrossAsyncBoundaryRule::new()));
555 engine.register_rule(Box::new(AwaitSpanGuardRule::new()));
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn test_template_injection_metadata() {
564 let rule = TemplateInjectionRule::new();
565 assert_eq!(rule.metadata().id, "RUSTCOLA205");
566 }
567
568 #[test]
569 fn test_unsafe_send_metadata() {
570 let rule = UnsafeSendAcrossAsyncBoundaryRule::new();
571 assert_eq!(rule.metadata().id, "RUSTCOLA206");
572 }
573
574 #[test]
575 fn test_span_guard_metadata() {
576 let rule = AwaitSpanGuardRule::new();
577 assert_eq!(rule.metadata().id, "RUSTCOLA207");
578 }
579}