use proc_macro::TokenStream;
use syn::{
parse::{Error, Parse, ParseStream},
spanned::Spanned,
Fields, FnArg, ItemEnum, ItemFn, Pat,
};
#[proc_macro_attribute]
pub fn with_methods(arg: TokenStream, input: TokenStream) -> TokenStream {
let mut input_methods = syn::parse_macro_input!(arg as Methods).0;
let input_enum = syn::parse_macro_input!(input as ItemEnum);
let mut errors = vec![];
for method in input_methods.iter_mut() {
if let Err(error) = add_block_to_fn(method, &input_enum) {
let span = error.span;
let message = error.message;
errors.push(quote::quote_spanned! {
span.span() => compile_error!(#message);
})
}
}
let enum_ident = &input_enum.ident;
let enum_impl = quote::quote! {
impl #enum_ident {
#(#input_methods)*
}
};
TokenStream::from(quote::quote! {
#input_enum
#enum_impl
#(#errors)*
})
}
struct Methods(Vec<ItemFn>);
impl Parse for Methods {
fn parse(input: ParseStream) -> Result<Self, Error> {
let mut methods = vec![];
while !input.is_empty() {
methods.push(input.parse()?);
}
Ok(Methods(methods))
}
}
fn add_block_to_fn(input_method: &mut ItemFn, input_enum: &ItemEnum) -> Result<(), MacroError> {
let method_ident = &input_method.sig.ident;
let method_arg_idents: Vec<_> = input_method
.sig
.inputs
.iter()
.filter_map(|i| match i {
FnArg::Typed(t) => match &*t.pat {
Pat::Ident(i) => Some(&i.ident),
_ => None,
},
FnArg::Receiver(_) => None,
})
.collect();
let mut match_arms = vec![];
for variant in &input_enum.variants {
let variant_ident = &variant.ident;
match &variant.fields {
Fields::Named(fields) => {
let first_field = fields
.named
.first()
.ok_or_else(|| MacroError {
span: Box::new(fields.clone()),
message: "variants must have at least one field".to_string(),
})?
.ident
.as_ref()
.unwrap(); match_arms.push(quote::quote! {
Self::#variant_ident { #first_field, .. } => #first_field.#method_ident ( #(#method_arg_idents,)* )
});
}
Fields::Unnamed(fields) => {
let _first_field = fields.unnamed.first().as_ref().ok_or_else(|| MacroError {
span: Box::new(fields.clone()),
message: "variants must have at least one field".to_string(),
})?;
match_arms.push(quote::quote! {
Self::#variant_ident ( f_1, .. ) => f_1.#method_ident ( #(#method_arg_idents,)* )
});
}
Fields::Unit => {
return Err(MacroError {
span: Box::new(variant.clone()),
message: "variants must have at least one field".to_string(),
})
}
};
}
input_method.block = syn::parse(
quote::quote!(
{
match self {
#(#match_arms),*
}
}
)
.into(),
)
.map_err(|e| MacroError {
message: e.to_string(),
span: Box::new(e.span()),
})?;
Ok(())
}
struct MacroError {
span: Box<dyn Spanned>,
message: String,
}