discrim_codegen/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    self,
5    parenthesized,
6    parse::{Parser, ParseStream},
7    Attribute, Data, DeriveInput, Generics, Ident,
8};
9
10// TODO: nicer error handling
11    // maybe using syn::parse
12
13/// Automatically implement `FromDiscriminant` for enums.
14///
15/// **Important:** Implementations can only be derived for types that are fieldless enums without generics,
16/// and which specify a `#[repr(...)]` attribute.
17///
18/// # Example
19/// ```rust
20/// use discrim::FromDiscriminant;
21///
22/// #[derive(Debug, FromDiscriminant, PartialEq)]
23/// #[repr(u8)]
24/// enum Opcode {
25///     Add, Sub, Mul, Div
26/// }
27///
28/// assert_eq!(Opcode::from_discriminant(2), Ok(Opcode::Mul));
29/// assert_eq!(Opcode::from_discriminant(5), Err(5));
30/// ```
31#[proc_macro_derive(FromDiscriminant)]
32pub fn derive_from_discriminant(input: TokenStream) -> TokenStream {
33    let input = syn::parse(input).expect("failed to parse macro input");
34    let (ty, repr, variants) = unpack_input(input);
35
36    // Declare a constant value per discriminant to match against
37    let discriminants = variants.iter().map(|v| {
38        let name = format_ident!("D_{}", v);
39        quote! {
40            const #name: #repr = #ty::#v as #repr;
41        }
42    });
43
44    // Define match arms for each variant
45    let match_arms = variants.iter().map(|v| {
46        let name = format_ident!("D_{}", v);
47        quote! {
48            #name => Ok(#ty::#v),
49        }
50    });
51
52    quote! {
53        impl discrim::FromDiscriminant<#repr> for #ty {
54            #[allow(non_upper_case_globals)]
55            fn from_discriminant(tag: #repr) -> Result<Self, #repr> {
56                #(#discriminants)*
57
58                match tag {
59                    #(#match_arms)*
60                    other => Err(other),
61                }
62            }
63        }
64    }.into()
65}
66
67fn unpack_input(input: DeriveInput) -> (Ident, Ident, Vec<Ident>) {
68    let data = match input.data {
69        Data::Enum(data) => data,
70        _ => panic!("input must be an enum"),
71    };
72
73    // check that there is at least one variant, and that they're all unit variants
74    if data.variants.is_empty() {
75        panic!("enum must have at least one variant");
76    }
77
78    let variants: Vec<_> = data.variants.into_iter().map(|v| v.ident).collect();
79
80    // disallow generics
81    if has_generics(&input.generics) {
82        panic!("generic enums are not supported");
83    }
84
85    // find and require the repr attribute
86    let repr = detect_repr(input.attrs).expect("#[repr(...)] attribute is required");
87
88    // return (ty, repr, variants)
89    (input.ident, repr, variants)
90}
91
92fn detect_repr(attrs: Vec<Attribute>) -> Option<Ident> {
93    // if an attr is the ident "repr", extract its contents and parse them into an ident
94    attrs.into_iter()
95        .find_map(|attr| {
96            if attr.path.is_ident("repr") {
97                Some(extract_repr.parse2(attr.tokens).expect("failed to parse tokens in #[repr(...)] attribute"))
98            } else {
99                None
100            }
101        })
102}
103
104fn extract_repr(input: ParseStream) -> syn::parse::Result<Ident> {
105    let repr;
106    parenthesized!(repr in input);
107    repr.parse()
108}
109
110fn has_generics(generics: &Generics) -> bool {
111    !generics.params.is_empty() || generics.where_clause.is_some()
112}