Skip to main content

tldr_cli/commands/
taint.rs

1//! Taint analysis CLI command
2//!
3//! Provides CFG-based taint analysis to detect security vulnerabilities
4//! such as SQL injection, command injection, and code injection.
5//!
6//! # Usage
7//!
8//! ```bash
9//! tldr taint <file> <function> [-f json|text]
10//! ```
11//!
12//! # Output
13//!
14//! - JSON: Full TaintInfo structure with sources, sinks, flows
15//! - Text: Human-readable summary with vulnerability highlights
16//!
17//! # Reference
18//! - session11-taint-spec.md
19
20use std::collections::HashMap;
21use std::path::PathBuf;
22
23use anyhow::Result;
24use clap::Args;
25use colored::Colorize;
26
27use tldr_core::ast::ParserPool;
28use tldr_core::{compute_taint_with_tree, get_cfg_context, get_dfg_context, Language, TaintInfo};
29
30use crate::output::OutputFormat;
31
32/// Analyze taint flows in a function to detect security vulnerabilities
33#[derive(Debug, Args)]
34pub struct TaintArgs {
35    /// Source file to analyze
36    pub file: PathBuf,
37
38    /// Function name to analyze
39    pub function: String,
40
41    /// Programming language (auto-detected from file extension if not specified)
42    #[arg(long, short = 'l')]
43    pub lang: Option<Language>,
44
45    /// Show verbose output with tainted variables per block
46    #[arg(long, short = 'v')]
47    pub verbose: bool,
48}
49
50impl TaintArgs {
51    /// Run the taint analysis command
52    pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
53        use crate::output::OutputWriter;
54
55        let writer = OutputWriter::new(format, quiet);
56
57        // Determine language from file extension or argument
58        let language = self
59            .lang
60            .unwrap_or_else(|| Language::from_path(&self.file).unwrap_or(Language::Python));
61
62        writer.progress(&format!(
63            "Analyzing taint flows for {} in {}...",
64            self.function,
65            self.file.display()
66        ));
67
68        // Read source file - ensure it exists
69        if !self.file.exists() {
70            return Err(anyhow::anyhow!("File not found: {}", self.file.display()));
71        }
72
73        let source = std::fs::read_to_string(&self.file)?;
74
75        // Get CFG for the function
76        let cfg = get_cfg_context(
77            self.file.to_str().unwrap_or_default(),
78            &self.function,
79            language,
80        )?;
81
82        // Get DFG for variable references
83        let dfg = get_dfg_context(
84            self.file.to_str().unwrap_or_default(),
85            &self.function,
86            language,
87        )?;
88
89        // Compute function line range from CFG blocks to scope statements
90        // to only the target function (avoids leaking sources/sinks from
91        // other functions in the same file).
92        let (fn_start, fn_end) = if cfg.blocks.is_empty() {
93            (1u32, source.lines().count() as u32)
94        } else {
95            let start = cfg.blocks.iter().map(|b| b.lines.0).min().unwrap_or(1);
96            let end = cfg
97                .blocks
98                .iter()
99                .map(|b| b.lines.1)
100                .max()
101                .unwrap_or(source.lines().count() as u32);
102            (start, end)
103        };
104
105        // Build statements map scoped to function line range
106        let statements: HashMap<u32, String> = source
107            .lines()
108            .enumerate()
109            .filter(|(i, _)| {
110                let line_num = (i + 1) as u32;
111                line_num >= fn_start && line_num <= fn_end
112            })
113            .map(|(i, line)| ((i + 1) as u32, line.to_string()))
114            .collect();
115
116        // Parse source with tree-sitter for AST-based taint detection
117        let pool = ParserPool::new();
118        let tree = pool.parse(&source, language).ok();
119
120        // Run taint analysis (AST-based when tree available, regex fallback otherwise)
121        let result = compute_taint_with_tree(
122            &cfg,
123            &dfg.refs,
124            &statements,
125            tree.as_ref(),
126            Some(source.as_bytes()),
127            language,
128        )?;
129
130        // Output based on format
131        match format {
132            OutputFormat::Text => {
133                let text = format_taint_text(&result, self.verbose);
134                writer.write_text(&text)?;
135            }
136            OutputFormat::Json | OutputFormat::Compact => {
137                let json = serde_json::to_string_pretty(&result)
138                    .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
139                writer.write_text(&json)?;
140            }
141            OutputFormat::Dot => {
142                // DOT not supported for taint analysis, fall back to JSON
143                let json = serde_json::to_string_pretty(&result)
144                    .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
145                writer.write_text(&json)?;
146            }
147            OutputFormat::Sarif => {
148                // SARIF not supported, fall back to JSON
149                let json = serde_json::to_string_pretty(&result)
150                    .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
151                writer.write_text(&json)?;
152            }
153        }
154
155        Ok(())
156    }
157}
158
159/// Format taint analysis results for human-readable text output
160fn format_taint_text(result: &TaintInfo, verbose: bool) -> String {
161    let mut output = String::new();
162
163    // Header
164    output.push_str(&format!(
165        "{}\n",
166        format!("Taint Analysis: {}", result.function_name)
167            .bold()
168            .cyan()
169    ));
170    output.push_str(&"=".repeat(50));
171    output.push('\n');
172
173    // Sources section
174    output.push_str(&format!(
175        "\n{} ({}):\n",
176        "Sources".bold(),
177        result.sources.len()
178    ));
179    if result.sources.is_empty() {
180        output.push_str("  No taint sources detected.\n");
181    } else {
182        for source in &result.sources {
183            output.push_str(&format!(
184                "  Line {}: {} ({})\n",
185                source.line.to_string().yellow(),
186                source.var.green(),
187                format!("{:?}", source.source_type).cyan()
188            ));
189            if let Some(ref stmt) = source.statement {
190                output.push_str(&format!("    {}\n", stmt.trim().dimmed()));
191            }
192        }
193    }
194
195    // Sinks section
196    output.push_str(&format!("\n{} ({}):\n", "Sinks".bold(), result.sinks.len()));
197    if result.sinks.is_empty() {
198        output.push_str("  No sinks detected.\n");
199    } else {
200        for sink in &result.sinks {
201            let status = if sink.tainted {
202                "TAINTED".red().bold().to_string()
203            } else {
204                "safe".green().to_string()
205            };
206            output.push_str(&format!(
207                "  Line {}: {} ({}) - {}\n",
208                sink.line.to_string().yellow(),
209                sink.var.green(),
210                format!("{:?}", sink.sink_type).cyan(),
211                status
212            ));
213            if let Some(ref stmt) = sink.statement {
214                output.push_str(&format!("    {}\n", stmt.trim().dimmed()));
215            }
216        }
217    }
218
219    // Vulnerabilities section (tainted sinks)
220    let vulns: Vec<_> = result.sinks.iter().filter(|s| s.tainted).collect();
221    output.push_str(&format!(
222        "\n{} ({}):\n",
223        "Vulnerabilities".bold().red(),
224        vulns.len()
225    ));
226    if vulns.is_empty() {
227        output.push_str(&format!("  {}\n", "No vulnerabilities found.".green()));
228    } else {
229        for sink in vulns {
230            output.push_str(&format!(
231                "  {} Line {}: {} flows to {} sink\n",
232                "[!]".red().bold(),
233                sink.line.to_string().yellow(),
234                sink.var.red(),
235                format!("{:?}", sink.sink_type).cyan()
236            ));
237        }
238    }
239
240    // Flows section
241    if !result.flows.is_empty() {
242        output.push_str(&format!(
243            "\n{} ({}):\n",
244            "Taint Flows".bold(),
245            result.flows.len()
246        ));
247        for flow in &result.flows {
248            output.push_str(&format!(
249                "  {} (line {}) -> {} (line {})\n",
250                flow.source.var.green(),
251                flow.source.line,
252                flow.sink.var.red(),
253                flow.sink.line
254            ));
255            if !flow.path.is_empty() {
256                output.push_str(&format!(
257                    "    Path: {}\n",
258                    flow.path
259                        .iter()
260                        .map(|b| b.to_string())
261                        .collect::<Vec<_>>()
262                        .join(" -> ")
263                        .dimmed()
264                ));
265            }
266        }
267    }
268
269    // Verbose: tainted variables per block
270    if verbose && !result.tainted_vars.is_empty() {
271        output.push_str(&format!("\n{}:\n", "Tainted Variables per Block".bold()));
272        let mut blocks: Vec<_> = result.tainted_vars.keys().collect();
273        blocks.sort();
274        for block_id in blocks {
275            if let Some(vars) = result.tainted_vars.get(block_id) {
276                if !vars.is_empty() {
277                    output.push_str(&format!(
278                        "  Block {}: {}\n",
279                        block_id,
280                        vars.iter()
281                            .map(|v| v.as_str())
282                            .collect::<Vec<_>>()
283                            .join(", ")
284                            .yellow()
285                    ));
286                }
287            }
288        }
289    }
290
291    // Sanitized variables
292    if !result.sanitized_vars.is_empty() {
293        output.push_str(&format!(
294            "\n{}: {}\n",
295            "Sanitized Variables".bold(),
296            result
297                .sanitized_vars
298                .iter()
299                .map(|v| v.as_str())
300                .collect::<Vec<_>>()
301                .join(", ")
302                .green()
303        ));
304    }
305
306    output
307}
308
309#[cfg(test)]
310mod tests {
311    
312    use std::collections::HashMap;
313    use std::io::Write;
314    use tempfile::NamedTempFile;
315
316    use tldr_core::ast::ParserPool;
317    use tldr_core::{
318        compute_taint_with_tree, get_cfg_context, get_dfg_context, Language, TaintSinkType,
319    };
320
321    const PYTHON_FIXTURE: &str = r#"import os
322
323def safe_func():
324    x = "hardcoded"
325    os.system(x)
326
327def vulnerable_func(user_input):
328    data = input("Enter: ")
329    query = "SELECT * FROM users WHERE id = " + data
330    os.system(user_input)
331    eval(data)
332"#;
333
334    /// Helper: write fixture to a temp file, get CFG+DFG, run taint analysis
335    fn run_taint_on_function(code: &str, function: &str) -> tldr_core::TaintInfo {
336        let mut tmp = NamedTempFile::with_suffix(".py").unwrap();
337        tmp.write_all(code.as_bytes()).unwrap();
338        tmp.flush().unwrap();
339        let path = tmp.path().to_str().unwrap();
340
341        let cfg = get_cfg_context(path, function, Language::Python).unwrap();
342        let dfg = get_dfg_context(path, function, Language::Python).unwrap();
343
344        // Compute function line range from CFG blocks (Bug 2 fix)
345        let (fn_start, fn_end) = if cfg.blocks.is_empty() {
346            (1u32, code.lines().count() as u32)
347        } else {
348            let start = cfg.blocks.iter().map(|b| b.lines.0).min().unwrap_or(1);
349            let end = cfg
350                .blocks
351                .iter()
352                .map(|b| b.lines.1)
353                .max()
354                .unwrap_or(code.lines().count() as u32);
355            (start, end)
356        };
357
358        let statements: HashMap<u32, String> = code
359            .lines()
360            .enumerate()
361            .filter(|(i, _)| {
362                let line_num = (i + 1) as u32;
363                line_num >= fn_start && line_num <= fn_end
364            })
365            .map(|(i, line)| ((i + 1) as u32, line.to_string()))
366            .collect();
367
368        let pool = ParserPool::new();
369        let tree = pool.parse(code, Language::Python).ok();
370
371        compute_taint_with_tree(
372            &cfg,
373            &dfg.refs,
374            &statements,
375            tree.as_ref(),
376            Some(code.as_bytes()),
377            Language::Python,
378        )
379        .unwrap()
380    }
381
382    #[test]
383    fn test_scoped_to_function() {
384        let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
385
386        // Get the line range for safe_func (lines 3-5) and vulnerable_func (lines 7-11)
387        // Sources should only come from vulnerable_func's range
388        for source in &result.sources {
389            assert!(
390                source.line >= 7 && source.line <= 11,
391                "Source on line {} is outside vulnerable_func's range (7-11). \
392                 Leaking from another function! var={}, type={:?}",
393                source.line,
394                source.var,
395                source.source_type
396            );
397        }
398
399        // Sinks should only come from vulnerable_func's range
400        for sink in &result.sinks {
401            assert!(
402                sink.line >= 7 && sink.line <= 11,
403                "Sink on line {} is outside vulnerable_func's range (7-11). \
404                 Leaking from another function! var={}, type={:?}",
405                sink.line,
406                sink.var,
407                sink.sink_type
408            );
409        }
410
411        // Should have found sources in vulnerable_func
412        assert!(
413            !result.sources.is_empty(),
414            "Should detect sources in vulnerable_func"
415        );
416    }
417
418    #[test]
419    fn test_sinks_detected() {
420        let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
421
422        let sink_types: Vec<_> = result.sinks.iter().map(|s| s.sink_type).collect();
423
424        assert!(
425            sink_types.contains(&TaintSinkType::ShellExec),
426            "Should detect os.system as ShellExec sink, got: {:?}",
427            sink_types
428        );
429        assert!(
430            sink_types.contains(&TaintSinkType::CodeEval),
431            "Should detect eval as CodeEval sink, got: {:?}",
432            sink_types
433        );
434    }
435
436    #[test]
437    fn test_sources_are_deduplicated() {
438        let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
439
440        // Check no duplicate sources (same line + source_type + var)
441        let mut seen = std::collections::HashSet::new();
442        for source in &result.sources {
443            let key = (
444                source.line,
445                std::mem::discriminant(&source.source_type),
446                source.var.clone(),
447            );
448            assert!(
449                seen.insert(key),
450                "Duplicate source: line={}, var={}, type={:?}",
451                source.line,
452                source.var,
453                source.source_type
454            );
455        }
456
457        // Check no duplicate sinks
458        let mut seen_sinks = std::collections::HashSet::new();
459        for sink in &result.sinks {
460            let key = (
461                sink.line,
462                std::mem::discriminant(&sink.sink_type),
463                sink.var.clone(),
464            );
465            assert!(
466                seen_sinks.insert(key),
467                "Duplicate sink: line={}, var={}, type={:?}",
468                sink.line,
469                sink.var,
470                sink.sink_type
471            );
472        }
473    }
474}