primitive_enum_derive/
lib.rs

1extern crate proc_macro2;
2extern crate quote;
3extern crate syn;
4
5use proc_macro2::TokenStream;
6use quote::quote;
7use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Ident};
8
9fn get_primitive_name(attrs: &[Attribute]) -> (Ident, String) {
10    attrs
11        .iter()
12        .find_map(|attr| {
13            if !attr.path().is_ident("primitive") {
14                return None;
15            }
16
17            let ident: Ident = attr.parse_args().unwrap();
18            let name = ident.to_string();
19
20            Some((ident, name))
21        })
22        .expect("complex enums must include primitive type name")
23}
24
25/// #[primitive(PrimitiveName)]
26#[proc_macro_derive(PrimitiveFromEnum, attributes(primitive))]
27pub fn derive_primitive_from_enum(stream: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    let ast = parse_macro_input!(stream as DeriveInput);
29
30    let name = &ast.ident;
31    let data = &ast.data;
32
33    match data {
34        Data::Enum(data_enum) => {
35            let is_simple_enum = data_enum.variants.iter().all(|item| item.fields.is_empty());
36
37            if is_simple_enum {
38                panic!("PrimitiveFromEnum only for non simple enum allow");
39            } else {
40                let (primitive_name, primitive_name_s) = get_primitive_name(&ast.attrs);
41
42                let len = data_enum.variants.len();
43
44                let mut get_primitive_enum: Vec<TokenStream> = Vec::with_capacity(len);
45
46                for variant in &data_enum.variants {
47                    let variant_name = &variant.ident;
48
49                    match &variant.fields {
50                        Fields::Unit => {
51                            get_primitive_enum.push(quote! {
52                                #name::#variant_name => #primitive_name::#variant_name,
53                            });
54                        }
55                        Fields::Unnamed(fields) => {
56                            let len = fields.unnamed.len();
57                            if len == 1 {
58                                get_primitive_enum.push(quote! {
59                                    #name::#variant_name(_) => #primitive_name::#variant_name,
60                                });
61                            } else {
62                                let underscores = vec![quote! { ,_ }; len - 1];
63                                get_primitive_enum.push(quote! {
64                                    #name::#variant_name(_ #(#underscores)*) => #primitive_name::#variant_name,
65                                });
66                            }
67                        }
68                        Fields::Named(fields) => {
69                            let fields = &fields
70                                .named
71                                .iter()
72                                .map(|f| {
73                                    let ident = f.ident.as_ref().unwrap();
74                                    quote! { #ident: _, }
75                                })
76                                .collect::<Vec<_>>();
77                            get_primitive_enum.push(quote! {
78                                #name::#variant_name{ #(#fields)* } => #primitive_name::#variant_name,
79                            });
80                        }
81                    };
82                }
83
84                let gen = quote! {
85                    impl primitive_enum::PrimitiveFromEnum for #name {
86                        type PrimitiveEnum = #primitive_name;
87                        #[inline]
88                        fn get_primitive_enum(&self) -> Self::PrimitiveEnum {
89                            match self {
90                                #(#get_primitive_enum)*
91                            }
92                        }
93                        #[inline]
94                        fn primitive_name() -> &'static str {
95                            #primitive_name_s
96                        }
97                    }
98                };
99
100                proc_macro::TokenStream::from(gen)
101            }
102        }
103        _ => {
104            panic!("PrimitiveFromEnum only for enum allow");
105        }
106    }
107}
108
109#[proc_macro_derive(FromU8, attributes(primitive))]
110pub fn derive_from_u8(stream: proc_macro::TokenStream) -> proc_macro::TokenStream {
111    let ast = parse_macro_input!(stream as DeriveInput);
112
113    let name = &ast.ident;
114    let name_s = &ast.ident.to_string();
115    let data = &ast.data;
116
117    match data {
118        Data::Enum(data_enum) => {
119            let is_simple_enum = data_enum.variants.iter().all(|item| item.fields.is_empty());
120            if is_simple_enum {
121                let mut variants: Vec<TokenStream> = Vec::with_capacity(data_enum.variants.len());
122                let mut try_variants: Vec<TokenStream> =
123                    Vec::with_capacity(data_enum.variants.len());
124
125                for variant in &data_enum.variants {
126                    let ident = &variant.ident;
127                    let var = quote! {
128                        u if #name::#ident == u => #name::#ident,
129                    };
130                    variants.push(var);
131                    try_variants.push(quote! {
132                        u if #name::#ident == u => Ok(#name::#ident),
133                    });
134                }
135
136                let gen = quote! {
137                    impl PartialEq<u8> for #name {
138                        fn eq(&self, other: &u8) -> bool {
139                            *self as u8 == *other
140                        }
141                    }
142                    impl From<#name> for u8 {
143                        fn from(e: #name) -> u8 {
144                            e as u8
145                        }
146                    }
147                    impl primitive_enum::UnsafeFromU8 for #name {
148                        #[inline]
149                        unsafe fn from_unsafe(u: u8) -> Self {
150                            match u {
151                                #(#variants)*
152                                _ => panic!("UnsafeFromU8 from_unsafe undefined value: {}", u),
153                            }
154                        }
155                        #[inline]
156                        fn name() -> &'static str {
157                            #name_s
158                        }
159                    }
160                    impl core::convert::TryFrom<u8> for #name {
161                        type Error = primitive_enum::EnumFromU8Error;
162                        fn try_from(value: u8) -> Result<Self, Self::Error> {
163                            match value {
164                                #(#try_variants)*
165                                _ => Err(primitive_enum::EnumFromU8Error),
166                            }
167                        }
168                    }
169                };
170                proc_macro::TokenStream::from(gen)
171            } else {
172                panic!("FromU8 only for simple enum allow (without nested data)");
173            }
174        }
175        _ => {
176            panic!("FromU8 only for enum allow");
177        }
178    }
179}