serde_dis 0.1.3

derive macro for enums
Documentation
// copyright 2023 Remi Bernotavicius

#![doc = include_str!("../README.md")]

use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::spanned::Spanned as _;
use syn::{
    parse_macro_input, parse_quote, Arm, Attribute, ConstParam, Data, DataEnum, DeriveInput, Error,
    Expr, FieldValue, Fields, GenericArgument, GenericParam, Generics, Ident, ItemImpl,
    LifetimeParam, LitStr, Pat, Result, Token, Type, TypeParam, TypeParamBound, WhereClause,
};

fn find_repr(attrs: &[Attribute]) -> Result<Ident> {
    let repr_attr = attrs
        .iter()
        .find(|a| a.path().is_ident("repr"))
        .ok_or(Error::new(Span::call_site(), "missing repr attribute"))?;

    repr_attr.parse_args()
}

fn find_rename(attrs: &[Attribute]) -> Result<Option<LitStr>> {
    Ok(attrs
        .iter()
        .filter_map(|a| {
            let mut rename = None;
            if a.path().is_ident("serde") {
                a.parse_nested_meta(|meta| {
                    if meta.path.is_ident("rename") {
                        let _: Token![=] = meta.input.parse()?;
                        rename = Some(meta.input.parse()?)
                    }
                    Ok(())
                })
                .unwrap();
            }
            rename
        })
        .next())
}

fn has_other_attr(attrs: &[Attribute]) -> bool {
    attrs.iter().any(|a| {
        if a.path().is_ident("serde") {
            let mut is_other = false;
            a.parse_nested_meta(|meta| {
                is_other = meta.path.is_ident("other");
                Ok(())
            })
            .unwrap();
            is_other
        } else {
            false
        }
    })
}

