1use itertools::Itertools;
2use proc_macro2::{Delimiter, Ident, Span, TokenStream, TokenTree};
3use quote::ToTokens;
4use std::iter::FromIterator;
5
6#[derive(Clone, Copy)]
7struct Ctx<'a> {
8 all_variants: &'a [String],
9 cur_variant: Option<&'a str>,
10}
11
12fn filter_field(ctx: Ctx, field: syn::Field) -> syn::Field {
13 syn::Field {
14 attrs: field
15 .attrs
16 .into_iter()
17 .flat_map(|a| {
18 assert!(a.style == syn::AttrStyle::Outer);
19 if a.path.is_ident("attrset") {
20 if let Some(TokenTree::Group(g)) = a.tokens.into_iter().next() {
21 assert!(g.delimiter() == Delimiter::Parenthesis);
22 let mut tokens = g.stream().into_iter();
23 let on_variants = tokens
24 .take_while_ref(|t| match t {
25 TokenTree::Punct(p) if p.as_char() == ',' => true,
26 TokenTree::Ident(i)
27 if i.to_string() == "_"
28 || ctx.all_variants.iter().any(|v| *v == i.to_string()) =>
29 {
30 true
31 }
32 _ => false,
33 })
34 .flat_map(|t| match t {
35 TokenTree::Punct(p) if p.as_char() == ',' => None,
36 TokenTree::Ident(i) => Some(i.to_string()),
37 _ => unreachable!(),
38 })
39 .collect::<Vec<_>>();
40 let v_matches = if let Some(v) = ctx.cur_variant {
41 on_variants.iter().any(|vv| vv == v)
42 } else {
43 false
44 };
45 let plain_matches =
46 ctx.cur_variant.is_none() && on_variants.iter().any(|vv| vv == "_");
47 if v_matches || plain_matches {
48 let path = syn::parse2::<syn::Path>(TokenStream::from_iter(
49 tokens.take_while_ref(|t| match t {
50 TokenTree::Punct(p) if p.as_char() == ':' => true,
51 TokenTree::Ident(_) => true,
52 _ => false,
53 }),
54 ))
55 .unwrap();
56 Some(syn::Attribute {
57 tokens: TokenStream::from_iter(tokens),
58 path,
59 ..a
60 })
61 } else {
62 None
63 }
64 } else {
65 panic!("attrset attr should look like attrset(...)");
66 }
67 } else {
68 Some(a)
69 }
70 })
71 .collect(),
72 ..field
73 }
74}
75
76fn filter_fields(ctx: Ctx, fields: syn::Fields) -> syn::Fields {
77 match fields {
78 syn::Fields::Named(n) => syn::Fields::Named(syn::FieldsNamed {
79 named: n
80 .named
81 .into_pairs()
82 .map(|p| match p {
83 syn::punctuated::Pair::Punctuated(f, c) => {
84 syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
85 }
86 syn::punctuated::Pair::End(f) => {
87 syn::punctuated::Pair::End(filter_field(ctx, f))
88 }
89 })
90 .collect(),
91 ..n
92 }),
93 syn::Fields::Unnamed(u) => syn::Fields::Unnamed(syn::FieldsUnnamed {
94 unnamed: u
95 .unnamed
96 .into_pairs()
97 .map(|p| match p {
98 syn::punctuated::Pair::Punctuated(f, c) => {
99 syn::punctuated::Pair::Punctuated(filter_field(ctx, f), c)
100 }
101 syn::punctuated::Pair::End(f) => {
102 syn::punctuated::Pair::End(filter_field(ctx, f))
103 }
104 })
105 .collect(),
106 ..u
107 }),
108 syn::Fields::Unit => syn::Fields::Unit,
109 }
110}
111
112fn filter_def(ctx: Ctx, inp: syn::DeriveInput) -> syn::DeriveInput {
113 let data = match inp.data {
114 syn::Data::Struct(stru) => syn::Data::Struct(syn::DataStruct {
115 fields: filter_fields(ctx, stru.fields),
116 ..stru
117 }),
118 syn::Data::Enum(enu) => syn::Data::Enum(syn::DataEnum {
119 variants: enu
120 .variants
121 .into_pairs()
122 .map(|p| match p {
123 syn::punctuated::Pair::Punctuated(v, c) => syn::punctuated::Pair::Punctuated(
124 syn::Variant {
125 fields: filter_fields(ctx, v.fields),
126 ..v
127 },
128 c,
129 ),
130 syn::punctuated::Pair::End(v) => syn::punctuated::Pair::End(syn::Variant {
131 fields: filter_fields(ctx, v.fields),
132 ..v
133 }),
134 })
135 .collect(),
136 ..enu
137 }),
138 syn::Data::Union(_) => panic!("attrsets does not support union"),
139 };
140 syn::DeriveInput {
141 ident: Ident::new(
142 &format!("{}{}", inp.ident.to_string(), ctx.cur_variant.unwrap_or("")),
143 Span::call_site(),
144 ),
145 data,
146 ..inp
147 }
148}
149
150#[proc_macro_attribute]
151pub fn attrsets(
152 attr: proc_macro::TokenStream,
153 item: proc_macro::TokenStream,
154) -> proc_macro::TokenStream {
155 let item_ast: syn::DeriveInput = syn::parse(item).unwrap();
156
157 let all_variants = attr
158 .into_iter()
159 .flat_map(|t| match t {
160 proc_macro::TokenTree::Punct(p) if p.as_char() == ',' => None,
161 proc_macro::TokenTree::Ident(i) => Some(i.to_string()),
162 _ => panic!("attrsets attr: bad token: {}", t),
163 })
164 .collect::<Vec<_>>();
165
166 let mut tst = filter_def(
167 Ctx {
168 all_variants: &all_variants,
169 cur_variant: None,
170 },
171 item_ast.clone(),
172 )
173 .into_token_stream();
174
175 for v in all_variants.iter() {
176 tst.extend(
177 filter_def(
178 Ctx {
179 all_variants: &all_variants,
180 cur_variant: Some(v),
181 },
182 item_ast.clone(),
183 )
184 .into_token_stream(),
185 );
186 }
187
188 tst.into()
189}