enumorph/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote_spanned, ToTokens};
3use syn::{parse_macro_input, parse_quote, spanned::Spanned, DeriveInput, Ident, Type};
4
5mod test;
6
7#[proc_macro_derive(Enumorph, attributes(enumorph))]
8pub fn enumorph(input: TokenStream) -> TokenStream {
9    // Parse the input tokens into a syntax tree
10    let input = parse_macro_input!(input as DeriveInput);
11
12    let enm = match input.data {
13        syn::Data::Struct(strct) => {
14            return syn::Error::new_spanned(
15                strct.struct_token,
16                "enum conversions only work on enums",
17            )
18            .into_compile_error()
19            .into()
20        }
21        syn::Data::Enum(enm) => enm,
22        syn::Data::Union(union) => {
23            return syn::Error::new_spanned(
24                union.union_token,
25                "enum conversions only work on enums",
26            )
27            .into_compile_error()
28            .into()
29        }
30    };
31
32    let impl_generics: (
33        syn::ImplGenerics<'_>,
34        syn::TypeGenerics<'_>,
35        Option<&syn::WhereClause>,
36    ) = input.generics.split_for_impl();
37
38    let impls = enm
39        .variants
40        .into_iter()
41        .filter(|x| {
42            x.attrs
43                .iter()
44                .all(|x| x.meta != parse_quote!(enumorph(ignore)))
45        })
46        .map(|x| match x.fields {
47            syn::Fields::Named(mut named) => {
48                let fields_span = named.span();
49                if named.named.len() == 1 {
50                    let field = named.named.pop().unwrap().into_value();
51                    Ok(mk_impls(
52                        &input.ident,
53                        &x.ident,
54                        &FieldName::Ident(field.ident.as_ref().unwrap()),
55                        &field.ty,
56                        &impl_generics,
57                    ))
58                } else {
59                    Err(syn::Error::new(
60                        fields_span,
61                        "only variants with one field are supported",
62                    ))
63                }
64            }
65            syn::Fields::Unnamed(mut unnamed) => {
66                let fields_span = unnamed.span();
67                if unnamed.unnamed.len() == 1 {
68                    let field = unnamed.unnamed.pop().unwrap().into_value();
69                    Ok(mk_impls(
70                        &input.ident,
71                        &x.ident,
72                        &FieldName::Index(syn::Index::from(0)),
73                        &field.ty,
74                        &impl_generics,
75                    ))
76                } else {
77                    Err(syn::Error::new(
78                        fields_span,
79                        "only variants with one field are supported",
80                    ))
81                }
82            }
83            syn::Fields::Unit => Err(syn::Error::new(
84                x.ident.span(),
85                "unit variants don't have any data to convert to/from; try `#[enumorph(ignore)]`-ing it",
86            )),
87        })
88        .fold(
89            (proc_macro2::TokenStream::new(), None::<syn::Error>),
90            |mut acc, curr| {
91                match curr {
92                    Ok(ok) => acc.0.extend(ok),
93                    Err(err) => match &mut acc.1 {
94                        Some(errs) => {
95                            errs.combine(err);
96                        }
97                        None => acc.1 = Some(err),
98                    },
99                }
100                acc
101            },
102        );
103
104    match impls.1 {
105        Some(errs) => errs.into_compile_error().into(),
106        None => impls.0.into(),
107    }
108}
109
110fn mk_impls(
111    enum_ident: &Ident,
112    variant_name: &Ident,
113    field_name: &FieldName,
114    field_type: &Type,
115    (impl_generics, ty_generics, where_clause): &(
116        syn::ImplGenerics<'_>,
117        syn::TypeGenerics<'_>,
118        Option<&syn::WhereClause>,
119    ),
120) -> proc_macro2::TokenStream {
121    quote_spanned! {field_type.span()=>
122        #[automatically_derived]
123        impl #impl_generics ::std::convert::TryFrom<#enum_ident #ty_generics> for #field_type #where_clause {
124            type Error = #enum_ident #ty_generics;
125
126            fn try_from(value: #enum_ident #ty_generics) -> ::std::result::Result<Self, Self::Error> {
127                match value {
128                    #enum_ident::#variant_name { #field_name: t, .. } => ::std::result::Result::Ok(t),
129                    #[allow(unreachable_patterns)] // triggers on enums with one variant
130                    _ => ::std::result::Result::Err(value),
131                }
132            }
133        }
134
135        #[automatically_derived]
136        impl #impl_generics ::std::convert::From<#field_type> for #enum_ident #ty_generics #where_clause {
137            fn from(value: #field_type) -> Self {
138                #[allow(clippy::init_numbered_fields)]
139                #enum_ident::#variant_name { #field_name: value }
140            }
141        }
142    }
143}
144
145enum FieldName<'a> {
146    Index(syn::Index),
147    Ident(&'a syn::Ident),
148}
149
150impl<'a> ToTokens for FieldName<'a> {
151    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
152        match self {
153            FieldName::Index(i) => i.to_tokens(tokens),
154            FieldName::Ident(i) => i.to_tokens(tokens),
155        }
156    }
157}