Skip to main content

mir_extractor/dataflow/
taint.rs

1// Taint tracking infrastructure for dataflow analysis
2// Tracks untrusted data from sources (env vars, network) to sinks (Command, fs)
3
4use super::MirDataflow;
5use crate::{Finding, MirFunction, RuleMetadata, Severity, SourceSpan};
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Basic block in MIR control flow graph
9#[derive(Debug, Clone)]
10struct BasicBlock {
11    id: String,                 // e.g., "bb0", "bb1"
12    statements: Vec<String>,    // Statements in this block
13    terminator: Option<String>, // goto, switchInt, return, etc.
14    successors: Vec<String>,    // Which blocks this can jump to
15}
16
17/// Control flow graph for a MIR function
18struct ControlFlowGraph {
19    blocks: HashMap<String, BasicBlock>,
20    _entry_block: String, // Usually "bb0"
21}
22
23impl ControlFlowGraph {
24    /// Parse MIR body into a control flow graph
25    fn from_mir(function: &MirFunction) -> Self {
26        let mut blocks = HashMap::new();
27        let mut current_block: Option<BasicBlock> = None;
28        let entry_block = "bb0".to_string();
29
30        for line in &function.body {
31            let trimmed = line.trim();
32
33            // Start of a new basic block
34            if trimmed.starts_with("bb") && trimmed.contains(": {") {
35                // Save previous block
36                if let Some(block) = current_block.take() {
37                    blocks.insert(block.id.clone(), block);
38                }
39
40                // Extract block ID (e.g., "bb0" from "bb0: {")
41                let id = trimmed.split(':').next().unwrap().trim().to_string();
42                current_block = Some(BasicBlock {
43                    id,
44                    statements: Vec::new(),
45                    terminator: None,
46                    successors: Vec::new(),
47                });
48            }
49            // Terminator (goto, switchInt, return, etc.)
50            else if trimmed.contains("goto")
51                || trimmed.contains("switchInt")
52                || trimmed.contains("return")
53                || trimmed.contains("-> [return:")
54            {
55                if let Some(ref mut block) = current_block {
56                    block.terminator = Some(trimmed.to_string());
57                    // Extract successor blocks
58                    block.successors = Self::extract_successors(trimmed);
59                }
60            }
61            // Regular statement in the current block
62            else if !trimmed.is_empty() && !trimmed.starts_with("}") && current_block.is_some() {
63                if let Some(ref mut block) = current_block {
64                    block.statements.push(trimmed.to_string());
65                }
66            }
67        }
68
69        // Save last block
70        if let Some(block) = current_block {
71            blocks.insert(block.id.clone(), block);
72        }
73
74        Self {
75            blocks,
76            _entry_block: entry_block,
77        }
78    }
79
80    /// Extract successor block IDs from a terminator
81    fn extract_successors(terminator: &str) -> Vec<String> {
82        let mut successors = Vec::new();
83
84        // Extract "bbN" patterns
85        let mut i = 0;
86        let chars: Vec<char> = terminator.chars().collect();
87        while i < chars.len() {
88            if i + 1 < chars.len() && chars[i] == 'b' && chars[i + 1] == 'b' {
89                i += 2;
90                let mut num = String::new();
91                while i < chars.len() && chars[i].is_ascii_digit() {
92                    num.push(chars[i]);
93                    i += 1;
94                }
95                if !num.is_empty() {
96                    successors.push(format!("bb{}", num));
97                }
98            } else {
99                i += 1;
100            }
101        }
102
103        successors
104    }
105
106    /// Check if a basic block containing the sink is guarded by a sanitization check
107    /// Returns true if the block is only reachable when the guard variable is true/non-zero
108    fn is_guarded_by(&self, sink_block_id: &str, guard_var: &str) -> bool {
109        // Find the block that contains the switchInt on the guard variable
110        for (_block_id, block) in &self.blocks {
111            if let Some(ref terminator) = block.terminator {
112                // Look for: switchInt(move _3) -> [0: bbX, otherwise: bbY]
113                // where _3 is the guard variable
114                if terminator.contains("switchInt") && terminator.contains(guard_var) {
115                    // Extract the "otherwise" target (where guard is true)
116                    let successors = &block.successors;
117
118                    // The pattern is: switchInt(var) -> [0: bb_false, otherwise: bb_true]
119                    // If there are 2 successors, the second one (or "otherwise") is the true branch
120                    if successors.len() >= 2 {
121                        let true_branch = &successors[1]; // "otherwise" branch
122
123                        // Check if sink_block is reachable from the true branch
124                        return self.is_reachable_from(true_branch, sink_block_id);
125                    }
126                }
127            }
128        }
129
130        false
131    }
132
133    /// Check if target_block is reachable from start_block
134    fn is_reachable_from(&self, start_block: &str, target_block: &str) -> bool {
135        if start_block == target_block {
136            return true;
137        }
138
139        let mut visited = HashSet::new();
140        let mut queue = VecDeque::new();
141        queue.push_back(start_block.to_string());
142        visited.insert(start_block.to_string());
143
144        while let Some(block_id) = queue.pop_front() {
145            if block_id == target_block {
146                return true;
147            }
148
149            if let Some(block) = self.blocks.get(&block_id) {
150                for successor in &block.successors {
151                    if !visited.contains(successor) {
152                        visited.insert(successor.clone());
153                        queue.push_back(successor.clone());
154                    }
155                }
156            }
157        }
158
159        false
160    }
161}
162
163/// Kinds of taint sources (where untrusted data originates)
164#[derive(Debug, Clone, PartialEq, Eq, Hash)]
165pub enum TaintSourceKind {
166    EnvironmentVariable, // env::var, env::var_os, env::vars_os
167    NetworkInput,        // TcpStream::read, HttpRequest::body (future)
168    FileInput,           // fs::read, File::read (future)
169    CommandOutput,       // Command::output (future)
170    UserInput,           // stdin, readline (future)
171}
172
173/// Kinds of taint sinks (security-sensitive operations)
174#[derive(Debug, Clone, PartialEq, Eq, Hash)]
175pub enum TaintSinkKind {
176    CommandExecution, // Command::new, Command::arg
177    FileSystemOp,     // fs::write, fs::remove, Path::join
178    SqlQuery,         // diesel::sql_query, sqlx::query (future)
179    RegexCompile,     // Regex::new (future)
180    NetworkWrite,     // TcpStream::write (future)
181}
182
183/// A taint source instance
184#[derive(Debug, Clone)]
185pub struct TaintSource {
186    pub kind: TaintSourceKind,
187    pub variable: String,    // MIR local (_1, _2, etc.)
188    pub source_line: String, // Original code line for reporting
189    pub confidence: f32,     // 0.0-1.0, how certain we are this is tainted
190}
191
192/// A taint sink instance
193#[derive(Debug, Clone)]
194pub struct TaintSink {
195    pub kind: TaintSinkKind,
196    pub variable: String,  // MIR local that reaches sink
197    pub sink_line: String, // Original code line for reporting
198    pub severity: Severity,
199}
200
201/// Registry of patterns that identify taint sources
202pub struct SourceRegistry {
203    patterns: Vec<SourcePattern>,
204}
205
206struct SourcePattern {
207    kind: TaintSourceKind,
208    function_patterns: Vec<&'static str>,
209    _severity: Severity,
210}
211
212impl SourceRegistry {
213    pub fn new() -> Self {
214        Self {
215            patterns: vec![
216                SourcePattern {
217                    kind: TaintSourceKind::EnvironmentVariable,
218                    function_patterns: vec![
219                        " = var::",       // Most common in MIR (fully qualified import)
220                        " = var(",        // Alternative
221                        "std::env::var(", // Full path
222                        "std::env::var_os(",
223                        "core::env::var(",
224                        "core::env::var_os(",
225                    ],
226                    _severity: Severity::Medium,
227                },
228                // Future: Add NetworkInput, FileInput, etc.
229            ],
230        }
231    }
232
233    /// Scan function for taint sources and return detected sources
234    pub fn detect_sources(&self, function: &MirFunction) -> Vec<TaintSource> {
235        let mut sources = Vec::new();
236
237        for line in &function.body {
238            for pattern in &self.patterns {
239                for func_pattern in &pattern.function_patterns {
240                    if line.contains(func_pattern) {
241                        // Extract the target variable (left side of assignment)
242                        if let Some(target) = extract_assignment_target(line) {
243                            sources.push(TaintSource {
244                                kind: pattern.kind.clone(),
245                                variable: target,
246                                source_line: line.trim().to_string(),
247                                confidence: 1.0,
248                            });
249                        }
250                    }
251                }
252            }
253        }
254
255        sources
256    }
257}
258
259/// Registry of patterns that identify taint sinks
260pub struct SinkRegistry {
261    patterns: Vec<SinkPattern>,
262}
263
264struct SinkPattern {
265    kind: TaintSinkKind,
266    function_patterns: Vec<&'static str>,
267    severity: Severity,
268}
269
270impl SinkRegistry {
271    pub fn new() -> Self {
272        Self {
273            patterns: vec![
274                SinkPattern {
275                    kind: TaintSinkKind::CommandExecution,
276                    function_patterns: vec![
277                        "Command::new::",  // With generics in MIR
278                        "Command::arg::",  // With generics in MIR
279                        "Command::args::", // With generics in MIR
280                    ],
281                    severity: Severity::High,
282                },
283                SinkPattern {
284                    kind: TaintSinkKind::FileSystemOp,
285                    function_patterns: vec![
286                        "std::fs::write::",
287                        "std::fs::remove_file::",
288                        "std::fs::remove_dir::",
289                        "std::path::Path::join::",
290                    ],
291                    severity: Severity::Medium,
292                },
293                // Future: Add SqlQuery, RegexCompile, etc.
294            ],
295        }
296    }
297
298    /// Scan function for taint sinks that use specific variables
299    pub fn detect_sinks(
300        &self,
301        function: &MirFunction,
302        tainted_vars: &HashSet<String>,
303    ) -> Vec<TaintSink> {
304        let mut sinks = Vec::new();
305
306        for line in &function.body {
307            for pattern in &self.patterns {
308                for func_pattern in &pattern.function_patterns {
309                    if line.contains(func_pattern) {
310                        // Extract variables used in this sink
311                        let used_vars = super::extract_variables(line);
312
313                        // Check if any tainted variable is used
314                        for var in used_vars {
315                            if tainted_vars.contains(&var) {
316                                sinks.push(TaintSink {
317                                    kind: pattern.kind.clone(),
318                                    variable: var,
319                                    sink_line: line.trim().to_string(),
320                                    severity: pattern.severity,
321                                });
322                                break; // Only report once per line
323                            }
324                        }
325                    }
326                }
327            }
328        }
329
330        sinks
331    }
332}
333
334/// Registry of patterns that sanitize tainted data
335pub struct SanitizerRegistry {
336    pub(crate) patterns: Vec<SanitizerPattern>,
337}
338
339pub(crate) struct SanitizerPattern {
340    pub(crate) function_patterns: Vec<&'static str>,
341    pub(crate) sanitizes: Vec<TaintSinkKind>, // Which sinks does this sanitize for?
342}
343
344impl SanitizerRegistry {
345    pub fn new() -> Self {
346        Self {
347            patterns: vec![
348                SanitizerPattern {
349                    // .parse::<T>() type conversions sanitize for most uses
350                    // Patterns: core::str::<impl str>::parse::<u16>, etc.
351                    function_patterns: vec!["::parse::<"],
352                    sanitizes: vec![TaintSinkKind::CommandExecution, TaintSinkKind::FileSystemOp],
353                },
354                SanitizerPattern {
355                    // .chars().all() validation patterns
356                    // Pattern: <Chars<'_> as Iterator>::all::<{closure@
357                    function_patterns: vec![" as Iterator>::all::<"],
358                    sanitizes: vec![TaintSinkKind::CommandExecution, TaintSinkKind::FileSystemOp],
359                },
360                // Future: Add regex validation, canonicalization, etc.
361            ],
362        }
363    }
364
365    /// Check if a variable is sanitized between source and sink
366    /// Returns true if we detect sanitization patterns in the function body
367    pub fn is_sanitized(
368        &self,
369        function: &MirFunction,
370        var: &str,
371        sink_kind: &TaintSinkKind,
372    ) -> bool {
373        // Look for sanitization patterns that operate on this variable
374        for line in &function.body {
375            // Check if this line involves the variable
376            if line.contains(var) {
377                // Check if it matches any sanitization pattern
378                for pattern in &self.patterns {
379                    // Check if this pattern sanitizes for this sink type
380                    if pattern.sanitizes.contains(sink_kind) {
381                        for func_pattern in &pattern.function_patterns {
382                            if line.contains(func_pattern) {
383                                // Found sanitization!
384                                return true;
385                            }
386                        }
387                    }
388                }
389            }
390        }
391        false
392    }
393}
394
395/// Main taint analysis engine
396pub struct TaintAnalysis {
397    source_registry: SourceRegistry,
398    sink_registry: SinkRegistry,
399    sanitizer_registry: SanitizerRegistry,
400}
401
402impl TaintAnalysis {
403    pub fn new() -> Self {
404        Self {
405            source_registry: SourceRegistry::new(),
406            sink_registry: SinkRegistry::new(),
407            sanitizer_registry: SanitizerRegistry::new(),
408        }
409    }
410
411    /// Perform taint analysis on a function
412    /// Returns (tainted variables, detected flows)
413    pub fn analyze(&self, function: &MirFunction) -> (HashSet<String>, Vec<TaintFlow>) {
414        let is_target_function = function.name.contains("sanitized_parse")
415            || function.name.contains("sanitized_allowlist");
416
417        if is_target_function {
418            eprintln!(
419                "\n========== ANALYZING TARGET FUNCTION: {} ==========",
420                function.name
421            );
422        }
423
424        // Step 1: Detect taint sources
425        let sources = self.source_registry.detect_sources(function);
426
427        if is_target_function {
428            eprintln!("Found {} sources", sources.len());
429
430            // Show basic block structure to understand control flow
431            eprintln!("\n--- MIR Basic Block Structure ---");
432            for line in &function.body {
433                let trimmed = line.trim();
434                if trimmed.starts_with("bb") && trimmed.contains(':') {
435                    eprintln!("{}", trimmed);
436                } else if trimmed.contains("switchInt")
437                    || trimmed.contains("goto")
438                    || trimmed.contains("return")
439                {
440                    eprintln!("  {}", trimmed);
441                }
442            }
443            eprintln!("--- End Basic Blocks ---\n");
444        }
445
446        if sources.is_empty() {
447            return (HashSet::new(), Vec::new());
448        }
449
450        // Step 2: Identify sanitized variables
451        // These are variables that result from sanitizing operations on tainted data
452        let sanitized_vars = self.detect_sanitized_variables(function, &sources);
453
454        // Step 3: Propagate taint through dataflow
455        let dataflow = MirDataflow::new(function);
456
457        let mut tainted_vars = HashSet::new();
458        for source in &sources {
459            tainted_vars.insert(source.variable.clone());
460        }
461
462        // Use existing taint_from to propagate
463        let tainted = dataflow
464            .taint_from(|assignment| sources.iter().any(|src| assignment.target == src.variable));
465        tainted_vars.extend(tainted);
466
467        // Don't remove sanitized vars - we'll check paths instead
468
469        // Step 4: Detect sinks that use tainted data
470        let sinks = self.sink_registry.detect_sinks(function, &tainted_vars);
471
472        // Step 5: Create flows and check if each flow goes through sanitization
473        let mut flows = Vec::new();
474        for sink in sinks {
475            // Find which source(s) contributed to this sink
476            for source in &sources {
477                if tainted_vars.contains(&sink.variable) {
478                    // Check if this sink is sanitized by tracing backward
479                    let is_sanitized = self.is_flow_sanitized(
480                        function,
481                        &sink.variable,
482                        &sanitized_vars,
483                        &tainted_vars,
484                    );
485
486                    flows.push(TaintFlow {
487                        source: source.clone(),
488                        sink: sink.clone(),
489                        sanitized: is_sanitized,
490                        propagation_path: vec![], // Path tracking done at inter-procedural level
491                    });
492                    break; // One source per sink for now
493                }
494            }
495        }
496
497        (tainted_vars, flows)
498    }
499
500    /// Check if a flow from source to sink goes through a sanitization operation
501    /// This includes both dataflow sanitization (parse) and control-flow guards (if checks)
502    fn is_flow_sanitized(
503        &self,
504        function: &MirFunction,
505        sink_var: &str,
506        sanitized_vars: &HashSet<String>,
507        tainted_vars: &HashSet<String>,
508    ) -> bool {
509        // First, check dataflow-based sanitization (e.g., parse())
510        // Build a reverse dependency map: for each variable, track what it depends on
511        let mut depends_on: HashMap<String, HashSet<String>> = HashMap::new();
512
513        for line in &function.body {
514            // Look for assignments: _X = ... _Y ...
515            if let Some(target) = extract_assignment_target(line) {
516                // Extract all variables referenced on the right-hand side
517                let deps = extract_referenced_variables(line);
518                depends_on
519                    .entry(target)
520                    .or_insert_with(HashSet::new)
521                    .extend(deps);
522            }
523        }
524
525        // BFS backward from sink to find if we reach a sanitized variable
526        let mut visited = HashSet::new();
527        let mut queue = VecDeque::new();
528        queue.push_back(sink_var.to_string());
529        visited.insert(sink_var.to_string());
530
531        while let Some(var) = queue.pop_front() {
532            // If this variable is sanitized via dataflow (e.g., parse result), flow is sanitized
533            if sanitized_vars.contains(&var) {
534                return true;
535            }
536
537            // Otherwise, add its dependencies to the queue
538            if let Some(deps) = depends_on.get(&var) {
539                for dep in deps {
540                    // Only follow dependencies that are tainted (part of the taint flow)
541                    if !visited.contains(dep) && tainted_vars.contains(dep) {
542                        visited.insert(dep.clone());
543                        queue.push_back(dep.clone());
544                    }
545                }
546            }
547        }
548
549        // Second, check control-flow-based sanitization (e.g., if chars().all())
550        // Build the control flow graph
551        let cfg = ControlFlowGraph::from_mir(function);
552
553        // Find which basic block contains the sink operation
554        let sink_block = self.find_sink_block(function, sink_var);
555
556        if let Some(sink_bb) = sink_block {
557            // Check if any sanitized variable guards this sink block
558            for sanitized_var in sanitized_vars {
559                if cfg.is_guarded_by(&sink_bb, sanitized_var) {
560                    return true;
561                }
562            }
563        }
564
565        false
566    }
567
568    /// Find which basic block contains the sink operation for the given variable
569    fn find_sink_block(&self, function: &MirFunction, sink_var: &str) -> Option<String> {
570        let mut current_block: Option<String> = None;
571
572        for line in &function.body {
573            let trimmed = line.trim();
574
575            // Track which block we're in
576            if trimmed.starts_with("bb") && trimmed.contains(": {") {
577                current_block = Some(trimmed.split(':').next().unwrap().trim().to_string());
578            }
579            // Look for sink operations that use this variable
580            else if trimmed.contains(sink_var) {
581                // Check if this is a sink operation (Command::arg, fs::write, etc.)
582                if trimmed.contains("Command::arg")
583                    || trimmed.contains("Command::new")
584                    || trimmed.contains("fs::write")
585                    || trimmed.contains("fs::remove")
586                    || trimmed.contains("Path::join")
587                {
588                    return current_block;
589                }
590            }
591        }
592
593        None
594    }
595
596    /// Detect variables that are results of sanitizing operations
597    /// These variables should not propagate taint even if their inputs were tainted
598    fn detect_sanitized_variables(
599        &self,
600        function: &MirFunction,
601        _sources: &[TaintSource],
602    ) -> HashSet<String> {
603        let mut sanitized_vars = HashSet::new();
604
605        // Look for sanitization patterns in the function body
606        for line in &function.body {
607            // Check if this line is a sanitizing operation
608            let is_sanitizing = self
609                .sanitizer_registry
610                .patterns
611                .iter()
612                .any(|pattern| pattern.function_patterns.iter().any(|p| line.contains(p)));
613
614            if is_sanitizing {
615                // Extract the target variable (left side of assignment)
616                if let Some(target) = extract_assignment_target(line) {
617                    sanitized_vars.insert(target);
618                }
619            }
620        }
621
622        sanitized_vars
623    }
624}
625
626/// Represents a complete taint flow from source to sink
627#[derive(Debug, Clone)]
628pub struct TaintFlow {
629    pub source: TaintSource,
630    pub sink: TaintSink,
631    pub sanitized: bool,
632    pub propagation_path: Vec<String>, // Intermediate steps (for debugging)
633}
634
635impl TaintFlow {
636    /// Convert this taint flow into a Finding for reporting
637    pub fn to_finding(
638        &self,
639        rule_metadata: &RuleMetadata,
640        function_name: &str,
641        function_sig: &str,
642        span: Option<SourceSpan>,
643    ) -> Finding {
644        let message = format!(
645            "Tainted data from {} flows to {}{}",
646            format_source_kind(&self.source.kind),
647            format_sink_kind(&self.sink.kind),
648            if self.sanitized {
649                " (sanitized)"
650            } else {
651                " without sanitization"
652            }
653        );
654
655        let evidence = vec![
656            format!("Source: {}", self.source.source_line),
657            format!("Sink: {}", self.sink.sink_line),
658        ];
659
660        Finding::new(
661            rule_metadata.id.clone(),
662            rule_metadata.name.clone(),
663            if self.sanitized {
664                Severity::Low
665            } else {
666                self.sink.severity
667            },
668            message,
669            function_name.to_string(),
670            function_sig.to_string(),
671            evidence,
672            span,
673        )
674    }
675}
676
677// Helper functions
678
679fn extract_assignment_target(line: &str) -> Option<String> {
680    let trimmed = line.trim();
681    if let Some(eq_pos) = trimmed.find('=') {
682        let lhs = trimmed[..eq_pos].trim();
683        // Handle simple case: "_1 = ..."
684        if lhs.starts_with('_') && lhs.chars().skip(1).all(|c| c.is_ascii_digit()) {
685            return Some(lhs.to_string());
686        }
687        // Handle tuple destructuring: "(_1, _2) = ..."
688        if lhs.starts_with('(') && lhs.ends_with(')') {
689            let inner = &lhs[1..lhs.len() - 1];
690            // Return first variable in tuple for simplicity
691            if let Some(first) = inner.split(',').next() {
692                let var = first.trim();
693                if var.starts_with('_') {
694                    return Some(var.to_string());
695                }
696            }
697        }
698    }
699    None
700}
701
702/// Extract all variables referenced on the right-hand side of an assignment
703/// E.g., "_1 = add(_2, _3)" returns ["_2", "_3"]
704fn extract_referenced_variables(line: &str) -> Vec<String> {
705    let mut vars = Vec::new();
706    let trimmed = line.trim();
707
708    // Find the right-hand side (after '=')
709    if let Some(eq_pos) = trimmed.find('=') {
710        let rhs = &trimmed[eq_pos + 1..];
711
712        // Look for all occurrences of _N where N is digits
713        let mut i = 0;
714        let chars: Vec<char> = rhs.chars().collect();
715        while i < chars.len() {
716            if chars[i] == '_' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit() {
717                let mut var = String::from("_");
718                i += 1;
719                while i < chars.len() && chars[i].is_ascii_digit() {
720                    var.push(chars[i]);
721                    i += 1;
722                }
723                vars.push(var);
724            } else {
725                i += 1;
726            }
727        }
728    }
729
730    vars
731}
732
733fn format_source_kind(kind: &TaintSourceKind) -> &'static str {
734    match kind {
735        TaintSourceKind::EnvironmentVariable => "environment variable",
736        TaintSourceKind::NetworkInput => "network input",
737        TaintSourceKind::FileInput => "file input",
738        TaintSourceKind::CommandOutput => "command output",
739        TaintSourceKind::UserInput => "user input",
740    }
741}
742
743fn format_sink_kind(kind: &TaintSinkKind) -> &'static str {
744    match kind {
745        TaintSinkKind::CommandExecution => "command execution",
746        TaintSinkKind::FileSystemOp => "file system operation",
747        TaintSinkKind::SqlQuery => "SQL query",
748        TaintSinkKind::RegexCompile => "regex compilation",
749        TaintSinkKind::NetworkWrite => "network write",
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    fn make_function(lines: &[&str]) -> MirFunction {
758        MirFunction {
759            name: "test_fn".to_string(),
760            signature: "fn test_fn()".to_string(),
761            body: lines.iter().map(|l| l.to_string()).collect(),
762            span: None,
763            ..Default::default()
764        }
765    }
766
767    #[test]
768    fn detects_env_var_source() {
769        let func = make_function(&["_1 = std::env::var(move _2);"]);
770
771        let registry = SourceRegistry::new();
772        let sources = registry.detect_sources(&func);
773
774        assert_eq!(sources.len(), 1);
775        assert_eq!(sources[0].kind, TaintSourceKind::EnvironmentVariable);
776        assert_eq!(sources[0].variable, "_1");
777    }
778
779    #[test]
780    fn detects_command_sink() {
781        let func = make_function(&[
782            "_1 = std::env::var(move _2);",
783            "_3 = Command::arg::<&str>(move _4, move _1) -> [return: bb1, unwind: bb2];",
784        ]);
785
786        let mut tainted = HashSet::new();
787        tainted.insert("_1".to_string());
788
789        let registry = SinkRegistry::new();
790        let sinks = registry.detect_sinks(&func, &tainted);
791
792        assert_eq!(sinks.len(), 1);
793        assert_eq!(sinks[0].kind, TaintSinkKind::CommandExecution);
794    }
795
796    #[test]
797    fn full_taint_analysis() {
798        let func = make_function(&[
799            "_1 = std::env::var(move _2);",
800            "_3 = copy _1;",
801            "_4 = Command::arg::<&str>(move _5, move _3) -> [return: bb1, unwind: bb2];",
802        ]);
803
804        let analysis = TaintAnalysis::new();
805        let (tainted_vars, flows) = analysis.analyze(&func);
806
807        assert!(tainted_vars.contains("_1"));
808        assert!(tainted_vars.contains("_3"));
809        assert_eq!(flows.len(), 1);
810        assert!(!flows[0].sanitized);
811    }
812
813    #[test]
814    fn no_false_positive_on_hardcoded() {
815        let func = make_function(&[
816            "_1 = const \"hardcoded\";",
817            "_2 = Command::arg::<&str>(move _3, move _1) -> [return: bb1, unwind: bb2];",
818        ]);
819
820        let analysis = TaintAnalysis::new();
821        let (_tainted_vars, flows) = analysis.analyze(&func);
822
823        assert_eq!(flows.len(), 0, "Hardcoded strings should not be tainted");
824    }
825}