Skip to main content

infigraph_core/taint/
interprocedural.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::Path;
3
4use anyhow::Result;
5use serde::Serialize;
6
7use crate::graph::GraphQuery;
8use crate::graph::GraphStore;
9
10use super::sinks::TAINT_SINKS;
11use super::sources::TAINT_SOURCES;
12use super::{FuncInfo, SourceCache};
13
14#[derive(Debug, Clone, Serialize)]
15pub struct InterProcTaintFlow {
16    pub source_symbol: String,
17    pub sink_symbol: String,
18    pub source_kind: String,
19    pub sink_kind: String,
20    pub sink_category: String,
21    pub call_chain: Vec<String>,
22    pub depth: u32,
23}
24
25pub fn detect_interprocedural_taint(
26    store: &GraphStore,
27    root: &Path,
28    max_depth: u32,
29) -> Result<Vec<InterProcTaintFlow>> {
30    let conn = store.connection()?;
31    let gq = GraphQuery::new(&conn);
32
33    // Step 1: Find functions that contain taint sources (entry points)
34    let source_functions = find_source_functions(store, root)?;
35
36    // Step 2: Find functions that contain taint sinks
37    let sink_functions = find_sink_functions(store, root)?;
38
39    // Step 3: BFS from source functions through CALLS edges to sink functions
40    let mut flows = Vec::new();
41
42    for (src_sym, src_kind) in &source_functions {
43        let mut visited: HashSet<String> = HashSet::new();
44        let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
45
46        visited.insert(src_sym.clone());
47        queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
48
49        while let Some((current, chain, depth)) = queue.pop_front() {
50            if depth > max_depth {
51                continue;
52            }
53
54            // Check if current function is a sink
55            if let Some((sink_kind, sink_cat)) = sink_functions.get(&current) {
56                if current != *src_sym {
57                    flows.push(InterProcTaintFlow {
58                        source_symbol: src_sym.clone(),
59                        sink_symbol: current.clone(),
60                        source_kind: src_kind.clone(),
61                        sink_kind: sink_kind.clone(),
62                        sink_category: sink_cat.clone(),
63                        call_chain: chain.clone(),
64                        depth,
65                    });
66                }
67            }
68
69            // Traverse callees
70            if let Ok(callees) = gq.callees_of(&current) {
71                for callee in callees {
72                    if !visited.contains(&callee) {
73                        visited.insert(callee.clone());
74                        let mut new_chain = chain.clone();
75                        new_chain.push(callee.clone());
76                        queue.push_back((callee, new_chain, depth + 1));
77                    }
78                }
79            }
80        }
81    }
82
83    Ok(flows)
84}
85
86pub fn detect_interprocedural_taint_with_cache(
87    store: &GraphStore,
88    functions: &[FuncInfo],
89    cache: &SourceCache,
90    max_depth: u32,
91) -> Result<Vec<InterProcTaintFlow>> {
92    let conn = store.connection()?;
93
94    // Preload entire CALLS adjacency list in one query instead of per-node queries
95    let adj = load_call_adjacency(&conn)?;
96
97    let (source_functions, sink_functions) = find_sources_and_sinks_from_cache(functions, cache);
98
99    let mut flows = Vec::new();
100    for (src_sym, src_kind) in &source_functions {
101        let mut visited: HashSet<String> = HashSet::new();
102        let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
103        visited.insert(src_sym.clone());
104        queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
105
106        while let Some((current, chain, depth)) = queue.pop_front() {
107            if depth > max_depth {
108                continue;
109            }
110            if let Some((sink_kind, sink_cat)) = sink_functions.get(&current) {
111                if current != *src_sym {
112                    flows.push(InterProcTaintFlow {
113                        source_symbol: src_sym.clone(),
114                        sink_symbol: current.clone(),
115                        source_kind: src_kind.clone(),
116                        sink_kind: sink_kind.clone(),
117                        sink_category: sink_cat.clone(),
118                        call_chain: chain.clone(),
119                        depth,
120                    });
121                }
122            }
123            if let Some(callees) = adj.get(&current) {
124                for callee in callees {
125                    if !visited.contains(callee) {
126                        visited.insert(callee.clone());
127                        let mut new_chain = chain.clone();
128                        new_chain.push(callee.clone());
129                        queue.push_back((callee.clone(), new_chain, depth + 1));
130                    }
131                }
132            }
133        }
134    }
135    Ok(flows)
136}
137
138fn load_call_adjacency(conn: &kuzu::Connection<'_>) -> Result<HashMap<String, Vec<String>>> {
139    let result = conn
140        .query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")
141        .map_err(|e| anyhow::anyhow!("load call adjacency: {e}"))?;
142    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
143    for row in result {
144        if row.len() >= 2 {
145            adj.entry(row[0].to_string())
146                .or_default()
147                .push(row[1].to_string());
148        }
149    }
150    Ok(adj)
151}
152
153#[allow(clippy::type_complexity)]
154fn find_sources_and_sinks_from_cache(
155    functions: &[FuncInfo],
156    cache: &SourceCache,
157) -> (Vec<(String, String)>, HashMap<String, (String, String)>) {
158    let mut sources: Vec<(String, String)> = Vec::new();
159    let mut sinks: HashMap<String, (String, String)> = HashMap::new();
160
161    for func in functions {
162        let lines = match cache.get(&func.file) {
163            Some(l) => l,
164            None => continue,
165        };
166        let start_idx = (func.start_line as usize).saturating_sub(1);
167        let end_idx = (func.end_line as usize).min(lines.len());
168        if start_idx >= end_idx {
169            continue;
170        }
171
172        let mut found_source = false;
173        let mut found_sink = false;
174
175        for line in &lines[start_idx..end_idx] {
176            let lower = line.to_lowercase();
177
178            if !found_source {
179                for src in super::sources::TAINT_SOURCES {
180                    for &pat in src.patterns {
181                        if lower.contains(&pat.to_lowercase()) {
182                            sources.push((func.id.clone(), src.kind.to_string()));
183                            found_source = true;
184                            break;
185                        }
186                    }
187                    if found_source {
188                        break;
189                    }
190                }
191            }
192
193            if !found_sink {
194                for sink in TAINT_SINKS {
195                    for &pat in sink.patterns {
196                        if lower.contains(&pat.to_lowercase()) {
197                            sinks.insert(
198                                func.id.clone(),
199                                (sink.kind.to_string(), sink.category.to_string()),
200                            );
201                            found_sink = true;
202                            break;
203                        }
204                    }
205                    if found_sink {
206                        break;
207                    }
208                }
209            }
210
211            if found_source && found_sink {
212                break;
213            }
214        }
215    }
216
217    sources.dedup_by(|a, b| a.0 == b.0);
218    (sources, sinks)
219}
220
221fn find_source_functions(store: &GraphStore, root: &Path) -> Result<Vec<(String, String)>> {
222    let conn = store.connection()?;
223    let result = conn
224        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
225        .map_err(|e| anyhow::anyhow!("query: {e}"))?;
226
227    let mut sources = Vec::new();
228    let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
229
230    for row in result {
231        if row.len() < 4 {
232            continue;
233        }
234        let id = row[0].to_string();
235        let file = row[1].to_string();
236        let start: usize = row[2].to_string().parse().unwrap_or(0);
237        let end: usize = row[3].to_string().parse().unwrap_or(0);
238        if start == 0 || end <= start {
239            continue;
240        }
241
242        let lines = file_cache.entry(file.clone()).or_insert_with(|| {
243            std::fs::read_to_string(root.join(&file))
244                .unwrap_or_default()
245                .lines()
246                .map(String::from)
247                .collect()
248        });
249
250        let start_idx = start.saturating_sub(1);
251        let end_idx = end.min(lines.len());
252        if start_idx >= end_idx {
253            continue;
254        }
255
256        for line in &lines[start_idx..end_idx] {
257            let lower = line.to_lowercase();
258            for src in TAINT_SOURCES {
259                for &pat in src.patterns {
260                    if lower.contains(&pat.to_lowercase()) {
261                        sources.push((id.clone(), src.kind.to_string()));
262                        break;
263                    }
264                }
265                if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
266                    break;
267                }
268            }
269            if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
270                break;
271            }
272        }
273    }
274
275    sources.dedup_by(|a, b| a.0 == b.0);
276    Ok(sources)
277}
278
279fn find_sink_functions(
280    store: &GraphStore,
281    root: &Path,
282) -> Result<HashMap<String, (String, String)>> {
283    let conn = store.connection()?;
284    let result = conn
285        .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
286        .map_err(|e| anyhow::anyhow!("query: {e}"))?;
287
288    let mut sinks: HashMap<String, (String, String)> = HashMap::new();
289    let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
290
291    for row in result {
292        if row.len() < 4 {
293            continue;
294        }
295        let id = row[0].to_string();
296        let file = row[1].to_string();
297        let start: usize = row[2].to_string().parse().unwrap_or(0);
298        let end: usize = row[3].to_string().parse().unwrap_or(0);
299        if start == 0 || end <= start {
300            continue;
301        }
302
303        let lines = file_cache.entry(file.clone()).or_insert_with(|| {
304            std::fs::read_to_string(root.join(&file))
305                .unwrap_or_default()
306                .lines()
307                .map(String::from)
308                .collect()
309        });
310
311        let start_idx = start.saturating_sub(1);
312        let end_idx = end.min(lines.len());
313        if start_idx >= end_idx {
314            continue;
315        }
316
317        'outer: for line in &lines[start_idx..end_idx] {
318            let lower = line.to_lowercase();
319            for sink in TAINT_SINKS {
320                for &pat in sink.patterns {
321                    if lower.contains(&pat.to_lowercase()) {
322                        sinks.insert(
323                            id.clone(),
324                            (sink.kind.to_string(), sink.category.to_string()),
325                        );
326                        break 'outer;
327                    }
328                }
329            }
330        }
331    }
332
333    Ok(sinks)
334}
335
336pub fn format_interprocedural_flows(flows: &[InterProcTaintFlow]) -> String {
337    if flows.is_empty() {
338        return "No inter-procedural taint flows detected.".to_string();
339    }
340
341    let mut out = format!("Inter-procedural taint flows: {} total\n\n", flows.len());
342
343    let mut by_category: std::collections::BTreeMap<&str, Vec<&InterProcTaintFlow>> =
344        std::collections::BTreeMap::new();
345    for f in flows {
346        by_category.entry(&f.sink_category).or_default().push(f);
347    }
348
349    for (cat, items) in &by_category {
350        out.push_str(&format!("## {} ({} flows)\n", cat, items.len()));
351        for f in items {
352            out.push_str(&format!(
353                "  {} ({}) -> {} ({}) [depth: {}]\n",
354                f.source_symbol, f.source_kind, f.sink_symbol, f.sink_kind, f.depth
355            ));
356            out.push_str("    Chain: ");
357            out.push_str(&f.call_chain.join(" -> "));
358            out.push('\n');
359        }
360        out.push('\n');
361    }
362
363    out
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_format_empty() {
372        let result = format_interprocedural_flows(&[]);
373        assert!(result.contains("No inter-procedural"));
374    }
375
376    #[test]
377    fn test_format_with_flows() {
378        let flows = vec![InterProcTaintFlow {
379            source_symbol: "app.py::handle_request".to_string(),
380            sink_symbol: "db.py::run_query".to_string(),
381            source_kind: "HttpParam".to_string(),
382            sink_kind: "SqlQuery".to_string(),
383            sink_category: "SqlInjection".to_string(),
384            call_chain: vec![
385                "app.py::handle_request".to_string(),
386                "db.py::run_query".to_string(),
387            ],
388            depth: 1,
389        }];
390        let result = format_interprocedural_flows(&flows);
391        assert!(result.contains("SqlInjection"));
392        assert!(result.contains("handle_request"));
393        assert!(result.contains("run_query"));
394        assert!(result.contains("depth: 1"));
395    }
396}