mm1-proc-macros 0.7.20

M/M/1! Queueing, do you speak it?!
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Generics, Ident, Item, LitBool, Path, PredicateType, Token, parse_macro_input};

struct MessageArgs {
    base_path:          Option<Path>,
    derive_serialize:   bool,
    derive_deserialize: bool,
}
pub(crate) fn message(attr: TokenStream, input: TokenStream) -> TokenStream {
    let args = parse_macro_input!(attr as MessageArgs);
    let input = parse_macro_input!(input as Item);

    let (ident, mut generics): (Ident, Generics) = match input {
        Item::Struct(ref s) => (s.ident.clone(), s.generics.clone()),
        Item::Enum(ref e) => (e.ident.clone(), e.generics.clone()),
        _ => {
            return syn::Error::new_spanned(
                input,
                "#[message] can only be applied to a struct or an enum",
            )
            .to_compile_error()
            .into()
        },
    };

    let where_clause = generics.make_where_clause();
    where_clause.predicates.push(
        PredicateType {
            lifetimes:   None,
            bounded_ty:  syn::parse_quote!(Self),
            colon_token: Default::default(),
            bounds:      {
                let mut b = Punctuated::new();
                b.push(syn::parse_quote!(::serde::Serialize));
                b.push(syn::parse_quote!(Send));
                b.push(syn::parse_quote!('static));
                b
            },
        }
        .into(),
    );
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let base_path: Path = args
        .base_path
        .unwrap_or_else(|| syn::parse_quote!(::mm1::proto));

    let derive_serialize = args
        .derive_serialize
        .then_some(quote!( #[derive(::serde::Serialize)] ));
    let derive_deserialize = args
        .derive_deserialize
        .then_some(quote!( #[derive(::serde::Deserialize)] ));

    let automatically_derived = quote! { #[automatically_derived] };

    quote! {
        #derive_serialize
        #derive_deserialize
        #input

        #automatically_derived
        impl #impl_generics #base_path :: Message for #ident #ty_generics #where_clause {}
    }
    .into()
}

enum MessageArg {
    BasePath(Path),
    DeriveSerialize(bool),
    DeriveDeserialize(bool),
}

impl Parse for MessageArgs {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        if input.is_empty() {
            return Ok(Self {
                base_path:          None,
                derive_serialize:   true,
                derive_deserialize: true,
            });
        }

        let mut base_path = None;
        let mut derive_serialize = true;
        let mut derive_deserialize = true;
        let args = Punctuated::<MessageArg, Comma>::parse_terminated(input)?;

        for arg in args {
            match arg {
                MessageArg::BasePath(path) => {
                    base_path = Some(path);
                },
                MessageArg::DeriveSerialize(value) => {
                    derive_serialize = value;
                },
                MessageArg::DeriveDeserialize(value) => {
                    derive_deserialize = value;
                },
            }
        }

        Ok(Self {
            base_path,
            derive_serialize,
            derive_deserialize,
        })
    }
}

impl Parse for MessageArg {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let ident: Ident = input.parse()?;
        let _eq_token: Token![=] = input.parse()?;

        if ident == "base_path" {
            let path: Path = input.parse()?;
            Ok(MessageArg::BasePath(path))
        } else if ident == "derive_serialize" {
            let derive_serialize: LitBool = input.parse()?;
            Ok(MessageArg::DeriveSerialize(derive_serialize.value))
        } else if ident == "derive_deserialize" {
            let derive_deserialize: LitBool = input.parse()?;
            Ok(MessageArg::DeriveDeserialize(derive_deserialize.value))
        } else {
            Err(syn::Error::new_spanned(
                ident,
                "unexpected argument, expected `base_path`",
            ))
        }
    }
}