enumorph_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote_spanned, ToTokens};
3use syn::{parse_macro_input, parse_quote, spanned::Spanned, DeriveInput, Ident, Type};
4
5#[proc_macro_derive(Enumorph, attributes(enumorph))]
6pub fn enumorph(input: TokenStream) -> TokenStream {
7    // Parse the input tokens into a syntax tree
8    let input = parse_macro_input!(input as DeriveInput);
9
10    let enm = match input.data {
11        syn::Data::Struct(strct) => {
12            return syn::Error::new_spanned(
13                strct.struct_token,
14                "enum conversions only work on enums",
15            )
16            .into_compile_error()
17            .into()
18        }
19        syn::Data::Enum(enm) => enm,
20        syn::Data::Union(union) => {
21            return syn::Error::new_spanned(
22                union.union_token,
23                "enum conversions only work on enums",
24            )
25            .into_compile_error()
26            .into()
27        }
28    };
29
30    let impl_generics: (
31        syn::ImplGenerics<'_>,
32        syn::TypeGenerics<'_>,
33        Option<&syn::WhereClause>,
34    ) = input.generics.split_for_impl();
35
36    let impls = enm
37        .variants
38        .into_iter()
39        .filter(|x| {
40            x.attrs
41                .iter()
42                .all(|x| x.meta != parse_quote!(enumorph(ignore)))
43        })
44        .map(|x| match x.fields {
45            syn::Fields::Named(mut named) => {
46                let fields_span = named.span();
47                if named.named.len() == 1 {
48                    let field = named.named.pop().unwrap().into_value();
49                    Ok(mk_impls(
50                        &input.ident,
51                        &x.ident,
52                        &FieldName::Ident(field.ident.as_ref().unwrap()),
53                        &field.ty,
54                        &impl_generics,
55                    ))
56                } else {
57                    Err(syn::Error::new(
58                        fields_span,
59                        "only variants with one field are supported",
60                    ))
61                }
62            }
63            syn::Fields::Unnamed(mut unnamed) => {
64                let fields_span = unnamed.span();
65                if unnamed.unnamed.len() == 1 {
66                    let field = unnamed.unnamed.pop().unwrap().into_value();
67                    Ok(mk_impls(
68                        &input.ident,
69                        &x.ident,
70                        &FieldName::Index(syn::Index::from(0)),
71                        &field.ty,
72                        &impl_generics,
73                    ))
74                } else {
75                    Err(syn::Error::new(
76                        fields_span,
77                        "only variants with one field are supported",
78                    ))
79                }
80            }
81            syn::Fields::Unit => Err(syn::Error::new(
82                x.ident.span(),
83                "unit variants don't have any data to convert to/from; try `#[enumorph(ignore)]`-ing it",
84            )),
85        })
86        .fold(
87            (proc_macro2::TokenStream::new(), None::<syn::Error>),
88            |mut acc, curr| {
89                match curr {
90                    Ok(ok) => acc.0.extend(ok),
91                    Err(err) => match &mut acc.1 {
92                        Some(errs) => {
93                            errs.combine(err);
94                        }
95                        None => acc.1 = Some(err),
96                    },
97                }
98                acc
99            },
100        );
101
102    match impls {
103        (_, Some(errs)) => errs.into_compile_error().into(),
104        (impls, None) => impls.into(),
105    }
106}
107
108fn mk_impls(
109    enum_ident: &Ident,
110    variant_name: &Ident,
111    field_name: &FieldName,
112    field_type: &Type,
113    (impl_generics, ty_generics, where_clause): &(
114        syn::ImplGenerics<'_>,
115        syn::TypeGenerics<'_>,
116        Option<&syn::WhereClause>,
117    ),
118) -> proc_macro2::TokenStream {
119    quote_spanned! {field_type.span()=>
120        #[automatically_derived]
121        impl #impl_generics ::core::convert::TryFrom<#enum_ident #ty_generics> for #field_type #where_clause {
122            type Error = #enum_ident #ty_generics;
123
124            fn try_from(value: #enum_ident #ty_generics) -> ::core::result::Result<Self, Self::Error> {
125                match value {
126                    #enum_ident::#variant_name { #field_name: t, .. } => ::core::result::Result::Ok(t),
127                    #[allow(unreachable_patterns)] // triggers on enums with one variant
128                    _ => ::core::result::Result::Err(value),
129                }
130            }
131        }
132
133        #[automatically_derived]
134        impl #impl_generics ::core::convert::From<#field_type> for #enum_ident #ty_generics #where_clause {
135            fn from(value: #field_type) -> Self {
136                #[allow(clippy::init_numbered_fields)]
137                #enum_ident::#variant_name { #field_name: value }
138            }
139        }
140    }
141}
142
143enum FieldName<'a> {
144    Index(syn::Index),
145    Ident(&'a syn::Ident),
146}
147
148impl<'a> ToTokens for FieldName<'a> {
149    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
150        match self {
151            FieldName::Index(i) => i.to_tokens(tokens),
152            FieldName::Ident(i) => i.to_tokens(tokens),
153        }
154    }
155}