memory_layout_codegen/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, quote_spanned};
4use syn::{
5 parse::Parse, parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Error as SynError,
6 Field, LitInt, Result as SynResult, Type
7};
8
9struct FieldInfo {
10 field: Field,
11 previous_type: Option<Type>,
12 relative_offset: usize,
13 absolute_offset: usize
14}
15
16struct StructInfo {
17 derived: DeriveInput,
18 fields: Vec<FieldInfo>
19}
20
21impl StructInfo {
22 fn get_data_struct(input: &DeriveInput) -> SynResult<&DataStruct> {
23 match &input.data {
24 Data::Struct(data) => Ok(data),
25 Data::Enum(data) => {
26 Err(SynError::new_spanned(
27 data.enum_token,
28 "Expected struct but found enum."
29 ))
30 }
31 Data::Union(data) => {
32 Err(SynError::new_spanned(
33 data.union_token,
34 "Expected struct but found union."
35 ))
36 }
37 }
38 }
39
40 fn get_field_offset_value(attr: &Attribute) -> SynResult<usize> {
41 attr
42 .parse_args::<LitInt>()
43 .and_then(|lit| lit.base10_parse::<usize>())
44 .map_err(|_| SynError::new_spanned(attr, "Field offset must be an integer literal."))
45 }
46
47 fn get_fields(data: &DataStruct) -> SynResult<Vec<FieldInfo>> {
48 let mut result = Vec::<FieldInfo>::new();
49
50 let mut current_offset = 0usize;
51 let mut previous_type: Option<Type> = None;
52 for field in &data.fields {
53 let field_offset = field
54 .attrs
55 .iter()
56 .find(|attr| attr.path().is_ident("field_offset"));
57
58 let offset = field_offset
59 .ok_or_else(|| SynError::new_spanned(field, "Field is missing a field_offset."))
60 .and_then(Self::get_field_offset_value)?;
61
62 if current_offset > offset {
63 return Err(SynError::new_spanned(
64 field_offset,
65 "Field offset can't be lower than its predecessor."
66 ));
67 }
68
69 result.push(FieldInfo {
70 field: field.clone(),
71 previous_type: previous_type.clone(),
72 relative_offset: offset - current_offset,
73 absolute_offset: offset
74 });
75
76 previous_type = Some(field.ty.clone());
77 current_offset = offset
78 }
79
80 Ok(result)
81 }
82}
83
84impl Parse for StructInfo {
85 fn parse(input: syn::parse::ParseStream) -> SynResult<Self> {
86 let input: DeriveInput = input.parse()?;
87
88 let data = Self::get_data_struct(&input)?;
89 let fields = Self::get_fields(data)?;
90
91 Ok(StructInfo {
92 derived: input,
93 fields
94 })
95 }
96}
97
98#[cfg(feature = "offset_of")]
99fn generate_field_offset_checks(struct_info: &StructInfo) -> Vec<proc_macro2::TokenStream> {
100 let struct_ident = &struct_info.derived.ident;
101 struct_info
102 .fields
103 .iter()
104 .map(|field| {
105 let Some(ident) = field.field.ident.clone() else {
106 todo!()
107 };
108 let offset = field.absolute_offset;
109 quote! {
110 const _:() = assert!(::core::mem::offset_of!(#struct_ident, #ident) == #offset);
111 }
112 })
113 .collect::<Vec<_>>()
114}
115
116#[cfg(not(feature = "offset_of"))]
117fn generate_field_offset_checks(_: &StructInfo) -> Vec<proc_macro2::TokenStream> {
118 vec![]
119}
120
121#[proc_macro_attribute]
163pub fn memory_layout(attr: TokenStream, input: TokenStream) -> TokenStream {
164 let struct_info = parse_macro_input!(input as StructInfo);
165
166 let attr_value = parse_macro_input!(attr as Option<LitInt>);
167
168 let desired_size = if let Some(lit) = attr_value {
169 let Ok(r) = lit.base10_parse::<usize>() else {
170 return quote_spanned!(
171 lit.span() =>
172 compile_error!("Desired size must be a valid usize");
173 )
174 .into();
175 };
176 Some(r)
177 } else {
178 None
179 };
180
181 let mut fields = struct_info
182 .fields
183 .iter()
184 .enumerate()
185 .map(|(i, f)| {
186 let ident = f.field.ident.as_ref().unwrap();
187 let typename = &f.field.ty;
188 let vis = &f.field.vis;
189 let relative_offset = f.relative_offset;
190 let previous_type = &f.previous_type;
191 let pad_ident = syn::Ident::new(&format!("__pad{}", i), ident.span());
192 let attrs = f
193 .field
194 .attrs
195 .iter()
196 .filter(|attr| !attr.path().is_ident("field_offset"));
197 match previous_type {
198 Some(ty) => {
199 quote! {
200 #[doc(hidden)]
201 #pad_ident: [u8; #relative_offset - ::core::mem::size_of::<#ty>()],
202 #(#attrs)*
203 #vis #ident: #typename
204 }
205 }
206 None => {
207 quote! {
208 #[doc(hidden)]
209 #pad_ident: [u8; #relative_offset],
210 #(#attrs)*
211 #vis #ident: #typename
212 }
213 }
214 }
215 })
216 .collect::<Vec<_>>();
217
218 if let Some(size) = desired_size {
219 let pad_ident = syn::Ident::new(
220 &format!("__pad{}", struct_info.fields.len()),
221 Span::call_site()
222 );
223 if let Some(last_field) = struct_info.fields.last() {
224 let last_offset = last_field.absolute_offset;
225 let prev_type = last_field.field.ty.clone();
226 let Some(required_padding) = size.checked_sub(last_offset) else {
227 return quote!(
228 compile_error!("Desired struct size is lower than the highest field offset.");
229 )
230 .into();
231 };
232
233 fields.push(quote! {
234 #[doc(hidden)]
235 #pad_ident: [u8; #required_padding - ::core::mem::size_of::<#prev_type>()],
236 })
237 } else {
238 fields.push(quote! {
239 #[doc(hidden)]
240 #pad_ident: [u8; #size],
241 })
242 }
243 }
244
245 let struct_ident = &struct_info.derived.ident;
246 let struct_size_check = desired_size.map(|size| {
247 quote! {
248 const _:() = assert!(::core::mem::size_of::<#struct_ident>() == #size);
249 }
250 });
251
252 let field_offset_checks = generate_field_offset_checks(&struct_info);
253
254 let name = struct_info.derived.ident;
255 let vis = struct_info.derived.vis;
256 let attrs = struct_info.derived.attrs;
257 let generics = struct_info.derived.generics;
258
259 quote! {
260 #(#attrs)*
261 #vis struct #name #generics {
262 #(#fields),*
263 }
264
265 #(#field_offset_checks)*
266 #struct_size_check
267 }
268 .into()
269}