use proc_macro::TokenStream;
use quote::ToTokens;
use syn::{
parse::{Error, Parse, ParseStream},
spanned::Spanned,
Fields, FnArg, ItemEnum, ItemFn, Pat,
};
#[proc_macro_attribute]
pub fn with_methods(arg: TokenStream, input: TokenStream) -> TokenStream {
let mut input_methods = syn::parse_macro_input!(arg as Methods).0;
let input_enum = syn::parse_macro_input!(input as ItemEnum);
let mut errors = vec![];
for method in input_methods.iter_mut() {
if let Err(error) = add_block_to_fn(method, &input_enum) {
let span = error.span;
let message = error.message;
errors.push(quote::quote_spanned! {
span.span() => compile_error!(#message);
})
}
}
let enum_ident = &input_enum.ident;
let enum_impl = quote::quote! {
impl #enum_ident {
#(#input_methods)*
}
};
TokenStream::from(quote::quote! {
#input_enum
#enum_impl
#(#errors)*
})
}
struct Methods(Vec<ItemFn>);
impl Parse for Methods {
fn parse(input: ParseStream) -> Result<Self, Error> {
let mut methods = vec![];
while !input.is_empty() {
methods.push(input.parse()?);
}
Ok(Methods(methods))
}
}
fn add_block_to_fn(input_method: &mut ItemFn, input_enum: &ItemEnum) -> Result<(), MacroError> {
let method_ident = &input_method.sig.ident;
let mut has_self_arg = false;
let method_arg_idents: Vec<_> = input_method
.sig
.inputs
.iter()
.filter_map(|i| match i {
FnArg::Typed(t) => match &*t.pat {
Pat::Ident(i) => {
if i.ident == "self" {
has_self_arg = true;
None
} else {
Some(i.ident.to_token_stream())
}
}
_ => None,
},
FnArg::Receiver(_) => {
has_self_arg = true;
None
}
})
.collect();
let and = syn::Token);
let self_token = syn::Token);
if !has_self_arg {
input_method.sig.inputs.insert(
0,
FnArg::Receiver(syn::Receiver {
attrs: vec![],
reference: Some((and, None)),
mutability: None,
self_token,
}),
);
}
let mut match_arms = vec![];
for variant in &input_enum.variants {
let variant_ident = &variant.ident;
match &variant.fields {
Fields::Named(fields) => {
let mut first_field = fields
.named
.first()
.ok_or_else(|| MacroError {
span: Box::new(fields.clone()),
message: "variants must have at least one field".to_string(),
})?
.clone();
let path = if let syn::Type::Path(path) = &mut first_field.ty {
path
} else {
panic!();
};
for seg in &mut path.path.segments {
if let syn::PathArguments::AngleBracketed(gen) = &mut seg.arguments {
let colon2 = syn::Token);
gen.colon2_token = Some(colon2);
}
}
let first_field_ident = first_field.ident.as_ref().unwrap();
let first_field_type = &first_field.ty;
let match_arm = if has_self_arg {
quote::quote! {
Self::#variant_ident { #first_field_ident, .. } => #first_field_type :: #method_ident (#first_field_ident, #(#method_arg_idents,)* )
}
} else {
quote::quote! {
Self::#variant_ident { .. } => #first_field_type :: #method_ident (#(#method_arg_idents,)* )
}
};
match_arms.push(match_arm);
}
Fields::Unnamed(fields) => {
let mut first_field = fields
.unnamed
.first()
.ok_or_else(|| MacroError {
span: Box::new(fields.clone()),
message: "variants must have at least one field".to_string(),
})?
.clone();
let path = if let syn::Type::Path(path) = &mut first_field.ty {
path
} else {
panic!();
};
for seg in &mut path.path.segments {
if let syn::PathArguments::AngleBracketed(gen) = &mut seg.arguments {
let colon2 = syn::Token);
gen.colon2_token = Some(colon2);
}
}
let match_arm = if has_self_arg {
quote::quote! {
Self::#variant_ident ( f_1, .. ) => #first_field :: #method_ident (f_1, #(#method_arg_idents,)* )
}
} else {
quote::quote! {
Self::#variant_ident ( .. ) => #first_field :: #method_ident ( #(#method_arg_idents,)* )
}
};
match_arms.push(match_arm);
}
Fields::Unit => {
return Err(MacroError {
span: Box::new(variant.clone()),
message: "variants must have at least one field".to_string(),
})
}
};
}
input_method.block = syn::parse(
quote::quote!(
{
match self {
#(#match_arms),*
}
}
)
.into(),
)
.map_err(|e| MacroError {
message: e.to_string(),
span: Box::new(e.span()),
})?;
Ok(())
}
struct MacroError {
span: Box<dyn Spanned>,
message: String,
}