fn generics_to_args(generics: &Generics) -> Vec<GenericArgument> {
    generics
        .params
        .iter()
        .map(|p| -> GenericArgument {
            match p {
                GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => parse_quote!(#lifetime),
                GenericParam::Type(TypeParam { ident, .. }) => parse_quote!(#ident),
                GenericParam::Const(ConstParam { ident, .. }) => parse_quote!(#ident),
            }
        })
        .collect()
}

fn type_with_generics(ident: &Ident, generics: &Generics) -> Type {
    let filtered_generics = generics_to_args(generics);
    parse_quote!(#ident <#(#filtered_generics),*>)
}

fn generate_deserialize(
    self_ident: Ident,
    self_ser_name: String,
    self_generics: Generics,
    en: DataEnum,
    int_type: Ident,
) -> Result<ItemImpl> {
    let self_: Type = type_with_generics(&self_ident, &self_generics);

    let mut arms: Vec<Arm> = vec![];

    let mut default_arm: Arm = parse_quote! {
        _ => ::std::result::Result::Err(serde::de::Error::custom(
            ::std::format!(
                "unexpected value {disc:?} for {}",
                #self_ser_name
            )
        ))
    };

    for v in &en.variants {
        let name = &v.ident;
        let (_, disc) = v
            .discriminant
            .as_ref()
            .ok_or(Error::new(v.span(), "variant missing discriminant"))?;

        if has_other_attr(&v.attrs) {
            if !matches!(&v.fields, Fields::Unit) {
                return Err(Error::new(v.span(), "other must be used without fields"));
            }
            default_arm = parse_quote! {
                _ => ::std::result::Result::Ok(#self_ident::#name)
            };
            continue;
        }

        match &v.fields {
            Fields::Unit => arms.push(parse_quote! {
                v if v == #disc => ::std::result::Result::Ok(#self_ident::#name)
            }),
            Fields::Named(f) => {
                let fields = f.named.iter().map(|f| -> FieldValue {
                    let ident = &f.ident;
                    parse_quote! {
                        #ident: ::serde::de::SeqAccess::next_element(&mut seq)?
                            .ok_or(::serde::de::Error::custom("expected field"))?
                    }
                });
                arms.push(parse_quote! {
                    v if v == #disc => ::std::result::Result::Ok(#self_ident::#name {
                        #(#fields),*
                    })
                });
            }
            Fields::Unnamed(f) => {
                let fields = f.unnamed.iter().map(|_| -> Expr {
                    parse_quote! {
                        ::serde::de::SeqAccess::next_element(&mut seq)?
                            .ok_or(::serde::de::Error::custom("expected field"))?
                    }
                });
                arms.push(parse_quote! {
                    v if v == #disc => ::std::result::Result::Ok(#self_ident::#name(
                        #(#fields),*
                    ))
                });
            }
        }
    }

    let mut impl_generics = self_generics.clone();
    impl_generics.params.push(parse_quote!('de));

    let impl_where_clause =
        generate_where_clause(&self_generics, parse_quote!(::serde::Deserialize<'de>));

    let visitor_params = self_generics.params.iter().map(|p| -> Type {
        match p {
            GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => parse_quote!(&#lifetime ()),
            GenericParam::Type(TypeParam { ident, .. }) => parse_quote!(#ident),
            GenericParam::Const(ConstParam { ident, .. }) => parse_quote!([(); #ident]),
        }
    });

    let self_generic_args = generics_to_args(&self_generics);

    Ok(parse_quote! {
        impl #impl_generics ::serde::Deserialize<'de> for #self_ #impl_where_clause {
            fn deserialize<D>(deserializer: D) -> ::std::result::Result<#self_, D::Error>
                where
                D: ::serde::de::Deserializer<'de>,
            {
                struct Visitor #self_generics (::std::marker::PhantomData<(#(#visitor_params),*)>);

                impl #impl_generics ::serde::de::Visitor<'de> for Visitor <#(#self_generic_args),*>
                    #impl_where_clause
                {
                    type Value = #self_;

                    fn expecting(
                        &self, formatter: &mut ::std::fmt::Formatter
                    ) -> ::std::fmt::Result {
                        ::std::fmt::Formatter::write_str(formatter, #self_ser_name)
                    }

                    fn visit_seq<A>(
                        self, mut seq: A
                    ) -> ::std::result::Result<Self::Value, A::Error>
                        where
                            A: ::serde::de::SeqAccess<'de>
                    {
                        let disc: #int_type = ::serde::de::SeqAccess::next_element(&mut seq)?
                            .ok_or(::serde::de::Error::custom("expected discriminant"))?;
                        match disc {
                            #(#arms,)*
                            #default_arm
                        }
                    }
                }

                ::serde::de::Deserializer::deserialize_struct(
                    deserializer,
                    #self_ser_name,
                    &["disc", "value"],
                    Visitor(::std::marker::PhantomData)
                )
            }
        }
    })
}

fn generate_where_clause(self_generics: &Generics, bound: TypeParamBound) -> WhereClause {
    let predicates = self_generics
        .params
        .iter()
        .filter_map(|p| -> Option<TypeParam> {
            match p {
                GenericParam::Type(TypeParam { ident, .. }) => Some(parse_quote!(#ident: #bound)),
                _ => None,
            }
        });
    let impl_where_clause: WhereClause = parse_quote! {
        where
            #(#predicates),*
    };
    impl_where_clause
}

fn deserialize_with_discriminant_inner(input: DeriveInput) -> Result<ItemImpl> {
    if let Data::Enum(en) = input.data {
        let int_type = find_repr(&input.attrs)?;
        let ser_name = find_rename(&input.attrs)?
            .map(|n| n.value())
            .unwrap_or_else(|| input.ident.to_string());
        generate_deserialize(input.ident, ser_name, input.generics, en, int_type)
    } else {
        Err(Error::new(
            input.ident.span(),
            "Must be applied to `enum`s only",
        ))
    }
}

#[proc_macro_derive(DeserializeWithDiscriminant, attributes(serde))]
pub fn deserialize_with_discriminant(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    match deserialize_with_discriminant_inner(input) {
        Err(e) => e.into_compile_error().into(),
        Ok(v) => quote!(#v).into(),
    }
}

fn generate_serialize(
    self_ident: Ident,
    self_ser_name: String,
    self_generics: Generics,
    en: DataEnum,
    int_type: Ident,
) -> Result<ItemImpl> {
    let self_: Type = type_with_generics(&self_ident, &self_generics);

    let mut match_arms: Vec<Arm> = vec![];

    for v in &en.variants {
        let name = &v.ident;
        let (_, disc) = v
            .discriminant
            .as_ref()
            .ok_or(Error::new(v.span(), "variant missing discriminant"))?;

        let mut num_fields: usize = 1;
        let mut statements: Vec<Expr> = vec![];
        statements.push(parse_quote! {
            ::serde::ser::SerializeStruct::serialize_field::<#int_type>(
                &mut state, "discriminant", &(#disc)
            )?
        });
        let pattern: Pat = match &v.fields {
            Fields::Unit => parse_quote!(#self_ident::#name),
            Fields::Named(fields) => {
                num_fields += fields.named.len();
                for f in &fields.named {
                    let f_ident = f.ident.as_ref().unwrap();
                    let f_ident_name = f_ident.to_string();

                    statements.push(parse_quote! {
                        ::serde::ser::SerializeStruct::serialize_field(
                            &mut state, #f_ident_name, #f_ident
                        )?
                    });
                }
                let field_pattern = fields.named.iter().map(|f| &f.ident);
                parse_quote!(#self_ident::#name { #(#field_pattern),* })
            }
            Fields::Unnamed(fields) => {
                num_fields += fields.unnamed.len();
                let mut field_pattern = vec![];
                for i in 0..fields.unnamed.len() {
                    let f_ident = Ident::new(&format!("field{i}"), Span::call_site());
                    let f_ident_name = f_ident.to_string();

                    statements.push(parse_quote! {
                        ::serde::ser::SerializeStruct::serialize_field(
                            &mut state, #f_ident_name, #f_ident
                        )?
                    });
                    field_pattern.push(f_ident);
                }
                parse_quote!(#self_ident::#name ( #(#field_pattern),* ))
            }
        };

        match_arms.push(parse_quote! {
            #pattern => {
                let mut state = ::serde::ser::Serializer::serialize_struct(
                    serializer, #self_ser_name, #num_fields
                )?;
                #(#statements;)*
                ::serde::ser::SerializeStruct::end(state)
            }
        });
    }

    let impl_where_clause = generate_where_clause(&self_generics, parse_quote!(::serde::Serialize));

    Ok(parse_quote! {
        impl #self_generics ::serde::Serialize for #self_ #impl_where_clause {
            fn serialize<__S>(&self, serializer: __S) -> ::std::result::Result<__S::Ok, __S::Error>
                where
                    __S: ::serde::ser::Serializer
            {
                match self {
                    #(#match_arms),*
                }
            }
        }
    })
}

fn serialize_with_discriminant_inner(input: DeriveInput) -> Result<ItemImpl> {
    if let Data::Enum(en) = input.data {
        let int_type = find_repr(&input.attrs)?;
        let ser_name = find_rename(&input.attrs)?
            .map(|n| n.value())
            .unwrap_or_else(|| input.ident.to_string());
        generate_serialize(input.ident, ser_name, input.generics, en, int_type)
    } else {
        Err(Error::new(
            input.ident.span(),
            "Must be applied to `enum`s only",
        ))
    }
}

#[proc_macro_derive(SerializeWithDiscriminant, attributes(serde))]
pub fn serialize_with_discriminant(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    match serialize_with_discriminant_inner(input) {
        Err(e) => e.into_compile_error().into(),
        Ok(v) => quote!(#v).into(),
    }
}