use std::rc::Rc;
use either::Either;
use indexical::ToIndex;
use rustc_data_structures::{graph::Successors, work_queue::WorkQueue};
use rustc_index::IndexVec;
use rustc_middle::{
mir::{Body, Location, traversal},
ty::TyCtxt,
};
use rustc_mir_dataflow::{Analysis, Direction, JoinSemiLattice};
use rustc_utils::{
BodyExt,
mir::location_or_arg::{
LocationOrArg,
index::{LocationOrArgDomain, LocationOrArgIndex},
},
};
pub struct AnalysisResults<'tcx, A: Analysis<'tcx>> {
pub analysis: A,
location_domain: Rc<LocationOrArgDomain>,
state: IndexVec<LocationOrArgIndex, Rc<A::Domain>>,
}
impl<'tcx, A: Analysis<'tcx>> AnalysisResults<'tcx, A> {
pub fn state_at(&self, location: Location) -> &A::Domain {
&self.state[location.to_index(&self.location_domain)]
}
}
pub fn iterate_to_fixpoint<'tcx, A: Analysis<'tcx>>(
_tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
location_domain: Rc<LocationOrArgDomain>,
mut analysis: A,
) -> AnalysisResults<'tcx, A> {
let bottom_value = analysis.bottom_value(body);
let num_locs = body.all_locations().count();
let mut state = IndexVec::from_elem_n(bottom_value, num_locs);
analysis
.initialize_start_block(body, &mut state[Location::START.to_index(&location_domain)]);
let mut dirty_queue: WorkQueue<LocationOrArgIndex> = WorkQueue::with_none(num_locs);
if A::Direction::IS_FORWARD {
for (block, data) in traversal::reverse_postorder(body) {
for statement_index in 0 ..= data.statements.len() {
let location = Location {
block,
statement_index,
};
dirty_queue.insert(location.to_index(&location_domain));
}
}
}
while let Some(loc_index) = dirty_queue.pop() {
let LocationOrArg::Location(location) = *location_domain.value(loc_index) else {
unreachable!()
};
let next_locs = match body.stmt_at(location) {
Either::Left(statement) => {
analysis.apply_primary_statement_effect(
&mut state[loc_index],
statement,
location,
);
vec![location.successor_within_block()]
}
Either::Right(terminator) => {
analysis.apply_primary_terminator_effect(
&mut state[loc_index],
terminator,
location,
);
body
.basic_blocks
.successors(location.block)
.map(|block| Location {
block,
statement_index: 0,
})
.collect::<Vec<_>>()
}
};
for next_loc in next_locs {
let next_loc_index = location_domain.index(&LocationOrArg::Location(next_loc));
let (cur_state, next_state) = state.pick2_mut(loc_index, next_loc_index);
let changed = next_state.join(cur_state);
if changed {
dirty_queue.insert(next_loc_index);
}
}
}
let state = state.into_iter().map(Rc::new).collect();
AnalysisResults {
analysis,
location_domain,
state,
}
}