1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::Path;
3
4use anyhow::Result;
5use serde::Serialize;
6
7use crate::graph::GraphQuery;
8use crate::graph::GraphStore;
9
10use super::sinks::TAINT_SINKS;
11use super::sources::TAINT_SOURCES;
12use super::{FuncInfo, SourceCache};
13
14#[derive(Debug, Clone, Serialize)]
15pub struct InterProcTaintFlow {
16 pub source_symbol: String,
17 pub sink_symbol: String,
18 pub source_kind: String,
19 pub sink_kind: String,
20 pub sink_category: String,
21 pub call_chain: Vec<String>,
22 pub depth: u32,
23}
24
25pub fn detect_interprocedural_taint(
26 store: &GraphStore,
27 root: &Path,
28 max_depth: u32,
29) -> Result<Vec<InterProcTaintFlow>> {
30 let conn = store.connection()?;
31 let gq = GraphQuery::new(&conn);
32
33 let source_functions = find_source_functions(store, root)?;
35
36 let sink_functions = find_sink_functions(store, root)?;
38
39 let mut flows = Vec::new();
41
42 for (src_sym, src_kind) in &source_functions {
43 let mut visited: HashSet<String> = HashSet::new();
44 let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
45
46 visited.insert(src_sym.clone());
47 queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
48
49 while let Some((current, chain, depth)) = queue.pop_front() {
50 if depth > max_depth {
51 continue;
52 }
53
54 if let Some((sink_kind, sink_cat)) = sink_functions.get(¤t) {
56 if current != *src_sym {
57 flows.push(InterProcTaintFlow {
58 source_symbol: src_sym.clone(),
59 sink_symbol: current.clone(),
60 source_kind: src_kind.clone(),
61 sink_kind: sink_kind.clone(),
62 sink_category: sink_cat.clone(),
63 call_chain: chain.clone(),
64 depth,
65 });
66 }
67 }
68
69 if let Ok(callees) = gq.callees_of(¤t) {
71 for callee in callees {
72 if !visited.contains(&callee) {
73 visited.insert(callee.clone());
74 let mut new_chain = chain.clone();
75 new_chain.push(callee.clone());
76 queue.push_back((callee, new_chain, depth + 1));
77 }
78 }
79 }
80 }
81 }
82
83 Ok(flows)
84}
85
86pub fn detect_interprocedural_taint_with_cache(
87 store: &GraphStore,
88 functions: &[FuncInfo],
89 cache: &SourceCache,
90 max_depth: u32,
91) -> Result<Vec<InterProcTaintFlow>> {
92 let conn = store.connection()?;
93
94 let adj = load_call_adjacency(&conn)?;
96
97 let (source_functions, sink_functions) = find_sources_and_sinks_from_cache(functions, cache);
98
99 let mut flows = Vec::new();
100 for (src_sym, src_kind) in &source_functions {
101 let mut visited: HashSet<String> = HashSet::new();
102 let mut queue: VecDeque<(String, Vec<String>, u32)> = VecDeque::new();
103 visited.insert(src_sym.clone());
104 queue.push_back((src_sym.clone(), vec![src_sym.clone()], 0));
105
106 while let Some((current, chain, depth)) = queue.pop_front() {
107 if depth > max_depth {
108 continue;
109 }
110 if let Some((sink_kind, sink_cat)) = sink_functions.get(¤t) {
111 if current != *src_sym {
112 flows.push(InterProcTaintFlow {
113 source_symbol: src_sym.clone(),
114 sink_symbol: current.clone(),
115 source_kind: src_kind.clone(),
116 sink_kind: sink_kind.clone(),
117 sink_category: sink_cat.clone(),
118 call_chain: chain.clone(),
119 depth,
120 });
121 }
122 }
123 if let Some(callees) = adj.get(¤t) {
124 for callee in callees {
125 if !visited.contains(callee) {
126 visited.insert(callee.clone());
127 let mut new_chain = chain.clone();
128 new_chain.push(callee.clone());
129 queue.push_back((callee.clone(), new_chain, depth + 1));
130 }
131 }
132 }
133 }
134 }
135 Ok(flows)
136}
137
138fn load_call_adjacency(conn: &kuzu::Connection<'_>) -> Result<HashMap<String, Vec<String>>> {
139 let result = conn
140 .query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")
141 .map_err(|e| anyhow::anyhow!("load call adjacency: {e}"))?;
142 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
143 for row in result {
144 if row.len() >= 2 {
145 adj.entry(row[0].to_string())
146 .or_default()
147 .push(row[1].to_string());
148 }
149 }
150 Ok(adj)
151}
152
153#[allow(clippy::type_complexity)]
154fn find_sources_and_sinks_from_cache(
155 functions: &[FuncInfo],
156 cache: &SourceCache,
157) -> (Vec<(String, String)>, HashMap<String, (String, String)>) {
158 let mut sources: Vec<(String, String)> = Vec::new();
159 let mut sinks: HashMap<String, (String, String)> = HashMap::new();
160
161 for func in functions {
162 let lines = match cache.get(&func.file) {
163 Some(l) => l,
164 None => continue,
165 };
166 let start_idx = (func.start_line as usize).saturating_sub(1);
167 let end_idx = (func.end_line as usize).min(lines.len());
168 if start_idx >= end_idx {
169 continue;
170 }
171
172 let mut found_source = false;
173 let mut found_sink = false;
174
175 for line in &lines[start_idx..end_idx] {
176 let lower = line.to_lowercase();
177
178 if !found_source {
179 for src in super::sources::TAINT_SOURCES {
180 for &pat in src.patterns {
181 if lower.contains(&pat.to_lowercase()) {
182 sources.push((func.id.clone(), src.kind.to_string()));
183 found_source = true;
184 break;
185 }
186 }
187 if found_source {
188 break;
189 }
190 }
191 }
192
193 if !found_sink {
194 for sink in TAINT_SINKS {
195 for &pat in sink.patterns {
196 if lower.contains(&pat.to_lowercase()) {
197 sinks.insert(
198 func.id.clone(),
199 (sink.kind.to_string(), sink.category.to_string()),
200 );
201 found_sink = true;
202 break;
203 }
204 }
205 if found_sink {
206 break;
207 }
208 }
209 }
210
211 if found_source && found_sink {
212 break;
213 }
214 }
215 }
216
217 sources.dedup_by(|a, b| a.0 == b.0);
218 (sources, sinks)
219}
220
221fn find_source_functions(store: &GraphStore, root: &Path) -> Result<Vec<(String, String)>> {
222 let conn = store.connection()?;
223 let result = conn
224 .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
225 .map_err(|e| anyhow::anyhow!("query: {e}"))?;
226
227 let mut sources = Vec::new();
228 let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
229
230 for row in result {
231 if row.len() < 4 {
232 continue;
233 }
234 let id = row[0].to_string();
235 let file = row[1].to_string();
236 let start: usize = row[2].to_string().parse().unwrap_or(0);
237 let end: usize = row[3].to_string().parse().unwrap_or(0);
238 if start == 0 || end <= start {
239 continue;
240 }
241
242 let lines = file_cache.entry(file.clone()).or_insert_with(|| {
243 std::fs::read_to_string(root.join(&file))
244 .unwrap_or_default()
245 .lines()
246 .map(String::from)
247 .collect()
248 });
249
250 let start_idx = start.saturating_sub(1);
251 let end_idx = end.min(lines.len());
252 if start_idx >= end_idx {
253 continue;
254 }
255
256 for line in &lines[start_idx..end_idx] {
257 let lower = line.to_lowercase();
258 for src in TAINT_SOURCES {
259 for &pat in src.patterns {
260 if lower.contains(&pat.to_lowercase()) {
261 sources.push((id.clone(), src.kind.to_string()));
262 break;
263 }
264 }
265 if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
266 break;
267 }
268 }
269 if sources.last().map(|(s, _)| s == &id).unwrap_or(false) {
270 break;
271 }
272 }
273 }
274
275 sources.dedup_by(|a, b| a.0 == b.0);
276 Ok(sources)
277}
278
279fn find_sink_functions(
280 store: &GraphStore,
281 root: &Path,
282) -> Result<HashMap<String, (String, String)>> {
283 let conn = store.connection()?;
284 let result = conn
285 .query("MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND s.file IS NOT NULL RETURN s.id, s.file, s.start_line, s.end_line")
286 .map_err(|e| anyhow::anyhow!("query: {e}"))?;
287
288 let mut sinks: HashMap<String, (String, String)> = HashMap::new();
289 let mut file_cache: HashMap<String, Vec<String>> = HashMap::new();
290
291 for row in result {
292 if row.len() < 4 {
293 continue;
294 }
295 let id = row[0].to_string();
296 let file = row[1].to_string();
297 let start: usize = row[2].to_string().parse().unwrap_or(0);
298 let end: usize = row[3].to_string().parse().unwrap_or(0);
299 if start == 0 || end <= start {
300 continue;
301 }
302
303 let lines = file_cache.entry(file.clone()).or_insert_with(|| {
304 std::fs::read_to_string(root.join(&file))
305 .unwrap_or_default()
306 .lines()
307 .map(String::from)
308 .collect()
309 });
310
311 let start_idx = start.saturating_sub(1);
312 let end_idx = end.min(lines.len());
313 if start_idx >= end_idx {
314 continue;
315 }
316
317 'outer: for line in &lines[start_idx..end_idx] {
318 let lower = line.to_lowercase();
319 for sink in TAINT_SINKS {
320 for &pat in sink.patterns {
321 if lower.contains(&pat.to_lowercase()) {
322 sinks.insert(
323 id.clone(),
324 (sink.kind.to_string(), sink.category.to_string()),
325 );
326 break 'outer;
327 }
328 }
329 }
330 }
331 }
332
333 Ok(sinks)
334}
335
336pub fn format_interprocedural_flows(flows: &[InterProcTaintFlow]) -> String {
337 if flows.is_empty() {
338 return "No inter-procedural taint flows detected.".to_string();
339 }
340
341 let mut out = format!("Inter-procedural taint flows: {} total\n\n", flows.len());
342
343 let mut by_category: std::collections::BTreeMap<&str, Vec<&InterProcTaintFlow>> =
344 std::collections::BTreeMap::new();
345 for f in flows {
346 by_category.entry(&f.sink_category).or_default().push(f);
347 }
348
349 for (cat, items) in &by_category {
350 out.push_str(&format!("## {} ({} flows)\n", cat, items.len()));
351 for f in items {
352 out.push_str(&format!(
353 " {} ({}) -> {} ({}) [depth: {}]\n",
354 f.source_symbol, f.source_kind, f.sink_symbol, f.sink_kind, f.depth
355 ));
356 out.push_str(" Chain: ");
357 out.push_str(&f.call_chain.join(" -> "));
358 out.push('\n');
359 }
360 out.push('\n');
361 }
362
363 out
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_format_empty() {
372 let result = format_interprocedural_flows(&[]);
373 assert!(result.contains("No inter-procedural"));
374 }
375
376 #[test]
377 fn test_format_with_flows() {
378 let flows = vec![InterProcTaintFlow {
379 source_symbol: "app.py::handle_request".to_string(),
380 sink_symbol: "db.py::run_query".to_string(),
381 source_kind: "HttpParam".to_string(),
382 sink_kind: "SqlQuery".to_string(),
383 sink_category: "SqlInjection".to_string(),
384 call_chain: vec![
385 "app.py::handle_request".to_string(),
386 "db.py::run_query".to_string(),
387 ],
388 depth: 1,
389 }];
390 let result = format_interprocedural_flows(&flows);
391 assert!(result.contains("SqlInjection"));
392 assert!(result.contains("handle_request"));
393 assert!(result.contains("run_query"));
394 assert!(result.contains("depth: 1"));
395 }
396}