1extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{format_ident, quote};
12use syn::{Fields, ItemStruct, parse_macro_input};
13
14#[proc_macro_attribute]
53pub fn assert_no_padding(_attr: TokenStream, item: TokenStream) -> TokenStream {
54 let input = parse_macro_input!(item as ItemStruct);
55 let expanded = emit_no_padding_assertion(&input);
56 TokenStream::from(quote! {
57 #input
58 #expanded
59 })
60}
61
62fn emit_no_padding_assertion(input: &ItemStruct) -> TokenStream2 {
63 let struct_name = &input.ident;
64 let const_ident = format_ident!(
65 "_PADLOCK_ASSERT_NO_PADDING_{}",
66 struct_name.to_string().to_uppercase()
67 );
68
69 let field_types: Vec<_> = match &input.fields {
70 Fields::Named(nf) => nf.named.iter().map(|f| &f.ty).collect(),
71 Fields::Unnamed(uf) => uf.unnamed.iter().map(|f| &f.ty).collect(),
72 Fields::Unit => {
73 return quote! {
75 const #const_ident: () = ();
76 };
77 }
78 };
79
80 if field_types.is_empty() {
81 return quote! {
82 const #const_ident: () = ();
83 };
84 }
85
86 let field_sizes = field_types.iter().map(|ty| {
88 quote! { ::std::mem::size_of::<#ty>() }
89 });
90
91 quote! {
98 const #const_ident: () = {
99 let struct_size = ::std::mem::size_of::<#struct_name>();
100 let field_sum: usize = 0 #( + #field_sizes )*;
101 assert!(
102 struct_size == field_sum,
103 concat!(
104 "padlock: struct `",
105 stringify!(#struct_name),
106 "` has padding — size_of != sum of field sizes. ",
107 "Reorder fields by descending alignment or add #[repr(packed)]."
108 )
109 );
110 };
111 }
112}
113
114#[proc_macro_attribute]
130pub fn assert_size(attr: TokenStream, item: TokenStream) -> TokenStream {
131 let input = parse_macro_input!(item as ItemStruct);
132 let expected: syn::LitInt = match syn::parse(attr) {
133 Ok(n) => n,
134 Err(e) => return e.to_compile_error().into(),
135 };
136
137 let struct_name = &input.ident;
138 let const_ident = format_ident!(
139 "_PADLOCK_ASSERT_SIZE_{}",
140 struct_name.to_string().to_uppercase()
141 );
142
143 let expanded = quote! {
144 #input
145
146 const #const_ident: () = {
147 let actual = ::std::mem::size_of::<#struct_name>();
148 let expected: usize = #expected;
149 assert!(
150 actual == expected,
151 concat!(
152 "padlock: struct `",
153 stringify!(#struct_name),
154 "` has unexpected size. Check for accidental padding or field additions."
155 )
156 );
157 };
158 };
159
160 TokenStream::from(expanded)
161}
162
163#[cfg(test)]
166mod tests {
167 use super::*;
168 use syn::parse_quote;
169
170 #[test]
171 fn no_padding_assertion_for_unit_struct_is_empty_const() {
172 let item: ItemStruct = parse_quote! { struct Unit; };
173 let ts = emit_no_padding_assertion(&item);
174 let s = ts.to_string();
175 assert!(s.contains("()"));
177 assert!(!s.contains("size_of"));
179 }
180
181 #[test]
182 fn no_padding_assertion_contains_struct_name() {
183 let item: ItemStruct = parse_quote! {
184 struct MyStruct {
185 a: u64,
186 b: u32,
187 }
188 };
189 let ts = emit_no_padding_assertion(&item);
190 let s = ts.to_string();
191 assert!(
192 s.contains("MY_STRUCT") || s.contains("MyStruct") || s.contains("my_struct"),
193 "expected struct name reference in: {s}"
194 );
195 }
196
197 #[test]
198 fn no_padding_assertion_includes_size_of_fields() {
199 let item: ItemStruct = parse_quote! {
200 struct Foo {
201 a: u8,
202 b: u64,
203 }
204 };
205 let ts = emit_no_padding_assertion(&item);
206 let s = ts.to_string();
207 assert!(s.contains("size_of"), "expected size_of in: {s}");
208 assert!(s.contains("u8"), "expected u8 in: {s}");
209 assert!(s.contains("u64"), "expected u64 in: {s}");
210 }
211
212 #[test]
213 fn no_padding_assertion_empty_named_fields_is_trivial() {
214 let item: ItemStruct = parse_quote! { struct Empty {} };
216 let ts = emit_no_padding_assertion(&item);
217 let s = ts.to_string();
218 assert!(
219 !s.contains("size_of"),
220 "empty struct should not generate size_of check"
221 );
222 }
223
224 #[test]
225 fn no_padding_const_name_is_uppercase() {
226 let item: ItemStruct = parse_quote! {
227 struct FooBar { x: u32 }
228 };
229 let ts = emit_no_padding_assertion(&item);
230 let s = ts.to_string();
231 assert!(s.contains("FOOBAR"), "expected FOOBAR in const name: {s}");
233 }
234
235 #[test]
236 fn assert_message_contains_struct_name() {
237 let item: ItemStruct = parse_quote! {
238 struct Suspect { a: u8, b: u64 }
239 };
240 let ts = emit_no_padding_assertion(&item);
241 let s = ts.to_string();
242 assert!(
243 s.contains("Suspect"),
244 "expected struct name in assertion message: {s}"
245 );
246 }
247}