Skip to main content

infigraph_core/taint/
interprocedural.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::Path;
3
4use anyhow::Result;
5use serde::Serialize;
6
7use crate::graph::GraphQuery;
8use crate::graph::GraphStore;
9
10use super::sinks::TAINT_SINKS;
11use super::sources::TAINT_SOURCES;
12use super::{FuncInfo, SourceCache};
13
14#[derive(Debug, Clone, Serialize)]
15pub struct InterProcTaintFlow {
16    pub source_symbol: String,
17    pub sink_symbol: String,
18    pub source_kind: String,
19    pub sink_kind: String,
20    pub sink_category: String,
21    pub call_chain: Vec<String>,
22    pub depth: u32,
23}
24
25pub fn detect_interprocedural_taint(
26    store: &GraphStore,
27    root: &Path,
28    max_depth: u32,
29) -> Result<Vec<InterProcTaintFlow>> {
30    let conn = store.connection()?;
31    let gq = GraphQuery::new(&conn);
32
33    // Step 1: Find functions that contain taint sources (entry points)
34    let source_functions = find_source_functions(store, root)?;
35
36    // Step 2: Find functions that contain taint sinks
37    let sink_functions = find_sink_functions(store, root)?;
38
39    // Step 3: BFS from source functions through CALLS edges to sink functions
40    let mut flows = Vec::new();
41
42    for (src_sym, src_kind) in &source_functions {
43        let mut visited: HashSet<String> = HashSet::new();
44        let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
45
46        visited.insert(src_sym.clone());
47        queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
48
49        while let Some((current, chain, depth)) = queue.pop_front() {
50            if depth > max_depth {
51                continue;
52            }
53
54            // Check if current function is a sink
55            if let Some((sink_kind, sink_cat)) = sink_functions.get(&current) {
56                if current != *src_sym {
57                    flows.push(InterProcTaintFlow {
58                        source_symbol: src_sym.clone(),
59                        sink_symbol: current.clone(),
60                        source_kind: src_kind.clone(),
61                        sink_kind: sink_kind.clone(),
62                        sink_category: sink_cat.clone(),
63                        call_chain: chain.clone(),
64                        depth,
65                    });
66                }
67            }
68
69            // Traverse callees
70            if let Ok(callees) = gq.callees_of(&current) {
71                for callee in callees {
72                    if !visited.contains(&callee) {
73                        visited.insert(callee.clone());
74                        let mut new_chain = chain.clone();
75                        new_chain.push(callee.clone());
76                        queue.push_back((callee, new_chain, depth + 1));
77                    }
78                }
79            }
80        }
81    }
82
83    Ok(flows)
84}
85
86pub fn detect_interprocedural_taint_with_cache(
87    store: &GraphStore,
88    functions: &[FuncInfo],
89    cache: &SourceCache,
90    max_depth: u32,
91) -> Result<Vec<InterProcTaintFlow>> {
92    let conn = store.connection()?;
93
94    // Preload entire CALLS adjacency list in one query instead of per-node queries
95    let adj = load_call_adjacency(&conn)?;
96
97    let (source_functions, sink_functions) = find_sources_and_sinks_from_cache(functions, cache);
98
99    let mut flows = Vec::new();
100    for (src_sym, src_kind) in &source_functions {
101        let mut visited: HashSet<String> = HashSet::new();
102        let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
103        visited.insert(src_sym.clone());
104        queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
105
106        while let Some((current, chain, depth)) = queue.pop_front() {
107            if depth > max_depth {
108                continue;
109            }
110            if let Some((sink_kind, sink_cat)) = sink_functions.get(&current) {
111                if current != *src_sym {
112                    flows.push(InterProcTaintFlow {
113                        source_symbol: src_sym.clone(),
114                        sink_symbol: current.clone(),
115                        source_kind: src_kind.clone(),
116                        sink_kind: sink_kind.clone(),
117                        sink_category: sink_cat.clone(),
118                        call_chain: chain.clone(),
119                        depth,
120                    });
121                }
122            }
123            if let Some(callees) = adj.get(&current) {
124                for callee in callees {
125                    if !visited.contains(callee) {
126                        visited.insert(callee.clone());
127                        let mut new_chain = chain.clone();
128                        new_chain.push(callee.clone());
129                        queue.push_back((callee.clone(), new_chain, depth + 1));
130                    }
131                }
132            }
133        }
134    }
135    Ok(flows)
136}
137
138fn load_call_adjacency(conn: &kuzu::Connection<'_>) -> Result<HashMap<String, Vec<String>>> {
139    let result = conn
140        .query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")
141        .map_err(|e| anyhow::anyhow!("load call adjacency: {e}"))?;
142    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
143    for row in result {
144        if row.len() >= 2 {
145            adj.entry(row[0].to_string())
146                .or_default()
147                .push(row[1].to_string());
148        }
149    }
150    Ok(adj)
151}
152
153fn find_sources_and_sinks_from_cache(
154    functions: &[FuncInfo],
155    cache: &SourceCache,
156) -> (Vec<(String, String)>, HashMap<String, (String, String)>) {
157    let mut sources: Vec<(String, String)> = Vec::new();
158    let mut sinks: HashMap<String, (String, String)> = HashMap::new();
159
160    for func in functions {
161        let lines = match cache.get(&func.file) {
162            Some(l) => l,
163            None => continue,
164        };
165        let start_idx = (func.start_line as usize).saturating_sub(1);
166        let end_idx = (func.end_line as usize).min(lines.len());
167        if start_idx >= end_idx {
168            continue;
169        }
170
171        let mut found_source = false;
172        let mut found_sink = false;
173
174        for line in &lines[start_idx..end_idx] {
175            let lower = line.to_lowercase();
176
177            if !found_source {
178                for src in super::sources::TAINT_SOURCES {
179                    for &pat in src.patterns {
180                        if lower.contains(&pat.to_lowercase()) {
181                            sources.push((func.id.clone(), src.kind.to_string()));
182                            found_source = true;
183                            break;
184                        }
185                    }
186                    if found_source {
187                        break;
188                    }
189                }
190            }
191
192            if !found_sink {
193                for sink in TAINT_SINKS {
194                    for &pat in sink.patterns {
195                        if lower.contains(&pat.to_lowercase()) {
196                            sinks.insert(
197                                func.id.clone(),
198                                (sink.kind.to_string(), sink.category.to_string()),
199                            );
200                            found_sink = true;
201                            break;
202                        }
203                    }
204                    if found_sink {
205                        break;
206                    }
207                }
208            }
209
210            if found_source && found_sink {
211                break;
212            }
213        }
214    }
215
216    sources.dedup_by(|a, b| a.0 == b.0);
217    (sources, sinks)
218}
219
220fn find_source_functions(store: &GraphStore, root: &Path) -> Result<Vec<(String, String)>> {
221    let conn = store.connection()?;
222    let result = conn
223        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
224        .map_err(|e| anyhow::anyhow!("query: {e}"))?;
225
226    let mut sources = Vec::new();
227    let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
228
229    for row in result {
230        if row.len() < 4 {
231            continue;
232        }
233        let id = row[0].to_string();
234        let file = row[1].to_string();
235        let start: usize = row[2].to_string().parse().unwrap_or(0);
236        let end: usize = row[3].to_string().parse().unwrap_or(0);
237        if start == 0 || end <= start {
238            continue;
239        }
240
241        let lines = file_cache.entry(file.clone()).or_insert_with(|| {
242            std::fs::read_to_string(root.join(&file))
243                .unwrap_or_default()
244                .lines()
245                .map(String::from)
246                .collect()
247        });
248
249        let start_idx = start.saturating_sub(1);
250        let end_idx = end.min(lines.len());
251        if start_idx >= end_idx {
252            continue;
253        }
254
255        for line in &lines[start_idx..end_idx] {
256            let lower = line.to_lowercase();
257            for src in TAINT_SOURCES {
258                for &pat in src.patterns {
259                    if lower.contains(&pat.to_lowercase()) {
260                        sources.push((id.clone(), src.kind.to_string()));
261                        break;
262                    }
263                }
264                if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
265                    break;
266                }
267            }
268            if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
269                break;
270            }
271        }
272    }
273
274    sources.dedup_by(|a, b| a.0 == b.0);
275    Ok(sources)
276}
277
278fn find_sink_functions(
279    store: &GraphStore,
280    root: &Path,
281) -> Result<HashMap<String, (String, String)>> {
282    let conn = store.connection()?;
283    let result = conn
284        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
285        .map_err(|e| anyhow::anyhow!("query: {e}"))?;
286
287    let mut sinks: HashMap<String, (String, String)> = HashMap::new();
288    let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
289
290    for row in result {
291        if row.len() < 4 {
292            continue;
293        }
294        let id = row[0].to_string();
295        let file = row[1].to_string();
296        let start: usize = row[2].to_string().parse().unwrap_or(0);
297        let end: usize = row[3].to_string().parse().unwrap_or(0);
298        if start == 0 || end <= start {
299            continue;
300        }
301
302        let lines = file_cache.entry(file.clone()).or_insert_with(|| {
303            std::fs::read_to_string(root.join(&file))
304                .unwrap_or_default()
305                .lines()
306                .map(String::from)
307                .collect()
308        });
309
310        let start_idx = start.saturating_sub(1);
311        let end_idx = end.min(lines.len());
312        if start_idx >= end_idx {
313            continue;
314        }
315
316        'outer: for line in &lines[start_idx..end_idx] {
317            let lower = line.to_lowercase();
318            for sink in TAINT_SINKS {
319                for &pat in sink.patterns {
320                    if lower.contains(&pat.to_lowercase()) {
321                        sinks.insert(
322                            id.clone(),
323                            (sink.kind.to_string(), sink.category.to_string()),
324                        );
325                        break 'outer;
326                    }
327                }
328            }
329        }
330    }
331
332    Ok(sinks)
333}
334
335pub fn format_interprocedural_flows(flows: &[InterProcTaintFlow]) -> String {
336    if flows.is_empty() {
337        return "No inter-procedural taint flows detected.".to_string();
338    }
339
340    let mut out = format!("Inter-procedural taint flows: {} total\n\n", flows.len());
341
342    let mut by_category: std::collections::BTreeMap<&str, Vec<&InterProcTaintFlow>> =
343        std::collections::BTreeMap::new();
344    for f in flows {
345        by_category.entry(&f.sink_category).or_default().push(f);
346    }
347
348    for (cat, items) in &by_category {
349        out.push_str(&format!("## {} ({} flows)\n", cat, items.len()));
350        for f in items {
351            out.push_str(&format!(
352                "  {} ({}) -> {} ({}) [depth: {}]\n",
353                f.source_symbol, f.source_kind, f.sink_symbol, f.sink_kind, f.depth
354            ));
355            out.push_str("    Chain: ");
356            out.push_str(&f.call_chain.join(" -> "));
357            out.push('\n');
358        }
359        out.push('\n');
360    }
361
362    out
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_format_empty() {
371        let result = format_interprocedural_flows(&[]);
372        assert!(result.contains("No inter-procedural"));
373    }
374
375    #[test]
376    fn test_format_with_flows() {
377        let flows = vec![InterProcTaintFlow {
378            source_symbol: "app.py::handle_request".to_string(),
379            sink_symbol: "db.py::run_query".to_string(),
380            source_kind: "HttpParam".to_string(),
381            sink_kind: "SqlQuery".to_string(),
382            sink_category: "SqlInjection".to_string(),
383            call_chain: vec![
384                "app.py::handle_request".to_string(),
385                "db.py::run_query".to_string(),
386            ],
387            depth: 1,
388        }];
389        let result = format_interprocedural_flows(&flows);
390        assert!(result.contains("SqlInjection"));
391        assert!(result.contains("handle_request"));
392        assert!(result.contains("run_query"));
393        assert!(result.contains("depth: 1"));
394    }
395}