use std::collections::HashMap;
use std::io::Write;
use tempfile::NamedTempFile;
use tldr_core::cfg::get_cfg_context;
use tldr_core::dataflow::compute_available_exprs;
use tldr_core::dfg::{
build_def_use_chains, build_use_def_chains, compute_reaching_definitions, get_dfg_context,
};
use tldr_core::pdg::{get_slice, get_slice_rich};
use tldr_core::security::compute_taint;
use tldr_core::ssa::{construct_ssa, find_dead_code, SsaType};
use tldr_core::types::RefType;
use tldr_core::{Language, SliceDirection};
fn write_temp(source: &str, ext: &str) -> NamedTempFile {
let mut tmp = NamedTempFile::with_suffix(ext).expect("create tempfile");
tmp.write_all(source.as_bytes()).expect("write tempfile");
tmp.flush().expect("flush tempfile");
tmp
}
const PYTHON_SOURCE: &str = r#"
def example(user_input):
x = 10
y = x + 1
z = x + 1
unused = 42
query = "SELECT * FROM users WHERE name = '" + user_input + "'"
return y
"#;
const JAVASCRIPT_SOURCE: &str = r#"
function example(userInput) {
let x = 10;
let y = x + 1;
let z = x + 1;
let unused = 42;
let query = "SELECT * FROM users WHERE name = '" + userInput + "'";
return y;
}
"#;
const TYPESCRIPT_SOURCE: &str = r#"
function example(userInput: string): number {
let x: number = 10;
let y: number = x + 1;
let z: number = x + 1;
let unused: number = 42;
let query: string = "SELECT * FROM users WHERE name = '" + userInput + "'";
return y;
}
"#;
const GO_SOURCE: &str = r#"
package main
func example(userInput string) int {
x := 10
y := x + 1
z := x + 1
unused := 42
_ = unused
_ = z
query := "SELECT * FROM users WHERE name = '" + userInput + "'"
_ = query
return y
}
"#;
const RUST_SOURCE: &str = r#"
fn example(user_input: &str) -> i32 {
let x = 10;
let y = x + 1;
let z = x + 1;
let unused = 42;
let query = format!("SELECT * FROM users WHERE name = '{}'", user_input);
let _ = (z, unused, query);
y
}
"#;
const JAVA_SOURCE: &str = r#"
public class Example {
public static int example(String userInput) {
int x = 10;
int y = x + 1;
int z = x + 1;
int unused = 42;
String query = "SELECT * FROM users WHERE name = '" + userInput + "'";
return y;
}
}
"#;
const C_SOURCE: &str = r#"
int example(const char* user_input) {
int x = 10;
int y = x + 1;
int z = x + 1;
int unused = 42;
return y;
}
"#;
const CPP_SOURCE: &str = r#"
#include <string>
int example(const std::string& user_input) {
int x = 10;
int y = x + 1;
int z = x + 1;
int unused = 42;
std::string query = "SELECT * FROM users WHERE name = '" + user_input + "'";
return y;
}
"#;
const RUBY_SOURCE: &str = r#"
def example(user_input)
x = 10
y = x + 1
z = x + 1
unused = 42
query = "SELECT * FROM users WHERE name = '#{user_input}'"
return y
end
"#;
const PHP_SOURCE: &str = r#"<?php
function example($user_input) {
$x = 10;
$y = $x + 1;
$z = $x + 1;
$unused = 42;
$query = "SELECT * FROM users WHERE name = '" . $user_input . "'";
return $y;
}
"#;
const PYTHON_TAINT_SOURCE: &str = r#"
def taint_example(user_input):
data = input()
query = "SELECT * FROM users WHERE name = '" + data + "'"
cursor.execute(query)
return data
"#;
fn run_reaching_defs(
source: &str,
func: &str,
lang: Language,
ext: &str,
) -> (
Vec<tldr_core::dfg::DefUseChain>,
Vec<tldr_core::dfg::UseDefChain>,
Vec<tldr_core::types::VarRef>,
) {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let cfg = get_cfg_context(path, func, lang).expect("CFG extraction failed");
let dfg = get_dfg_context(path, func, lang).expect("DFG extraction failed");
let reaching = compute_reaching_definitions(&cfg, &dfg.refs);
let du_chains = build_def_use_chains(&reaching, &cfg, &dfg.refs);
let ud_chains = build_use_def_chains(&reaching, &cfg, &dfg.refs);
(du_chains, ud_chains, dfg.refs)
}
#[test]
fn test_reaching_defs_python() {
let (du_chains, ud_chains, refs) = run_reaching_defs(PYTHON_SOURCE, "example", Language::Python, ".py");
assert!(
!du_chains.is_empty(),
"Python: expected at least one def-use chain, got 0"
);
let x_defs: Vec<_> = du_chains.iter().filter(|c| c.definition.var == "x").collect();
assert!(
!x_defs.is_empty(),
"Python: expected def-use chain for variable 'x'"
);
let x_has_uses = x_defs.iter().any(|c| !c.uses.is_empty());
assert!(
x_has_uses,
"Python: variable 'x' should have at least one use reached by its definition"
);
assert!(
!ud_chains.is_empty(),
"Python: expected at least one use-def chain"
);
let def_count = refs
.iter()
.filter(|r| matches!(r.ref_type, RefType::Definition))
.count();
assert!(
def_count >= 2,
"Python: expected at least 2 definitions, got {}",
def_count
);
}
#[test]
fn test_reaching_defs_javascript() {
let (du_chains, ud_chains, _) =
run_reaching_defs(JAVASCRIPT_SOURCE, "example", Language::JavaScript, ".js");
assert!(
!du_chains.is_empty(),
"JavaScript: expected at least one def-use chain"
);
let x_chains: Vec<_> = du_chains.iter().filter(|c| c.definition.var == "x").collect();
assert!(
!x_chains.is_empty(),
"JavaScript: expected def-use chain for 'x'"
);
assert!(
x_chains.iter().any(|c| !c.uses.is_empty()),
"JavaScript: 'x' definition should reach uses"
);
assert!(
!ud_chains.is_empty(),
"JavaScript: expected use-def chains"
);
}
#[test]
fn test_reaching_defs_typescript() {
let (du_chains, ud_chains, _) =
run_reaching_defs(TYPESCRIPT_SOURCE, "example", Language::TypeScript, ".ts");
assert!(
!du_chains.is_empty(),
"TypeScript: expected def-use chains"
);
assert!(
!ud_chains.is_empty(),
"TypeScript: expected use-def chains"
);
}
#[test]
fn test_reaching_defs_go() {
let (du_chains, _ud_chains, refs) =
run_reaching_defs(GO_SOURCE, "example", Language::Go, ".go");
assert!(
!du_chains.is_empty(),
"Go: expected def-use chains"
);
let def_count = refs
.iter()
.filter(|r| matches!(r.ref_type, RefType::Definition))
.count();
assert!(
def_count >= 2,
"Go: expected at least 2 definitions, got {}",
def_count
);
}
#[test]
fn test_reaching_defs_rust() {
let (du_chains, ud_chains, _) =
run_reaching_defs(RUST_SOURCE, "example", Language::Rust, ".rs");
assert!(
!du_chains.is_empty(),
"Rust: expected def-use chains"
);
assert!(
!ud_chains.is_empty(),
"Rust: expected use-def chains"
);
}
#[test]
fn test_reaching_defs_java() {
let (du_chains, _ud_chains, refs) =
run_reaching_defs(JAVA_SOURCE, "example", Language::Java, ".java");
assert!(
!du_chains.is_empty(),
"Java: expected def-use chains"
);
let def_count = refs
.iter()
.filter(|r| matches!(r.ref_type, RefType::Definition))
.count();
assert!(
def_count >= 2,
"Java: expected at least 2 definitions, got {}",
def_count
);
}
#[test]
fn test_reaching_defs_c() {
let (du_chains, _ud_chains, _) =
run_reaching_defs(C_SOURCE, "example", Language::C, ".c");
assert!(
!du_chains.is_empty(),
"C: expected def-use chains"
);
}
#[test]
fn test_reaching_defs_cpp() {
let (du_chains, _ud_chains, _) =
run_reaching_defs(CPP_SOURCE, "example", Language::Cpp, ".cpp");
assert!(
!du_chains.is_empty(),
"C++: expected def-use chains"
);
}
#[test]
fn test_reaching_defs_ruby() {
let (du_chains, _ud_chains, _) =
run_reaching_defs(RUBY_SOURCE, "example", Language::Ruby, ".rb");
assert!(
!du_chains.is_empty(),
"Ruby: expected def-use chains"
);
}
#[test]
fn test_reaching_defs_php() {
let (du_chains, _ud_chains, _) =
run_reaching_defs(PHP_SOURCE, "example", Language::Php, ".php");
assert!(
!du_chains.is_empty(),
"PHP: expected def-use chains"
);
}
fn run_available_exprs(
source: &str,
func: &str,
lang: Language,
ext: &str,
) -> tldr_core::dataflow::AvailableExprsInfo {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let cfg = get_cfg_context(path, func, lang).expect("CFG extraction failed");
let dfg = get_dfg_context(path, func, lang).expect("DFG extraction failed");
compute_available_exprs(&cfg, &dfg).expect("available expressions analysis failed")
}
#[test]
fn test_available_python() {
let avail = run_available_exprs(PYTHON_SOURCE, "example", Language::Python, ".py");
let redundant = avail.redundant_computations();
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"Python: available expressions analysis should produce non-empty results"
);
if !avail.all_exprs.is_empty() {
assert!(
!avail.all_exprs.is_empty(),
"Python: should detect at least one expression, got 0"
);
}
if !redundant.is_empty() {
let (ref expr_text, first_line, second_line) = redundant[0];
assert!(
second_line > first_line,
"Python: redundant computation second occurrence (line {}) should be after first (line {})",
second_line,
first_line
);
assert!(
!expr_text.is_empty(),
"Python: redundant expression text should not be empty"
);
}
}
#[test]
fn test_available_javascript() {
let avail = run_available_exprs(JAVASCRIPT_SOURCE, "example", Language::JavaScript, ".js");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"JavaScript: available expressions should produce results"
);
}
#[test]
fn test_available_typescript() {
let avail = run_available_exprs(TYPESCRIPT_SOURCE, "example", Language::TypeScript, ".ts");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"TypeScript: available expressions should produce results"
);
}
#[test]
fn test_available_go() {
let avail = run_available_exprs(GO_SOURCE, "example", Language::Go, ".go");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"Go: available expressions should produce results"
);
}
#[test]
fn test_available_rust() {
let avail = run_available_exprs(RUST_SOURCE, "example", Language::Rust, ".rs");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"Rust: available expressions should produce results"
);
}
#[test]
fn test_available_java() {
let avail = run_available_exprs(JAVA_SOURCE, "example", Language::Java, ".java");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"Java: available expressions should produce results"
);
}
#[test]
fn test_available_c() {
let avail = run_available_exprs(C_SOURCE, "example", Language::C, ".c");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"C: available expressions should produce results"
);
}
#[test]
fn test_available_cpp() {
let avail = run_available_exprs(CPP_SOURCE, "example", Language::Cpp, ".cpp");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"C++: available expressions should produce results"
);
}
#[test]
fn test_available_ruby() {
let avail = run_available_exprs(RUBY_SOURCE, "example", Language::Ruby, ".rb");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"Ruby: available expressions should produce results"
);
}
#[test]
fn test_available_php() {
let avail = run_available_exprs(PHP_SOURCE, "example", Language::Php, ".php");
assert!(
!avail.avail_in.is_empty() || !avail.avail_out.is_empty() || !avail.all_exprs.is_empty(),
"PHP: available expressions should produce results"
);
}
fn run_dead_stores(
source: &str,
func: &str,
lang: Language,
ext: &str,
) -> Vec<tldr_core::ssa::SsaNameId> {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let code = std::fs::read_to_string(path).expect("read tempfile");
let ssa = construct_ssa(&code, func, lang, SsaType::Minimal)
.expect("SSA construction failed");
find_dead_code(&ssa).expect("dead code detection failed")
}
#[test]
fn test_dead_stores_python() {
let dead = run_dead_stores(PYTHON_SOURCE, "example", Language::Python, ".py");
assert!(
!dead.is_empty(),
"Python: expected at least one dead store (unused = 42), got 0 dead names"
);
}
#[test]
fn test_dead_stores_javascript() {
let dead = run_dead_stores(JAVASCRIPT_SOURCE, "example", Language::JavaScript, ".js");
assert!(
!dead.is_empty(),
"JavaScript: expected at least one dead store"
);
}
#[test]
fn test_dead_stores_typescript() {
let dead = run_dead_stores(TYPESCRIPT_SOURCE, "example", Language::TypeScript, ".ts");
assert!(
!dead.is_empty(),
"TypeScript: expected at least one dead store"
);
}
#[test]
fn test_dead_stores_go() {
let dead = run_dead_stores(GO_SOURCE, "example", Language::Go, ".go");
let _ = dead; }
#[test]
fn test_dead_stores_rust() {
let dead = run_dead_stores(RUST_SOURCE, "example", Language::Rust, ".rs");
let _ = dead;
}
#[test]
fn test_dead_stores_java() {
let dead = run_dead_stores(JAVA_SOURCE, "example", Language::Java, ".java");
assert!(
!dead.is_empty(),
"Java: expected at least one dead store (unused = 42)"
);
}
#[test]
fn test_dead_stores_c() {
let dead = run_dead_stores(C_SOURCE, "example", Language::C, ".c");
assert!(
!dead.is_empty(),
"C: expected at least one dead store (unused = 42)"
);
}
#[test]
fn test_dead_stores_cpp() {
let dead = run_dead_stores(CPP_SOURCE, "example", Language::Cpp, ".cpp");
assert!(
!dead.is_empty(),
"C++: expected at least one dead store"
);
}
#[test]
fn test_dead_stores_ruby() {
let dead = run_dead_stores(RUBY_SOURCE, "example", Language::Ruby, ".rb");
assert!(
!dead.is_empty(),
"Ruby: expected at least one dead store"
);
}
#[test]
fn test_dead_stores_php() {
let dead = run_dead_stores(PHP_SOURCE, "example", Language::Php, ".php");
assert!(
!dead.is_empty(),
"PHP: expected at least one dead store"
);
}
fn run_backward_slice(
source: &str,
func: &str,
line: u32,
lang: Language,
ext: &str,
) -> std::collections::HashSet<u32> {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
get_slice(path, func, line, SliceDirection::Backward, None, lang)
.expect("backward slice failed")
}
fn run_forward_slice(
source: &str,
func: &str,
line: u32,
lang: Language,
ext: &str,
) -> std::collections::HashSet<u32> {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
get_slice(path, func, line, SliceDirection::Forward, None, lang)
.expect("forward slice failed")
}
#[test]
fn test_slice_python() {
let slice = run_backward_slice(PYTHON_SOURCE, "example", 8, Language::Python, ".py");
assert!(
!slice.is_empty(),
"Python backward slice from return should be non-empty"
);
if slice.contains(&4) {
assert!(
slice.contains(&3),
"Python: backward slice from return y includes y=x+1 (line 4) but not x=10 (line 3)"
);
}
let fwd = run_forward_slice(PYTHON_SOURCE, "example", 3, Language::Python, ".py");
assert!(
!fwd.is_empty(),
"Python forward slice from x = 10 should be non-empty"
);
}
#[test]
fn test_slice_python_rich() {
let tmp = write_temp(PYTHON_SOURCE, ".py");
let path = tmp.path().to_str().unwrap();
let rich = get_slice_rich(path, "example", 8, SliceDirection::Backward, None, Language::Python)
.expect("rich slice failed");
assert!(
!rich.nodes.is_empty(),
"Python: rich backward slice should have nodes"
);
for node in &rich.nodes {
assert!(node.line > 0, "Node line should be positive");
}
if !rich.edges.is_empty() {
for edge in &rich.edges {
assert!(
edge.dep_type == "data" || edge.dep_type == "control",
"Edge dep_type should be 'data' or 'control', got '{}'",
edge.dep_type
);
}
}
}
#[test]
fn test_slice_javascript() {
let slice = run_backward_slice(JAVASCRIPT_SOURCE, "example", 8, Language::JavaScript, ".js");
assert!(
!slice.is_empty(),
"JavaScript: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_typescript() {
let slice = run_backward_slice(TYPESCRIPT_SOURCE, "example", 8, Language::TypeScript, ".ts");
assert!(
!slice.is_empty(),
"TypeScript: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_go() {
let slice = run_backward_slice(GO_SOURCE, "example", 13, Language::Go, ".go");
assert!(
!slice.is_empty(),
"Go: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_rust() {
let slice = run_backward_slice(RUST_SOURCE, "example", 8, Language::Rust, ".rs");
assert!(
!slice.is_empty(),
"Rust: backward slice from return expression should be non-empty"
);
}
#[test]
fn test_slice_java() {
let slice = run_backward_slice(JAVA_SOURCE, "example", 8, Language::Java, ".java");
assert!(
!slice.is_empty(),
"Java: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_c() {
let slice = run_backward_slice(C_SOURCE, "example", 6, Language::C, ".c");
assert!(
!slice.is_empty(),
"C: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_cpp() {
let slice = run_backward_slice(CPP_SOURCE, "example", 9, Language::Cpp, ".cpp");
assert!(
!slice.is_empty(),
"C++: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_ruby() {
let slice = run_backward_slice(RUBY_SOURCE, "example", 8, Language::Ruby, ".rb");
assert!(
!slice.is_empty(),
"Ruby: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_php() {
let slice = run_backward_slice(PHP_SOURCE, "example", 8, Language::Php, ".php");
assert!(
!slice.is_empty(),
"PHP: backward slice from return should be non-empty"
);
}
#[test]
fn test_slice_forward_javascript() {
let fwd = run_forward_slice(JAVASCRIPT_SOURCE, "example", 3, Language::JavaScript, ".js");
assert!(
!fwd.is_empty(),
"JavaScript: forward slice from x definition should be non-empty"
);
}
#[test]
fn test_slice_forward_go() {
let fwd = run_forward_slice(GO_SOURCE, "example", 5, Language::Go, ".go");
assert!(
!fwd.is_empty(),
"Go: forward slice from x definition should be non-empty"
);
}
fn run_chop(
source: &str,
func: &str,
source_line: u32,
target_line: u32,
lang: Language,
ext: &str,
) -> std::collections::HashSet<u32> {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let forward = get_slice(path, func, source_line, SliceDirection::Forward, None, lang)
.expect("forward slice for chop failed");
let backward = get_slice(path, func, target_line, SliceDirection::Backward, None, lang)
.expect("backward slice for chop failed");
forward.intersection(&backward).copied().collect()
}
#[test]
fn test_chop_python() {
let chop = run_chop(PYTHON_SOURCE, "example", 3, 8, Language::Python, ".py");
assert!(
!chop.is_empty(),
"Python: chop from x definition to return y should be non-empty"
);
}
#[test]
fn test_chop_javascript() {
let chop = run_chop(JAVASCRIPT_SOURCE, "example", 3, 8, Language::JavaScript, ".js");
assert!(
!chop.is_empty(),
"JavaScript: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_typescript() {
let chop = run_chop(TYPESCRIPT_SOURCE, "example", 3, 8, Language::TypeScript, ".ts");
assert!(
!chop.is_empty(),
"TypeScript: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_go() {
let chop = run_chop(GO_SOURCE, "example", 5, 13, Language::Go, ".go");
assert!(
!chop.is_empty(),
"Go: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_rust() {
let chop = run_chop(RUST_SOURCE, "example", 3, 8, Language::Rust, ".rs");
assert!(
!chop.is_empty(),
"Rust: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_java() {
let chop = run_chop(JAVA_SOURCE, "example", 4, 8, Language::Java, ".java");
assert!(
!chop.is_empty(),
"Java: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_c() {
let chop = run_chop(C_SOURCE, "example", 3, 6, Language::C, ".c");
assert!(
!chop.is_empty(),
"C: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_cpp() {
let chop = run_chop(CPP_SOURCE, "example", 5, 9, Language::Cpp, ".cpp");
assert!(
!chop.is_empty(),
"C++: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_ruby() {
let chop = run_chop(RUBY_SOURCE, "example", 3, 8, Language::Ruby, ".rb");
assert!(
!chop.is_empty(),
"Ruby: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_php() {
let chop = run_chop(PHP_SOURCE, "example", 3, 8, Language::Php, ".php");
assert!(
!chop.is_empty(),
"PHP: chop from x to return should be non-empty"
);
}
#[test]
fn test_chop_is_subset_of_both_slices_python() {
let tmp = write_temp(PYTHON_SOURCE, ".py");
let path = tmp.path().to_str().unwrap();
let forward = get_slice(path, "example", 3, SliceDirection::Forward, None, Language::Python)
.expect("forward slice failed");
let backward = get_slice(path, "example", 8, SliceDirection::Backward, None, Language::Python)
.expect("backward slice failed");
let chop: std::collections::HashSet<u32> =
forward.intersection(&backward).copied().collect();
for line in &chop {
assert!(
forward.contains(line),
"Chop line {} not in forward slice",
line
);
}
for line in &chop {
assert!(
backward.contains(line),
"Chop line {} not in backward slice",
line
);
}
}
fn run_taint(
source: &str,
func: &str,
lang: Language,
ext: &str,
) -> tldr_core::security::TaintInfo {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let cfg = get_cfg_context(path, func, lang).expect("CFG extraction for taint failed");
let dfg = get_dfg_context(path, func, lang).expect("DFG extraction for taint failed");
let code = std::fs::read_to_string(path).expect("read source for taint");
let (fn_start, fn_end) = if cfg.blocks.is_empty() {
(1u32, code.lines().count() as u32)
} else {
let start = cfg.blocks.iter().map(|b| b.lines.0).min().unwrap_or(1);
let end = cfg
.blocks
.iter()
.map(|b| b.lines.1)
.max()
.unwrap_or(code.lines().count() as u32);
(start, end)
};
let statements: HashMap<u32, String> = code
.lines()
.enumerate()
.filter(|(i, _)| {
let line_num = (i + 1) as u32;
line_num >= fn_start && line_num <= fn_end
})
.map(|(i, line)| ((i + 1) as u32, line.to_string()))
.collect();
compute_taint(&cfg, &dfg.refs, &statements, lang).expect("taint analysis failed")
}
#[test]
fn test_taint_python() {
let result = run_taint(PYTHON_TAINT_SOURCE, "taint_example", Language::Python, ".py");
assert!(
!result.sources.is_empty(),
"Python taint: expected at least one source (input()), got 0. \
Sources: {:?}",
result.sources
);
let has_user_input = result
.sources
.iter()
.any(|s| matches!(s.source_type, tldr_core::security::TaintSourceType::UserInput));
assert!(
has_user_input,
"Python taint: expected UserInput source type, got: {:?}",
result.sources.iter().map(|s| &s.source_type).collect::<Vec<_>>()
);
assert!(
!result.sinks.is_empty(),
"Python taint: expected at least one sink (cursor.execute), got 0"
);
if !result.flows.is_empty() {
let flow = &result.flows[0];
assert!(
flow.source.line > 0,
"Taint flow source line should be positive"
);
assert!(
flow.sink.line > 0,
"Taint flow sink line should be positive"
);
assert!(
flow.sink.line > flow.source.line,
"Taint flow: sink (line {}) should be after source (line {})",
flow.sink.line,
flow.source.line
);
}
}
#[test]
fn test_taint_python_tainted_vars() {
let result = run_taint(PYTHON_TAINT_SOURCE, "taint_example", Language::Python, ".py");
let data_tainted = result
.tainted_vars
.values()
.any(|vars| vars.contains("data"));
assert!(
data_tainted,
"Python taint: variable 'data' should be tainted in at least one block. \
tainted_vars: {:?}",
result.tainted_vars
);
}
#[test]
fn test_taint_python_sql_injection_pattern() {
let sql_source = r#"
def sql_vuln():
name = input()
q = "SELECT * FROM users WHERE name = '" + name + "'"
cursor.execute(q)
"#;
let result = run_taint(sql_source, "sql_vuln", Language::Python, ".py");
assert!(
!result.sources.is_empty(),
"SQL injection pattern: should detect input() as source"
);
assert!(
!result.sinks.is_empty(),
"SQL injection pattern: should detect cursor.execute as sink"
);
}
#[test]
fn test_taint_javascript() {
let js_taint = r#"
function vuln(req) {
let userInput = req.query.name;
let query = "SELECT * FROM users WHERE name = '" + userInput + "'";
return query;
}
"#;
let result = run_taint(js_taint, "vuln", Language::JavaScript, ".js");
let _ = result;
}
#[test]
fn test_taint_python_no_vuln() {
let safe_source = r#"
def safe_func():
x = 42
y = x + 1
return y
"#;
let result = run_taint(safe_source, "safe_func", Language::Python, ".py");
assert!(
result.flows.is_empty(),
"Safe function should have no taint flows, got: {:?}",
result.flows
);
}
#[test]
fn test_taint_python_sanitizer() {
let sanitized_source = r#"
def sanitized_func():
data = input()
safe_data = int(data)
return safe_data
"#;
let result = run_taint(sanitized_source, "sanitized_func", Language::Python, ".py");
let has_source = !result.sources.is_empty();
if has_source {
assert!(
result.flows.is_empty(),
"Sanitized function should have no taint flows"
);
}
}
#[test]
fn test_dfg_extraction_all_languages() {
let cases: Vec<(&str, Language, &str, &str)> = vec![
(PYTHON_SOURCE, Language::Python, ".py", "Python"),
(JAVASCRIPT_SOURCE, Language::JavaScript, ".js", "JavaScript"),
(TYPESCRIPT_SOURCE, Language::TypeScript, ".ts", "TypeScript"),
(GO_SOURCE, Language::Go, ".go", "Go"),
(RUST_SOURCE, Language::Rust, ".rs", "Rust"),
(JAVA_SOURCE, Language::Java, ".java", "Java"),
(C_SOURCE, Language::C, ".c", "C"),
(CPP_SOURCE, Language::Cpp, ".cpp", "C++"),
(RUBY_SOURCE, Language::Ruby, ".rb", "Ruby"),
(PHP_SOURCE, Language::Php, ".php", "PHP"),
];
for (source, lang, ext, name) in &cases {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let dfg = get_dfg_context(path, "example", *lang);
assert!(
dfg.is_ok(),
"{}: DFG extraction failed: {:?}",
name,
dfg.err()
);
let dfg = dfg.unwrap();
assert!(
!dfg.refs.is_empty(),
"{}: DFG should have at least one variable reference",
name
);
}
}
#[test]
fn test_cfg_dfg_pair_all_languages() {
let cases: Vec<(&str, Language, &str, &str)> = vec![
(PYTHON_SOURCE, Language::Python, ".py", "Python"),
(JAVASCRIPT_SOURCE, Language::JavaScript, ".js", "JavaScript"),
(TYPESCRIPT_SOURCE, Language::TypeScript, ".ts", "TypeScript"),
(GO_SOURCE, Language::Go, ".go", "Go"),
(RUST_SOURCE, Language::Rust, ".rs", "Rust"),
(JAVA_SOURCE, Language::Java, ".java", "Java"),
(C_SOURCE, Language::C, ".c", "C"),
(CPP_SOURCE, Language::Cpp, ".cpp", "C++"),
(RUBY_SOURCE, Language::Ruby, ".rb", "Ruby"),
(PHP_SOURCE, Language::Php, ".php", "PHP"),
];
for (source, lang, ext, name) in &cases {
let tmp = write_temp(source, ext);
let path = tmp.path().to_str().unwrap();
let cfg = get_cfg_context(path, "example", *lang);
assert!(
cfg.is_ok(),
"{}: CFG extraction failed: {:?}",
name,
cfg.err()
);
let cfg = cfg.unwrap();
assert!(
!cfg.blocks.is_empty(),
"{}: CFG should have at least one block",
name
);
let dfg = get_dfg_context(path, "example", *lang);
assert!(
dfg.is_ok(),
"{}: DFG extraction failed: {:?}",
name,
dfg.err()
);
}
}
#[test]
fn test_reaching_defs_variable_names_python() {
let (du_chains, _ud_chains, _) =
run_reaching_defs(PYTHON_SOURCE, "example", Language::Python, ".py");
let defined_vars: Vec<&str> = du_chains.iter().map(|c| c.definition.var.as_str()).collect();
assert!(
defined_vars.contains(&"x"),
"Python reaching-defs: should define 'x', got: {:?}",
defined_vars
);
assert!(
defined_vars.contains(&"y"),
"Python reaching-defs: should define 'y', got: {:?}",
defined_vars
);
}
#[test]
fn test_dead_stores_cross_check_python() {
let tmp = write_temp(PYTHON_SOURCE, ".py");
let path = tmp.path().to_str().unwrap();
let code = std::fs::read_to_string(path).expect("read tempfile");
let ssa = construct_ssa(&code, "example", Language::Python, SsaType::Minimal)
.expect("SSA construction failed");
let dead = find_dead_code(&ssa).expect("dead code detection failed");
let dead_var_names: Vec<&str> = dead
.iter()
.filter_map(|id| {
ssa.ssa_names
.iter()
.find(|n| n.id == *id)
.map(|n| n.variable.as_str())
})
.collect();
assert!(
dead_var_names.contains(&"unused"),
"Python dead-stores: 'unused' should be dead, got dead vars: {:?}",
dead_var_names
);
assert!(
!dead_var_names.is_empty(),
"Python dead-stores: should have at least one dead variable"
);
for name in &dead_var_names {
assert!(
!name.is_empty(),
"Dead variable names should be non-empty strings"
);
}
}
#[test]
fn test_slice_transitive_deps_python() {
let slice = run_backward_slice(PYTHON_SOURCE, "example", 8, Language::Python, ".py");
assert!(
!slice.is_empty(),
"Python: backward slice from return y should not be empty"
);
assert!(
slice.contains(&3),
"Python: backward slice from return y should include x = 10 (line 3). \
Slice contains: {:?}",
slice
);
assert!(
slice.contains(&4),
"Python: backward slice from return y should include y = x + 1 (line 4). \
Slice contains: {:?}",
slice
);
assert!(
slice.contains(&8),
"Python: backward slice should include the return line itself (line 8). \
Slice contains: {:?}",
slice
);
}
#[test]
fn test_slice_direction_difference_python() {
let tmp = write_temp(PYTHON_SOURCE, ".py");
let path = tmp.path().to_str().unwrap();
let forward = get_slice(path, "example", 3, SliceDirection::Forward, None, Language::Python)
.expect("forward slice failed");
let backward = get_slice(path, "example", 8, SliceDirection::Backward, None, Language::Python)
.expect("backward slice failed");
assert!(!forward.is_empty(), "Forward slice should be non-empty");
assert!(!backward.is_empty(), "Backward slice should be non-empty");
let intersection: std::collections::HashSet<u32> =
forward.intersection(&backward).copied().collect();
assert!(
!intersection.is_empty(),
"Forward and backward slices should overlap (chop should be non-empty)"
);
}