mirsa-domains 0.2.0

Abstract interpretation domains for mirsa
use rustc_hir::def_id::DefId;
use rustc_middle::mir::{Body, Terminator, TerminatorKind};
use rustc_middle::ty::{TyCtxt, TyKind};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum ContractCall {
    NonZeroNewUnchecked,
    NonNullNewUnchecked,
    CStrFromPtr,
    SliceFromRawParts,
    SliceFromRawPartsMut,
    VecFromRawParts,
    PtrRead,
    PtrWrite,
    PtrCopyNonoverlapping,
    SliceGetUnchecked,
    SliceGetUncheckedMut,
    SliceSplitAtUnchecked,
    SliceSplitAtMutUnchecked,
}

impl ContractCall {
    pub(crate) fn has_internval_contract(self) -> bool {
        matches!(
            self,
            ContractCall::NonZeroNewUnchecked
                | ContractCall::SliceGetUnchecked
                | ContractCall::SliceGetUncheckedMut
                | ContractCall::SliceSplitAtUnchecked
                | ContractCall::SliceSplitAtMutUnchecked
                | ContractCall::PtrCopyNonoverlapping
        )
    }

    pub(crate) fn has_nullptr_contract(self) -> bool {
        matches!(
            self,
            ContractCall::NonNullNewUnchecked
                | ContractCall::CStrFromPtr
                | ContractCall::SliceFromRawParts
                | ContractCall::SliceFromRawPartsMut
                | ContractCall::VecFromRawParts
                | ContractCall::PtrRead
                | ContractCall::PtrWrite
                | ContractCall::PtrCopyNonoverlapping
        )
    }
}

fn call_def_id<'tcx>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    term: &Terminator<'tcx>,
) -> Option<DefId> {
    let TerminatorKind::Call { func, .. } = &term.kind else {
        return None;
    };
    let TyKind::FnDef(def_id, _) = func.ty(&body.local_decls, tcx).kind() else {
        return None;
    };
    Some(*def_id)
}

pub(crate) fn call_path_segments<'tcx>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    term: &Terminator<'tcx>,
) -> Option<Vec<String>> {
    let def_id = call_def_id(tcx, body, term)?;
    Some(def_path_segments(tcx, def_id))
}

pub(crate) fn classify_call<'tcx>(
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    term: &Terminator<'tcx>,
) -> Option<ContractCall> {
    let path = call_path_segments(tcx, body, term)?;
    classify_path(&path)
}

fn def_path_segments(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
    if let Some(impl_def_id) = tcx.impl_of_method(def_id) {
        let mut segments = match tcx.type_of(impl_def_id).instantiate_identity().kind() {
            TyKind::Adt(def, _) => item_def_path_segments(tcx, def.did()),
            _ => item_def_path_segments(tcx, impl_def_id),
        };
        if let Some(name) = def_path_item_name(tcx, def_id) {
            segments.push(name);
        }
        return segments;
    }

    if let Some(trait_def_id) = tcx.trait_of_item(def_id) {
        let mut segments = item_def_path_segments(tcx, trait_def_id);
        if let Some(name) = def_path_item_name(tcx, def_id) {
            segments.push(name);
        }
        return segments;
    }

    item_def_path_segments(tcx, def_id)
}

fn item_def_path_segments(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
    let def_path = tcx.def_path(def_id);
    let mut segments = vec![tcx.crate_name(def_path.krate).as_str().to_string()];
    segments.extend(
        def_path
            .data
            .iter()
            .filter_map(|component| component.data.get_opt_name())
            .map(|symbol| symbol.as_str().to_string()),
    );
    segments
}

fn def_path_item_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
    tcx.def_key(def_id)
        .disambiguated_data
        .data
        .get_opt_name()
        .map(|symbol| symbol.as_str().to_string())
}

fn classify_path(path: &[String]) -> Option<ContractCall> {
    let path = path.iter().map(String::as_str).collect::<Vec<_>>();

    match path.as_slice() {
        [.., "copy_nonoverlapping"] => Some(ContractCall::PtrCopyNonoverlapping),
        [.., ty, "new_unchecked"] if ty.starts_with("NonZero") => {
            Some(ContractCall::NonZeroNewUnchecked)
        }
        [.., "NonNull", "new_unchecked"] => Some(ContractCall::NonNullNewUnchecked),
        [.., "CStr", "from_ptr"] => Some(ContractCall::CStrFromPtr),
        [.., "from_raw_parts_mut"] if path.contains(&"slice") => {
            Some(ContractCall::SliceFromRawPartsMut)
        }
        [.., "Vec", "from_raw_parts"] | [.., "vec", "from_raw_parts"] => {
            Some(ContractCall::VecFromRawParts)
        }
        [.., "from_raw_parts"] if path.contains(&"slice") => Some(ContractCall::SliceFromRawParts),
        [.., "read"] if path.contains(&"ptr") => Some(ContractCall::PtrRead),
        [.., "write"] if path.contains(&"ptr") => Some(ContractCall::PtrWrite),
        [.., "get_unchecked_mut"] => Some(ContractCall::SliceGetUncheckedMut),
        [.., "get_unchecked"] => Some(ContractCall::SliceGetUnchecked),
        [.., "split_at_mut_unchecked"] => Some(ContractCall::SliceSplitAtMutUnchecked),
        [.., "split_at_unchecked"] => Some(ContractCall::SliceSplitAtUnchecked),
        _ => None,
    }
}