use rustc_middle::mir::Body;
use rustc_middle::mir::{BasicBlock, StatementKind, TerminatorKind};
use rustc_middle::ty::TyCtxt;
use rustc_hir::def_id::DefId;
use crate::analysis::dataflow::graph::build_dataflow_graph;
use crate::graphs::dataflow::DataflowGraph;
use super::super::{
contract,
def_use::{
RelevantPlaces, bind_callsite_roots, operand_uses, terminator_use_def,
},
helpers::{Callsite, CallsiteLocation},
path::{Path, PathStep},
};
use super::{
types::{BackwardItem, KeepReason, ForgetReason, RelevantMirItems},
call_visit,
};
pub struct BackwardVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
}
impl<'tcx> BackwardVisitor<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
Self { tcx }
}
pub fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
pub fn start_visit(
&self,
callsite: CallsiteLocation,
path: &Path,
property: &contract::Property<'tcx>,
) -> RelevantMirItems<'tcx> {
RelevantMirItems {
callsite,
property: property.clone(),
path: path.clone(),
items: Vec::new(),
roots: RelevantPlaces::from_property(property),
}
}
pub fn visit(
&self,
callsite: &Callsite<'tcx>,
path: &Path,
property: &contract::Property<'tcx>,
) -> RelevantMirItems<'tcx> {
let mut visit = self.start_visit(callsite.location(), path, property);
bind_callsite_roots(self.tcx, &mut visit.roots, callsite);
self.visit_path(callsite.caller, &callsite.location(), path, &mut visit);
visit
}
pub fn visit_for_checkpoint(
&self,
caller: DefId,
checkpoint: CallsiteLocation,
path: &Path,
property: &contract::Property<'tcx>,
) -> RelevantMirItems<'tcx> {
let mut visit = self.start_visit(checkpoint, path, property);
self.visit_path(caller, &checkpoint, path, &mut visit);
visit
}
fn visit_path(
&self,
caller: DefId,
callsite_loc: &CallsiteLocation,
path: &Path,
visit: &mut RelevantMirItems<'tcx>,
) {
let mut relevant = visit.roots.clone();
let mut items: Vec<BackwardItem<'tcx>> = Vec::new();
let body = self.tcx.optimized_mir(caller);
let flow = build_dataflow_graph(self.tcx, caller);
for step in path.steps.iter().rev() {
self.visit_path_step_inner(step, callsite_loc, &body, &flow, &mut relevant, &mut items);
}
items.reverse();
visit.items = items;
}
fn visit_path_step_inner(
&self,
step: &PathStep,
callsite_loc: &CallsiteLocation,
body: &'tcx Body<'tcx>,
flow: &DataflowGraph,
relevant: &mut RelevantPlaces,
items: &mut Vec<BackwardItem<'tcx>>,
) {
match step {
PathStep::Callsite(location) => {
if *location != *callsite_loc {
return;
}
items.push(BackwardItem::Terminator {
block: location.block,
kind: KeepReason::Callsite,
});
}
PathStep::Block(block) => {
let block_data = &body.basic_blocks[*block];
if *block != callsite_loc.block {
self.visit_terminator(
*block,
block_data.terminator(),
flow,
body,
relevant,
items,
);
}
for (statement_index, statement) in block_data.statements.iter().enumerate().rev() {
self.visit_statement(*block, statement_index, statement, flow, relevant, items);
}
}
}
}
fn visit_statement(
&self,
block: BasicBlock,
statement_index: usize,
statement: &'tcx rustc_middle::mir::Statement<'tcx>,
flow: &DataflowGraph,
relevant: &mut RelevantPlaces,
items: &mut Vec<BackwardItem<'tcx>>,
) {
let mut defs = RelevantPlaces::new();
match &statement.kind {
StatementKind::Assign(box (place, _)) => {
defs.insert_mir_place(place);
}
StatementKind::StorageDead(local) => {
defs.insert_local(*local);
}
_ => {}
}
if defs.intersects(relevant) {
let uses = collect_statement_uses(statement, block, statement_index, flow);
items.push(BackwardItem::Statement {
block,
statement_index,
kind: statement_keep_reason(statement),
});
relevant.remove_all(&defs);
relevant.extend(uses);
return;
}
if statement_invalidates_relevant(statement, relevant) {
items.push(BackwardItem::Statement {
block,
statement_index,
kind: KeepReason::Invalidation,
});
} else if statement_can_refine(statement) {
let mut uses = RelevantPlaces::new();
for &local in &defs.locals {
for &edge_idx in &flow.node(local).in_edges {
let edge = &flow.edges[edge_idx];
if edge.block == block.as_usize() && edge.statement_index == statement_index {
uses.insert_local(edge.src);
}
}
}
if uses.intersects(relevant) {
items.push(BackwardItem::Statement {
block,
statement_index,
kind: KeepReason::RuntimeCheck,
});
}
}
}
fn visit_terminator(
&self,
block: BasicBlock,
terminator: &rustc_middle::mir::Terminator<'tcx>,
flow: &DataflowGraph,
body: &Body<'tcx>,
relevant: &mut RelevantPlaces,
items: &mut Vec<BackwardItem<'tcx>>,
) {
if let TerminatorKind::Call {
func,
args,
destination,
..
} = &terminator.kind
{
call_visit::visit(self.tcx, block, func, args, destination, flow, body, relevant, items);
return;
}
let use_def = terminator_use_def(terminator);
if terminator_is_path_condition(terminator) {
items.push(BackwardItem::Terminator {
block,
kind: KeepReason::PathCondition,
});
relevant.extend(use_def.uses.clone());
return;
}
if use_def.defs.intersects(relevant) {
if terminator_may_havoc(terminator) {
items.push(BackwardItem::Forget {
reason: ForgetReason::UnknownCall,
});
}
items.push(BackwardItem::Terminator {
block,
kind: terminator_definition_reason(terminator),
});
relevant.remove_all(&use_def.defs);
relevant.extend(use_def.uses);
return;
}
if use_def.uses.intersects(relevant) {
if terminator_may_havoc(terminator) {
items.push(BackwardItem::Forget {
reason: ForgetReason::UnknownCall,
});
}
items.push(BackwardItem::Terminator {
block,
kind: terminator_use_reason(terminator),
});
}
}
}
fn statement_keep_reason(statement: &rustc_middle::mir::Statement<'_>) -> KeepReason {
match &statement.kind {
StatementKind::Assign(box (_, rvalue)) => match rvalue {
rustc_middle::mir::Rvalue::Ref(_, _, _)
| rustc_middle::mir::Rvalue::RawPtr(_, _)
| rustc_middle::mir::Rvalue::Cast(_, _, _)
| rustc_middle::mir::Rvalue::CopyForDeref(_)
| rustc_middle::mir::Rvalue::BinaryOp(_, _) => KeepReason::PointerFlow,
_ => KeepReason::Definition,
},
StatementKind::StorageDead(_) => KeepReason::Invalidation,
_ => KeepReason::Definition,
}
}
fn statement_can_refine(statement: &rustc_middle::mir::Statement<'_>) -> bool {
matches!(
&statement.kind,
StatementKind::Assign(box (
_,
rustc_middle::mir::Rvalue::BinaryOp(_, _)
| rustc_middle::mir::Rvalue::UnaryOp(_, _)
| rustc_middle::mir::Rvalue::Cast(_, _, _),
))
)
}
fn statement_invalidates_relevant(
statement: &rustc_middle::mir::Statement<'_>,
relevant: &RelevantPlaces,
) -> bool {
match &statement.kind {
StatementKind::StorageDead(local) => relevant.locals.contains(local),
_ => false,
}
}
fn terminator_is_path_condition(terminator: &rustc_middle::mir::Terminator<'_>) -> bool {
matches!(
terminator.kind,
TerminatorKind::SwitchInt { .. } | TerminatorKind::Assert { .. }
)
}
fn terminator_definition_reason(terminator: &rustc_middle::mir::Terminator<'_>) -> KeepReason {
match terminator.kind {
TerminatorKind::Call { .. } => KeepReason::UnknownEffect,
_ => KeepReason::Definition,
}
}
fn terminator_use_reason(terminator: &rustc_middle::mir::Terminator<'_>) -> KeepReason {
match terminator.kind {
TerminatorKind::SwitchInt { .. } | TerminatorKind::Assert { .. } => {
KeepReason::PathCondition
}
TerminatorKind::Drop { .. } => KeepReason::Invalidation,
TerminatorKind::Call { .. } => KeepReason::UnknownEffect,
_ => KeepReason::UnknownEffect,
}
}
fn terminator_may_havoc(terminator: &rustc_middle::mir::Terminator<'_>) -> bool {
matches!(terminator.kind, TerminatorKind::Call { .. })
}
fn collect_statement_uses<'tcx>(
statement: &'tcx rustc_middle::mir::Statement<'tcx>,
block: BasicBlock,
statement_index: usize,
flow: &DataflowGraph,
) -> RelevantPlaces {
let mut uses = RelevantPlaces::new();
let def_locals = match &statement.kind {
StatementKind::Assign(box (place, _)) => {
vec![place.local]
}
StatementKind::StorageDead(local) => vec![*local],
_ => Vec::new(),
};
for &local in &def_locals {
for &edge_idx in &flow.node(local).in_edges {
let edge = &flow.edges[edge_idx];
if edge.block == block.as_usize() && edge.statement_index == statement_index {
uses.insert_local(edge.src);
}
}
}
if let StatementKind::Assign(box (_, rvalue)) = &statement.kind {
for operand in super::super::def_use::rvalue_operands(rvalue) {
uses.extend(operand_uses(operand));
}
}
uses
}