enum_str_conv/
lib.rs

1use syn::{DeriveInput, spanned::Spanned as _};
2
3#[proc_macro_derive(EnumStrConv, attributes(enum_str_conv))]
4pub fn enum_str_conv(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
5    let input = syn::parse_macro_input!(input as syn::DeriveInput);
6    let output = my_derive(input).unwrap_or_else(syn::Error::into_compile_error);
7    proc_macro::TokenStream::from(output)
8}
9
10fn my_derive(input: syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
11    Ok(format(parse(input)?))
12}
13
14struct Parsed {
15    enum_ident: syn::Ident,
16    error_ident: syn::Expr,
17    unknown_fn: syn::Expr,
18    variant_attrs: Vec<(syn::Ident, syn::LitStr)>,
19}
20
21fn parse(input: syn::DeriveInput) -> Result<Parsed, syn::Error> {
22    let data_enum = if let syn::Data::Enum(data_enum) = &input.data {
23        Ok(data_enum)
24    } else {
25        Err(syn::Error::new_spanned(
26            &input,
27            "EnumStrConv can only be derived for enums",
28        ))
29    }?;
30    let enum_ident = input.ident.clone();
31    let EnumAttr {
32        error: error_ident,
33        unknown: unknown_fn,
34    } = parse_enum_attr(&input)?;
35    let variant_attrs = data_enum
36        .variants
37        .iter()
38        .map(|variant| {
39            let variant_ident = variant.ident.clone();
40            let VariantAttr { str: variant_str } = parse_variant_attr(&variant)?;
41            Ok((variant_ident, variant_str))
42        })
43        .collect::<Result<Vec<(syn::Ident, syn::LitStr)>, syn::Error>>()?;
44    Ok(Parsed {
45        enum_ident,
46        error_ident,
47        unknown_fn,
48        variant_attrs,
49    })
50}
51
52fn format(
53    Parsed {
54        enum_ident,
55        error_ident,
56        unknown_fn: unknown,
57        variant_attrs,
58    }: Parsed,
59) -> proc_macro2::TokenStream {
60    let display_variants = variant_attrs.iter().map(|(ident, str)| {
61        quote::quote! {
62            Self::#ident => write!(f, #str)
63        }
64    });
65    let from_str_variants = variant_attrs.iter().map(|(ident, str)| {
66        quote::quote! {
67            #str => Ok(Self::#ident)
68        }
69    });
70    let output = quote::quote! {
71        #[automatically_derived]
72        impl ::std::fmt::Display for #enum_ident {
73            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74                match self {
75                    #(#display_variants,)*
76                }
77            }
78        }
79
80        #[automatically_derived]
81        impl ::std::str::FromStr for #enum_ident {
82            type Err = #error_ident;
83
84            fn from_str(s: &str) -> Result<Self, Self::Err> {
85                match s {
86                    #(#from_str_variants,)*
87                    _ => Err(#unknown(s.to_owned())),
88                }
89            }
90        }
91    };
92
93    output
94}
95
96struct EnumAttr {
97    error: syn::Expr,
98    unknown: syn::Expr,
99}
100
101fn parse_enum_attr(input: &DeriveInput) -> Result<EnumAttr, syn::Error> {
102    let attr = input
103        .attrs
104        .iter()
105        .find(|attr| attr.path().is_ident("enum_str_conv"))
106        .ok_or_else(|| {
107            syn::Error::new_spanned(
108                &input,
109                "expected attribute: #[enum_str_conv(error = ..., unknown = ...)]",
110            )
111        })?;
112    let nested = attr
113        .parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
114        .map_err(|_| {
115            syn::Error::new_spanned(
116                &attr.meta,
117                "expected attribute arguments: #[enum_str_conv(error = ..., unknown = ...)]",
118            )
119        })?;
120    let mut error = None;
121    let mut unknown = None;
122    for meta in nested {
123        let meta_name_value = meta.require_name_value()?;
124        if meta_name_value.path.is_ident("error") {
125            error = Some(meta_name_value.value.clone());
126        } else if meta_name_value.path.is_ident("unknown") {
127            unknown = Some(meta_name_value.value.clone());
128        } else {
129            Err(syn::Error::new_spanned(
130                &meta_name_value,
131                "unknown argument: #[enum_str_conv(error = ..., unknown = ...)]",
132            ))?;
133        }
134    }
135    match (error, unknown) {
136        (None, None) | (None, Some(_)) => Err(syn::Error::new_spanned(
137            &attr.meta,
138            "expected `error` argument: #[enum_str_conv(error = ...)]",
139        )),
140        (Some(_), None) => Err(syn::Error::new_spanned(
141            &attr.meta,
142            "expected `unknown` argument: #[enum_str_conv(unknown = ...)]",
143        )),
144        (Some(error), Some(unknown)) => Ok(EnumAttr { error, unknown }),
145    }
146}
147
148struct VariantAttr {
149    str: syn::LitStr,
150}
151
152fn parse_variant_attr(variant: &syn::Variant) -> Result<VariantAttr, syn::Error> {
153    let attr = variant
154        .attrs
155        .iter()
156        .find(|attr| attr.path().is_ident("enum_str_conv"))
157        .ok_or_else(|| {
158            syn::Error::new(
159                variant.span(),
160                "expected attribute: #[enum_str_conv(str = ...)]",
161            )
162        })?;
163    let nested = attr
164        .parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
165        .map_err(|_| {
166            syn::Error::new(
167                attr.meta.span(),
168                "expected attribute arguments: #[enum_str_conv(str = ...)]",
169            )
170        })?;
171    let mut str = None;
172    for meta in nested {
173        let meta_name_value = meta.require_name_value().unwrap();
174        if meta_name_value.path.is_ident("str") {
175            match &meta_name_value.value {
176                syn::Expr::Lit(syn::ExprLit {
177                    lit: syn::Lit::Str(lit_str),
178                    ..
179                }) => {
180                    str = Some(lit_str.to_owned());
181                }
182                _ => {
183                    Err(syn::Error::new_spanned(
184                        &meta_name_value.value,
185                        r#"unknown argument type: #[enum_str_conv(str = "...")]"#,
186                    ))?;
187                }
188            }
189        } else {
190            Err(syn::Error::new_spanned(
191                &meta_name_value,
192                "unknown argument: #[enum_str_conv(str = ...)]",
193            ))?;
194        }
195    }
196    match str {
197        None => Err(syn::Error::new_spanned(
198            &attr.meta,
199            "expected `str` argument: #[enum_str_conv(str = ...)]",
200        )),
201        Some(str) => Ok(VariantAttr { str }),
202    }
203}