enum_display_macro/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::{self, TokenStream};
3use proc_macro2::Span;
4use quote::quote;
5use regex::Regex;
6use syn::{parse_macro_input, Attribute, DeriveInput, FieldsNamed, FieldsUnnamed, Ident, Variant};
7
8// Enum attributes
9struct EnumAttrs {
10    case_transform: Option<Case>,
11}
12
13impl EnumAttrs {
14    fn from_attrs(attrs: Vec<Attribute>) -> Self {
15        let mut case_transform: Option<Case> = None;
16
17        for attr in attrs.into_iter() {
18            if attr.path.is_ident("enum_display") {
19                let meta = attr.parse_meta().unwrap();
20                if let syn::Meta::List(list) = meta {
21                    for nested in list.nested {
22                        if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
23                            if name_value.path.is_ident("case") {
24                                if let syn::Lit::Str(lit_str) = name_value.lit {
25                                    case_transform =
26                                        Some(Self::parse_case_name(lit_str.value().as_str()));
27                                }
28                            }
29                        }
30                    }
31                }
32            }
33        }
34
35        Self { case_transform }
36    }
37
38    fn parse_case_name(case_name: &str) -> Case {
39        match case_name {
40            "Upper" => Case::Upper,
41            "Lower" => Case::Lower,
42            "Title" => Case::Title,
43            "Toggle" => Case::Toggle,
44            "Camel" => Case::Camel,
45            "Pascal" => Case::Pascal,
46            "UpperCamel" => Case::UpperCamel,
47            "Snake" => Case::Snake,
48            "UpperSnake" => Case::UpperSnake,
49            "ScreamingSnake" => Case::ScreamingSnake,
50            "Kebab" => Case::Kebab,
51            "Cobol" => Case::Cobol,
52            "UpperKebab" => Case::UpperKebab,
53            "Train" => Case::Train,
54            "Flat" => Case::Flat,
55            "UpperFlat" => Case::UpperFlat,
56            "Alternating" => Case::Alternating,
57            _ => panic!("Unrecognized case name: {case_name}"),
58        }
59    }
60
61    fn transform_case(&self, ident: String) -> String {
62        if let Some(case) = self.case_transform {
63            ident.to_case(case)
64        } else {
65            ident
66        }
67    }
68}
69
70// Variant attributes
71struct VariantAttrs {
72    format: Option<String>,
73}
74
75impl VariantAttrs {
76    fn from_attrs(attrs: Vec<Attribute>) -> Self {
77        let mut format = None;
78
79        // Find the display attribute
80        for attr in attrs.into_iter() {
81            if attr.path.is_ident("display") {
82                let meta = attr.parse_meta().unwrap();
83                if let syn::Meta::List(list) = meta {
84                    if let Some(first_nested) = list.nested.first() {
85                        match first_nested {
86                            // Handle literal string: #[display("format string")]
87                            syn::NestedMeta::Lit(syn::Lit::Str(lit_str)) => {
88                                format =
89                                    Some(Self::translate_numeric_placeholders(&lit_str.value()));
90                            }
91                            // Handle named value: #[display(format = "format string")]
92                            syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
93                                if let syn::Lit::Str(lit_str) = &name_value.lit {
94                                    format = Some(Self::translate_numeric_placeholders(
95                                        &lit_str.value(),
96                                    ));
97                                }
98                            }
99                            _ => {}
100                        }
101                    }
102                }
103            }
104        }
105
106        Self { format }
107    }
108
109    // Translates {123:?} to {_unnamed_123:?} for safer format arg usage
110    fn translate_numeric_placeholders(fmt: &str) -> String {
111        let re = Regex::new(r"\{\s*(\d+)\s*([^}]*)\}").unwrap();
112        re.replace_all(fmt, |caps: &regex::Captures| {
113            let idx = &caps[1];
114            let fmt_spec = &caps[2];
115            format!("{{_unnamed_{idx}{fmt_spec}}}")
116        })
117        .to_string()
118    }
119}
120
121// Shared intermediate variant info
122struct VariantInfo {
123    ident: Ident,
124    ident_transformed: String,
125    attrs: VariantAttrs,
126}
127
128// Intermediate Named variant info
129struct NamedVariantIR {
130    info: VariantInfo,
131    fields: Vec<Ident>,
132}
133
134impl NamedVariantIR {
135    fn from_fields_named(fields_named: FieldsNamed, info: VariantInfo) -> Self {
136        let fields = fields_named
137            .named
138            .into_iter()
139            .filter_map(|field| field.ident)
140            .collect();
141        Self { info, fields }
142    }
143
144    fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
145        let VariantInfo {
146            ident,
147            ident_transformed,
148            attrs,
149        } = self.info;
150        let fields = self.fields;
151        match (any_has_format, attrs.format) {
152            (true, Some(fmt)) => {
153                quote! { #ident { #(#fields),* } => {
154                    let variant = #ident_transformed;
155                    ::core::write!(f, #fmt)
156                } }
157            }
158            (true, None) => {
159                quote! { #ident { .. } => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
160            }
161            (false, None) => quote! { #ident { .. } => #ident_transformed, },
162            _ => unreachable!(
163                "`any_has_format` should never be false when a variant has format string"
164            ),
165        }
166    }
167}
168
169// Intermediate Unnamed variant info
170struct UnnamedVariantIR {
171    info: VariantInfo,
172    fields: Vec<Ident>,
173}
174
175impl UnnamedVariantIR {
176    fn from_fields_unnamed(fields_unnamed: FieldsUnnamed, info: VariantInfo) -> Self {
177        let fields: Vec<Ident> = fields_unnamed
178            .unnamed
179            .into_iter()
180            .enumerate()
181            .map(|(i, _)| Ident::new(format!("_unnamed_{i}").as_str(), Span::call_site()))
182            .collect();
183        Self { info, fields }
184    }
185
186    fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
187        let VariantInfo {
188            ident,
189            ident_transformed,
190            attrs,
191        } = self.info;
192        let fields = self.fields;
193        match (any_has_format, attrs.format) {
194            (true, Some(fmt)) => {
195                quote! { #ident(#(#fields),*) => {
196                    let variant = #ident_transformed;
197                    ::core::write!(f, #fmt)
198                } }
199            }
200            (true, None) => {
201                quote! { #ident(..) => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
202            }
203            (false, None) => quote! { #ident(..) => #ident_transformed, },
204            _ => unreachable!(
205                "`any_has_format` should never be false when a variant has format string"
206            ),
207        }
208    }
209}
210
211// Intermediate Unit variant info
212struct UnitVariantIR {
213    info: VariantInfo,
214}
215
216impl UnitVariantIR {
217    fn new(info: VariantInfo) -> Self {
218        Self { info }
219    }
220
221    fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
222        let VariantInfo {
223            ident,
224            ident_transformed,
225            attrs,
226        } = self.info;
227        match (any_has_format, attrs.format) {
228            (true, Some(fmt)) => {
229                quote! { #ident => {
230                    let variant = #ident_transformed;
231                    ::core::write!(f, #fmt)
232                } }
233            }
234            (true, None) => {
235                quote! { #ident => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
236            }
237            (false, None) => quote! { #ident => #ident_transformed, },
238            _ => unreachable!(
239                "`any_has_format` should never be false when a variant has format string"
240            ),
241        }
242    }
243}
244
245// Intermediate version of Variant
246enum VariantIR {
247    Named(NamedVariantIR),
248    Unnamed(UnnamedVariantIR),
249    Unit(UnitVariantIR),
250}
251
252impl VariantIR {
253    fn from_variant(variant: Variant, enum_attrs: &EnumAttrs) -> Self {
254        let ident_str = variant.ident.to_string();
255        let info = VariantInfo {
256            ident: variant.ident,
257            ident_transformed: enum_attrs.transform_case(ident_str),
258            attrs: VariantAttrs::from_attrs(variant.attrs),
259        };
260        match variant.fields {
261            syn::Fields::Named(fields_named) => {
262                Self::Named(NamedVariantIR::from_fields_named(fields_named, info))
263            }
264            syn::Fields::Unnamed(fields_unnamed) => {
265                Self::Unnamed(UnnamedVariantIR::from_fields_unnamed(fields_unnamed, info))
266            }
267            syn::Fields::Unit => Self::Unit(UnitVariantIR::new(info)),
268        }
269    }
270
271    fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
272        match self {
273            VariantIR::Named(named_variant) => named_variant.generate(any_has_format),
274            VariantIR::Unnamed(unnamed_variant) => unnamed_variant.generate(any_has_format),
275            VariantIR::Unit(unit_variant) => unit_variant.generate(any_has_format),
276        }
277    }
278
279    fn has_format(&self) -> bool {
280        match self {
281            VariantIR::Named(named_variant) => &named_variant.info,
282            VariantIR::Unnamed(unnamed_variant) => &unnamed_variant.info,
283            VariantIR::Unit(unit_variant) => &unit_variant.info,
284        }
285        .attrs
286        .format
287        .is_some()
288    }
289}
290
291#[proc_macro_derive(EnumDisplay, attributes(enum_display, display))]
292pub fn derive(input: TokenStream) -> TokenStream {
293    // Parse the input tokens into a syntax tree
294    let DeriveInput {
295        ident,
296        data,
297        attrs,
298        generics,
299        ..
300    } = parse_macro_input!(input);
301
302    // Copy generics and bounds
303    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
304
305    // Read enum attrs
306    let enum_attrs = EnumAttrs::from_attrs(attrs);
307
308    // Read variants and variant attrs into an intermediate format
309    let intermediate_variants: Vec<VariantIR> = match data {
310        syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
311        _ => panic!("EnumDisplay can only be derived for enums"),
312    }
313    .into_iter()
314    .map(|variant| VariantIR::from_variant(variant, &enum_attrs))
315    .collect();
316
317    // If any variants have a format string, we need to handle formatting differently
318    let any_has_format = intermediate_variants.iter().any(|v| v.has_format());
319
320    // Build the match arms
321    let variants = intermediate_variants
322        .into_iter()
323        .map(|v| v.generate(any_has_format));
324
325    let output = if any_has_format {
326        // When format strings are present, we write directly to the formatter
327        quote! {
328            #[automatically_derived]
329            #[allow(unused_qualifications)]
330            impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
331                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
332                    match self {
333                        #(Self::#variants)*
334                    }
335                }
336            }
337        }
338    } else {
339        // When no format strings, we can return &str directly
340        quote! {
341            #[automatically_derived]
342            #[allow(unused_qualifications)]
343            impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
344                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
345                    ::core::fmt::Formatter::write_str(
346                        f,
347                        match self {
348                            #(Self::#variants)*
349                        }
350                    )
351                }
352            }
353        }
354    };
355
356    output.into()
357}