#![allow(clippy::similar_names)]
use std::sync::{Arc, OnceLock};
use crate::visitors::{Descend, for_each_expr_without_closures};
use crate::{get_unique_builtin_attr, sym};
use arrayvec::ArrayVec;
use rustc_ast::{FormatArgs, FormatArgument, FormatPlaceholder};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::{self as hir, Expr, ExprKind, HirId, Node, QPath};
use rustc_lint::{LateContext, LintContext};
use rustc_span::def_id::DefId;
use rustc_span::hygiene::{self, MacroKind, SyntaxContext};
use rustc_span::{BytePos, ExpnData, ExpnId, ExpnKind, Span, SpanData, Symbol};
use std::ops::ControlFlow;
const FORMAT_MACRO_DIAG_ITEMS: &[Symbol] = &[
sym::assert_eq_macro,
sym::assert_macro,
sym::assert_ne_macro,
sym::core_panic_macro,
sym::debug_assert_eq_macro,
sym::debug_assert_macro,
sym::debug_assert_ne_macro,
sym::eprint_macro,
sym::eprintln_macro,
sym::format_args_macro,
sym::format_macro,
sym::print_macro,
sym::println_macro,
sym::std_panic_macro,
sym::todo_macro,
sym::unimplemented_macro,
sym::write_macro,
sym::writeln_macro,
];
pub fn is_format_macro(cx: &LateContext<'_>, macro_def_id: DefId) -> bool {
if let Some(name) = cx.tcx.get_diagnostic_name(macro_def_id) {
FORMAT_MACRO_DIAG_ITEMS.contains(&name)
} else {
get_unique_builtin_attr(
cx.sess(),
#[allow(deprecated)]
cx.tcx.get_all_attrs(macro_def_id),
sym::format_args,
)
.is_some()
}
}
#[derive(Debug)]
pub struct MacroCall {
pub def_id: DefId,
pub kind: MacroKind,
pub expn: ExpnId,
pub span: Span,
}
impl MacroCall {
pub fn is_local(&self) -> bool {
span_is_local(self.span)
}
}
pub fn expn_backtrace(mut span: Span) -> impl Iterator<Item = (ExpnId, ExpnData)> {
std::iter::from_fn(move || {
let ctxt = span.ctxt();
if ctxt == SyntaxContext::root() {
return None;
}
let expn = ctxt.outer_expn();
let data = expn.expn_data();
span = data.call_site;
Some((expn, data))
})
}
pub fn span_is_local(span: Span) -> bool {
!span.from_expansion() || expn_is_local(span.ctxt().outer_expn())
}
pub fn expn_is_local(expn: ExpnId) -> bool {
if expn == ExpnId::root() {
return true;
}
let data = expn.expn_data();
let backtrace = expn_backtrace(data.call_site);
std::iter::once((expn, data))
.chain(backtrace)
.find_map(|(_, data)| data.macro_def_id)
.is_none_or(DefId::is_local)
}
pub fn macro_backtrace(span: Span) -> impl Iterator<Item = MacroCall> {
expn_backtrace(span).filter_map(|(expn, data)| match data {
ExpnData {
kind: ExpnKind::Macro(kind, _),
macro_def_id: Some(def_id),
call_site: span,
..
} => Some(MacroCall {
def_id,
kind,
expn,
span,
}),
_ => None,
})
}
pub fn root_macro_call(span: Span) -> Option<MacroCall> {
macro_backtrace(span).last()
}
pub fn matching_root_macro_call(cx: &LateContext<'_>, span: Span, name: Symbol) -> Option<MacroCall> {
root_macro_call(span).filter(|mc| cx.tcx.is_diagnostic_item(name, mc.def_id))
}
pub fn root_macro_call_first_node(cx: &LateContext<'_>, node: &impl HirNode) -> Option<MacroCall> {
if first_node_in_macro(cx, node) != Some(ExpnId::root()) {
return None;
}
root_macro_call(node.span())
}
pub fn first_node_macro_backtrace(cx: &LateContext<'_>, node: &impl HirNode) -> impl Iterator<Item = MacroCall> {
let span = node.span();
first_node_in_macro(cx, node)
.into_iter()
.flat_map(move |expn| macro_backtrace(span).take_while(move |macro_call| macro_call.expn != expn))
}
pub fn first_node_in_macro(cx: &LateContext<'_>, node: &impl HirNode) -> Option<ExpnId> {
let expn = macro_backtrace(node.span()).next()?.expn;
let mut parent_iter = cx.tcx.hir_parent_iter(node.hir_id());
let (parent_id, _) = match parent_iter.next() {
None => return Some(ExpnId::root()),
Some((_, Node::Stmt(_))) => match parent_iter.next() {
None => return Some(ExpnId::root()),
Some(next) => next,
},
Some(next) => next,
};
let parent_span = cx.tcx.hir_span(parent_id);
let Some(parent_macro_call) = macro_backtrace(parent_span).next() else {
return Some(ExpnId::root());
};
if parent_macro_call.expn.is_descendant_of(expn) {
return None;
}
Some(parent_macro_call.expn)
}
pub fn is_panic(cx: &LateContext<'_>, def_id: DefId) -> bool {
let Some(name) = cx.tcx.get_diagnostic_name(def_id) else {
return false;
};
matches!(
name,
sym::core_panic_macro
| sym::std_panic_macro
| sym::core_panic_2015_macro
| sym::std_panic_2015_macro
| sym::core_panic_2021_macro
)
}
pub fn is_assert_macro(cx: &LateContext<'_>, def_id: DefId) -> bool {
let Some(name) = cx.tcx.get_diagnostic_name(def_id) else {
return false;
};
matches!(name, sym::assert_macro | sym::debug_assert_macro)
}
#[derive(Debug)]
pub enum PanicCall<'a> {
DefaultMessage,
Str2015(&'a Expr<'a>),
Display(&'a Expr<'a>),
Format(&'a Expr<'a>),
}
impl<'a> PanicCall<'a> {
pub fn parse(expr: &'a Expr<'a>) -> Option<Self> {
let ExprKind::Call(callee, args) = &expr.kind else {
return None;
};
let ExprKind::Path(QPath::Resolved(_, path)) = &callee.kind else {
return None;
};
let name = path.segments.last().unwrap().ident.name;
let [arg, rest @ ..] = args else {
return None;
};
let result = match name {
sym::panic | sym::begin_panic | sym::panic_str_2015 => {
if arg.span.eq_ctxt(expr.span) || arg.span.is_dummy() {
Self::DefaultMessage
} else {
Self::Str2015(arg)
}
},
sym::panic_display => {
let ExprKind::AddrOf(_, _, e) = &arg.kind else {
return None;
};
Self::Display(e)
},
sym::panic_fmt => Self::Format(arg),
sym::assert_failed => {
if rest.len() != 3 {
return None;
}
let msg_arg = &rest[2];
match msg_arg.kind {
ExprKind::Call(_, [fmt_arg]) => Self::Format(fmt_arg),
_ => Self::DefaultMessage,
}
},
_ => return None,
};
Some(result)
}
pub fn is_default_message(&self) -> bool {
matches!(self, Self::DefaultMessage)
}
}
pub fn find_assert_args<'a>(
cx: &LateContext<'_>,
expr: &'a Expr<'a>,
expn: ExpnId,
) -> Option<(&'a Expr<'a>, PanicCall<'a>)> {
find_assert_args_inner(cx, expr, expn).map(|([e], p)| (e, p))
}
pub fn find_assert_eq_args<'a>(
cx: &LateContext<'_>,
expr: &'a Expr<'a>,
expn: ExpnId,
) -> Option<(&'a Expr<'a>, &'a Expr<'a>, PanicCall<'a>)> {
find_assert_args_inner(cx, expr, expn).map(|([a, b], p)| (a, b, p))
}
fn find_assert_args_inner<'a, const N: usize>(
cx: &LateContext<'_>,
expr: &'a Expr<'a>,
expn: ExpnId,
) -> Option<([&'a Expr<'a>; N], PanicCall<'a>)> {
let macro_id = expn.expn_data().macro_def_id?;
let (expr, expn) = match cx.tcx.item_name(macro_id).as_str().strip_prefix("debug_") {
None => (expr, expn),
Some(inner_name) => find_assert_within_debug_assert(cx, expr, expn, Symbol::intern(inner_name))?,
};
let mut args = ArrayVec::new();
let panic_expn = for_each_expr_without_closures(expr, |e| {
if args.is_full() {
match PanicCall::parse(e) {
Some(expn) => ControlFlow::Break(expn),
None => ControlFlow::Continue(Descend::Yes),
}
} else if is_assert_arg(cx, e, expn) {
args.push(e);
ControlFlow::Continue(Descend::No)
} else {
ControlFlow::Continue(Descend::Yes)
}
});
let args = args.into_inner().ok()?;
Some((args, panic_expn?))
}
fn find_assert_within_debug_assert<'a>(
cx: &LateContext<'_>,
expr: &'a Expr<'a>,
expn: ExpnId,
assert_name: Symbol,
) -> Option<(&'a Expr<'a>, ExpnId)> {
for_each_expr_without_closures(expr, |e| {
if !e.span.from_expansion() {
return ControlFlow::Continue(Descend::No);
}
let e_expn = e.span.ctxt().outer_expn();
if e_expn == expn {
ControlFlow::Continue(Descend::Yes)
} else if e_expn.expn_data().macro_def_id.map(|id| cx.tcx.item_name(id)) == Some(assert_name) {
ControlFlow::Break((e, e_expn))
} else {
ControlFlow::Continue(Descend::No)
}
})
}
fn is_assert_arg(cx: &LateContext<'_>, expr: &Expr<'_>, assert_expn: ExpnId) -> bool {
if !expr.span.from_expansion() {
return true;
}
let result = macro_backtrace(expr.span).try_for_each(|macro_call| {
if macro_call.expn == assert_expn {
ControlFlow::Break(false)
} else {
match cx.tcx.item_name(macro_call.def_id) {
sym::cfg => ControlFlow::Continue(()),
_ => ControlFlow::Break(true),
}
}
});
match result {
ControlFlow::Break(is_assert_arg) => is_assert_arg,
ControlFlow::Continue(()) => true,
}
}
#[derive(Default, Clone)]
pub struct FormatArgsStorage(Arc<OnceLock<FxHashMap<Span, FormatArgs>>>);
impl FormatArgsStorage {
pub fn get(&self, cx: &LateContext<'_>, start: &Expr<'_>, expn_id: ExpnId) -> Option<&FormatArgs> {
let format_args_expr = for_each_expr_without_closures(start, |expr| {
let ctxt = expr.span.ctxt();
if ctxt.outer_expn().is_descendant_of(expn_id) {
if macro_backtrace(expr.span)
.map(|macro_call| cx.tcx.item_name(macro_call.def_id))
.any(|name| matches!(name, sym::const_format_args | sym::format_args | sym::format_args_nl))
{
ControlFlow::Break(expr)
} else {
ControlFlow::Continue(Descend::Yes)
}
} else {
ControlFlow::Continue(Descend::No)
}
})?;
debug_assert!(self.0.get().is_some(), "`FormatArgsStorage` not yet populated");
self.0.get()?.get(&format_args_expr.span.with_parent(None))
}
pub fn set(&self, format_args: FxHashMap<Span, FormatArgs>) {
self.0
.set(format_args)
.expect("`FormatArgsStorage::set` should only be called once");
}
}
pub fn find_format_arg_expr<'hir>(start: &'hir Expr<'hir>, target: &FormatArgument) -> Option<&'hir Expr<'hir>> {
let SpanData {
lo,
hi,
ctxt,
parent: _,
} = target.expr.span.data();
for_each_expr_without_closures(start, |expr| {
let data = expr.span.data();
if data.lo == lo && data.hi == hi && data.ctxt == ctxt {
ControlFlow::Break(expr)
} else {
ControlFlow::Continue(())
}
})
}
pub fn format_placeholder_format_span(placeholder: &FormatPlaceholder) -> Option<Span> {
let base = placeholder.span?.data();
Some(Span::new(
placeholder.argument.span?.hi(),
base.hi - BytePos(1),
base.ctxt,
base.parent,
))
}
pub fn format_args_inputs_span(format_args: &FormatArgs) -> Span {
match format_args.arguments.explicit_args() {
[] => format_args.span,
[.., last] => format_args
.span
.to(hygiene::walk_chain(last.expr.span, format_args.span.ctxt())),
}
}
pub fn format_arg_removal_span(format_args: &FormatArgs, index: usize) -> Option<Span> {
let ctxt = format_args.span.ctxt();
let current = hygiene::walk_chain(format_args.arguments.by_index(index)?.expr.span, ctxt);
let prev = if index == 0 {
format_args.span
} else {
hygiene::walk_chain(format_args.arguments.by_index(index - 1)?.expr.span, ctxt)
};
Some(current.with_lo(prev.hi()))
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum FormatParamUsage {
Argument,
Width,
Precision,
}
pub trait HirNode {
fn hir_id(&self) -> HirId;
fn span(&self) -> Span;
}
macro_rules! impl_hir_node {
($($t:ident),*) => {
$(impl HirNode for hir::$t<'_> {
fn hir_id(&self) -> HirId {
self.hir_id
}
fn span(&self) -> Span {
self.span
}
})*
};
}
impl_hir_node!(Expr, Pat);
impl HirNode for hir::Item<'_> {
fn hir_id(&self) -> HirId {
self.hir_id()
}
fn span(&self) -> Span {
self.span
}
}