use rustc_hir::def_id::DefId;
use rustc_middle::mir::{Body, TerminatorKind};
use rustc_middle::ty::TyCtxt;
use rustc_span::Span;
use std::collections::HashMap;
#[derive(Debug)]
pub struct AllocationViolation {
pub span: Span,
pub reason: String,
}
pub fn detect_allocation_in_mir<'tcx>(
tcx: TyCtxt<'tcx>,
mir: &Body<'tcx>,
_fn_def_id: DefId,
cache: &mut HashMap<DefId, bool>,
) -> Option<AllocationViolation> {
for (_bb, bb_data) in mir.basic_blocks.iter_enumerated() {
if let Some(terminator) = &bb_data.terminator
&& let TerminatorKind::Call { func, args, .. } = &terminator.kind
{
if let Some((callee_def_id, _generics)) = func.const_fn_def() {
let path = tcx.def_path_str(callee_def_id);
for arg in args.iter() {
use rustc_middle::mir::Operand;
let closure_def_id = match &arg.node {
Operand::Constant(constant) => {
let ty = constant.const_.ty();
if let rustc_middle::ty::TyKind::Closure(def_id, _) = ty.kind() {
Some(*def_id)
} else if let rustc_middle::ty::TyKind::FnDef(def_id, _) = ty.kind() {
Some(*def_id)
} else {
None
}
}
Operand::Move(place) | Operand::Copy(place) => {
let ty = place.ty(mir, tcx).ty;
if let rustc_middle::ty::TyKind::Closure(def_id, _) = ty.kind() {
Some(*def_id)
} else if let rustc_middle::ty::TyKind::FnDef(def_id, _) = ty.kind() {
Some(*def_id)
} else {
None
}
}
Operand::RuntimeChecks(_) => None,
};
if let Some(closure_def_id) = closure_def_id {
if closure_def_id.krate == rustc_hir::def_id::LOCAL_CRATE
&& tcx.is_mir_available(closure_def_id)
&& function_allocates(tcx, closure_def_id, cache)
{
return Some(AllocationViolation {
span: terminator.source_info.span,
reason: format!("passes allocating closure to {path}"),
});
}
}
}
if is_allocating_function(&path) {
return Some(AllocationViolation {
span: terminator.source_info.span,
reason: format!("calls allocating function: {path}"),
});
}
if should_analyze_transitively(tcx, callee_def_id)
&& function_allocates(tcx, callee_def_id, cache)
{
return Some(AllocationViolation {
span: terminator.source_info.span,
reason: format!("calls function that allocates: {path}"),
});
}
}
}
}
None
}
fn is_allocating_function(path: &str) -> bool {
if path.contains("alloc::alloc::")
&& (path.contains("::alloc")
|| path.contains("::allocate")
|| path.contains("::exchange_malloc")
|| path.contains("::box_free"))
{
return true;
}
if (path.contains("::Box::") || path.contains("::Box::<")) && path.contains("::new") {
return true;
}
if (path.contains("::Vec::") || path.contains("::Vec::<"))
&& (path.contains("::new")
|| path.contains("::with_capacity")
|| path.contains("::push")
|| path.contains("::insert")
|| path.contains("::extend")
|| path.contains("::append")
|| path.contains("::resize")
|| path.contains("::from_elem"))
{
return true;
}
if path.contains("::String::")
&& (path.contains("::new")
|| path.contains("::from")
|| path.contains("::from_utf8")
|| path.contains("::from_utf16")
|| path.contains("::push_str")
|| path.contains("::push")
|| path.contains("::insert")
|| path.contains("::insert_str"))
{
return true;
}
if path.contains("::format") || path.contains("fmt::format") {
return true;
}
if (path.contains("::Rc::")
|| path.contains("::Rc::<")
|| path.contains("::Arc::")
|| path.contains("::Arc::<"))
&& (path.contains("::new") || path.contains("::clone"))
{
return true;
}
if (path.contains("HashMap")
|| path.contains("BTreeMap")
|| path.contains("HashSet")
|| path.contains("BTreeSet")
|| path.contains("VecDeque")
|| path.contains("LinkedList")
|| path.contains("BinaryHeap"))
&& (path.contains(">::new")
|| path.contains(">::with_capacity")
|| path.contains(">::insert")
|| path.contains(">::push"))
{
return true;
}
if path.contains("::to_string") || path.contains("::to_owned") {
return true;
}
if path.contains("RawVec") && (path.contains("::new") || path.contains("::allocate")) {
return true;
}
false
}
fn should_analyze_transitively(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
def_id.krate == rustc_hir::def_id::LOCAL_CRATE && tcx.is_mir_available(def_id)
}
fn function_allocates<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
cache: &mut HashMap<DefId, bool>,
) -> bool {
if let Some(&result) = cache.get(&def_id) {
return result;
}
cache.insert(def_id, false);
if !tcx.is_mir_available(def_id) {
return false;
}
let mir = tcx.optimized_mir(def_id);
let allocates = detect_allocation_in_mir(tcx, mir, def_id, cache).is_some();
cache.insert(def_id, allocates);
allocates
}