#![doc(test(
no_crate_inject,
attr(
deny(warnings, rust_2018_idioms, single_use_lifetimes),
allow(dead_code, unused_variables)
)
))]
#![forbid(unsafe_code)]
#![warn(future_incompatible, rust_2018_idioms, unreachable_pub)]
#![cfg_attr(test, warn(single_use_lifetimes))]
#![warn(clippy::all, clippy::default_trait_access)]
#[allow(unused_extern_crates)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, ToTokens};
use std::{collections::hash_map::DefaultHasher, hash::Hasher, mem};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
token,
visit_mut::VisitMut,
Attribute, Error, GenericParam, Generics, Ident, ImplItem, ItemImpl, ItemTrait, PredicateType,
Result, Token, TraitItem, TraitItemConst, TraitItemMethod, Type, TypeParam, TypePath,
Visibility, WherePredicate,
};
macro_rules! error {
($span:expr, $msg:expr) => {
syn::Error::new_spanned(&$span, $msg)
};
($span:expr, $($tt:tt)*) => {
error!($span, format!($($tt)*))
};
}
#[proc_macro_attribute]
pub fn ext(args: TokenStream, input: TokenStream) -> TokenStream {
let mut args: Args = syn::parse_macro_input!(args);
if args.name.is_none() {
args.name = Some(format_ident!("__ExtTrait{}", hash(&input)));
}
let mut item: ItemImpl = syn::parse_macro_input!(input);
trait_from_impl(&mut item, args)
.map(ToTokens::into_token_stream)
.map(|mut tokens| {
tokens.extend(item.into_token_stream());
tokens
})
.unwrap_or_else(Error::into_compile_error)
.into()
}
struct Args {
vis: Option<Visibility>,
name: Option<Ident>,
}
impl Parse for Args {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let vis: Visibility = input.parse()?;
let name: Option<Ident> = input.parse()?;
Ok(Args { vis: if let Visibility::Inherited = vis { None } else { Some(vis) }, name })
}
}
fn determine_trait_generics<'a>(generics: &mut Generics, self_ty: &'a Type) -> Option<&'a Ident> {
if let Type::Path(TypePath { path, qself: None }) = self_ty {
if let Some(ident) = path.get_ident() {
let i = generics.params.iter().position(|param| {
if let GenericParam::Type(param) = param { param.ident == *ident } else { false }
});
if let Some(i) = i {
let mut params = mem::replace(&mut generics.params, Punctuated::new())
.into_iter()
.collect::<Vec<_>>();
let param = params.remove(i);
generics.params = params.into_iter().collect();
if let GenericParam::Type(TypeParam {
colon_token: Some(colon_token),
bounds,
..
}) = param
{
generics.make_where_clause().predicates.push(WherePredicate::Type(
PredicateType {
lifetimes: None,
bounded_ty: parse_quote!(Self),
colon_token,
bounds,
},
));
}
return Some(ident);
}
}
}
None
}
fn trait_from_impl(item: &mut ItemImpl, args: Args) -> Result<ItemTrait> {
struct ReplaceParam<'a> {
self_ty: &'a Ident,
}
impl VisitMut for ReplaceParam<'_> {
fn visit_ident_mut(&mut self, ident: &mut Ident) {
if *ident == *self.self_ty {
*ident = Ident::new("Self", ident.span());
}
}
}
let name = args.name.unwrap();
let mut generics = item.generics.clone();
let mut visitor = determine_trait_generics(&mut generics, &item.self_ty)
.map(|self_ty| ReplaceParam { self_ty });
if let Some(visitor) = &mut visitor {
visitor.visit_generics_mut(&mut generics);
}
let ty_generics = generics.split_for_impl().1;
let trait_ = parse_quote!(#name #ty_generics);
item.trait_ = Some((None, trait_, <Token![for]>::default()));
let impl_vis = args.vis;
let mut assoc_vis = None;
let mut items = Vec::with_capacity(item.items.len());
item.items.iter_mut().try_for_each(|item| {
trait_item_from_impl_item(item, &mut assoc_vis, &impl_vis).map(|mut item| {
if let Some(visitor) = &mut visitor {
visitor.visit_trait_item_mut(&mut item);
}
items.push(item)
})
})?;
let mut attrs = item.attrs.clone();
find_remove(&mut item.attrs, "doc");
attrs.push(parse_quote!(#[allow(patterns_in_fns_without_body)]));
Ok(ItemTrait {
attrs,
vis: impl_vis.unwrap_or_else(|| assoc_vis.unwrap_or(Visibility::Inherited)),
unsafety: item.unsafety,
auto_token: None,
trait_token: <Token![trait]>::default(),
ident: name,
generics,
colon_token: None,
supertraits: Punctuated::new(),
brace_token: token::Brace::default(),
items,
})
}
fn trait_item_from_impl_item(
impl_item: &mut ImplItem,
prev_vis: &mut Option<Visibility>,
impl_vis: &Option<Visibility>,
) -> Result<TraitItem> {
fn compare_visibility(x: &Visibility, y: &Visibility) -> bool {
match (x, y) {
(Visibility::Public(_), Visibility::Public(_))
| (Visibility::Crate(_), Visibility::Crate(_))
| (Visibility::Inherited, Visibility::Inherited) => true,
(Visibility::Restricted(x), Visibility::Restricted(y)) => {
x.to_token_stream().to_string() == y.to_token_stream().to_string()
}
_ => false,
}
}
fn check_visibility(
current: Visibility,
prev: &mut Option<Visibility>,
impl_vis: &Option<Visibility>,
span: &dyn ToTokens,
) -> Result<()> {
if impl_vis.is_some() {
return if let Visibility::Inherited = current {
Ok(())
} else {
Err(error!(current, "all associated items must have inherited visibility"))
};
}
match prev {
None => *prev = Some(current),
Some(prev) if compare_visibility(prev, ¤t) => {}
Some(prev) => {
return if let Visibility::Inherited = prev {
Err(error!(current, "all associated items must have inherited visibility"))
} else {
Err(error!(
if let Visibility::Inherited = current { span } else { ¤t },
"all associated items must have a visibility of `{}`",
prev.to_token_stream(),
))
};
}
}
Ok(())
}
match impl_item {
ImplItem::Const(impl_const) => {
let vis = mem::replace(&mut impl_const.vis, Visibility::Inherited);
check_visibility(vis, prev_vis, impl_vis, &impl_const.ident)?;
let attrs = impl_const.attrs.clone();
find_remove(&mut impl_const.attrs, "doc");
Ok(TraitItem::Const(TraitItemConst {
attrs,
const_token: impl_const.const_token,
ident: impl_const.ident.clone(),
colon_token: impl_const.colon_token,
ty: impl_const.ty.clone(),
default: None,
semi_token: impl_const.semi_token,
}))
}
ImplItem::Method(impl_method) => {
let vis = mem::replace(&mut impl_method.vis, Visibility::Inherited);
check_visibility(vis, prev_vis, impl_vis, &impl_method.sig.ident)?;
let mut attrs = impl_method.attrs.clone();
find_remove(&mut impl_method.attrs, "doc");
find_remove(&mut attrs, "inline");
Ok(TraitItem::Method(TraitItemMethod {
attrs,
sig: impl_method.sig.clone(),
default: None,
semi_token: Some(Token![;](impl_method.block.brace_token.span)),
}))
}
_ => Err(error!(impl_item, "unsupported item")),
}
}
fn find_remove(attrs: &mut Vec<Attribute>, ident: &str) {
while let Some(i) = attrs.iter().position(|attr| attr.path.is_ident(ident)) {
attrs.remove(i);
}
}
fn hash(input: &TokenStream) -> u64 {
let mut hasher = DefaultHasher::new();
hasher.write(input.to_string().as_bytes());
hasher.finish()
}