Skip to main content

codesearch/
callgraph.rs

1//! Call Graph Module
2//!
3//! Analyzes function call relationships in code.
4
5use crate::parser::get_parser_for_extension;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::Path;
9use walkdir::WalkDir;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CallGraph {
13    pub nodes: HashMap<String, CallNode>,
14    pub edges: Vec<CallEdge>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CallNode {
19    pub function_name: String,
20    pub file_path: String,
21    pub line: usize,
22    pub is_recursive: bool,
23    pub call_count: usize,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CallEdge {
28    pub caller: String,
29    pub callee: String,
30    pub call_site_line: usize,
31    pub is_direct: bool,
32}
33
34impl CallGraph {
35    pub fn new() -> Self {
36        Self {
37            nodes: HashMap::new(),
38            edges: Vec::new(),
39        }
40    }
41
42    pub fn add_node(&mut self, node: CallNode) {
43        self.nodes.insert(node.function_name.clone(), node);
44    }
45
46    pub fn add_edge(
47        &mut self,
48        caller: String,
49        callee: String,
50        call_site_line: usize,
51        is_direct: bool,
52    ) {
53        self.edges.push(CallEdge {
54            caller,
55            callee,
56            call_site_line,
57            is_direct,
58        });
59    }
60
61    pub fn get_callers(&self, function: &str) -> Vec<String> {
62        self.edges
63            .iter()
64            .filter(|e| e.callee == function)
65            .map(|e| e.caller.clone())
66            .collect()
67    }
68
69    pub fn get_callees(&self, function: &str) -> Vec<String> {
70        self.edges
71            .iter()
72            .filter(|e| e.caller == function)
73            .map(|e| e.callee.clone())
74            .collect()
75    }
76
77    pub fn find_recursive_functions(&self) -> Vec<String> {
78        let mut recursive = Vec::new();
79
80        for (func_name, _) in &self.nodes {
81            if self.is_recursive(func_name) {
82                recursive.push(func_name.clone());
83            }
84        }
85
86        recursive
87    }
88
89    fn is_recursive(&self, function: &str) -> bool {
90        let mut visited = HashSet::new();
91        let mut stack = vec![function.to_string()];
92
93        while let Some(current) = stack.pop() {
94            if current == function && !visited.is_empty() {
95                return true;
96            }
97
98            if visited.insert(current.clone()) {
99                for callee in self.get_callees(&current) {
100                    stack.push(callee);
101                }
102            }
103        }
104
105        false
106    }
107
108    pub fn find_dead_functions(&self) -> Vec<String> {
109        let mut called_functions = HashSet::new();
110
111        for edge in &self.edges {
112            called_functions.insert(edge.callee.clone());
113        }
114
115        self.nodes
116            .keys()
117            .filter(|func| !called_functions.contains(*func) && *func != "main")
118            .cloned()
119            .collect()
120    }
121
122    pub fn calculate_call_depth(&self, function: &str) -> usize {
123        let mut max_depth = 0;
124        let mut visited = HashSet::new();
125        self.calculate_depth_recursive(function, 0, &mut visited, &mut max_depth);
126        max_depth
127    }
128
129    fn calculate_depth_recursive(
130        &self,
131        function: &str,
132        depth: usize,
133        visited: &mut HashSet<String>,
134        max_depth: &mut usize,
135    ) {
136        if visited.contains(function) {
137            return;
138        }
139
140        visited.insert(function.to_string());
141        *max_depth = (*max_depth).max(depth);
142
143        for callee in self.get_callees(function) {
144            self.calculate_depth_recursive(&callee, depth + 1, visited, max_depth);
145        }
146
147        visited.remove(function);
148    }
149
150    pub fn find_call_chains(&self, from: &str, to: &str) -> Vec<Vec<String>> {
151        let mut chains = Vec::new();
152        let mut current_path = vec![from.to_string()];
153        let mut visited = HashSet::new();
154
155        self.find_chains_recursive(from, to, &mut current_path, &mut visited, &mut chains);
156
157        chains
158    }
159
160    fn find_chains_recursive(
161        &self,
162        current: &str,
163        target: &str,
164        path: &mut Vec<String>,
165        visited: &mut HashSet<String>,
166        chains: &mut Vec<Vec<String>>,
167    ) {
168        if current == target {
169            chains.push(path.clone());
170            return;
171        }
172
173        if visited.contains(current) {
174            return;
175        }
176
177        visited.insert(current.to_string());
178
179        for callee in self.get_callees(current) {
180            path.push(callee.clone());
181            self.find_chains_recursive(&callee, target, path, visited, chains);
182            path.pop();
183        }
184
185        visited.remove(current);
186    }
187
188    pub fn to_dot(&self) -> String {
189        let mut dot = String::from("digraph CallGraph {\n");
190        dot.push_str("  rankdir=LR;\n");
191        dot.push_str("  node [shape=box];\n\n");
192
193        for (func_name, node) in &self.nodes {
194            let color = if node.is_recursive {
195                "lightcoral"
196            } else if self.get_callers(func_name).is_empty() {
197                "lightgreen"
198            } else {
199                "lightblue"
200            };
201
202            dot.push_str(&format!(
203                "  \"{}\" [label=\"{}\\n({}:{})\", fillcolor={}, style=filled];\n",
204                func_name, func_name, node.file_path, node.line, color
205            ));
206        }
207
208        dot.push_str("\n");
209
210        for edge in &self.edges {
211            let style = if edge.is_direct {
212                ""
213            } else {
214                " [style=dashed]"
215            };
216            dot.push_str(&format!(
217                "  \"{}\" -> \"{}\"{};\n",
218                edge.caller, edge.callee, style
219            ));
220        }
221
222        dot.push_str("}\n");
223        dot
224    }
225}
226
227pub fn build_call_graph(
228    path: &Path,
229    extensions: Option<&[String]>,
230    exclude: Option<&[String]>,
231) -> Result<CallGraph, Box<dyn std::error::Error>> {
232    let mut graph = CallGraph::new();
233    let mut function_definitions: HashMap<String, (String, usize)> = HashMap::new();
234
235    let walker = WalkDir::new(path)
236        .into_iter()
237        .filter_entry(|e| {
238            if let Some(name) = e.file_name().to_str() {
239                if let Some(exclude_dirs) = exclude {
240                    for exclude_dir in exclude_dirs {
241                        if name == exclude_dir {
242                            return false;
243                        }
244                    }
245                }
246            }
247            true
248        })
249        .filter_map(|e| e.ok())
250        .filter(|e| e.file_type().is_file());
251
252    let files: Vec<_> = walker
253        .filter(|entry| {
254            let file_path = entry.path();
255            if let Some(exts) = extensions {
256                if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
257                    exts.iter().any(|e| e == ext)
258                } else {
259                    false
260                }
261            } else {
262                true
263            }
264        })
265        .collect();
266
267    for entry in &files {
268        let file_path = entry.path();
269        let content = std::fs::read_to_string(file_path)?;
270
271        if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
272            if let Some(parser) = get_parser_for_extension(ext) {
273                if let Ok(analysis) = parser.parse_content(&content) {
274                    for func in analysis.functions {
275                        function_definitions.insert(
276                            func.name.clone(),
277                            (file_path.to_string_lossy().to_string(), func.line),
278                        );
279
280                        let node = CallNode {
281                            function_name: func.name,
282                            file_path: file_path.to_string_lossy().to_string(),
283                            line: func.line,
284                            is_recursive: false,
285                            call_count: 0,
286                        };
287                        graph.add_node(node);
288                    }
289                }
290            } else {
291                let func_def_pattern = regex::Regex::new(r"(?:fn|def|function|func)\s+(\w+)")?;
292                for (line_num, line) in content.lines().enumerate() {
293                    if let Some(caps) = func_def_pattern.captures(line) {
294                        if let Some(func_name) = caps.get(1) {
295                            let func_name_str = func_name.as_str().to_string();
296                            function_definitions.insert(
297                                func_name_str.clone(),
298                                (file_path.to_string_lossy().to_string(), line_num + 1),
299                            );
300
301                            let node = CallNode {
302                                function_name: func_name_str,
303                                file_path: file_path.to_string_lossy().to_string(),
304                                line: line_num + 1,
305                                is_recursive: false,
306                                call_count: 0,
307                            };
308                            graph.add_node(node);
309                        }
310                    }
311                }
312            }
313        }
314    }
315
316    let func_call_pattern = regex::Regex::new(r"(\w+)\s*\(")?;
317
318    for entry in &files {
319        let file_path = entry.path();
320        let content = std::fs::read_to_string(file_path)?;
321        let mut current_function = None;
322
323        if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
324            if let Some(parser) = get_parser_for_extension(ext) {
325                if let Ok(analysis) = parser.parse_content(&content) {
326                    for func in &analysis.functions {
327                        for (line_num, line) in content.lines().enumerate() {
328                            if line_num + 1 >= func.line {
329                                for cap in func_call_pattern.captures_iter(line) {
330                                    if let Some(callee_match) = cap.get(1) {
331                                        let callee = callee_match.as_str().to_string();
332
333                                        if function_definitions.contains_key(&callee)
334                                            && callee != func.name
335                                        {
336                                            graph.add_edge(
337                                                func.name.clone(),
338                                                callee,
339                                                line_num + 1,
340                                                true,
341                                            );
342                                        }
343                                    }
344                                }
345                            }
346                        }
347                    }
348                    continue;
349                }
350            }
351        }
352
353        let func_def_pattern = regex::Regex::new(r"(?:fn|def|function|func)\s+(\w+)")?;
354        for (line_num, line) in content.lines().enumerate() {
355            if let Some(caps) = func_def_pattern.captures(line) {
356                if let Some(func_name) = caps.get(1) {
357                    current_function = Some(func_name.as_str().to_string());
358                }
359            }
360
361            if let Some(caller) = &current_function {
362                for cap in func_call_pattern.captures_iter(line) {
363                    if let Some(callee_match) = cap.get(1) {
364                        let callee = callee_match.as_str().to_string();
365
366                        if function_definitions.contains_key(&callee) && callee != *caller {
367                            graph.add_edge(caller.clone(), callee, line_num + 1, true);
368                        }
369                    }
370                }
371            }
372        }
373    }
374
375    for func_name in graph.nodes.keys().cloned().collect::<Vec<_>>() {
376        if graph.is_recursive(&func_name) {
377            if let Some(node) = graph.nodes.get_mut(&func_name) {
378                node.is_recursive = true;
379            }
380        }
381    }
382
383    Ok(graph)
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_call_graph_creation() {
392        let graph = CallGraph::new();
393        assert_eq!(graph.nodes.len(), 0);
394        assert_eq!(graph.edges.len(), 0);
395    }
396
397    #[test]
398    fn test_add_node() {
399        let mut graph = CallGraph::new();
400        let node = CallNode {
401            function_name: "test".to_string(),
402            file_path: "test.rs".to_string(),
403            line: 1,
404            is_recursive: false,
405            call_count: 0,
406        };
407        graph.add_node(node);
408        assert_eq!(graph.nodes.len(), 1);
409    }
410
411    #[test]
412    fn test_get_callees() {
413        let mut graph = CallGraph::new();
414
415        graph.add_node(CallNode {
416            function_name: "main".to_string(),
417            file_path: "test.rs".to_string(),
418            line: 1,
419            is_recursive: false,
420            call_count: 0,
421        });
422
423        graph.add_node(CallNode {
424            function_name: "helper".to_string(),
425            file_path: "test.rs".to_string(),
426            line: 5,
427            is_recursive: false,
428            call_count: 0,
429        });
430
431        graph.add_edge("main".to_string(), "helper".to_string(), 2, true);
432
433        let callees = graph.get_callees("main");
434        assert_eq!(callees.len(), 1);
435        assert_eq!(callees[0], "helper");
436    }
437
438    #[test]
439    fn test_find_dead_functions() {
440        let mut graph = CallGraph::new();
441
442        graph.add_node(CallNode {
443            function_name: "main".to_string(),
444            file_path: "test.rs".to_string(),
445            line: 1,
446            is_recursive: false,
447            call_count: 0,
448        });
449
450        graph.add_node(CallNode {
451            function_name: "unused".to_string(),
452            file_path: "test.rs".to_string(),
453            line: 10,
454            is_recursive: false,
455            call_count: 0,
456        });
457
458        let dead = graph.find_dead_functions();
459        assert!(dead.contains(&"unused".to_string()));
460    }
461}