use std::panic::{AssertUnwindSafe, catch_unwind};
use rustc_hir::def_id::DefId;
use rustc_middle::{
mir::{Local, Operand, Rvalue, StatementKind, TerminatorKind},
ty::{PseudoCanonicalInput, Ty, TyCtxt, TyKind},
};
use crate::analysis::dataflow::{DataflowAnalysis, default::DataflowAnalyzer};
use super::path_refine::ForgetReason;
#[derive(Clone, Debug)]
pub struct CallDependencySummary {
pub callee: Option<DefId>,
pub name: String,
pub return_depends_on_args: Vec<usize>,
pub may_write_args: Vec<usize>,
pub unsupported: bool,
}
impl CallDependencySummary {
fn unknown(callee: Option<DefId>, name: String, arg_count: usize) -> Self {
Self {
callee,
name,
return_depends_on_args: (0..arg_count).collect(),
may_write_args: Vec::new(),
unsupported: true,
}
}
}
#[derive(Clone, Debug)]
pub struct CallEffectSummary {
pub callee: Option<DefId>,
pub name: String,
pub destination: Option<Local>,
pub effects: Vec<CallEffect>,
pub unsupported: bool,
}
impl CallEffectSummary {
fn unknown(callee: Option<DefId>, name: String, destination: Option<Local>) -> Self {
Self {
callee,
name,
destination,
effects: Vec::new(),
unsupported: true,
}
}
}
#[derive(Clone, Debug)]
pub enum CallEffect {
ReturnAliasArg { arg: usize },
ReturnPointerFromArg { arg: usize },
ReturnPointerAdd {
base_arg: usize,
offset_arg: usize,
stride: Option<u64>,
},
ReturnPointerSub {
base_arg: usize,
offset_arg: usize,
stride: Option<u64>,
},
ReturnNonZero,
ReturnAligned { align: u64, ty_name: String },
ReturnConst { value: u64, label: String },
ReadMemory { arg: usize },
WriteMemory { pointer_arg: usize },
ReturnLengthOfArg { arg: usize },
ForgetArgFacts { arg: usize, reason: ForgetReason },
}
pub fn dependency_summary<'tcx>(
tcx: TyCtxt<'tcx>,
func: &Operand<'tcx>,
arg_count: usize,
) -> CallDependencySummary {
let callee = callee_def_id(func);
let name = call_name(tcx, func);
if is_as_ptr_call(&name) || is_as_mut_ptr_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: vec![0],
may_write_args: Vec::new(),
unsupported: false,
};
}
if is_pointer_add_call(&name) || is_pointer_sub_call(&name) || is_pointer_offset_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: vec![0, 1],
may_write_args: Vec::new(),
unsupported: false,
};
}
if is_pointer_read_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: vec![0],
may_write_args: Vec::new(),
unsupported: false,
};
}
if is_pointer_write_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: Vec::new(),
may_write_args: vec![0],
unsupported: false,
};
}
if is_len_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: vec![0],
may_write_args: Vec::new(),
unsupported: false,
};
}
if is_maybe_uninit_uninit_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: Vec::new(),
may_write_args: Vec::new(),
unsupported: false,
};
}
if is_layout_constant_call(&name) {
return CallDependencySummary {
callee,
name,
return_depends_on_args: Vec::new(),
may_write_args: Vec::new(),
unsupported: false,
};
}
if let Some(callee) = callee {
if let Some(return_deps) = local_return_dependencies(tcx, callee) {
return CallDependencySummary {
callee: Some(callee),
name,
return_depends_on_args: return_deps
.into_iter()
.filter(|index| *index < arg_count)
.collect(),
may_write_args: Vec::new(),
unsupported: false,
};
}
}
CallDependencySummary::unknown(callee, name, arg_count)
}
pub fn effect_summary<'tcx>(
tcx: TyCtxt<'tcx>,
caller: DefId,
func: &Operand<'tcx>,
destination: Local,
) -> CallEffectSummary {
let callee = callee_def_id(func);
let name = call_name(tcx, func);
let destination = Some(destination);
if is_as_ptr_call(&name) || is_as_mut_ptr_call(&name) {
let mut effects = vec![
CallEffect::ReturnPointerFromArg { arg: 0 },
CallEffect::ReturnNonZero,
];
if let Some((align, ty_name)) = destination_pointee_alignment(tcx, caller, destination) {
effects.push(CallEffect::ReturnAligned { align, ty_name });
}
return CallEffectSummary {
callee,
name,
destination,
effects,
unsupported: false,
};
}
if is_pointer_add_call(&name) || is_pointer_offset_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: vec![CallEffect::ReturnPointerAdd {
base_arg: 0,
offset_arg: 1,
stride: destination_stride(tcx, caller, destination),
}],
unsupported: false,
};
}
if is_pointer_sub_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: vec![CallEffect::ReturnPointerSub {
base_arg: 0,
offset_arg: 1,
stride: destination_stride(tcx, caller, destination),
}],
unsupported: false,
};
}
if is_pointer_read_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: vec![CallEffect::ReadMemory { arg: 0 }],
unsupported: false,
};
}
if is_pointer_write_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: vec![CallEffect::WriteMemory { pointer_arg: 0 }],
unsupported: false,
};
}
if is_len_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: vec![CallEffect::ReturnLengthOfArg { arg: 0 }],
unsupported: false,
};
}
if is_maybe_uninit_uninit_call(&name) {
return CallEffectSummary {
callee,
name,
destination,
effects: Vec::new(),
unsupported: false,
};
}
if is_layout_constant_call(&name) {
let effects = layout_constant_effect(tcx, caller, func, &name)
.into_iter()
.collect();
return CallEffectSummary {
callee,
name,
destination,
effects,
unsupported: false,
};
}
if let Some(callee) = callee {
if let Some(effect) = try_pointer_arith_wrapper_effect(tcx, callee, destination) {
return CallEffectSummary {
callee: Some(callee),
name,
destination,
effects: vec![effect],
unsupported: false,
};
}
if let Some(return_deps) = local_return_dependencies(tcx, callee) {
return CallEffectSummary {
callee: Some(callee),
name,
destination,
effects: return_deps
.into_iter()
.map(|arg| CallEffect::ReturnAliasArg { arg })
.collect(),
unsupported: false,
};
}
}
CallEffectSummary::unknown(callee, name, destination)
}
pub fn callee_def_id(func: &Operand<'_>) -> Option<DefId> {
let Operand::Constant(func_constant) = func else {
return None;
};
let TyKind::FnDef(def_id, _) = func_constant.const_.ty().kind() else {
return None;
};
Some(*def_id)
}
pub fn call_name(tcx: TyCtxt<'_>, func: &Operand<'_>) -> String {
callee_def_id(func)
.map(|def_id| tcx.def_path_str(def_id))
.unwrap_or_else(|| format!("{func:?}"))
}
pub fn is_as_ptr_call(name: &str) -> bool {
name.ends_with("::as_ptr") || name.contains("::as_ptr")
}
pub fn is_as_mut_ptr_call(name: &str) -> bool {
name.ends_with("::as_mut_ptr") || name.contains("::as_mut_ptr")
}
pub fn is_pointer_add_call(name: &str) -> bool {
name.contains("::add") || name.contains("::wrapping_add")
}
pub fn is_pointer_sub_call(name: &str) -> bool {
name.contains("::sub") || name.contains("::wrapping_sub")
}
pub fn is_pointer_offset_call(name: &str) -> bool {
name.contains("::offset") || name.contains("::wrapping_offset")
}
pub fn is_pointer_read_call(name: &str) -> bool {
name.contains("::read") || name.ends_with("read")
}
pub fn is_pointer_write_call(name: &str) -> bool {
(name.contains("::write") || name.ends_with("write"))
&& !name.contains("write_bytes")
&& !name.contains("write_unaligned")
&& !name.contains("write_volatile")
}
pub fn is_len_call(name: &str) -> bool {
name.ends_with("::len") || name.contains("::len")
}
pub fn is_maybe_uninit_uninit_call(name: &str) -> bool {
name.contains("MaybeUninit") && name.ends_with("::uninit")
}
pub fn is_layout_constant_call(name: &str) -> bool {
name.contains("align_of") || name.contains("size_of")
}
fn layout_constant_effect<'tcx>(
tcx: TyCtxt<'tcx>,
caller: DefId,
func: &Operand<'tcx>,
name: &str,
) -> Option<CallEffect> {
let ty = layout_call_ty(func)?;
let (align, size) = type_layout(tcx, caller, ty)?;
if name.contains("align_of") {
Some(CallEffect::ReturnConst {
value: align,
label: format!("align_of::<{ty:?}>()"),
})
} else if name.contains("size_of") {
Some(CallEffect::ReturnConst {
value: size,
label: format!("size_of::<{ty:?}>()"),
})
} else {
None
}
}
fn layout_call_ty<'tcx>(func: &Operand<'tcx>) -> Option<Ty<'tcx>> {
let Operand::Constant(func_constant) = func else {
return None;
};
let TyKind::FnDef(_, args) = func_constant.const_.ty().kind() else {
return None;
};
args.iter().find_map(|arg| arg.as_type())
}
fn trace_to_callee_arg<'tcx>(
body: &rustc_middle::mir::Body<'tcx>,
operand: &Operand<'_>,
) -> Option<usize> {
use std::collections::{HashSet, VecDeque};
let local = match operand {
Operand::Copy(place) | Operand::Move(place) => place.local,
_ => return None,
};
let idx = local.as_usize();
if idx >= 1 && idx <= body.arg_count {
return Some(idx - 1);
}
let mut queue = VecDeque::from([local]);
let mut seen = HashSet::from([local]);
while let Some(current) = queue.pop_front() {
let cidx = current.as_usize();
if cidx >= 1 && cidx <= body.arg_count {
return Some(cidx - 1);
}
for bb in body.basic_blocks.iter() {
for stmt in &bb.statements {
let StatementKind::Assign(assign) = &stmt.kind else { continue };
let dest = assign.0.local;
if dest != current {
continue;
}
let source = match &assign.1 {
Rvalue::Use(Operand::Copy(place))
| Rvalue::Use(Operand::Move(place))
| Rvalue::Cast(_, Operand::Copy(place), _)
| Rvalue::Cast(_, Operand::Move(place), _) => place.local,
_ => continue,
};
if !seen.contains(&source) {
seen.insert(source);
queue.push_back(source);
}
}
}
}
None
}
fn try_pointer_arith_wrapper_effect<'tcx>(
tcx: TyCtxt<'tcx>,
callee: DefId,
_destination: Option<Local>,
) -> Option<CallEffect> {
use std::collections::{HashSet, VecDeque};
let body = tcx.optimized_mir(callee);
let ret = Local::from_usize(0);
for bb in body.basic_blocks.iter() {
let Some(terminator) = &bb.terminator else { continue };
let TerminatorKind::Call {
func,
args,
destination: call_dest,
..
} = &terminator.kind
else {
continue;
};
let name = call_name(tcx, func);
let is_add = is_pointer_add_call(&name);
let is_sub = is_pointer_sub_call(&name);
let inner_effect = if !is_add && !is_sub {
callee_def_id(func).and_then(|inner_callee| {
try_pointer_arith_wrapper_effect(tcx, inner_callee, Some(call_dest.local))
})
} else {
None
};
if !is_add && !is_sub && inner_effect.is_none() {
continue;
}
let mut queue = VecDeque::from([call_dest.local]);
let mut seen = HashSet::from([call_dest.local]);
let mut reaches_ret = false;
while let Some(current) = queue.pop_front() {
if current == ret {
reaches_ret = true;
break;
}
for bb2 in body.basic_blocks.iter() {
for stmt in &bb2.statements {
let StatementKind::Assign(assign) = &stmt.kind else { continue };
let dest = assign.0.local;
if seen.contains(&dest) {
continue;
}
match &assign.1 {
Rvalue::Use(Operand::Copy(place))
| Rvalue::Use(Operand::Move(place)) => {
if place.local == current {
queue.push_back(dest);
seen.insert(dest);
}
}
Rvalue::Cast(_, Operand::Copy(place), _)
| Rvalue::Cast(_, Operand::Move(place), _) => {
if place.local == current {
queue.push_back(dest);
seen.insert(dest);
}
}
_ => {}
}
}
}
}
if !reaches_ret {
continue;
}
if let Some(effect) = inner_effect {
match effect {
CallEffect::ReturnPointerAdd {
base_arg: inner_base,
offset_arg: inner_offset,
stride,
}
| CallEffect::ReturnPointerSub {
base_arg: inner_base,
offset_arg: inner_offset,
stride,
} => {
let base_arg =
trace_to_callee_arg(body, &args.get(inner_base)?.node)?;
let offset_arg =
trace_to_callee_arg(body, &args.get(inner_offset)?.node)?;
return Some(match effect {
CallEffect::ReturnPointerSub { .. } => {
CallEffect::ReturnPointerSub {
base_arg,
offset_arg,
stride,
}
}
_ => CallEffect::ReturnPointerAdd {
base_arg,
offset_arg,
stride,
},
});
}
_ => {}
}
continue;
}
let map_arg = |operand: &Operand<'_>| -> Option<usize> {
let local = match operand {
Operand::Copy(place) | Operand::Move(place) => place.local,
_ => return None,
};
let idx = local.as_usize();
if idx >= 1 && idx <= body.arg_count {
return Some(idx - 1);
}
let mut queue = VecDeque::from([local]);
let mut seen = HashSet::from([local]);
while let Some(current) = queue.pop_front() {
let cidx = current.as_usize();
if cidx >= 1 && cidx <= body.arg_count {
return Some(cidx - 1);
}
for bb2 in body.basic_blocks.iter() {
for stmt in &bb2.statements {
let StatementKind::Assign(assign) = &stmt.kind else { continue };
let dest = assign.0.local;
if dest != current {
continue;
}
let source = match &assign.1 {
Rvalue::Use(Operand::Copy(place))
| Rvalue::Use(Operand::Move(place))
| Rvalue::Cast(_, Operand::Copy(place), _)
| Rvalue::Cast(_, Operand::Move(place), _) => {
place.local
}
_ => continue,
};
if !seen.contains(&source) {
seen.insert(source);
queue.push_back(source);
}
}
}
}
None
};
let base_arg = map_arg(&args[0].node)?;
let offset_arg = map_arg(&args[1].node)?;
let stride = destination_stride(tcx, callee, Some(call_dest.local));
return if is_sub {
Some(CallEffect::ReturnPointerSub {
base_arg,
offset_arg,
stride,
})
} else {
Some(CallEffect::ReturnPointerAdd {
base_arg,
offset_arg,
stride,
})
};
}
None
}
fn local_return_dependencies(tcx: TyCtxt<'_>, callee: DefId) -> Option<Vec<usize>> {
callee.as_local()?;
catch_unwind(AssertUnwindSafe(|| {
let mut analyzer = DataflowAnalyzer::new(tcx, false);
analyzer.build_graph(callee);
let deps = analyzer.get_fn_arg2ret(callee);
deps.iter_enumerated()
.filter_map(|(local, depends)| {
if *depends && local.as_usize() > 0 {
Some(local.as_usize() - 1)
} else {
None
}
})
.collect()
}))
.ok()
}
fn destination_stride<'tcx>(
tcx: TyCtxt<'tcx>,
caller: DefId,
destination: Option<Local>,
) -> Option<u64> {
let destination = destination?;
let ty = tcx.optimized_mir(caller).local_decls[destination].ty;
let pointee = pointee_ty(ty)?;
type_layout(tcx, caller, pointee).map(|(_, size)| size)
}
fn destination_pointee_alignment<'tcx>(
tcx: TyCtxt<'tcx>,
caller: DefId,
destination: Option<Local>,
) -> Option<(u64, String)> {
let destination = destination?;
let ty = tcx.optimized_mir(caller).local_decls[destination].ty;
let pointee = pointee_ty(ty).or(Some(ty))?;
type_layout(tcx, caller, pointee).map(|(align, _)| (align, format!("{pointee:?}")))
}
fn pointee_ty<'tcx>(ty: Ty<'tcx>) -> Option<Ty<'tcx>> {
match ty.kind() {
TyKind::RawPtr(ty, _) | TyKind::Ref(_, ty, _) => Some(*ty),
_ => None,
}
}
fn type_layout<'tcx>(tcx: TyCtxt<'tcx>, caller: DefId, ty: Ty<'tcx>) -> Option<(u64, u64)> {
let typing_env = rustc_middle::ty::TypingEnv::post_analysis(tcx, caller);
let input = PseudoCanonicalInput {
typing_env,
value: ty,
};
let layout = tcx.layout_of(input).ok()?;
Some((layout.align.abi.bytes(), layout.size.bytes()))
}