maelstrom-macro 0.12.0

Macros for Maelstrom internal usage
Documentation
use heck::ToSnakeCase as _;
use proc_macro2::Span;
use syn::{
    parenthesized,
    parse::{Parse, ParseStream},
    parse_quote,
    punctuated::Punctuated,
    token::Paren,
    Attribute, Data, DeriveInput, Error, Ident, ItemMacro, Meta, Result, Token,
};

struct DeriveFieldAttribute {
    at_token: Token![@],
    field_or_variant: Ident,
    colon_token: Token![:],
    attr: Meta,
}

impl Parse for DeriveFieldAttribute {
    fn parse(input: ParseStream) -> Result<Self> {
        Ok(Self {
            at_token: input.parse()?,
            field_or_variant: input.parse()?,
            colon_token: input.parse()?,
            attr: input.parse()?,
        })
    }
}

impl quote::ToTokens for DeriveFieldAttribute {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        self.at_token.to_tokens(tokens);
        self.field_or_variant.to_tokens(tokens);
        self.colon_token.to_tokens(tokens);
        self.attr.to_tokens(tokens);
    }
}

enum DeriveAttribute {
    Container(Meta),
    FieldOrVariant(DeriveFieldAttribute),
}

impl Parse for DeriveAttribute {
    fn parse(input: ParseStream) -> Result<Self> {
        if input.peek(Token![@]) {
            Ok(Self::FieldOrVariant(input.parse()?))
        } else {
            Ok(Self::Container(input.parse()?))
        }
    }
}

impl quote::ToTokens for DeriveAttribute {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        match self {
            Self::Container(m) => m.to_tokens(tokens),
            Self::FieldOrVariant(f) => f.to_tokens(tokens),
        }
    }
}

struct DeriveAttrs {
    _comma_token: Token![,],
    attrs: Punctuated<DeriveAttribute, Token![,]>,
}

impl DeriveAttrs {
    fn container_attrs(&self) -> impl Iterator<Item = &Meta> {
        self.attrs.iter().filter_map(|a| {
            if let DeriveAttribute::Container(a) = a {
                Some(a)
            } else {
                None
            }
        })
    }

    fn field_or_variant_attrs(&self) -> impl Iterator<Item = &DeriveFieldAttribute> {
        self.attrs.iter().filter_map(|a| {
            if let DeriveAttribute::FieldOrVariant(a) = a {
                Some(a)
            } else {
                None
            }
        })
    }
}

pub struct Arguments {
    remote_type: Ident,
    _comma_token: Token![,],
    derive_macros: Punctuated<Ident, Token![,]>,
    derive_attrs: Option<DeriveAttrs>,
}

impl Parse for Arguments {
    fn parse(input: ParseStream) -> Result<Self> {
        Ok(Self {
            remote_type: input.parse()?,
            _comma_token: input.parse()?,
            derive_macros: if input.peek(Paren) {
                let content;
                parenthesized!(content in input);
                content.parse_terminated(Ident::parse, Token![,])?
            } else {
                [input.parse::<Ident>()?].into_iter().collect()
            },
            derive_attrs: input
                .peek(Token![,])
                .then(|| -> Result<_> {
                    Ok(DeriveAttrs {
                        _comma_token: input.parse()?,
                        attrs: input.parse_terminated(DeriveAttribute::parse, Token![,])?,
                    })
                })
                .transpose()?,
        })
    }
}

pub struct InnerArguments {
    input: DeriveInput,
    _comma_token: Token![,],
    remote_derive: Ident,
    derive_attrs: Option<DeriveAttrs>,
}

impl Parse for InnerArguments {
    fn parse(input: ParseStream) -> Result<Self> {
        Ok(Self {
            input: input.parse()?,
            _comma_token: input.parse()?,
            remote_derive: input.parse()?,
            derive_attrs: input
                .peek(Token![,])
                .then(|| -> Result<_> {
                    Ok(DeriveAttrs {
                        _comma_token: input.parse()?,
                        attrs: input.parse_terminated(
                            |s| {
                                let content;
                                parenthesized!(content in s);
                                DeriveAttribute::parse(&content)
                            },
                            Token![,],
                        )?,
                    })
                })
                .transpose()?,
        })
    }
}

fn add_field_or_variant_attribute(
    data: &mut Data,
    field_or_variant_to_find: &Ident,
    attr: &Meta,
) -> Result<()> {
    match data {
        Data::Struct(s) => {
            let mut found = false;
            for f in s.fields.iter_mut() {
                let field_ident = f.ident.as_ref().ok_or_else(|| {
                    Error::new(
                        Span::call_site(),
                        "field attributes not supported for unnamed fields",
                    )
                })?;
                if field_ident == field_or_variant_to_find {
                    f.attrs.push(parse_quote!(#[#attr]));
                    found = true;
                    break;
                }
            }
            if !found {
                Err(Error::new(
                    field_or_variant_to_find.span(),
                    "failed to find field",
                ))
            } else {
                Ok(())
            }
        }
        Data::Enum(e) => {
            let mut found = false;
            for v in &mut e.variants {
                if &v.ident == field_or_variant_to_find {
                    v.attrs.push(parse_quote!(#[#attr]));
                    found = true;
                    break;
                }
            }
            if !found {
                Err(Error::new(
                    field_or_variant_to_find.span(),
                    "failed to find variant",
                ))
            } else {
                Ok(())
            }
        }
        _ => Err(Error::new(
            Span::call_site(),
            "field attributes not supported for enum or union",
        )),
    }
}

pub fn inner_main(args: InnerArguments) -> Result<ItemMacro> {
    let remote_derive = &args.remote_derive;
    let mut input = args.input;
    if let Some(attrs) = args.derive_attrs {
        let container_attrs = attrs
            .container_attrs()
            .map(|a| -> Attribute { parse_quote!(#[#a]) });
        for f in attrs.field_or_variant_attrs() {
            add_field_or_variant_attribute(&mut input.data, &f.field_or_variant, &f.attr)?;
        }
        Ok(parse_quote! {
            #remote_derive!(
                #(#container_attrs)*
                #input
            );
        })
    } else {
        Ok(parse_quote! {
            #remote_derive!(#input);
        })
    }
}

pub fn main(args: Arguments) -> Result<Vec<ItemMacro>> {
    let mut items = vec![];
    for derive_macro in args.derive_macros {
        let pocket_definition = Ident::new(
            &format!(
                "{}_pocket_definition",
                args.remote_type.to_string().to_snake_case()
            ),
            Span::call_site(),
        );
        let remote_derive = Ident::new(
            &format!("{}_remote_derive", derive_macro.to_string().to_snake_case(),),
            Span::call_site(),
        );
        if let Some(attrs) = &args.derive_attrs {
            let attrs = attrs.attrs.iter();
            items.push(parse_quote! {
                #pocket_definition!(
                    maelstrom_macro::remote_derive_inner, #remote_derive, #((#attrs)),*
                );
            });
        } else {
            items.push(parse_quote! {
                #pocket_definition!(maelstrom_macro::remote_derive_inner, #remote_derive);
            });
        }
    }
    Ok(items)
}