1use std::collections::HashSet;
2
3use function_name::named;
4use proc_macro as proc;
5use proc_macro2::TokenStream;
6use proc_macro_error::{abort, proc_macro_error};
7use quote::{quote, ToTokens};
8use syn::{parse::*, punctuated::*, *};
9
10macro_rules! abort_named_fn {
12 ($span:expr, $fmt:expr $(, $arg:expr)*) => {
13 abort!($span, concat!("{}: ", $fmt), function_name!(), $($arg),*)
14 };
15}
16
17#[proc_macro]
32#[proc_macro_error]
33#[named]
34pub fn as_char_results(input: proc::TokenStream) -> proc::TokenStream {
35 let input_literal = parse_macro_input!(input as ExprLit);
36
37 match input_literal.lit {
38 Lit::Str(str_literal) => {
39 let mut ok_wrapped_chars: Punctuated<Expr, Token![,]> = Punctuated::new();
40 for char in str_literal.value().chars() {
41 ok_wrapped_chars.push(
42 parse_quote!(std::result::Result::<char, std::convert::Infallible>::Ok(#char)),
43 )
44 }
45
46 proc::TokenStream::from(quote!([ #ok_wrapped_chars ]))
47 }
48 Lit::Char(char_literal) => {
49 let char = char_literal.value();
50
51 proc::TokenStream::from(
52 quote!([ std::result::Result::<char, std::convert::Infallible>::Ok(#char) ]),
53 )
54 }
55 _ => abort_named_fn!(input_literal, "Input must be a string or char literal."),
56 }
57}
58
59#[proc_macro]
74pub fn as_char_results_and_input(input: proc::TokenStream) -> proc::TokenStream {
75 let input_literal = TokenStream::from(input.clone());
76 let ok_wrapped_chars = TokenStream::from(as_char_results(input));
77
78 proc::TokenStream::from(quote!(
79 (#ok_wrapped_chars , #input_literal)
80 ))
81}
82
83#[proc_macro_attribute]
107#[proc_macro_error]
108#[named]
109pub fn enum_fields(args: proc::TokenStream, input: proc::TokenStream) -> proc::TokenStream {
110 let mut enum_definition = parse_macro_input!(input as ItemEnum);
111
112 let (skip_list, field_list) = parse_macro_input!(args with parse_enum_fields_args);
113 let fields: FieldsNamed = parse_quote!({ #field_list });
114
115 for enum_variant in &mut enum_definition.variants {
116 if skip_list.contains(&enum_variant.ident) {
117 continue;
118 }
119 match &mut enum_variant.fields {
120 Fields::Unit => enum_variant.fields = Fields::Named(fields.clone()),
121 Fields::Named(existing_fields) => existing_fields.named.extend(fields.named.clone()),
122 Fields::Unnamed(_) => abort_named_fn!(
123 enum_variant,
124 "Cannot add a named field to a tuple-like enum variant."
125 ),
126 }
127 }
128
129 proc::TokenStream::from(enum_definition.to_token_stream())
130}
131
132struct SkipList {
138 to_skip: HashSet<Ident>,
139}
140
141impl Parse for SkipList {
142 fn parse(input: ParseStream) -> Result<Self> {
143 let mut skip_list = SkipList::new();
144
145 input.parse::<Token![!]>()?;
146 let bracket_content;
147 bracketed!(bracket_content in input);
148
149 loop {
150 skip_list.insert(bracket_content.parse()?);
151 if bracket_content.is_empty() {
152 break;
153 }
154
155 bracket_content.parse::<Token![,]>()?;
156
157 if bracket_content.is_empty() {
158 break;
159 }
160 }
161
162 Ok(skip_list)
163 }
164}
165
166impl SkipList {
167 pub fn new() -> Self {
168 SkipList {
169 to_skip: HashSet::new(),
170 }
171 }
172
173 pub fn insert(&mut self, ident: Ident) -> bool {
174 self.to_skip.insert(ident)
175 }
176
177 pub fn contains(&self, ident: &Ident) -> bool {
178 self.to_skip.contains(ident)
179 }
180}
181
182struct FieldList {
184 fields: Punctuated<Field, Token![,]>,
185}
186
187impl Parse for FieldList {
188 fn parse(input: ParseStream) -> Result<Self> {
189 let mut fields: Punctuated<Field, Token![,]> = Punctuated::new();
190
191 loop {
192 fields.push_value(Field::parse_named(input)?);
193 if input.is_empty() {
194 break;
195 }
196
197 fields.push_punct(input.parse()?);
198 if input.is_empty() {
199 break;
200 }
201 }
202
203 Ok(FieldList { fields })
204 }
205}
206
207impl ToTokens for FieldList {
208 fn to_tokens(&self, tokens: &mut TokenStream) {
209 self.fields.to_tokens(tokens);
210 }
211}
212
213fn parse_enum_fields_args(input: ParseStream) -> Result<(SkipList, FieldList)> {
215 let skip_list = if input.peek(Token![!]) {
216 match SkipList::parse(input) {
217 Ok(list) => list,
218 Err(error) => return Err(error),
219 }
220 } else {
221 SkipList::new()
222 };
223 let field_list = match FieldList::parse(input) {
224 Ok(list) => list,
225 Err(error) => return Err(error),
226 };
227
228 Ok((skip_list, field_list))
229}