1use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
13use quote::{quote, quote_spanned, ToTokens};
14use syn::{
15 parse_macro_input, spanned::Spanned, Data, DeriveInput, Error, Fields, FieldsNamed,
16 FieldsUnnamed, LitStr,
17};
18
19macro_rules! derive_error {
20 ($string: tt) => {
21 Error::new(Span::call_site(), $string)
22 .to_compile_error()
23 .into()
24 };
25}
26
27fn has_only_unit_variants(data: &syn::DataEnum) -> bool {
28 data.variants
29 .iter()
30 .all(|variant| matches!(variant.fields, Fields::Unit))
31}
32
33fn find_duplicate_strings(data: &syn::DataEnum) -> Vec<(String, Vec<String>)> {
34 let mut string_to_variants = std::collections::HashMap::new();
35
36 for variant in data.variants.iter() {
37 if let Fields::Unit = variant.fields {
38 let mut string = variant.ident.to_string();
39 let variant_name = variant.ident.to_string();
40
41 for attr in &variant.attrs {
43 if attr.path.is_ident("enum2str") {
44 if let Ok(literal) = attr.parse_args::<syn::LitStr>() {
45 string = literal.value();
46 }
47 }
48 }
49
50 string_to_variants
51 .entry(string)
52 .or_insert_with(Vec::new)
53 .push(variant_name);
54 }
55 }
56
57 string_to_variants
58 .into_iter()
59 .filter(|(_, variants)| variants.len() > 1)
60 .collect()
61}
62
63#[proc_macro_derive(EnumStr, attributes(enum2str))]
64pub fn derive_enum2str(input: TokenStream) -> TokenStream {
65 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
66 let name = &input.ident;
67
68 let data = match input.data {
69 Data::Enum(data) => data,
70 _ => return derive_error!("enum2str only supports enums"),
71 };
72
73 let mut match_arms = TokenStream2::new();
74 let mut variant_names = TokenStream2::new();
75 let mut template_arms = TokenStream2::new();
76 let mut arg_arms = TokenStream2::new();
77 let mut from_str_arms = TokenStream2::new();
78
79 for variant in data.variants.iter() {
80 let variant_name = &variant.ident;
81
82 match &variant.fields {
83 Fields::Unit => {
84 let mut display_ident = variant_name.to_string().to_token_stream();
85 let mut from_str_pattern = variant_name.to_string();
86
87 for attr in &variant.attrs {
88 if attr.path.is_ident("enum2str") && attr.path.segments.first().is_some() {
89 match attr.parse_args::<syn::LitStr>() {
90 Ok(literal) => {
91 display_ident = literal.to_token_stream();
92 from_str_pattern = literal.value();
93 }
94 Err(_) => {
95 return derive_error!(
96 r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
97 );
98 }
99 }
100 }
101 }
102
103 match_arms.extend(quote_spanned! {
104 variant.span() =>
105 #name::#variant_name => write!(f, "{}", #display_ident),
106 });
107
108 template_arms.extend(quote_spanned! {
109 variant.span() =>
110 #name::#variant_name => #display_ident.to_string(),
111 });
112
113 variant_names.extend(quote_spanned! {
114 variant.span() =>
115 stringify!(#variant_name).to_string(),
116 });
117
118 arg_arms.extend(quote_spanned! {
119 variant.span() =>
120 #name::#variant_name => vec![],
121 });
122
123 from_str_arms.extend(quote_spanned! {
124 variant.span() =>
125 s if s == #from_str_pattern => Ok(#name::#variant_name),
126 });
127 }
128 Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) => {
129 let mut format_ident = "{}".to_string().to_token_stream();
130
131 for attr in &variant.attrs {
132 if attr.path.is_ident("enum2str") && attr.path.segments.first().is_some() {
133 match attr.parse_args::<LitStr>() {
134 Ok(literal) => format_ident = literal.to_token_stream(),
135 Err(_) => {
136 return derive_error!(
137 r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
138 );
139 }
140 }
141 }
142 }
143
144 if format_ident.to_string().contains("{}") {
145 let fields = unnamed.iter().len();
146 let args = ('a'..='z')
147 .take(fields)
148 .map(|letter| Ident::new(&letter.to_string(), variant.span()))
149 .collect::<Vec<_>>();
150 match_arms.extend(quote_spanned! {
151 variant.span() =>
152 #name::#variant_name(#(#args),*) => write!(f, #format_ident, #(#args),*),
153 });
154
155 template_arms.extend(quote_spanned! {
156 variant.span() =>
157 #name::#variant_name(..) => #format_ident.to_string(),
158 });
159
160 variant_names.extend(quote_spanned! {
161 variant.span() =>
162 stringify!(#variant_name).to_string(),
163 });
164
165 arg_arms.extend(quote_spanned! {
166 variant.span() =>
167 #name::#variant_name(#(#args),*) => vec![#(#args.to_string()),*],
168 });
169 } else {
170 match_arms.extend(quote_spanned! {
171 variant.span() =>
172 #name::#variant_name(..) => write!(f, #format_ident),
173 });
174
175 variant_names.extend(quote_spanned! {
176 variant.span() =>
177 stringify!(#variant_name).to_string(),
178 });
179
180 template_arms.extend(quote_spanned! {
181 variant.span() =>
182 #name::#variant_name(..) => #format_ident.to_string(),
183 });
184
185 arg_arms.extend(quote_spanned! {
186 variant.span() =>
187 #name::#variant_name(..) => vec![],
188 });
189 }
190 }
191 Fields::Named(FieldsNamed { named, .. }) => {
192 let mut format_ident = variant_name.to_string().to_token_stream();
193 let mut field_idents = Vec::new();
194
195 let mut has_attribute = false;
196 for attr in &variant.attrs {
197 if attr.path.is_ident("enum2str") {
198 has_attribute = true;
199 match attr.parse_args::<LitStr>() {
200 Ok(literal) => {
201 format_ident = literal.clone().to_token_stream();
202 let literal_str = literal.value().clone();
203 let mut start_indices =
204 literal_str.match_indices('{').map(|(i, _)| i);
205 let mut end_indices =
206 literal_str.match_indices('}').map(|(i, _)| i);
207
208 while let (Some(start), Some(end)) =
209 (start_indices.next(), end_indices.next())
210 {
211 let field_name = &literal_str[(start + 1)..end];
212 field_idents.push(Ident::new(field_name, Span::call_site()));
213 }
214 }
215 Err(_) => {
216 return derive_error!(
217 r#"The 'enum2str' attribute is missing a String argument. Example: #[enum2str("Listening on: {} {}")] "#
218 );
219 }
220 }
221 }
222 }
223
224 let field_names: Vec<_> = named.iter().map(|f| f.ident.as_ref().unwrap()).collect();
225
226 if !field_idents.is_empty() {
227 let arg_pattern = field_idents
229 .iter()
230 .map(|ident| quote!(#ident = #ident))
231 .collect::<Vec<_>>();
232
233 match_arms.extend(quote_spanned! {
234 variant.span() =>
235 #name::#variant_name { #(#field_names),* } => write!(f, #format_ident, #(#arg_pattern),*),
236 });
237
238 arg_arms.extend(quote_spanned! {
239 variant.span() =>
240 #name::#variant_name { #(#field_names),* } => vec![#(#field_names.to_string()),*],
241 });
242 } else {
243 match_arms.extend(quote_spanned! {
245 variant.span() =>
246 #name::#variant_name { .. } => write!(f, "{}", if #has_attribute { #format_ident.to_string() } else { stringify!(#variant_name).to_string() }),
247 });
248
249 arg_arms.extend(quote_spanned! {
250 variant.span() =>
251 #name::#variant_name { .. } => vec![],
252 });
253 }
254
255 template_arms.extend(quote_spanned! {
256 variant.span() =>
257 #name::#variant_name { .. } => #format_ident.to_string(),
258 });
259
260 variant_names.extend(quote_spanned! {
261 variant.span() =>
262 stringify!(#variant_name).to_string(),
263 });
264
265 if field_names.is_empty() && has_attribute {
266 let display_str = format_ident.to_string();
267 from_str_arms.extend(quote_spanned! {
268 variant.span() =>
269 s if s == #display_str => Ok(#name::#variant_name { }),
270 });
271 }
272 }
273 };
274 }
275
276 let expanded = quote! {
277 impl core::fmt::Display for #name {
278 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
279 match self {
280 #match_arms
281 }
282 }
283 }
284
285 impl core::str::FromStr for #name {
286 type Err = String;
287
288 fn from_str(s: &str) -> Result<Self, Self::Err> {
289 match s {
290 #from_str_arms
291 _ => Err(format!("Invalid {} variant: {}", stringify!(#name), s))
292 }
293 }
294 }
295
296 impl #name {
297 pub fn variant_names() -> Vec<String> {
299 vec![
300 #variant_names
301 ]
302 }
303
304 pub fn template(&self) -> String {
306 match self {
307 #template_arms
308 }
309 }
310
311 pub fn arguments(&self) -> Vec<String> {
313 match self {
314 #arg_arms
315 }
316 }
317 }
318 };
319
320 let mut expanded = TokenStream::from(expanded);
321
322 if has_only_unit_variants(&data) {
324 let duplicates = find_duplicate_strings(&data);
325 let try_from_impl = if duplicates.is_empty() {
326 quote! {
328 impl core::convert::TryFrom<String> for #name {
329 type Error = String;
330
331 fn try_from(value: String) -> Result<Self, Self::Error> {
332 Self::from_str(&value)
333 }
334 }
335 }
336 } else {
337 let error_msg = format!(
339 "Ambiguous string representation. The following strings are used by multiple variants: {}",
340 duplicates
341 .iter()
342 .map(|(s, v)| format!("'{}' (used by {})", s, v.join(", ")))
343 .collect::<Vec<_>>()
344 .join(", ")
345 );
346
347 let duplicate_strings: Vec<_> = duplicates.iter().map(|(s, _)| s).collect();
348
349 quote! {
350 impl core::convert::TryFrom<String> for #name {
351 type Error = String;
352
353 fn try_from(value: String) -> Result<Self, Self::Error> {
354 if [#(#duplicate_strings),*].contains(&value.as_str()) {
356 return Err(#error_msg.to_string());
357 }
358 Self::from_str(&value)
360 }
361 }
362 }
363 };
364 expanded.extend(TokenStream::from(try_from_impl));
365 }
366
367 expanded
368}