use std::collections::BTreeSet;
use clippy_utils::diagnostics::span_lint_and_then;
use clippy_utils::source::indent_of;
use clippy_utils::sym;
use clippy_utils::ty::implements_trait;
use rustc_errors::Applicability;
use rustc_hir as hir;
use rustc_hir::attrs::AttributeKind;
use rustc_lint::{LateContext, LateLintPass, LintStore};
use rustc_middle::middle::privacy::Level;
use rustc_middle::ty::{self, TyCtxt};
use rustc_session::{declare_tool_lint, impl_lint_pass};
use rustc_span::def_id::{CRATE_DEF_ID, LocalDefId};
use crate::common::{DefaultState, resolve_string_set, resolved_state};
declare_tool_lint! {
pub perfectionist::NON_EXHAUSTIVE_ERROR,
Warn,
"error-shaped type is missing `#[non_exhaustive]`",
report_in_external_macro: false
}
pub(crate) const DEFAULT_STATE: DefaultState = DefaultState::Inactive;
const CONFIG_KEY: &str = "perfectionist::non_exhaustive_error";
#[derive(Debug, Clone, Copy, Default, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
enum RequireFor {
#[default]
Pub,
PubCrate,
All,
}
const DEFAULT_SUFFIXES: &[&str] = &["Error"];
#[derive(Debug, Default, serde::Deserialize)]
#[serde(default, deny_unknown_fields, rename_all = "snake_case")]
struct Config {
require_for: RequireFor,
extra_suffixes: Vec<String>,
ignore_suffixes: Vec<String>,
}
pub struct NonExhaustiveError {
require_for: RequireFor,
suffixes: BTreeSet<String>,
}
impl NonExhaustiveError {
fn new() -> Self {
let config: Config = dylint_linting::config_or_default(CONFIG_KEY);
let suffixes = resolve_string_set(
DEFAULT_SUFFIXES,
config.extra_suffixes,
config.ignore_suffixes,
);
Self {
require_for: config.require_for,
suffixes,
}
}
fn visibility_qualifies(&self, tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
match self.require_for {
RequireFor::All => true,
RequireFor::Pub => is_externally_reachable(tcx, def_id),
RequireFor::PubCrate => {
if is_externally_reachable(tcx, def_id) {
return true;
}
matches!(
tcx.visibility(def_id.to_def_id()),
ty::Visibility::Restricted(scope) if scope == CRATE_DEF_ID.to_def_id(),
)
}
}
}
fn name_matches(&self, name: &str) -> bool {
self.suffixes
.iter()
.any(|suffix| name.ends_with(suffix.as_str()))
}
}
impl_lint_pass!(NonExhaustiveError => [NON_EXHAUSTIVE_ERROR]);
pub fn register_lint(lint_store: &mut LintStore) {
lint_store.register_lints(&[NON_EXHAUSTIVE_ERROR]);
}
pub fn register_pass(lint_store: &mut LintStore) {
if let DefaultState::Inactive = resolved_state("non_exhaustive_error", DEFAULT_STATE) {
return;
}
lint_store.register_late_pass(|_| Box::new(NonExhaustiveError::new()));
}
impl<'tcx> LateLintPass<'tcx> for NonExhaustiveError {
fn check_item(&mut self, cx: &LateContext<'tcx>, item: &'tcx hir::Item<'tcx>) {
let (ident, kind_label) = match item.kind {
hir::ItemKind::Enum(ident, _, _) => (ident, "enum"),
hir::ItemKind::Struct(ident, _, ref data) => {
if !is_sum_like(cx, data) {
return;
}
(ident, "struct")
}
_ => return,
};
let local_def_id = item.owner_id.def_id;
if !self.visibility_qualifies(cx.tcx, local_def_id) {
return;
}
if !self.name_matches(ident.name.as_str()) && !implements_error_trait(cx, local_def_id) {
return;
}
let attrs = cx.tcx.hir_attrs(item.hir_id());
if attrs.iter().any(|attr| {
matches!(
attr,
hir::Attribute::Parsed(AttributeKind::NonExhaustive(_)),
)
}) {
return;
}
emit(cx, item, kind_label, ident);
}
}
fn is_externally_reachable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
tcx.effective_visibilities(())
.is_public_at_level(def_id, Level::Reexported)
}
fn implements_error_trait(cx: &LateContext<'_>, def_id: LocalDefId) -> bool {
let Some(error_trait) = cx.tcx.get_diagnostic_item(sym::Error) else {
return false;
};
let ty = cx
.tcx
.type_of(def_id)
.instantiate_identity()
.skip_normalization();
implements_trait(cx, ty, error_trait, &[])
}
fn is_sum_like(cx: &LateContext<'_>, data: &hir::VariantData<'_>) -> bool {
let fields = data.fields();
if fields.len() != 1 {
return false;
}
let field_ty = cx
.tcx
.type_of(fields[0].def_id)
.instantiate_identity()
.skip_normalization();
matches!(field_ty.kind(), ty::Adt(adt_def, _) if adt_def.is_enum())
}
fn emit(cx: &LateContext<'_>, item: &hir::Item<'_>, kind_label: &str, ident: rustc_span::Ident) {
let name = ident.name.as_str();
let message = format!("{kind_label} `{name}` is missing `#[non_exhaustive]`");
let insert_at = item.span.shrink_to_lo();
let indent = indent_of(cx, item.span).unwrap_or(0);
let replacement = format!("#[non_exhaustive]\n{:indent$}", "", indent = indent);
span_lint_and_then(cx, NON_EXHAUSTIVE_ERROR, ident.span, message, |diag| {
diag.span_suggestion(
insert_at,
"add `#[non_exhaustive]` to keep new variants from being a SemVer break",
replacement,
Applicability::MaybeIncorrect,
);
});
}