const_serialize_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, LitInt, Path};
4use syn::{parse_quote, Generics, WhereClause, WherePredicate};
5
6fn add_bounds(where_clause: &mut Option<WhereClause>, generics: &Generics, krate: &Path) {
7    let bounds = generics.params.iter().filter_map(|param| match param {
8        syn::GenericParam::Type(ty) => {
9            Some::<WherePredicate>(parse_quote! { #ty: #krate::SerializeConst, })
10        }
11        syn::GenericParam::Lifetime(_) => None,
12        syn::GenericParam::Const(_) => None,
13    });
14    if let Some(clause) = where_clause {
15        clause.predicates.extend(bounds);
16    } else {
17        *where_clause = Some(parse_quote! { where #(#bounds)* });
18    }
19}
20
21/// Derive the const serialize trait for a struct
22#[proc_macro_derive(SerializeConst, attributes(const_serialize))]
23pub fn derive_parse(raw_input: TokenStream) -> TokenStream {
24    // Parse the input tokens into a syntax tree
25    let input = parse_macro_input!(raw_input as DeriveInput);
26    let krate = input.attrs.iter().find_map(|attr| {
27        attr.path()
28            .is_ident("const_serialize")
29            .then(|| {
30                let mut path = None;
31                if let Err(err) = attr.parse_nested_meta(|meta| {
32                    if meta.path.is_ident("crate") {
33                        let ident: Path = meta.value()?.parse()?;
34                        path = Some(ident);
35                    }
36                    Ok(())
37                }) {
38                    return Some(Err(err));
39                }
40                path.map(Ok)
41            })
42            .flatten()
43    });
44    let krate = match krate {
45        Some(Ok(path)) => path,
46        Some(Err(err)) => return err.into_compile_error().into(),
47        None => parse_quote! { const_serialize },
48    };
49
50    match input.data {
51        syn::Data::Struct(data) => match data.fields {
52            syn::Fields::Unnamed(_) | syn::Fields::Named(_) => {
53                let ty = &input.ident;
54                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
55                let mut where_clause = where_clause.cloned();
56                add_bounds(&mut where_clause, &input.generics, &krate);
57                let field_names = data.fields.iter().enumerate().map(|(i, field)| {
58                    field
59                        .ident
60                        .as_ref()
61                        .map(|ident| ident.to_token_stream())
62                        .unwrap_or_else(|| {
63                            LitInt::new(&i.to_string(), proc_macro2::Span::call_site())
64                                .into_token_stream()
65                        })
66                });
67                let field_types = data.fields.iter().map(|field| &field.ty);
68                quote! {
69                    unsafe impl #impl_generics #krate::SerializeConst for #ty #ty_generics #where_clause {
70                        const MEMORY_LAYOUT: #krate::Layout = #krate::Layout::Struct(#krate::StructLayout::new(
71                            std::mem::size_of::<Self>(),
72                            &[#(
73                                #krate::StructFieldLayout::new(
74                                    stringify!(#field_names),
75                                    std::mem::offset_of!(#ty, #field_names),
76                                    <#field_types as #krate::SerializeConst>::MEMORY_LAYOUT,
77                                ),
78                            )*],
79                        ));
80                    }
81                }.into()
82            }
83            syn::Fields::Unit => {
84                let ty = &input.ident;
85                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
86                let mut where_clause = where_clause.cloned();
87                add_bounds(&mut where_clause, &input.generics, &krate);
88                quote! {
89                    unsafe impl #impl_generics #krate::SerializeConst for #ty #ty_generics #where_clause {
90                        const MEMORY_LAYOUT: #krate::Layout = #krate::Layout::Struct(#krate::StructLayout::new(
91                            std::mem::size_of::<Self>(),
92                            &[],
93                        ));
94                    }
95                }.into()
96            }
97        },
98        syn::Data::Enum(data) => match data.variants.len() {
99            0 => syn::Error::new(input.ident.span(), "Enums must have at least one variant")
100                .to_compile_error()
101                .into(),
102            1.. => {
103                let mut repr_c = false;
104                let mut discriminant_size = None;
105                for attr in &input.attrs {
106                    if attr.path().is_ident("repr") {
107                        if let Err(err) = attr.parse_nested_meta(|meta| {
108                            // #[repr(C)]
109                            if meta.path.is_ident("C") {
110                                repr_c = true;
111                                return Ok(());
112                            }
113
114                            // #[repr(u8)]
115                            if meta.path.is_ident("u8") {
116                                discriminant_size = Some(1);
117                                return Ok(());
118                            }
119
120                            // #[repr(u16)]
121                            if meta.path.is_ident("u16") {
122                                discriminant_size = Some(2);
123                                return Ok(());
124                            }
125
126                            // #[repr(u32)]
127                            if meta.path.is_ident("u32") {
128                                discriminant_size = Some(3);
129                                return Ok(());
130                            }
131
132                            // #[repr(u64)]
133                            if meta.path.is_ident("u64") {
134                                discriminant_size = Some(4);
135                                return Ok(());
136                            }
137
138                            Err(meta.error("unrecognized repr"))
139                        }) {
140                            return err.to_compile_error().into();
141                        }
142                    }
143                }
144
145                let variants_have_fields = data
146                    .variants
147                    .iter()
148                    .any(|variant| !variant.fields.is_empty());
149                if !repr_c && variants_have_fields {
150                    return syn::Error::new(input.ident.span(), "Enums must be repr(C, u*)")
151                        .to_compile_error()
152                        .into();
153                }
154
155                if discriminant_size.is_none() {
156                    return syn::Error::new(input.ident.span(), "Enums must be repr(u*)")
157                        .to_compile_error()
158                        .into();
159                }
160
161                let ty = &input.ident;
162                let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
163                let mut where_clause = where_clause.cloned();
164                add_bounds(&mut where_clause, &input.generics, &krate);
165                let mut last_discriminant = None;
166                let variants = data.variants.iter().map(|variant| {
167                    let discriminant = variant
168                        .discriminant
169                        .as_ref()
170                        .map(|(_, discriminant)| discriminant.to_token_stream())
171                        .unwrap_or_else(|| match &last_discriminant {
172                            Some(discriminant) => quote! { #discriminant + 1 },
173                            None => {
174                                quote! { 0 }
175                            }
176                        });
177                    last_discriminant = Some(discriminant.clone());
178                    let variant_name = &variant.ident;
179                    let field_names = variant.fields.iter().enumerate().map(|(i, field)| {
180                        field
181                            .ident
182                            .clone()
183                            .unwrap_or_else(|| quote::format_ident!("__field_{}", i))
184                    });
185                    let field_types = variant.fields.iter().map(|field| &field.ty);
186                    let generics = &input.generics;
187                    quote! {
188                        {
189                            #[allow(unused)]
190                            #[derive(#krate::SerializeConst)]
191                            #[const_serialize(crate = #krate)]
192                            #[repr(C)]
193                            struct VariantStruct #generics {
194                                #(
195                                    #field_names: #field_types,
196                                )*
197                            }
198                            #krate::EnumVariant::new(
199                                stringify!(#variant_name),
200                                #discriminant as u32,
201                                match <VariantStruct #generics as #krate::SerializeConst>::MEMORY_LAYOUT {
202                                    #krate::Layout::Struct(layout) => layout,
203                                    _ => panic!("VariantStruct::MEMORY_LAYOUT must be a struct"),
204                                },
205                                ::std::mem::align_of::<VariantStruct>(),
206                            )
207                        }
208                    }
209                });
210                quote! {
211                    unsafe impl #impl_generics #krate::SerializeConst for #ty #ty_generics #where_clause {
212                        const MEMORY_LAYOUT: #krate::Layout = #krate::Layout::Enum(#krate::EnumLayout::new(
213                            ::std::mem::size_of::<Self>(),
214                            #krate::PrimitiveLayout::new(
215                                #discriminant_size as usize,
216                            ),
217                            {
218                                const DATA: &'static [#krate::EnumVariant] = &[
219                                    #(
220                                        #variants,
221                                    )*
222                                ];
223                                DATA
224                            },
225                        ));
226                    }
227                }.into()
228            }
229        },
230        _ => syn::Error::new(input.ident.span(), "Only structs and enums are supported")
231            .to_compile_error()
232            .into(),
233    }
234}