use rustc_hir::{def::DefKind, def_id::DefId};
use rustc_middle::{
mir::{self, Body},
ty::TyCtxt,
};
use std::collections::HashMap;
use std::collections::HashSet;
use super::visitor::CallGraphVisitor;
use crate::{
Analysis,
analysis::core::callgraph::{CallGraphAnalysis, FnCallMap},
};
pub struct CallGraphAnalyzer<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub graph: CallGraph<'tcx>,
}
impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
fn name(&self) -> &'static str {
"Default call graph analysis algorithm."
}
fn run(&mut self) {
self.start();
}
fn reset(&mut self) {
todo!();
}
}
impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
fn get_fn_calls(&self) -> FnCallMap {
let fn_calls: HashMap<DefId, Vec<DefId>> = self
.graph
.fn_calls
.clone()
.into_iter()
.map(|(caller, callees)| {
let callee_ids = callees.into_iter().map(|(did, _)| did).collect::<Vec<_>>();
(caller, callee_ids)
})
.collect();
fn_calls
}
}
impl<'tcx> CallGraphAnalyzer<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
Self {
tcx: tcx,
graph: CallGraph::new(tcx),
}
}
pub fn start(&mut self) {
for local_def_id in self.tcx.mir_keys(()) {
let def_id = local_def_id.to_def_id();
if self.tcx.is_mir_available(def_id) {
let def_kind = self.tcx.def_kind(def_id);
let body: &Body<'_> = match def_kind {
DefKind::Fn | DefKind::AssocFn | DefKind::Closure => {
&self.tcx.optimized_mir(def_id)
}
DefKind::Const
| DefKind::Static { .. }
| DefKind::AssocConst
| DefKind::InlineConst
| DefKind::AnonConst => {
&self.tcx.mir_for_ctfe(def_id)
}
_ => {
rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
continue;
}
};
let mut call_graph_visitor =
CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
call_graph_visitor.visit();
}
}
}
}
pub type CallMap<'tcx> = HashMap<DefId, Vec<(DefId, Option<&'tcx mir::Terminator<'tcx>>)>>;
pub struct CallGraph<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub functions: HashSet<DefId>, pub fn_calls: CallMap<'tcx>, }
impl<'tcx> CallGraph<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
Self {
tcx,
functions: HashSet::new(),
fn_calls: HashMap::new(),
}
}
pub fn register_fn(&mut self, def_id: DefId) -> bool {
if let Some(_) = self.functions.iter().find(|func_id| **func_id == def_id) {
false
} else {
self.functions.insert(def_id);
true
}
}
pub fn add_funciton_call(
&mut self,
caller_id: DefId,
callee_id: DefId,
terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
) {
let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
entry.push((callee_id, terminator_stmt));
}
}
impl<'tcx> CallGraph<'tcx> {
pub fn get_reverse_post_order(&self) -> Vec<DefId> {
let mut result = self.get_post_order();
result.reverse();
result
}
pub fn get_post_order(&self) -> Vec<DefId> {
let mut visited = HashSet::new();
let mut post_order_ids = Vec::new();
for &func_def_id in self.functions.iter() {
if !visited.contains(&func_def_id) {
self.dfs_post_order(func_def_id, &mut visited, &mut post_order_ids);
}
}
post_order_ids
}
fn dfs_post_order(
&self,
func_def_id: DefId,
visited: &mut HashSet<DefId>,
post_order_ids: &mut Vec<DefId>,
) {
visited.insert(func_def_id);
if let Some(callees) = self.fn_calls.get(&func_def_id) {
for (callee_id, _terminator) in callees {
if !visited.contains(callee_id) {
self.dfs_post_order(*callee_id, visited, post_order_ids);
}
}
}
post_order_ids.push(func_def_id);
}
pub fn get_callers_map(&self) -> CallMap<'tcx> {
let mut callers_map: CallMap<'tcx> = HashMap::new();
for (&caller_id, calls_vec) in &self.fn_calls {
for (callee_id, terminator) in calls_vec {
callers_map
.entry(*callee_id)
.or_insert_with(Vec::new)
.push((caller_id, *terminator));
}
}
callers_map
}
pub fn get_callees(&self, caller_def_id: DefId) -> Vec<DefId> {
if let Some(callees) = self.fn_calls.get(&caller_def_id) {
callees
.clone()
.into_iter()
.map(|(did, _)| did)
.collect::<Vec<_>>()
} else {
vec![]
}
}
pub fn get_callees_recursive(&self, caller_def_id: DefId) -> Vec<DefId> {
let mut visited = HashSet::new();
let mut result = Vec::new();
self.dfs_post_order(caller_def_id, &mut visited, &mut result);
result
}
pub fn get_callers(&self, callee_def_id: DefId) -> Vec<DefId> {
let callers_map = self.get_callers_map();
if let Some(callers) = callers_map.get(&callee_def_id) {
callers
.clone()
.into_iter()
.map(|(did, _)| did)
.collect::<Vec<_>>()
} else {
vec![]
}
}
}