extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Fields, ItemStruct, parse_macro_input};
#[proc_macro_attribute]
pub fn assert_no_padding(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let expanded = emit_no_padding_assertion(&input);
TokenStream::from(quote! {
#input
#expanded
})
}
fn emit_no_padding_assertion(input: &ItemStruct) -> TokenStream2 {
let struct_name = &input.ident;
let const_ident = format_ident!(
"_PADLOCK_ASSERT_NO_PADDING_{}",
struct_name.to_string().to_uppercase()
);
let field_types: Vec<_> = match &input.fields {
Fields::Named(nf) => nf.named.iter().map(|f| &f.ty).collect(),
Fields::Unnamed(uf) => uf.unnamed.iter().map(|f| &f.ty).collect(),
Fields::Unit => {
return quote! {
const #const_ident: () = ();
};
}
};
if field_types.is_empty() {
return quote! {
const #const_ident: () = ();
};
}
let field_sizes = field_types.iter().map(|ty| {
quote! { ::std::mem::size_of::<#ty>() }
});
quote! {
const #const_ident: () = {
let struct_size = ::std::mem::size_of::<#struct_name>();
let field_sum: usize = 0 #( + #field_sizes )*;
assert!(
struct_size == field_sum,
concat!(
"padlock: struct `",
stringify!(#struct_name),
"` has padding — size_of != sum of field sizes. ",
"Reorder fields by descending alignment or add #[repr(packed)]."
)
);
};
}
}
#[proc_macro_attribute]
pub fn assert_size(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let expected: syn::LitInt = match syn::parse(attr) {
Ok(n) => n,
Err(e) => return e.to_compile_error().into(),
};
let struct_name = &input.ident;
let const_ident = format_ident!(
"_PADLOCK_ASSERT_SIZE_{}",
struct_name.to_string().to_uppercase()
);
let expanded = quote! {
#input
const #const_ident: () = {
let actual = ::std::mem::size_of::<#struct_name>();
let expected: usize = #expected;
assert!(
actual == expected,
concat!(
"padlock: struct `",
stringify!(#struct_name),
"` has unexpected size. Check for accidental padding or field additions."
)
);
};
};
TokenStream::from(expanded)
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn no_padding_assertion_for_unit_struct_is_empty_const() {
let item: ItemStruct = parse_quote! { struct Unit; };
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(s.contains("()"));
assert!(!s.contains("size_of"));
}
#[test]
fn no_padding_assertion_contains_struct_name() {
let item: ItemStruct = parse_quote! {
struct MyStruct {
a: u64,
b: u32,
}
};
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(
s.contains("MY_STRUCT") || s.contains("MyStruct") || s.contains("my_struct"),
"expected struct name reference in: {s}"
);
}
#[test]
fn no_padding_assertion_includes_size_of_fields() {
let item: ItemStruct = parse_quote! {
struct Foo {
a: u8,
b: u64,
}
};
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(s.contains("size_of"), "expected size_of in: {s}");
assert!(s.contains("u8"), "expected u8 in: {s}");
assert!(s.contains("u64"), "expected u64 in: {s}");
}
#[test]
fn no_padding_assertion_empty_named_fields_is_trivial() {
let item: ItemStruct = parse_quote! { struct Empty {} };
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(
!s.contains("size_of"),
"empty struct should not generate size_of check"
);
}
#[test]
fn no_padding_const_name_is_uppercase() {
let item: ItemStruct = parse_quote! {
struct FooBar { x: u32 }
};
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(s.contains("FOOBAR"), "expected FOOBAR in const name: {s}");
}
#[test]
fn assert_message_contains_struct_name() {
let item: ItemStruct = parse_quote! {
struct Suspect { a: u8, b: u64 }
};
let ts = emit_no_padding_assertion(&item);
let s = ts.to_string();
assert!(
s.contains("Suspect"),
"expected struct name in assertion message: {s}"
);
}
}