1use super::MirDataflow;
5use crate::{Finding, MirFunction, RuleMetadata, Severity, SourceSpan};
6use std::collections::{HashMap, HashSet, VecDeque};
7
8#[derive(Debug, Clone)]
10struct BasicBlock {
11 id: String, statements: Vec<String>, terminator: Option<String>, successors: Vec<String>, }
16
17struct ControlFlowGraph {
19 blocks: HashMap<String, BasicBlock>,
20 _entry_block: String, }
22
23impl ControlFlowGraph {
24 fn from_mir(function: &MirFunction) -> Self {
26 let mut blocks = HashMap::new();
27 let mut current_block: Option<BasicBlock> = None;
28 let entry_block = "bb0".to_string();
29
30 for line in &function.body {
31 let trimmed = line.trim();
32
33 if trimmed.starts_with("bb") && trimmed.contains(": {") {
35 if let Some(block) = current_block.take() {
37 blocks.insert(block.id.clone(), block);
38 }
39
40 let id = trimmed.split(':').next().unwrap().trim().to_string();
42 current_block = Some(BasicBlock {
43 id,
44 statements: Vec::new(),
45 terminator: None,
46 successors: Vec::new(),
47 });
48 }
49 else if trimmed.contains("goto")
51 || trimmed.contains("switchInt")
52 || trimmed.contains("return")
53 || trimmed.contains("-> [return:")
54 {
55 if let Some(ref mut block) = current_block {
56 block.terminator = Some(trimmed.to_string());
57 block.successors = Self::extract_successors(trimmed);
59 }
60 }
61 else if !trimmed.is_empty() && !trimmed.starts_with("}") && current_block.is_some() {
63 if let Some(ref mut block) = current_block {
64 block.statements.push(trimmed.to_string());
65 }
66 }
67 }
68
69 if let Some(block) = current_block {
71 blocks.insert(block.id.clone(), block);
72 }
73
74 Self {
75 blocks,
76 _entry_block: entry_block,
77 }
78 }
79
80 fn extract_successors(terminator: &str) -> Vec<String> {
82 let mut successors = Vec::new();
83
84 let mut i = 0;
86 let chars: Vec<char> = terminator.chars().collect();
87 while i < chars.len() {
88 if i + 1 < chars.len() && chars[i] == 'b' && chars[i + 1] == 'b' {
89 i += 2;
90 let mut num = String::new();
91 while i < chars.len() && chars[i].is_ascii_digit() {
92 num.push(chars[i]);
93 i += 1;
94 }
95 if !num.is_empty() {
96 successors.push(format!("bb{}", num));
97 }
98 } else {
99 i += 1;
100 }
101 }
102
103 successors
104 }
105
106 fn is_guarded_by(&self, sink_block_id: &str, guard_var: &str) -> bool {
109 for (_block_id, block) in &self.blocks {
111 if let Some(ref terminator) = block.terminator {
112 if terminator.contains("switchInt") && terminator.contains(guard_var) {
115 let successors = &block.successors;
117
118 if successors.len() >= 2 {
121 let true_branch = &successors[1]; return self.is_reachable_from(true_branch, sink_block_id);
125 }
126 }
127 }
128 }
129
130 false
131 }
132
133 fn is_reachable_from(&self, start_block: &str, target_block: &str) -> bool {
135 if start_block == target_block {
136 return true;
137 }
138
139 let mut visited = HashSet::new();
140 let mut queue = VecDeque::new();
141 queue.push_back(start_block.to_string());
142 visited.insert(start_block.to_string());
143
144 while let Some(block_id) = queue.pop_front() {
145 if block_id == target_block {
146 return true;
147 }
148
149 if let Some(block) = self.blocks.get(&block_id) {
150 for successor in &block.successors {
151 if !visited.contains(successor) {
152 visited.insert(successor.clone());
153 queue.push_back(successor.clone());
154 }
155 }
156 }
157 }
158
159 false
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Eq, Hash)]
165pub enum TaintSourceKind {
166 EnvironmentVariable, NetworkInput, FileInput, CommandOutput, UserInput, }
172
173#[derive(Debug, Clone, PartialEq, Eq, Hash)]
175pub enum TaintSinkKind {
176 CommandExecution, FileSystemOp, SqlQuery, RegexCompile, NetworkWrite, }
182
183#[derive(Debug, Clone)]
185pub struct TaintSource {
186 pub kind: TaintSourceKind,
187 pub variable: String, pub source_line: String, pub confidence: f32, }
191
192#[derive(Debug, Clone)]
194pub struct TaintSink {
195 pub kind: TaintSinkKind,
196 pub variable: String, pub sink_line: String, pub severity: Severity,
199}
200
201pub struct SourceRegistry {
203 patterns: Vec<SourcePattern>,
204}
205
206struct SourcePattern {
207 kind: TaintSourceKind,
208 function_patterns: Vec<&'static str>,
209 _severity: Severity,
210}
211
212impl SourceRegistry {
213 pub fn new() -> Self {
214 Self {
215 patterns: vec![
216 SourcePattern {
217 kind: TaintSourceKind::EnvironmentVariable,
218 function_patterns: vec![
219 " = var::", " = var(", "std::env::var(", "std::env::var_os(",
223 "core::env::var(",
224 "core::env::var_os(",
225 ],
226 _severity: Severity::Medium,
227 },
228 ],
230 }
231 }
232
233 pub fn detect_sources(&self, function: &MirFunction) -> Vec<TaintSource> {
235 let mut sources = Vec::new();
236
237 for line in &function.body {
238 for pattern in &self.patterns {
239 for func_pattern in &pattern.function_patterns {
240 if line.contains(func_pattern) {
241 if let Some(target) = extract_assignment_target(line) {
243 sources.push(TaintSource {
244 kind: pattern.kind.clone(),
245 variable: target,
246 source_line: line.trim().to_string(),
247 confidence: 1.0,
248 });
249 }
250 }
251 }
252 }
253 }
254
255 sources
256 }
257}
258
259pub struct SinkRegistry {
261 patterns: Vec<SinkPattern>,
262}
263
264struct SinkPattern {
265 kind: TaintSinkKind,
266 function_patterns: Vec<&'static str>,
267 severity: Severity,
268}
269
270impl SinkRegistry {
271 pub fn new() -> Self {
272 Self {
273 patterns: vec![
274 SinkPattern {
275 kind: TaintSinkKind::CommandExecution,
276 function_patterns: vec![
277 "Command::new::", "Command::arg::", "Command::args::", ],
281 severity: Severity::High,
282 },
283 SinkPattern {
284 kind: TaintSinkKind::FileSystemOp,
285 function_patterns: vec![
286 "std::fs::write::",
287 "std::fs::remove_file::",
288 "std::fs::remove_dir::",
289 "std::path::Path::join::",
290 ],
291 severity: Severity::Medium,
292 },
293 ],
295 }
296 }
297
298 pub fn detect_sinks(
300 &self,
301 function: &MirFunction,
302 tainted_vars: &HashSet<String>,
303 ) -> Vec<TaintSink> {
304 let mut sinks = Vec::new();
305
306 for line in &function.body {
307 for pattern in &self.patterns {
308 for func_pattern in &pattern.function_patterns {
309 if line.contains(func_pattern) {
310 let used_vars = super::extract_variables(line);
312
313 for var in used_vars {
315 if tainted_vars.contains(&var) {
316 sinks.push(TaintSink {
317 kind: pattern.kind.clone(),
318 variable: var,
319 sink_line: line.trim().to_string(),
320 severity: pattern.severity,
321 });
322 break; }
324 }
325 }
326 }
327 }
328 }
329
330 sinks
331 }
332}
333
334pub struct SanitizerRegistry {
336 pub(crate) patterns: Vec<SanitizerPattern>,
337}
338
339pub(crate) struct SanitizerPattern {
340 pub(crate) function_patterns: Vec<&'static str>,
341 pub(crate) sanitizes: Vec<TaintSinkKind>, }
343
344impl SanitizerRegistry {
345 pub fn new() -> Self {
346 Self {
347 patterns: vec![
348 SanitizerPattern {
349 function_patterns: vec!["::parse::<"],
352 sanitizes: vec![TaintSinkKind::CommandExecution, TaintSinkKind::FileSystemOp],
353 },
354 SanitizerPattern {
355 function_patterns: vec![" as Iterator>::all::<"],
358 sanitizes: vec![TaintSinkKind::CommandExecution, TaintSinkKind::FileSystemOp],
359 },
360 ],
362 }
363 }
364
365 pub fn is_sanitized(
368 &self,
369 function: &MirFunction,
370 var: &str,
371 sink_kind: &TaintSinkKind,
372 ) -> bool {
373 for line in &function.body {
375 if line.contains(var) {
377 for pattern in &self.patterns {
379 if pattern.sanitizes.contains(sink_kind) {
381 for func_pattern in &pattern.function_patterns {
382 if line.contains(func_pattern) {
383 return true;
385 }
386 }
387 }
388 }
389 }
390 }
391 false
392 }
393}
394
395pub struct TaintAnalysis {
397 source_registry: SourceRegistry,
398 sink_registry: SinkRegistry,
399 sanitizer_registry: SanitizerRegistry,
400}
401
402impl TaintAnalysis {
403 pub fn new() -> Self {
404 Self {
405 source_registry: SourceRegistry::new(),
406 sink_registry: SinkRegistry::new(),
407 sanitizer_registry: SanitizerRegistry::new(),
408 }
409 }
410
411 pub fn analyze(&self, function: &MirFunction) -> (HashSet<String>, Vec<TaintFlow>) {
414 let is_target_function = function.name.contains("sanitized_parse")
415 || function.name.contains("sanitized_allowlist");
416
417 if is_target_function {
418 eprintln!(
419 "\n========== ANALYZING TARGET FUNCTION: {} ==========",
420 function.name
421 );
422 }
423
424 let sources = self.source_registry.detect_sources(function);
426
427 if is_target_function {
428 eprintln!("Found {} sources", sources.len());
429
430 eprintln!("\n--- MIR Basic Block Structure ---");
432 for line in &function.body {
433 let trimmed = line.trim();
434 if trimmed.starts_with("bb") && trimmed.contains(':') {
435 eprintln!("{}", trimmed);
436 } else if trimmed.contains("switchInt")
437 || trimmed.contains("goto")
438 || trimmed.contains("return")
439 {
440 eprintln!(" {}", trimmed);
441 }
442 }
443 eprintln!("--- End Basic Blocks ---\n");
444 }
445
446 if sources.is_empty() {
447 return (HashSet::new(), Vec::new());
448 }
449
450 let sanitized_vars = self.detect_sanitized_variables(function, &sources);
453
454 let dataflow = MirDataflow::new(function);
456
457 let mut tainted_vars = HashSet::new();
458 for source in &sources {
459 tainted_vars.insert(source.variable.clone());
460 }
461
462 let tainted = dataflow
464 .taint_from(|assignment| sources.iter().any(|src| assignment.target == src.variable));
465 tainted_vars.extend(tainted);
466
467 let sinks = self.sink_registry.detect_sinks(function, &tainted_vars);
471
472 let mut flows = Vec::new();
474 for sink in sinks {
475 for source in &sources {
477 if tainted_vars.contains(&sink.variable) {
478 let is_sanitized = self.is_flow_sanitized(
480 function,
481 &sink.variable,
482 &sanitized_vars,
483 &tainted_vars,
484 );
485
486 flows.push(TaintFlow {
487 source: source.clone(),
488 sink: sink.clone(),
489 sanitized: is_sanitized,
490 propagation_path: vec![], });
492 break; }
494 }
495 }
496
497 (tainted_vars, flows)
498 }
499
500 fn is_flow_sanitized(
503 &self,
504 function: &MirFunction,
505 sink_var: &str,
506 sanitized_vars: &HashSet<String>,
507 tainted_vars: &HashSet<String>,
508 ) -> bool {
509 let mut depends_on: HashMap<String, HashSet<String>> = HashMap::new();
512
513 for line in &function.body {
514 if let Some(target) = extract_assignment_target(line) {
516 let deps = extract_referenced_variables(line);
518 depends_on
519 .entry(target)
520 .or_insert_with(HashSet::new)
521 .extend(deps);
522 }
523 }
524
525 let mut visited = HashSet::new();
527 let mut queue = VecDeque::new();
528 queue.push_back(sink_var.to_string());
529 visited.insert(sink_var.to_string());
530
531 while let Some(var) = queue.pop_front() {
532 if sanitized_vars.contains(&var) {
534 return true;
535 }
536
537 if let Some(deps) = depends_on.get(&var) {
539 for dep in deps {
540 if !visited.contains(dep) && tainted_vars.contains(dep) {
542 visited.insert(dep.clone());
543 queue.push_back(dep.clone());
544 }
545 }
546 }
547 }
548
549 let cfg = ControlFlowGraph::from_mir(function);
552
553 let sink_block = self.find_sink_block(function, sink_var);
555
556 if let Some(sink_bb) = sink_block {
557 for sanitized_var in sanitized_vars {
559 if cfg.is_guarded_by(&sink_bb, sanitized_var) {
560 return true;
561 }
562 }
563 }
564
565 false
566 }
567
568 fn find_sink_block(&self, function: &MirFunction, sink_var: &str) -> Option<String> {
570 let mut current_block: Option<String> = None;
571
572 for line in &function.body {
573 let trimmed = line.trim();
574
575 if trimmed.starts_with("bb") && trimmed.contains(": {") {
577 current_block = Some(trimmed.split(':').next().unwrap().trim().to_string());
578 }
579 else if trimmed.contains(sink_var) {
581 if trimmed.contains("Command::arg")
583 || trimmed.contains("Command::new")
584 || trimmed.contains("fs::write")
585 || trimmed.contains("fs::remove")
586 || trimmed.contains("Path::join")
587 {
588 return current_block;
589 }
590 }
591 }
592
593 None
594 }
595
596 fn detect_sanitized_variables(
599 &self,
600 function: &MirFunction,
601 _sources: &[TaintSource],
602 ) -> HashSet<String> {
603 let mut sanitized_vars = HashSet::new();
604
605 for line in &function.body {
607 let is_sanitizing = self
609 .sanitizer_registry
610 .patterns
611 .iter()
612 .any(|pattern| pattern.function_patterns.iter().any(|p| line.contains(p)));
613
614 if is_sanitizing {
615 if let Some(target) = extract_assignment_target(line) {
617 sanitized_vars.insert(target);
618 }
619 }
620 }
621
622 sanitized_vars
623 }
624}
625
626#[derive(Debug, Clone)]
628pub struct TaintFlow {
629 pub source: TaintSource,
630 pub sink: TaintSink,
631 pub sanitized: bool,
632 pub propagation_path: Vec<String>, }
634
635impl TaintFlow {
636 pub fn to_finding(
638 &self,
639 rule_metadata: &RuleMetadata,
640 function_name: &str,
641 function_sig: &str,
642 span: Option<SourceSpan>,
643 ) -> Finding {
644 let message = format!(
645 "Tainted data from {} flows to {}{}",
646 format_source_kind(&self.source.kind),
647 format_sink_kind(&self.sink.kind),
648 if self.sanitized {
649 " (sanitized)"
650 } else {
651 " without sanitization"
652 }
653 );
654
655 let evidence = vec![
656 format!("Source: {}", self.source.source_line),
657 format!("Sink: {}", self.sink.sink_line),
658 ];
659
660 Finding::new(
661 rule_metadata.id.clone(),
662 rule_metadata.name.clone(),
663 if self.sanitized {
664 Severity::Low
665 } else {
666 self.sink.severity
667 },
668 message,
669 function_name.to_string(),
670 function_sig.to_string(),
671 evidence,
672 span,
673 )
674 }
675}
676
677fn extract_assignment_target(line: &str) -> Option<String> {
680 let trimmed = line.trim();
681 if let Some(eq_pos) = trimmed.find('=') {
682 let lhs = trimmed[..eq_pos].trim();
683 if lhs.starts_with('_') && lhs.chars().skip(1).all(|c| c.is_ascii_digit()) {
685 return Some(lhs.to_string());
686 }
687 if lhs.starts_with('(') && lhs.ends_with(')') {
689 let inner = &lhs[1..lhs.len() - 1];
690 if let Some(first) = inner.split(',').next() {
692 let var = first.trim();
693 if var.starts_with('_') {
694 return Some(var.to_string());
695 }
696 }
697 }
698 }
699 None
700}
701
702fn extract_referenced_variables(line: &str) -> Vec<String> {
705 let mut vars = Vec::new();
706 let trimmed = line.trim();
707
708 if let Some(eq_pos) = trimmed.find('=') {
710 let rhs = &trimmed[eq_pos + 1..];
711
712 let mut i = 0;
714 let chars: Vec<char> = rhs.chars().collect();
715 while i < chars.len() {
716 if chars[i] == '_' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit() {
717 let mut var = String::from("_");
718 i += 1;
719 while i < chars.len() && chars[i].is_ascii_digit() {
720 var.push(chars[i]);
721 i += 1;
722 }
723 vars.push(var);
724 } else {
725 i += 1;
726 }
727 }
728 }
729
730 vars
731}
732
733fn format_source_kind(kind: &TaintSourceKind) -> &'static str {
734 match kind {
735 TaintSourceKind::EnvironmentVariable => "environment variable",
736 TaintSourceKind::NetworkInput => "network input",
737 TaintSourceKind::FileInput => "file input",
738 TaintSourceKind::CommandOutput => "command output",
739 TaintSourceKind::UserInput => "user input",
740 }
741}
742
743fn format_sink_kind(kind: &TaintSinkKind) -> &'static str {
744 match kind {
745 TaintSinkKind::CommandExecution => "command execution",
746 TaintSinkKind::FileSystemOp => "file system operation",
747 TaintSinkKind::SqlQuery => "SQL query",
748 TaintSinkKind::RegexCompile => "regex compilation",
749 TaintSinkKind::NetworkWrite => "network write",
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 fn make_function(lines: &[&str]) -> MirFunction {
758 MirFunction {
759 name: "test_fn".to_string(),
760 signature: "fn test_fn()".to_string(),
761 body: lines.iter().map(|l| l.to_string()).collect(),
762 span: None,
763 ..Default::default()
764 }
765 }
766
767 #[test]
768 fn detects_env_var_source() {
769 let func = make_function(&["_1 = std::env::var(move _2);"]);
770
771 let registry = SourceRegistry::new();
772 let sources = registry.detect_sources(&func);
773
774 assert_eq!(sources.len(), 1);
775 assert_eq!(sources[0].kind, TaintSourceKind::EnvironmentVariable);
776 assert_eq!(sources[0].variable, "_1");
777 }
778
779 #[test]
780 fn detects_command_sink() {
781 let func = make_function(&[
782 "_1 = std::env::var(move _2);",
783 "_3 = Command::arg::<&str>(move _4, move _1) -> [return: bb1, unwind: bb2];",
784 ]);
785
786 let mut tainted = HashSet::new();
787 tainted.insert("_1".to_string());
788
789 let registry = SinkRegistry::new();
790 let sinks = registry.detect_sinks(&func, &tainted);
791
792 assert_eq!(sinks.len(), 1);
793 assert_eq!(sinks[0].kind, TaintSinkKind::CommandExecution);
794 }
795
796 #[test]
797 fn full_taint_analysis() {
798 let func = make_function(&[
799 "_1 = std::env::var(move _2);",
800 "_3 = copy _1;",
801 "_4 = Command::arg::<&str>(move _5, move _3) -> [return: bb1, unwind: bb2];",
802 ]);
803
804 let analysis = TaintAnalysis::new();
805 let (tainted_vars, flows) = analysis.analyze(&func);
806
807 assert!(tainted_vars.contains("_1"));
808 assert!(tainted_vars.contains("_3"));
809 assert_eq!(flows.len(), 1);
810 assert!(!flows[0].sanitized);
811 }
812
813 #[test]
814 fn no_false_positive_on_hardcoded() {
815 let func = make_function(&[
816 "_1 = const \"hardcoded\";",
817 "_2 = Command::arg::<&str>(move _3, move _1) -> [return: bb1, unwind: bb2];",
818 ]);
819
820 let analysis = TaintAnalysis::new();
821 let (_tainted_vars, flows) = analysis.analyze(&func);
822
823 assert_eq!(flows.len(), 0, "Hardcoded strings should not be tainted");
824 }
825}