use std::collections::HashMap;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
use syn;
#[proc_macro_derive(EnumFields)]
pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
impl_for_input(&ast)
}
fn collect_available_fields(enum_data: &syn::DataEnum) -> HashMap<String, Vec<&syn::Field>> {
let mut fields = HashMap::new();
for variant in &enum_data.variants {
for field in &variant.fields {
if let Some(field_ident) = &field.ident {
let ident = field_ident.to_string();
fields.entry(ident)
.or_insert(Vec::new())
.push(field);
}
}
}
fields
}
fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
let fail_message = "`EnumFields` is only applicable to `enum`s";
match &ast.data {
syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
}
}
fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
let name = &ast.ident;
let fields = collect_available_fields(enum_data);
let mut data = proc_macro2::TokenStream::new();
for (field_name, fields) in fields {
let field_present_everywhere = fields.len() == enum_data.variants.len()
&& fields.iter().all(|x| x.ty == fields[0].ty);
let generics = &ast.generics;
let field_type = &fields[0].ty;
let field_name_ident = Ident::new(&field_name, Span::call_site());
let field_name_ident_mut = Ident::new(&format!("{field_name}_mut"), Span::call_site());
let mut variants = proc_macro2::TokenStream::new();
for variant in &enum_data.variants {
let name = &variant.ident;
let variant_field = variant.fields.iter()
.find(|variant_field| {
if let Some(variant_field_ident) = &variant_field.ident {
if variant_field_ident.to_string() == field_name {
true
} else {
false
}
} else {
false
}
});
let variant_field_ident = variant_field.as_ref().and_then(|field| field.ident.as_ref());
match variant_field_ident {
Some(variant_field_ident) => {
variants.extend(quote! {
Self::#name{ #variant_field_ident, .. } => (#variant_field_ident).into(),
});
}
None => {
if let Some(first_field) = variant.fields.iter().next() {
if first_field.ident.is_some() {
variants.extend(quote! {
Self::#name{ .. } => None,
});
} else {
variants.extend(quote! {
Self::#name(..) => None,
});
}
} else {
variants.extend(quote! {
Self::#name => None,
});
}
}
}
}
let ty = if field_present_everywhere {
quote! {
& #field_type
}
} else {
quote! {
Option<& #field_type>
}
};
let ty_mut = if field_present_everywhere {
quote! {
&mut #field_type
}
} else {
quote! {
Option<&mut #field_type>
}
};
data.extend(quote! {
impl #generics #name #generics {
pub fn #field_name_ident(&self) -> #ty {
match self {
#variants
}
}
pub fn #field_name_ident_mut(&mut self) -> #ty_mut {
match self {
#variants
}
}
}
});
}
data.into()
}