enum_display_macro/
lib.rs1use convert_case::{Case, Casing};
2use proc_macro::{self, TokenStream};
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput};
5
6fn parse_case_name(case_name: &str) -> Case {
7 match case_name {
8 "Upper" => Case::Upper,
9 "Lower" => Case::Lower,
10 "Title" => Case::Title,
11 "Toggle" => Case::Toggle,
12 "Camel" => Case::Camel,
13 "Pascal" => Case::Pascal,
14 "UpperCamel" => Case::UpperCamel,
15 "Snake" => Case::Snake,
16 "UpperSnake" => Case::UpperSnake,
17 "ScreamingSnake" => Case::ScreamingSnake,
18 "Kebab" => Case::Kebab,
19 "Cobol" => Case::Cobol,
20 "UpperKebab" => Case::UpperKebab,
21 "Train" => Case::Train,
22 "Flat" => Case::Flat,
23 "UpperFlat" => Case::UpperFlat,
24 "Alternating" => Case::Alternating,
25 _ => panic!("Unrecognized case name: {}", case_name),
26 }
27}
28
29#[proc_macro_derive(EnumDisplay, attributes(enum_display))]
30pub fn derive(input: TokenStream) -> TokenStream {
31 let DeriveInput {
33 ident, data, attrs, ..
34 } = parse_macro_input!(input);
35
36 let mut case_transform: Option<Case> = None;
38
39 for attr in attrs.into_iter() {
41 if attr.path.is_ident("enum_display") {
42 let meta = attr.parse_meta().unwrap();
43 if let syn::Meta::List(list) = meta {
44 for nested in list.nested {
45 if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
46 if name_value.path.is_ident("case") {
47 if let syn::Lit::Str(lit_str) = name_value.lit {
48 case_transform = Some(parse_case_name(lit_str.value().as_str()));
50 }
51 }
52 }
53 }
54 }
55 }
56 }
57
58 let variants = match data {
60 syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
61 _ => panic!("EnumDisplay can only be derived for enums"),
62 }
63 .into_iter()
64 .map(|variant| {
65 let ident = variant.ident;
66 let ident_str = if case_transform.is_some() {
67 ident.to_string().to_case(case_transform.unwrap())
68 } else {
69 ident.to_string()
70 };
71
72 match variant.fields {
73 syn::Fields::Named(_) => quote! {
74 #ident { .. } => #ident_str,
75 },
76 syn::Fields::Unnamed(_) => quote! {
77 #ident(..) => #ident_str,
78 },
79 syn::Fields::Unit => quote! {
80 #ident => #ident_str,
81 },
82 }
83 });
84
85 let output = quote! {
89 #[automatically_derived]
90 #[allow(unused_qualifications)]
91 impl ::core::fmt::Display for #ident {
92 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
93 ::core::fmt::Formatter::write_str(
94 f,
95 match self {
96 #(#ident::#variants)*
97 },
98 )
99 }
100 }
101 };
102 output.into()
103}