1use crate::error::Result;
2use crate::trace::{CallExtractor, FunctionDef, FunctionFinder};
3use std::collections::HashSet;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum TraceDirection {
8 Forward,
10 Backward,
12}
13
14#[derive(Debug, Clone)]
16pub struct CallNode {
17 pub def: FunctionDef,
19 pub children: Vec<CallNode>,
21 pub truncated: bool,
23}
24
25#[derive(Debug, Clone)]
27pub struct CallTree {
28 pub root: CallNode,
30}
31
32pub struct CallGraphBuilder<'a> {
34 direction: TraceDirection,
35 max_depth: usize,
36 finder: &'a mut FunctionFinder,
37 extractor: &'a CallExtractor,
38}
39
40impl<'a> CallGraphBuilder<'a> {
41 pub fn new(
49 direction: TraceDirection,
50 max_depth: usize,
51 finder: &'a mut FunctionFinder,
52 extractor: &'a CallExtractor,
53 ) -> Self {
54 Self {
55 direction,
56 max_depth,
57 finder,
58 extractor,
59 }
60 }
61
62 pub fn build_trace(&mut self, start_fn: &FunctionDef) -> Result<Option<CallTree>> {
64 let mut current_path = HashSet::new();
65
66 match self.build_node(start_fn, 0, &mut current_path) {
67 Some(root) => Ok(Some(CallTree { root })),
68 None => Ok(None),
69 }
70 }
71
72 fn build_node(
77 &mut self,
78 func: &FunctionDef,
79 depth: usize,
80 current_path: &mut HashSet<FunctionDef>,
81 ) -> Option<CallNode> {
82 if depth >= self.max_depth {
84 return Some(CallNode {
85 def: func.clone(),
86 children: vec![],
87 truncated: true,
88 });
89 }
90
91 if current_path.contains(func) {
93 return Some(CallNode {
94 def: func.clone(),
95 children: vec![],
96 truncated: false, });
98 }
99
100 current_path.insert(func.clone());
102
103 let children = match self.direction {
104 TraceDirection::Forward => self.build_forward_children(func, depth, current_path),
105 TraceDirection::Backward => self.build_backward_children(func, depth, current_path),
106 };
107
108 current_path.remove(func);
110
111 Some(CallNode {
112 def: func.clone(),
113 children,
114 truncated: false,
115 })
116 }
117
118 fn build_forward_children(
120 &mut self,
121 func: &FunctionDef,
122 depth: usize,
123 current_path: &mut HashSet<FunctionDef>,
124 ) -> Vec<CallNode> {
125 let call_names = match self.extractor.extract_calls(func) {
127 Ok(calls) => calls,
128 Err(_) => return vec![], };
130
131 let mut children = Vec::new();
132
133 for call_name in call_names {
134 if let Some(called_func) = self.finder.find_function(&call_name) {
136 if let Some(child_node) = self.build_node(&called_func, depth + 1, current_path) {
138 children.push(child_node);
139 }
140 }
141 }
143
144 children
145 }
146
147 fn build_backward_children(
149 &mut self,
150 func: &FunctionDef,
151 depth: usize,
152 current_path: &mut HashSet<FunctionDef>,
153 ) -> Vec<CallNode> {
154 let callers = match self.extractor.find_callers(&func.name) {
156 Ok(caller_infos) => caller_infos,
157 Err(_) => return vec![], };
159
160 let mut children = Vec::new();
161
162 for caller_info in callers {
163 if let Some(caller_func) = self.finder.find_function(&caller_info.caller_name) {
165 if !children.iter().any(|child: &CallNode| {
167 child.def.name == caller_func.name && child.def.file == caller_func.file
168 }) {
169 if let Some(child_node) = self.build_node(&caller_func, depth + 1, current_path)
171 {
172 children.push(child_node);
173 }
174 }
175 }
176 }
178
179 children
180 }
181}
182
183impl CallTree {
184 pub fn node_count(&self) -> usize {
186 Self::count_nodes(&self.root)
187 }
188
189 pub fn max_depth(&self) -> usize {
191 Self::calculate_depth(&self.root, 0)
192 }
193
194 pub fn has_cycles(&self) -> bool {
196 let mut visited = HashSet::new();
197 let mut path = HashSet::new();
198 Self::has_cycle_helper(&self.root, &mut visited, &mut path)
199 }
200
201 fn count_nodes(node: &CallNode) -> usize {
202 1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
203 }
204
205 fn calculate_depth(node: &CallNode, current_depth: usize) -> usize {
206 if node.children.is_empty() {
207 current_depth
208 } else {
209 node.children
210 .iter()
211 .map(|child| Self::calculate_depth(child, current_depth + 1))
212 .max()
213 .unwrap_or(current_depth)
214 }
215 }
216
217 fn has_cycle_helper(
218 node: &CallNode,
219 visited: &mut HashSet<FunctionDef>,
220 path: &mut HashSet<FunctionDef>,
221 ) -> bool {
222 if path.contains(&node.def) {
223 return true; }
225
226 if visited.contains(&node.def) {
227 return false; }
229
230 visited.insert(node.def.clone());
231 path.insert(node.def.clone());
232
233 for child in &node.children {
234 if Self::has_cycle_helper(child, visited, path) {
235 return true;
236 }
237 }
238
239 path.remove(&node.def);
240 false
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::path::PathBuf;
248
249 fn create_test_function(name: &str, file: &str, line: usize) -> FunctionDef {
250 FunctionDef {
251 name: name.to_string(),
252 file: PathBuf::from(file),
253 line,
254 body: format!("function {}() {{}}", name),
255 }
256 }
257
258 #[test]
259 fn test_trace_direction_equality() {
260 assert_eq!(TraceDirection::Forward, TraceDirection::Forward);
261 assert_eq!(TraceDirection::Backward, TraceDirection::Backward);
262 assert_ne!(TraceDirection::Forward, TraceDirection::Backward);
263 }
264
265 #[test]
266 fn test_call_node_creation() {
267 let func = create_test_function("test_func", "test.js", 10);
268 let node = CallNode {
269 def: func.clone(),
270 children: vec![],
271 truncated: false,
272 };
273
274 assert_eq!(node.def.name, "test_func");
275 assert_eq!(node.children.len(), 0);
276 assert!(!node.truncated);
277 }
278
279 #[test]
280 fn test_call_tree_creation() {
281 let func = create_test_function("main", "main.js", 1);
282 let root = CallNode {
283 def: func,
284 children: vec![],
285 truncated: false,
286 };
287 let tree = CallTree { root };
288
289 assert_eq!(tree.root.def.name, "main");
290 }
291
292 #[test]
293 fn test_call_tree_node_count() {
294 let main_func = create_test_function("main", "main.js", 1);
295 let helper_func = create_test_function("helper", "utils.js", 5);
296
297 let helper_node = CallNode {
298 def: helper_func,
299 children: vec![],
300 truncated: false,
301 };
302
303 let root = CallNode {
304 def: main_func,
305 children: vec![helper_node],
306 truncated: false,
307 };
308
309 let tree = CallTree { root };
310 assert_eq!(tree.node_count(), 2);
311 }
312
313 #[test]
314 fn test_call_tree_max_depth() {
315 let func1 = create_test_function("func1", "test.js", 1);
316 let func2 = create_test_function("func2", "test.js", 10);
317 let func3 = create_test_function("func3", "test.js", 20);
318
319 let node3 = CallNode {
321 def: func3,
322 children: vec![],
323 truncated: false,
324 };
325
326 let node2 = CallNode {
327 def: func2,
328 children: vec![node3],
329 truncated: false,
330 };
331
332 let root = CallNode {
333 def: func1,
334 children: vec![node2],
335 truncated: false,
336 };
337
338 let tree = CallTree { root };
339 assert_eq!(tree.max_depth(), 2); }
341
342 #[test]
343 fn test_call_graph_builder_creation() {
344 use crate::trace::{CallExtractor, FunctionFinder};
345 use std::env;
346
347 let base_dir = env::current_dir().unwrap();
348 let mut finder = FunctionFinder::new(base_dir.clone());
349 let extractor = CallExtractor::new(base_dir);
350
351 let builder = CallGraphBuilder::new(TraceDirection::Forward, 5, &mut finder, &extractor);
352
353 assert_eq!(builder.direction, TraceDirection::Forward);
354 assert_eq!(builder.max_depth, 5);
355 }
356
357 #[test]
358 fn test_depth_limit_handling() {
359 use crate::trace::{CallExtractor, FunctionFinder};
360 use std::env;
361
362 let base_dir = env::current_dir().unwrap();
363 let mut finder = FunctionFinder::new(base_dir.clone());
364 let extractor = CallExtractor::new(base_dir);
365
366 let mut builder = CallGraphBuilder::new(
367 TraceDirection::Forward,
368 0, &mut finder,
370 &extractor,
371 );
372
373 let test_func = create_test_function("test", "test.js", 1);
374 let mut path = HashSet::new();
375 let result = builder.build_node(&test_func, 0, &mut path);
376
377 assert!(result.is_some());
378 let node = result.unwrap();
379 assert_eq!(node.def.name, "test");
380 assert_eq!(node.children.len(), 0); assert!(node.truncated); }
383
384 #[test]
385 fn test_cycle_detection() {
386 use crate::trace::{CallExtractor, FunctionFinder};
387 use std::env;
388
389 let base_dir = env::current_dir().unwrap();
390 let mut finder = FunctionFinder::new(base_dir.clone());
391 let extractor = CallExtractor::new(base_dir);
392
393 let mut builder =
394 CallGraphBuilder::new(TraceDirection::Forward, 10, &mut finder, &extractor);
395
396 let test_func = create_test_function("recursive", "test.js", 1);
397 let mut path = HashSet::new();
398
399 path.insert(test_func.clone());
401
402 let result = builder.build_node(&test_func, 0, &mut path);
403
404 assert!(result.is_some());
405 let node = result.unwrap();
406 assert_eq!(node.children.len(), 0); }
408
409 #[test]
410 fn test_function_def_equality() {
411 let func1 = create_test_function("test", "file.js", 10);
412 let func2 = create_test_function("test", "file.js", 10);
413 let func3 = create_test_function("test", "file.js", 20);
414
415 assert_eq!(func1, func2);
416 assert_ne!(func1, func3); }
418}