extern crate proc_macro;
use proc_macro2::{Ident, Span};
use quote::quote;
use std::fmt::Display;
use syn::{
parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
AttributeArgs, FnArg, Generics, ItemFn, Lit, Meta, NestedMeta, Pat, Type, TypeParam,
};
pub(crate) fn main(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let sa = parse_macro_input!(attr as AttributeArgs);
let si = parse_macro_input!(item as ItemFn);
let name = &si.sig.ident;
let mut trait_name = Ident::new(&format!("IntoExtensionMethod{}", &name), name.span());
let mut impl_type_param = Ident::new("FirstArgImplType Param", Span::call_site());
let mut remove_fn = false;
for item in sa {
if let NestedMeta::Meta(Meta::NameValue(x)) = item {
if x.path.is_ident("trait_name") {
if let Lit::Str(lstr) = x.lit {
trait_name = Ident::new(&lstr.value(), lstr.span());
} else {
return syn_err(
x.lit.span(),
"expected string literal (eg: `\"IntoMyExtensionMethod\"`)",
);
}
} else if x.path.is_ident("impl_type_param") {
if let Lit::Str(lstr) = x.lit {
impl_type_param = Ident::new(&lstr.value(), lstr.span());
} else {
return syn_err(
x.lit.span(),
"expected string literal (eg: `\"FirstArgImplTypeParam\"`)",
);
}
} else if x.path.is_ident("remove_fn") {
if let Lit::Bool(lbool) = x.lit {
remove_fn = lbool.value;
} else {
return syn_err(x.lit.span(), "expected bool literal (eg: `true`)");
}
}
}
}
let vis_trait = &si.vis;
let mut generics = si.sig.generics.clone();
let kw_async = si.sig.asyncness;
let kw_const = si.sig.constness;
let kw_unsafe = si.sig.unsafety;
let mut first_arg = if let Some(arg) = si.sig.inputs.first() {
arg.clone()
} else {
return syn_err(
si.sig.paren_token.span,
"extension method must have at least 1 input",
);
};
let mut first_arg_pt = if let FnArg::Typed(arg) = &first_arg {
arg
} else {
return syn_err(first_arg.span(), "first argument must not be `self`");
};
if let Type::Path(p) = first_arg_pt.ty.as_ref() {
if p.path.is_ident("Self") {
return syn_err(p.span(), "first argument cannot be of type `Self`");
}
}
let (fa_mut, fa_ident) = if let Pat::Ident(ident) = first_arg_pt.pat.as_ref() {
(&ident.mutability, &ident.ident)
} else {
return syn_err(first_arg_pt.pat.span(), "expected identifier");
};
if let Type::ImplTrait(p) = first_arg_pt.ty.as_ref() {
let fa_impl_bounds = &p.bounds;
generics
.params
.push(parse_quote!(#impl_type_param: #fa_impl_bounds));
let attrs = &first_arg_pt.attrs;
first_arg = parse_quote!(#(#attrs)* #fa_mut #fa_ident: #impl_type_param);
first_arg_pt = if let FnArg::Typed(pt) = &first_arg {
pt
} else {
panic!("Bro how")
}
}
let trait_tp = if let Some(p) = is_type_param(&first_arg_pt.ty, &generics) {
Some((p, &None))
} else if let Type::Ptr(ptr) = first_arg_pt.ty.as_ref() {
is_type_param(&ptr.elem, &generics).map(|x| (x, &None))
} else if let Type::Reference(rf) = first_arg_pt.ty.as_ref() {
is_type_param(&rf.elem, &generics).map(|x| (x, &rf.lifetime))
} else {
None
};
let trait_generics: Option<Generics> = trait_tp.map(|(tp, lt)| {
if let Some(lt) = lt {
parse_quote!(<#lt, #tp>)
} else {
parse_quote!(<#tp>)
}
});
let remain_generics = if let Some((tp, lto)) = trait_tp {
let p: Punctuated<_, Comma> = generics
.params
.iter()
.filter(|x| match x {
syn::GenericParam::Type(t) => t.ident != tp.ident,
syn::GenericParam::Lifetime(l) => {
if let Some(lt) = lto {
l.lifetime.ident != lt.ident
} else {
false
}
}
_ => true,
})
.collect();
parse_quote!(<#p>)
} else {
generics.clone()
};
let trait_doc =
format!("this trait is autogenerated for extension method `obj.{name}(...)` to work");
let fn_attrs = &si.attrs;
let first_arg_attrs = &first_arg_pt.attrs;
let fn_out = &si.sig.output;
let wcl = &generics.where_clause;
let inp: Punctuated<_, Comma> = si.sig.inputs.iter().skip(1).collect();
let trait_code = quote! {
#[doc = #trait_doc]
#[allow(non_camel_case_types)]
#vis_trait trait #trait_name #trait_generics {
#(#fn_attrs)*
#kw_const #kw_unsafe #kw_async fn #name #remain_generics(#(#first_arg_attrs)* self, #inp) #fn_out #wcl;
}
};
let impl_ty = first_arg_pt.ty.as_ref();
let stmts = &si.block.stmts;
let fa_pat = &first_arg_pt.pat;
let impl_code = quote! {
impl #trait_generics #trait_name #trait_generics for #impl_ty {
#(#fn_attrs)*
#kw_const #kw_unsafe #kw_async fn #name #remain_generics(#(#first_arg_attrs)* self, #inp) #fn_out #wcl {
let #fa_pat = self;
#(#stmts)*
}
}
};
let original_fn = match remove_fn {
true => None,
false => Some(&si),
};
quote! {
#trait_code
#impl_code
#original_fn
}
.into()
}
fn is_type_param<'a>(ty: &Type, tps: &'a Generics) -> Option<&'a TypeParam> {
if let Type::Path(p) = ty {
tps.type_params().find(|x| p.path.is_ident(&x.ident))
} else {
None
}
}
fn syn_err<T: Display>(span: proc_macro2::Span, message: T) -> proc_macro::TokenStream {
syn::Error::new(span, message).into_compile_error().into()
}