use crate::error::Result;
use crate::trace::{CallExtractor, FunctionDef, FunctionFinder};
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TraceDirection {
Forward,
Backward,
}
#[derive(Debug, Clone)]
pub struct CallNode {
pub def: FunctionDef,
pub children: Vec<CallNode>,
pub truncated: bool,
}
#[derive(Debug, Clone)]
pub struct CallTree {
pub root: CallNode,
}
pub struct CallGraphBuilder<'a> {
direction: TraceDirection,
max_depth: usize,
finder: &'a mut FunctionFinder,
extractor: &'a CallExtractor,
}
impl<'a> CallGraphBuilder<'a> {
pub fn new(
direction: TraceDirection,
max_depth: usize,
finder: &'a mut FunctionFinder,
extractor: &'a CallExtractor,
) -> Self {
Self {
direction,
max_depth,
finder,
extractor,
}
}
pub fn build_trace(&mut self, start_fn: &FunctionDef) -> Result<Option<CallTree>> {
let mut current_path = HashSet::new();
match self.build_node(start_fn, 0, &mut current_path) {
Some(root) => Ok(Some(CallTree { root })),
None => Ok(None),
}
}
fn build_node(
&mut self,
func: &FunctionDef,
depth: usize,
current_path: &mut HashSet<FunctionDef>,
) -> Option<CallNode> {
if depth >= self.max_depth {
return Some(CallNode {
def: func.clone(),
children: vec![],
truncated: true,
});
}
if current_path.contains(func) {
return Some(CallNode {
def: func.clone(),
children: vec![],
truncated: false, });
}
current_path.insert(func.clone());
let children = match self.direction {
TraceDirection::Forward => self.build_forward_children(func, depth, current_path),
TraceDirection::Backward => self.build_backward_children(func, depth, current_path),
};
current_path.remove(func);
Some(CallNode {
def: func.clone(),
children,
truncated: false,
})
}
fn build_forward_children(
&mut self,
func: &FunctionDef,
depth: usize,
current_path: &mut HashSet<FunctionDef>,
) -> Vec<CallNode> {
let call_names = match self.extractor.extract_calls(func) {
Ok(calls) => calls,
Err(_) => return vec![], };
let mut children = Vec::new();
for call_name in call_names {
if let Some(called_func) = self.finder.find_function(&call_name) {
if let Some(child_node) = self.build_node(&called_func, depth + 1, current_path) {
children.push(child_node);
}
}
}
children
}
fn build_backward_children(
&mut self,
func: &FunctionDef,
depth: usize,
current_path: &mut HashSet<FunctionDef>,
) -> Vec<CallNode> {
let callers = match self.extractor.find_callers(&func.name) {
Ok(caller_infos) => caller_infos,
Err(_) => return vec![], };
let mut children = Vec::new();
for caller_info in callers {
if let Some(caller_func) = self.finder.find_function(&caller_info.caller_name) {
if !children.iter().any(|child: &CallNode| {
child.def.name == caller_func.name && child.def.file == caller_func.file
}) {
if let Some(child_node) = self.build_node(&caller_func, depth + 1, current_path)
{
children.push(child_node);
}
}
}
}
children
}
}
impl CallTree {
pub fn node_count(&self) -> usize {
Self::count_nodes(&self.root)
}
pub fn max_depth(&self) -> usize {
Self::calculate_depth(&self.root, 0)
}
pub fn has_cycles(&self) -> bool {
let mut visited = HashSet::new();
let mut path = HashSet::new();
Self::has_cycle_helper(&self.root, &mut visited, &mut path)
}
fn count_nodes(node: &CallNode) -> usize {
1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
}
fn calculate_depth(node: &CallNode, current_depth: usize) -> usize {
if node.children.is_empty() {
current_depth
} else {
node.children
.iter()
.map(|child| Self::calculate_depth(child, current_depth + 1))
.max()
.unwrap_or(current_depth)
}
}
fn has_cycle_helper(
node: &CallNode,
visited: &mut HashSet<FunctionDef>,
path: &mut HashSet<FunctionDef>,
) -> bool {
if path.contains(&node.def) {
return true; }
if visited.contains(&node.def) {
return false; }
visited.insert(node.def.clone());
path.insert(node.def.clone());
for child in &node.children {
if Self::has_cycle_helper(child, visited, path) {
return true;
}
}
path.remove(&node.def);
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn create_test_function(name: &str, file: &str, line: usize) -> FunctionDef {
FunctionDef {
name: name.to_string(),
file: PathBuf::from(file),
line,
body: format!("function {}() {{}}", name),
}
}
#[test]
fn test_trace_direction_equality() {
assert_eq!(TraceDirection::Forward, TraceDirection::Forward);
assert_eq!(TraceDirection::Backward, TraceDirection::Backward);
assert_ne!(TraceDirection::Forward, TraceDirection::Backward);
}
#[test]
fn test_call_node_creation() {
let func = create_test_function("test_func", "test.js", 10);
let node = CallNode {
def: func.clone(),
children: vec![],
truncated: false,
};
assert_eq!(node.def.name, "test_func");
assert_eq!(node.children.len(), 0);
assert!(!node.truncated);
}
#[test]
fn test_call_tree_creation() {
let func = create_test_function("main", "main.js", 1);
let root = CallNode {
def: func,
children: vec![],
truncated: false,
};
let tree = CallTree { root };
assert_eq!(tree.root.def.name, "main");
}
#[test]
fn test_call_tree_node_count() {
let main_func = create_test_function("main", "main.js", 1);
let helper_func = create_test_function("helper", "utils.js", 5);
let helper_node = CallNode {
def: helper_func,
children: vec![],
truncated: false,
};
let root = CallNode {
def: main_func,
children: vec![helper_node],
truncated: false,
};
let tree = CallTree { root };
assert_eq!(tree.node_count(), 2);
}
#[test]
fn test_call_tree_max_depth() {
let func1 = create_test_function("func1", "test.js", 1);
let func2 = create_test_function("func2", "test.js", 10);
let func3 = create_test_function("func3", "test.js", 20);
let node3 = CallNode {
def: func3,
children: vec![],
truncated: false,
};
let node2 = CallNode {
def: func2,
children: vec![node3],
truncated: false,
};
let root = CallNode {
def: func1,
children: vec![node2],
truncated: false,
};
let tree = CallTree { root };
assert_eq!(tree.max_depth(), 2); }
#[test]
fn test_call_graph_builder_creation() {
use crate::trace::{CallExtractor, FunctionFinder};
use std::env;
let base_dir = env::current_dir().unwrap();
let mut finder = FunctionFinder::new(base_dir.clone());
let extractor = CallExtractor::new(base_dir);
let builder = CallGraphBuilder::new(TraceDirection::Forward, 5, &mut finder, &extractor);
assert_eq!(builder.direction, TraceDirection::Forward);
assert_eq!(builder.max_depth, 5);
}
#[test]
fn test_depth_limit_handling() {
use crate::trace::{CallExtractor, FunctionFinder};
use std::env;
let base_dir = env::current_dir().unwrap();
let mut finder = FunctionFinder::new(base_dir.clone());
let extractor = CallExtractor::new(base_dir);
let mut builder = CallGraphBuilder::new(
TraceDirection::Forward,
0, &mut finder,
&extractor,
);
let test_func = create_test_function("test", "test.js", 1);
let mut path = HashSet::new();
let result = builder.build_node(&test_func, 0, &mut path);
assert!(result.is_some());
let node = result.unwrap();
assert_eq!(node.def.name, "test");
assert_eq!(node.children.len(), 0); assert!(node.truncated); }
#[test]
fn test_cycle_detection() {
use crate::trace::{CallExtractor, FunctionFinder};
use std::env;
let base_dir = env::current_dir().unwrap();
let mut finder = FunctionFinder::new(base_dir.clone());
let extractor = CallExtractor::new(base_dir);
let mut builder =
CallGraphBuilder::new(TraceDirection::Forward, 10, &mut finder, &extractor);
let test_func = create_test_function("recursive", "test.js", 1);
let mut path = HashSet::new();
path.insert(test_func.clone());
let result = builder.build_node(&test_func, 0, &mut path);
assert!(result.is_some());
let node = result.unwrap();
assert_eq!(node.children.len(), 0); }
#[test]
fn test_function_def_equality() {
let func1 = create_test_function("test", "file.js", 10);
let func2 = create_test_function("test", "file.js", 10);
let func3 = create_test_function("test", "file.js", 20);
assert_eq!(func1, func2);
assert_ne!(func1, func3); }
}