1use convert_case::{Case, Casing};
2use proc_macro::{self, TokenStream};
3use proc_macro2::Span;
4use quote::quote;
5use regex::Regex;
6use syn::{parse_macro_input, Attribute, DeriveInput, FieldsNamed, FieldsUnnamed, Ident, Variant};
7
8struct EnumAttrs {
10 case_transform: Option<Case>,
11}
12
13impl EnumAttrs {
14 fn from_attrs(attrs: Vec<Attribute>) -> Self {
15 let mut case_transform: Option<Case> = None;
16
17 for attr in attrs.into_iter() {
18 if attr.path.is_ident("enum_display") {
19 let meta = attr.parse_meta().unwrap();
20 if let syn::Meta::List(list) = meta {
21 for nested in list.nested {
22 if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
23 if name_value.path.is_ident("case") {
24 if let syn::Lit::Str(lit_str) = name_value.lit {
25 case_transform =
26 Some(Self::parse_case_name(lit_str.value().as_str()));
27 }
28 }
29 }
30 }
31 }
32 }
33 }
34
35 Self { case_transform }
36 }
37
38 fn parse_case_name(case_name: &str) -> Case {
39 match case_name {
40 "Upper" => Case::Upper,
41 "Lower" => Case::Lower,
42 "Title" => Case::Title,
43 "Toggle" => Case::Toggle,
44 "Camel" => Case::Camel,
45 "Pascal" => Case::Pascal,
46 "UpperCamel" => Case::UpperCamel,
47 "Snake" => Case::Snake,
48 "UpperSnake" => Case::UpperSnake,
49 "ScreamingSnake" => Case::ScreamingSnake,
50 "Kebab" => Case::Kebab,
51 "Cobol" => Case::Cobol,
52 "UpperKebab" => Case::UpperKebab,
53 "Train" => Case::Train,
54 "Flat" => Case::Flat,
55 "UpperFlat" => Case::UpperFlat,
56 "Alternating" => Case::Alternating,
57 _ => panic!("Unrecognized case name: {case_name}"),
58 }
59 }
60
61 fn transform_case(&self, ident: String) -> String {
62 if let Some(case) = self.case_transform {
63 ident.to_case(case)
64 } else {
65 ident
66 }
67 }
68}
69
70struct VariantAttrs {
72 format: Option<String>,
73}
74
75impl VariantAttrs {
76 fn from_attrs(attrs: Vec<Attribute>) -> Self {
77 let mut format = None;
78
79 for attr in attrs.into_iter() {
81 if attr.path.is_ident("display") {
82 let meta = attr.parse_meta().unwrap();
83 if let syn::Meta::List(list) = meta {
84 if let Some(first_nested) = list.nested.first() {
85 match first_nested {
86 syn::NestedMeta::Lit(syn::Lit::Str(lit_str)) => {
88 format =
89 Some(Self::translate_numeric_placeholders(&lit_str.value()));
90 }
91 syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
93 if let syn::Lit::Str(lit_str) = &name_value.lit {
94 format = Some(Self::translate_numeric_placeholders(
95 &lit_str.value(),
96 ));
97 }
98 }
99 _ => {}
100 }
101 }
102 }
103 }
104 }
105
106 Self { format }
107 }
108
109 fn translate_numeric_placeholders(fmt: &str) -> String {
111 let re = Regex::new(r"\{\s*(\d+)\s*([^}]*)\}").unwrap();
112 re.replace_all(fmt, |caps: ®ex::Captures| {
113 let idx = &caps[1];
114 let fmt_spec = &caps[2];
115 format!("{{_unnamed_{idx}{fmt_spec}}}")
116 })
117 .to_string()
118 }
119}
120
121struct VariantInfo {
123 ident: Ident,
124 ident_transformed: String,
125 attrs: VariantAttrs,
126}
127
128struct NamedVariantIR {
130 info: VariantInfo,
131 fields: Vec<Ident>,
132}
133
134impl NamedVariantIR {
135 fn from_fields_named(fields_named: FieldsNamed, info: VariantInfo) -> Self {
136 let fields = fields_named
137 .named
138 .into_iter()
139 .filter_map(|field| field.ident)
140 .collect();
141 Self { info, fields }
142 }
143
144 fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
145 let VariantInfo {
146 ident,
147 ident_transformed,
148 attrs,
149 } = self.info;
150 let fields = self.fields;
151 match (any_has_format, attrs.format) {
152 (true, Some(fmt)) => {
153 quote! { #ident { #(#fields),* } => {
154 let variant = #ident_transformed;
155 ::core::write!(f, #fmt)
156 } }
157 }
158 (true, None) => {
159 quote! { #ident { .. } => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
160 }
161 (false, None) => quote! { #ident { .. } => #ident_transformed, },
162 _ => unreachable!(
163 "`any_has_format` should never be false when a variant has format string"
164 ),
165 }
166 }
167}
168
169struct UnnamedVariantIR {
171 info: VariantInfo,
172 fields: Vec<Ident>,
173}
174
175impl UnnamedVariantIR {
176 fn from_fields_unnamed(fields_unnamed: FieldsUnnamed, info: VariantInfo) -> Self {
177 let fields: Vec<Ident> = fields_unnamed
178 .unnamed
179 .into_iter()
180 .enumerate()
181 .map(|(i, _)| Ident::new(format!("_unnamed_{i}").as_str(), Span::call_site()))
182 .collect();
183 Self { info, fields }
184 }
185
186 fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
187 let VariantInfo {
188 ident,
189 ident_transformed,
190 attrs,
191 } = self.info;
192 let fields = self.fields;
193 match (any_has_format, attrs.format) {
194 (true, Some(fmt)) => {
195 quote! { #ident(#(#fields),*) => {
196 let variant = #ident_transformed;
197 ::core::write!(f, #fmt)
198 } }
199 }
200 (true, None) => {
201 quote! { #ident(..) => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
202 }
203 (false, None) => quote! { #ident(..) => #ident_transformed, },
204 _ => unreachable!(
205 "`any_has_format` should never be false when a variant has format string"
206 ),
207 }
208 }
209}
210
211struct UnitVariantIR {
213 info: VariantInfo,
214}
215
216impl UnitVariantIR {
217 fn new(info: VariantInfo) -> Self {
218 Self { info }
219 }
220
221 fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
222 let VariantInfo {
223 ident,
224 ident_transformed,
225 attrs,
226 } = self.info;
227 match (any_has_format, attrs.format) {
228 (true, Some(fmt)) => {
229 quote! { #ident => {
230 let variant = #ident_transformed;
231 ::core::write!(f, #fmt)
232 } }
233 }
234 (true, None) => {
235 quote! { #ident => ::core::fmt::Formatter::write_str(f, #ident_transformed), }
236 }
237 (false, None) => quote! { #ident => #ident_transformed, },
238 _ => unreachable!(
239 "`any_has_format` should never be false when a variant has format string"
240 ),
241 }
242 }
243}
244
245enum VariantIR {
247 Named(NamedVariantIR),
248 Unnamed(UnnamedVariantIR),
249 Unit(UnitVariantIR),
250}
251
252impl VariantIR {
253 fn from_variant(variant: Variant, enum_attrs: &EnumAttrs) -> Self {
254 let ident_str = variant.ident.to_string();
255 let info = VariantInfo {
256 ident: variant.ident,
257 ident_transformed: enum_attrs.transform_case(ident_str),
258 attrs: VariantAttrs::from_attrs(variant.attrs),
259 };
260 match variant.fields {
261 syn::Fields::Named(fields_named) => {
262 Self::Named(NamedVariantIR::from_fields_named(fields_named, info))
263 }
264 syn::Fields::Unnamed(fields_unnamed) => {
265 Self::Unnamed(UnnamedVariantIR::from_fields_unnamed(fields_unnamed, info))
266 }
267 syn::Fields::Unit => Self::Unit(UnitVariantIR::new(info)),
268 }
269 }
270
271 fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream {
272 match self {
273 VariantIR::Named(named_variant) => named_variant.generate(any_has_format),
274 VariantIR::Unnamed(unnamed_variant) => unnamed_variant.generate(any_has_format),
275 VariantIR::Unit(unit_variant) => unit_variant.generate(any_has_format),
276 }
277 }
278
279 fn has_format(&self) -> bool {
280 match self {
281 VariantIR::Named(named_variant) => &named_variant.info,
282 VariantIR::Unnamed(unnamed_variant) => &unnamed_variant.info,
283 VariantIR::Unit(unit_variant) => &unit_variant.info,
284 }
285 .attrs
286 .format
287 .is_some()
288 }
289}
290
291#[proc_macro_derive(EnumDisplay, attributes(enum_display, display))]
292pub fn derive(input: TokenStream) -> TokenStream {
293 let DeriveInput {
295 ident,
296 data,
297 attrs,
298 generics,
299 ..
300 } = parse_macro_input!(input);
301
302 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
304
305 let enum_attrs = EnumAttrs::from_attrs(attrs);
307
308 let intermediate_variants: Vec<VariantIR> = match data {
310 syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
311 _ => panic!("EnumDisplay can only be derived for enums"),
312 }
313 .into_iter()
314 .map(|variant| VariantIR::from_variant(variant, &enum_attrs))
315 .collect();
316
317 let any_has_format = intermediate_variants.iter().any(|v| v.has_format());
319
320 let variants = intermediate_variants
322 .into_iter()
323 .map(|v| v.generate(any_has_format));
324
325 let output = if any_has_format {
326 quote! {
328 #[automatically_derived]
329 #[allow(unused_qualifications)]
330 impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
331 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
332 match self {
333 #(Self::#variants)*
334 }
335 }
336 }
337 }
338 } else {
339 quote! {
341 #[automatically_derived]
342 #[allow(unused_qualifications)]
343 impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
344 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
345 ::core::fmt::Formatter::write_str(
346 f,
347 match self {
348 #(Self::#variants)*
349 }
350 )
351 }
352 }
353 }
354 };
355
356 output.into()
357}