attrsets/
lib.rs

1use itertools::Itertools;
2use proc_macro2::{Delimiter, Ident, Span, TokenStream, TokenTree};
3use quote::ToTokens;
4use std::iter::FromIterator;
5
6#[derive(Clone, Copy)]
7struct Ctx<'a> {
8    all_variants: &'a [String],
9    cur_variant: Option<&'a str>,
10}
11
12fn filter_field(ctx: Ctx, field: syn::Field) -> syn::Field {
13    syn::Field {
14        attrs: field
15            .attrs
16            .into_iter()
17            .flat_map(|a| {
18                assert!(a.style == syn::AttrStyle::Outer);
19                if a.path.is_ident("attrset") {
20                    if let Some(TokenTree::Group(g)) = a.tokens.into_iter().next() {
21                        assert!(g.delimiter() == Delimiter::Parenthesis);
22                        let mut tokens = g.stream().into_iter();
23                        let on_variants = tokens
24                            .take_while_ref(|t| match t {
25                                TokenTree::Punct(p) if p.as_char() == ',' => true,
26                                TokenTree::Ident(i)
27                                    if i.to_string() == "_"
28                                        || ctx.all_variants.iter().any(|v| *v == i.to_string()) =>
29                                {
30                                    true
31                                }
32                                _ => false,
33                            })
34                            .flat_map(|t| match t {
35                                TokenTree::Punct(p) if p.as_char() == ',' => None,
36                                TokenTree::Ident(i) => Some(i.to_string()),
37                                _ => unreachable!(),
38                            })
39                            .collect::<Vec<_>>();
40                        let v_matches = if let Some(v) = ctx.cur_variant {
41                            on_variants.iter().any(|vv| vv == v)
42                        } else {
43                            false
44                        };
45                        let plain_matches =
46                            ctx.cur_variant.is_none() && on_variants.iter().any(|vv| vv == "_");
47                        if v_matches || plain_matches {
48                            let path = syn::parse2::<syn::Path>(TokenStream::from_iter(
49                                tokens.take_while_ref(|t| match t {
50                                    TokenTree::Punct(p) if p.as_char() == ':' => true,
51                                    TokenTree::Ident(_) => true,
52                                    _ => false,
53                                }),
54                            ))
55                            .unwrap();
56                            Some(syn::Attribute {
57                                tokens: TokenStream::from_iter(tokens),
58                                path,
59                                ..a
60                            })
61                        } else {
62                            None
63                        }
64                    } else {
65                        panic!("attrset attr should look like attrset(...)");
66                    }
67                } else {
68                    Some(a)
69                }
70            })
71            .collect(),
72        ..field
73    }
74}
75
76fn filter_fields(ctx: Ctx, fields: syn::Fields) -> syn::Fields {
77    match fields {
78        syn::Fields::Named(n) => syn::Fields::Named(syn::FieldsNamed {
79            named: n
80                .named
81                .into_pairs()
82                .map(|p| match p {
83                    syn::punctuated::Pair::Punctuated(f, c) => {
84                        syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
85                    }
86                    syn::punctuated::Pair::End(f) => {
87                        syn::punctuated::Pair::End(filter_field(ctx, f))
88                    }
89                })
90                .collect(),
91            ..n
92        }),
93        syn::Fields::Unnamed(u) => syn::Fields::Unnamed(syn::FieldsUnnamed {
94            unnamed: u
95                .unnamed
96                .into_pairs()
97                .map(|p| match p {
98                    syn::punctuated::Pair::Punctuated(f, c) => {
99                        syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
100                    }
101                    syn::punctuated::Pair::End(f) => {
102                        syn::punctuated::Pair::End(filter_field(ctx, f))
103                    }
104                })
105                .collect(),
106            ..u
107        }),
108        syn::Fields::Unit => syn::Fields::Unit,
109    }
110}
111
112fn filter_def(ctx: Ctx, inp: syn::DeriveInput) -> syn::DeriveInput {
113    let data = match inp.data {
114        syn::Data::Struct(stru) => syn::Data::Struct(syn::DataStruct {
115            fields: filter_fields(ctx, stru.fields),
116            ..stru
117        }),
118        syn::Data::Enum(enu) => syn::Data::Enum(syn::DataEnum {
119            variants: enu
120                .variants
121                .into_pairs()
122                .map(|p| match p {
123                    syn::punctuated::Pair::Punctuated(v, c) => syn::punctuated::Pair::Punctuated(
124                        syn::Variant {
125                            fields: filter_fields(ctx, v.fields),
126                            ..v
127                        },
128                        c,
129                    ),
130                    syn::punctuated::Pair::End(v) => syn::punctuated::Pair::End(syn::Variant {
131                        fields: filter_fields(ctx, v.fields),
132                        ..v
133                    }),
134                })
135                .collect(),
136            ..enu
137        }),
138        syn::Data::Union(_) => panic!("attrsets does not support union"),
139    };
140    syn::DeriveInput {
141        ident: Ident::new(
142            &format!("{}{}", inp.ident.to_string(), ctx.cur_variant.unwrap_or("")),
143            Span::call_site(),
144        ),
145        data,
146        ..inp
147    }
148}
149
150#[proc_macro_attribute]
151pub fn attrsets(
152    attr: proc_macro::TokenStream,
153    item: proc_macro::TokenStream,
154) -> proc_macro::TokenStream {
155    let item_ast: syn::DeriveInput = syn::parse(item).unwrap();
156
157    let all_variants = attr
158        .into_iter()
159        .flat_map(|t| match t {
160            proc_macro::TokenTree::Punct(p) if p.as_char() == ',' => None,
161            proc_macro::TokenTree::Ident(i) => Some(i.to_string()),
162            _ => panic!("attrsets attr: bad token: {}", t),
163        })
164        .collect::<Vec<_>>();
165
166    let mut tst = filter_def(
167        Ctx {
168            all_variants: &all_variants,
169            cur_variant: None,
170        },
171        item_ast.clone(),
172    )
173    .into_token_stream();
174
175    for v in all_variants.iter() {
176        tst.extend(
177            filter_def(
178                Ctx {
179                    all_variants: &all_variants,
180                    cur_variant: Some(v),
181                },
182                item_ast.clone(),
183            )
184            .into_token_stream(),
185        );
186    }
187
188    tst.into()
189}