const_enum_tools_derive/
lib.rs1#![allow(incomplete_features)]
2#![feature(generic_const_exprs)]
3extern crate proc_macro;
4extern crate const_enum_tools;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput};
9
10#[proc_macro_derive(VariantCount)]
11pub fn derive_variant_count(enum_item: TokenStream) -> TokenStream {
12 let ast: syn::DeriveInput = parse_macro_input!(enum_item as DeriveInput);
13
14 match ast.data {
15 syn::Data::Union(union_data) => {
16 let err = syn::Error::new_spanned(union_data.union_token, "Unexpected union declaration: VariantList can only be derived for enums.");
17 err.into_compile_error().into()
18 },
19 syn::Data::Struct(struct_data) => {
20 let err = syn::Error::new_spanned(struct_data.struct_token, "Unexpected union declaration: VariantList can only be derived for enums.");
21 err.into_compile_error().into()
22 },
23 syn::Data::Enum(enum_field_data) => {
24 let variants = enum_field_data.variants;
25 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
26 let name = ast.ident;
27 let variant_count = variants.len();
28
29 quote!(
30 #[automatically_derived]
31 impl #impl_generics ::const_enum_tools::VariantCount for #name #ty_generics #where_clause {
32 const VARIANT_COUNT: usize = #variant_count;
33 }
34 ).into()
35 }
36 }
37}
38
39const DISALLOW_INSTANCE_BITCOPY: &str = "disallow_instance_bitcopy";
40
41#[proc_macro_derive(VariantList, attributes(disallow_instance_bitcopy))]
42pub fn derive_variant_list(enum_item: TokenStream) -> TokenStream {
43 let ast: syn::DeriveInput = parse_macro_input!(enum_item as DeriveInput);
44
45 match ast.data {
46 syn::Data::Union(union_data) => {
47 let err = syn::Error::new_spanned(union_data.union_token, "Unexpected union declaration: VariantList can only be derived for enums.");
48 err.into_compile_error().into()
49 },
50 syn::Data::Struct(struct_data) => {
51 let err = syn::Error::new_spanned(struct_data.struct_token, "Unexpected union declaration: VariantList can only be derived for enums.");
52 err.into_compile_error().into()
53 },
54 syn::Data::Enum(enum_field_data) => {
55 let variants = enum_field_data.variants;
56 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
57 let name = ast.ident;
58 let variant_count = variants.len();
59
60 let mut variant_index_match_arms = Vec::new();
61 let mut variant_names = Vec::new();
62 let mut all_unit_no_discriminant = true;
63 let mut disallow_instance_bitcopy = false;
64
65 for attr in &ast.attrs {
66 if attr.path.is_ident(DISALLOW_INSTANCE_BITCOPY) {
67 disallow_instance_bitcopy = true;
68 }
69 }
70
71 for (index, variant) in variants.iter().enumerate() {
72 let variant_name = &variant.ident;
73 if !disallow_instance_bitcopy {
74 for attr in &variant.attrs {
75 if attr.path.is_ident(DISALLOW_INSTANCE_BITCOPY) {
76 disallow_instance_bitcopy = true;
77 }
78 }
79 }
80
81 variant_index_match_arms.push(
82 match &variant.fields {
83 syn::Fields::Named(fields) => {
84 all_unit_no_discriminant = false;
85 let mapped = fields.named.iter().map(|_| { quote!(_) });
86 quote!(
87 Self::#variant_name(#(#mapped),*) => {
88 #index
89 }
90 )
91 },
92 syn::Fields::Unnamed(fields) => {
93 all_unit_no_discriminant = false;
94 let mapped = fields.unnamed.iter().map(|_| { quote!(_) });
95 quote!(
96 Self::#variant_name(#(#mapped),*) => {
97 #index
98 }
99 )
100 },
101 syn::Fields::Unit => {
102 if let Some(discriminant) = &variant.discriminant {
105 match discriminant.1.clone() {
106 syn::Expr::Lit(lit) => {
108 match lit.lit {
109 syn::Lit::Int(int_lit) => {
110 if int_lit.base10_digits() != index.to_string().as_str() {
113 all_unit_no_discriminant = false;
114 }
115 },
116 _ => {
117 all_unit_no_discriminant = false;
118 }
119 }
120 },
121 _ => {
124 all_unit_no_discriminant = false;
125 },
126 }
127 }
128 quote!(
129 Self::#variant_name => {
130 #index
131 }
132 )
133 },
134 }
135 );
136
137 variant_names.push({
138 let variant_name_string = variant_name.to_string();
139 quote!(
140 #variant_name_string
141 )
142 });
143
144 }
145
146 let variant_index_body = if all_unit_no_discriminant && !disallow_instance_bitcopy {
152 quote!(
153 unsafe {
154 (self as *const Self).read() as usize
155 }
156 )
157 }
158 else {
159 quote!(
160 match self {
161 #(
162 #variant_index_match_arms
163 ),*
164 }
165 )
166 };
167
168 quote!(
169 #[automatically_derived]
170 impl #impl_generics ::const_enum_tools::VariantList for #name #ty_generics #where_clause {
171 #[inline]
172 fn variant_index (&self) -> usize {
173 #variant_index_body
174 }
175
176 const VARIANTS: [&'static str; #variant_count] = [#(#variant_names),*];
177 }
178 ).into()
179 }
180 }
181
182}