use crate::detectors::taint::TaintCategory;
use crate::parsers::lightweight::Language;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct IntraFlowResult {
pub tainted_vars: HashMap<String, TaintSource>,
pub sink_reaches: Vec<SinkReach>,
pub sanitized_vars: HashSet<String>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct TaintSource {
pub pattern: String,
pub line: usize,
}
#[derive(Debug, Clone)]
pub struct SinkReach {
pub variable: String,
pub taint_source: TaintSource,
pub sink_pattern: String,
pub sink_line: usize,
pub is_sanitized: bool,
pub confidence: f64,
}
pub struct HeuristicFlow;
impl HeuristicFlow {
pub fn new() -> Self {
Self
}
pub fn analyze_intra_function(
&self,
func_source: &str,
language: Language,
category: TaintCategory,
sources: &HashSet<String>,
sinks: &HashSet<String>,
sanitizers: &HashSet<String>,
) -> IntraFlowResult {
let _ = category;
let sources_lower: Vec<(String, String)> = sources
.iter()
.map(|s| (s.clone(), s.to_lowercase()))
.collect();
let sinks_lower: Vec<(String, String)> = sinks
.iter()
.map(|s| (s.clone(), s.to_lowercase()))
.collect();
let sanitizers_lower: Vec<String> = sanitizers.iter().map(|s| s.to_lowercase()).collect();
let mut tainted: HashMap<String, TaintSource> = HashMap::new();
let mut sanitized: HashSet<String> = HashSet::new();
let mut sink_reaches: Vec<SinkReach> = Vec::new();
for (line_idx, line) in func_source.lines().enumerate() {
let line_num = line_idx + 1;
let trimmed = line.trim();
if trimmed.is_empty()
|| trimmed.starts_with('#')
|| trimmed.starts_with("//")
|| trimmed.starts_with('*')
{
continue;
}
if let Some((lhs, rhs)) = self.parse_assignment(line, language) {
let rhs_lower = rhs.to_lowercase();
let matched_source = sources_lower
.iter()
.find(|(_, lowered)| rhs_lower.contains(lowered.as_str()));
if let Some((original, _)) = matched_source {
tainted.insert(
lhs.to_string(),
TaintSource {
pattern: original.clone(),
line: line_num,
},
);
sanitized.remove(lhs);
continue;
}
if let Some(source_var) = self.rhs_references_tainted(rhs, &tainted) {
if self.is_sanitizer_call(rhs, &sanitizers_lower) {
sanitized.insert(lhs.to_string());
} else {
if let Some(source) = tainted.get(&source_var) {
tainted.insert(lhs.to_string(), source.clone());
sanitized.remove(lhs);
}
}
continue;
}
if self.is_sanitizer_call(rhs, &sanitizers_lower) && tainted.contains_key(lhs) {
sanitized.insert(lhs.to_string());
continue;
}
}
let mut reaches =
self.check_sink_call(trimmed, line_num, &tainted, &sinks_lower, &sanitized);
sink_reaches.append(&mut reaches);
}
IntraFlowResult {
tainted_vars: tainted,
sink_reaches,
sanitized_vars: sanitized,
}
}
fn parse_assignment<'a>(&self, line: &'a str, lang: Language) -> Option<(&'a str, &'a str)> {
let trimmed = line.trim();
if trimmed.starts_with('#')
|| trimmed.starts_with("//")
|| trimmed.starts_with('*')
|| trimmed.starts_with("/*")
{
return None;
}
match lang {
Language::Python => {
if let Some(eq_pos) = trimmed.find('=') {
if eq_pos > 0
&& !trimmed[eq_pos..].starts_with("==")
&& !matches!(
trimmed.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs = trimmed[..eq_pos].trim();
let rhs = trimmed[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
Language::JavaScript | Language::TypeScript => {
let stripped = trimmed
.strip_prefix("const ")
.or_else(|| trimmed.strip_prefix("let "))
.or_else(|| trimmed.strip_prefix("var "))
.unwrap_or(trimmed);
if let Some(eq_pos) = stripped.find('=') {
if eq_pos > 0
&& !stripped[eq_pos..].starts_with("==")
&& !stripped[eq_pos..].starts_with("=>")
&& !matches!(
stripped.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs = stripped[..eq_pos].trim();
let lhs = lhs.split(':').next().unwrap_or(lhs).trim();
let rhs = stripped[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
Language::Go => {
if let Some(pos) = trimmed.find(":=") {
let lhs = trimmed[..pos].trim();
let rhs = trimmed[pos + 2..].trim();
let lhs = lhs.split(',').next().unwrap_or(lhs).trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
} else if let Some(eq_pos) = trimmed.find('=') {
if eq_pos > 0
&& !trimmed[eq_pos..].starts_with("==")
&& !matches!(
trimmed.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs = trimmed[..eq_pos].trim();
let rhs = trimmed[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
Language::Rust => {
let stripped = trimmed
.strip_prefix("let ")
.map(|s| s.strip_prefix("mut ").unwrap_or(s));
if let Some(stripped) = stripped {
if let Some(eq_pos) = stripped.find('=') {
if !stripped[eq_pos..].starts_with("==") {
let lhs = stripped[..eq_pos].trim();
let lhs = lhs.split(':').next().unwrap_or(lhs).trim();
let rhs = stripped[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
if !trimmed.starts_with("let ") {
if let Some(eq_pos) = trimmed.find('=') {
if eq_pos > 0
&& !trimmed[eq_pos..].starts_with("==")
&& !trimmed[eq_pos..].starts_with("=>")
&& !matches!(
trimmed.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs = trimmed[..eq_pos].trim();
let rhs = trimmed[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
}
Language::Java | Language::CSharp | Language::Kotlin => {
if let Some(eq_pos) = trimmed.find('=') {
if eq_pos > 0
&& !trimmed[eq_pos..].starts_with("==")
&& !matches!(
trimmed.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs_full = trimmed[..eq_pos].trim();
let rhs = trimmed[eq_pos + 1..].trim();
let lhs = lhs_full.split_whitespace().last().unwrap_or(lhs_full);
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
_ => {
if let Some(eq_pos) = trimmed.find('=') {
if eq_pos > 0
&& !trimmed[eq_pos..].starts_with("==")
&& !matches!(
trimmed.as_bytes().get(eq_pos - 1),
Some(b'!' | b'<' | b'>' | b'=')
)
{
let lhs = trimmed[..eq_pos].trim();
let rhs = trimmed[eq_pos + 1..].trim();
if is_simple_var(lhs) && !rhs.is_empty() {
return Some((lhs, rhs));
}
}
}
}
}
None
}
fn rhs_references_tainted(
&self,
rhs: &str,
tainted: &HashMap<String, TaintSource>,
) -> Option<String> {
for var in tainted.keys() {
if rhs_contains_var(rhs, var) {
return Some(var.clone());
}
}
None
}
fn is_sanitizer_call(&self, rhs: &str, sanitizers_lower: &[String]) -> bool {
let rhs_lower = rhs.to_lowercase();
sanitizers_lower
.iter()
.any(|s| rhs_lower.contains(s.as_str()))
}
fn check_sink_call(
&self,
line: &str,
line_num: usize,
tainted: &HashMap<String, TaintSource>,
sinks_lower: &[(String, String)],
sanitized: &HashSet<String>,
) -> Vec<SinkReach> {
let line_lower = line.to_lowercase();
let matching_sinks: Vec<&(String, String)> = sinks_lower
.iter()
.filter(|(_, lowered)| line_lower.contains(lowered.as_str()))
.collect();
if matching_sinks.is_empty() {
return Vec::new();
}
let mut reaches = Vec::new();
for (original, lowered) in matching_sinks {
for (var, source) in tainted {
if sanitized.contains(var) {
continue;
}
if line_contains_var_in_call(line, lowered, var) {
reaches.push(SinkReach {
variable: var.clone(),
taint_source: source.clone(),
sink_pattern: original.clone(),
sink_line: line_num,
is_sanitized: false,
confidence: 0.85,
});
}
}
}
reaches
}
}
fn is_simple_var(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.next()
.is_some_and(|c| c.is_alphabetic() || c == '_')
&& s.chars().all(|c| c.is_alphanumeric() || c == '_')
}
fn rhs_contains_var(rhs: &str, var: &str) -> bool {
let mut search_from = 0;
while let Some(pos) = rhs[search_from..].find(var) {
let abs_pos = search_from + pos;
let before_ok = abs_pos == 0
|| !rhs.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
&& rhs.as_bytes()[abs_pos - 1] != b'_';
let after_pos = abs_pos + var.len();
let after_ok = after_pos >= rhs.len()
|| !rhs.as_bytes()[after_pos].is_ascii_alphanumeric()
&& rhs.as_bytes()[after_pos] != b'_';
if before_ok && after_ok {
return true;
}
search_from = abs_pos + 1;
}
false
}
fn line_contains_var_in_call(line: &str, _sink_lower: &str, var: &str) -> bool {
rhs_contains_var(line, var)
}
#[cfg(test)]
mod tests {
use super::*;
fn sources() -> HashSet<String> {
[
"request.args",
"request.form",
"req.body",
"req.query",
"req.params",
"params[",
]
.iter()
.map(|s| s.to_string())
.collect()
}
fn sql_sinks() -> HashSet<String> {
["execute", "executemany", "raw_sql", "query(", "db.run"]
.iter()
.map(|s| s.to_string())
.collect()
}
fn sanitizers() -> HashSet<String> {
[
"escape",
"sanitize",
"parameterize",
"prepare",
"bindparam",
"html.escape",
]
.iter()
.map(|s| s.to_string())
.collect()
}
#[test]
fn test_python_basic_taint_flow() {
let code = r#"
user_input = request.args.get("q")
query = "SELECT * FROM t WHERE x = '" + user_input + "'"
cursor.execute(query)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect taint reaching execute()"
);
assert_eq!(result.sink_reaches[0].sink_pattern, "execute");
assert!(!result.sink_reaches[0].is_sanitized);
}
#[test]
fn test_python_fstring_propagation() {
let code = r#"
user_input = request.args.get("q")
query = f"SELECT * FROM t WHERE x = '{user_input}'"
cursor.execute(query)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect f-string taint propagation"
);
}
#[test]
fn test_python_sanitized_flow() {
let code = r#"
user_input = request.args.get("q")
clean_input = escape(user_input)
query = f"SELECT * FROM t WHERE x = '{clean_input}'"
cursor.execute(query)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
let vulnerable: Vec<_> = result
.sink_reaches
.iter()
.filter(|r| !r.is_sanitized)
.collect();
assert!(
vulnerable.is_empty(),
"Sanitized flow should not be flagged"
);
assert!(result.sanitized_vars.contains("clean_input"));
}
#[test]
fn test_javascript_taint_flow() {
let code = r#"
const userInput = req.body.username;
const query = "SELECT * FROM users WHERE name = '" + userInput + "'";
db.run(query);
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::JavaScript,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect JS taint flow"
);
}
#[test]
fn test_go_taint_flow() {
let code = r#"
userInput := req.query.Get("name")
query := "SELECT * FROM users WHERE name = '" + userInput + "'"
db.run(query)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Go,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect Go taint flow"
);
}
#[test]
fn test_rust_taint_flow() {
let code = r#"
let user_input = req.query("name");
let query = format!("SELECT * FROM users WHERE name = '{}'", user_input);
db.run(&query);
"#;
let flow = HeuristicFlow::new();
let mut srcs = sources();
srcs.insert("req.query".to_string());
let result = flow.analyze_intra_function(
code,
Language::Rust,
TaintCategory::SqlInjection,
&srcs,
&sql_sinks(),
&sanitizers(),
);
assert!(
result.tainted_vars.contains_key("user_input"),
"user_input should be tainted"
);
}
#[test]
fn test_no_taint_no_findings() {
let code = r#"
x = 42
y = x + 1
print(y)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
result.sink_reaches.is_empty(),
"No taint sources means no findings"
);
assert!(result.tainted_vars.is_empty());
}
#[test]
fn test_taint_propagation_chain() {
let code = r#"
raw = request.args.get("input")
step1 = raw
step2 = step1
step3 = step2
cursor.execute(step3)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::SqlInjection,
&sources(),
&sql_sinks(),
&sanitizers(),
);
assert!(
result.tainted_vars.contains_key("step3"),
"Taint should propagate through chain"
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect taint at end of chain"
);
}
#[test]
fn test_command_injection_flow() {
let cmd_sinks: HashSet<String> = [
"system",
"exec",
"popen",
"subprocess.run",
"subprocess.call",
]
.iter()
.map(|s| s.to_string())
.collect();
let code = r#"
filename = request.form.get("file")
cmd = "cat " + filename
os.system(cmd)
"#;
let flow = HeuristicFlow::new();
let result = flow.analyze_intra_function(
code,
Language::Python,
TaintCategory::CommandInjection,
&sources(),
&cmd_sinks,
&sanitizers(),
);
assert!(
!result.sink_reaches.is_empty(),
"Should detect command injection flow"
);
}
#[test]
fn test_is_simple_var() {
assert!(is_simple_var("x"));
assert!(is_simple_var("user_input"));
assert!(is_simple_var("_private"));
assert!(is_simple_var("camelCase"));
assert!(!is_simple_var("obj.field"));
assert!(!is_simple_var("arr[0]"));
assert!(!is_simple_var(""));
assert!(!is_simple_var("123abc"));
}
#[test]
fn test_rhs_contains_var_word_boundary() {
assert!(rhs_contains_var("foo + bar", "foo"));
assert!(rhs_contains_var("func(foo)", "foo"));
assert!(rhs_contains_var("f\"{foo}\"", "foo"));
assert!(!rhs_contains_var("foobar", "foo"));
assert!(!rhs_contains_var("barfoo", "foo"));
assert!(!rhs_contains_var("_foo", "foo"));
assert!(rhs_contains_var("a + foo + b", "foo"));
}
}