1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
self,
parenthesized,
parse::{Parser, ParseStream},
Attribute, Data, DeriveInput, Generics, Ident,
};
#[proc_macro_derive(FromDiscriminant)]
pub fn derive_from_discriminant(input: TokenStream) -> TokenStream {
let input = syn::parse(input).expect("failed to parse macro input");
let (ty, repr, variants) = unpack_input(input);
let discriminants = variants.iter().map(|v| {
let name = format_ident!("D_{}", v);
quote! {
const #name: #repr = #ty::#v as #repr;
}
});
let match_arms = variants.iter().map(|v| {
let name = format_ident!("D_{}", v);
quote! {
#name => Ok(#ty::#v),
}
});
quote! {
impl discrim::FromDiscriminant<#repr> for #ty {
#[allow(non_upper_case_globals)]
fn from_discriminant(tag: #repr) -> Result<Self, #repr> {
#(#discriminants)*
match tag {
#(#match_arms)*
other => Err(other),
}
}
}
}.into()
}
fn unpack_input(input: DeriveInput) -> (Ident, Ident, Vec<Ident>) {
let data = match input.data {
Data::Enum(data) => data,
_ => panic!("input must be an enum"),
};
if data.variants.is_empty() {
panic!("enum must have at least one variant");
}
let variants: Vec<_> = data.variants.into_iter().map(|v| v.ident).collect();
if has_generics(&input.generics) {
panic!("generic enums are not supported");
}
let repr = detect_repr(input.attrs).expect("#[repr(...)] attribute is required");
(input.ident, repr, variants)
}
fn detect_repr(attrs: Vec<Attribute>) -> Option<Ident> {
attrs.into_iter()
.find_map(|attr| {
if attr.path.is_ident("repr") {
Some(extract_repr.parse2(attr.tokens).expect("failed to parse tokens in #[repr(...)] attribute"))
} else {
None
}
})
}
fn extract_repr(input: ParseStream) -> syn::parse::Result<Ident> {
let repr;
parenthesized!(repr in input);
repr.parse()
}
fn has_generics(generics: &Generics) -> bool {
!generics.params.is_empty() || generics.where_clause.is_some()
}