extern crate proc_macro;
mod traits;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Result};
use crate::traits::{
bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable,
NoUninit, Pod, TransparentWrapper, Zeroable,
};
#[proc_macro_derive(Pod, attributes(bytemuck))]
pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(AnyBitPattern, attributes(bytemuck))]
pub fn derive_anybitpattern(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(Zeroable, attributes(bytemuck, zeroable))]
pub fn derive_zeroable(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(NoUninit)]
pub fn derive_no_uninit(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(CheckedBitPattern)]
pub fn derive_maybe_pod(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(TransparentWrapper, attributes(bytemuck, transparent))]
pub fn derive_transparent(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
input as DeriveInput
));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(Contiguous)]
pub fn derive_contiguous(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
proc_macro::TokenStream::from(expanded)
}
#[proc_macro_derive(ByteEq)]
pub fn derive_byte_eq(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
proc_macro::TokenStream::from(quote! {
impl ::core::cmp::PartialEq for #ident {
#[inline]
#[must_use]
fn eq(&self, other: &Self) -> bool {
#crate_name::bytes_of(self) == #crate_name::bytes_of(other)
}
}
impl ::core::cmp::Eq for #ident { }
})
}
#[proc_macro_derive(ByteHash)]
pub fn derive_byte_hash(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
proc_macro::TokenStream::from(quote! {
impl ::core::hash::Hash for #ident {
#[inline]
fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
::core::hash::Hash::hash_slice(#crate_name::bytes_of(self), state)
}
#[inline]
fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
::core::hash::Hash::hash_slice(#crate_name::cast_slice::<_, u8>(data), state)
}
}
})
}
fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
derive_marker_trait_inner::<Trait>(input)
.unwrap_or_else(|err| err.into_compile_error())
}
fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
example_value: &str, invalid_value_msg: &str,
) -> Result<Vec<P::Output>> {
let invalid_format_msg =
format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
syn::Meta::Path(path) => path
.is_ident(name)
.then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
syn::Meta::NameValue(namevalue) => {
namevalue.path.is_ident(name).then(|| {
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
})
}
syn::Meta::List(list) => list.path.is_ident(name).then(|| {
let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
.map_err(|_| {
syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
})?;
if namevalue.path.is_ident(key) {
match namevalue.value {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(strlit), ..
}) => Ok(strlit),
_ => {
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
}
}
} else {
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
}
}),
});
values_to_check
.map(|lit| {
let lit = lit?;
lit.parse_with(parser).map_err(|err| {
syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
})
})
.collect()
}
fn derive_marker_trait_inner<Trait: Derivable>(
mut input: DeriveInput,
) -> Result<TokenStream> {
let crate_name = bytemuck_crate_name(&input);
let trait_ = Trait::ident(&input, &crate_name)?;
if let Some(name) = Trait::explicit_bounds_attribute_name() {
let explicit_bounds = find_and_parse_helper_attributes(
&input.attrs,
name,
"bound",
<syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
"Type: Trait",
"invalid where predicate",
)?;
if !explicit_bounds.is_empty() {
let explicit_bounds = explicit_bounds
.into_iter()
.flatten()
.collect::<Vec<syn::WherePredicate>>();
let predicates = &mut input.generics.make_where_clause().predicates;
predicates.extend(explicit_bounds);
let fields = match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
syn::Data::Union(_) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for unions",
));
}
syn::Data::Enum(_) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for enums",
));
}
};
for field in fields {
let ty = field.ty;
predicates.push(syn::parse_quote!(
#ty: #trait_
));
}
} else {
add_trait_marker(&mut input.generics, &trait_);
}
} else {
add_trait_marker(&mut input.generics, &trait_);
}
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();
Trait::check_attributes(&input.data, &input.attrs)?;
let asserts = Trait::asserts(&input, &crate_name)?;
let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input, &crate_name)?;
let implies_trait = if let Some(implies_trait) =
Trait::implies_trait(&crate_name)
{
quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
} else {
quote!()
};
let where_clause =
if Trait::requires_where_clause() { where_clause } else { None };
Ok(quote! {
#asserts
#trait_impl_extras
unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
#trait_impl
}
#implies_trait
})
}
fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
let type_params = generics
.type_params()
.map(|param| ¶m.ident)
.map(|param| {
syn::parse_quote!(
#param: #trait_name
)
})
.collect::<Vec<syn::WherePredicate>>();
generics.make_where_clause().predicates.extend(type_params);
}