enum_variants_strings_derive/
macro.rs

1#![doc = include_str!("./README.md")]
2
3use std::iter;
4
5use either_n::Either2;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use string_cases::StringCasesExt;
10use syn::{
11    parse::ParseStream, parse_macro_input, parse_quote, punctuated::Punctuated, Arm, Data,
12    DeriveInput, Expr, ExprLit, Lit, LitStr, Meta, MetaNameValue, Token,
13};
14
15/// Name of the attribute that is used for specifying a mapping other than it snake cased name
16/// ```rs
17/// enum X {
18///     #[enum_variant_from_strings("z", "zed", "zee")]
19///     Z
20/// }
21/// ```
22const CUSTOM_VARIANT_STRING_MAPPING: &str = "enum_variants_strings_mappings";
23
24/// For specifying the custom transform
25const CUSTOM_VARIANT_STRING_TRANSFORM: &str = "enum_variants_strings_transform";
26
27#[derive(Default)]
28enum Transform {
29    #[default]
30    SnakeCase,
31    UpperCase,
32    LowerCase,
33    KebabCase,
34    None,
35}
36
37/// Ironically this is what this proc macro should generate
38impl Transform {
39    pub(crate) fn from_str(s: &str) -> Result<Self, UnknownCustomTransformError> {
40        match s {
41            "snake_case" => Ok(Self::SnakeCase),
42            "upper_case" => Ok(Self::UpperCase),
43            "lower_case" => Ok(Self::LowerCase),
44            "kebab_case" | "kebab-case" => Ok(Self::KebabCase),
45            "none" => Ok(Self::None),
46            s => Err(UnknownCustomTransformError { transform: s }),
47        }
48    }
49
50    pub(crate) fn apply_transform(&self, s: &str) -> String {
51        match self {
52            Transform::SnakeCase => s.to_snake_case(),
53            Transform::KebabCase => s.to_kebab_case(),
54            Transform::UpperCase => s.to_uppercase(),
55            Transform::LowerCase => s.to_lowercase(),
56            Transform::None => s.to_owned(),
57        }
58    }
59}
60
61struct UnknownCustomTransformError<'a> {
62    transform: &'a str,
63}
64
65#[allow(clippy::from_over_into)]
66impl<'a> Into<TokenStream> for UnknownCustomTransformError<'a> {
67    fn into(self) -> TokenStream {
68        let message = format!("Unknown transform '{}'", self.transform);
69        quote!(compile_error!(#message)).into()
70    }
71}
72
73#[proc_macro_derive(
74    EnumVariantsStrings,
75    attributes(enum_variants_strings_mappings, enum_variants_strings_transform)
76)]
77pub fn enum_variants_strings(input: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79
80    if let Data::Enum(r#enum) = &input.data {
81        let ident = &input.ident;
82
83        let mapping: Option<Result<String, ()>> = input.attrs.iter().find_map(|attr| {
84            attr.path()
85                .is_ident(CUSTOM_VARIANT_STRING_TRANSFORM)
86                .then(|| {
87                    if let Meta::List(ref meta_list) = attr.meta {
88                        let inner = meta_list.parse_args::<MetaNameValue>();
89
90                        if let Ok(MetaNameValue {
91                            path,
92                            value:
93                                Expr::Lit(ExprLit {
94                                    lit: Lit::Str(lit_str),
95                                    ..
96                                }),
97                            ..
98                        }) = inner
99                        {
100                            if path.is_ident("transform") {
101                                return Ok(lit_str.value());
102                            }
103                        }
104                    }
105                    Err(())
106                })
107        });
108
109        let mapping = match mapping.transpose() {
110            Ok(mapping) => mapping,
111            Err(_) => {
112                return quote!(
113                    compile_error!("Invalid usage of \"enum_variants_strings_transform\", check docs for usage");
114                ).into();
115            }
116        };
117
118        let transform = match mapping.as_deref().map(Transform::from_str).transpose() {
119            Ok(transform) => transform.unwrap_or_default(),
120            Err(err) => {
121                return err.into();
122            }
123        };
124
125        let (mut to_string_arms, mut from_string_arms) = (
126            Vec::<Arm>::with_capacity(r#enum.variants.len()),
127            Vec::<Arm>::with_capacity(r#enum.variants.len()),
128        );
129        let mut possible_matches = Vec::<String>::new();
130
131        for variant in r#enum.variants.iter() {
132            // A list of LitStr which match the variant
133            let variant_names = if let Some(attr) = variant
134                .attrs
135                .iter()
136                .find(|attr| attr.path().is_ident(CUSTOM_VARIANT_STRING_MAPPING))
137            {
138                let parse_args_result = attr.parse_args_with(|stream: ParseStream| {
139                    stream
140                        .parse_terminated(|stream: ParseStream| stream.parse::<LitStr>(), Token![,])
141                });
142                let args: Punctuated<LitStr, Token![,]> = match parse_args_result {
143                    Ok(args) => args,
144                    Err(_) => {
145                        return quote!(compile_error!(
146                            "Failed to parse string arguments in custom mapping"
147                        ))
148                        .into();
149                    }
150                };
151
152                Either2::One(args.into_iter())
153            } else {
154                Either2::Two(iter::once(LitStr::new(
155                    &transform.apply_transform(&variant.ident.to_string()),
156                    Span::call_site(),
157                )))
158            };
159
160            let variant_name = &variant.ident;
161
162            // Build default constructor for field.
163            let variant_default_body = match &variant.fields {
164                syn::Fields::Named(named) => {
165                    let fields = named.named.iter().map(|field| {
166                        let field_ident = &field.ident;
167                        quote!(#field_ident: Default::default())
168                    });
169                    quote!( Self::#variant_name { #(#fields),* } )
170                }
171                syn::Fields::Unnamed(unnamed) => {
172                    let fields = unnamed
173                        .unnamed
174                        .iter()
175                        .map(|_field| quote!(Default::default()));
176                    quote!( Self::#variant_name ( #(#fields),* ) )
177                }
178                syn::Fields::Unit => quote!( Self::#variant_name ),
179            };
180
181            possible_matches.extend(variant_names.clone().map(|lit_str| lit_str.value()));
182
183            // If multiple output names use last
184            let last_str = variant_names.clone().last().unwrap();
185
186            let to_string_pattern = match &variant.fields {
187                syn::Fields::Named(_) => {
188                    quote!( Self::#variant_name {..} )
189                }
190                syn::Fields::Unnamed(_) => {
191                    quote!( Self::#variant_name (..) )
192                }
193                syn::Fields::Unit => quote!( Self::#variant_name ),
194            };
195
196            to_string_arms.push(parse_quote! {
197                #to_string_pattern => #last_str
198            });
199
200            from_string_arms.push(parse_quote! {
201                #(#variant_names)|* => Ok(#variant_default_body)
202            });
203        }
204
205        quote! {
206            impl ::enum_variants_strings::EnumVariantsStrings for #ident {
207                fn from_str(input: &str) -> Result<Self, &[&str]> {
208                    match input {
209                        #(#from_string_arms),*,
210                        _ => Err(&[#(#possible_matches),*])
211                    }
212                }
213
214                fn to_str(&self) -> &'static str {
215                    match self {
216                        #(#to_string_arms),*
217                    }
218                }
219            }
220        }
221        .into()
222    } else {
223        quote!(
224            compile_error!("Can only implement 'EnumVariantsStrings' on a enum");
225        )
226        .into()
227    }
228}