1use crate::parser::get_parser_for_extension;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::Path;
9use walkdir::WalkDir;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CallGraph {
13 pub nodes: HashMap<String, CallNode>,
14 pub edges: Vec<CallEdge>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CallNode {
19 pub function_name: String,
20 pub file_path: String,
21 pub line: usize,
22 pub is_recursive: bool,
23 pub call_count: usize,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CallEdge {
28 pub caller: String,
29 pub callee: String,
30 pub call_site_line: usize,
31 pub is_direct: bool,
32}
33
34impl CallGraph {
35 pub fn new() -> Self {
36 Self {
37 nodes: HashMap::new(),
38 edges: Vec::new(),
39 }
40 }
41
42 pub fn add_node(&mut self, node: CallNode) {
43 self.nodes.insert(node.function_name.clone(), node);
44 }
45
46 pub fn add_edge(
47 &mut self,
48 caller: String,
49 callee: String,
50 call_site_line: usize,
51 is_direct: bool,
52 ) {
53 self.edges.push(CallEdge {
54 caller,
55 callee,
56 call_site_line,
57 is_direct,
58 });
59 }
60
61 pub fn get_callers(&self, function: &str) -> Vec<String> {
62 self.edges
63 .iter()
64 .filter(|e| e.callee == function)
65 .map(|e| e.caller.clone())
66 .collect()
67 }
68
69 pub fn get_callees(&self, function: &str) -> Vec<String> {
70 self.edges
71 .iter()
72 .filter(|e| e.caller == function)
73 .map(|e| e.callee.clone())
74 .collect()
75 }
76
77 pub fn find_recursive_functions(&self) -> Vec<String> {
78 let mut recursive = Vec::new();
79
80 for (func_name, _) in &self.nodes {
81 if self.is_recursive(func_name) {
82 recursive.push(func_name.clone());
83 }
84 }
85
86 recursive
87 }
88
89 fn is_recursive(&self, function: &str) -> bool {
90 let mut visited = HashSet::new();
91 let mut stack = vec![function.to_string()];
92
93 while let Some(current) = stack.pop() {
94 if current == function && !visited.is_empty() {
95 return true;
96 }
97
98 if visited.insert(current.clone()) {
99 for callee in self.get_callees(¤t) {
100 stack.push(callee);
101 }
102 }
103 }
104
105 false
106 }
107
108 pub fn find_dead_functions(&self) -> Vec<String> {
109 let mut called_functions = HashSet::new();
110
111 for edge in &self.edges {
112 called_functions.insert(edge.callee.clone());
113 }
114
115 self.nodes
116 .keys()
117 .filter(|func| !called_functions.contains(*func) && *func != "main")
118 .cloned()
119 .collect()
120 }
121
122 pub fn calculate_call_depth(&self, function: &str) -> usize {
123 let mut max_depth = 0;
124 let mut visited = HashSet::new();
125 self.calculate_depth_recursive(function, 0, &mut visited, &mut max_depth);
126 max_depth
127 }
128
129 fn calculate_depth_recursive(
130 &self,
131 function: &str,
132 depth: usize,
133 visited: &mut HashSet<String>,
134 max_depth: &mut usize,
135 ) {
136 if visited.contains(function) {
137 return;
138 }
139
140 visited.insert(function.to_string());
141 *max_depth = (*max_depth).max(depth);
142
143 for callee in self.get_callees(function) {
144 self.calculate_depth_recursive(&callee, depth + 1, visited, max_depth);
145 }
146
147 visited.remove(function);
148 }
149
150 pub fn find_call_chains(&self, from: &str, to: &str) -> Vec<Vec<String>> {
151 let mut chains = Vec::new();
152 let mut current_path = vec![from.to_string()];
153 let mut visited = HashSet::new();
154
155 self.find_chains_recursive(from, to, &mut current_path, &mut visited, &mut chains);
156
157 chains
158 }
159
160 fn find_chains_recursive(
161 &self,
162 current: &str,
163 target: &str,
164 path: &mut Vec<String>,
165 visited: &mut HashSet<String>,
166 chains: &mut Vec<Vec<String>>,
167 ) {
168 if current == target {
169 chains.push(path.clone());
170 return;
171 }
172
173 if visited.contains(current) {
174 return;
175 }
176
177 visited.insert(current.to_string());
178
179 for callee in self.get_callees(current) {
180 path.push(callee.clone());
181 self.find_chains_recursive(&callee, target, path, visited, chains);
182 path.pop();
183 }
184
185 visited.remove(current);
186 }
187
188 pub fn to_dot(&self) -> String {
189 let mut dot = String::from("digraph CallGraph {\n");
190 dot.push_str(" rankdir=LR;\n");
191 dot.push_str(" node [shape=box];\n\n");
192
193 for (func_name, node) in &self.nodes {
194 let color = if node.is_recursive {
195 "lightcoral"
196 } else if self.get_callers(func_name).is_empty() {
197 "lightgreen"
198 } else {
199 "lightblue"
200 };
201
202 dot.push_str(&format!(
203 " \"{}\" [label=\"{}\\n({}:{})\", fillcolor={}, style=filled];\n",
204 func_name, func_name, node.file_path, node.line, color
205 ));
206 }
207
208 dot.push_str("\n");
209
210 for edge in &self.edges {
211 let style = if edge.is_direct {
212 ""
213 } else {
214 " [style=dashed]"
215 };
216 dot.push_str(&format!(
217 " \"{}\" -> \"{}\"{};\n",
218 edge.caller, edge.callee, style
219 ));
220 }
221
222 dot.push_str("}\n");
223 dot
224 }
225}
226
227pub fn build_call_graph(
228 path: &Path,
229 extensions: Option<&[String]>,
230 exclude: Option<&[String]>,
231) -> Result<CallGraph, Box<dyn std::error::Error>> {
232 let mut graph = CallGraph::new();
233 let mut function_definitions: HashMap<String, (String, usize)> = HashMap::new();
234
235 let walker = WalkDir::new(path)
236 .into_iter()
237 .filter_entry(|e| {
238 if let Some(name) = e.file_name().to_str() {
239 if let Some(exclude_dirs) = exclude {
240 for exclude_dir in exclude_dirs {
241 if name == exclude_dir {
242 return false;
243 }
244 }
245 }
246 }
247 true
248 })
249 .filter_map(|e| e.ok())
250 .filter(|e| e.file_type().is_file());
251
252 let files: Vec<_> = walker
253 .filter(|entry| {
254 let file_path = entry.path();
255 if let Some(exts) = extensions {
256 if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
257 exts.iter().any(|e| e == ext)
258 } else {
259 false
260 }
261 } else {
262 true
263 }
264 })
265 .collect();
266
267 for entry in &files {
268 let file_path = entry.path();
269 let content = std::fs::read_to_string(file_path)?;
270
271 if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
272 if let Some(parser) = get_parser_for_extension(ext) {
273 if let Ok(analysis) = parser.parse_content(&content) {
274 for func in analysis.functions {
275 function_definitions.insert(
276 func.name.clone(),
277 (file_path.to_string_lossy().to_string(), func.line),
278 );
279
280 let node = CallNode {
281 function_name: func.name,
282 file_path: file_path.to_string_lossy().to_string(),
283 line: func.line,
284 is_recursive: false,
285 call_count: 0,
286 };
287 graph.add_node(node);
288 }
289 }
290 } else {
291 let func_def_pattern = regex::Regex::new(r"(?:fn|def|function|func)\s+(\w+)")?;
292 for (line_num, line) in content.lines().enumerate() {
293 if let Some(caps) = func_def_pattern.captures(line) {
294 if let Some(func_name) = caps.get(1) {
295 let func_name_str = func_name.as_str().to_string();
296 function_definitions.insert(
297 func_name_str.clone(),
298 (file_path.to_string_lossy().to_string(), line_num + 1),
299 );
300
301 let node = CallNode {
302 function_name: func_name_str,
303 file_path: file_path.to_string_lossy().to_string(),
304 line: line_num + 1,
305 is_recursive: false,
306 call_count: 0,
307 };
308 graph.add_node(node);
309 }
310 }
311 }
312 }
313 }
314 }
315
316 let func_call_pattern = regex::Regex::new(r"(\w+)\s*\(")?;
317
318 for entry in &files {
319 let file_path = entry.path();
320 let content = std::fs::read_to_string(file_path)?;
321 let mut current_function = None;
322
323 if let Some(ext) = file_path.extension().and_then(|s| s.to_str()) {
324 if let Some(parser) = get_parser_for_extension(ext) {
325 if let Ok(analysis) = parser.parse_content(&content) {
326 for func in &analysis.functions {
327 for (line_num, line) in content.lines().enumerate() {
328 if line_num + 1 >= func.line {
329 for cap in func_call_pattern.captures_iter(line) {
330 if let Some(callee_match) = cap.get(1) {
331 let callee = callee_match.as_str().to_string();
332
333 if function_definitions.contains_key(&callee)
334 && callee != func.name
335 {
336 graph.add_edge(
337 func.name.clone(),
338 callee,
339 line_num + 1,
340 true,
341 );
342 }
343 }
344 }
345 }
346 }
347 }
348 continue;
349 }
350 }
351 }
352
353 let func_def_pattern = regex::Regex::new(r"(?:fn|def|function|func)\s+(\w+)")?;
354 for (line_num, line) in content.lines().enumerate() {
355 if let Some(caps) = func_def_pattern.captures(line) {
356 if let Some(func_name) = caps.get(1) {
357 current_function = Some(func_name.as_str().to_string());
358 }
359 }
360
361 if let Some(caller) = ¤t_function {
362 for cap in func_call_pattern.captures_iter(line) {
363 if let Some(callee_match) = cap.get(1) {
364 let callee = callee_match.as_str().to_string();
365
366 if function_definitions.contains_key(&callee) && callee != *caller {
367 graph.add_edge(caller.clone(), callee, line_num + 1, true);
368 }
369 }
370 }
371 }
372 }
373 }
374
375 for func_name in graph.nodes.keys().cloned().collect::<Vec<_>>() {
376 if graph.is_recursive(&func_name) {
377 if let Some(node) = graph.nodes.get_mut(&func_name) {
378 node.is_recursive = true;
379 }
380 }
381 }
382
383 Ok(graph)
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_call_graph_creation() {
392 let graph = CallGraph::new();
393 assert_eq!(graph.nodes.len(), 0);
394 assert_eq!(graph.edges.len(), 0);
395 }
396
397 #[test]
398 fn test_add_node() {
399 let mut graph = CallGraph::new();
400 let node = CallNode {
401 function_name: "test".to_string(),
402 file_path: "test.rs".to_string(),
403 line: 1,
404 is_recursive: false,
405 call_count: 0,
406 };
407 graph.add_node(node);
408 assert_eq!(graph.nodes.len(), 1);
409 }
410
411 #[test]
412 fn test_get_callees() {
413 let mut graph = CallGraph::new();
414
415 graph.add_node(CallNode {
416 function_name: "main".to_string(),
417 file_path: "test.rs".to_string(),
418 line: 1,
419 is_recursive: false,
420 call_count: 0,
421 });
422
423 graph.add_node(CallNode {
424 function_name: "helper".to_string(),
425 file_path: "test.rs".to_string(),
426 line: 5,
427 is_recursive: false,
428 call_count: 0,
429 });
430
431 graph.add_edge("main".to_string(), "helper".to_string(), 2, true);
432
433 let callees = graph.get_callees("main");
434 assert_eq!(callees.len(), 1);
435 assert_eq!(callees[0], "helper");
436 }
437
438 #[test]
439 fn test_find_dead_functions() {
440 let mut graph = CallGraph::new();
441
442 graph.add_node(CallNode {
443 function_name: "main".to_string(),
444 file_path: "test.rs".to_string(),
445 line: 1,
446 is_recursive: false,
447 call_count: 0,
448 });
449
450 graph.add_node(CallNode {
451 function_name: "unused".to_string(),
452 file_path: "test.rs".to_string(),
453 line: 10,
454 is_recursive: false,
455 call_count: 0,
456 });
457
458 let dead = graph.find_dead_functions();
459 assert!(dead.contains(&"unused".to_string()));
460 }
461}