#![doc = include_str!("../README.md")]
use darling::{FromDeriveInput, FromField, ast::Data, util::PathList};
use proc_macro::TokenStream;
use quote::quote;
use syn::{
DeriveInput, Ident, Type, parse_macro_input,
visit::{self, Visit},
};
#[derive(FromField, Clone)]
#[darling(attributes(companion))]
struct FieldAttrs {
ident: Option<Ident>,
ty: Type,
rename: Option<String>,
#[darling(default)]
skip: bool,
}
#[derive(FromDeriveInput)]
#[darling(attributes(companion), supports(struct_named))]
struct CompanionOpts {
ident: Ident,
vis: syn::Visibility,
generics: syn::Generics,
data: Data<(), FieldAttrs>,
#[darling(default = "default_value_fn")]
value_fn: String,
#[darling(default = "default_update_fn")]
update_fn: String,
#[darling(default = "default_fields_fn")]
fields_fn: String,
#[darling(default)]
derive_field: PathList,
#[darling(default)]
derive_value: PathList,
#[darling(default)]
serde_field: Option<syn::Meta>,
#[darling(default)]
serde_value: Option<syn::Meta>,
}
fn default_value_fn() -> String {
"value".to_string()
}
fn default_update_fn() -> String {
"update".to_string()
}
fn default_fields_fn() -> String {
"fields".to_string()
}
#[proc_macro_derive(EnumCompanion, attributes(companion))]
pub fn enum_companion_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let opts = match CompanionOpts::from_derive_input(&input) {
Ok(val) => val,
Err(err) => {
return err.write_errors().into();
}
};
let struct_name = opts.ident;
let vis = opts.vis;
let generics = opts.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let value_fn_name = Ident::new(&opts.value_fn, struct_name.span());
let update_fn_name = Ident::new(&opts.update_fn, struct_name.span());
let fields_fn_name = Ident::new(&opts.fields_fn, struct_name.span());
let derive_field = opts.derive_field;
let derive_value = opts.derive_value;
let serde_field = &opts.serde_field;
let serde_field_attr = if let Some(syn::Meta::List(serde_field)) = serde_field {
let attr_tokens: proc_macro2::TokenStream = serde_field.tokens.clone();
quote! { #[serde(#attr_tokens)] }
} else {
quote! {}
};
let serde_value = &opts.serde_value;
let serde_value_attr = if let Some(syn::Meta::List(serde_value)) = serde_value {
let attr_tokens: proc_macro2::TokenStream = serde_value.tokens.clone();
quote! { #[serde(#attr_tokens)] }
} else {
quote! {}
};
let fields = opts.data.take_struct().unwrap();
let mut field_idents = Vec::new();
let mut field_types = Vec::new();
let mut field_variants = Vec::new();
let mut from_str_arms = Vec::new();
for field in fields.fields {
if field.skip {
continue;
}
let ident = field.ident.clone().unwrap();
let variant_name_str = field
.rename
.clone()
.unwrap_or_else(|| to_pascal_case(&ident.to_string()));
let variant = Ident::new(&variant_name_str, ident.span());
let ident_str = ident.to_string();
let mut patterns = vec![ident_str.clone()];
if variant_name_str != ident_str {
patterns.push(variant_name_str);
}
from_str_arms.push(quote! {
#(#patterns)|* => Ok(Self::#variant)
});
field_idents.push(ident);
field_types.push(field.ty);
field_variants.push(variant);
}
let field_enum_name = syn::Ident::new(&format!("{struct_name}Field"), struct_name.span());
let value_enum_name = syn::Ident::new(&format!("{struct_name}Value"), struct_name.span());
let field_enum_variants = field_variants.iter();
let _field_variants_count = field_variants.len();
let value_enum_variants = field_variants
.iter()
.zip(field_types.iter())
.map(|(variant, ty)| {
quote! { #variant(#ty) }
});
let value_match_arms =
field_idents
.iter()
.zip(field_variants.iter())
.map(|(ident, variant)| {
quote! {
#field_enum_name::#variant => #value_enum_name::#variant(self.#ident.clone())
}
});
let update_match_arms =
field_idents
.iter()
.zip(field_variants.iter())
.map(|(ident, variant)| {
quote! {
#value_enum_name::#variant(value) => self.#ident = value
}
});
let trait_impl = if opts.value_fn == "value"
&& opts.update_fn == "update"
&& opts.fields_fn == "fields"
{
quote! {
impl #impl_generics ::enum_companion::EnumCompanionTrait<#field_enum_name, #value_enum_name #ty_generics> for #struct_name #ty_generics #where_clause {
fn value(&self, field: #field_enum_name) -> #value_enum_name #ty_generics {
self.value(field)
}
fn update(&mut self, value: #value_enum_name #ty_generics) {
self.update(value)
}
fn fields() -> &'static [#field_enum_name] {
&#field_enum_name::FIELDS
}
fn as_values(&self) -> Vec<#value_enum_name #ty_generics> {
self.as_values()
}
}
}
} else {
quote! {}
};
let mut unique_types = std::collections::HashMap::new();
for (ty, variant) in field_types.iter().zip(field_variants.iter()) {
let key = quote!(#ty).to_string();
unique_types
.entry(key)
.or_insert_with(|| (ty.clone(), Vec::new()))
.1
.push(variant.clone());
}
let generic_param_idents: std::collections::HashSet<String> = generics
.params
.iter()
.filter_map(|p| match p {
syn::GenericParam::Type(ty) => Some(ty.ident.to_string()),
_ => None,
})
.collect();
let try_from_impls = unique_types.values().filter_map(|(ty, variants)| {
if type_contains_generic(ty, &generic_param_idents) {
return None;
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
Some(quote! {
impl #impl_generics std::convert::TryFrom<#value_enum_name #ty_generics> for #ty #where_clause {
type Error = #value_enum_name #ty_generics;
fn try_from(value: #value_enum_name #ty_generics) -> Result<Self, Self::Error> {
match value {
#(#value_enum_name::#variants(val) => Ok(val)),*,
_ => Err(value),
}
}
}
})
});
let try_into_impls = unique_types.values().filter_map(|(ty, variants)| {
if type_contains_generic(ty, &generic_param_idents) {
return None;
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let arms = variants.iter().map(|variant| {
quote! {
#field_enum_name::#variant => Ok(#value_enum_name::#variant(value)),
}
});
Some(quote! {
impl #impl_generics std::convert::TryFrom<(#field_enum_name, #ty)> for #value_enum_name #ty_generics #where_clause {
type Error = #field_enum_name;
fn try_from(value: (#field_enum_name, #ty)) -> Result<Self, Self::Error> {
let (field, value) = value;
match field {
#(#arms)*
_ => Err(field),
}
}
}
})
});
let expanded = quote! {
#[allow(dead_code)]
#[derive(Copy, Clone, #(#derive_field),*)]
#serde_field_attr
#vis enum #field_enum_name {
#(#field_enum_variants),*
}
impl #field_enum_name {
pub const FIELDS: &'static [#field_enum_name] = &[#(#field_enum_name::#field_variants),*];
}
#[allow(dead_code)]
#[derive(Clone, #(#derive_value),*)]
#serde_value_attr
#vis enum #value_enum_name #ty_generics {
#(#value_enum_variants),*
}
impl std::str::FromStr for #field_enum_name {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
#(#from_str_arms),*,
_ => Err(format!("Invalid field name: {}", s)),
}
}
}
impl #impl_generics #struct_name #ty_generics #where_clause {
pub fn #fields_fn_name() -> &'static [#field_enum_name] {
#field_enum_name::FIELDS
}
pub fn as_values(&self) -> Vec<#value_enum_name #ty_generics> {
Self::#fields_fn_name()
.iter()
.map(|&field| self.#value_fn_name(field))
.collect()
}
pub fn #value_fn_name(&self, field: #field_enum_name) -> #value_enum_name #ty_generics {
match field {
#(#value_match_arms),*
}
}
pub fn #update_fn_name(&mut self, value: #value_enum_name #ty_generics) {
match value {
#(#update_match_arms),*
}
}
}
#trait_impl
#(#try_from_impls)*
#(#try_into_impls)*
};
TokenStream::from(expanded)
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut c = word.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect()
}
struct GenericVisitor<'a> {
generic_params: &'a std::collections::HashSet<String>,
contains_generic: bool,
}
impl<'ast, 'a> Visit<'ast> for GenericVisitor<'a> {
fn visit_type_path(&mut self, i: &'ast syn::TypePath) {
if self.contains_generic {
return;
}
if i.qself.is_none() {
if let Some(segment) = i.path.segments.last() {
if self.generic_params.contains(&segment.ident.to_string()) {
self.contains_generic = true;
return;
}
}
}
visit::visit_type_path(self, i);
}
}
fn type_contains_generic(
ty: &syn::Type,
generic_params: &std::collections::HashSet<String>,
) -> bool {
let mut visitor = GenericVisitor {
generic_params,
contains_generic: false,
};
visitor.visit_type(ty);
visitor.contains_generic
}