mirsa-analysis 0.2.3

Analysis orchestration for mirsa
use rustc_middle::mir::*;
use rustc_middle::ty::{TyCtxt, TyKind};

use super::state::CombinedState;
use mirsa_domains::interval::transfer::interval_of_const;

fn record_switch_fact<'tcx>(
    out: &mut CombinedState<'tcx>,
    discr: Operand<'tcx>,
    values_for_succ: &[u128],
    all_values: &[u128],
    is_otherwise: bool,
) {
    if is_otherwise {
        for value in all_values {
            out.symbolic.assume_ne_const(discr.clone(), *value);
        }
    } else if values_for_succ.len() == 1 {
        out.symbolic.assume_eq_const(discr, values_for_succ[0]);
    }
}

pub fn refine_edge<'tcx>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    pred: BasicBlock,
    succ: BasicBlock,
    in_state: &CombinedState<'tcx>,
) -> Option<CombinedState<'tcx>> {
    let term = body.basic_blocks[pred].terminator.as_ref()?;
    in_state.interval.debug(format_args!(
        "refine edge bb{} -> bb{} by {:?}",
        pred.index(),
        succ.index(),
        term.kind
    ));

    match &term.kind {
        TerminatorKind::Goto { target } => {
            if *target != succ {
                return None;
            }
            let mut out = in_state.clone();
            out.refine_with_path_facts(tcx, &body.local_decls)
                .then_some(out)
        }
        TerminatorKind::SwitchInt { discr, targets } => {
            let mut values_for_succ: Vec<u128> = Vec::new();
            let mut all_values: Vec<u128> = Vec::new();
            for (val, target) in targets.iter() {
                all_values.push(val);
                if target == succ {
                    values_for_succ.push(val);
                }
            }
            let is_otherwise = targets.otherwise() == succ;
            if values_for_succ.is_empty() && !is_otherwise {
                return None;
            }

            let mut out = in_state.clone();
            match discr {
                Operand::Copy(_) | Operand::Move(_) => {
                    record_switch_fact(
                        &mut out,
                        discr.clone(),
                        &values_for_succ,
                        &all_values,
                        is_otherwise,
                    );

                    let discr_ty = discr.ty(&body.local_decls, tcx);
                    if matches!(discr_ty.kind(), TyKind::Bool) {
                        if is_otherwise && all_values.contains(&0) && all_values.contains(&1) {
                            return None;
                        }
                    }

                    out.refine_with_path_facts(tcx, &body.local_decls)
                        .then_some(out)
                }
                Operand::Constant(c) => {
                    let c_iv = interval_of_const(c);
                    if c_iv.is_empty() || c_iv.low != c_iv.high {
                        return out
                            .refine_with_path_facts(tcx, &body.local_decls)
                            .then_some(out);
                    }
                    let c_val = c_iv.low;
                    let discr_ty = discr.ty(&body.local_decls, tcx);
                    let values_for_succ_i128: Vec<i128> = values_for_succ
                        .iter()
                        .filter_map(|value| {
                            mirsa_domains::interval::transfer::switch_value_to_i128(
                                tcx, discr_ty, *value,
                            )
                        })
                        .collect();
                    let all_values_i128: Vec<i128> = all_values
                        .iter()
                        .filter_map(|value| {
                            mirsa_domains::interval::transfer::switch_value_to_i128(
                                tcx, discr_ty, *value,
                            )
                        })
                        .collect();
                    let matches_succ_value = values_for_succ_i128.contains(&c_val);
                    let matches_any_value = all_values_i128.contains(&c_val);
                    if matches_succ_value || (is_otherwise && !matches_any_value) {
                        out.refine_with_path_facts(tcx, &body.local_decls)
                            .then_some(out)
                    } else {
                        None
                    }
                }
            }
        }
        _ => {
            let mut out = in_state.clone();
            out.refine_with_path_facts(tcx, &body.local_decls)
                .then_some(out)
        }
    }
}