1use {
6 proc_macro2::{Span, TokenStream},
7 quote::{ToTokens, quote},
8 syn::{
9 Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident, Result,
10 parse_macro_input,
11 },
12};
13
14#[proc_macro_derive(Exhaustive)]
27pub fn exhaustive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29 derive(&input)
30 .unwrap_or_else(|err| err.to_compile_error())
31 .into()
32}
33
34macro_rules! shortcuts {
35 {
36 struct $shortcuts_name:ident {
37 $(
38 $($item_path:ident::)* : $item_name:ident
39 ),*
40 }
41 } => {
42 #[allow(non_snake_case, reason = "shortcut items")]
43 struct $shortcuts_name {
44 $(
45 $item_name: TokenStream,
46 )*
47 }
48
49 impl Default for $shortcuts_name {
50 fn default() -> Self {
51 Self {
52 $(
53 $item_name: quote! { ::$($item_path::)*$item_name },
54 )*
55 }
56 }
57 }
58 };
59}
60
61shortcuts! {
62 struct Shortcuts {
63 core::cell:::UnsafeCell,
64 core::mem:::MaybeUninit,
65 const_exhaustive:::Exhaustive,
66 const_exhaustive:::const_transmute,
67 const_exhaustive::typenum:::U0,
68 const_exhaustive::typenum:::U1,
69 const_exhaustive::typenum:::Unsigned,
70 const_exhaustive::typenum::operator_aliases:::Sum,
71 const_exhaustive::typenum::operator_aliases:::Prod,
72 const_exhaustive::generic_array:::GenericArray
73 }
74}
75
76fn derive(input: &DeriveInput) -> Result<TokenStream> {
77 let Shortcuts { Exhaustive, .. } = Shortcuts::default();
78
79 let name = &input.ident;
80 let generics = &input.generics;
81 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
82
83 let ExhaustiveImpl { num, values } = match &input.data {
84 Data::Struct(data) => make_for_struct(data),
85 Data::Enum(data) => make_for_enum(data),
86 Data::Union(_) => {
87 return Err(Error::new_spanned(
88 input,
89 "exhaustive union is not supported",
90 ));
91 }
92 };
93
94 let body = impl_body(num, values);
95 Ok(quote! {
96 unsafe impl #impl_generics #Exhaustive for #name #type_generics #where_clause {
97 #body
98 }
99 })
100}
101
102struct ExhaustiveImpl {
103 num: TokenStream,
104 values: TokenStream,
105}
106
107fn make_for_struct(data: &DataStruct) -> ExhaustiveImpl {
108 make_for_fields(&data.fields, quote! { Self })
109}
110
111fn make_for_enum(data: &DataEnum) -> ExhaustiveImpl {
112 struct VariantInfo {
113 num: TokenStream,
114 values: TokenStream,
115 }
116
117 let Shortcuts { U0, Sum, .. } = Shortcuts::default();
118
119 let variants = data
120 .variants
121 .iter()
122 .map(|variant| {
123 let ident = &variant.ident;
124 let ExhaustiveImpl { num, values } =
125 make_for_fields(&variant.fields, quote! { Self::#ident });
126 VariantInfo { num, values }
127 })
128 .collect::<Vec<_>>();
129
130 let num = variants
131 .iter()
132 .fold(quote! { #U0 }, |acc, VariantInfo { num, .. }| {
133 quote! { #Sum<#acc, #num> }
134 });
135
136 let values = variants
137 .iter()
138 .map(|VariantInfo { values, .. }| {
139 quote! {
140 {
141 #values
142 }
143 }
144 })
145 .collect::<Vec<_>>();
146 let values = quote! {
147 #(#values)*
148 };
149
150 ExhaustiveImpl { num, values }
151}
152
153fn make_for_fields(fields: &Fields, construct_ident: impl ToTokens) -> ExhaustiveImpl {
154 struct FieldInfo<'a> {
155 field: &'a Field,
156 index: Ident,
157 }
158
159 const fn require_ident(field: &Field) -> &Ident {
160 field
161 .ident
162 .as_ref()
163 .expect("named field must have an ident")
164 }
165
166 fn get_value(ty: impl ToTokens, index: impl ToTokens) -> TokenStream {
167 let Shortcuts { Exhaustive, .. } = Shortcuts::default();
168
169 quote! {
170 <#ty as #Exhaustive>::ALL.as_slice()[#index]
171 }
172 }
173
174 let Shortcuts {
175 MaybeUninit,
176 Exhaustive,
177 U1,
178 Unsigned,
179 Prod,
180 ..
181 } = Shortcuts::default();
182
183 let (fields, construct) = match fields {
184 Fields::Unit => (Vec::<FieldInfo>::new(), quote! {}),
185 Fields::Unnamed(fields) => {
186 let fields = fields
187 .unnamed
188 .iter()
189 .enumerate()
190 .map(|(index, field)| {
191 let index = Ident::new(&format!("i_{index}"), Span::call_site());
192 FieldInfo { field, index }
193 })
194 .collect::<Vec<_>>();
195 let construct = fields
196 .iter()
197 .map(|FieldInfo { field, index }| get_value(&field.ty, index));
198 let construct = quote! {
199 (
200 #(#construct),*
201 )
202 };
203 (fields, construct)
204 }
205 Fields::Named(fields) => {
206 let fields = fields
207 .named
208 .iter()
209 .map(|field| {
210 let ident = require_ident(field);
211 let index = Ident::new(&format!("i_{ident}"), Span::call_site());
212 FieldInfo { field, index }
213 })
214 .collect::<Vec<_>>();
215 let construct = fields
216 .iter()
217 .map(|FieldInfo { field, index }| {
218 let ident = require_ident(field);
219 let get_value = get_value(&field.ty, index);
220 quote! { #ident: #get_value }
221 })
222 .collect::<Vec<_>>();
223 let construct = quote! {
224 {
225 #(#construct),*
226 }
227 };
228 (fields, construct)
229 }
230 };
231
232 let num = fields
233 .iter()
234 .fold(quote! { #U1 }, |acc, FieldInfo { field, .. }| {
235 let ty = &field.ty;
236 quote! {
237 #Prod<#acc, <#ty as #Exhaustive>::Num>
238 }
239 });
240
241 let values = fields.iter().rfold(
245 quote! {
246 unsafe {
247 *all.as_slice()[i].get() = #MaybeUninit::new(#construct_ident #construct);
248 };
249 i += 1;
250 },
251 |acc, FieldInfo { field, index, .. }| {
252 let ty = &field.ty;
253 quote! {
254 let mut #index = 0usize;
255 while #index < <<#ty as #Exhaustive>::Num as #Unsigned>::USIZE {
256 #acc
257 #index += 1;
258 };
259 }
260 },
261 );
262
263 ExhaustiveImpl { num, values }
264}
265
266fn impl_body(num: impl ToTokens, values: impl ToTokens) -> TokenStream {
267 let Shortcuts {
268 UnsafeCell,
269 MaybeUninit,
270 GenericArray,
271 const_transmute,
272 ..
273 } = Shortcuts::default();
274
275 quote! {
276 type Num = #num;
277
278 const ALL: #GenericArray<Self, Self::Num> = {
279 let all: #GenericArray<#UnsafeCell<#MaybeUninit<Self>>, Self::Num> = unsafe {
280 #MaybeUninit::uninit().assume_init()
281 };
282
283 let mut i = 0;
284
285 #values
286
287 unsafe {
288 #const_transmute(all)
289 }
290 };
291 }
292}