use super::types::{CallGraph, CallType, FunctionCall, FunctionId, FunctionNode};
use im::{HashMap, HashSet, Vector};
use std::path::PathBuf;
impl CallGraph {
pub fn merge(&mut self, other: CallGraph) {
let mut sorted_nodes: Vec<_> = other.nodes.into_iter().collect();
sorted_nodes.sort_by(|a, b| a.0.cmp(&b.0));
for (id, node) in sorted_nodes {
self.add_function(
id,
node.is_entry_point,
node.is_test,
node.complexity,
node._lines,
);
}
for call in other.edges {
self.add_call(call);
}
}
pub fn add_function(
&mut self,
id: FunctionId,
is_entry_point: bool,
is_test: bool,
complexity: u32,
lines: usize,
) {
let node = FunctionNode {
id: id.clone(),
is_entry_point,
is_test,
complexity,
_lines: lines,
};
self.nodes.insert(id.clone(), node);
let fuzzy_key = id.fuzzy_key();
self.fuzzy_index
.entry(fuzzy_key)
.or_default()
.push(id.clone());
let normalized_name = FunctionId::normalize_name(&id.name);
self.name_index.entry(normalized_name).or_default().push(id);
}
pub fn add_call(&mut self, call: FunctionCall) {
let caller = call.caller.clone();
let callee = call.callee.clone();
self.edges.push_back(call);
self.callee_index
.entry(caller.clone())
.or_default()
.insert(callee.clone());
self.caller_index.entry(callee).or_default().insert(caller);
}
pub fn add_call_parts(&mut self, caller: FunctionId, callee: FunctionId, call_type: CallType) {
self.add_call(FunctionCall {
caller,
callee,
call_type,
});
}
pub fn get_callees(&self, func_id: &FunctionId) -> Vec<FunctionId> {
let canonical_func_id = self
.find_function(func_id)
.unwrap_or_else(|| func_id.clone());
let mut callees: Vec<FunctionId> = self
.callee_index
.get(&canonical_func_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
callees.sort();
callees
}
pub fn get_callees_exact(&self, func_id: &FunctionId) -> Vec<FunctionId> {
let mut callees: Vec<FunctionId> = self
.callee_index
.get(func_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
callees.sort();
callees
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn get_callers(&self, func_id: &FunctionId) -> Vec<FunctionId> {
let canonical_func_id = self
.find_function(func_id)
.unwrap_or_else(|| func_id.clone());
let mut callers: Vec<FunctionId> = self
.caller_index
.get(&canonical_func_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
callers.sort();
callers
}
pub fn get_callers_exact(&self, func_id: &FunctionId) -> Vec<FunctionId> {
let mut callers: Vec<FunctionId> = self
.caller_index
.get(func_id)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default();
callers.sort();
callers
}
pub fn get_dependency_count(&self, func_id: &FunctionId) -> usize {
self.get_callers(func_id).len()
}
pub fn get_all_functions(&self) -> impl Iterator<Item = &FunctionId> {
let mut funcs: Vec<&FunctionId> = self.nodes.keys().collect();
funcs.sort();
funcs.into_iter()
}
pub fn get_function_info(&self, func_id: &FunctionId) -> Option<(bool, bool, u32, usize)> {
self.nodes.get(func_id).map(|node| {
(
node.is_entry_point,
node.is_test,
node.complexity,
node._lines,
)
})
}
pub fn mark_as_trait_dispatch(&mut self, func_id: FunctionId) {
if !self.nodes.contains_key(&func_id) {
self.add_function(func_id.clone(), false, false, 0, 0);
}
if let Some(node) = self.nodes.get_mut(&func_id) {
node.is_entry_point = true;
}
}
pub fn is_entry_point(&self, func_id: &FunctionId) -> bool {
self.nodes
.get(func_id)
.map(|n| n.is_entry_point)
.unwrap_or(false)
}
pub fn is_test_function(&self, func_id: &FunctionId) -> bool {
self.nodes.get(func_id).map(|n| n.is_test).unwrap_or(false)
}
pub fn has_test_function_named(&self, name: &str) -> bool {
let normalized_name = FunctionId::normalize_name(name);
if self
.name_index
.get(&normalized_name)
.is_some_and(|candidates| candidates.iter().any(|id| self.is_test_function(id)))
{
return true;
}
let suffix_pattern = format!("::{}", normalized_name);
self.nodes
.keys()
.any(|id| id.name.ends_with(&suffix_pattern) && self.is_test_function(id))
}
pub fn add_edge_by_name(&mut self, from: String, to: String, file: PathBuf) {
let from_id = FunctionId::new(
file.clone(),
from,
0, );
let to_id = FunctionId::new(file.clone(), to, 0);
if !self.nodes.contains_key(&from_id) {
self.add_function(from_id.clone(), false, false, 0, 0);
}
if !self.nodes.contains_key(&to_id) {
self.add_function(to_id.clone(), false, false, 0, 0);
}
self.add_call(FunctionCall {
caller: from_id,
callee: to_id,
call_type: CallType::Direct,
});
}
pub fn get_callees_by_name(&self, function: &str) -> Vec<String> {
let mut results: Vec<String> = self
.nodes
.keys()
.filter(|id| id.name == function)
.flat_map(|id| self.get_callees(id))
.map(|id| id.name.clone())
.collect::<HashSet<_>>()
.into_iter()
.collect();
results.sort();
results
}
pub fn get_callers_by_name(&self, function: &str) -> Vec<String> {
let mut results: Vec<String> = self
.nodes
.keys()
.filter(|id| id.name == function)
.flat_map(|id| self.get_callers(id))
.map(|id| id.name.clone())
.collect::<HashSet<_>>()
.into_iter()
.collect();
results.sort();
results
}
pub fn get_transitive_callees(
&self,
func_id: &FunctionId,
max_depth: usize,
) -> HashSet<FunctionId> {
let mut visited = HashSet::new();
let mut to_visit = Vector::new();
to_visit.push_back((func_id.clone(), 0));
while let Some((current, depth)) = to_visit.pop_front() {
if depth >= max_depth || visited.contains(¤t) {
continue;
}
visited.insert(current.clone());
for callee in self.get_callees(¤t) {
if !visited.contains(&callee) {
to_visit.push_back((callee, depth + 1));
}
}
}
visited.remove(func_id);
visited
}
pub fn get_transitive_callers(
&self,
func_id: &FunctionId,
max_depth: usize,
) -> HashSet<FunctionId> {
let mut visited = HashSet::new();
let mut to_visit = Vector::new();
to_visit.push_back((func_id.clone(), 0));
while let Some((current, depth)) = to_visit.pop_front() {
if visited.contains(¤t) {
continue;
}
visited.insert(current.clone());
if depth < max_depth {
for caller in self.get_callers(¤t) {
if !visited.contains(&caller) {
to_visit.push_back((caller, depth + 1));
}
}
}
}
visited.remove(func_id);
visited
}
pub fn find_entry_points(&self) -> Vec<FunctionId> {
let mut results: Vec<FunctionId> = self
.nodes
.values()
.filter(|node| node.is_entry_point)
.map(|node| node.id.clone())
.collect();
results.sort();
results
}
pub fn find_all_functions(&self) -> Vec<FunctionId> {
let mut results: Vec<FunctionId> = self.nodes.keys().cloned().collect();
results.sort();
results
}
pub fn get_functions_by_file(&self, file: &PathBuf) -> Vec<FunctionId> {
let mut results: Vec<FunctionId> = self
.nodes
.keys()
.filter(|id| &id.file == file)
.cloned()
.collect();
results.sort();
results
}
pub fn get_functions_by_name(&self, name: &str) -> Vec<FunctionId> {
let mut results: Vec<FunctionId> = self
.nodes
.keys()
.filter(|id| id.name == name)
.cloned()
.collect();
results.sort();
results
}
pub fn get_functions_with_no_callers(&self) -> Vec<FunctionId> {
let mut results: Vec<FunctionId> = self
.nodes
.keys()
.filter(|id| {
!self.caller_index.contains_key(id)
|| self.caller_index.get(id).is_none_or(|set| set.is_empty())
})
.cloned()
.collect();
results.sort();
results
}
pub fn get_function_calls(&self, func_id: &FunctionId) -> Vec<FunctionCall> {
let mut results: Vec<FunctionCall> = self
.edges
.iter()
.filter(|call| &call.caller == func_id)
.cloned()
.collect();
results.sort();
results
}
pub fn get_all_calls(&self) -> Vec<FunctionCall> {
let mut results: Vec<FunctionCall> = self.edges.iter().cloned().collect();
results.sort();
results
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn find_function(&self, query: &FunctionId) -> Option<FunctionId> {
if self.nodes.contains_key(query) {
return Some(query.clone());
}
let fuzzy_key = query.fuzzy_key();
if let Some(candidates) = self.fuzzy_index.get(&fuzzy_key) {
if candidates.len() == 1 {
return Some(candidates[0].clone());
}
if let Some(best) = Self::disambiguate_by_line(candidates, query.line) {
return Some(best);
}
}
let normalized_name = FunctionId::normalize_name(&query.name);
if let Some(candidates) = self.name_index.get(&normalized_name) {
if let Some(best) = Self::disambiguate_by_module(candidates, &query.module_path) {
return Some(best);
}
if let Some(best) = Self::disambiguate_by_line(candidates, query.line) {
return Some(best);
}
}
None
}
fn disambiguate_by_line(candidates: &[FunctionId], target_line: usize) -> Option<FunctionId> {
candidates
.iter()
.min_by_key(|func_id| target_line.abs_diff(func_id.line))
.cloned()
}
fn disambiguate_by_module(
candidates: &[FunctionId],
target_module: &str,
) -> Option<FunctionId> {
candidates
.iter()
.find(|func_id| func_id.module_path == target_module)
.cloned()
}
pub fn find_function_at_location(&self, file: &PathBuf, line: usize) -> Option<FunctionId> {
let functions_in_file = Self::functions_in_file(&self.nodes, file);
Self::find_best_line_match(&functions_in_file, line)
}
pub fn functions_in_file<'a>(
nodes: &'a HashMap<FunctionId, FunctionNode>,
file: &PathBuf,
) -> Vec<&'a FunctionId> {
let mut results: Vec<&'a FunctionId> = nodes.keys().filter(|id| &id.file == file).collect();
results.sort();
results
}
pub fn find_best_line_match(
functions: &[&FunctionId],
target_line: usize,
) -> Option<FunctionId> {
functions
.iter()
.filter(|func_id| func_id.line <= target_line)
.min_by_key(|func_id| target_line - func_id.line)
.map(|&func_id| func_id.clone())
}
pub fn is_recursive(&self, func_id: &FunctionId) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
self.has_cycle_dfs_iterative(func_id, &mut visited, &mut rec_stack)
}
fn has_cycle_dfs_iterative(
&self,
start: &FunctionId,
visited: &mut HashSet<FunctionId>,
rec_stack: &mut HashSet<FunctionId>,
) -> bool {
enum CycleState {
Enter(FunctionId),
Exit(FunctionId),
}
let mut stack = Vec::with_capacity(self.nodes.len().min(1024));
stack.push(CycleState::Enter(start.clone()));
while let Some(state) = stack.pop() {
match state {
CycleState::Enter(node) => {
if visited.contains(&node) {
continue;
}
visited.insert(node.clone());
rec_stack.insert(node.clone());
stack.push(CycleState::Exit(node.clone()));
for callee in self.get_callees(&node) {
if rec_stack.contains(&callee) {
return true; }
if !visited.contains(&callee) {
stack.push(CycleState::Enter(callee));
}
}
}
CycleState::Exit(node) => {
rec_stack.remove(&node);
}
}
}
false
}
pub fn topological_sort(&self) -> Result<Vec<FunctionId>, String> {
let mut visited = HashSet::new();
let mut result = Vector::new();
let mut nodes: Vec<_> = self.nodes.keys().collect();
nodes.sort();
for func_id in nodes {
if !visited.contains(func_id) {
self.topo_sort_dfs_iterative(func_id, &mut visited, &mut result);
}
}
Ok(result.iter().cloned().collect())
}
fn topo_sort_dfs_iterative(
&self,
start: &FunctionId,
visited: &mut HashSet<FunctionId>,
result: &mut Vector<FunctionId>,
) {
enum TopoState {
Visit(FunctionId),
Finish(FunctionId),
}
let mut stack = Vec::with_capacity(self.nodes.len().min(1024));
stack.push(TopoState::Visit(start.clone()));
while let Some(state) = stack.pop() {
match state {
TopoState::Visit(node) => {
if visited.contains(&node) {
continue;
}
visited.insert(node.clone());
stack.push(TopoState::Finish(node.clone()));
for callee in self.get_callees(&node) {
if !visited.contains(&callee) {
stack.push(TopoState::Visit(callee));
}
}
}
TopoState::Finish(node) => {
result.push_back(node);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_lookup() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
let result = graph.find_function(&func_id);
assert_eq!(result, Some(func_id));
}
#[test]
fn test_fuzzy_lookup_different_line() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
let query = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 150);
let result = graph.find_function(&query);
assert_eq!(result, Some(func_id));
}
#[test]
fn test_fuzzy_lookup_generic_function() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "map".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
let query = FunctionId::new(PathBuf::from("test.rs"), "map<String>".to_string(), 100);
let result = graph.find_function(&query);
assert_eq!(result, Some(func_id));
}
#[test]
fn test_name_only_lookup() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
let query = FunctionId::new(PathBuf::from("other.rs"), "foo".to_string(), 50);
let result = graph.find_function(&query);
assert_eq!(result, Some(func_id));
}
#[test]
fn test_disambiguate_by_line_proximity() {
let mut graph = CallGraph::new();
let func1 = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 100);
let func2 = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 200);
graph.add_function(func1.clone(), false, false, 5, 10);
graph.add_function(func2.clone(), false, false, 5, 10);
let query = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 120);
let result = graph.find_function(&query);
assert_eq!(result, Some(func1));
let query = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 190);
let result = graph.find_function(&query);
assert_eq!(result, Some(func2));
}
#[test]
fn test_disambiguate_by_module_path() {
let mut graph = CallGraph::new();
let func1 = FunctionId::with_module_path(
PathBuf::from("test.rs"),
"foo".to_string(),
100,
"module1".to_string(),
);
let func2 = FunctionId::with_module_path(
PathBuf::from("other.rs"),
"foo".to_string(),
100,
"module2".to_string(),
);
graph.add_function(func1.clone(), false, false, 5, 10);
graph.add_function(func2.clone(), false, false, 5, 10);
let query = FunctionId::with_module_path(
PathBuf::from("another.rs"),
"foo".to_string(),
50,
"module1".to_string(),
);
let result = graph.find_function(&query);
assert_eq!(result, Some(func1));
}
#[test]
fn test_no_match_returns_none() {
let graph = CallGraph::new();
let query = FunctionId::new(PathBuf::from("test.rs"), "nonexistent".to_string(), 100);
let result = graph.find_function(&query);
assert_eq!(result, None);
}
#[test]
fn test_lookup_chain_short_circuits() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "foo".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
let result = graph.find_function(&func_id);
assert_eq!(result, Some(func_id));
}
#[test]
fn test_is_recursive_direct() {
let mut graph = CallGraph::new();
let func_id = FunctionId::new(PathBuf::from("test.rs"), "factorial".to_string(), 100);
graph.add_function(func_id.clone(), false, false, 5, 10);
graph.add_call(FunctionCall {
caller: func_id.clone(),
callee: func_id.clone(),
call_type: CallType::Direct,
});
assert!(graph.is_recursive(&func_id));
}
#[test]
fn test_is_recursive_indirect() {
let mut graph = CallGraph::new();
let func_a = FunctionId::new(PathBuf::from("test.rs"), "a".to_string(), 100);
let func_b = FunctionId::new(PathBuf::from("test.rs"), "b".to_string(), 200);
graph.add_function(func_a.clone(), false, false, 5, 10);
graph.add_function(func_b.clone(), false, false, 5, 10);
graph.add_call(FunctionCall {
caller: func_a.clone(),
callee: func_b.clone(),
call_type: CallType::Direct,
});
graph.add_call(FunctionCall {
caller: func_b.clone(),
callee: func_a.clone(),
call_type: CallType::Direct,
});
assert!(graph.is_recursive(&func_a));
assert!(graph.is_recursive(&func_b));
}
#[test]
fn test_topological_sort_simple() {
let mut graph = CallGraph::new();
let func_a = FunctionId::new(PathBuf::from("test.rs"), "a".to_string(), 100);
let func_b = FunctionId::new(PathBuf::from("test.rs"), "b".to_string(), 200);
let func_c = FunctionId::new(PathBuf::from("test.rs"), "c".to_string(), 300);
graph.add_function(func_a.clone(), false, false, 5, 10);
graph.add_function(func_b.clone(), false, false, 5, 10);
graph.add_function(func_c.clone(), false, false, 5, 10);
graph.add_call(FunctionCall {
caller: func_a.clone(),
callee: func_b.clone(),
call_type: CallType::Direct,
});
graph.add_call(FunctionCall {
caller: func_b.clone(),
callee: func_c.clone(),
call_type: CallType::Direct,
});
let sorted = graph.topological_sort().unwrap();
let c_pos = sorted.iter().position(|id| id == &func_c).unwrap();
let b_pos = sorted.iter().position(|id| id == &func_b).unwrap();
let a_pos = sorted.iter().position(|id| id == &func_a).unwrap();
assert!(c_pos < b_pos);
assert!(b_pos < a_pos);
}
}