attrsets 0.1.2

Proc macro for defining multiple variants of a struct/enum with different attribute annotations e.g. for multiple Serde serializations
Documentation
use itertools::Itertools;
use proc_macro2::{Delimiter, Ident, Span, TokenStream, TokenTree};
use quote::ToTokens;
use std::iter::FromIterator;

#[derive(Clone, Copy)]
struct Ctx<'a> {
    all_variants: &'a [String],
    cur_variant: Option<&'a str>,
}

fn filter_field(ctx: Ctx, field: syn::Field) -> syn::Field {
    syn::Field {
        attrs: field
            .attrs
            .into_iter()
            .flat_map(|a| {
                assert!(a.style == syn::AttrStyle::Outer);
                if a.path.is_ident("attrset") {
                    if let Some(TokenTree::Group(g)) = a.tokens.into_iter().next() {
                        assert!(g.delimiter() == Delimiter::Parenthesis);
                        let mut tokens = g.stream().into_iter();
                        let on_variants = tokens
                            .take_while_ref(|t| match t {
                                TokenTree::Punct(p) if p.as_char() == ',' => true,
                                TokenTree::Ident(i)
                                    if i.to_string() == "_"
                                        || ctx.all_variants.iter().any(|v| *v == i.to_string()) =>
                                {
                                    true
                                }
                                _ => false,
                            })
                            .flat_map(|t| match t {
                                TokenTree::Punct(p) if p.as_char() == ',' => None,
                                TokenTree::Ident(i) => Some(i.to_string()),
                                _ => unreachable!(),
                            })
                            .collect::<Vec<_>>();
                        let v_matches = if let Some(v) = ctx.cur_variant {
                            on_variants.iter().any(|vv| vv == v)
                        } else {
                            false
                        };
                        let plain_matches =
                            ctx.cur_variant.is_none() && on_variants.iter().any(|vv| vv == "_");
                        if v_matches || plain_matches {
                            let path = syn::parse2::<syn::Path>(TokenStream::from_iter(
                                tokens.take_while_ref(|t| match t {
                                    TokenTree::Punct(p) if p.as_char() == ':' => true,
                                    TokenTree::Ident(_) => true,
                                    _ => false,
                                }),
                            ))
                            .unwrap();
                            Some(syn::Attribute {
                                tokens: TokenStream::from_iter(tokens),
                                path,
                                ..a
                            })
                        } else {
                            None
                        }
                    } else {
                        panic!("attrset attr should look like attrset(...)");
                    }
                } else {
                    Some(a)
                }
            })
            .collect(),
        ..field
    }
}

fn filter_fields(ctx: Ctx, fields: syn::Fields) -> syn::Fields {
    match fields {
        syn::Fields::Named(n) => syn::Fields::Named(syn::FieldsNamed {
            named: n
                .named
                .into_pairs()
                .map(|p| match p {
                    syn::punctuated::Pair::Punctuated(f, c) => {
                        syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
                    }
                    syn::punctuated::Pair::End(f) => {
                        syn::punctuated::Pair::End(filter_field(ctx, f))
                    }
                })
                .collect(),
            ..n
        }),
        syn::Fields::Unnamed(u) => syn::Fields::Unnamed(syn::FieldsUnnamed {
            unnamed: u
                .unnamed
                .into_pairs()
                .map(|p| match p {
                    syn::punctuated::Pair::Punctuated(f, c) => {
                        syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
                    }
                    syn::punctuated::Pair::End(f) => {
                        syn::punctuated::Pair::End(filter_field(ctx, f))
                    }
                })
                .collect(),
            ..u
        }),
        syn::Fields::Unit => syn::Fields::Unit,
    }
}

fn filter_def(ctx: Ctx, inp: syn::DeriveInput) -> syn::DeriveInput {
    let data = match inp.data {
        syn::Data::Struct(stru) => syn::Data::Struct(syn::DataStruct {
            fields: filter_fields(ctx, stru.fields),
            ..stru
        }),
        syn::Data::Enum(enu) => syn::Data::Enum(syn::DataEnum {
            variants: enu
                .variants
                .into_pairs()
                .map(|p| match p {
                    syn::punctuated::Pair::Punctuated(v, c) => syn::punctuated::Pair::Punctuated(
                        syn::Variant {
                            fields: filter_fields(ctx, v.fields),
                            ..v
                        },
                        c,
                    ),
                    syn::punctuated::Pair::End(v) => syn::punctuated::Pair::End(syn::Variant {
                        fields: filter_fields(ctx, v.fields),
                        ..v
                    }),
                })
                .collect(),
            ..enu
        }),
        syn::Data::Union(_) => panic!("attrsets does not support union"),
    };
    syn::DeriveInput {
        ident: Ident::new(
            &format!("{}{}", inp.ident.to_string(), ctx.cur_variant.unwrap_or("")),
            Span::call_site(),
        ),
        data,
        ..inp
    }
}

#[proc_macro_attribute]
pub fn attrsets(
    attr: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let item_ast: syn::DeriveInput = syn::parse(item).unwrap();

    let all_variants = attr
        .into_iter()
        .flat_map(|t| match t {
            proc_macro::TokenTree::Punct(p) if p.as_char() == ',' => None,
            proc_macro::TokenTree::Ident(i) => Some(i.to_string()),
            _ => panic!("attrsets attr: bad token: {}", t),
        })
        .collect::<Vec<_>>();

    let mut tst = filter_def(
        Ctx {
            all_variants: &all_variants,
            cur_variant: None,
        },
        item_ast.clone(),
    )
    .into_token_stream();

    for v in all_variants.iter() {
        tst.extend(
            filter_def(
                Ctx {
                    all_variants: &all_variants,
                    cur_variant: Some(v),
                },
                item_ast.clone(),
            )
            .into_token_stream(),
        );
    }

    tst.into()
}