#![warn(missing_docs)]
#![allow(clippy::style)]
use proc_macro::TokenStream;
use quote::quote;
struct TypeInfo {
ident: syn::Ident,
generics: Option<syn::AngleBracketedGenericArguments>,
reference: Option<syn::Lifetime>,
mutability: Option<syn::token::Mut>
}
fn generate_self_trait_bound(generic_name: syn::Ident, trait_name: &syn::Ident) -> syn::GenericArgument {
let mut segments = syn::punctuated::Punctuated::new();
segments.push(syn::PathSegment {
ident: trait_name.clone(),
arguments: syn::PathArguments::None,
});
let mut bounds = syn::punctuated::Punctuated::new();
bounds.push(syn::TypeParamBound::Trait(syn::TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: syn::Path {
leading_colon: None,
segments
}
}));
syn::GenericArgument::Constraint(syn::Constraint {
ident: generic_name,
generics: None,
colon_token: syn::Token),
bounds
})
}
fn extract_type(typ: &mut syn::Type, trait_name: &syn::Ident, deref_type: &mut Option<syn::Ident>) -> Result<TypeInfo, TokenStream> {
match typ {
syn::Type::Path(ref mut typ) => {
let ident = match typ.path.segments.first() {
Some(path) => path.ident.clone(),
None => return Err(syn::Error::new_spanned(typ, "Type has no path segments").to_compile_error().into()),
};
match typ.path.segments.last_mut().expect("To have at least on type path segment").arguments {
syn::PathArguments::AngleBracketed(ref mut args) => {
let result = args.clone();
for arg in args.args.iter_mut() {
if let syn::GenericArgument::Constraint(constraint) = arg {
for param in constraint.bounds.iter() {
if let syn::TypeParamBound::Trait(bound) = param {
if bound.path.is_ident(trait_name) {
if let Some(ident) = deref_type.replace(constraint.ident.clone()) {
return Err(syn::Error::new_spanned(ident, "Multiple bounds to trait, can be problematic so how about no?").to_compile_error().into());
}
}
}
}
let mut segments = syn::punctuated::Punctuated::new();
segments.push(syn::PathSegment {
ident: constraint.ident.clone(),
arguments: syn::PathArguments::None
});
*arg = syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
qself: None,
path: syn::Path {
leading_colon: None,
segments
},
}));
}
}
Ok(TypeInfo {
ident,
generics: Some(result),
reference: None,
mutability: None,
})
},
syn::PathArguments::None => Ok(TypeInfo {
ident,
generics: None,
reference: None,
mutability: None,
}),
syn::PathArguments::Parenthesized(ref args) => Err(syn::Error::new_spanned(args, "Unsupported type arguments").to_compile_error().into()),
}
},
syn::Type::Reference(reference) => match extract_type(&mut reference.elem, trait_name, deref_type) {
Ok(mut result) => {
result.mutability = reference.mutability;
result.reference = reference.lifetime.clone();
Ok(result)
},
Err(error) => Err(error),
}
other => Err(syn::Error::new_spanned(other, "Unsupported type").to_compile_error().into()),
}
}
#[proc_macro_attribute]
pub fn auto_trait(args: TokenStream, input: TokenStream) -> TokenStream {
let mut input = syn::parse_macro_input!(input as syn::ItemTrait);
let args: syn::Type = match syn::parse(args) {
Ok(args) => args,
Err(error) => {
return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
}
};
let mut args = vec![args];
let mut remaining_attrs = Vec::new();
for attr in input.attrs.drain(..) {
if attr.path().is_ident("auto_trait") {
match attr.parse_args() {
Ok(arg) => match arg {
syn::Type::Paren(arg) => args.push(*arg.elem),
arg => args.push(arg),
},
Err(error) => {
return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
}
}
} else {
remaining_attrs.push(attr)
}
}
input.attrs = remaining_attrs;
let mut impls = Vec::new();
for mut args in args.drain(..) {
let trait_name = input.ident.clone();
let mut deref_type = None;
let type_info = match extract_type(&mut args, &trait_name, &mut deref_type) {
Ok(type_info) => type_info,
Err(error) => return error,
};
let deref_name = deref_type.unwrap_or_else(|| trait_name.clone());
let mut methods = Vec::new();
for item in input.items.iter() {
match item {
syn::TraitItem::Fn(ref method) => {
let method_name = method.sig.ident.clone();
let mut method_args = Vec::new();
for arg in method.sig.inputs.iter() {
match arg {
syn::FnArg::Receiver(arg) => {
if arg.reference.is_some() {
if arg.mutability.is_some() {
if type_info.reference.is_some() {
method_args.push(quote! {
&mut **self
})
} else {
method_args.push(quote! {
core::ops::DerefMut::deref_mut(self)
})
}
} else {
if type_info.reference.is_some() {
method_args.push(quote! {
&**self
})
} else {
method_args.push(quote! {
core::ops::Deref::deref(self)
})
}
}
} else {
method_args.push(quote! {
self.into()
})
}
},
syn::FnArg::Typed(arg) => {
let name = &arg.pat;
method_args.push(quote! {
#name
})
},
}
}
let deref_block: syn::Block = syn::parse2(quote! {
{
#deref_name::#method_name(#(#method_args,)*)
}
}).unwrap();
let mut method = method.clone();
method.default = Some(deref_block);
method.semi_token = None;
methods.push(method);
},
unsupported => return syn::Error::new_spanned(unsupported, "Trait contains non-method definitions which is unsupported").to_compile_error().into(),
}
}
let type_generics = if let Some(lifetime) = type_info.reference {
match type_info.generics {
Some(mut generics) => {
let mut new_args = syn::punctuated::Punctuated::new();
new_args.insert(0, generate_self_trait_bound(type_info.ident, &trait_name));
new_args.insert(0, syn::GenericArgument::Lifetime(lifetime));
while let Some(arg) = generics.args.pop() {
new_args.push(arg.into_tuple().0);
}
generics.args = new_args;
Some(generics)
},
None => {
let mut args = syn::punctuated::Punctuated::new();
args.push(syn::GenericArgument::Lifetime(lifetime));
args.push(generate_self_trait_bound(type_info.ident, &trait_name));
Some(syn::AngleBracketedGenericArguments {
colon2_token: None,
lt_token: syn::Token),
args,
gt_token: syn::Token),
})
}
}
} else {
type_info.generics
};
impls.push(quote! {
impl#type_generics #trait_name for #args {
#(
#methods
)*
}
});
}
let mut result = quote! {
#input
};
result.extend(impls.drain(..));
result.into()
}