use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{
Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, GenericParam, Generics,
Ident, Member, Path, Token, TraitBound, Type, TypeParamBound, Visibility, parse_macro_input,
parse_quote,
};
#[derive(Clone, Copy)]
enum Repr {
C,
Transparent,
}
fn get_repr(attrs: &[Attribute]) -> syn::Result<Repr> {
let mut repr = None;
for attr in attrs {
if !attr.path().is_ident("repr") {
continue;
}
if repr.is_some() {
return Err(syn::Error::new_spanned(
attr,
"only one #[repr(...)] allowed",
));
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("C") {
repr = Some(Repr::C);
Ok(())
} else if meta.path.is_ident("transparent") {
repr = Some(Repr::Transparent);
Ok(())
} else {
Err(meta.error("only #[repr(C)] and #[repr(transparent)] are supported"))
}
})?;
}
let Some(repr) = repr else {
return Err(syn::Error::new(
Span::call_site(),
"type must be #[repr(C)] or #[repr(transparent)]",
));
};
Ok(repr)
}
fn get_fields(
data: &Data,
) -> syn::Result<(
impl Iterator<Item = Member> + Clone,
impl Iterator<Item = &Type> + Clone,
usize,
)> {
Ok(match data {
Data::Struct(DataStruct { fields, .. }) => {
(fields.members(), fields.iter().map(|f| &f.ty), fields.len())
}
Data::Enum(DataEnum { enum_token, .. }) => {
return Err(Error::new_spanned(enum_token, "only structs are supported"));
}
Data::Union(DataUnion { union_token, .. }) => {
return Err(Error::new_spanned(
union_token,
"only structs are supported",
));
}
})
}
struct DstAttrs {
simple_dst_path: Path,
new_unchecked_vis: Visibility,
}
fn get_dst_attrs(attrs: &[Attribute]) -> syn::Result<DstAttrs> {
let mut simple_dst_path: Option<Path> = None;
let mut new_unchecked_vis: Option<Visibility> = None;
for attr in attrs {
if !attr.path().is_ident("dst") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("simple_dst_path") {
if simple_dst_path.is_some() {
return Err(meta.error("only one #[dst(simple_dst_path = ...)] is allowed"));
}
simple_dst_path = Some({
meta.input.parse::<Token![=]>()?;
meta.input.parse()?
});
} else if meta.path.is_ident("new_unchecked_vis") {
if new_unchecked_vis.is_some() {
return Err(meta.error("only one #[dst(new_unchecked_vis = ...)] is allowed"));
}
new_unchecked_vis = Some({
meta.input.parse::<Token![=]>()?;
meta.input.parse()?
});
} else {
return Err(meta.error("unrecognised #[dst(...)] argument"));
}
Ok(())
})?;
}
let dst_attrs = DstAttrs {
simple_dst_path: simple_dst_path.unwrap_or_else(|| parse_quote! { ::simple_dst }),
new_unchecked_vis: new_unchecked_vis.unwrap_or(Visibility::Inherited),
};
Ok(dst_attrs)
}
fn has_unsized_bound<'a>(bounds: impl Iterator<Item = &'a TypeParamBound>) -> bool {
for bound in bounds {
if let TypeParamBound::Trait(TraitBound {
modifier: syn::TraitBoundModifier::Maybe(_),
lifetimes: None,
path,
..
}) = bound
&& path.is_ident("Sized")
{
return true;
}
}
false
}
fn add_dst_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(type_param) = param
&& has_unsized_bound(type_param.bounds.iter())
{
type_param
.bounds
.push(parse_quote! { #simple_dst_path::Dst });
type_param
.bounds
.push(parse_quote! { #simple_dst_path::CloneToUninit });
}
}
generics
}
#[proc_macro_derive(Dst, attributes(dst))]
pub fn derive_dst(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_dst_impl(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn get_internal_layout_fn(
simple_dst_path: &Path,
repr: Repr,
n_fields: usize,
idxs: &[usize],
first_tys: &[&Type],
last_ty: Option<&Type>,
) -> TokenStream {
match repr {
Repr::C => quote!(
{
let layouts = [#(::core::alloc::Layout::new::<#first_tys>()),*, <#last_ty as #simple_dst_path::Dst>::layout(len)?];
let mut offsets = [0; #n_fields];
let layout = ::core::alloc::Layout::from_size_align(0, 1)?;
#(
let (layout, offset) = layout.extend(layouts[#idxs])?;
offsets[#idxs] = offset;
)*
::core::result::Result::Ok((layout.pad_to_align(), offsets))
}
),
Repr::Transparent => quote!(
{
::core::result::Result::Ok((<#last_ty as #simple_dst_path::Dst>::layout(len)?, [0; #n_fields]))
}
),
}
}
fn derive_dst_impl(input: DeriveInput) -> syn::Result<TokenStream> {
let repr = get_repr(&input.attrs)?;
let name = input.ident;
let DstAttrs {
simple_dst_path,
new_unchecked_vis,
} = get_dst_attrs(&input.attrs)?;
let generics = add_dst_trait_bounds(input.generics, &simple_dst_path);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let (members, tys, n_fields) = get_fields(&input.data)?;
if n_fields == 0 {
return Err(Error::new_spanned(
name,
"type must have at least one field",
));
}
let idxs: Vec<_> = (0..n_fields).collect();
let first_idxs: Vec<_> = (0..n_fields - 1).collect();
let last_idx = n_fields - 1;
let last_member = members.clone().nth(last_idx);
let member_var_names: Vec<_> = members
.clone()
.map(|m| match m {
Member::Named(ident) => ident,
Member::Unnamed(index) => format_ident!("var_{}", index),
})
.collect();
let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
let last_member_var_name = member_var_names.get(last_idx);
let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
let last_ty = tys.clone().nth(last_idx);
let internal_layout_fn =
get_internal_layout_fn(&simple_dst_path, repr, n_fields, &idxs, &first_tys, last_ty);
Ok(quote! {
#[automatically_derived]
unsafe impl #impl_generics #simple_dst_path::Dst for #name #ty_generics #where_clause {
fn len(&self) -> usize {
#simple_dst_path::Dst::len(&self.#last_member)
}
fn layout(len: usize) -> ::core::result::Result<::core::alloc::Layout, ::core::alloc::LayoutError> {
let (layout, _) = Self::__dst_impl_layout_offsets(len)?;
::core::result::Result::Ok(layout)
}
fn retype(ptr: ::core::ptr::NonNull<u8>, len: usize) -> ::core::ptr::NonNull<Self> {
unsafe {
#[allow(
clippy::cast_ptr_alignment,
reason = "the responsibility to provide a pointer with the correct alignment is on the caller"
)]
::core::ptr::NonNull::new_unchecked(::core::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len) as *mut Self)
}
}
}
#[automatically_derived]
impl #impl_generics #name #ty_generics #where_clause {
#[doc(hidden)]
#[inline]
fn __dst_impl_layout_offsets(len: usize) -> ::core::result::Result<(::core::alloc::Layout, [usize; #n_fields]), ::core::alloc::LayoutError>
#internal_layout_fn
#new_unchecked_vis unsafe fn new_unchecked<A: #simple_dst_path::AllocDst<Self>>(
#( #first_member_var_names: #first_tys, )*
#last_member_var_name: &#last_ty
) -> ::core::result::Result<A, ::core::alloc::LayoutError> {
let (layout, offsets) = Self::__dst_impl_layout_offsets(#last_member_var_name.len())?;
Ok(unsafe {
A::new_dst(<#last_ty as #simple_dst_path::Dst>::len(#last_member_var_name), layout, |ptr| {
let dest = ptr.cast::<u8>();
<#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(#last_member_var_name, dest.add(offsets[#last_idx]).as_ptr());
#(
dest.add(offsets[#first_idxs]).cast::<#first_tys>().write(#first_member_var_names);
)*
})
})
}
}
})
}
fn add_clone_to_uninit_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(type_param) = param {
let bound = if has_unsized_bound(type_param.bounds.iter()) {
parse_quote! { #simple_dst_path::CloneToUninit }
} else {
parse_quote! { ::core::clone::Clone }
};
type_param.bounds.push(bound);
}
}
generics
}
#[proc_macro_derive(CloneToUninit, attributes(dst))]
pub fn derive_clone_to_uninit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_clone_to_uninit_impl(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn derive_clone_to_uninit_impl(input: DeriveInput) -> syn::Result<TokenStream> {
let name = input.ident;
let DstAttrs {
simple_dst_path, ..
} = get_dst_attrs(&input.attrs)?;
let generics = add_clone_to_uninit_trait_bounds(input.generics, &simple_dst_path);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let (members, tys, n_fields) = get_fields(&input.data)?;
if n_fields == 0 {
return Err(Error::new_spanned(
name,
"type must have at least one field",
));
}
let last_idx = n_fields - 1;
let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
let last_member = members.clone().nth(last_idx);
let member_var_names: Vec<_> = members
.clone()
.map(|m| match m {
Member::Named(ident) => ident,
Member::Unnamed(index) => format_ident!("var_{}", index),
})
.collect();
let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
let last_ty = tys.clone().nth(last_idx);
Ok(quote! {
#[automatically_derived]
unsafe impl #impl_generics #simple_dst_path::CloneToUninit for #name #ty_generics #where_clause {
unsafe fn clone_to_uninit(&self, dest: *mut u8) {
let last_offset = unsafe { (&raw const self.#last_member).byte_offset_from_unsigned(self) };
#(
let #first_member_var_names = <#first_tys as ::core::clone::Clone>::clone(&self.#first_members);
)*
unsafe {
<#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(&self.#last_member, dest.add(last_offset));
#(
dest.add(::core::mem::offset_of!(Self, #first_member_var_names)).cast::<#first_tys>().write(#first_member_var_names);
)*
}
}
}
})
}
struct ToOwnedAttrs {
alloc_path: Path,
owned: Type,
}
fn get_to_owned_attrs(attrs: &[Attribute], name: &Ident) -> syn::Result<ToOwnedAttrs> {
let mut alloc_path: Option<Path> = None;
let mut owned: Option<Type> = None;
for attr in attrs {
if !attr.path().is_ident("to_owned") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("alloc_path") {
if alloc_path.is_some() {
return Err(meta.error("only one #[to_owned(alloc_path = ...)] is allowed"));
}
alloc_path = Some({
meta.input.parse::<Token![=]>()?;
meta.input.parse()?
});
} else if meta.path.is_ident("owned") {
if owned.is_some() {
return Err(meta.error("only one #[to_owned(owned = ...)] is allowed"));
}
owned = Some({
meta.input.parse::<Token![=]>()?;
meta.input.parse()?
});
} else {
return Err(meta.error("unrecognised #[to_owned(...)] argument"));
}
Ok(())
})?;
}
let alloc_path = alloc_path.unwrap_or_else(|| parse_quote! { ::std });
let to_owned_attrs = ToOwnedAttrs {
alloc_path: alloc_path.clone(),
owned: owned.unwrap_or_else(|| parse_quote! { #alloc_path::boxed::Box<#name> }),
};
Ok(to_owned_attrs)
}
#[proc_macro_derive(ToOwned, attributes(dst, to_owned))]
pub fn derive_to_owned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_to_owned_impl(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn derive_to_owned_impl(input: DeriveInput) -> syn::Result<TokenStream> {
let name = input.ident;
let DstAttrs {
simple_dst_path, ..
} = get_dst_attrs(&input.attrs)?;
let ToOwnedAttrs { alloc_path, owned } = get_to_owned_attrs(&input.attrs, &name)?;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Ok(quote! {
#[automatically_derived]
impl #impl_generics #alloc_path::borrow::ToOwned for #name #ty_generics #where_clause {
type Owned = #owned;
fn to_owned(&self) -> Self::Owned {
let layout = ::core::alloc::Layout::for_value(self);
unsafe {
<#owned as #simple_dst_path::AllocDst<#name>>::new_dst(
<#name as #simple_dst_path::Dst>::len(self),
layout,
|ptr| {
let dest = ptr.cast::<u8>();
<#name as #simple_dst_path::CloneToUninit>::clone_to_uninit(self, dest.as_ptr())
},
)
}
}
}
})
}