bytes_cast_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{TokenStream, TokenTree};
4use quote::{quote, quote_spanned};
5use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Ident};
6
7/// Derive the `BytesCast` trait. See trait documentation.
8#[proc_macro_derive(BytesCast)]
9pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
10    derive_inner(parse_macro_input!(input as DeriveInput))
11        .unwrap_or_else(|error| quote!( compile_error!(#error); ))
12        .into()
13}
14
15fn derive_inner(input: DeriveInput) -> Result<TokenStream, &'static str> {
16    if !input.generics.params.is_empty() {
17        return Err(
18            "BytesCast cannot be derived for structs containing generic parameters \
19                    because the alignment requirement can’t be verified for generic structs",
20        )?;
21    }
22
23    let fields = if let Data::Struct(DataStruct { fields, .. }) = &input.data {
24        fields
25    } else {
26        return Err("Deriving BytesCast is only supported for structs");
27    };
28
29    check_repr(&input).map_err(|()| {
30        "BytesCast may give unexpected results without #[repr(C)] or #[repr(transparent)]"
31    })?;
32
33    let name = &input.ident;
34    let span = input.ident.span();
35    let field_types = fields.iter().map(|field| &field.ty);
36    let field_types2 = field_types.clone();
37    let asserts = quote_spanned!(span =>
38        const _: fn() = || {
39            let _static_assert_align_1: [(); 1] =
40                [(); ::core::mem::align_of::<#name>()];
41
42            // A struct whose size is the sum of the sizes of its fields cannot
43            // have padding between fields. This check is somewhat redundant
44            // since a `#[repr(C)]` struct with `align_of() == 1` is also know
45            // not to have padding
46            // (since all fields must also have `align_of() == 1`),
47            // but it doesn’t hurt to check.
48            #[allow(clippy::identity_op)] // There could be nothing to expand
49            let _static_assert_no_padding: [(); ::core::mem::size_of::<#name>()] =
50                [(); 0 #( + ::core::mem::size_of::<#field_types>() )*];
51
52            fn _static_assert_is_bytes_cast<T: BytesCast>() {}
53            #(
54                _static_assert_is_bytes_cast::<#field_types2>();
55            )*
56        };
57    );
58    Ok(quote! {
59        #asserts
60        unsafe impl BytesCast for #name {}
61    })
62}
63
64fn check_repr(input: &DeriveInput) -> Result<(), ()> {
65    for attr in &input.attrs {
66        if let (Some(path_ident), Some(inner_ident)) = (
67            attr.path.get_ident(),
68            get_ident_from_stream(attr.tokens.clone()),
69        ) {
70            if path_ident == "repr" {
71                return if inner_ident == "C" || inner_ident == "transparent" {
72                    Ok(())
73                } else {
74                    Err(())
75                };
76            }
77        }
78    }
79    Err(())
80}
81
82fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
83    match tokens.into_iter().next() {
84        Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
85        Some(TokenTree::Ident(ident)) => Some(ident),
86        _ => None,
87    }
88}