const_serialize_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, LitInt};
4use syn::{parse_quote, Generics, WhereClause, WherePredicate};
5
6fn add_bounds(where_clause: &mut Option<WhereClause>, generics: &Generics) {
7    let bounds = generics.params.iter().filter_map(|param| match param {
8        syn::GenericParam::Type(ty) => {
9            Some::<WherePredicate>(parse_quote! { #ty: const_serialize::SerializeConst, })
10        }
11        syn::GenericParam::Lifetime(_) => None,
12        syn::GenericParam::Const(_) => None,
13    });
14    if let Some(clause) = where_clause {
15        clause.predicates.extend(bounds);
16    } else {
17        *where_clause = Some(parse_quote! { where #(#bounds)* });
18    }
19}
20
21/// Derive the const serialize trait for a struct
22#[proc_macro_derive(SerializeConst)]
23pub fn derive_parse(input: TokenStream) -> TokenStream {
24    // Parse the input tokens into a syntax tree
25    let input = parse_macro_input!(input as DeriveInput);
26
27    match input.data {
28        syn::Data::Struct(data) => match data.fields {
29            syn::Fields::Unnamed(_) | syn::Fields::Named(_) => {
30                let ty = &input.ident;
31                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32                let mut where_clause = where_clause.cloned();
33                add_bounds(&mut where_clause, &input.generics);
34                let field_names = data.fields.iter().enumerate().map(|(i, field)| {
35                    field
36                        .ident
37                        .as_ref()
38                        .map(|ident| ident.to_token_stream())
39                        .unwrap_or_else(|| {
40                            LitInt::new(&i.to_string(), proc_macro2::Span::call_site())
41                                .into_token_stream()
42                        })
43                });
44                let field_types = data.fields.iter().map(|field| &field.ty);
45                quote! {
46                    unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
47                        const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Struct(const_serialize::StructLayout::new(
48                            std::mem::size_of::<Self>(),
49                            &[#(
50                                const_serialize::StructFieldLayout::new(
51                                    std::mem::offset_of!(#ty, #field_names),
52                                    <#field_types as const_serialize::SerializeConst>::MEMORY_LAYOUT,
53                                ),
54                            )*],
55                        ));
56                    }
57                }.into()
58            }
59            syn::Fields::Unit => {
60                let ty = &input.ident;
61                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
62                let mut where_clause = where_clause.cloned();
63                add_bounds(&mut where_clause, &input.generics);
64                quote! {
65                    unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
66                        const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Struct(const_serialize::StructLayout::new(
67                            std::mem::size_of::<Self>(),
68                            &[],
69                        ));
70                    }
71                }.into()
72            }
73        },
74        syn::Data::Enum(data) => match data.variants.len() {
75            0 => syn::Error::new(input.ident.span(), "Enums must have at least one variant")
76                .to_compile_error()
77                .into(),
78            1.. => {
79                let mut repr_c = false;
80                let mut discriminant_size = None;
81                for attr in &input.attrs {
82                    if attr.path().is_ident("repr") {
83                        if let Err(err) = attr.parse_nested_meta(|meta| {
84                            // #[repr(C)]
85                            if meta.path.is_ident("C") {
86                                repr_c = true;
87                                return Ok(());
88                            }
89
90                            // #[repr(u8)]
91                            if meta.path.is_ident("u8") {
92                                discriminant_size = Some(1);
93                                return Ok(());
94                            }
95
96                            // #[repr(u16)]
97                            if meta.path.is_ident("u16") {
98                                discriminant_size = Some(2);
99                                return Ok(());
100                            }
101
102                            // #[repr(u32)]
103                            if meta.path.is_ident("u32") {
104                                discriminant_size = Some(3);
105                                return Ok(());
106                            }
107
108                            // #[repr(u64)]
109                            if meta.path.is_ident("u64") {
110                                discriminant_size = Some(4);
111                                return Ok(());
112                            }
113
114                            Err(meta.error("unrecognized repr"))
115                        }) {
116                            return err.to_compile_error().into();
117                        }
118                    }
119                }
120
121                let variants_have_fields = data
122                    .variants
123                    .iter()
124                    .any(|variant| !variant.fields.is_empty());
125                if !repr_c && variants_have_fields {
126                    return syn::Error::new(input.ident.span(), "Enums must be repr(C, u*)")
127                        .to_compile_error()
128                        .into();
129                }
130
131                if discriminant_size.is_none() {
132                    return syn::Error::new(input.ident.span(), "Enums must be repr(u*)")
133                        .to_compile_error()
134                        .into();
135                }
136
137                let ty = &input.ident;
138                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
139                let mut where_clause = where_clause.cloned();
140                add_bounds(&mut where_clause, &input.generics);
141                let mut last_discriminant = None;
142                let variants = data.variants.iter().map(|variant| {
143                    let discriminant = variant
144                        .discriminant
145                        .as_ref()
146                        .map(|(_, discriminant)| discriminant.to_token_stream())
147                        .unwrap_or_else(|| match &last_discriminant {
148                            Some(discriminant) => quote! { #discriminant + 1 },
149                            None => {
150                                quote! { 0 }
151                            }
152                        });
153                    last_discriminant = Some(discriminant.clone());
154                    let field_names = variant.fields.iter().enumerate().map(|(i, field)| {
155                        field
156                            .ident
157                            .clone()
158                            .unwrap_or_else(|| quote::format_ident!("__field_{}", i))
159                    });
160                    let field_types = variant.fields.iter().map(|field| &field.ty);
161                    let generics = &input.generics;
162                    quote! {
163                        {
164                            #[allow(unused)]
165                            #[derive(const_serialize::SerializeConst)]
166                            #[repr(C)]
167                            struct VariantStruct #generics {
168                                #(
169                                    #field_names: #field_types,
170                                )*
171                            }
172                            const_serialize::EnumVariant::new(
173                                #discriminant as u32,
174                                match VariantStruct::MEMORY_LAYOUT {
175                                    const_serialize::Layout::Struct(layout) => layout,
176                                    _ => panic!("VariantStruct::MEMORY_LAYOUT must be a struct"),
177                                },
178                                std::mem::align_of::<VariantStruct>(),
179                            )
180                        }
181                    }
182                });
183                quote! {
184                    unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
185                        const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Enum(const_serialize::EnumLayout::new(
186                            std::mem::size_of::<Self>(),
187                            const_serialize::PrimitiveLayout::new(
188                                #discriminant_size as usize,
189                            ),
190                            {
191                                const DATA: &'static [const_serialize::EnumVariant] = &[
192                                    #(
193                                        #variants,
194                                    )*
195                                ];
196                                DATA
197                            },
198                        ));
199                    }
200                }.into()
201            }
202        },
203        _ => syn::Error::new(input.ident.span(), "Only structs and enums are supported")
204            .to_compile_error()
205            .into(),
206    }
207}