use rustc_hir::OwnerId;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypingMode};
use rustc_span::symbol::sym;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::{Obligation, ObligationCause};
use rustc_type_ir::TypeVisitableExt;
pub fn get_full_module_name(tcx: &TyCtxt<'_>, module_def_id: &OwnerId) -> String {
let krate_name = tcx
.crate_name(module_def_id.to_def_id().krate)
.to_ident_string();
let module_name = tcx.def_path_str(module_def_id.to_def_id());
if module_name.is_empty() {
krate_name
} else {
format!("{krate_name}::{module_name}")
}
}
pub fn implements_trait<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
ty: Ty<'tcx>,
trait_def_id: DefId,
) -> bool {
let cause = ObligationCause::dummy();
let trait_ref = ty::TraitRef::new(tcx, trait_def_id, [ty]);
let obligation = Obligation::new(tcx, cause, param_env, trait_ref);
let is_complex = ty.has_infer_types()
|| ty.has_opaque_types()
|| ty.walk().any(|t| {
if let Some(ty) = t.as_type() {
matches!(ty.kind(), ty::Alias(ty::Projection, _) | ty::Param(_))
} else if let Some(ct) = t.as_const() {
matches!(ct.kind(), ty::ConstKind::Param(_))
} else {
false
}
});
let infcx = if is_complex {
tcx.infer_ctxt().build(TypingMode::Analysis {
defining_opaque_types_and_generators: Default::default(),
})
} else {
tcx.infer_ctxt().build(TypingMode::Coherence)
};
infcx.predicate_may_hold(&obligation)
}
pub fn implements_error_trait<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
ty: Ty<'tcx>,
) -> bool {
match ty.kind() {
ty::TyKind::Int(_)
| ty::TyKind::Uint(_)
| ty::TyKind::Float(_)
| ty::TyKind::Bool
| ty::TyKind::Char => return false,
_ => {}
}
if let Some(error_trait_def_id) = tcx.get_diagnostic_item(sym::Error) {
implements_trait(tcx, param_env, ty, error_trait_def_id)
} else {
!ty.is_primitive()
}
}
pub fn get_canonical_trait_name(trait_name: &str) -> String {
if let Some(pos) = trait_name.find('<') {
trait_name[0..pos].to_string()
} else {
trait_name.to_string()
}
}
pub fn get_full_canonical_trait_name(crate_name: &str, trait_name: &str) -> String {
let canonical_name = get_canonical_trait_name(trait_name);
format!("{crate_name}::{canonical_name}")
}
pub fn get_canonical_trait_name_from_def_id(tcx: &TyCtxt<'_>, def_id: DefId) -> String {
let raw_trait_name = tcx.def_path_str(def_id);
get_canonical_trait_name(&raw_trait_name)
}
pub fn get_full_canonical_trait_name_from_def_id(tcx: &TyCtxt<'_>, def_id: DefId) -> String {
let crate_name = tcx.crate_name(def_id.krate).to_string();
let raw_trait_name = tcx.def_path_str(def_id);
get_full_canonical_trait_name(&crate_name, &raw_trait_name)
}
pub fn get_canonical_type_name(type_name: &str) -> String {
if let Some(pos) = type_name.find('<') {
type_name[0..pos].to_string()
} else {
type_name.to_string()
}
}