enum2str/
lib.rs

1//! enum2str is a rust derive macro that creates Display and FromStr impls for enums.
2//! This is useful for strongly typing composable sets of strings.
3//! ## Usage
4//!
5//! Add this to your `Cargo.toml`:
6//!
7//! ```toml
8//! enum2str = "0.1.11"
9//! ```
10
11use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
13use quote::{quote, quote_spanned, ToTokens};
14use syn::{
15    parse_macro_input, spanned::Spanned, Data, DeriveInput, Error, Fields, FieldsNamed,
16    FieldsUnnamed, LitStr,
17};
18
19macro_rules! derive_error {
20    ($string: tt) => {
21        Error::new(Span::call_site(), $string)
22            .to_compile_error()
23            .into()
24    };
25}
26
27fn has_only_unit_variants(data: &syn::DataEnum) -> bool {
28    data.variants
29        .iter()
30        .all(|variant| matches!(variant.fields, Fields::Unit))
31}
32
33fn find_duplicate_strings(data: &syn::DataEnum) -> Vec<(String, Vec<String>)> {
34    let mut string_to_variants = std::collections::HashMap::new();
35
36    for variant in data.variants.iter() {
37        if let Fields::Unit = variant.fields {
38            let mut string = variant.ident.to_string();
39            let variant_name = variant.ident.to_string();
40
41            // Check for enum2str attribute
42            for attr in &variant.attrs {
43                if attr.path.is_ident("enum2str") {
44                    if let Ok(literal) = attr.parse_args::<syn::LitStr>() {
45                        string = literal.value();
46                    }
47                }
48            }
49
50            string_to_variants
51                .entry(string)
52                .or_insert_with(Vec::new)
53                .push(variant_name);
54        }
55    }
56
57    string_to_variants
58        .into_iter()
59        .filter(|(_, variants)| variants.len() > 1)
60        .collect()
61}
62
63#[proc_macro_derive(EnumStr, attributes(enum2str))]
64pub fn derive_enum2str(input: TokenStream) -> TokenStream {
65    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
66    let name = &input.ident;
67
68    let data = match input.data {
69        Data::Enum(data) => data,
70        _ => return derive_error!("enum2str only supports enums"),
71    };
72
73    let mut match_arms = TokenStream2::new();
74    let mut variant_names = TokenStream2::new();
75    let mut template_arms = TokenStream2::new();
76    let mut arg_arms = TokenStream2::new();
77    let mut from_str_arms = TokenStream2::new();
78
79    for variant in data.variants.iter() {
80        let variant_name = &variant.ident;
81
82        match &variant.fields {
83            Fields::Unit => {
84                let mut display_ident = variant_name.to_string().to_token_stream();
85                let mut from_str_pattern = variant_name.to_string();
86
87                for attr in &variant.attrs {
88                    if attr.path.is_ident("enum2str") && attr.path.segments.first().is_some() {
89                        match attr.parse_args::<syn::LitStr>() {
90                            Ok(literal) => {
91                                display_ident = literal.to_token_stream();
92                                from_str_pattern = literal.value();
93                            }
94                            Err(_) => {
95                                return derive_error!(
96                                    r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
97                                );
98                            }
99                        }
100                    }
101                }
102
103                match_arms.extend(quote_spanned! {
104                    variant.span() =>
105                        #name::#variant_name =>  write!(f, "{}", #display_ident),
106                });
107
108                template_arms.extend(quote_spanned! {
109                    variant.span() =>
110                        #name::#variant_name => #display_ident.to_string(),
111                });
112
113                variant_names.extend(quote_spanned! {
114                    variant.span() =>
115                        stringify!(#variant_name).to_string(),
116                });
117
118                arg_arms.extend(quote_spanned! {
119                    variant.span() =>
120                        #name::#variant_name => vec![],
121                });
122
123                from_str_arms.extend(quote_spanned! {
124                    variant.span() =>
125                        s if s == #from_str_pattern => Ok(#name::#variant_name),
126                });
127            }
128            Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) => {
129                let mut format_ident = "{}".to_string().to_token_stream();
130
131                for attr in &variant.attrs {
132                    if attr.path.is_ident("enum2str") && attr.path.segments.first().is_some() {
133                        match attr.parse_args::<LitStr>() {
134                            Ok(literal) => format_ident = literal.to_token_stream(),
135                            Err(_) => {
136                                return derive_error!(
137                                    r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
138                                );
139                            }
140                        }
141                    }
142                }
143
144                if format_ident.to_string().contains("{}") {
145                    let fields = unnamed.iter().len();
146                    let args = ('a'..='z')
147                        .take(fields)
148                        .map(|letter| Ident::new(&letter.to_string(), variant.span()))
149                        .collect::<Vec<_>>();
150                    match_arms.extend(quote_spanned! {
151                        variant.span() =>
152                            #name::#variant_name(#(#args),*) => write!(f, #format_ident, #(#args),*),
153                    });
154
155                    template_arms.extend(quote_spanned! {
156                        variant.span() =>
157                            #name::#variant_name(..) => #format_ident.to_string(),
158                    });
159
160                    variant_names.extend(quote_spanned! {
161                        variant.span() =>
162                            stringify!(#variant_name).to_string(),
163                    });
164
165                    arg_arms.extend(quote_spanned! {
166                        variant.span() =>
167                            #name::#variant_name(#(#args),*) => vec![#(#args.to_string()),*],
168                    });
169                } else {
170                    match_arms.extend(quote_spanned! {
171                        variant.span() =>
172                            #name::#variant_name(..) => write!(f, #format_ident),
173                    });
174
175                    variant_names.extend(quote_spanned! {
176                        variant.span() =>
177                            stringify!(#variant_name).to_string(),
178                    });
179
180                    template_arms.extend(quote_spanned! {
181                        variant.span() =>
182                            #name::#variant_name(..) => #format_ident.to_string(),
183                    });
184
185                    arg_arms.extend(quote_spanned! {
186                        variant.span() =>
187                            #name::#variant_name(..) => vec![],
188                    });
189                }
190            }
191            Fields::Named(FieldsNamed { named, .. }) => {
192                let mut format_ident = variant_name.to_string().to_token_stream();
193                let mut field_idents = Vec::new();
194
195                let mut has_attribute = false;
196                for attr in &variant.attrs {
197                    if attr.path.is_ident("enum2str") {
198                        has_attribute = true;
199                        match attr.parse_args::<LitStr>() {
200                            Ok(literal) => {
201                                format_ident = literal.clone().to_token_stream();
202                                let literal_str = literal.value().clone();
203                                let mut start_indices =
204                                    literal_str.match_indices('{').map(|(i, _)| i);
205                                let mut end_indices =
206                                    literal_str.match_indices('}').map(|(i, _)| i);
207
208                                while let (Some(start), Some(end)) =
209                                    (start_indices.next(), end_indices.next())
210                                {
211                                    let field_name = &literal_str[(start + 1)..end];
212                                    field_idents.push(Ident::new(field_name, Span::call_site()));
213                                }
214                            }
215                            Err(_) => {
216                                return derive_error!(
217                                    r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
218                                );
219                            }
220                        }
221                    }
222                }
223
224                let field_names: Vec<_> = named.iter().map(|f| f.ident.as_ref().unwrap()).collect();
225
226                if !field_idents.is_empty() {
227                    // Use named arguments in format string
228                    let arg_pattern = field_idents
229                        .iter()
230                        .map(|ident| quote!(#ident = #ident))
231                        .collect::<Vec<_>>();
232
233                    match_arms.extend(quote_spanned! {
234                        variant.span() =>
235                            #name::#variant_name { #(#field_names),* } => write!(f, #format_ident, #(#arg_pattern),*),
236                    });
237
238                    arg_arms.extend(quote_spanned! {
239                        variant.span() =>
240                            #name::#variant_name { #(#field_names),* } => vec![#(#field_names.to_string()),*],
241                    });
242                } else {
243                    // Just use variant name or custom string
244                    match_arms.extend(quote_spanned! {
245                        variant.span() =>
246                            #name::#variant_name { .. } => write!(f, "{}", if #has_attribute { #format_ident.to_string() } else { stringify!(#variant_name).to_string() }),
247                    });
248
249                    arg_arms.extend(quote_spanned! {
250                        variant.span() =>
251                            #name::#variant_name { .. } => vec![],
252                    });
253                }
254
255                template_arms.extend(quote_spanned! {
256                    variant.span() =>
257                        #name::#variant_name { .. } => #format_ident.to_string(),
258                });
259
260                variant_names.extend(quote_spanned! {
261                    variant.span() =>
262                        stringify!(#variant_name).to_string(),
263                });
264
265                if field_names.is_empty() && has_attribute {
266                    let display_str = format_ident.to_string();
267                    from_str_arms.extend(quote_spanned! {
268                        variant.span() =>
269                            s if s == #display_str => Ok(#name::#variant_name { }),
270                    });
271                }
272            }
273        };
274    }
275
276    let expanded = quote! {
277        impl core::fmt::Display for #name {
278            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
279                match self {
280                    #match_arms
281                }
282            }
283        }
284
285        impl core::str::FromStr for #name {
286            type Err = String;
287
288            fn from_str(s: &str) -> Result<Self, Self::Err> {
289                match s {
290                    #from_str_arms
291                    _ => Err(format!("Invalid {} variant: {}", stringify!(#name), s))
292                }
293            }
294        }
295
296        impl #name {
297            /// Get the names of this enum's variants
298            pub fn variant_names() -> Vec<String> {
299                vec![
300                    #variant_names
301                ]
302            }
303
304            /// Get the format specifier used to display a variant
305            pub fn template(&self) -> String {
306                match self {
307                    #template_arms
308                }
309            }
310
311            /// Gets the names of a variant's arguments
312            pub fn arguments(&self) -> Vec<String> {
313                match self {
314                    #arg_arms
315                }
316            }
317        }
318    };
319
320    let mut expanded = TokenStream::from(expanded);
321
322    // Add TryFrom<String> implementation for enums with only unit variants
323    if has_only_unit_variants(&data) {
324        let duplicates = find_duplicate_strings(&data);
325        let try_from_impl = if duplicates.is_empty() {
326            // Simple implementation when no duplicates
327            quote! {
328                impl core::convert::TryFrom<String> for #name {
329                    type Error = String;
330
331                    fn try_from(value: String) -> Result<Self, Self::Error> {
332                        Self::from_str(&value)
333                    }
334                }
335            }
336        } else {
337            // Implementation that handles duplicates
338            let error_msg = format!(
339                "Ambiguous string representation. The following strings are used by multiple variants: {}",
340                duplicates
341                    .iter()
342                    .map(|(s, v)| format!("'{}' (used by {})", s, v.join(", ")))
343                    .collect::<Vec<_>>()
344                    .join(", ")
345            );
346
347            let duplicate_strings: Vec<_> = duplicates.iter().map(|(s, _)| s).collect();
348
349            quote! {
350                impl core::convert::TryFrom<String> for #name {
351                    type Error = String;
352
353                    fn try_from(value: String) -> Result<Self, Self::Error> {
354                        // First check if this is an ambiguous string
355                        if [#(#duplicate_strings),*].contains(&value.as_str()) {
356                            return Err(#error_msg.to_string());
357                        }
358                        // If not ambiguous, try normal conversion
359                        Self::from_str(&value)
360                    }
361                }
362            }
363        };
364        expanded.extend(TokenStream::from(try_from_impl));
365    }
366
367    expanded
368}