cs/trace/
graph_builder.rs

1use crate::error::Result;
2use crate::trace::{CallExtractor, FunctionDef, FunctionFinder};
3use std::collections::HashSet;
4
5/// Direction of the call graph trace
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum TraceDirection {
8    /// Trace forward: which functions does this function call?
9    Forward,
10    /// Trace backward: which functions call this function?
11    Backward,
12}
13
14/// A node in the call graph tree
15#[derive(Debug, Clone)]
16pub struct CallNode {
17    /// The function definition for this node
18    pub def: FunctionDef,
19    /// Child nodes (called functions or callers, depending on direction)
20    pub children: Vec<CallNode>,
21    /// Whether the tree was truncated at this node due to depth limit
22    pub truncated: bool,
23}
24
25/// Represents a complete call graph tree
26#[derive(Debug, Clone)]
27pub struct CallTree {
28    /// The root node of the tree (the starting function)
29    pub root: CallNode,
30}
31
32/// Builds a call graph by recursively tracing function calls
33pub struct CallGraphBuilder<'a> {
34    direction: TraceDirection,
35    max_depth: usize,
36    finder: &'a FunctionFinder,
37    extractor: &'a CallExtractor,
38}
39
40impl<'a> CallGraphBuilder<'a> {
41    /// Create a new CallGraphBuilder
42    ///
43    /// # Arguments
44    /// * `direction` - Whether to trace forward or backward
45    /// * `max_depth` - Maximum depth of the call tree
46    /// * `finder` - Service to find function definitions
47    /// * `extractor` - Service to extract calls from code
48    pub fn new(
49        direction: TraceDirection,
50        max_depth: usize,
51        finder: &'a FunctionFinder,
52        extractor: &'a CallExtractor,
53    ) -> Self {
54        Self {
55            direction,
56            max_depth,
57            finder,
58            extractor,
59        }
60    }
61
62    /// Build a call trace tree starting from the given function
63    pub fn build_trace(&self, start_fn: &FunctionDef) -> Result<Option<CallTree>> {
64        let mut current_path = HashSet::new();
65
66        match self.build_node(start_fn, 0, &mut current_path) {
67            Some(root) => Ok(Some(CallTree { root })),
68            None => Ok(None),
69        }
70    }
71
72    /// Recursively build a call tree node
73    ///
74    /// Uses proper cycle detection with current_path to prevent infinite recursion
75    /// while still allowing the same function to appear in different branches.
76    fn build_node(
77        &self,
78        func: &FunctionDef,
79        depth: usize,
80        current_path: &mut HashSet<FunctionDef>,
81    ) -> Option<CallNode> {
82        // Check depth limit
83        if depth >= self.max_depth {
84            return Some(CallNode {
85                def: func.clone(),
86                children: vec![],
87                truncated: true,
88            });
89        }
90
91        // Check for cycles in current path (prevents infinite recursion)
92        if current_path.contains(func) {
93            return Some(CallNode {
94                def: func.clone(),
95                children: vec![],
96                truncated: false, // Not truncated by depth, but by cycle
97            });
98        }
99
100        // Add to current path for cycle detection
101        current_path.insert(func.clone());
102
103        let children = match self.direction {
104            TraceDirection::Forward => self.build_forward_children(func, depth, current_path),
105            TraceDirection::Backward => self.build_backward_children(func, depth, current_path),
106        };
107
108        // Remove from current path (allows same function in different branches)
109        current_path.remove(func);
110
111        Some(CallNode {
112            def: func.clone(),
113            children,
114            truncated: false,
115        })
116    }
117
118    /// Build children for forward tracing (what does this function call?)
119    fn build_forward_children(
120        &self,
121        func: &FunctionDef,
122        depth: usize,
123        current_path: &mut HashSet<FunctionDef>,
124    ) -> Vec<CallNode> {
125        // Extract function calls from this function's body
126        let call_names = match self.extractor.extract_calls(func) {
127            Ok(calls) => calls,
128            Err(_) => return vec![], // If extraction fails, return empty children
129        };
130
131        let mut children = Vec::new();
132
133        for call_name in call_names {
134            // Find the definition of the called function
135            if let Some(called_func) = self.finder.find_function(&call_name) {
136                // Recursively build the child node
137                if let Some(child_node) = self.build_node(&called_func, depth + 1, current_path) {
138                    children.push(child_node);
139                }
140            }
141            // If function not found, we simply don't include it (graceful handling)
142        }
143
144        children
145    }
146
147    /// Build children for backward tracing (who calls this function?)
148    fn build_backward_children(
149        &self,
150        func: &FunctionDef,
151        depth: usize,
152        current_path: &mut HashSet<FunctionDef>,
153    ) -> Vec<CallNode> {
154        // Find all functions that call this function
155        let callers = match self.extractor.find_callers(&func.name) {
156            Ok(caller_infos) => caller_infos,
157            Err(_) => return vec![], // If finding callers fails, return empty children
158        };
159
160        let mut children = Vec::new();
161
162        for caller_info in callers {
163            // Try to find the caller function definition
164            if let Some(caller_func) = self.finder.find_function(&caller_info.caller_name) {
165                // Avoid adding the same caller multiple times
166                if !children.iter().any(|child: &CallNode| {
167                    child.def.name == caller_func.name && child.def.file == caller_func.file
168                }) {
169                    // Recursively build the child node
170                    if let Some(child_node) = self.build_node(&caller_func, depth + 1, current_path)
171                    {
172                        children.push(child_node);
173                    }
174                }
175            }
176            // If caller function not found, we simply don't include it (graceful handling)
177        }
178
179        children
180    }
181}
182
183impl CallTree {
184    /// Get the total number of nodes in the tree
185    pub fn node_count(&self) -> usize {
186        self.count_nodes(&self.root)
187    }
188
189    /// Get the maximum depth of the tree
190    pub fn max_depth(&self) -> usize {
191        self.calculate_depth(&self.root, 0)
192    }
193
194    /// Check if the tree contains cycles
195    pub fn has_cycles(&self) -> bool {
196        let mut visited = HashSet::new();
197        let mut path = HashSet::new();
198        self.has_cycle_helper(&self.root, &mut visited, &mut path)
199    }
200
201    fn count_nodes(&self, node: &CallNode) -> usize {
202        1 + node
203            .children
204            .iter()
205            .map(|child| self.count_nodes(child))
206            .sum::<usize>()
207    }
208
209    fn calculate_depth(&self, node: &CallNode, current_depth: usize) -> usize {
210        if node.children.is_empty() {
211            current_depth
212        } else {
213            node.children
214                .iter()
215                .map(|child| self.calculate_depth(child, current_depth + 1))
216                .max()
217                .unwrap_or(current_depth)
218        }
219    }
220
221    fn has_cycle_helper(
222        &self,
223        node: &CallNode,
224        visited: &mut HashSet<FunctionDef>,
225        path: &mut HashSet<FunctionDef>,
226    ) -> bool {
227        if path.contains(&node.def) {
228            return true; // Found a cycle
229        }
230
231        if visited.contains(&node.def) {
232            return false; // Already processed this node
233        }
234
235        visited.insert(node.def.clone());
236        path.insert(node.def.clone());
237
238        for child in &node.children {
239            if self.has_cycle_helper(child, visited, path) {
240                return true;
241            }
242        }
243
244        path.remove(&node.def);
245        false
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use std::path::PathBuf;
253
254    fn create_test_function(name: &str, file: &str, line: usize) -> FunctionDef {
255        FunctionDef {
256            name: name.to_string(),
257            file: PathBuf::from(file),
258            line,
259            body: format!("function {}() {{}}", name),
260        }
261    }
262
263    #[test]
264    fn test_trace_direction_equality() {
265        assert_eq!(TraceDirection::Forward, TraceDirection::Forward);
266        assert_eq!(TraceDirection::Backward, TraceDirection::Backward);
267        assert_ne!(TraceDirection::Forward, TraceDirection::Backward);
268    }
269
270    #[test]
271    fn test_call_node_creation() {
272        let func = create_test_function("test_func", "test.js", 10);
273        let node = CallNode {
274            def: func.clone(),
275            children: vec![],
276            truncated: false,
277        };
278
279        assert_eq!(node.def.name, "test_func");
280        assert_eq!(node.children.len(), 0);
281        assert!(!node.truncated);
282    }
283
284    #[test]
285    fn test_call_tree_creation() {
286        let func = create_test_function("main", "main.js", 1);
287        let root = CallNode {
288            def: func,
289            children: vec![],
290            truncated: false,
291        };
292        let tree = CallTree { root };
293
294        assert_eq!(tree.root.def.name, "main");
295    }
296
297    #[test]
298    fn test_call_tree_node_count() {
299        let main_func = create_test_function("main", "main.js", 1);
300        let helper_func = create_test_function("helper", "utils.js", 5);
301
302        let helper_node = CallNode {
303            def: helper_func,
304            children: vec![],
305            truncated: false,
306        };
307
308        let root = CallNode {
309            def: main_func,
310            children: vec![helper_node],
311            truncated: false,
312        };
313
314        let tree = CallTree { root };
315        assert_eq!(tree.node_count(), 2);
316    }
317
318    #[test]
319    fn test_call_tree_max_depth() {
320        let func1 = create_test_function("func1", "test.js", 1);
321        let func2 = create_test_function("func2", "test.js", 10);
322        let func3 = create_test_function("func3", "test.js", 20);
323
324        // Create a chain: func1 -> func2 -> func3
325        let node3 = CallNode {
326            def: func3,
327            children: vec![],
328            truncated: false,
329        };
330
331        let node2 = CallNode {
332            def: func2,
333            children: vec![node3],
334            truncated: false,
335        };
336
337        let root = CallNode {
338            def: func1,
339            children: vec![node2],
340            truncated: false,
341        };
342
343        let tree = CallTree { root };
344        assert_eq!(tree.max_depth(), 2); // 0-indexed depth
345    }
346
347    #[test]
348    fn test_call_graph_builder_creation() {
349        use crate::trace::{CallExtractor, FunctionFinder};
350        use std::env;
351
352        let base_dir = env::current_dir().unwrap();
353        let finder = FunctionFinder::new(base_dir.clone());
354        let extractor = CallExtractor::new(base_dir);
355
356        let builder = CallGraphBuilder::new(TraceDirection::Forward, 5, &finder, &extractor);
357
358        assert_eq!(builder.direction, TraceDirection::Forward);
359        assert_eq!(builder.max_depth, 5);
360    }
361
362    #[test]
363    fn test_depth_limit_handling() {
364        use crate::trace::{CallExtractor, FunctionFinder};
365        use std::env;
366
367        let base_dir = env::current_dir().unwrap();
368        let finder = FunctionFinder::new(base_dir.clone());
369        let extractor = CallExtractor::new(base_dir);
370
371        let builder = CallGraphBuilder::new(
372            TraceDirection::Forward,
373            0, // Max depth of 0 should only return root
374            &finder,
375            &extractor,
376        );
377
378        let test_func = create_test_function("test", "test.js", 1);
379        let mut path = HashSet::new();
380        let result = builder.build_node(&test_func, 0, &mut path);
381
382        assert!(result.is_some());
383        let node = result.unwrap();
384        assert_eq!(node.def.name, "test");
385        assert_eq!(node.children.len(), 0); // Should have no children due to depth limit
386        assert!(node.truncated); // Should be truncated
387    }
388
389    #[test]
390    fn test_cycle_detection() {
391        use crate::trace::{CallExtractor, FunctionFinder};
392        use std::env;
393
394        let base_dir = env::current_dir().unwrap();
395        let finder = FunctionFinder::new(base_dir.clone());
396        let extractor = CallExtractor::new(base_dir);
397
398        let builder = CallGraphBuilder::new(TraceDirection::Forward, 10, &finder, &extractor);
399
400        let test_func = create_test_function("recursive", "test.js", 1);
401        let mut path = HashSet::new();
402
403        // Add the function to path to simulate cycle detection
404        path.insert(test_func.clone());
405
406        let result = builder.build_node(&test_func, 0, &mut path);
407
408        assert!(result.is_some());
409        let node = result.unwrap();
410        assert_eq!(node.children.len(), 0); // Should stop due to cycle detection
411    }
412
413    #[test]
414    fn test_function_def_equality() {
415        let func1 = create_test_function("test", "file.js", 10);
416        let func2 = create_test_function("test", "file.js", 10);
417        let func3 = create_test_function("test", "file.js", 20);
418
419        assert_eq!(func1, func2);
420        assert_ne!(func1, func3); // Different line numbers
421    }
422}