phenotype_macro/
lib.rs

1use proc_macro2::TokenStream;
2use proc_macro_error::{abort, proc_macro_error};
3use quote::{format_ident, quote};
4use std::collections::HashMap;
5use syn::{parse_macro_input, DeriveInput, FieldsNamed, FieldsUnnamed, Generics, Ident, Variant};
6
7const NOTE: &str = "can only derive phenotype on enums";
8
9type Tag = usize;
10
11/// Holds the logic for parsing generics
12mod generic;
13
14/// Condensed derive input; just the stuff we need
15struct Condensed<'a> {
16    name: Ident,
17    variants: HashMap<Tag, Variant>,
18    generics: &'a Generics,
19}
20// For calculating log without using the unstable feature
21const fn num_bits<T>() -> usize {
22    std::mem::size_of::<T>() * 8
23}
24
25fn log2(x: usize) -> u32 {
26    assert!(x > 0);
27    num_bits::<usize>() as u32 - x.leading_zeros() - 1
28}
29
30#[proc_macro_derive(Phenotype)]
31#[proc_macro_error]
32pub fn phenotype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    let ast = parse_macro_input!(input as DeriveInput);
34
35    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
36    let ident = ast.ident.clone();
37
38    // Verify we have an enum
39    let enumb = match ast.data {
40        syn::Data::Enum(e) => e,
41        syn::Data::Struct(data) => {
42            abort!(data.struct_token, "struct `{}` is not an enum", ast.ident; note=NOTE)
43        }
44        syn::Data::Union(data) => {
45            abort!(data.union_token, "union `{}` is not an enum", ast.ident; note=NOTE)
46        }
47    };
48
49    let data = Condensed {
50        variants: enumb
51            .variants
52            .into_iter()
53            .enumerate()
54            .collect::<HashMap<Tag, Variant>>(),
55        name: ident.clone(),
56        generics: &ast.generics,
57    };
58
59    // Make sure there are variants!
60    if data.variants.is_empty() {
61        abort!(data.name, "enum `{}` has no variants", data.name)
62    }
63
64    // Abort if there are const generics - works funky with the way we deal with generics
65    if ast.generics.const_params().next().is_some() {
66        abort!(
67            ty_generics,
68            "const generics are not supported for `#[derive(Phenotype)]`";
69            note = "it may be possible to implement `Phenotype` by hand"
70        )
71    }
72
73    let auxiliaries = make_auxiliaries(&data);
74
75    let cleave_impl = cleave_impl(&data);
76
77    let reknit_impl = reknit_impl(&data);
78
79    let bits = {
80        if data.variants.is_empty() {
81            0
82        } else if data.variants.len() == 1 {
83            // This avoids having to check everywhere if T::BITS == 1,
84            // which is easy to forget and can easily cause panics,
85            // for the cheap cost of one bit
86            1
87        } else {
88            let log = log2(data.variants.len());
89            let pow = 2usize.pow(log);
90
91            // if 2 ** log is less than the number of variants, that means
92            // the log rounded down (i.e. the float version was something like
93            // 1.4, which became 1)
94            //
95            // We round up because we always carry the extra bits, i.e.
96            // 7 variants needs 2.8 bits but we carry 3
97            (if pow < data.variants.len() {
98                log + 1
99            } else {
100                log
101            }) as usize
102        }
103    };
104
105    let num_variants = data.variants.len();
106
107    let union_ident = format_ident!("__PhenotypeInternal{}Data", data.name);
108
109    let peapod_size = match data.generics.type_params().next() {
110        Some(_) => quote!(None),
111        // No generics
112        None => {
113            let bytes = bits / 8
114                + if bits % 8 == 0 {
115                    0
116                } else {
117                    // Add an extra byte if there are remaining bits (a partial byte)
118                    1
119                };
120            quote!(Some({ #bytes + ::core::mem::size_of::<#union_ident>() }))
121        }
122    };
123
124    let is_more_compact = match data.generics.type_params().next() {
125        Some(_) => quote!(None),
126        // No generics
127        None => {
128            quote!(
129                Some(
130                    // unwrap isn't const
131                    match <Self as Phenotype>::PEAPOD_SIZE {
132                        Some(size) => size <= ::core::mem::size_of::<#ident>(),
133                        // Unreachable as if there are not generics, PEAPOD_SIZE
134                        // is `Some`
135                        None => unreachable!()
136                    }
137
138                )
139            )
140        }
141    };
142
143    quote! {
144        #auxiliaries
145        unsafe impl #impl_generics Phenotype for #ident #ty_generics
146            #where_clause
147        {
148            const NUM_VARIANTS: usize = #num_variants;
149            const BITS: usize = #bits;
150            const PEAPOD_SIZE: Option<usize> = #peapod_size;
151            const IS_MORE_COMPACT: Option<bool> = #is_more_compact;
152            #cleave_impl
153            #reknit_impl
154        }
155    }
156    .into()
157}
158
159fn reknit_impl(data: &Condensed) -> TokenStream {
160    let mut arms = Vec::with_capacity(data.variants.len());
161
162    let ident = &data.name;
163
164    // We're going to turn each variant into a match that handles that variant's case
165    for (tag, var) in &data.variants {
166        let struct_name = format_ident!("__PhenotypeInternal{}{}Data", data.name, var.ident);
167        let var_ident = &var.ident;
168        let var_generics = generic::variant_generics(data.generics, var);
169        arms.push(match &var.fields {
170            syn::Fields::Named(FieldsNamed { named, .. }) => {
171                let struct_fields = named
172                    .iter()
173                    .map(|f| f.ident.clone().unwrap())
174                    .collect::<Vec<_>>();
175                quote! {
176                    #tag => {
177                        // SAFETY: Safe because the tag guarantees that we are reading the correct field
178                        let data = ::core::mem::ManuallyDrop::<#struct_name :: #var_generics>::into_inner(
179                            unsafe { value.#var_ident }
180                        );
181                        #ident::#var_ident { #(#struct_fields: data.#struct_fields),* }
182                    }
183                }
184            }
185            syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
186                // This produces the indexes we use to extract the data from the struct
187                let struct_field_placeholders = (0..unnamed.len()).map(syn::Index::from);
188                quote! {
189                    #tag => {
190                        // SAFETY: Safe because the tag guarantees that we are reading the correct field
191                        let data = ::core::mem::ManuallyDrop::<#struct_name :: #var_generics>::into_inner(
192                            unsafe { value.#var_ident }
193                        );
194                        #ident::#var_ident ( #(data.#struct_field_placeholders),* )
195                    }
196                }
197            }
198            syn::Fields::Unit => {
199                quote! {
200                    #tag => {
201                        #ident::#var_ident
202                    }
203                }
204            }
205        })
206    }
207
208    let generics = data.generics.split_for_impl().1;
209    quote! {
210        unsafe fn reknit(tag: usize, value: <Self as Phenotype>::Value) -> #ident #generics {
211            match tag {
212                #(#arms),*
213                // There should be no other cases, as there are no other variants
214                _ => ::core::unreachable!()
215            }
216        }
217    }
218}
219
220/// Implement the `value` trait method
221fn cleave_impl(data: &Condensed) -> proc_macro2::TokenStream {
222    let ident = &data.name;
223    let union_ident = format_ident!("__PhenotypeInternal{ident}Data");
224
225    // Snippet to extract data out of each field
226    let mut arms: Vec<proc_macro2::TokenStream> = Vec::with_capacity(data.variants.len());
227
228    let generics = data.generics.split_for_impl().1;
229
230    // Like `reknit_impl`, we produce a match arm for each variant
231    for (tag, var) in &data.variants {
232        let var_ident = &var.ident;
233        let struct_name = format_ident!("__PhenotypeInternal{ident}{var_ident}Data");
234
235        let var_generics = generic::variant_generics(data.generics, var);
236        arms.push(match &var.fields {
237            syn::Fields::Named(FieldsNamed { named, .. }) => {
238                // Capture each enum field (named), use it's ident to capture it's value
239                let fields = named.iter().map(|f| f.ident.clone()).collect::<Vec<_>>();
240                quote! {
241                    #ident::#var_ident {#(#fields),*} => (#tag,
242                        #union_ident {
243                            #var_ident: ::core::mem::ManuallyDrop::new(#struct_name :: #var_generics {
244                                // We've wrapped the enum that was passed in in a ManuallyDrop,
245                                // and now we read each field with ptr::read.
246
247                                // We wrap the enum that was passed in a ManuallyDrop to prevent
248                                // double drops.
249
250                                // We have to ptr::read because you can't move out of a
251                                // type that implements `Drop`
252                                // SAFETY: we are reading from a reference
253                                #(#fields: unsafe { ::core::ptr::read(#fields) }),*
254                            })
255                        }
256                    )
257                }
258            }
259            syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
260                // For each field (unnamed), produce an ident like _0, _1, ... so we can capture the value
261                let fields = (0..unnamed.iter().len())
262                    .map(|i| format_ident!("_{i}"))
263                    .collect::<Vec<_>>();
264                quote! {
265                    #ident::#var_ident(#(#fields),*) => (#tag,
266                        #union_ident {
267                            #var_ident: ::core::mem::ManuallyDrop::new(
268                                #struct_name :: #var_generics (
269                                    // We've wrapped the enum that was passed in in a ManuallyDrop,
270                                    // and now we read each field with ptr::read.
271
272                                    // We wrap the enum that was passed in a ManuallyDrop to prevent
273                                    // double drops.
274
275                                    // We have to ptr::read because you can't move out of a
276                                    // type that implements `Drop`
277                                    // SAFETY: we are reading from a reference
278                                    #( unsafe { ::core::ptr::read(#fields) }),*
279                                )
280                            )
281                        }
282                    )
283                }
284            }
285            syn::Fields::Unit => quote! {
286                #ident::#var_ident => (#tag, #union_ident { #var_ident: () }) // Doesn't contain data
287            },
288        })
289    }
290    quote! {
291        type Value = #union_ident #generics;
292        fn cleave(self) -> (usize, <Self as Phenotype>::Value) {
293            match &*::core::mem::ManuallyDrop::new(self) {
294                #(#arms),*
295            }
296        }
297    }
298}
299
300/// A struct that represents the data found in an enum
301struct Auxiliary {
302    ident: Ident,
303    // Tokens for the actual code of the struct
304    tokens: proc_macro2::TokenStream,
305}
306
307/// Return an auxiliary struct that can hold the data from an enum variant.
308/// Returns `None` if the variant doesn't contain any data
309fn def_auxiliary_struct(
310    variant: &Variant,
311    enum_name: &Ident,
312    all_generics: &Generics,
313) -> Option<Auxiliary> {
314    let field = &variant.ident;
315
316    let struct_name = format_ident!("__PhenotypeInternal{}{}Data", enum_name, field);
317
318    let generics = generic::variant_generics(all_generics, variant);
319
320    match &variant.fields {
321        // Create a dummy struct that contains the named fields
322        // We need the field idents and types so we can make pairs like:
323        // ident1: type1
324        // ident2: type2
325        // ...
326        syn::Fields::Named(FieldsNamed { named, .. }) => {
327            // Get the names of the fields
328            let idents = named.iter().map(|field| field.ident.as_ref().unwrap());
329            let types = named.iter().map(|field| &field.ty);
330            Some(Auxiliary {
331                ident: struct_name.clone(),
332                tokens: quote! {
333                    #[repr(packed)]
334                    struct #struct_name #generics {
335                        #(#idents: #types,)*
336                    }
337                },
338            })
339        }
340
341        // Create a dummy tuple struct that contains the fields
342        // We only need the types so we can produce output like
343        // type1, type2, ...
344        syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
345            let types = unnamed.iter().map(|field| &field.ty);
346            Some(Auxiliary {
347                ident: struct_name.clone(),
348                tokens: quote! { #[repr(packed)] struct #struct_name #generics (#(#types,)*); },
349            })
350        }
351
352        // No fields so we don't need to do anything
353        syn::Fields::Unit => None,
354    }
355}
356
357/// Define all auxiliary structs and the data enum
358fn make_auxiliaries(data: &Condensed) -> proc_macro2::TokenStream {
359    // Define the union that holds the data
360    let union_ident = format_ident!("__PhenotypeInternal{}Data", data.name);
361
362    // Assorted data that goes into defining all the machinery
363    let (
364        mut struct_idents,
365        mut struct_defs,
366        mut field_idents,
367        mut empty_field_idents,
368        mut struct_generics,
369    ) = (vec![], vec![], vec![], vec![], vec![]);
370
371    for var in data.variants.values() {
372        if let Some(aux) = def_auxiliary_struct(var, &data.name, data.generics) {
373            struct_idents.push(aux.ident);
374            struct_defs.push(aux.tokens);
375            field_idents.push(var.ident.clone());
376            struct_generics.push(generic::variant_generics(data.generics, var));
377        } else {
378            empty_field_idents.push(var.ident.clone())
379        }
380    }
381
382    let union_generics = data.generics.split_for_impl().1;
383
384    quote! {
385        #(#struct_defs)*
386        #[allow(non_snake_case)]
387        union #union_ident #union_generics {
388            #(#field_idents: ::core::mem::ManuallyDrop<#struct_idents #struct_generics>,)*
389            #(#empty_field_idents: (),)*
390        }
391    }
392}
393
394#[proc_macro_derive(PhenotypeDebug)]
395#[proc_macro_error]
396pub fn phenotype_debug(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
397    let ast = parse_macro_input!(input as DeriveInput);
398
399    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
400    let ident = ast.ident.clone();
401
402    // Verify we have an enum
403    let enumb = match ast.data {
404        syn::Data::Enum(e) => e,
405        syn::Data::Struct(data) => {
406            abort!(data.struct_token, "struct `{}` is not an enum", ast.ident; note=NOTE)
407        }
408        syn::Data::Union(data) => {
409            abort!(data.union_token, "union `{}` is not an enum", ast.ident; note=NOTE)
410        }
411    };
412
413    let data = Condensed {
414        variants: enumb
415            .variants
416            .into_iter()
417            .enumerate()
418            .collect::<HashMap<Tag, Variant>>(),
419        name: ident.clone(),
420        generics: &ast.generics,
421    };
422
423    // Make sure there are variants!
424    if data.variants.is_empty() {
425        abort!(data.name, "enum `{}` has no variants", data.name)
426    }
427
428    // Abort if there are const generics - works funky with the way we deal with generics
429    if ast.generics.const_params().next().is_some() {
430        abort!(
431            ty_generics,
432            "const generics are not supported for `#[derive(Phenotype)]`";
433            note = "it may be possible to implement `Phenotype` by hand"
434        )
435    }
436
437    let discriminant_impl = discriminant_impl(&data);
438    let debug_tag_impl = debug_tag_impl(&data);
439    quote! {
440        impl #impl_generics PhenotypeDebug for #ident #ty_generics
441            #where_clause
442        {
443            #discriminant_impl
444            #debug_tag_impl
445        }
446    }
447    .into()
448}
449
450/// Code for the discriminant trait method
451fn discriminant_impl(data: &Condensed) -> proc_macro2::TokenStream {
452    let enum_name = &data.name;
453
454    // Zip variants together with discriminants
455    // Each quote! looks something like `ident::variant => number,`
456    let arms = data.variants.iter().map(|(tag, variant)| {
457        let var_ident = &variant.ident;
458        // Make sure we have the proper destructuring syntax
459        match variant.fields {
460            syn::Fields::Named(_) => quote! { #enum_name::#var_ident {..} => #tag,},
461            syn::Fields::Unnamed(_) => quote! { #enum_name::#var_ident (..) => #tag,},
462            syn::Fields::Unit => quote! { #enum_name::#var_ident => #tag,},
463        }
464    });
465
466    quote! {
467        fn discriminant(&self) -> usize {
468            match &self {
469                #(#arms)*
470            }
471        }
472    }
473}
474
475/// Code for the debug_tag trait method
476fn debug_tag_impl(data: &Condensed) -> proc_macro2::TokenStream {
477    let enum_name = &data.name;
478
479    // Zip variants together with discriminants
480    // Each quote! looks something like `ident::variant => number,`
481    let arms = data.variants.iter().map(|(tag, variant)| {
482        let var_ident = &variant.ident;
483        let stringified = format!("{}::{}", enum_name, var_ident);
484        quote! {
485            #tag => #stringified,
486        }
487    });
488
489    quote! {
490        fn debug_tag(tag: usize) -> &'static str {
491            match tag {
492                #(#arms)*
493                _ => ::core::panic!("invalid tag")
494            }
495        }
496    }
497}