1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
extern crate proc_macro;

use proc_macro2::{TokenStream, TokenTree};
use quote::{quote, quote_spanned};
use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Ident};

/// Derive the `BytesCast` trait. See trait documentation.
#[proc_macro_derive(BytesCast)]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    derive_inner(parse_macro_input!(input as DeriveInput))
        .unwrap_or_else(|error| quote!( compile_error!(#error); ))
        .into()
}

fn derive_inner(input: DeriveInput) -> Result<TokenStream, &'static str> {
    if !input.generics.params.is_empty() {
        return Err(
            "BytesCast cannot be derived for structs containing generic parameters \
                    because the alignment requirement can’t be verified for generic structs",
        )?;
    }

    let fields = if let Data::Struct(DataStruct { fields, .. }) = &input.data {
        fields
    } else {
        return Err("Deriving BytesCast is only supported for structs");
    };

    check_repr(&input).map_err(|()| {
        "BytesCast may give unexpected results without #[repr(C)] or #[repr(transparent)]"
    })?;

    let name = &input.ident;
    let span = input.ident.span();
    let field_types = fields.iter().map(|field| &field.ty);
    let field_types2 = field_types.clone();
    let asserts = quote_spanned!(span =>
        const _: fn() = || {
            let _static_assert_align_1: [(); 1] =
                [(); ::core::mem::align_of::<#name>()];

            // A struct whose size is the sum of the sizes of its fields cannot
            // have padding between fields. This check is somewhat redundant
            // since a `#[repr(C)]` struct with `align_of() == 1` is also know
            // not to have padding
            // (since all fields must also have `align_of() == 1`),
            // but it doesn’t hurt to check.
            #[allow(clippy::identity_op)] // There could be nothing to expand
            let _static_assert_no_padding: [(); ::core::mem::size_of::<#name>()] =
                [(); 0 #( + ::core::mem::size_of::<#field_types>() )*];

            fn _static_assert_is_bytes_cast<T: BytesCast>() {}
            #(
                _static_assert_is_bytes_cast::<#field_types2>();
            )*
        };
    );
    Ok(quote! {
        #asserts
        unsafe impl BytesCast for #name {}
    })
}

fn check_repr(input: &DeriveInput) -> Result<(), ()> {
    for attr in &input.attrs {
        if let (Some(path_ident), Some(inner_ident)) = (
            attr.path.get_ident(),
            get_ident_from_stream(attr.tokens.clone()),
        ) {
            if path_ident == "repr" {
                return if inner_ident == "C" || inner_ident == "transparent" {
                    Ok(())
                } else {
                    Err(())
                };
            }
        }
    }
    Err(())
}

fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
    match tokens.into_iter().next() {
        Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
        Some(TokenTree::Ident(ident)) => Some(ident),
        _ => None,
    }
}