const_enum_tools_derive/
lib.rs

1#![allow(incomplete_features)]
2#![feature(generic_const_exprs)]
3extern crate proc_macro;
4extern crate const_enum_tools;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput};
9
10#[proc_macro_derive(VariantCount)]
11pub fn derive_variant_count(enum_item: TokenStream) -> TokenStream {
12    let ast: syn::DeriveInput = parse_macro_input!(enum_item as DeriveInput);
13
14    match ast.data {
15        syn::Data::Union(union_data) => {
16            let err = syn::Error::new_spanned(union_data.union_token, "Unexpected union declaration: VariantList can only be derived for enums.");
17            err.into_compile_error().into()
18        },
19        syn::Data::Struct(struct_data) => {
20            let err = syn::Error::new_spanned(struct_data.struct_token, "Unexpected union declaration: VariantList can only be derived for enums.");
21            err.into_compile_error().into()
22        },
23        syn::Data::Enum(enum_field_data) => {
24            let variants = enum_field_data.variants;
25            let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
26            let name = ast.ident;
27            let variant_count = variants.len();
28
29            quote!(
30                #[automatically_derived]
31                impl #impl_generics ::const_enum_tools::VariantCount for #name #ty_generics #where_clause {
32                    const VARIANT_COUNT: usize = #variant_count;
33                }
34            ).into()
35        }
36    }
37}
38
39const DISALLOW_INSTANCE_BITCOPY: &str = "disallow_instance_bitcopy";
40
41#[proc_macro_derive(VariantList, attributes(disallow_instance_bitcopy))]
42pub fn derive_variant_list(enum_item: TokenStream) -> TokenStream {
43    let ast: syn::DeriveInput = parse_macro_input!(enum_item as DeriveInput);
44
45    match ast.data {
46        syn::Data::Union(union_data) => {
47            let err = syn::Error::new_spanned(union_data.union_token, "Unexpected union declaration: VariantList can only be derived for enums.");
48            err.into_compile_error().into()
49        },
50        syn::Data::Struct(struct_data) => {
51            let err = syn::Error::new_spanned(struct_data.struct_token, "Unexpected union declaration: VariantList can only be derived for enums.");
52            err.into_compile_error().into()
53        },
54        syn::Data::Enum(enum_field_data) => {
55            let variants = enum_field_data.variants;
56            let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
57            let name = ast.ident;
58            let variant_count = variants.len();
59
60            let mut variant_index_match_arms = Vec::new();
61            let mut variant_names = Vec::new();
62            let mut all_unit_no_discriminant = true;
63            let mut disallow_instance_bitcopy = false;
64
65            for attr in &ast.attrs {
66                if attr.path.is_ident(DISALLOW_INSTANCE_BITCOPY) {
67                    disallow_instance_bitcopy = true;
68                }
69            }
70
71            for (index, variant) in variants.iter().enumerate() {
72                let variant_name = &variant.ident;
73                if !disallow_instance_bitcopy {
74                    for attr in &variant.attrs {
75                        if attr.path.is_ident(DISALLOW_INSTANCE_BITCOPY) {
76                            disallow_instance_bitcopy = true;
77                        }
78                    }
79                }
80
81                variant_index_match_arms.push(
82                    match &variant.fields {
83                        syn::Fields::Named(fields) => {
84                            all_unit_no_discriminant = false;
85                            let mapped = fields.named.iter().map(|_| { quote!(_) });
86                            quote!(
87                                Self::#variant_name(#(#mapped),*) => {
88                                    #index
89                                }
90                            )
91                        },
92                        syn::Fields::Unnamed(fields) => {
93                            all_unit_no_discriminant = false;
94                            let mapped = fields.unnamed.iter().map(|_| { quote!(_) });
95                            quote!(
96                                Self::#variant_name(#(#mapped),*) => {
97                                    #index
98                                }
99                            )
100                        },
101                        syn::Fields::Unit => {
102                            // If there is an explicit discriminant, we might not be able to perform the bitwise copy
103                            // optimization.
104                            if let Some(discriminant) = &variant.discriminant {
105                                match discriminant.1.clone() {
106                                    // If the discriminant expression is a literal, we can check if it is equal to the default value.
107                                    syn::Expr::Lit(lit) => {
108                                        match lit.lit {
109                                            syn::Lit::Int(int_lit) => {
110                                                // If the first part of the literal before the type is the same as what it would be
111                                                // because of the position in the enum, we're good. Otherwise, no optimization.
112                                                if int_lit.base10_digits() != index.to_string().as_str() {
113                                                    all_unit_no_discriminant = false;
114                                                }
115                                            },
116                                            _ => {
117                                                all_unit_no_discriminant = false;
118                                            }
119                                        }
120                                    },
121                                    // Otherwise, since we cannot evaluate arbitrary const expressions, we will not be able to optimize.
122                                    // This involves using the long match arms list.
123                                    _ => {
124                                        all_unit_no_discriminant = false;
125                                    },
126                                }
127                            }
128                            quote!(
129                                Self::#variant_name => {
130                                    #index
131                                }
132                            )
133                        },
134                    }
135                );
136
137                variant_names.push({
138                    let variant_name_string = variant_name.to_string();
139                    quote!(
140                        #variant_name_string
141                    )
142                });
143
144            }
145
146            // If there are no explicit discriminants
147            // This enum will be represented as a number type. Cast the reference
148            // to a raw pointer and read the bits from it (allows this optimization to be performed even when self =/= Copy).
149            // This is effectively a clone. Then cast to usize for index.
150            // I would love a better way of doing this that doesn't require an unsafe block. Alas, I can't think of any.
151            let variant_index_body = if all_unit_no_discriminant && !disallow_instance_bitcopy {
152                quote!(
153                    unsafe {
154                        (self as *const Self).read() as usize
155                    }
156                )
157            }
158            else {
159                quote!(
160                    match self {
161                        #(
162                            #variant_index_match_arms
163                        ),*
164                    }
165                )
166            };
167
168            quote!(
169                #[automatically_derived]
170                impl #impl_generics ::const_enum_tools::VariantList for #name #ty_generics #where_clause {
171                    #[inline]
172                    fn variant_index (&self) -> usize {
173                        #variant_index_body
174                    }
175
176                    const VARIANTS: [&'static str; #variant_count] = [#(#variant_names),*];
177                }
178            ).into()
179        }
180    }
181
182}