enum_discriminant_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, parse_quote, AttrStyle};
5
6macro_rules! compile_error_unless_ok {
7 ($result:expr) => {
8 match $result {
9 Ok(value) => value,
10 Err(error) => return error.to_compile_error().into(),
11 }
12 };
13}
14
15#[proc_macro_attribute]
23pub fn discriminant(arguments: TokenStream, item: TokenStream) -> TokenStream {
24 let enum_item = parse_macro_input!(item as syn::ItemEnum);
25 let enum_name = &enum_item.ident;
26
27 let arguments: TokenStream2 = arguments.into();
28
29 let repr_type = compile_error_unless_ok!(get_repr_type(arguments.clone()));
30
31 let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
32 let discriminant_code = generate_discriminant_function(&repr_type);
33
34 quote! {
35 #[repr(#arguments)]
36 #enum_item
37
38 impl #enum_name {
39 #from_discriminant_code
40
41 #discriminant_code
42 }
43 }
44 .into()
45}
46
47#[proc_macro_derive(IntoDiscriminant)]
53pub fn derive_into_discriminant(item: TokenStream) -> TokenStream {
54 let input = parse_macro_input!(item as syn::DeriveInput);
55 let enum_name = &input.ident;
56
57 let repr_args = compile_error_unless_ok!(get_repr_args("IntoDiscriminant", &input));
58 let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
59
60 let discriminant_code = generate_discriminant_function(&repr_type);
61
62 quote! {
63 impl IntoDiscriminant for #enum_name {
64 type DiscriminantType = #repr_type;
65
66 #discriminant_code
67 }
68 }
69 .into()
70}
71
72#[proc_macro_derive(FromDiscriminant)]
79pub fn derive_from_discriminant(item: TokenStream) -> TokenStream {
80 let cloned_item = item.clone();
81 let input = parse_macro_input!(item as syn::DeriveInput);
82 let enum_item = parse_macro_input!(cloned_item as syn::ItemEnum);
83 let enum_name = &enum_item.ident;
84
85 let repr_args = compile_error_unless_ok!(get_repr_args("FromDiscriminant", &input));
86 let repr_type = compile_error_unless_ok!(get_repr_type(repr_args));
87
88 let from_discriminant_code = generate_from_discriminant_function(&repr_type, &enum_item);
89
90 quote! {
91 impl FromDiscriminant for #enum_name {
92 type DiscriminantType = #repr_type;
93
94 #from_discriminant_code
95 }
96 }
97 .into()
98}
99
100fn get_repr_type(arguments: TokenStream2) -> Result<syn::Path, syn::Error> {
103 let allowed_types = [
104 "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize",
105 ];
106
107 arguments
108 .clone()
109 .into_iter()
110 .filter_map(|token_tree| {
112 if let proc_macro2::TokenTree::Ident(ident) = token_tree {
113 let ident_str = ident.to_string();
114 if allowed_types.contains(&ident_str.as_str()) {
115 return Some(syn::parse_str::<syn::Path>(&ident_str).unwrap());
116 }
117 }
118 None
119 })
120 .next()
121 .ok_or_else(|| {
123 syn::Error::new_spanned(
124 arguments,
125 "Valid enum representation type expected as argument to the discriminant \
126 macro, e.g., #[discriminant(u8)]",
127 )
128 })
129}
130
131fn get_repr_args(macro_name: &str, input: &syn::DeriveInput) -> Result<TokenStream2, syn::Error> {
134 let x = input
135 .attrs
136 .iter()
137 .filter(|attr| matches!(attr.style, AttrStyle::Outer))
138 .filter(|attr| {
139 let path = attr.path();
140 path.is_ident("repr") || path.is_ident("discriminant")
141 })
142 .filter_map(|attr| attr.meta.require_list().ok())
143 .next()
144 .ok_or_else(|| {
145 syn::Error::new_spanned(
146 input,
147 format!(
148 "When deriving {} on an enum, you also need to specify \
149 representation type with #[repr()] or #[discriminant()]",
150 macro_name
151 ),
152 )
153 })?;
154 Ok(x.tokens.clone())
155}
156
157fn enum_unit_variants(enum_item: &syn::ItemEnum) -> (Vec<proc_macro2::Ident>, Vec<syn::Expr>) {
161 let mut previous_expr: Option<syn::Expr> = None;
162 enum_item
163 .variants
164 .iter()
165 .filter(|variant| matches!(variant.fields, syn::Fields::Unit))
166 .map(|variant| {
167 let expr = if let Some(discriminant) = &variant.discriminant {
168 discriminant.1.clone()
169 } else if let Some(ref old_expr) = previous_expr {
170 parse_quote!( 1 + #old_expr )
171 } else {
172 parse_quote!(0)
173 };
174 previous_expr = Some(expr.clone());
175 (variant.ident.clone(), expr)
176 })
177 .unzip()
178}
179
180fn generate_from_discriminant_function(
181 repr_type: &syn::Path,
182 enum_item: &syn::ItemEnum,
183) -> TokenStream2 {
184 let (variant_names, discriminants) = enum_unit_variants(enum_item);
185 let enum_name = &enum_item.ident;
186
187 quote! {
188 fn from_discriminant(discriminant: #repr_type) -> Option<Self> {
191 match discriminant {
192 #( discriminant if discriminant == #discriminants =>
195 Some(#enum_name::#variant_names), )*
196 _ => None,
197 }
198 }
199 }
200}
201
202fn generate_discriminant_function(repr_type: &syn::Path) -> TokenStream2 {
203 quote! {
204 fn discriminant(&self) -> #repr_type {
206 unsafe {
208 *<*const _>::from(self).cast::<#repr_type>()
209 }
210 }
211 }
212}