1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
    self,
    parenthesized,
    parse::{Parser, ParseStream},
    Attribute, Data, DeriveInput, Generics, Ident,
};

// TODO: nicer error handling
    // maybe using syn::parse

#[proc_macro_derive(FromDiscriminant)]
pub fn derive_from_discriminant(input: TokenStream) -> TokenStream {
    let input = syn::parse(input).expect("failed to parse macro input");
    let (ty, repr, variants) = unpack_input(input);

    // Declare a constant value per discriminant to match against
    let discriminants = variants.iter().map(|v| {
        let name = format_ident!("D_{}", v);
        quote! {
            const #name: #repr = #ty::#v as #repr;
        }
    });

    // Define match arms for each variant
    let match_arms = variants.iter().map(|v| {
        let name = format_ident!("D_{}", v);
        quote! {
            #name => Ok(#ty::#v),
        }
    });

    quote! {
        impl discrim::FromDiscriminant<#repr> for #ty {
            #[allow(non_upper_case_globals)]
            fn from_discriminant(tag: #repr) -> Result<Self, #repr> {
                #(#discriminants)*

                match tag {
                    #(#match_arms)*
                    other => Err(other),
                }
            }
        }
    }.into()
}

fn unpack_input(input: DeriveInput) -> (Ident, Ident, Vec<Ident>) {
    let data = match input.data {
        Data::Enum(data) => data,
        _ => panic!("input must be an enum"),
    };

    // check that there is at least one variant, and that they're all unit variants
    if data.variants.is_empty() {
        panic!("enum must have at least one variant");
    }

    let variants: Vec<_> = data.variants.into_iter().map(|v| v.ident).collect();

    // disallow generics
    if has_generics(&input.generics) {
        panic!("generic enums are not supported");
    }

    // find and require the repr attribute
    let repr = detect_repr(input.attrs).expect("#[repr(...)] attribute is required");

    // return (ty, repr, variants)
    (input.ident, repr, variants)
}

fn detect_repr(attrs: Vec<Attribute>) -> Option<Ident> {
    // if an attr is the ident "repr", extract its contents and parse them into an ident
    attrs.into_iter()
        .find_map(|attr| {
            if attr.path.is_ident("repr") {
                Some(extract_repr.parse2(attr.tokens).expect("failed to parse tokens in #[repr(...)] attribute"))
            } else {
                None
            }
        })
}

fn extract_repr(input: ParseStream) -> syn::parse::Result<Ident> {
    let repr;
    parenthesized!(repr in input);
    repr.parse()
}

fn has_generics(generics: &Generics) -> bool {
    !generics.params.is_empty() || generics.where_clause.is_some()
}