1use crate::error::Result;
2use crate::trace::{CallExtractor, FunctionDef, FunctionFinder};
3use std::collections::HashSet;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum TraceDirection {
8 Forward,
10 Backward,
12}
13
14#[derive(Debug, Clone)]
16pub struct CallNode {
17 pub def: FunctionDef,
19 pub children: Vec<CallNode>,
21 pub truncated: bool,
23}
24
25#[derive(Debug, Clone)]
27pub struct CallTree {
28 pub root: CallNode,
30}
31
32pub struct CallGraphBuilder<'a> {
34 direction: TraceDirection,
35 max_depth: usize,
36 finder: &'a FunctionFinder,
37 extractor: &'a CallExtractor,
38}
39
40impl<'a> CallGraphBuilder<'a> {
41 pub fn new(
49 direction: TraceDirection,
50 max_depth: usize,
51 finder: &'a FunctionFinder,
52 extractor: &'a CallExtractor,
53 ) -> Self {
54 Self {
55 direction,
56 max_depth,
57 finder,
58 extractor,
59 }
60 }
61
62 pub fn build_trace(&self, start_fn: &FunctionDef) -> Result<Option<CallTree>> {
64 let mut current_path = HashSet::new();
65
66 match self.build_node(start_fn, 0, &mut current_path) {
67 Some(root) => Ok(Some(CallTree { root })),
68 None => Ok(None),
69 }
70 }
71
72 fn build_node(
77 &self,
78 func: &FunctionDef,
79 depth: usize,
80 current_path: &mut HashSet<FunctionDef>,
81 ) -> Option<CallNode> {
82 if depth >= self.max_depth {
84 return Some(CallNode {
85 def: func.clone(),
86 children: vec![],
87 truncated: true,
88 });
89 }
90
91 if current_path.contains(func) {
93 return Some(CallNode {
94 def: func.clone(),
95 children: vec![],
96 truncated: false, });
98 }
99
100 current_path.insert(func.clone());
102
103 let children = match self.direction {
104 TraceDirection::Forward => self.build_forward_children(func, depth, current_path),
105 TraceDirection::Backward => self.build_backward_children(func, depth, current_path),
106 };
107
108 current_path.remove(func);
110
111 Some(CallNode {
112 def: func.clone(),
113 children,
114 truncated: false,
115 })
116 }
117
118 fn build_forward_children(
120 &self,
121 func: &FunctionDef,
122 depth: usize,
123 current_path: &mut HashSet<FunctionDef>,
124 ) -> Vec<CallNode> {
125 let call_names = match self.extractor.extract_calls(func) {
127 Ok(calls) => calls,
128 Err(_) => return vec![], };
130
131 let mut children = Vec::new();
132
133 for call_name in call_names {
134 if let Some(called_func) = self.finder.find_function(&call_name) {
136 if let Some(child_node) = self.build_node(&called_func, depth + 1, current_path) {
138 children.push(child_node);
139 }
140 }
141 }
143
144 children
145 }
146
147 fn build_backward_children(
149 &self,
150 func: &FunctionDef,
151 depth: usize,
152 current_path: &mut HashSet<FunctionDef>,
153 ) -> Vec<CallNode> {
154 let callers = match self.extractor.find_callers(&func.name) {
156 Ok(caller_infos) => caller_infos,
157 Err(_) => return vec![], };
159
160 let mut children = Vec::new();
161
162 for caller_info in callers {
163 if let Some(caller_func) = self.finder.find_function(&caller_info.caller_name) {
165 if !children.iter().any(|child: &CallNode| {
167 child.def.name == caller_func.name && child.def.file == caller_func.file
168 }) {
169 if let Some(child_node) = self.build_node(&caller_func, depth + 1, current_path)
171 {
172 children.push(child_node);
173 }
174 }
175 }
176 }
178
179 children
180 }
181}
182
183impl CallTree {
184 pub fn node_count(&self) -> usize {
186 self.count_nodes(&self.root)
187 }
188
189 pub fn max_depth(&self) -> usize {
191 self.calculate_depth(&self.root, 0)
192 }
193
194 pub fn has_cycles(&self) -> bool {
196 let mut visited = HashSet::new();
197 let mut path = HashSet::new();
198 self.has_cycle_helper(&self.root, &mut visited, &mut path)
199 }
200
201 fn count_nodes(&self, node: &CallNode) -> usize {
202 1 + node
203 .children
204 .iter()
205 .map(|child| self.count_nodes(child))
206 .sum::<usize>()
207 }
208
209 fn calculate_depth(&self, node: &CallNode, current_depth: usize) -> usize {
210 if node.children.is_empty() {
211 current_depth
212 } else {
213 node.children
214 .iter()
215 .map(|child| self.calculate_depth(child, current_depth + 1))
216 .max()
217 .unwrap_or(current_depth)
218 }
219 }
220
221 fn has_cycle_helper(
222 &self,
223 node: &CallNode,
224 visited: &mut HashSet<FunctionDef>,
225 path: &mut HashSet<FunctionDef>,
226 ) -> bool {
227 if path.contains(&node.def) {
228 return true; }
230
231 if visited.contains(&node.def) {
232 return false; }
234
235 visited.insert(node.def.clone());
236 path.insert(node.def.clone());
237
238 for child in &node.children {
239 if self.has_cycle_helper(child, visited, path) {
240 return true;
241 }
242 }
243
244 path.remove(&node.def);
245 false
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::path::PathBuf;
253
254 fn create_test_function(name: &str, file: &str, line: usize) -> FunctionDef {
255 FunctionDef {
256 name: name.to_string(),
257 file: PathBuf::from(file),
258 line,
259 body: format!("function {}() {{}}", name),
260 }
261 }
262
263 #[test]
264 fn test_trace_direction_equality() {
265 assert_eq!(TraceDirection::Forward, TraceDirection::Forward);
266 assert_eq!(TraceDirection::Backward, TraceDirection::Backward);
267 assert_ne!(TraceDirection::Forward, TraceDirection::Backward);
268 }
269
270 #[test]
271 fn test_call_node_creation() {
272 let func = create_test_function("test_func", "test.js", 10);
273 let node = CallNode {
274 def: func.clone(),
275 children: vec![],
276 truncated: false,
277 };
278
279 assert_eq!(node.def.name, "test_func");
280 assert_eq!(node.children.len(), 0);
281 assert!(!node.truncated);
282 }
283
284 #[test]
285 fn test_call_tree_creation() {
286 let func = create_test_function("main", "main.js", 1);
287 let root = CallNode {
288 def: func,
289 children: vec![],
290 truncated: false,
291 };
292 let tree = CallTree { root };
293
294 assert_eq!(tree.root.def.name, "main");
295 }
296
297 #[test]
298 fn test_call_tree_node_count() {
299 let main_func = create_test_function("main", "main.js", 1);
300 let helper_func = create_test_function("helper", "utils.js", 5);
301
302 let helper_node = CallNode {
303 def: helper_func,
304 children: vec![],
305 truncated: false,
306 };
307
308 let root = CallNode {
309 def: main_func,
310 children: vec![helper_node],
311 truncated: false,
312 };
313
314 let tree = CallTree { root };
315 assert_eq!(tree.node_count(), 2);
316 }
317
318 #[test]
319 fn test_call_tree_max_depth() {
320 let func1 = create_test_function("func1", "test.js", 1);
321 let func2 = create_test_function("func2", "test.js", 10);
322 let func3 = create_test_function("func3", "test.js", 20);
323
324 let node3 = CallNode {
326 def: func3,
327 children: vec![],
328 truncated: false,
329 };
330
331 let node2 = CallNode {
332 def: func2,
333 children: vec![node3],
334 truncated: false,
335 };
336
337 let root = CallNode {
338 def: func1,
339 children: vec![node2],
340 truncated: false,
341 };
342
343 let tree = CallTree { root };
344 assert_eq!(tree.max_depth(), 2); }
346
347 #[test]
348 fn test_call_graph_builder_creation() {
349 use crate::trace::{CallExtractor, FunctionFinder};
350 use std::env;
351
352 let base_dir = env::current_dir().unwrap();
353 let finder = FunctionFinder::new(base_dir.clone());
354 let extractor = CallExtractor::new(base_dir);
355
356 let builder = CallGraphBuilder::new(TraceDirection::Forward, 5, &finder, &extractor);
357
358 assert_eq!(builder.direction, TraceDirection::Forward);
359 assert_eq!(builder.max_depth, 5);
360 }
361
362 #[test]
363 fn test_depth_limit_handling() {
364 use crate::trace::{CallExtractor, FunctionFinder};
365 use std::env;
366
367 let base_dir = env::current_dir().unwrap();
368 let finder = FunctionFinder::new(base_dir.clone());
369 let extractor = CallExtractor::new(base_dir);
370
371 let builder = CallGraphBuilder::new(
372 TraceDirection::Forward,
373 0, &finder,
375 &extractor,
376 );
377
378 let test_func = create_test_function("test", "test.js", 1);
379 let mut path = HashSet::new();
380 let result = builder.build_node(&test_func, 0, &mut path);
381
382 assert!(result.is_some());
383 let node = result.unwrap();
384 assert_eq!(node.def.name, "test");
385 assert_eq!(node.children.len(), 0); assert!(node.truncated); }
388
389 #[test]
390 fn test_cycle_detection() {
391 use crate::trace::{CallExtractor, FunctionFinder};
392 use std::env;
393
394 let base_dir = env::current_dir().unwrap();
395 let finder = FunctionFinder::new(base_dir.clone());
396 let extractor = CallExtractor::new(base_dir);
397
398 let builder = CallGraphBuilder::new(TraceDirection::Forward, 10, &finder, &extractor);
399
400 let test_func = create_test_function("recursive", "test.js", 1);
401 let mut path = HashSet::new();
402
403 path.insert(test_func.clone());
405
406 let result = builder.build_node(&test_func, 0, &mut path);
407
408 assert!(result.is_some());
409 let node = result.unwrap();
410 assert_eq!(node.children.len(), 0); }
412
413 #[test]
414 fn test_function_def_equality() {
415 let func1 = create_test_function("test", "file.js", 10);
416 let func2 = create_test_function("test", "file.js", 10);
417 let func3 = create_test_function("test", "file.js", 20);
418
419 assert_eq!(func1, func2);
420 assert_ne!(func1, func3); }
422}