const_exhaustive_derive/
lib.rs

1//! Derive macros for [`const-exhaustive`].
2//!
3//! [`const-exhaustive`]: https://docs.rs/const-exhaustive
4
5use {
6    proc_macro2::{Span, TokenStream},
7    quote::{ToTokens, quote},
8    syn::{
9        Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident, Result,
10        parse_macro_input,
11    },
12};
13
14/// Derives `const_exhaustive::Exhaustive` on this type.
15///
16/// This type must be [`Clone`] and [`Copy`], and all types contained within
17/// it must also be `Exhaustive`.
18///
19/// # Limitations
20///
21/// This macro cannot be used on `union`s.
22///
23/// This macro cannot yet be used on types with type parameters. This is
24/// technically possible, but requires the macro to add more explicit `where`
25/// bounds. Pull requests welcome!
26#[proc_macro_derive(Exhaustive)]
27pub fn exhaustive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    derive(&input)
30        .unwrap_or_else(|err| err.to_compile_error())
31        .into()
32}
33
34macro_rules! shortcuts {
35    {
36        struct $shortcuts_name:ident {
37            $(
38                $($item_path:ident::)* : $item_name:ident
39            ),*
40        }
41    } => {
42        #[allow(non_snake_case, reason = "shortcut items")]
43        struct $shortcuts_name {
44            $(
45                $item_name: TokenStream,
46            )*
47        }
48
49        impl Default for $shortcuts_name {
50            fn default() -> Self {
51                Self {
52                    $(
53                        $item_name: quote! { ::$($item_path::)*$item_name },
54                    )*
55                }
56            }
57        }
58    };
59}
60
61shortcuts! {
62    struct Shortcuts {
63        core::cell:::UnsafeCell,
64        core::mem:::MaybeUninit,
65        const_exhaustive:::Exhaustive,
66        const_exhaustive:::const_transmute,
67        const_exhaustive::typenum:::U0,
68        const_exhaustive::typenum:::U1,
69        const_exhaustive::typenum:::Unsigned,
70        const_exhaustive::typenum::operator_aliases:::Sum,
71        const_exhaustive::typenum::operator_aliases:::Prod,
72        const_exhaustive::generic_array:::GenericArray
73    }
74}
75
76fn derive(input: &DeriveInput) -> Result<TokenStream> {
77    let Shortcuts { Exhaustive, .. } = Shortcuts::default();
78
79    let name = &input.ident;
80    let generics = &input.generics;
81    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
82
83    let ExhaustiveImpl { num, values } = match &input.data {
84        Data::Struct(data) => make_for_struct(data),
85        Data::Enum(data) => make_for_enum(data),
86        Data::Union(_) => {
87            return Err(Error::new_spanned(
88                input,
89                "exhaustive union is not supported",
90            ));
91        }
92    };
93
94    let body = impl_body(num, values);
95    Ok(quote! {
96        unsafe impl #impl_generics #Exhaustive for #name #type_generics #where_clause {
97            #body
98        }
99    })
100}
101
102struct ExhaustiveImpl {
103    num: TokenStream,
104    values: TokenStream,
105}
106
107fn make_for_struct(data: &DataStruct) -> ExhaustiveImpl {
108    make_for_fields(&data.fields, quote! { Self })
109}
110
111fn make_for_enum(data: &DataEnum) -> ExhaustiveImpl {
112    struct VariantInfo {
113        num: TokenStream,
114        values: TokenStream,
115    }
116
117    let Shortcuts { U0, Sum, .. } = Shortcuts::default();
118
119    let variants = data
120        .variants
121        .iter()
122        .map(|variant| {
123            let ident = &variant.ident;
124            let ExhaustiveImpl { num, values } =
125                make_for_fields(&variant.fields, quote! { Self::#ident });
126            VariantInfo { num, values }
127        })
128        .collect::<Vec<_>>();
129
130    let num = variants
131        .iter()
132        .fold(quote! { #U0 }, |acc, VariantInfo { num, .. }| {
133            quote! { #Sum<#acc, #num> }
134        });
135
136    let values = variants
137        .iter()
138        .map(|VariantInfo { values, .. }| {
139            quote! {
140                {
141                    #values
142                }
143            }
144        })
145        .collect::<Vec<_>>();
146    let values = quote! {
147        #(#values)*
148    };
149
150    ExhaustiveImpl { num, values }
151}
152
153fn make_for_fields(fields: &Fields, construct_ident: impl ToTokens) -> ExhaustiveImpl {
154    struct FieldInfo<'a> {
155        field: &'a Field,
156        index: Ident,
157    }
158
159    const fn require_ident(field: &Field) -> &Ident {
160        field
161            .ident
162            .as_ref()
163            .expect("named field must have an ident")
164    }
165
166    fn get_value(ty: impl ToTokens, index: impl ToTokens) -> TokenStream {
167        let Shortcuts { Exhaustive, .. } = Shortcuts::default();
168
169        quote! {
170            <#ty as #Exhaustive>::ALL.as_slice()[#index]
171        }
172    }
173
174    let Shortcuts {
175        MaybeUninit,
176        Exhaustive,
177        U1,
178        Unsigned,
179        Prod,
180        ..
181    } = Shortcuts::default();
182
183    let (fields, construct) = match fields {
184        Fields::Unit => (Vec::<FieldInfo>::new(), quote! {}),
185        Fields::Unnamed(fields) => {
186            let fields = fields
187                .unnamed
188                .iter()
189                .enumerate()
190                .map(|(index, field)| {
191                    let index = Ident::new(&format!("i_{index}"), Span::call_site());
192                    FieldInfo { field, index }
193                })
194                .collect::<Vec<_>>();
195            let construct = fields
196                .iter()
197                .map(|FieldInfo { field, index }| get_value(&field.ty, index));
198            let construct = quote! {
199                (
200                    #(#construct),*
201                )
202            };
203            (fields, construct)
204        }
205        Fields::Named(fields) => {
206            let fields = fields
207                .named
208                .iter()
209                .map(|field| {
210                    let ident = require_ident(field);
211                    let index = Ident::new(&format!("i_{ident}"), Span::call_site());
212                    FieldInfo { field, index }
213                })
214                .collect::<Vec<_>>();
215            let construct = fields
216                .iter()
217                .map(|FieldInfo { field, index }| {
218                    let ident = require_ident(field);
219                    let get_value = get_value(&field.ty, index);
220                    quote! { #ident: #get_value }
221                })
222                .collect::<Vec<_>>();
223            let construct = quote! {
224                {
225                    #(#construct),*
226                }
227            };
228            (fields, construct)
229        }
230    };
231
232    let num = fields
233        .iter()
234        .fold(quote! { #U1 }, |acc, FieldInfo { field, .. }| {
235            let ty = &field.ty;
236            quote! {
237                #Prod<#acc, <#ty as #Exhaustive>::Num>
238            }
239        });
240
241    // rfold here so that the value order matches the tuple value order
242    // e.g. we generate i_0 { i_1 { i_2 } }
243    //       instead of i_2 { i_1 { i_0 } }
244    let values = fields.iter().rfold(
245        quote! {
246            unsafe {
247                *all.as_slice()[i].get() = #MaybeUninit::new(#construct_ident #construct);
248            };
249            i += 1;
250        },
251        |acc, FieldInfo { field, index, .. }| {
252            let ty = &field.ty;
253            quote! {
254                let mut #index = 0usize;
255                while #index < <<#ty as #Exhaustive>::Num as #Unsigned>::USIZE {
256                    #acc
257                    #index += 1;
258                };
259            }
260        },
261    );
262
263    ExhaustiveImpl { num, values }
264}
265
266fn impl_body(num: impl ToTokens, values: impl ToTokens) -> TokenStream {
267    let Shortcuts {
268        UnsafeCell,
269        MaybeUninit,
270        GenericArray,
271        const_transmute,
272        ..
273    } = Shortcuts::default();
274
275    quote! {
276        type Num = #num;
277
278        const ALL: #GenericArray<Self, Self::Num> = {
279            let all: #GenericArray<#UnsafeCell<#MaybeUninit<Self>>, Self::Num> = unsafe {
280                #MaybeUninit::uninit().assume_init()
281            };
282
283            let mut i = 0;
284
285            #values
286
287            unsafe {
288                #const_transmute(all)
289            }
290        };
291    }
292}