Skip to main content

semantic/analysis/
analysis_graph.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Call graph and blast radius analysis.
3//!
4//! Extracts function call relationships from tree-sitter ASTs and computes
5//! downstream impact ("blast radius") for changed functions.
6
7use std::{
8    collections::{HashMap, HashSet, VecDeque},
9    path::PathBuf,
10};
11
12use crate::parser::{Language, ParsedFile};
13
14/// A node in the call graph: a function definition.
15#[derive(Clone, Debug)]
16pub struct CallGraphNode {
17    /// File containing this function.
18    pub file: PathBuf,
19    /// Function name.
20    pub name: String,
21    /// Line range in the source file.
22    pub start_line: usize,
23    pub end_line: usize,
24}
25
26/// A directed edge: `caller` calls `callee`.
27#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)]
28pub struct FunctionKey {
29    pub file: PathBuf,
30    pub name: String,
31    pub start_line: usize,
32}
33
34/// A directed edge: `caller` calls `callee`.
35#[derive(Clone, Debug)]
36pub struct CallEdge {
37    pub caller: FunctionKey,
38    pub callee: FunctionKey,
39}
40
41/// The call graph for a set of files.
42#[derive(Clone, Debug, Default)]
43pub struct CallGraph {
44    /// Function definitions: name → node.
45    pub nodes: HashMap<FunctionKey, CallGraphNode>,
46    /// Forward edges: function → set of functions it calls.
47    pub calls: HashMap<FunctionKey, HashSet<FunctionKey>>,
48    /// Reverse edges: function → set of functions that call it.
49    pub called_by: HashMap<FunctionKey, HashSet<FunctionKey>>,
50}
51
52/// Blast radius result for a set of changed functions.
53#[derive(Clone, Debug)]
54pub struct BlastRadius {
55    /// The changed functions that triggered the analysis.
56    pub changed_functions: Vec<String>,
57    /// Downstream functions affected (transitive callers).
58    pub affected: Vec<CallGraphNode>,
59    /// Total number of affected functions.
60    pub affected_count: usize,
61}
62
63impl CallGraph {
64    /// Build a call graph from a set of files and their contents.
65    pub fn build(files: &[(PathBuf, String)]) -> Self {
66        let mut graph = CallGraph::default();
67
68        for (path, content) in files {
69            let language = Language::from_path(path);
70            let Some(parsed) = ParsedFile::parse(content, language) else {
71                continue;
72            };
73
74            let functions: Vec<_> = parsed
75                .extract_functions()
76                .into_iter()
77                .map(|func| {
78                    let key = FunctionKey {
79                        file: path.clone(),
80                        name: func.name.clone(),
81                        start_line: func.start_line,
82                    };
83                    (key, func)
84                })
85                .collect();
86
87            for (key, func) in &functions {
88                graph.nodes.insert(
89                    key.clone(),
90                    CallGraphNode {
91                        file: path.clone(),
92                        name: func.name.clone(),
93                        start_line: func.start_line,
94                        end_line: func.end_line,
95                    },
96                );
97            }
98
99            let edges = extract_call_edges(&functions, parsed.language);
100            for edge in edges {
101                graph
102                    .calls
103                    .entry(edge.caller.clone())
104                    .or_default()
105                    .insert(edge.callee.clone());
106                graph
107                    .called_by
108                    .entry(edge.callee.clone())
109                    .or_default()
110                    .insert(edge.caller.clone());
111            }
112        }
113
114        graph
115    }
116
117    /// Return stable function identities for a bare function name.
118    pub fn keys_for_name(&self, name: &str) -> Vec<FunctionKey> {
119        let mut keys: Vec<_> = self
120            .nodes
121            .keys()
122            .filter(|key| key.name == name)
123            .cloned()
124            .collect();
125        keys.sort();
126        keys
127    }
128
129    /// Compute the blast radius for a set of changed function names.
130    /// Returns all transitive callers (upstream functions that depend on the changed ones).
131    pub fn blast_radius(&self, changed_functions: &[String]) -> BlastRadius {
132        let mut affected: HashSet<FunctionKey> = HashSet::new();
133        let mut queue: VecDeque<FunctionKey> = VecDeque::new();
134        let changed_names: HashSet<_> = changed_functions.iter().map(String::as_str).collect();
135        let changed_keys: HashSet<_> = self
136            .nodes
137            .keys()
138            .filter(|key| changed_names.contains(key.name.as_str()))
139            .cloned()
140            .collect();
141
142        for key in &changed_keys {
143            queue.push_back(key.clone());
144        }
145
146        while let Some(current) = queue.pop_front() {
147            if let Some(callers) = self.called_by.get(&current) {
148                for caller in callers {
149                    if !changed_keys.contains(caller) && affected.insert(caller.clone()) {
150                        queue.push_back(caller.clone());
151                    }
152                }
153            }
154        }
155
156        let mut affected_nodes: Vec<CallGraphNode> = affected
157            .iter()
158            .filter_map(|key| self.nodes.get(key).cloned())
159            .collect();
160        affected_nodes.sort_by(|left, right| {
161            left.file
162                .cmp(&right.file)
163                .then_with(|| left.start_line.cmp(&right.start_line))
164                .then_with(|| left.name.cmp(&right.name))
165        });
166        let count = affected_nodes.len();
167
168        BlastRadius {
169            changed_functions: changed_functions.to_vec(),
170            affected: affected_nodes,
171            affected_count: count,
172        }
173    }
174}
175
176/// Extract call edges from a parsed file by walking function bodies
177/// and looking for identifier references that match known patterns.
178fn extract_call_edges(
179    functions: &[(FunctionKey, crate::parser::FunctionDef)],
180    language: Language,
181) -> Vec<CallEdge> {
182    let mut edges = Vec::new();
183    let mut functions_by_name: HashMap<String, Vec<FunctionKey>> = HashMap::new();
184
185    for (key, func) in functions {
186        functions_by_name
187            .entry(func.name.clone())
188            .or_default()
189            .push(key.clone());
190    }
191
192    for (caller_key, func) in functions {
193        let calls = extract_calls_from_text(&func.content, language);
194        for callee_name in calls {
195            if callee_name == func.name {
196                continue;
197            }
198
199            if let Some(callees) = functions_by_name.get(&callee_name) {
200                for callee_key in callees {
201                    edges.push(CallEdge {
202                        caller: caller_key.clone(),
203                        callee: callee_key.clone(),
204                    });
205                }
206            }
207        }
208    }
209
210    edges
211}
212
213/// Extract function call names from a source text snippet.
214/// Simple heuristic: look for `identifier(` patterns.
215fn extract_calls_from_text(text: &str, _language: Language) -> Vec<String> {
216    let mut calls = HashSet::new();
217    let bytes = text.as_bytes();
218    let len = bytes.len();
219    let mut i = 0;
220
221    while i < len {
222        // Find '(' and look backwards for an identifier.
223        if bytes[i] == b'(' {
224            let end = i;
225            // Skip backwards past whitespace.
226            let mut j = end;
227            while j > 0 && bytes[j - 1] == b' ' {
228                j -= 1;
229            }
230            // Collect identifier characters backwards.
231            let ident_end = j;
232            while j > 0 && (bytes[j - 1].is_ascii_alphanumeric() || bytes[j - 1] == b'_') {
233                j -= 1;
234            }
235            if j < ident_end {
236                let ident = &text[j..ident_end];
237                // Filter out language keywords.
238                if !is_keyword(ident) && !ident.is_empty() {
239                    calls.insert(ident.to_string());
240                }
241            }
242        }
243        i += 1;
244    }
245
246    calls.into_iter().collect()
247}
248
249fn is_keyword(s: &str) -> bool {
250    matches!(
251        s,
252        "if" | "else"
253            | "for"
254            | "while"
255            | "match"
256            | "return"
257            | "let"
258            | "mut"
259            | "fn"
260            | "pub"
261            | "struct"
262            | "enum"
263            | "impl"
264            | "trait"
265            | "use"
266            | "mod"
267            | "type"
268            | "where"
269            | "async"
270            | "await"
271            | "loop"
272            | "break"
273            | "continue"
274            | "self"
275            | "Self"
276            | "super"
277            | "crate"
278            | "as"
279            | "in"
280            | "ref"
281            | "move"
282            | "dyn"
283            | "unsafe"
284            | "extern"
285            | "const"
286            | "static"
287            | "true"
288            | "false"
289            | "Some"
290            | "None"
291            | "Ok"
292            | "Err"
293            // Common control flow in other languages
294            | "def"
295            | "class"
296            | "import"
297            | "from"
298            | "try"
299            | "catch"
300            | "throw"
301            | "new"
302            | "var"
303            | "function"
304            | "switch"
305            | "case"
306            | "default"
307            | "typeof"
308            | "instanceof"
309            | "void"
310            | "delete"
311    )
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_basic_call_graph() {
320        let files = vec![(
321            PathBuf::from("test.rs"),
322            concat!(
323                "fn main() {\n",
324                "    let x = helper();\n",
325                "    process(x);\n",
326                "}\n\n",
327                "fn helper() -> i32 {\n",
328                "    42\n",
329                "}\n\n",
330                "fn process(x: i32) {\n",
331                "    output(x);\n",
332                "}\n\n",
333                "fn output(x: i32) {\n",
334                "    println!(\"{}\", x);\n",
335                "}\n",
336            )
337            .to_string(),
338        )];
339
340        let graph = CallGraph::build(&files);
341        assert_eq!(graph.nodes.len(), 4); // main, helper, process, output
342
343        // main calls helper and process.
344        let main_key = graph.keys_for_name("main").pop().unwrap();
345        let helper_key = graph.keys_for_name("helper").pop().unwrap();
346        let process_key = graph.keys_for_name("process").pop().unwrap();
347        let output_key = graph.keys_for_name("output").pop().unwrap();
348        let main_calls = graph.calls.get(&main_key).unwrap();
349        assert!(main_calls.contains(&helper_key));
350        assert!(main_calls.contains(&process_key));
351
352        // process calls output.
353        let proc_calls = graph.calls.get(&process_key).unwrap();
354        assert!(proc_calls.contains(&output_key));
355    }
356
357    #[test]
358    fn test_blast_radius() {
359        let files = vec![(
360            PathBuf::from("test.rs"),
361            concat!(
362                "fn main() {\n    run();\n}\n\n",
363                "fn run() {\n    compute();\n}\n\n",
364                "fn compute() {\n    42\n}\n\n",
365                "fn unrelated() {\n    99\n}\n",
366            )
367            .to_string(),
368        )];
369
370        let graph = CallGraph::build(&files);
371        let blast = graph.blast_radius(&["compute".to_string()]);
372
373        // compute is called by run, run is called by main.
374        assert_eq!(blast.affected_count, 2);
375        let names: HashSet<_> = blast.affected.iter().map(|n| n.name.as_str()).collect();
376        assert!(names.contains("run"));
377        assert!(names.contains("main"));
378        // unrelated is not affected.
379        assert!(!names.contains("unrelated"));
380    }
381
382    #[test]
383    fn test_no_blast_radius_for_leaf() {
384        let files = vec![(
385            PathBuf::from("test.rs"),
386            "fn leaf() {\n    42\n}\n".to_string(),
387        )];
388
389        let graph = CallGraph::build(&files);
390        let blast = graph.blast_radius(&["leaf".to_string()]);
391        assert_eq!(blast.affected_count, 0);
392    }
393
394    #[test]
395    fn test_duplicate_function_names_stay_isolated_by_file() {
396        let files = vec![
397            (
398                PathBuf::from("a.rs"),
399                concat!(
400                    "fn run() {\n    target();\n}\n\n",
401                    "fn target() {\n    42\n}\n",
402                )
403                .to_string(),
404            ),
405            (
406                PathBuf::from("b.rs"),
407                concat!(
408                    "fn run() {\n    other();\n}\n\n",
409                    "fn other() {\n    99\n}\n",
410                )
411                .to_string(),
412            ),
413        ];
414
415        let graph = CallGraph::build(&files);
416        let run_keys = graph.keys_for_name("run");
417        assert_eq!(run_keys.len(), 2);
418
419        let blast = graph.blast_radius(&["target".to_string()]);
420        assert_eq!(blast.affected_count, 1);
421        assert_eq!(blast.affected[0].file, PathBuf::from("a.rs"));
422        assert_eq!(blast.affected[0].name, "run");
423    }
424}