1use syn::{DeriveInput, spanned::Spanned as _};
2
3#[proc_macro_derive(EnumStrConv, attributes(enum_str_conv))]
4pub fn enum_str_conv(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
5 let input = syn::parse_macro_input!(input as syn::DeriveInput);
6 let output = my_derive(input).unwrap_or_else(syn::Error::into_compile_error);
7 proc_macro::TokenStream::from(output)
8}
9
10fn my_derive(input: syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
11 Ok(format(parse(input)?))
12}
13
14struct Parsed {
15 enum_ident: syn::Ident,
16 error_ident: syn::Expr,
17 unknown_fn: syn::Expr,
18 variant_attrs: Vec<(syn::Ident, syn::LitStr)>,
19}
20
21fn parse(input: syn::DeriveInput) -> Result<Parsed, syn::Error> {
22 let data_enum = if let syn::Data::Enum(data_enum) = &input.data {
23 Ok(data_enum)
24 } else {
25 Err(syn::Error::new_spanned(
26 &input,
27 "EnumStrConv can only be derived for enums",
28 ))
29 }?;
30 let enum_ident = input.ident.clone();
31 let EnumAttr {
32 error: error_ident,
33 unknown: unknown_fn,
34 } = parse_enum_attr(&input)?;
35 let variant_attrs = data_enum
36 .variants
37 .iter()
38 .map(|variant| {
39 let variant_ident = variant.ident.clone();
40 let VariantAttr { str: variant_str } = parse_variant_attr(&variant)?;
41 Ok((variant_ident, variant_str))
42 })
43 .collect::<Result<Vec<(syn::Ident, syn::LitStr)>, syn::Error>>()?;
44 Ok(Parsed {
45 enum_ident,
46 error_ident,
47 unknown_fn,
48 variant_attrs,
49 })
50}
51
52fn format(
53 Parsed {
54 enum_ident,
55 error_ident,
56 unknown_fn: unknown,
57 variant_attrs,
58 }: Parsed,
59) -> proc_macro2::TokenStream {
60 let display_variants = variant_attrs.iter().map(|(ident, str)| {
61 quote::quote! {
62 Self::#ident => write!(f, #str)
63 }
64 });
65 let from_str_variants = variant_attrs.iter().map(|(ident, str)| {
66 quote::quote! {
67 #str => Ok(Self::#ident)
68 }
69 });
70 let output = quote::quote! {
71 #[automatically_derived]
72 impl ::std::fmt::Display for #enum_ident {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 #(#display_variants,)*
76 }
77 }
78 }
79
80 #[automatically_derived]
81 impl ::std::str::FromStr for #enum_ident {
82 type Err = #error_ident;
83
84 fn from_str(s: &str) -> Result<Self, Self::Err> {
85 match s {
86 #(#from_str_variants,)*
87 _ => Err(#unknown(s.to_owned())),
88 }
89 }
90 }
91 };
92
93 output
94}
95
96struct EnumAttr {
97 error: syn::Expr,
98 unknown: syn::Expr,
99}
100
101fn parse_enum_attr(input: &DeriveInput) -> Result<EnumAttr, syn::Error> {
102 let attr = input
103 .attrs
104 .iter()
105 .find(|attr| attr.path().is_ident("enum_str_conv"))
106 .ok_or_else(|| {
107 syn::Error::new_spanned(
108 &input,
109 "expected attribute: #[enum_str_conv(error = ..., unknown = ...)]",
110 )
111 })?;
112 let nested = attr
113 .parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
114 .map_err(|_| {
115 syn::Error::new_spanned(
116 &attr.meta,
117 "expected attribute arguments: #[enum_str_conv(error = ..., unknown = ...)]",
118 )
119 })?;
120 let mut error = None;
121 let mut unknown = None;
122 for meta in nested {
123 let meta_name_value = meta.require_name_value()?;
124 if meta_name_value.path.is_ident("error") {
125 error = Some(meta_name_value.value.clone());
126 } else if meta_name_value.path.is_ident("unknown") {
127 unknown = Some(meta_name_value.value.clone());
128 } else {
129 Err(syn::Error::new_spanned(
130 &meta_name_value,
131 "unknown argument: #[enum_str_conv(error = ..., unknown = ...)]",
132 ))?;
133 }
134 }
135 match (error, unknown) {
136 (None, None) | (None, Some(_)) => Err(syn::Error::new_spanned(
137 &attr.meta,
138 "expected `error` argument: #[enum_str_conv(error = ...)]",
139 )),
140 (Some(_), None) => Err(syn::Error::new_spanned(
141 &attr.meta,
142 "expected `unknown` argument: #[enum_str_conv(unknown = ...)]",
143 )),
144 (Some(error), Some(unknown)) => Ok(EnumAttr { error, unknown }),
145 }
146}
147
148struct VariantAttr {
149 str: syn::LitStr,
150}
151
152fn parse_variant_attr(variant: &syn::Variant) -> Result<VariantAttr, syn::Error> {
153 let attr = variant
154 .attrs
155 .iter()
156 .find(|attr| attr.path().is_ident("enum_str_conv"))
157 .ok_or_else(|| {
158 syn::Error::new(
159 variant.span(),
160 "expected attribute: #[enum_str_conv(str = ...)]",
161 )
162 })?;
163 let nested = attr
164 .parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
165 .map_err(|_| {
166 syn::Error::new(
167 attr.meta.span(),
168 "expected attribute arguments: #[enum_str_conv(str = ...)]",
169 )
170 })?;
171 let mut str = None;
172 for meta in nested {
173 let meta_name_value = meta.require_name_value().unwrap();
174 if meta_name_value.path.is_ident("str") {
175 match &meta_name_value.value {
176 syn::Expr::Lit(syn::ExprLit {
177 lit: syn::Lit::Str(lit_str),
178 ..
179 }) => {
180 str = Some(lit_str.to_owned());
181 }
182 _ => {
183 Err(syn::Error::new_spanned(
184 &meta_name_value.value,
185 r#"unknown argument type: #[enum_str_conv(str = "...")]"#,
186 ))?;
187 }
188 }
189 } else {
190 Err(syn::Error::new_spanned(
191 &meta_name_value,
192 "unknown argument: #[enum_str_conv(str = ...)]",
193 ))?;
194 }
195 }
196 match str {
197 None => Err(syn::Error::new_spanned(
198 &attr.meta,
199 "expected `str` argument: #[enum_str_conv(str = ...)]",
200 )),
201 Some(str) => Ok(VariantAttr { str }),
202 }
203}