use std::collections::HashSet;
use quote::quote;
use syn::parse;
#[derive(Copy, Clone)]
enum AccessorKind {
Pod,
Checked,
}
#[derive(Copy, Clone)]
pub struct Args {
size: usize,
accessors: AccessorKind,
}
mod kw {
syn::custom_keyword!(size);
syn::custom_keyword!(pod_accessors);
syn::custom_keyword!(checked_accessors);
}
impl syn::parse::Parse for Args {
fn parse(input: syn::parse::ParseStream<'_>) -> parse::Result<Self> {
fn try_parse_args(input: syn::parse::ParseStream<'_>) -> parse::Result<Args> {
let size = if input.peek(kw::size) {
input.parse::<kw::size>()?;
input.parse::<syn::Token![=]>()?;
let size = input.parse::<syn::LitInt>()?;
size.base10_parse::<usize>()?
} else {
return Err(parse::Error::new(
proc_macro2::Span::call_site(),
"unexpected tokens, expected size specifier",
));
};
input.parse::<syn::Token![,]>()?;
let accessors = if input.peek(kw::pod_accessors) {
input.parse::<kw::pod_accessors>()?;
AccessorKind::Pod
} else if input.peek(kw::checked_accessors) {
input.parse::<kw::checked_accessors>()?;
AccessorKind::Checked
} else {
return Err(parse::Error::new(
proc_macro2::Span::call_site(),
"unexpected tokens, expected accessor kind",
));
};
if !input.is_empty() {
return Err(parse::Error::new(
proc_macro2::Span::call_site(),
"unexpected tokens at end",
));
}
Ok(Args { size, accessors })
}
try_parse_args(input)
.map_err(|e| parse::Error::new(
e.span(),
format!("{e}:\n\nexpected #[ffi_union(size = <size>, <pod_accessors | checked_accessors>)]"),
))
}
}
pub struct ExpansionExtras {
pub extras: Vec<proc_macro2::TokenStream>,
pub accessors: Vec<proc_macro2::TokenStream>,
}
pub fn expand(input: &mut syn::ItemUnion, args: Args) -> parse::Result<ExpansionExtras> {
let size = args.size;
let mut extras = Vec::new();
let mut accessors = Vec::new();
let mut processed_types = HashSet::new();
for field in input.fields.named.iter_mut() {
let orig_field_ty = field.ty.clone();
let should_impl_no_padding = processed_types.insert(orig_field_ty.clone());
let field_ident = field.ident.as_ref().unwrap();
let const_ident = quote::format_ident!("__{}_{}_PADDING", &input.ident, field_ident);
let new_field_ty =
syn::parse2(quote!(crate::TransparentPad<#orig_field_ty, #const_ident>))?;
extras.push(quote! {
#[allow(non_upper_case_globals)]
const #const_ident: usize = #size - ::core::mem::size_of::<#orig_field_ty>();
});
if should_impl_no_padding {
let size_check_type_ident =
quote::format_ident!("PaddedField_{}_PaddedToAlign", field_ident);
extras.push(quote! {
const _: fn() = || {
#[repr(transparent)]
#[allow(non_camel_case_types)]
struct #size_check_type_ident(#new_field_ty);
let _ = ::core::mem::transmute::<#size_check_type_ident, [u8; #size]>;
};
unsafe impl ::bytemuck::NoUninit for #new_field_ty {}
});
}
let accessor_ident = quote::format_ident!("as_{}", field_ident);
match args.accessors {
AccessorKind::Pod => {
let accessor_doc = format!("Access `self` as `{field_ident}`.");
accessors.push(quote!(
#[doc = #accessor_doc]
#[inline]
pub fn #accessor_ident(&self) -> &#orig_field_ty {
&(::bytemuck::cast_ref::<Self, #new_field_ty>(self).0)
}
));
}
AccessorKind::Checked => {
let try_accessor_ident = quote::format_ident!("try_as_{}", field_ident);
let try_accessor_doc = format!("Attempt to access `self` as `{field_ident}`.\n\n\
Will succeed if the cast is safe, even if `{field_ident}` was not the logically inhabited form\
of `self`, i.e. does not check that `self` was last written as `{field_ident}` as long\
as the underlying memory is *safe* (not necessarily logical) to interpret as `{field_ident}`.");
let accessor_doc = format!(
"Access `self` as `{field_ident}`. Same conditions as [`{try_accessor_ident}`]\
but panics on failure."
);
accessors.push(quote!(
#[doc = #try_accessor_doc]
#[inline]
pub fn #try_accessor_ident(&self) -> ::core::result::Result<&#orig_field_ty, ::bytemuck::checked::CheckedCastError> {
::bytemuck::checked::try_cast_ref::<Self, #new_field_ty>(self).map(|padded| &padded.0)
}
#[doc = #accessor_doc]
#[inline]
pub fn #accessor_ident(&self) -> &#orig_field_ty {
&::bytemuck::checked::cast_ref::<Self, #new_field_ty>(self).0
}
));
}
}
field.ty = new_field_ty;
}
let union_ident = &input.ident;
extras.push(quote! {
unsafe impl ::bytemuck::AnyBitPattern for #union_ident {}
unsafe impl ::bytemuck::Zeroable for #union_ident {}
unsafe impl ::bytemuck::NoUninit for #union_ident {}
});
Ok(ExpansionExtras { extras, accessors })
}