use rustc_hir::def_id::DefId;
use rustc_middle::mir::{AssertKind, Body, Rvalue, StatementKind, TerminatorKind};
use rustc_middle::ty::TyCtxt;
use rustc_span::Span;
use std::collections::{HashMap, HashSet};
fn resolve_span_to_callsite(span: Span) -> Span {
span.source_callsite()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PanicCategory {
Unwrap,
ExplicitPanic,
IndexBounds,
}
pub fn classify_panic_source(path: &str) -> Option<PanicCategory> {
if (path.contains("Option::<") || path.contains("::Option::"))
&& (path.ends_with(">::unwrap") || path.ends_with(">::expect"))
{
return Some(PanicCategory::Unwrap);
}
if (path.contains("Result::<") || path.contains("::Result::"))
&& (path.ends_with(">::unwrap")
|| path.ends_with(">::expect")
|| path.ends_with(">::unwrap_err")
|| path.ends_with(">::expect_err"))
{
return Some(PanicCategory::Unwrap);
}
if path.contains("slice_index")
|| path.contains("index_len_fail")
|| path.contains("slice_start_index")
|| path.contains("slice_end_index")
{
return Some(PanicCategory::IndexBounds);
}
if path.contains("core::panicking::")
|| path.contains("std::panicking::")
|| path.contains("begin_panic")
|| path.contains("panic_fmt")
|| path.contains("unreachable")
|| path.contains("assert_failed")
{
return Some(PanicCategory::ExplicitPanic);
}
None
}
#[derive(Debug)]
pub struct PanicViolation {
pub span: Span,
pub reason: String,
}
pub fn detect_panic_in_mir<'tcx>(
tcx: TyCtxt<'tcx>,
mir: &Body<'tcx>,
categories: &HashSet<PanicCategory>,
) -> Option<PanicViolation> {
let mut cache = HashMap::new();
analyze_mir(tcx, mir, &mut cache, categories)
}
fn analyze_mir<'tcx>(
tcx: TyCtxt<'tcx>,
mir: &Body<'tcx>,
cache: &mut HashMap<DefId, bool>,
categories: &HashSet<PanicCategory>,
) -> Option<PanicViolation> {
for bb_data in mir.basic_blocks.iter() {
for stmt in &bb_data.statements {
if let StatementKind::Assign(assign) = &stmt.kind
&& let Rvalue::Aggregate(kind, _) = &assign.1
&& let rustc_middle::mir::AggregateKind::Coroutine(def_id, _) = &**kind
&& def_id.krate == rustc_hir::def_id::LOCAL_CRATE
&& tcx.is_mir_available(*def_id)
{
let coroutine_mir = tcx.optimized_mir(*def_id);
if let Some(violation) = analyze_mir(tcx, coroutine_mir, cache, categories) {
return Some(violation);
}
}
}
}
for (_bb, bb_data) in mir.basic_blocks.iter_enumerated() {
let Some(terminator) = &bb_data.terminator else {
continue;
};
match &terminator.kind {
TerminatorKind::Assert { msg, .. } => {
if let AssertKind::BoundsCheck { .. } = &**msg
&& categories.contains(&PanicCategory::IndexBounds)
{
return Some(PanicViolation {
span: resolve_span_to_callsite(terminator.source_info.span),
reason: "index bounds check may panic".to_string(),
});
}
}
TerminatorKind::Call { func, args, .. } => {
if let Some((callee_def_id, _generics)) = func.const_fn_def() {
let path = tcx.def_path_str(callee_def_id);
for arg in args.iter() {
use rustc_middle::mir::Operand;
let closure_def_id = match &arg.node {
Operand::Constant(constant) => {
let ty = constant.const_.ty();
if let rustc_middle::ty::TyKind::Closure(def_id, _) = ty.kind() {
Some(*def_id)
} else if let rustc_middle::ty::TyKind::FnDef(def_id, _) = ty.kind()
{
Some(*def_id)
} else {
None
}
}
Operand::Move(place) | Operand::Copy(place) => {
let ty = place.ty(mir, tcx).ty;
if let rustc_middle::ty::TyKind::Closure(def_id, _) = ty.kind() {
Some(*def_id)
} else if let rustc_middle::ty::TyKind::FnDef(def_id, _) = ty.kind()
{
Some(*def_id)
} else {
None
}
}
Operand::RuntimeChecks(_) => None,
};
if let Some(closure_def_id) = closure_def_id {
if closure_def_id.krate == rustc_hir::def_id::LOCAL_CRATE
&& tcx.is_mir_available(closure_def_id)
&& function_panics_with_categories(
tcx,
closure_def_id,
cache,
categories,
)
{
return Some(PanicViolation {
span: resolve_span_to_callsite(terminator.source_info.span),
reason: format!("passes panicking closure to {path}"),
});
}
}
}
if let Some(category) = classify_panic_source(&path)
&& categories.contains(&category)
{
return Some(PanicViolation {
span: resolve_span_to_callsite(terminator.source_info.span),
reason: format!("calls panicking function: {path}"),
});
}
if should_analyze_transitively(tcx, callee_def_id)
&& function_panics_with_categories(tcx, callee_def_id, cache, categories)
{
return Some(PanicViolation {
span: resolve_span_to_callsite(terminator.source_info.span),
reason: format!("calls function that may panic: {path}"),
});
}
}
}
_ => {}
}
}
None
}
fn should_analyze_transitively(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
def_id.krate == rustc_hir::def_id::LOCAL_CRATE && tcx.is_mir_available(def_id)
}
fn function_panics_with_categories<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
cache: &mut HashMap<DefId, bool>,
categories: &HashSet<PanicCategory>,
) -> bool {
if let Some(&result) = cache.get(&def_id) {
return result;
}
cache.insert(def_id, false);
if !tcx.is_mir_available(def_id) {
return false;
}
let mir = tcx.optimized_mir(def_id);
let panics = analyze_mir(tcx, mir, cache, categories).is_some();
cache.insert(def_id, panics);
panics
}