Skip to main content

infigraph_core/taint/
mod.rs

1pub mod dynamic_urls;
2pub mod interprocedural;
3pub mod path_traversal;
4pub mod sinks;
5pub mod sources;
6
7use std::collections::{HashMap, HashSet};
8use std::path::Path;
9
10use anyhow::Result;
11use rayon::prelude::*;
12use serde::Serialize;
13
14use crate::graph::GraphStore;
15use sinks::{TAINT_SANITIZERS, TAINT_SINKS};
16use sources::TAINT_SOURCES;
17
18pub type SourceCache = HashMap<String, Vec<String>>;
19
20#[derive(Clone)]
21pub struct FuncInfo {
22    pub id: String,
23    pub file: String,
24    pub start_line: u32,
25    pub end_line: u32,
26}
27
28pub fn build_source_cache(store: &GraphStore, root: &Path) -> Result<(Vec<FuncInfo>, SourceCache)> {
29    let conn = store.connection()?;
30    let result = conn
31        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method', 'Test'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
32        .map_err(|e| anyhow::anyhow!("query failed: {e}"))?;
33
34    let mut functions = Vec::new();
35    let mut files_needed: HashSet<String> = HashSet::new();
36    for row in result {
37        if row.len() < 4 {
38            continue;
39        }
40        let id = row[0].to_string();
41        let file = row[1].to_string();
42        let start: u32 = row[2].to_string().parse().unwrap_or(0);
43        let end: u32 = row[3].to_string().parse().unwrap_or(0);
44        if start > 0 && end > start {
45            files_needed.insert(file.clone());
46            functions.push(FuncInfo {
47                id,
48                file,
49                start_line: start,
50                end_line: end,
51            });
52        }
53    }
54
55    let files_vec: Vec<String> = files_needed.into_iter().collect();
56    let cache: SourceCache = files_vec
57        .par_iter()
58        .map(|file| {
59            let content = std::fs::read_to_string(root.join(file))
60                .unwrap_or_default()
61                .lines()
62                .map(String::from)
63                .collect();
64            (file.clone(), content)
65        })
66        .collect();
67
68    Ok((functions, cache))
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct TaintFlow {
73    pub symbol_id: String,
74    pub file: String,
75    pub source_kind: String,
76    pub source_line: u32,
77    pub source_var: String,
78    pub sink_kind: String,
79    pub sink_line: u32,
80    pub sink_category: String,
81    pub path: Vec<String>,
82    pub sanitized: bool,
83    pub sanitizer: Option<String>,
84}
85
86pub fn detect_taint_flows(store: &GraphStore, root: &Path) -> Result<Vec<TaintFlow>> {
87    let _lock = store.write_lock()?;
88    let conn = store.connection()?;
89
90    let result = conn
91        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method', 'Test'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
92        .map_err(|e| anyhow::anyhow!("query failed: {e}"))?;
93
94    let mut functions: Vec<(String, String, u32, u32)> = Vec::new();
95    for row in result {
96        if row.len() < 4 {
97            continue;
98        }
99        let id = row[0].to_string();
100        let file = row[1].to_string();
101        let start: u32 = row[2].to_string().parse().unwrap_or(0);
102        let end: u32 = row[3].to_string().parse().unwrap_or(0);
103        if start > 0 && end > start {
104            functions.push((id, file, start, end));
105        }
106    }
107
108    let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
109    let mut all_flows = Vec::new();
110
111    for (symbol_id, file, start_line, end_line) in &functions {
112        let lines = file_cache.entry(file.clone()).or_insert_with(|| {
113            let abs = root.join(file);
114            std::fs::read_to_string(&abs)
115                .unwrap_or_default()
116                .lines()
117                .map(String::from)
118                .collect()
119        });
120
121        let start_idx = (*start_line as usize).saturating_sub(1);
122        let end_idx = (*end_line as usize).min(lines.len());
123        if start_idx >= end_idx {
124            continue;
125        }
126
127        let func_lines = &lines[start_idx..end_idx];
128        let flows = analyze_function(symbol_id, file, *start_line, func_lines);
129        all_flows.extend(flows);
130    }
131
132    if !all_flows.is_empty() {
133        write_taint_flows(store, &all_flows)?;
134    }
135
136    Ok(all_flows)
137}
138
139pub fn detect_taint_flows_with_cache(
140    store: &GraphStore,
141    functions: &[FuncInfo],
142    cache: &SourceCache,
143) -> Result<Vec<TaintFlow>> {
144    let _lock = store.write_lock()?;
145    let mut all_flows = Vec::new();
146
147    for func in functions {
148        let lines = match cache.get(&func.file) {
149            Some(l) => l,
150            None => continue,
151        };
152        let start_idx = (func.start_line as usize).saturating_sub(1);
153        let end_idx = (func.end_line as usize).min(lines.len());
154        if start_idx >= end_idx {
155            continue;
156        }
157
158        let func_lines = &lines[start_idx..end_idx];
159        let flows = analyze_function(&func.id, &func.file, func.start_line, func_lines);
160        all_flows.extend(flows);
161    }
162
163    if !all_flows.is_empty() {
164        write_taint_flows(store, &all_flows)?;
165    }
166
167    Ok(all_flows)
168}
169
170fn analyze_function(
171    symbol_id: &str,
172    file: &str,
173    base_line: u32,
174    lines: &[String],
175) -> Vec<TaintFlow> {
176    let mut tainted: HashMap<String, TaintInfo> = HashMap::new();
177    let mut flows = Vec::new();
178
179    for (offset, line) in lines.iter().enumerate() {
180        let line_no = base_line + offset as u32;
181        let trimmed = line.trim();
182        let lower = trimmed.to_lowercase();
183
184        // Check for taint sources
185        for source in TAINT_SOURCES {
186            for &pattern in source.patterns {
187                if lower.contains(&pattern.to_lowercase()) {
188                    if let Some(var) = extract_lhs(trimmed) {
189                        tainted.insert(
190                            var.clone(),
191                            TaintInfo {
192                                source_kind: source.kind.to_string(),
193                                source_line: line_no,
194                                path: vec![format!("L{}: {} <- {}", line_no, var, source.kind)],
195                                original_var: var,
196                            },
197                        );
198                    }
199                }
200            }
201        }
202
203        // Propagate taint through assignments
204        if let Some((lhs, rhs)) = parse_assignment(trimmed) {
205            let rhs_lower = rhs.to_lowercase();
206            let mut propagated = false;
207            for (tvar, info) in tainted.clone() {
208                if rhs_lower.contains(&tvar.to_lowercase()) {
209                    let mut new_path = info.path.clone();
210                    new_path.push(format!("L{}: {} = ...{}...", line_no, lhs, tvar));
211                    tainted.insert(
212                        lhs.clone(),
213                        TaintInfo {
214                            source_kind: info.source_kind.clone(),
215                            source_line: info.source_line,
216                            path: new_path,
217                            original_var: info.original_var.clone(),
218                        },
219                    );
220                    propagated = true;
221                    break;
222                }
223            }
224            if !propagated {
225                // Check if RHS has a sanitizer — clears taint from LHS
226                for san in TAINT_SANITIZERS {
227                    for &pat in san.patterns {
228                        if rhs_lower.contains(&pat.to_lowercase()) {
229                            tainted.remove(&lhs);
230                        }
231                    }
232                }
233            }
234        }
235
236        // Check for taint sinks
237        for sink in TAINT_SINKS {
238            for &pattern in sink.patterns {
239                if lower.contains(&pattern.to_lowercase()) {
240                    let sink_vars = extract_args_from_call(trimmed);
241                    for svar in &sink_vars {
242                        if let Some(info) = tainted.get(&svar.to_lowercase()).or_else(|| {
243                            tainted
244                                .iter()
245                                .find(|(k, _)| svar.to_lowercase().contains(&k.to_lowercase()))
246                                .map(|(_, v)| v)
247                        }) {
248                            let sanitized = is_sanitized_nearby(lines, offset, sink.category);
249                            let sanitizer = if sanitized {
250                                find_sanitizer_name(lines, offset, sink.category)
251                            } else {
252                                None
253                            };
254
255                            let mut path = info.path.clone();
256                            path.push(format!(
257                                "L{}: {}({}) [SINK: {}]",
258                                line_no,
259                                pattern.trim_end_matches('('),
260                                svar,
261                                sink.kind
262                            ));
263
264                            flows.push(TaintFlow {
265                                symbol_id: symbol_id.to_string(),
266                                file: file.to_string(),
267                                source_kind: info.source_kind.clone(),
268                                source_line: info.source_line,
269                                source_var: info.original_var.clone(),
270                                sink_kind: sink.kind.to_string(),
271                                sink_line: line_no,
272                                sink_category: sink.category.to_string(),
273                                path,
274                                sanitized,
275                                sanitizer,
276                            });
277                        }
278                    }
279                }
280            }
281        }
282    }
283
284    flows
285}
286
287#[derive(Debug, Clone)]
288struct TaintInfo {
289    source_kind: String,
290    source_line: u32,
291    path: Vec<String>,
292    original_var: String,
293}
294
295fn extract_lhs(line: &str) -> Option<String> {
296    let line = line.trim();
297    // Python/JS/Go/Rust: var = expr or let/var/const var = expr
298    let stripped = line
299        .strip_prefix("let ")
300        .or_else(|| line.strip_prefix("var "))
301        .or_else(|| line.strip_prefix("const "))
302        .or_else(|| line.strip_prefix("mut "))
303        .unwrap_or(line);
304
305    if let Some(eq_pos) = stripped.find('=') {
306        if eq_pos > 0 {
307            let before = stripped[..eq_pos].trim();
308            // Skip if it's == or !=
309            if stripped.get(eq_pos + 1..eq_pos + 2) == Some("=") {
310                return None;
311            }
312            if before.ends_with('!') || before.ends_with('<') || before.ends_with('>') {
313                return None;
314            }
315            // Extract variable name (handle type annotations like `x: int = ...`)
316            let var = before.split(':').next()?.trim();
317            let var = var.split_whitespace().last()?;
318            if var.chars().all(|c| c.is_alphanumeric() || c == '_') && !var.is_empty() {
319                return Some(var.to_lowercase());
320            }
321        }
322    }
323    None
324}
325
326fn parse_assignment(line: &str) -> Option<(String, String)> {
327    let line = line.trim();
328    let stripped = line
329        .strip_prefix("let ")
330        .or_else(|| line.strip_prefix("var "))
331        .or_else(|| line.strip_prefix("const "))
332        .or_else(|| line.strip_prefix("mut "))
333        .unwrap_or(line);
334
335    if let Some(eq_pos) = stripped.find('=') {
336        if eq_pos > 0 && stripped.get(eq_pos + 1..eq_pos + 2) != Some("=") {
337            let before = stripped[..eq_pos].trim();
338            if before.ends_with('!') || before.ends_with('<') || before.ends_with('>') {
339                return None;
340            }
341            let var = before.split(':').next()?.trim();
342            let var = var.split_whitespace().last()?;
343            if var.chars().all(|c| c.is_alphanumeric() || c == '_') && !var.is_empty() {
344                let rhs = stripped[eq_pos + 1..].trim();
345                return Some((var.to_lowercase(), rhs.to_string()));
346            }
347        }
348    }
349    None
350}
351
352fn extract_args_from_call(line: &str) -> Vec<String> {
353    let mut args = Vec::new();
354    let lower = line.to_lowercase();
355
356    // Extract identifiers that appear as function arguments
357    for (i, _) in lower.match_indices('(') {
358        if let Some(close) = lower[i..].find(')') {
359            let inner = &line[i + 1..i + close];
360            for arg in inner.split(',') {
361                let arg = arg
362                    .trim()
363                    .trim_matches(|c: char| c == '"' || c == '\'' || c == '`');
364                let var = arg.split('.').next().unwrap_or(arg).trim();
365                if !var.is_empty() && var.chars().all(|c| c.is_alphanumeric() || c == '_') {
366                    args.push(var.to_lowercase());
367                }
368            }
369        }
370    }
371
372    // Also check for string concatenation patterns: "..." + var
373    for word in line.split(|c: char| !c.is_alphanumeric() && c != '_') {
374        let w = word.trim();
375        if !w.is_empty() && w.chars().all(|c| c.is_alphanumeric() || c == '_') {
376            args.push(w.to_lowercase());
377        }
378    }
379
380    let unique: HashSet<String> = args.into_iter().collect();
381    unique.into_iter().collect()
382}
383
384fn is_sanitized_nearby(lines: &[String], current_offset: usize, category: &str) -> bool {
385    let start = current_offset.saturating_sub(5);
386    let end = (current_offset + 6).min(lines.len());
387
388    for san in TAINT_SANITIZERS {
389        if san.category != category {
390            continue;
391        }
392        for line in &lines[start..end] {
393            let lower = line.to_lowercase();
394            for &pat in san.patterns {
395                if lower.contains(&pat.to_lowercase()) {
396                    return true;
397                }
398            }
399        }
400    }
401    false
402}
403
404fn find_sanitizer_name(lines: &[String], current_offset: usize, category: &str) -> Option<String> {
405    let start = current_offset.saturating_sub(5);
406    let end = (current_offset + 6).min(lines.len());
407
408    for san in TAINT_SANITIZERS {
409        if san.category != category {
410            continue;
411        }
412        for line in &lines[start..end] {
413            let lower = line.to_lowercase();
414            for &pat in san.patterns {
415                if lower.contains(&pat.to_lowercase()) {
416                    return Some(pat.to_string());
417                }
418            }
419        }
420    }
421    None
422}
423
424fn write_taint_flows(store: &GraphStore, flows: &[TaintFlow]) -> Result<()> {
425    let conn = store.connection()?;
426
427    conn.query("BEGIN TRANSACTION")
428        .map_err(|e| anyhow::anyhow!("begin txn: {e}"))?;
429
430    let _ = conn.query("MATCH ()-[r:TAINT_FLOW]->() DELETE r");
431
432    for flow in flows {
433        if flow.sanitized {
434            continue;
435        }
436        let sym_esc = crate::escape_str(&flow.symbol_id);
437        let src_esc = crate::escape_str(&flow.source_kind);
438        let sink_esc = crate::escape_str(&flow.sink_kind);
439        let path_str = flow.path.join(" -> ");
440        let path_esc = crate::escape_str(&path_str);
441
442        // Self-edge: function taints itself (intra-procedural)
443        let _ = conn.query(&format!(
444            "MATCH (s:Symbol) WHERE s.id = '{sym_esc}' \
445             CREATE (s)-[:TAINT_FLOW {{source_kind: '{src_esc}', sink_kind: '{sink_esc}', path: '{path_esc}'}}]->(s)"
446        ));
447    }
448
449    conn.query("COMMIT")
450        .map_err(|e| anyhow::anyhow!("commit txn: {e}"))?;
451
452    Ok(())
453}
454
455pub fn format_taint_flows(flows: &[TaintFlow]) -> String {
456    if flows.is_empty() {
457        return "No taint flows detected.".to_string();
458    }
459
460    let active: Vec<_> = flows.iter().filter(|f| !f.sanitized).collect();
461    let sanitized_count = flows.len() - active.len();
462
463    let mut out = format!(
464        "Taint flows: {} total ({} active, {} sanitized)\n\n",
465        flows.len(),
466        active.len(),
467        sanitized_count
468    );
469
470    if !active.is_empty() {
471        let mut by_category: std::collections::BTreeMap<&str, Vec<&&TaintFlow>> =
472            std::collections::BTreeMap::new();
473        for f in &active {
474            by_category.entry(&f.sink_category).or_default().push(f);
475        }
476
477        for (category, items) in &by_category {
478            out.push_str(&format!("## {} ({} flows)\n", category, items.len()));
479            for f in items {
480                out.push_str(&format!(
481                    "  {}:{} -> {}:{}\n    {} -> {}\n",
482                    f.file, f.source_line, f.file, f.sink_line, f.source_kind, f.sink_kind,
483                ));
484                out.push_str("    Path: ");
485                for (i, step) in f.path.iter().enumerate() {
486                    if i > 0 {
487                        out.push_str(" -> ");
488                    }
489                    out.push_str(step);
490                }
491                out.push('\n');
492            }
493            out.push('\n');
494        }
495    }
496
497    if sanitized_count > 0 {
498        out.push_str(&format!("\n--- {} flows sanitized ---\n", sanitized_count));
499        for f in flows.iter().filter(|f| f.sanitized) {
500            out.push_str(&format!(
501                "  {}:L{} -> L{} ({} -> {}) sanitized by: {}\n",
502                f.file,
503                f.source_line,
504                f.sink_line,
505                f.source_kind,
506                f.sink_kind,
507                f.sanitizer.as_deref().unwrap_or("unknown"),
508            ));
509        }
510    }
511
512    out
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    fn run_analysis(code: &str) -> Vec<TaintFlow> {
520        let lines: Vec<String> = code.lines().map(String::from).collect();
521        analyze_function("test::func", "test.py", 1, &lines)
522    }
523
524    #[test]
525    fn test_simple_sql_injection() {
526        let code = r#"
527user_input = request.GET.get('name')
528cursor.execute("SELECT * FROM users WHERE name = " + user_input)
529"#;
530        let flows = run_analysis(code);
531        assert!(!flows.is_empty(), "should detect taint flow");
532        assert!(
533            flows.iter().any(|f| f.sink_category == "SqlInjection"),
534            "should be SQL injection"
535        );
536        assert!(
537            flows.iter().any(|f| f.source_kind == "HttpParam"),
538            "source should be HttpParam"
539        );
540    }
541
542    #[test]
543    fn test_multi_step_propagation() {
544        let code = r#"
545a = request.GET.get('q')
546b = a
547c = b
548cursor.execute(c)
549"#;
550        let flows = run_analysis(code);
551        assert!(!flows.is_empty(), "should detect multi-step taint");
552        let flow = flows
553            .iter()
554            .find(|f| f.sink_category == "SqlInjection")
555            .unwrap();
556        assert!(
557            flow.path.len() >= 3,
558            "path should have multiple steps: {:?}",
559            flow.path
560        );
561    }
562
563    #[test]
564    fn test_sanitizer_clears_taint() {
565        let code = r#"
566user_input = request.GET.get('name')
567safe_input = html.escape(user_input)
568el.innerHTML = safe_input
569"#;
570        let flows = run_analysis(code);
571        // safe_input is sanitized, so innerHTML should not have active taint from it
572        // But user_input is still tainted and might match
573        let xss_flows: Vec<_> = flows
574            .iter()
575            .filter(|f| f.sink_category == "XssRisk")
576            .collect();
577        // All XSS flows involving safe_input should be sanitized
578        for f in &xss_flows {
579            if f.source_var == "safe_input" {
580                assert!(f.sanitized, "safe_input should be sanitized");
581            }
582        }
583    }
584
585    #[test]
586    fn test_command_injection() {
587        let code = r#"
588cmd = request.POST.get('command')
589os.system(cmd)
590"#;
591        let flows = run_analysis(code);
592        assert!(flows.iter().any(|f| f.sink_category == "CommandInjection"));
593    }
594
595    #[test]
596    fn test_path_traversal() {
597        let code = r#"
598filename = req.params.filename
599content = open(filename)
600"#;
601        let flows = run_analysis(code);
602        assert!(
603            flows.iter().any(|f| f.sink_category == "PathTraversal"),
604            "flows: {:?}",
605            flows
606        );
607    }
608
609    #[test]
610    fn test_open_redirect() {
611        let code = r#"
612url = request.GET.get('next')
613redirect(url)
614"#;
615        let flows = run_analysis(code);
616        assert!(
617            flows.iter().any(|f| f.sink_category == "OpenRedirect"),
618            "flows: {:?}",
619            flows
620        );
621    }
622
623    #[test]
624    fn test_no_taint_without_source() {
625        let code = r#"
626name = "hardcoded"
627cursor.execute("SELECT * FROM users WHERE name = " + name)
628"#;
629        let flows = run_analysis(code);
630        assert!(
631            flows.is_empty(),
632            "hardcoded string should not be tainted: {:?}",
633            flows
634        );
635    }
636
637    #[test]
638    fn test_sanitized_sql() {
639        let code = r#"
640user_input = request.GET.get('id')
641safe = sanitize_sql(user_input)
642cursor.execute(safe)
643"#;
644        let flows = run_analysis(code);
645        let sql: Vec<_> = flows
646            .iter()
647            .filter(|f| f.sink_category == "SqlInjection")
648            .collect();
649        // Should either be empty (taint cleared) or sanitized
650        for f in &sql {
651            if f.source_var != "user_input" {
652                assert!(
653                    f.sanitized || f.path.is_empty(),
654                    "sanitized input should not produce active flow"
655                );
656            }
657        }
658    }
659
660    #[test]
661    fn test_java_request_param() {
662        let code = r#"
663String name = request.getParameter("name");
664stmt.executeQuery("SELECT * FROM users WHERE name = '" + name + "'");
665"#;
666        let flows = run_analysis(code);
667        assert!(
668            flows.iter().any(|f| f.sink_category == "SqlInjection"),
669            "flows: {:?}",
670            flows
671        );
672    }
673
674    #[test]
675    fn test_nodejs_req_body() {
676        let code = r#"
677const data = req.body.username;
678res.send(`<h1>${data}</h1>`);
679"#;
680        let flows = run_analysis(code);
681        // data is tainted from req.body but send() isn't a tracked sink
682        // The source should still be detected as HttpBody
683        let tainted_vars: Vec<_> = flows.iter().map(|f| &f.source_kind).collect();
684        // May or may not produce flows depending on sink matching
685        // At minimum, verify no crash
686        assert!(flows.is_empty() || tainted_vars.contains(&&"HttpBody".to_string()));
687    }
688
689    #[test]
690    fn test_extract_lhs_simple() {
691        assert_eq!(extract_lhs("x = foo()"), Some("x".to_string()));
692        assert_eq!(extract_lhs("let y = bar()"), Some("y".to_string()));
693        assert_eq!(extract_lhs("const z = baz()"), Some("z".to_string()));
694    }
695
696    #[test]
697    fn test_extract_lhs_no_match() {
698        assert_eq!(extract_lhs("if x == y:"), None);
699        assert_eq!(extract_lhs("x != y"), None);
700        assert_eq!(extract_lhs("foo()"), None);
701    }
702
703    #[test]
704    fn test_parse_assignment() {
705        let (lhs, rhs) = parse_assignment("data = request.GET.get('q')").unwrap();
706        assert_eq!(lhs, "data");
707        assert!(rhs.contains("request.GET"));
708    }
709
710    #[test]
711    fn test_deserialization_taint() {
712        let code = r#"
713data = request.body
714obj = pickle.loads(data)
715"#;
716        let flows = run_analysis(code);
717        assert!(
718            flows
719                .iter()
720                .any(|f| f.sink_category == "InsecureDeserialization"),
721            "flows: {:?}",
722            flows
723        );
724    }
725}