const_serialize_macro/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, LitInt};
4use syn::{parse_quote, Generics, WhereClause, WherePredicate};
5
6fn add_bounds(where_clause: &mut Option<WhereClause>, generics: &Generics) {
7 let bounds = generics.params.iter().filter_map(|param| match param {
8 syn::GenericParam::Type(ty) => {
9 Some::<WherePredicate>(parse_quote! { #ty: const_serialize::SerializeConst, })
10 }
11 syn::GenericParam::Lifetime(_) => None,
12 syn::GenericParam::Const(_) => None,
13 });
14 if let Some(clause) = where_clause {
15 clause.predicates.extend(bounds);
16 } else {
17 *where_clause = Some(parse_quote! { where #(#bounds)* });
18 }
19}
20
21#[proc_macro_derive(SerializeConst)]
23pub fn derive_parse(input: TokenStream) -> TokenStream {
24 let input = parse_macro_input!(input as DeriveInput);
26
27 match input.data {
28 syn::Data::Struct(data) => match data.fields {
29 syn::Fields::Unnamed(_) | syn::Fields::Named(_) => {
30 let ty = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32 let mut where_clause = where_clause.cloned();
33 add_bounds(&mut where_clause, &input.generics);
34 let field_names = data.fields.iter().enumerate().map(|(i, field)| {
35 field
36 .ident
37 .as_ref()
38 .map(|ident| ident.to_token_stream())
39 .unwrap_or_else(|| {
40 LitInt::new(&i.to_string(), proc_macro2::Span::call_site())
41 .into_token_stream()
42 })
43 });
44 let field_types = data.fields.iter().map(|field| &field.ty);
45 quote! {
46 unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
47 const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Struct(const_serialize::StructLayout::new(
48 std::mem::size_of::<Self>(),
49 &[#(
50 const_serialize::StructFieldLayout::new(
51 std::mem::offset_of!(#ty, #field_names),
52 <#field_types as const_serialize::SerializeConst>::MEMORY_LAYOUT,
53 ),
54 )*],
55 ));
56 }
57 }.into()
58 }
59 syn::Fields::Unit => {
60 let ty = &input.ident;
61 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
62 let mut where_clause = where_clause.cloned();
63 add_bounds(&mut where_clause, &input.generics);
64 quote! {
65 unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
66 const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Struct(const_serialize::StructLayout::new(
67 std::mem::size_of::<Self>(),
68 &[],
69 ));
70 }
71 }.into()
72 }
73 },
74 syn::Data::Enum(data) => match data.variants.len() {
75 0 => syn::Error::new(input.ident.span(), "Enums must have at least one variant")
76 .to_compile_error()
77 .into(),
78 1.. => {
79 let mut repr_c = false;
80 let mut discriminant_size = None;
81 for attr in &input.attrs {
82 if attr.path().is_ident("repr") {
83 if let Err(err) = attr.parse_nested_meta(|meta| {
84 if meta.path.is_ident("C") {
86 repr_c = true;
87 return Ok(());
88 }
89
90 if meta.path.is_ident("u8") {
92 discriminant_size = Some(1);
93 return Ok(());
94 }
95
96 if meta.path.is_ident("u16") {
98 discriminant_size = Some(2);
99 return Ok(());
100 }
101
102 if meta.path.is_ident("u32") {
104 discriminant_size = Some(3);
105 return Ok(());
106 }
107
108 if meta.path.is_ident("u64") {
110 discriminant_size = Some(4);
111 return Ok(());
112 }
113
114 Err(meta.error("unrecognized repr"))
115 }) {
116 return err.to_compile_error().into();
117 }
118 }
119 }
120
121 let variants_have_fields = data
122 .variants
123 .iter()
124 .any(|variant| !variant.fields.is_empty());
125 if !repr_c && variants_have_fields {
126 return syn::Error::new(input.ident.span(), "Enums must be repr(C, u*)")
127 .to_compile_error()
128 .into();
129 }
130
131 if discriminant_size.is_none() {
132 return syn::Error::new(input.ident.span(), "Enums must be repr(u*)")
133 .to_compile_error()
134 .into();
135 }
136
137 let ty = &input.ident;
138 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
139 let mut where_clause = where_clause.cloned();
140 add_bounds(&mut where_clause, &input.generics);
141 let mut last_discriminant = None;
142 let variants = data.variants.iter().map(|variant| {
143 let discriminant = variant
144 .discriminant
145 .as_ref()
146 .map(|(_, discriminant)| discriminant.to_token_stream())
147 .unwrap_or_else(|| match &last_discriminant {
148 Some(discriminant) => quote! { #discriminant + 1 },
149 None => {
150 quote! { 0 }
151 }
152 });
153 last_discriminant = Some(discriminant.clone());
154 let field_names = variant.fields.iter().enumerate().map(|(i, field)| {
155 field
156 .ident
157 .clone()
158 .unwrap_or_else(|| quote::format_ident!("__field_{}", i))
159 });
160 let field_types = variant.fields.iter().map(|field| &field.ty);
161 let generics = &input.generics;
162 quote! {
163 {
164 #[allow(unused)]
165 #[derive(const_serialize::SerializeConst)]
166 #[repr(C)]
167 struct VariantStruct #generics {
168 #(
169 #field_names: #field_types,
170 )*
171 }
172 const_serialize::EnumVariant::new(
173 #discriminant as u32,
174 match VariantStruct::MEMORY_LAYOUT {
175 const_serialize::Layout::Struct(layout) => layout,
176 _ => panic!("VariantStruct::MEMORY_LAYOUT must be a struct"),
177 },
178 std::mem::align_of::<VariantStruct>(),
179 )
180 }
181 }
182 });
183 quote! {
184 unsafe impl #impl_generics const_serialize::SerializeConst for #ty #ty_generics #where_clause {
185 const MEMORY_LAYOUT: const_serialize::Layout = const_serialize::Layout::Enum(const_serialize::EnumLayout::new(
186 std::mem::size_of::<Self>(),
187 const_serialize::PrimitiveLayout::new(
188 #discriminant_size as usize,
189 ),
190 {
191 const DATA: &'static [const_serialize::EnumVariant] = &[
192 #(
193 #variants,
194 )*
195 ];
196 DATA
197 },
198 ));
199 }
200 }.into()
201 }
202 },
203 _ => syn::Error::new(input.ident.span(), "Only structs and enums are supported")
204 .to_compile_error()
205 .into(),
206 }
207}