1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 self,
5 parenthesized,
6 parse::{Parser, ParseStream},
7 Attribute, Data, DeriveInput, Generics, Ident,
8};
9
10#[proc_macro_derive(FromDiscriminant)]
32pub fn derive_from_discriminant(input: TokenStream) -> TokenStream {
33 let input = syn::parse(input).expect("failed to parse macro input");
34 let (ty, repr, variants) = unpack_input(input);
35
36 let discriminants = variants.iter().map(|v| {
38 let name = format_ident!("D_{}", v);
39 quote! {
40 const #name: #repr = #ty::#v as #repr;
41 }
42 });
43
44 let match_arms = variants.iter().map(|v| {
46 let name = format_ident!("D_{}", v);
47 quote! {
48 #name => Ok(#ty::#v),
49 }
50 });
51
52 quote! {
53 impl discrim::FromDiscriminant<#repr> for #ty {
54 #[allow(non_upper_case_globals)]
55 fn from_discriminant(tag: #repr) -> Result<Self, #repr> {
56 #(#discriminants)*
57
58 match tag {
59 #(#match_arms)*
60 other => Err(other),
61 }
62 }
63 }
64 }.into()
65}
66
67fn unpack_input(input: DeriveInput) -> (Ident, Ident, Vec<Ident>) {
68 let data = match input.data {
69 Data::Enum(data) => data,
70 _ => panic!("input must be an enum"),
71 };
72
73 if data.variants.is_empty() {
75 panic!("enum must have at least one variant");
76 }
77
78 let variants: Vec<_> = data.variants.into_iter().map(|v| v.ident).collect();
79
80 if has_generics(&input.generics) {
82 panic!("generic enums are not supported");
83 }
84
85 let repr = detect_repr(input.attrs).expect("#[repr(...)] attribute is required");
87
88 (input.ident, repr, variants)
90}
91
92fn detect_repr(attrs: Vec<Attribute>) -> Option<Ident> {
93 attrs.into_iter()
95 .find_map(|attr| {
96 if attr.path.is_ident("repr") {
97 Some(extract_repr.parse2(attr.tokens).expect("failed to parse tokens in #[repr(...)] attribute"))
98 } else {
99 None
100 }
101 })
102}
103
104fn extract_repr(input: ParseStream) -> syn::parse::Result<Ident> {
105 let repr;
106 parenthesized!(repr in input);
107 repr.parse()
108}
109
110fn has_generics(generics: &Generics) -> bool {
111 !generics.params.is_empty() || generics.where_clause.is_some()
112}