cs/trace/
graph_builder.rs

1use crate::error::Result;
2use crate::trace::{CallExtractor, FunctionDef, FunctionFinder};
3use std::collections::HashSet;
4
5/// Direction of the call graph trace
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum TraceDirection {
8    /// Trace forward: which functions does this function call?
9    Forward,
10    /// Trace backward: which functions call this function?
11    Backward,
12}
13
14/// A node in the call graph tree
15#[derive(Debug, Clone)]
16pub struct CallNode {
17    /// The function definition for this node
18    pub def: FunctionDef,
19    /// Child nodes (called functions or callers, depending on direction)
20    pub children: Vec<CallNode>,
21    /// Whether the tree was truncated at this node due to depth limit
22    pub truncated: bool,
23}
24
25/// Represents a complete call graph tree
26#[derive(Debug, Clone)]
27pub struct CallTree {
28    /// The root node of the tree (the starting function)
29    pub root: CallNode,
30}
31
32/// Builds a call graph by recursively tracing function calls
33pub struct CallGraphBuilder<'a> {
34    direction: TraceDirection,
35    max_depth: usize,
36    finder: &'a FunctionFinder,
37    extractor: &'a CallExtractor,
38}
39
40impl<'a> CallGraphBuilder<'a> {
41    /// Create a new CallGraphBuilder
42    ///
43    /// # Arguments
44    /// * `direction` - Whether to trace forward or backward
45    /// * `max_depth` - Maximum depth of the call tree
46    /// * `finder` - Service to find function definitions
47    /// * `extractor` - Service to extract calls from code
48    pub fn new(
49        direction: TraceDirection,
50        max_depth: usize,
51        finder: &'a FunctionFinder,
52        extractor: &'a CallExtractor,
53    ) -> Self {
54        Self {
55            direction,
56            max_depth,
57            finder,
58            extractor,
59        }
60    }
61
62    /// Build a call trace tree starting from the given function
63    pub fn build_trace(&self, start_fn: &FunctionDef) -> Result<Option<CallTree>> {
64        let mut current_path = HashSet::new();
65
66        match self.build_node(start_fn, 0, &mut current_path) {
67            Some(root) => Ok(Some(CallTree { root })),
68            None => Ok(None),
69        }
70    }
71
72    /// Recursively build a call tree node
73    ///
74    /// Uses proper cycle detection with current_path to prevent infinite recursion
75    /// while still allowing the same function to appear in different branches.
76    fn build_node(
77        &self,
78        func: &FunctionDef,
79        depth: usize,
80        current_path: &mut HashSet<FunctionDef>,
81    ) -> Option<CallNode> {
82        // Check depth limit
83        if depth >= self.max_depth {
84            return Some(CallNode {
85                def: func.clone(),
86                children: vec![],
87                truncated: true,
88            });
89        }
90
91        // Check for cycles in current path (prevents infinite recursion)
92        if current_path.contains(func) {
93            return Some(CallNode {
94                def: func.clone(),
95                children: vec![],
96                truncated: false, // Not truncated by depth, but by cycle
97            });
98        }
99
100        // Add to current path for cycle detection
101        current_path.insert(func.clone());
102
103        let children = match self.direction {
104            TraceDirection::Forward => self.build_forward_children(func, depth, current_path),
105            TraceDirection::Backward => self.build_backward_children(func, depth, current_path),
106        };
107
108        // Remove from current path (allows same function in different branches)
109        current_path.remove(func);
110
111        Some(CallNode {
112            def: func.clone(),
113            children,
114            truncated: false,
115        })
116    }
117
118    /// Build children for forward tracing (what does this function call?)
119    fn build_forward_children(
120        &self,
121        func: &FunctionDef,
122        depth: usize,
123        current_path: &mut HashSet<FunctionDef>,
124    ) -> Vec<CallNode> {
125        // Extract function calls from this function's body
126        let call_names = match self.extractor.extract_calls(func) {
127            Ok(calls) => calls,
128            Err(_) => return vec![], // If extraction fails, return empty children
129        };
130
131        let mut children = Vec::new();
132
133        for call_name in call_names {
134            // Find the definition of the called function
135            if let Some(called_func) = self.finder.find_function(&call_name) {
136                // Recursively build the child node
137                if let Some(child_node) = self.build_node(&called_func, depth + 1, current_path) {
138                    children.push(child_node);
139                }
140            }
141            // If function not found, we simply don't include it (graceful handling)
142        }
143
144        children
145    }
146
147    /// Build children for backward tracing (who calls this function?)
148    fn build_backward_children(
149        &self,
150        func: &FunctionDef,
151        depth: usize,
152        current_path: &mut HashSet<FunctionDef>,
153    ) -> Vec<CallNode> {
154        // Find all functions that call this function
155        let callers = match self.extractor.find_callers(&func.name) {
156            Ok(caller_infos) => caller_infos,
157            Err(_) => return vec![], // If finding callers fails, return empty children
158        };
159
160        let mut children = Vec::new();
161
162        for caller_info in callers {
163            // Try to find the caller function definition
164            if let Some(caller_func) = self.finder.find_function(&caller_info.caller_name) {
165                // Avoid adding the same caller multiple times
166                if !children.iter().any(|child: &CallNode| {
167                    child.def.name == caller_func.name && child.def.file == caller_func.file
168                }) {
169                    // Recursively build the child node
170                    if let Some(child_node) = self.build_node(&caller_func, depth + 1, current_path)
171                    {
172                        children.push(child_node);
173                    }
174                }
175            }
176            // If caller function not found, we simply don't include it (graceful handling)
177        }
178
179        children
180    }
181}
182
183impl CallTree {
184    /// Get the total number of nodes in the tree
185    pub fn node_count(&self) -> usize {
186        Self::count_nodes(&self.root)
187    }
188
189    /// Get the maximum depth of the tree
190    pub fn max_depth(&self) -> usize {
191        Self::calculate_depth(&self.root, 0)
192    }
193
194    /// Check if the tree contains cycles
195    pub fn has_cycles(&self) -> bool {
196        let mut visited = HashSet::new();
197        let mut path = HashSet::new();
198        Self::has_cycle_helper(&self.root, &mut visited, &mut path)
199    }
200
201    fn count_nodes(node: &CallNode) -> usize {
202        1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
203    }
204
205    fn calculate_depth(node: &CallNode, current_depth: usize) -> usize {
206        if node.children.is_empty() {
207            current_depth
208        } else {
209            node.children
210                .iter()
211                .map(|child| Self::calculate_depth(child, current_depth + 1))
212                .max()
213                .unwrap_or(current_depth)
214        }
215    }
216
217    fn has_cycle_helper(
218        node: &CallNode,
219        visited: &mut HashSet<FunctionDef>,
220        path: &mut HashSet<FunctionDef>,
221    ) -> bool {
222        if path.contains(&node.def) {
223            return true; // Found a cycle
224        }
225
226        if visited.contains(&node.def) {
227            return false; // Already processed this node
228        }
229
230        visited.insert(node.def.clone());
231        path.insert(node.def.clone());
232
233        for child in &node.children {
234            if Self::has_cycle_helper(child, visited, path) {
235                return true;
236            }
237        }
238
239        path.remove(&node.def);
240        false
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::path::PathBuf;
248
249    fn create_test_function(name: &str, file: &str, line: usize) -> FunctionDef {
250        FunctionDef {
251            name: name.to_string(),
252            file: PathBuf::from(file),
253            line,
254            body: format!("function {}() {{}}", name),
255        }
256    }
257
258    #[test]
259    fn test_trace_direction_equality() {
260        assert_eq!(TraceDirection::Forward, TraceDirection::Forward);
261        assert_eq!(TraceDirection::Backward, TraceDirection::Backward);
262        assert_ne!(TraceDirection::Forward, TraceDirection::Backward);
263    }
264
265    #[test]
266    fn test_call_node_creation() {
267        let func = create_test_function("test_func", "test.js", 10);
268        let node = CallNode {
269            def: func.clone(),
270            children: vec![],
271            truncated: false,
272        };
273
274        assert_eq!(node.def.name, "test_func");
275        assert_eq!(node.children.len(), 0);
276        assert!(!node.truncated);
277    }
278
279    #[test]
280    fn test_call_tree_creation() {
281        let func = create_test_function("main", "main.js", 1);
282        let root = CallNode {
283            def: func,
284            children: vec![],
285            truncated: false,
286        };
287        let tree = CallTree { root };
288
289        assert_eq!(tree.root.def.name, "main");
290    }
291
292    #[test]
293    fn test_call_tree_node_count() {
294        let main_func = create_test_function("main", "main.js", 1);
295        let helper_func = create_test_function("helper", "utils.js", 5);
296
297        let helper_node = CallNode {
298            def: helper_func,
299            children: vec![],
300            truncated: false,
301        };
302
303        let root = CallNode {
304            def: main_func,
305            children: vec![helper_node],
306            truncated: false,
307        };
308
309        let tree = CallTree { root };
310        assert_eq!(tree.node_count(), 2);
311    }
312
313    #[test]
314    fn test_call_tree_max_depth() {
315        let func1 = create_test_function("func1", "test.js", 1);
316        let func2 = create_test_function("func2", "test.js", 10);
317        let func3 = create_test_function("func3", "test.js", 20);
318
319        // Create a chain: func1 -> func2 -> func3
320        let node3 = CallNode {
321            def: func3,
322            children: vec![],
323            truncated: false,
324        };
325
326        let node2 = CallNode {
327            def: func2,
328            children: vec![node3],
329            truncated: false,
330        };
331
332        let root = CallNode {
333            def: func1,
334            children: vec![node2],
335            truncated: false,
336        };
337
338        let tree = CallTree { root };
339        assert_eq!(tree.max_depth(), 2); // 0-indexed depth
340    }
341
342    #[test]
343    fn test_call_graph_builder_creation() {
344        use crate::trace::{CallExtractor, FunctionFinder};
345        use std::env;
346
347        let base_dir = env::current_dir().unwrap();
348        let finder = FunctionFinder::new(base_dir.clone());
349        let extractor = CallExtractor::new(base_dir);
350
351        let builder = CallGraphBuilder::new(TraceDirection::Forward, 5, &finder, &extractor);
352
353        assert_eq!(builder.direction, TraceDirection::Forward);
354        assert_eq!(builder.max_depth, 5);
355    }
356
357    #[test]
358    fn test_depth_limit_handling() {
359        use crate::trace::{CallExtractor, FunctionFinder};
360        use std::env;
361
362        let base_dir = env::current_dir().unwrap();
363        let finder = FunctionFinder::new(base_dir.clone());
364        let extractor = CallExtractor::new(base_dir);
365
366        let builder = CallGraphBuilder::new(
367            TraceDirection::Forward,
368            0, // Max depth of 0 should only return root
369            &finder,
370            &extractor,
371        );
372
373        let test_func = create_test_function("test", "test.js", 1);
374        let mut path = HashSet::new();
375        let result = builder.build_node(&test_func, 0, &mut path);
376
377        assert!(result.is_some());
378        let node = result.unwrap();
379        assert_eq!(node.def.name, "test");
380        assert_eq!(node.children.len(), 0); // Should have no children due to depth limit
381        assert!(node.truncated); // Should be truncated
382    }
383
384    #[test]
385    fn test_cycle_detection() {
386        use crate::trace::{CallExtractor, FunctionFinder};
387        use std::env;
388
389        let base_dir = env::current_dir().unwrap();
390        let finder = FunctionFinder::new(base_dir.clone());
391        let extractor = CallExtractor::new(base_dir);
392
393        let builder = CallGraphBuilder::new(TraceDirection::Forward, 10, &finder, &extractor);
394
395        let test_func = create_test_function("recursive", "test.js", 1);
396        let mut path = HashSet::new();
397
398        // Add the function to path to simulate cycle detection
399        path.insert(test_func.clone());
400
401        let result = builder.build_node(&test_func, 0, &mut path);
402
403        assert!(result.is_some());
404        let node = result.unwrap();
405        assert_eq!(node.children.len(), 0); // Should stop due to cycle detection
406    }
407
408    #[test]
409    fn test_function_def_equality() {
410        let func1 = create_test_function("test", "file.js", 10);
411        let func2 = create_test_function("test", "file.js", 10);
412        let func3 = create_test_function("test", "file.js", 20);
413
414        assert_eq!(func1, func2);
415        assert_ne!(func1, func3); // Different line numbers
416    }
417}