mirsa-framework 0.2.3

Generic analysis framework for mirsa
use crate::forward::{
    ForwardSemantics, PathForwardAnalysisConfig, PathForwardAnalysisResult,
    run_path_sensitive_forward_analysis_with_config,
};
use mirsa_core::cfg::Cfg;
use rustc_hir::def_id::DefId;
use rustc_middle::mir::{
    Body, Local, Location, Place, ProjectionElem, Terminator, TerminatorKind, VarDebugInfoContents,
};
use rustc_middle::ty::TyCtxt;
use std::collections::HashMap;

pub trait StateEntries<'tcx> {
    fn entries(&self) -> Vec<(Place<'tcx>, String)>;
    fn should_print_entry(&self, _place: Place<'tcx>) -> bool {
        true
    }
}

pub fn print_function_header<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) {
    let name = tcx.def_path_str(def_id);
    println!("== Function: {name} ==");
}

pub fn pick_return_or_last_bb<'tcx>(body: &Body<'tcx>, states_len: usize) -> usize {
    let return_bb_idx = body
        .basic_blocks
        .iter_enumerated()
        .find(|(_, bbdata)| {
            bbdata.statements.is_empty()
                && matches!(
                    bbdata.terminator.as_ref().map(|t| &t.kind),
                    Some(TerminatorKind::Return)
                )
        })
        .map(|(bb, _)| bb.index());
    return_bb_idx.unwrap_or_else(|| states_len.saturating_sub(1))
}

pub fn collect_local_names<'tcx>(body: &Body<'tcx>) -> HashMap<Local, String> {
    let mut names = HashMap::new();
    for info in &body.var_debug_info {
        if let VarDebugInfoContents::Place(place) = &info.value {
            if place.projection.is_empty() {
                names
                    .entry(place.local)
                    .or_insert_with(|| info.name.to_string());
            }
        }
    }
    names
}

pub fn format_place_label<'tcx>(
    place: Place<'tcx>,
    local_names: &HashMap<Local, String>,
) -> String {
    let mut label = local_names
        .get(&place.local)
        .cloned()
        .unwrap_or_else(|| format!("{:?}", place.local));

    for elem in place.projection.iter() {
        match elem {
            ProjectionElem::ConstantIndex {
                offset,
                from_end: false,
                ..
            } => {
                label.push_str(&format!("[{offset}]"));
            }
            ProjectionElem::Index(local) => {
                let idx = local_names
                    .get(&local)
                    .cloned()
                    .unwrap_or_else(|| format!("{local:?}"));
                label.push_str(&format!("[{idx}]"));
            }
            ProjectionElem::Field(field, _) => {
                label.push_str(&format!(".{}", field.as_usize()));
            }
            _ => return format!("{place:?}"),
        }
    }

    label
}

pub fn visible_entries<'tcx, S>(body: &Body<'tcx>, state: &S) -> Vec<(String, String)>
where
    S: StateEntries<'tcx>,
{
    let local_names = collect_local_names(body);
    let mut entries: Vec<(String, String)> = state
        .entries()
        .into_iter()
        .filter(|(place, _)| state.should_print_entry(*place))
        .map(|(place, value)| (format_place_label(place, &local_names), value))
        .filter(|(label, _)| !label.starts_with('_'))
        .collect();
    entries.sort_by(|a, b| a.0.cmp(&b.0));
    entries.dedup();
    entries
}

pub fn print_call_pre_states<'tcx, S>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    result: &PathForwardAnalysisResult<S>,
    mut state_before: impl FnMut(
        TyCtxt<'tcx>,
        &Body<'tcx>,
        &PathForwardAnalysisResult<S>,
        Location,
    ) -> Option<S>,
    mut should_print_call: impl FnMut(TyCtxt<'tcx>, &Body<'tcx>, &Terminator<'tcx>) -> bool,
) where
    S: StateEntries<'tcx>,
{
    for (bb, bbdata) in body.basic_blocks.iter_enumerated() {
        let Some(term) = bbdata.terminator.as_ref() else {
            continue;
        };
        let TerminatorKind::Call { .. } = &term.kind else {
            continue;
        };
        if !should_print_call(tcx, body, term) {
            continue;
        }
        let location = Location {
            block: bb,
            statement_index: bbdata.statements.len(),
        };
        let Some(state) = state_before(tcx, body, result, location) else {
            continue;
        };
        let entries = visible_entries(body, &state);
        if entries.is_empty() {
            continue;
        }
        println!("  unsafe pre-state @ bb{}:", bb.index());
        let width = entries
            .iter()
            .map(|(label, _)| label.len())
            .max()
            .unwrap_or(0);
        for (label, value) in entries {
            println!("    {label:width$} => {value}");
        }
    }
}

pub fn run_and_print_path_sensitive_analysis<'tcx, A>(
    tcx: TyCtxt<'tcx>,
    def_id: DefId,
    body: &Body<'tcx>,
    cfg: &Cfg,
    semantics: &A,
    config: PathForwardAnalysisConfig,
) where
    A: ForwardSemantics<'tcx>,
    A::State: StateEntries<'tcx>,
{
    let result = run_path_sensitive_analysis(tcx, body, cfg, semantics, config);
    print_function_header(tcx, def_id);
    let picked_bb_idx = pick_return_or_last_bb(body, result.out_states.len());
    if let Some(state) = result.out_states.get(picked_bb_idx) {
        println!("  bb{picked_bb_idx}:");
        println!("  locals: {:?}", body.var_debug_info);
        let entries = visible_entries(body, state);
        let place_width = entries
            .iter()
            .map(|(place, _)| place.len())
            .max()
            .unwrap_or(0);
        for (place, value) in entries {
            println!("    {place:place_width$} => {value}");
        }
    }
}

pub fn run_path_sensitive_analysis<'tcx, A>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    cfg: &Cfg,
    semantics: &A,
    config: PathForwardAnalysisConfig,
) -> PathForwardAnalysisResult<A::State>
where
    A: ForwardSemantics<'tcx>,
{
    run_path_sensitive_forward_analysis_with_config(tcx, body, cfg, semantics, config)
}

pub fn print_all_bb_states<'tcx, S>(body: &Body<'tcx>, states: &[S])
where
    S: StateEntries<'tcx>,
{
    println!("  locals: {:?}", body.var_debug_info);
    for (bb, state) in body
        .basic_blocks
        .iter_enumerated()
        .filter_map(|(bb, _)| states.get(bb.index()).map(|state| (bb, state)))
    {
        println!("  bb{}:", bb.index());
        let entries = visible_entries(body, state);
        let place_width = entries
            .iter()
            .map(|(place, _)| place.len())
            .max()
            .unwrap_or(0);
        for (place, value) in entries {
            println!("    {place:place_width$} => {value}");
        }
    }
}

pub fn print_final_analysis_result<'tcx, S>(
    body: &Body<'tcx>,
    result: &PathForwardAnalysisResult<S>,
) where
    S: StateEntries<'tcx>,
{
    println!("  final in-states:");
    print_all_bb_states(body, &result.in_states);

    println!("  final out-states:");
    print_all_bb_states(body, &result.out_states);
}