1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::Path;
3
4use anyhow::Result;
5use serde::Serialize;
6
7use crate::graph::GraphQuery;
8use crate::graph::GraphStore;
9
10use super::sinks::TAINT_SINKS;
11use super::sources::TAINT_SOURCES;
12use super::{FuncInfo, SourceCache};
13
14#[derive(Debug, Clone, Serialize)]
15pub struct InterProcTaintFlow {
16 pub source_symbol: String,
17 pub sink_symbol: String,
18 pub source_kind: String,
19 pub sink_kind: String,
20 pub sink_category: String,
21 pub call_chain: Vec<String>,
22 pub depth: u32,
23}
24
25pub fn detect_interprocedural_taint(
26 store: &GraphStore,
27 root: &Path,
28 max_depth: u32,
29) -> Result<Vec<InterProcTaintFlow>> {
30 let conn = store.connection()?;
31 let gq = GraphQuery::new(&conn);
32
33 let source_functions = find_source_functions(store, root)?;
35
36 let sink_functions = find_sink_functions(store, root)?;
38
39 let mut flows = Vec::new();
41
42 for (src_sym, src_kind) in &source_functions {
43 let mut visited: HashSet<String> = HashSet::new();
44 let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
45
46 visited.insert(src_sym.clone());
47 queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
48
49 while let Some((current, chain, depth)) = queue.pop_front() {
50 if depth > max_depth {
51 continue;
52 }
53
54 if let Some((sink_kind, sink_cat)) = sink_functions.get(¤t) {
56 if current != *src_sym {
57 flows.push(InterProcTaintFlow {
58 source_symbol: src_sym.clone(),
59 sink_symbol: current.clone(),
60 source_kind: src_kind.clone(),
61 sink_kind: sink_kind.clone(),
62 sink_category: sink_cat.clone(),
63 call_chain: chain.clone(),
64 depth,
65 });
66 }
67 }
68
69 if let Ok(callees) = gq.callees_of(¤t) {
71 for callee in callees {
72 if !visited.contains(&callee) {
73 visited.insert(callee.clone());
74 let mut new_chain = chain.clone();
75 new_chain.push(callee.clone());
76 queue.push_back((callee, new_chain, depth + 1));
77 }
78 }
79 }
80 }
81 }
82
83 Ok(flows)
84}
85
86pub fn detect_interprocedural_taint_with_cache(
87 store: &GraphStore,
88 functions: &[FuncInfo],
89 cache: &SourceCache,
90 max_depth: u32,
91) -> Result<Vec<InterProcTaintFlow>> {
92 let conn = store.connection()?;
93
94 let adj = load_call_adjacency(&conn)?;
96
97 let (source_functions, sink_functions) = find_sources_and_sinks_from_cache(functions, cache);
98
99 let mut flows = Vec::new();
100 for (src_sym, src_kind) in &source_functions {
101 let mut visited: HashSet<String> = HashSet::new();
102 let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
103 visited.insert(src_sym.clone());
104 queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
105
106 while let Some((current, chain, depth)) = queue.pop_front() {
107 if depth > max_depth {
108 continue;
109 }
110 if let Some((sink_kind, sink_cat)) = sink_functions.get(¤t) {
111 if current != *src_sym {
112 flows.push(InterProcTaintFlow {
113 source_symbol: src_sym.clone(),
114 sink_symbol: current.clone(),
115 source_kind: src_kind.clone(),
116 sink_kind: sink_kind.clone(),
117 sink_category: sink_cat.clone(),
118 call_chain: chain.clone(),
119 depth,
120 });
121 }
122 }
123 if let Some(callees) = adj.get(¤t) {
124 for callee in callees {
125 if !visited.contains(callee) {
126 visited.insert(callee.clone());
127 let mut new_chain = chain.clone();
128 new_chain.push(callee.clone());
129 queue.push_back((callee.clone(), new_chain, depth + 1));
130 }
131 }
132 }
133 }
134 }
135 Ok(flows)
136}
137
138fn load_call_adjacency(conn: &kuzu::Connection<'_>) -> Result<HashMap<String, Vec<String>>> {
139 let result = conn
140 .query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")
141 .map_err(|e| anyhow::anyhow!("load call adjacency: {e}"))?;
142 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
143 for row in result {
144 if row.len() >= 2 {
145 adj.entry(row[0].to_string())
146 .or_default()
147 .push(row[1].to_string());
148 }
149 }
150 Ok(adj)
151}
152
153fn find_sources_and_sinks_from_cache(
154 functions: &[FuncInfo],
155 cache: &SourceCache,
156) -> (Vec<(String, String)>, HashMap<String, (String, String)>) {
157 let mut sources: Vec<(String, String)> = Vec::new();
158 let mut sinks: HashMap<String, (String, String)> = HashMap::new();
159
160 for func in functions {
161 let lines = match cache.get(&func.file) {
162 Some(l) => l,
163 None => continue,
164 };
165 let start_idx = (func.start_line as usize).saturating_sub(1);
166 let end_idx = (func.end_line as usize).min(lines.len());
167 if start_idx >= end_idx {
168 continue;
169 }
170
171 let mut found_source = false;
172 let mut found_sink = false;
173
174 for line in &lines[start_idx..end_idx] {
175 let lower = line.to_lowercase();
176
177 if !found_source {
178 for src in super::sources::TAINT_SOURCES {
179 for &pat in src.patterns {
180 if lower.contains(&pat.to_lowercase()) {
181 sources.push((func.id.clone(), src.kind.to_string()));
182 found_source = true;
183 break;
184 }
185 }
186 if found_source {
187 break;
188 }
189 }
190 }
191
192 if !found_sink {
193 for sink in TAINT_SINKS {
194 for &pat in sink.patterns {
195 if lower.contains(&pat.to_lowercase()) {
196 sinks.insert(
197 func.id.clone(),
198 (sink.kind.to_string(), sink.category.to_string()),
199 );
200 found_sink = true;
201 break;
202 }
203 }
204 if found_sink {
205 break;
206 }
207 }
208 }
209
210 if found_source && found_sink {
211 break;
212 }
213 }
214 }
215
216 sources.dedup_by(|a, b| a.0 == b.0);
217 (sources, sinks)
218}
219
220fn find_source_functions(store: &GraphStore, root: &Path) -> Result<Vec<(String, String)>> {
221 let conn = store.connection()?;
222 let result = conn
223 .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
224 .map_err(|e| anyhow::anyhow!("query: {e}"))?;
225
226 let mut sources = Vec::new();
227 let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
228
229 for row in result {
230 if row.len() < 4 {
231 continue;
232 }
233 let id = row[0].to_string();
234 let file = row[1].to_string();
235 let start: usize = row[2].to_string().parse().unwrap_or(0);
236 let end: usize = row[3].to_string().parse().unwrap_or(0);
237 if start == 0 || end <= start {
238 continue;
239 }
240
241 let lines = file_cache.entry(file.clone()).or_insert_with(|| {
242 std::fs::read_to_string(root.join(&file))
243 .unwrap_or_default()
244 .lines()
245 .map(String::from)
246 .collect()
247 });
248
249 let start_idx = start.saturating_sub(1);
250 let end_idx = end.min(lines.len());
251 if start_idx >= end_idx {
252 continue;
253 }
254
255 for line in &lines[start_idx..end_idx] {
256 let lower = line.to_lowercase();
257 for src in TAINT_SOURCES {
258 for &pat in src.patterns {
259 if lower.contains(&pat.to_lowercase()) {
260 sources.push((id.clone(), src.kind.to_string()));
261 break;
262 }
263 }
264 if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
265 break;
266 }
267 }
268 if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
269 break;
270 }
271 }
272 }
273
274 sources.dedup_by(|a, b| a.0 == b.0);
275 Ok(sources)
276}
277
278fn find_sink_functions(
279 store: &GraphStore,
280 root: &Path,
281) -> Result<HashMap<String, (String, String)>> {
282 let conn = store.connection()?;
283 let result = conn
284 .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
285 .map_err(|e| anyhow::anyhow!("query: {e}"))?;
286
287 let mut sinks: HashMap<String, (String, String)> = HashMap::new();
288 let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
289
290 for row in result {
291 if row.len() < 4 {
292 continue;
293 }
294 let id = row[0].to_string();
295 let file = row[1].to_string();
296 let start: usize = row[2].to_string().parse().unwrap_or(0);
297 let end: usize = row[3].to_string().parse().unwrap_or(0);
298 if start == 0 || end <= start {
299 continue;
300 }
301
302 let lines = file_cache.entry(file.clone()).or_insert_with(|| {
303 std::fs::read_to_string(root.join(&file))
304 .unwrap_or_default()
305 .lines()
306 .map(String::from)
307 .collect()
308 });
309
310 let start_idx = start.saturating_sub(1);
311 let end_idx = end.min(lines.len());
312 if start_idx >= end_idx {
313 continue;
314 }
315
316 'outer: for line in &lines[start_idx..end_idx] {
317 let lower = line.to_lowercase();
318 for sink in TAINT_SINKS {
319 for &pat in sink.patterns {
320 if lower.contains(&pat.to_lowercase()) {
321 sinks.insert(
322 id.clone(),
323 (sink.kind.to_string(), sink.category.to_string()),
324 );
325 break 'outer;
326 }
327 }
328 }
329 }
330 }
331
332 Ok(sinks)
333}
334
335pub fn format_interprocedural_flows(flows: &[InterProcTaintFlow]) -> String {
336 if flows.is_empty() {
337 return "No inter-procedural taint flows detected.".to_string();
338 }
339
340 let mut out = format!("Inter-procedural taint flows: {} total\n\n", flows.len());
341
342 let mut by_category: std::collections::BTreeMap<&str, Vec<&InterProcTaintFlow>> =
343 std::collections::BTreeMap::new();
344 for f in flows {
345 by_category.entry(&f.sink_category).or_default().push(f);
346 }
347
348 for (cat, items) in &by_category {
349 out.push_str(&format!("## {} ({} flows)\n", cat, items.len()));
350 for f in items {
351 out.push_str(&format!(
352 " {} ({}) -> {} ({}) [depth: {}]\n",
353 f.source_symbol, f.source_kind, f.sink_symbol, f.sink_kind, f.depth
354 ));
355 out.push_str(" Chain: ");
356 out.push_str(&f.call_chain.join(" -> "));
357 out.push('\n');
358 }
359 out.push('\n');
360 }
361
362 out
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_format_empty() {
371 let result = format_interprocedural_flows(&[]);
372 assert!(result.contains("No inter-procedural"));
373 }
374
375 #[test]
376 fn test_format_with_flows() {
377 let flows = vec![InterProcTaintFlow {
378 source_symbol: "app.py::handle_request".to_string(),
379 sink_symbol: "db.py::run_query".to_string(),
380 source_kind: "HttpParam".to_string(),
381 sink_kind: "SqlQuery".to_string(),
382 sink_category: "SqlInjection".to_string(),
383 call_chain: vec![
384 "app.py::handle_request".to_string(),
385 "db.py::run_query".to_string(),
386 ],
387 depth: 1,
388 }];
389 let result = format_interprocedural_flows(&flows);
390 assert!(result.contains("SqlInjection"));
391 assert!(result.contains("handle_request"));
392 assert!(result.contains("run_query"));
393 assert!(result.contains("depth: 1"));
394 }
395}