ark-api-macros 0.13.0

Macros utilities for Ark API
Documentation
//! See the documentation on the `ffi_union` macro definition in the root of this crate
//! for more info on the purpose of this macro and how to use it.

use std::collections::HashSet;

use quote::quote;
use syn::parse;

#[derive(Copy, Clone)]
enum AccessorKind {
    Pod,
    Checked,
}

/// The arguments necessary for the proc macro. Expected in the form:
/// `#[ffi_union(size = <size>, <pod_accessors | checked_accessors>)]`
#[derive(Copy, Clone)]
pub struct Args {
    /// size = <size>
    size: usize,
    /// <pod_accessors | checked_accessors>
    accessors: AccessorKind,
}

mod kw {
    syn::custom_keyword!(size);
    syn::custom_keyword!(pod_accessors);
    syn::custom_keyword!(checked_accessors);
}

impl syn::parse::Parse for Args {
    fn parse(input: syn::parse::ParseStream<'_>) -> parse::Result<Self> {
        fn try_parse_args(input: syn::parse::ParseStream<'_>) -> parse::Result<Args> {
            let size = if input.peek(kw::size) {
                input.parse::<kw::size>()?;
                input.parse::<syn::Token![=]>()?;
                let size = input.parse::<syn::LitInt>()?;
                size.base10_parse::<usize>()?
            } else {
                return Err(parse::Error::new(
                    proc_macro2::Span::call_site(),
                    "unexpected tokens, expected size specifier",
                ));
            };

            input.parse::<syn::Token![,]>()?;

            let accessors = if input.peek(kw::pod_accessors) {
                input.parse::<kw::pod_accessors>()?;
                AccessorKind::Pod
            } else if input.peek(kw::checked_accessors) {
                input.parse::<kw::checked_accessors>()?;
                AccessorKind::Checked
            } else {
                return Err(parse::Error::new(
                    proc_macro2::Span::call_site(),
                    "unexpected tokens, expected accessor kind",
                ));
            };

            if !input.is_empty() {
                return Err(parse::Error::new(
                    proc_macro2::Span::call_site(),
                    "unexpected tokens at end",
                ));
            }

            Ok(Args { size, accessors })
        }

        try_parse_args(input)
            .map_err(|e| parse::Error::new(
                e.span(),
                format!("{e}:\n\nexpected #[ffi_union(size = <size>, <pod_accessors | checked_accessors>)]"),
            ))
    }
}

pub struct ExpansionExtras {
    pub extras: Vec<proc_macro2::TokenStream>,
    pub accessors: Vec<proc_macro2::TokenStream>,
}

/// Modifies the input `syn::ItemUnion` definition and also creates the returned `ExpansionExtras`.
pub fn expand(input: &mut syn::ItemUnion, args: Args) -> parse::Result<ExpansionExtras> {
    let size = args.size;
    let mut extras = Vec::new();
    let mut accessors = Vec::new();
    let mut processed_types = HashSet::new();
    for field in input.fields.named.iter_mut() {
        let orig_field_ty = field.ty.clone();
        let should_impl_no_padding = processed_types.insert(orig_field_ty.clone());
        let field_ident = field.ident.as_ref().unwrap();

        let const_ident = quote::format_ident!("__{}_{}_PADDING", &input.ident, field_ident);
        let new_field_ty =
            syn::parse2(quote!(crate::TransparentPad<#orig_field_ty, #const_ident>))?;
        extras.push(quote! {
            #[allow(non_upper_case_globals)]
            const #const_ident: usize = #size - ::core::mem::size_of::<#orig_field_ty>();

        });

        if should_impl_no_padding {
            let size_check_type_ident =
                quote::format_ident!("PaddedField_{}_PaddedToAlign", field_ident);
            extras.push(quote! {
                const _: fn() = || {
                    #[repr(transparent)]
                    #[allow(non_camel_case_types)]
                    struct #size_check_type_ident(#new_field_ty);
                    let _ = ::core::mem::transmute::<#size_check_type_ident, [u8; #size]>;
                };

                // SAFETY: Derived as part of ffi_union macro, needed requirements (that
                // there is no extra automatic padding added due to alignment rules) are statically
                // asserted by the macro
                unsafe impl ::bytemuck::NoUninit for #new_field_ty {}
            });
        }

        let accessor_ident = quote::format_ident!("as_{}", field_ident);
        match args.accessors {
            AccessorKind::Pod => {
                let accessor_doc = format!("Access `self` as `{field_ident}`.");

                accessors.push(quote!(
                    #[doc = #accessor_doc]
                    #[inline]
                    pub fn #accessor_ident(&self) -> &#orig_field_ty {
                        &(::bytemuck::cast_ref::<Self, #new_field_ty>(self).0)
                    }
                ));
            }
            AccessorKind::Checked => {
                let try_accessor_ident = quote::format_ident!("try_as_{}", field_ident);

                let try_accessor_doc = format!("Attempt to access `self` as `{field_ident}`.\n\n\
                    Will succeed if  the cast is safe, even if `{field_ident}` was not the logically inhabited form\
                    of `self`, i.e. does not check that `self` was last written as `{field_ident}` as long\
                    as the underlying memory is *safe* (not necessarily logical) to interpret as `{field_ident}`.");

                let accessor_doc = format!(
                    "Access `self` as `{field_ident}`. Same conditions as [`{try_accessor_ident}`]\
                    but panics on failure."
                );

                accessors.push(quote!(
                    #[doc = #try_accessor_doc]
                    #[inline]
                    pub fn #try_accessor_ident(&self) -> ::core::result::Result<&#orig_field_ty, ::bytemuck::checked::CheckedCastError> {
                        ::bytemuck::checked::try_cast_ref::<Self, #new_field_ty>(self).map(|padded| &padded.0)
                    }

                    #[doc = #accessor_doc]
                    #[inline]
                    pub fn #accessor_ident(&self) -> &#orig_field_ty {
                        &::bytemuck::checked::cast_ref::<Self, #new_field_ty>(self).0
                    }
                ));
            }
        }

        field.ty = new_field_ty;
    }

    let union_ident = &input.ident;
    extras.push(quote! {
        // SAFETY: All unions are always AnyBitPattern, this is manually implemented
        // just to avoid an extra recursion of proc macros for the derive
        unsafe impl ::bytemuck::AnyBitPattern for #union_ident {}

        // SAFETY: All unions are always Zeroable, this is manually implemented
        // just to avoid an extra recursion of proc macros for the derive
        unsafe impl ::bytemuck::Zeroable for #union_ident {}

        // SAFETY: The derived padded fields will all have the same size and no padding
        // as long as the static asserts below work. Since they are all the same size and
        // have no padding, the union will also never have any padding. It's technically still
        // possible to get uninit memory in the union in a sound manner, but only through using
        // unsafe code. Use of the ffi_union macro makes you assert that you'll never do this,
        // so we can assume it won't happen.
        unsafe impl ::bytemuck::NoUninit for #union_ident {}
    });

    Ok(ExpansionExtras { extras, accessors })
}