enum_utils/
conv.rs

1use failure::format_err;
2use proc_macro2::{TokenStream, Span};
3use quote::quote;
4
5use crate::attr::{Enum, ErrorList};
6
7pub fn derive_try_from_repr(input: &syn::DeriveInput) -> Result<TokenStream, ErrorList> {
8    let Enum { name, variants, primitive_repr, .. } = Enum::parse(input)?;
9
10    let mut errors = ErrorList::new();
11    let repr = match primitive_repr {
12        Ok(Some((_, repr))) => repr,
13
14        Ok(None) => bail_list!("`#[repr(...)]` must be specified to derive `TryFrom`"),
15        Err(e) => {
16            errors.push_back(e);
17            return Err(errors);
18        }
19    };
20
21    for (v, _) in variants.iter() {
22        if v.fields != syn::Fields::Unit {
23            errors.push_back(format_err!("Variant cannot have fields"));
24            continue;
25        }
26    }
27
28    if !errors.is_empty() {
29        return Err(errors);
30    }
31
32    let consts = variants.iter()
33        .map(|(v, _)| {
34            let s = "DISCRIMINANT_".to_owned() + &v.ident.to_string();
35            syn::Ident::new(s.as_str(), Span::call_site())
36        });
37
38    let ctors = variants.iter()
39        .map(|(v, _)| {
40            let v = &v.ident;
41            quote!(#name::#v)
42        });
43
44    // `as` casts are not valid as part of a pattern, so we need to do define new `consts` to hold
45    // them.
46    let const_defs = consts.clone()
47        .zip(ctors.clone())
48        .map(|(v, ctor)|  quote!(const #v: #repr = #ctor as #repr));
49
50    Ok(quote! {
51        impl ::std::convert::TryFrom<#repr> for #name {
52            type Error = ();
53
54            #[allow(non_upper_case_globals)]
55            fn try_from(d: #repr) -> Result<Self, Self::Error> {
56
57                #( #const_defs; )*
58
59                match d {
60                    #( #consts => Ok(#ctors), )*
61                    _ => Err(())
62                }
63            }
64        }
65    })
66}
67
68pub fn derive_repr_from(input: &syn::DeriveInput) -> Result<TokenStream, ErrorList> {
69    let Enum { name, variants, primitive_repr, .. } = Enum::parse(input)?;
70
71    let mut errors = ErrorList::new();
72    let repr = match primitive_repr {
73        Ok(Some((_, repr))) => repr,
74
75        Ok(None) => bail_list!("`#[repr(...)]` must be specified to derive `TryFrom`"),
76        Err(e) => {
77            errors.push_back(e);
78            return Err(errors);
79        }
80    };
81
82    for (v, _) in variants.iter() {
83        if v.fields != syn::Fields::Unit {
84            errors.push_back(format_err!("Variant cannot have fields"));
85            continue;
86        }
87    }
88
89    if !errors.is_empty() {
90        return Err(errors);
91    }
92
93    Ok(quote! {
94        impl ::std::convert::From<#name> for #repr {
95            fn from(d: #name) -> Self {
96                d as #repr
97            }
98        }
99    })
100}