use crate::{
disasm::Disassembler,
types::{
CallContext, CallGraph, CallGraphConfig, CallGraphEdge, CallGraphNode, CallGraphStatistics,
CallSite, CallType, Function, Instruction, NodeType,
},
BinaryError, BinaryFile, Result,
};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Clone)]
pub struct CallGraphAnalyzer {
config: CallGraphConfig,
}
impl CallGraphAnalyzer {
pub fn new(config: CallGraphConfig) -> Self {
Self { config }
}
pub fn new_default() -> Self {
Self {
config: CallGraphConfig::default(),
}
}
pub fn analyze_binary(&self, binary: &BinaryFile) -> Result<CallGraph> {
let functions = self.extract_functions(binary)?;
let mut nodes = Vec::new();
let mut address_to_node: HashMap<u64, usize> = HashMap::new();
for (i, function) in functions.iter().enumerate() {
let node = CallGraphNode {
function_address: function.start_address,
function_name: function.name.clone(),
node_type: self.classify_node_type(function, binary),
complexity: 0, in_degree: 0, out_degree: 0, is_recursive: false, call_depth: None, };
nodes.push(node);
address_to_node.insert(function.start_address, i);
}
let edges = self.extract_function_calls(binary, &functions, &address_to_node)?;
self.update_node_degrees(&mut nodes, &edges);
self.detect_recursion(&mut nodes, &edges);
let entry_points = self.find_entry_points(binary, &functions);
let mut call_graph = CallGraph {
nodes,
edges,
entry_points: entry_points.clone(),
unreachable_functions: Vec::new(),
statistics: CallGraphStatistics::default(),
};
self.compute_call_depths(&mut call_graph)?;
call_graph.unreachable_functions = self.find_unreachable_functions(&call_graph);
call_graph.statistics = self.compute_statistics(&call_graph);
Ok(call_graph)
}
fn extract_functions(&self, binary: &BinaryFile) -> Result<Vec<Function>> {
let mut functions = Vec::new();
let mut seen_addresses = HashSet::new();
for symbol in binary.symbols() {
if matches!(symbol.symbol_type, crate::types::SymbolType::Function)
&& symbol.size > 0
&& !seen_addresses.contains(&symbol.address)
{
let function = Function {
name: symbol.name.clone(),
start_address: symbol.address,
end_address: symbol.address + symbol.size,
size: symbol.size,
function_type: crate::types::FunctionType::Normal,
calling_convention: None,
parameters: Vec::new(),
return_type: None,
};
functions.push(function);
seen_addresses.insert(symbol.address);
}
}
if let Some(entry_point) = binary.entry_point() {
if !seen_addresses.contains(&entry_point) {
let function = Function {
name: "_start".to_string(),
start_address: entry_point,
end_address: entry_point + 1000, size: 1000,
function_type: crate::types::FunctionType::Entrypoint,
calling_convention: None,
parameters: Vec::new(),
return_type: None,
};
functions.push(function);
}
}
Ok(functions)
}
fn classify_node_type(&self, function: &Function, binary: &BinaryFile) -> NodeType {
if matches!(
function.function_type,
crate::types::FunctionType::Entrypoint
) || matches!(function.function_type, crate::types::FunctionType::Main)
|| function.name == "_start"
|| function.name == "main"
|| function.name == "DllMain"
{
return NodeType::EntryPoint;
}
for import in binary.imports() {
if let Some(addr) = import.address {
if addr == function.start_address {
return NodeType::External;
}
}
}
if self.is_library_function(&function.name) {
return NodeType::Library;
}
NodeType::Internal
}
fn is_library_function(&self, name: &str) -> bool {
const LIBRARY_PREFIXES: &[&str] =
&["libc_", "libm_", "__", "_GLOBAL_", "std::", "_ZN", "_Z"];
const LIBRARY_NAMES: &[&str] = &[
"printf", "scanf", "malloc", "free", "strlen", "strcpy", "strcat", "memcpy", "memset",
"fopen", "fclose", "fread", "fwrite", "msvcrt", "kernel32", "ntdll",
];
for prefix in LIBRARY_PREFIXES {
if name.starts_with(prefix) {
return true;
}
}
for lib_name in LIBRARY_NAMES {
if name == *lib_name || name.contains(lib_name) {
return true;
}
}
false
}
fn extract_function_calls(
&self,
binary: &BinaryFile,
functions: &[Function],
address_to_node: &HashMap<u64, usize>,
) -> Result<Vec<CallGraphEdge>> {
let mut edges = Vec::new();
let disassembler = Disassembler::new(binary.architecture())?;
for function in functions {
if let Ok(instructions) =
self.get_function_instructions(binary, function, &disassembler)
{
for instruction in &instructions {
if let Some(edge) = self.analyze_call_instruction(
instruction,
function.start_address,
address_to_node,
) {
edges.push(edge);
}
}
}
}
if self.config.detect_tail_calls {
let tail_call_edges = self.detect_tail_calls(binary, functions, address_to_node)?;
edges.extend(tail_call_edges);
}
Ok(edges)
}
fn get_function_instructions(
&self,
binary: &BinaryFile,
function: &Function,
disassembler: &Disassembler,
) -> Result<Vec<Instruction>> {
for section in binary.sections() {
let start = section.address;
let end = start + section.size;
if function.start_address >= start && function.start_address < end {
let data = section.data.as_ref().ok_or_else(|| {
BinaryError::invalid_data("Section data not available for disassembly")
})?;
let offset = (function.start_address - start) as usize;
if offset >= data.len() {
return Ok(Vec::new());
}
let available = data.len() - offset;
let length = std::cmp::min(function.size as usize, available);
if length == 0 {
return Ok(Vec::new());
}
let slice = &data[offset..offset + length];
return disassembler.disassemble_at(slice, function.start_address, length);
}
}
Ok(Vec::new())
}
fn analyze_call_instruction(
&self,
instruction: &Instruction,
caller_address: u64,
address_to_node: &HashMap<u64, usize>,
) -> Option<CallGraphEdge> {
match &instruction.flow {
crate::types::ControlFlow::Call(target_address) => {
if address_to_node.contains_key(target_address) {
let call_site = CallSite {
address: instruction.address,
instruction_bytes: instruction.bytes.clone(),
context: CallContext::Normal,
};
let call_type = if *target_address == caller_address {
CallType::Recursive
} else {
CallType::Direct
};
return Some(CallGraphEdge {
caller: caller_address,
callee: *target_address,
call_type,
call_sites: vec![call_site],
});
}
}
_ => {
if self.config.analyze_indirect_calls {
if let Some(edge) = self.analyze_indirect_call(instruction, caller_address) {
return Some(edge);
}
}
}
}
None
}
fn analyze_indirect_call(
&self,
instruction: &Instruction,
caller_address: u64,
) -> Option<CallGraphEdge> {
if instruction.mnemonic.starts_with("call") && instruction.operands.contains('[') {
let call_site = CallSite {
address: instruction.address,
instruction_bytes: instruction.bytes.clone(),
context: CallContext::Normal,
};
return Some(CallGraphEdge {
caller: caller_address,
callee: 0, call_type: CallType::Indirect,
call_sites: vec![call_site],
});
}
None
}
fn detect_tail_calls(
&self,
binary: &BinaryFile,
functions: &[Function],
address_to_node: &HashMap<u64, usize>,
) -> Result<Vec<CallGraphEdge>> {
let mut tail_call_edges = Vec::new();
let disassembler = Disassembler::new(binary.architecture())?;
for function in functions {
if let Ok(instructions) =
self.get_function_instructions(binary, function, &disassembler)
{
if let Some(last_instruction) = instructions.last() {
if let crate::types::ControlFlow::Jump(target) = &last_instruction.flow {
if address_to_node.contains_key(target) && *target != function.start_address
{
let call_site = CallSite {
address: last_instruction.address,
instruction_bytes: last_instruction.bytes.clone(),
context: CallContext::Normal,
};
tail_call_edges.push(CallGraphEdge {
caller: function.start_address,
callee: *target,
call_type: CallType::TailCall,
call_sites: vec![call_site],
});
}
}
}
}
}
Ok(tail_call_edges)
}
fn update_node_degrees(&self, nodes: &mut [CallGraphNode], edges: &[CallGraphEdge]) {
let mut in_degrees: HashMap<u64, u32> = HashMap::new();
let mut out_degrees: HashMap<u64, u32> = HashMap::new();
for edge in edges {
*out_degrees.entry(edge.caller).or_insert(0) += 1;
*in_degrees.entry(edge.callee).or_insert(0) += 1;
}
for node in nodes {
node.in_degree = in_degrees.get(&node.function_address).copied().unwrap_or(0);
node.out_degree = out_degrees
.get(&node.function_address)
.copied()
.unwrap_or(0);
}
}
fn detect_recursion(&self, nodes: &mut [CallGraphNode], edges: &[CallGraphEdge]) {
let mut recursive_functions = HashSet::new();
for edge in edges {
if edge.caller == edge.callee {
recursive_functions.insert(edge.caller);
}
}
let mut adjacency: HashMap<u64, Vec<u64>> = HashMap::new();
for edge in edges {
adjacency.entry(edge.caller).or_default().push(edge.callee);
}
for node in nodes.iter() {
if self.has_cycle_from_node(node.function_address, &adjacency) {
recursive_functions.insert(node.function_address);
}
}
for node in nodes {
node.is_recursive = recursive_functions.contains(&node.function_address);
}
}
fn has_cycle_from_node(&self, start: u64, adjacency: &HashMap<u64, Vec<u64>>) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
self.dfs_has_cycle(start, &mut visited, &mut rec_stack, adjacency)
}
#[allow(clippy::only_used_in_recursion)]
fn dfs_has_cycle(
&self,
node: u64,
visited: &mut HashSet<u64>,
rec_stack: &mut HashSet<u64>,
adjacency: &HashMap<u64, Vec<u64>>,
) -> bool {
visited.insert(node);
rec_stack.insert(node);
if let Some(neighbors) = adjacency.get(&node) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
if self.dfs_has_cycle(neighbor, visited, rec_stack, adjacency) {
return true;
}
} else if rec_stack.contains(&neighbor) {
return true;
}
}
}
rec_stack.remove(&node);
false
}
fn find_entry_points(&self, binary: &BinaryFile, functions: &[Function]) -> Vec<u64> {
let mut entry_points = Vec::new();
if let Some(entry) = binary.entry_point() {
entry_points.push(entry);
}
for function in functions {
match function.function_type {
crate::types::FunctionType::Entrypoint | crate::types::FunctionType::Main => {
if !entry_points.contains(&function.start_address) {
entry_points.push(function.start_address);
}
}
_ => {}
}
}
entry_points
}
fn compute_call_depths(&self, call_graph: &mut CallGraph) -> Result<()> {
let mut address_to_node: HashMap<u64, usize> = HashMap::new();
for (i, node) in call_graph.nodes.iter().enumerate() {
address_to_node.insert(node.function_address, i);
}
let mut adjacency: HashMap<u64, Vec<u64>> = HashMap::new();
for edge in &call_graph.edges {
adjacency.entry(edge.caller).or_default().push(edge.callee);
}
for &entry_point in &call_graph.entry_points {
let mut queue = VecDeque::new();
let mut visited = HashSet::new();
queue.push_back((entry_point, 0));
visited.insert(entry_point);
while let Some((current_addr, depth)) = queue.pop_front() {
if let Some(&node_index) = address_to_node.get(¤t_addr) {
let current_depth = call_graph.nodes[node_index].call_depth;
if current_depth.is_none() || current_depth.unwrap() > depth {
call_graph.nodes[node_index].call_depth = Some(depth);
}
}
if let Some(neighbors) = adjacency.get(¤t_addr) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back((neighbor, depth + 1));
}
}
}
}
}
Ok(())
}
fn find_unreachable_functions(&self, call_graph: &CallGraph) -> Vec<u64> {
let reachable: HashSet<u64> = call_graph
.nodes
.iter()
.filter(|node| node.call_depth.is_some())
.map(|node| node.function_address)
.collect();
call_graph
.nodes
.iter()
.filter(|node| !reachable.contains(&node.function_address))
.map(|node| node.function_address)
.collect()
}
fn compute_statistics(&self, call_graph: &CallGraph) -> CallGraphStatistics {
let total_functions = call_graph.nodes.len();
let total_calls = call_graph.edges.len();
let direct_calls = call_graph
.edges
.iter()
.filter(|edge| matches!(edge.call_type, CallType::Direct))
.count();
let indirect_calls = call_graph
.edges
.iter()
.filter(|edge| matches!(edge.call_type, CallType::Indirect))
.count();
let recursive_functions = call_graph
.nodes
.iter()
.filter(|node| node.is_recursive)
.count();
let leaf_functions = call_graph
.nodes
.iter()
.filter(|node| node.out_degree == 0)
.count();
let entry_points = call_graph.entry_points.len();
let unreachable_functions = call_graph.unreachable_functions.len();
let max_call_depth = call_graph
.nodes
.iter()
.filter_map(|node| node.call_depth)
.max()
.unwrap_or(0);
let depths: Vec<u32> = call_graph
.nodes
.iter()
.filter_map(|node| node.call_depth)
.collect();
let average_call_depth = if !depths.is_empty() {
depths.iter().sum::<u32>() as f64 / depths.len() as f64
} else {
0.0
};
let cyclic_dependencies = self.count_cyclic_dependencies(call_graph);
CallGraphStatistics {
total_functions,
total_calls,
direct_calls,
indirect_calls,
recursive_functions,
leaf_functions,
entry_points,
unreachable_functions,
max_call_depth,
average_call_depth,
cyclic_dependencies,
}
}
fn count_cyclic_dependencies(&self, call_graph: &CallGraph) -> usize {
call_graph
.nodes
.iter()
.filter(|node| node.is_recursive)
.count()
}
}
#[derive(Debug, Clone)]
pub struct DotConfig {
pub include_addresses: bool,
pub color_by_type: bool,
pub show_call_counts: bool,
pub cluster_by_module: bool,
pub max_nodes: Option<usize>,
}
impl Default for DotConfig {
fn default() -> Self {
Self {
include_addresses: true,
color_by_type: true,
show_call_counts: false,
cluster_by_module: false,
max_nodes: Some(1000),
}
}
}
pub trait CallGraphExporter {
fn export(&self, graph: &CallGraph) -> Result<String>;
}
pub struct DotExporter {
config: DotConfig,
}
impl DotExporter {
pub fn new(config: DotConfig) -> Self {
Self { config }
}
pub fn new_default() -> Self {
Self {
config: DotConfig::default(),
}
}
}
impl CallGraphExporter for DotExporter {
fn export(&self, graph: &CallGraph) -> Result<String> {
let mut dot = String::new();
dot.push_str("digraph CallGraph {\n");
dot.push_str(" rankdir=TB;\n");
dot.push_str(" node [shape=box, style=filled];\n");
dot.push_str(" edge [arrowhead=normal];\n\n");
let nodes_to_include = if let Some(max_nodes) = self.config.max_nodes {
&graph.nodes[..std::cmp::min(max_nodes, graph.nodes.len())]
} else {
&graph.nodes
};
for node in nodes_to_include {
let label = if self.config.include_addresses {
format!("{}\\n0x{:x}", node.function_name, node.function_address)
} else {
node.function_name.clone()
};
let color = if self.config.color_by_type {
match node.node_type {
NodeType::EntryPoint => "lightgreen",
NodeType::Library => "lightblue",
NodeType::External => "lightyellow",
NodeType::Internal => "lightgray",
NodeType::Indirect => "orange",
NodeType::Virtual => "purple",
NodeType::Unknown => "pink",
}
} else {
"lightgray"
};
dot.push_str(&format!(
" \"0x{:x}\" [label=\"{}\", fillcolor=\"{}\"];\n",
node.function_address, label, color
));
}
dot.push('\n');
let node_addresses: HashSet<u64> = nodes_to_include
.iter()
.map(|n| n.function_address)
.collect();
for edge in &graph.edges {
if node_addresses.contains(&edge.caller) && node_addresses.contains(&edge.callee) {
let style = match edge.call_type {
CallType::Direct => "",
CallType::Indirect => ", style=dashed",
CallType::TailCall => ", color=red",
CallType::Virtual => ", color=purple",
CallType::Recursive => ", color=green, style=bold",
CallType::Conditional => ", color=orange",
};
let label = if self.config.show_call_counts {
format!(" [label=\"{}\"{}", edge.call_sites.len(), style)
} else {
format!(" [{}]", &style[2..]) };
dot.push_str(&format!(
" \"0x{:x}\" -> \"0x{:x}\"{};\n",
edge.caller, edge.callee, label
));
}
}
dot.push_str("}\n");
Ok(dot)
}
}
pub struct JsonExporter;
impl CallGraphExporter for JsonExporter {
fn export(&self, graph: &CallGraph) -> Result<String> {
#[cfg(feature = "serde-support")]
{
serde_json::to_string_pretty(graph)
.map_err(|e| BinaryError::invalid_data(format!("JSON serialization failed: {}", e)))
}
#[cfg(not(feature = "serde-support"))]
{
let _ = graph; Err(BinaryError::invalid_data(
"JSON export requires 'serde-support' feature",
))
}
}
}
impl CallGraph {
pub fn to_dot(&self) -> String {
let exporter = DotExporter::new_default();
exporter.export(self).unwrap_or_default()
}
pub fn to_dot_with_config(&self, config: DotConfig) -> String {
let exporter = DotExporter::new(config);
exporter.export(self).unwrap_or_default()
}
pub fn to_json(&self) -> String {
let exporter = JsonExporter;
exporter.export(self).unwrap_or_default()
}
pub fn detect_cycles(&self) -> Vec<Vec<u64>> {
let mut cycles = Vec::new();
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut current_path = Vec::new();
let mut adjacency: HashMap<u64, Vec<u64>> = HashMap::new();
for edge in &self.edges {
adjacency.entry(edge.caller).or_default().push(edge.callee);
}
for node in &self.nodes {
if !visited.contains(&node.function_address) {
self.dfs_find_cycles(
node.function_address,
&mut visited,
&mut rec_stack,
&mut current_path,
&mut cycles,
&adjacency,
);
}
}
cycles
}
#[allow(clippy::only_used_in_recursion)]
fn dfs_find_cycles(
&self,
node: u64,
visited: &mut HashSet<u64>,
rec_stack: &mut HashSet<u64>,
current_path: &mut Vec<u64>,
cycles: &mut Vec<Vec<u64>>,
adjacency: &HashMap<u64, Vec<u64>>,
) {
visited.insert(node);
rec_stack.insert(node);
current_path.push(node);
if let Some(neighbors) = adjacency.get(&node) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
self.dfs_find_cycles(
neighbor,
visited,
rec_stack,
current_path,
cycles,
adjacency,
);
} else if rec_stack.contains(&neighbor) {
if let Some(cycle_start) = current_path.iter().position(|&x| x == neighbor) {
let cycle = current_path[cycle_start..].to_vec();
cycles.push(cycle);
}
}
}
}
current_path.pop();
rec_stack.remove(&node);
}
}
pub fn analyze_binary(binary: &BinaryFile) -> Result<CallGraph> {
let analyzer = CallGraphAnalyzer::new_default();
analyzer.analyze_binary(binary)
}
pub fn analyze_binary_with_config(
binary: &BinaryFile,
config: CallGraphConfig,
) -> Result<CallGraph> {
let analyzer = CallGraphAnalyzer::new(config);
analyzer.analyze_binary(binary)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::*;
#[test]
fn test_analyzer_creation() {
let analyzer = CallGraphAnalyzer::new_default();
assert!(analyzer.config.analyze_indirect_calls);
assert!(analyzer.config.detect_tail_calls);
}
#[test]
fn test_library_function_detection() {
let analyzer = CallGraphAnalyzer::new_default();
assert!(analyzer.is_library_function("printf"));
assert!(analyzer.is_library_function("libc_start_main"));
assert!(analyzer.is_library_function("__stack_chk_fail"));
assert!(!analyzer.is_library_function("user_function"));
assert!(!analyzer.is_library_function("main"));
}
#[test]
fn test_node_type_classification() {
let analyzer = CallGraphAnalyzer::new_default();
let entry_function = Function {
name: "_start".to_string(),
start_address: 0x1000,
end_address: 0x1100,
size: 256,
function_type: FunctionType::Entrypoint,
calling_convention: None,
parameters: Vec::new(),
return_type: None,
};
let elf_data = vec![
0x7f, 0x45, 0x4c, 0x46, 0x02, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x3e, 0x00, ];
match crate::BinaryFile::parse(&elf_data) {
Ok(binary) => {
let node_type = analyzer.classify_node_type(&entry_function, &binary);
assert_eq!(node_type, NodeType::EntryPoint);
}
Err(_) => {
assert_eq!(entry_function.function_type, FunctionType::Entrypoint);
}
}
}
}