rkyv_derive 0.8.16

Derive macro for rkyv
Documentation
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
    parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DeriveInput,
    Error, Fields, Generics, Ident, Index, Path, WhereClause,
};

use crate::{
    attributes::{Attributes, FieldAttributes, VariantAttributes},
    util::{strip_generics_from_path, strip_raw},
};

pub fn derive(input: DeriveInput) -> Result<TokenStream, Error> {
    let attributes = Attributes::parse(&input)?;
    derive_serialize_impl(input, &attributes)
}

fn derive_serialize_impl(
    mut input: DeriveInput,
    attributes: &Attributes,
) -> Result<TokenStream, Error> {
    let rkyv_path = attributes.crate_path();

    let where_clause = input.generics.make_where_clause();
    if let Some(ref bounds) = attributes.archive_bounds {
        for bound in bounds {
            where_clause.predicates.push(bound.clone());
        }
    }
    if let Some(ref bounds) = attributes.serialize_bounds {
        for bound in bounds {
            where_clause.predicates.push(bound.clone());
        }
    }

    let mut impl_input_params = Punctuated::default();
    impl_input_params
        .push(parse_quote! { __S: #rkyv_path::rancor::Fallible + ?Sized });
    for param in input.generics.params.iter() {
        impl_input_params.push(param.clone());
    }
    let impl_input_generics = Generics {
        lt_token: Some(Default::default()),
        params: impl_input_params,
        gt_token: Some(Default::default()),
        where_clause: input.generics.where_clause.clone(),
    };

    let name = &input.ident;
    let (impl_generics, ..) = impl_input_generics.split_for_impl();
    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
    let where_clause = where_clause.unwrap();

    let resolver = attributes.resolver.as_ref().map_or_else(
        || Ident::new(&format!("{}Resolver", strip_raw(name)), name.span()),
        |value| value.clone(),
    );

    let mut serialize_where = where_clause.clone();

    if let Some(ref remote) = attributes.remote {
        let body = generate_serialize_body(
            &input,
            attributes,
            &mut serialize_where,
            &rkyv_path,
            resolver,
            strip_generics_from_path(remote.clone()),
        )?;

        Ok(quote! {
            #[automatically_derived]
            impl #impl_generics #rkyv_path::with::SerializeWith<#remote, __S>
                for #name #ty_generics
            #serialize_where
            {
                fn serialize_with(
                    field: &#remote,
                    serializer: &mut __S,
                ) -> ::core::result::Result<
                    <Self as #rkyv_path::with::ArchiveWith<#remote>>::Resolver,
                    <__S as #rkyv_path::rancor::Fallible>::Error,
                > {
                    let __this = field;
                    #body
                }
            }
        })
    } else {
        let body = generate_serialize_body(
            &input,
            attributes,
            &mut serialize_where,
            &rkyv_path,
            resolver,
            parse_quote!(#name),
        )?;

        Ok(quote! {
            #[automatically_derived]
            impl #impl_generics #rkyv_path::Serialize<__S>
                for #name #ty_generics
            #serialize_where
            {
                fn serialize(
                    &self,
                    serializer: &mut __S,
                ) -> ::core::result::Result<
                    <Self as #rkyv_path::Archive>::Resolver,
                    <__S as #rkyv_path::rancor::Fallible>::Error,
                > {
                    let __this = self;
                    #body
                }
            }
        })
    }
}

fn generate_serialize_body(
    input: &DeriveInput,
    attributes: &Attributes,
    serialize_where: &mut WhereClause,
    rkyv_path: &Path,
    resolver: Ident,
    name: Path,
) -> Result<TokenStream, Error> {
    let this = Ident::new("__this", Span::call_site());
    let body = match input.data {
        Data::Struct(ref data) => match data.fields {
            Fields::Named(ref fields) => {
                let resolver_values = fields
                    .named
                    .iter()
                    .map(|field| {
                        let field_attrs =
                            FieldAttributes::parse(attributes, field)?;

                        serialize_where.predicates.extend(
                            field_attrs.serialize_bound(rkyv_path, field),
                        );

                        let name = &field.ident;
                        let access_field =
                            field_attrs.access_field(&this, name);
                        let serialize = field_attrs.serialize(rkyv_path, field);
                        Ok(quote! {
                            #name: #serialize(#access_field, serializer)?
                        })
                    })
                    .collect::<Result<Vec<_>, Error>>()?;

                quote! { #resolver { #(#resolver_values,)* } }
            }
            Fields::Unnamed(ref fields) => {
                let resolver_values = fields
                    .unnamed
                    .iter()
                    .enumerate()
                    .map(|(i, field)| {
                        let field_attrs =
                            FieldAttributes::parse(attributes, field)?;

                        serialize_where.predicates.extend(
                            field_attrs.serialize_bound(rkyv_path, field),
                        );

                        let index = Index::from(i);
                        let access_field =
                            field_attrs.access_field(&this, &index);
                        let serialize = field_attrs.serialize(rkyv_path, field);
                        Ok(quote! { #serialize(#access_field, serializer)? })
                    })
                    .collect::<Result<Vec<_>, Error>>()?;

                quote! { #resolver(#(#resolver_values,)*) }
            }
            Fields::Unit => quote! { #resolver },
        },
        Data::Enum(ref data) => {
            let mut other: Option<Path> = None;
            let serialize_arms = data
                .variants
                .iter()
                .map(|v| {
                    if let Some(ref other) = other {
                        return Err(Error::new_spanned(
                            other,
                            "Only the very last variant may be denoted with \
                             `#[rkyv(other)]`",
                        ));
                    }
                    let variant_attrs =
                        VariantAttributes::parse(attributes, v)?;
                    let variant = &v.ident;
                    match v.fields {
                        Fields::Named(ref fields) => {
                            let bindings =
                                fields.named.iter().map(|f| &f.ident);
                            let fields = fields
                                .named
                                .iter()
                                .map(|field| {
                                    let field_attrs = FieldAttributes::parse(
                                        attributes, field,
                                    )?;

                                    serialize_where.predicates.extend(
                                        field_attrs
                                            .serialize_bound(rkyv_path, field),
                                    );

                                    let name = &field.ident;
                                    let serialize =
                                        field_attrs.serialize(rkyv_path, field);
                                    Ok(quote! {
                                        #name: #serialize(#name, serializer)?
                                    })
                                })
                                .collect::<Result<Vec<_>, Error>>()?;
                            Ok(quote! {
                                #name::#variant {
                                    #(#bindings,)*..
                                } => #resolver::#variant {
                                    #(#fields,)*
                                }
                            })
                        }
                        Fields::Unnamed(ref fields) => {
                            let bindings =
                                fields.unnamed.iter().enumerate().map(
                                    |(i, f)| {
                                        Ident::new(&format!("_{i}"), f.span())
                                    },
                                );

                            let fields = fields
                                .unnamed
                                .iter()
                                .enumerate()
                                .map(|(i, field)| {
                                    let field_attrs = FieldAttributes::parse(
                                        attributes, field,
                                    )?;

                                    serialize_where.predicates.extend(
                                        field_attrs
                                            .serialize_bound(rkyv_path, field),
                                    );

                                    let binding = Ident::new(
                                        &format!("_{i}"),
                                        field.span(),
                                    );
                                    let serialize =
                                        field_attrs.serialize(rkyv_path, field);
                                    Ok(quote! {
                                        #serialize(#binding, serializer)?
                                    })
                                })
                                .collect::<Result<Vec<_>, Error>>()?;
                            Ok(quote! {
                                #name::#variant(
                                    #(#bindings,)*..
                                ) => #resolver::#variant(#(#fields,)*)
                            })
                        }
                        Fields::Unit => {
                            if variant_attrs.other.is_some() {
                                other = variant_attrs.other;
                                Ok(quote! { _ => #resolver::#variant })
                            } else {
                                Ok(quote! {
                                    #name::#variant => #resolver::#variant
                                })
                            }
                        }
                    }
                })
                .collect::<Result<Vec<_>, Error>>()?;

            quote! {
                match __this {
                    #(#serialize_arms,)*
                }
            }
        }
        Data::Union(_) => {
            return Err(Error::new_spanned(
                input,
                "Serialize cannot be derived for unions",
            ))
        }
    };

    Ok(quote! { ::core::result::Result::Ok(#body) })
}