proc_macro_util/
lib.rs

1//! Utilities to implement procedural macros
2
3#[cfg(test)]
4#[cfg_attr(test, macro_use)]
5extern crate quote;
6
7use proc_macro2::{Delimiter, Span, TokenTree};
8use std::collections::hash_map::Entry;
9use std::collections::{HashMap, HashSet};
10
11mod attr;
12mod ctxt;
13
14pub mod prelude {
15    pub use super::attr::*;
16    pub use super::ctxt::*;
17}
18
19/// An error that is located by a span
20#[derive(Debug)]
21pub struct SpanError {
22    /// The source error
23    pub error: failure::Error,
24    /// The location of the error
25    pub span: Span,
26}
27
28impl SpanError {
29    /// Build a new `SpanError`
30    pub fn from_error(error: failure::Error, span: Span) -> Self {
31        Self { error, span }
32    }
33}
34
35/// Parse keyed arrays of identifiers.
36///
37/// It expects as an input a parenthesis group, and in this group
38/// if it ran into a key in `allowed_key_values`, then it will
39/// parse the following tokens as an array of identifiers.
40///
41/// The values must be in the provided value set.
42///
43/// # Arguments
44/// * `input`: token iterator to consume
45/// * `default_span`: A span that will be used if there is an error that can't be
46///    linked to a specific span.
47/// * `allowed_key_values`: An hash map to define the allowed keys and their corresponding allowed
48///    values.
49///
50/// # Return
51/// An hash map with the values found for each keys.
52///
53/// # Error
54/// * When a key is detected but its value is not a bracket group with identifiers separated by commas
55/// * When an identifier is not in the allowed identifier list for a specific key
56///
57/// # Example
58///
59/// * `(not_array, array_key = [value_1, value_2])`
60///    Here `array_key` is defined as a valid key and its
61///    following tokens will be processed as an array of identifiers.
62///    This will be valid only if both `value_1` and `value_2` are valid identifiers.
63pub fn parse_identifier_arrays(
64    input: impl IntoIterator<Item = TokenTree>,
65    default_span: Span,
66    allowed_key_values: &HashMap<&str, &[&str]>,
67) -> Result<HashMap<String, Vec<String>>, SpanError> {
68    let mut result = HashMap::new();
69    let iter = input.into_iter();
70
71    let mut iter = private::consume_group_parenthesis_token(iter, default_span.clone())?;
72
73    loop {
74        match iter.next() {
75            Some(TokenTree::Ident(ident)) => {
76                let ident_str = ident.to_string();
77                if allowed_key_values.contains_key(std::ops::Deref::deref(&ident_str)) {
78                    private::consume_punct(&mut iter, default_span.clone(), '=')?;
79
80                    // We expect now an array of identifier
81                    let values = parse_identifier_array(
82                        &mut iter,
83                        default_span,
84                        &allowed_key_values[std::ops::Deref::deref(&ident_str)],
85                    )?;
86                    match result.entry(ident_str) {
87                        Entry::Vacant(entry) => {
88                            entry.insert(values);
89                        }
90                        Entry::Occupied(mut entry) => {
91                            entry.get_mut().extend(values);
92                        }
93                    }
94                }
95            }
96            None => break,
97            _ => {}
98        }
99    }
100
101    Ok(result)
102}
103
104/// Parse an array of identifiers.
105///
106/// It expects an array of predefined identifiers.
107///
108/// # Arguments
109/// * `iter`: token iterator to consume
110/// * `default_span`: span to use when an error was found but can't be linked to a span.
111/// * `allowed_values`: identifiers allowed as values for this array
112///
113/// # Return
114/// The identifiers found for this array.
115///
116/// # Error
117/// * When the tokens is not a bracket group with identifiers separated by commas
118/// * When an identifier is not in the allowed identifier list
119pub fn parse_identifier_array<I>(
120    iter: &mut I,
121    default_span: proc_macro2::Span,
122    allowed_values: &[&str],
123) -> Result<Vec<String>, SpanError>
124where
125    I: Iterator<Item = TokenTree>,
126{
127    match iter.next() {
128        Some(TokenTree::Group(ref group)) if group.delimiter() == Delimiter::Bracket => {
129            let mut group_iter = group.stream().into_iter();
130            let mut result = Vec::new();
131
132            loop {
133                match group_iter.next() {
134                    Some(TokenTree::Ident(ident)) => {
135                        let ident_str = ident.to_string();
136                        if allowed_values.contains(&std::ops::Deref::deref(&ident_str)) {
137                            result.push(ident_str);
138                        } else {
139                            return Err(SpanError::from_error(
140                                failure::err_msg(format!(
141                                    "Unexpected value: {:?}, values expected are: {:?}",
142                                    ident_str, allowed_values
143                                )),
144                                ident.span(),
145                            ));
146                        }
147
148                        match group_iter.next() {
149                            Some(TokenTree::Punct(ref punct)) if punct.as_char() == ',' => {}
150                            Some(v) => {
151                                return Err(SpanError::from_error(
152                                    failure::err_msg(format!(
153                                        "Unexpected token: {:?}, only ',' is expected",
154                                        v.to_string()
155                                    )),
156                                    private::token_tree_span(&v),
157                                ))
158                            }
159                            None => {}
160                        }
161                    }
162                    Some(v) => {
163                        return Err(SpanError::from_error(
164                            failure::err_msg(format!(
165                                "Unexpected token: {:?}, only a value in {:?} is expected",
166                                v.to_string(),
167                                allowed_values
168                            )),
169                            private::token_tree_span(&v),
170                        ))
171                    }
172                    None => break,
173                }
174            }
175
176            Ok(result)
177        }
178        Some(v) => {
179            return Err(SpanError::from_error(
180                failure::err_msg(format!(
181                    "Expected an array of values, but received: {}",
182                    v.to_string()
183                )),
184                private::token_tree_span(&v),
185            ))
186        }
187        None => Err(SpanError::from_error(
188            failure::err_msg("Expected an array of values."),
189            default_span,
190        )),
191    }
192}
193
194/// Parse tokens as flags toggles
195///
196/// This will search in the provided token iterator `input` for specific
197/// identifier that will toggle on a flag.
198///
199/// # Arguments
200/// * `input`: token iterator to consume
201/// * `default_span`: span to use when an error occurs and can't be linked to a span
202/// * `allowed_flags`: identifiers that are parsed as flags
203///
204/// # Return
205/// An hash set of toggled flags.
206///
207/// # Errors
208/// if a flag is found but is not parsed as a flag. (Usually followed by a `=` token,
209/// which is invalid for a flag)
210pub fn parse_flags(
211    input: impl IntoIterator<Item = TokenTree>,
212    default_span: proc_macro2::Span,
213    allowed_flags: &[&str],
214) -> Result<HashSet<String>, SpanError> {
215    let mut result = HashSet::new();
216    let iter = input.into_iter();
217
218    let mut iter = private::consume_group_parenthesis_token(iter, default_span.clone())?;
219
220    loop {
221        match iter.next() {
222            Some(TokenTree::Ident(ident)) => {
223                let ident_str = ident.to_string();
224                if allowed_flags.contains(&std::ops::Deref::deref(&ident_str)) {
225                    // Check that following token is either , or EOF
226                    match iter.next() {
227                        Some(TokenTree::Punct(ref punct)) if punct.as_char() == ',' => {},
228                        None => {},
229                        Some(v) => {
230                            return Err(SpanError::from_error(
231                                failure::err_msg(format!(
232                                    "'{}' must be used as a flag, it must be followed by a ',' or another argument.", ident_str
233                                )),
234                                private::token_tree_span(&v)
235                            ))
236                        }
237                    }
238
239                    result.insert(ident.to_string());
240                }
241            }
242            None => break,
243            _ => {}
244        }
245    }
246
247    Ok(result)
248}
249
250/// Parse strings identified by keys
251///
252/// # Argument
253/// * `input`: token iterator to consume
254/// * `default_span`: span to use when an error is detected but can't be linked to a span
255/// * `keys`: keys to search for
256///
257/// # Return
258/// An hash map with the key that were detected with their associated value.
259/// The value contains the quote that makes the string literal.
260///
261/// # Error
262/// * When a key was detected but can't be parsed as a string literal assignment
263///
264/// # Example
265/// `(my_val = "toto"), then the value of identifier `my_val` is `r#""toto""#`.
266pub fn parse_keyed_strings(
267    input: impl IntoIterator<Item = TokenTree>,
268    default_span: proc_macro2::Span,
269    keys: &[&str],
270) -> Result<HashMap<String, String>, SpanError> {
271    let mut result = HashMap::new();
272    let iter = input.into_iter();
273
274    let mut iter = private::consume_group_parenthesis_token(iter, default_span.clone())?;
275
276    loop {
277        match iter.next() {
278            Some(TokenTree::Ident(ident)) => {
279                let ident_str = ident.to_string();
280                if keys.contains(&std::ops::Deref::deref(&ident_str)) {
281                    // Extract value as a string
282
283                    // Extract '=' token
284                    match iter.next() {
285                        Some(TokenTree::Punct(ref punct)) if punct.as_char() == '=' => {}
286                        Some(v) => {
287                            return Err(SpanError::from_error(
288                                failure::err_msg(format!(
289                                    "Expected a string value for {:?}, but received: {:?}",
290                                    ident_str,
291                                    v.to_string()
292                                )),
293                                private::token_tree_span(&v),
294                            ))
295                        }
296                        _ => {
297                            return Err(SpanError::from_error(
298                                failure::err_msg(format!(
299                                    "Expected a string value for {:?}, but received none",
300                                    ident_str
301                                )),
302                                default_span,
303                            ))
304                        }
305                    };
306
307                    // Extract value
308                    match iter.next() {
309                        Some(TokenTree::Literal(ref literal)) => {
310                            let old = result.insert(ident_str.clone(), literal.to_string());
311                            if old.is_some() {
312                                return Err(SpanError::from_error(
313                                    failure::err_msg(format!(
314                                        "{:?} can appear only once",
315                                        ident_str
316                                    )),
317                                    literal.span(),
318                                ));
319                            }
320                        }
321                        Some(v) => {
322                            return Err(SpanError::from_error(
323                                failure::err_msg(format!(
324                                    "Expected a string value for {:?}, but received: {:?}",
325                                    ident_str,
326                                    v.to_string()
327                                )),
328                                private::token_tree_span(&v),
329                            ))
330                        }
331                        _ => {
332                            return Err(SpanError::from_error(
333                                failure::err_msg(format!(
334                                    "Expected a string value for {:?}, but received none",
335                                    ident_str
336                                )),
337                                default_span,
338                            ))
339                        }
340                    };
341                }
342            }
343            None => break,
344            _ => {}
345        }
346    }
347
348    Ok(result)
349}
350
351mod private {
352    use super::SpanError;
353    use proc_macro2::{Delimiter, Span, TokenTree};
354
355    pub fn token_tree_span(tree: &TokenTree) -> Span {
356        match tree {
357            TokenTree::Punct(punct) => punct.span(),
358            TokenTree::Ident(ident) => ident.span(),
359            TokenTree::Group(group) => group.span(),
360            TokenTree::Literal(literal) => literal.span(),
361        }
362    }
363
364    pub fn consume_punct<I: Iterator<Item = TokenTree>>(
365        iter: &mut I,
366        default_span: Span,
367        value: char,
368    ) -> Result<(), SpanError> {
369        match iter.next() {
370            Some(TokenTree::Punct(ref punct)) if punct.as_char() == value => {}
371            Some(v) => {
372                return Err(SpanError::from_error(
373                    failure::err_msg(format!(
374                        "Expected token '{:?}', but received {:?}",
375                        value,
376                        v.to_string()
377                    )),
378                    token_tree_span(&v),
379                ))
380            }
381            _ => {
382                return Err(SpanError::from_error(
383                    failure::err_msg(format!("Expected token '{:?}', but received none", value)),
384                    default_span,
385                ))
386            }
387        };
388        Ok(())
389    }
390
391    pub fn consume_group_parenthesis_token(
392        mut iter: impl Iterator<Item = TokenTree>,
393        default_span: Span,
394    ) -> Result<impl Iterator<Item = TokenTree>, SpanError> {
395        match iter.next() {
396            Some(TokenTree::Group(ref group)) if group.delimiter() == Delimiter::Parenthesis => {
397                Ok(group.stream().into_iter())
398            }
399            Some(v) => Err(SpanError::from_error(
400                failure::err_msg(format!(
401                    "Expected a parenthesis group, but received: {:?}",
402                    v.to_string()
403                )),
404                token_tree_span(&v),
405            )),
406            _ => Err(SpanError::from_error(
407                failure::err_msg("Expected a parenthesis group, but received none"),
408                default_span,
409            )),
410        }
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::{parse_flags, parse_identifier_arrays, parse_keyed_strings};
417    use proc_macro2::TokenStream;
418    use std::collections::{HashMap, HashSet};
419    use syn::spanned::Spanned;
420    use syn::DeriveInput;
421
422    #[test]
423    fn parse_flag_fails_with_invalid_entry() {
424        let input: TokenStream = quote! {
425            #[my_attr(my_flag = 3)]
426            struct MyAttr;
427        };
428        let derive_input: DeriveInput = syn::parse2(input).unwrap();
429        let attr = &derive_input.attrs[0];
430
431        let flags = parse_flags(attr.tokens.clone().into_iter(), attr.span(), &["my_flag"]);
432
433        assert!(flags.is_err());
434    }
435
436    #[test]
437    fn parse_flag_works_with_one_entry() {
438        let input: TokenStream = quote! {
439            #[my_attr(my_flag)]
440            struct MyAttr;
441        };
442        let derive_input: DeriveInput = syn::parse2(input).unwrap();
443        let attr = &derive_input.attrs[0];
444
445        let flags =
446            parse_flags(attr.tokens.clone().into_iter(), attr.span(), &["my_flag"]).unwrap();
447
448        let expected_flags = vec!["my_flag".to_owned()]
449            .into_iter()
450            .collect::<HashSet<_>>();
451
452        assert_eq!(flags, expected_flags);
453    }
454
455    #[test]
456    fn parse_flag_works_with_multiple_entries() {
457        let input: TokenStream = quote! {
458            #[my_attr(my_flag, my_second_flag)]
459            struct MyAttr;
460        };
461        let derive_input: DeriveInput = syn::parse2(input).unwrap();
462        let attr = &derive_input.attrs[0];
463
464        let flags = parse_flags(
465            attr.tokens.clone().into_iter(),
466            attr.span(),
467            &["my_flag", "my_second_flag", "my_third_flag"],
468        )
469        .unwrap();
470
471        let expected_flags = vec!["my_flag".to_owned(), "my_second_flag".to_owned()]
472            .into_iter()
473            .collect::<HashSet<_>>();
474
475        assert_eq!(flags, expected_flags);
476    }
477
478    #[test]
479    fn parse_identifier_arrays_works() {
480        let input: TokenStream = quote! {
481            #[my_attr(array_key = [entry, entry2])]
482            struct MyAttr;
483        };
484        let derive_input: DeriveInput = syn::parse2(input).unwrap();
485        let attr = &derive_input.attrs[0];
486
487        let allowed_keys = vec![("array_key", &["entry", "entry2", "entry3"] as &[&str])]
488            .into_iter()
489            .collect::<HashMap<_, _>>();
490
491        let values =
492            parse_identifier_arrays(attr.tokens.clone().into_iter(), attr.span(), &allowed_keys)
493                .unwrap();
494
495        assert!(values.contains_key("array_key"));
496        let array_key_values = &values["array_key"].iter().cloned().collect::<HashSet<_>>();
497        let expected_array_key_values = vec!["entry".to_owned(), "entry2".to_owned()]
498            .into_iter()
499            .collect::<HashSet<_>>();
500        assert_eq!(array_key_values, &expected_array_key_values);
501    }
502
503    #[test]
504    fn parse_identifier_arrays_fails_with_invalid_key_entry() {
505        let input: TokenStream = quote! {
506            #[my_attr(array_key = 3)]
507            struct MyAttr;
508        };
509        let derive_input: DeriveInput = syn::parse2(input).unwrap();
510        let attr = &derive_input.attrs[0];
511
512        let allowed_keys = vec![("array_key", &["entry", "entry2", "entry3"] as &[&str])]
513            .into_iter()
514            .collect::<HashMap<_, _>>();
515
516        let values =
517            parse_identifier_arrays(attr.tokens.clone().into_iter(), attr.span(), &allowed_keys);
518
519        assert!(values.is_err());
520    }
521
522    #[test]
523    fn parse_identifier_arrays_fails_with_invalid_value_entry() {
524        let input: TokenStream = quote! {
525            #[my_attr(array_key = [entry, entry4])]
526            struct MyAttr;
527        };
528        let derive_input: DeriveInput = syn::parse2(input).unwrap();
529        let attr = &derive_input.attrs[0];
530
531        let allowed_keys = vec![("array_key", &["entry", "entry2", "entry3"] as &[&str])]
532            .into_iter()
533            .collect::<HashMap<_, _>>();
534
535        let values =
536            parse_identifier_arrays(attr.tokens.clone().into_iter(), attr.span(), &allowed_keys);
537
538        assert!(values.is_err());
539    }
540
541    #[test]
542    fn parse_keyed_strings_works() {
543        let input: TokenStream = quote! {
544            #[my_attr(string_key = "my_value")]
545            struct MyAttr;
546        };
547        let derive_input: DeriveInput = syn::parse2(input).unwrap();
548        let attr = &derive_input.attrs[0];
549
550        let values = parse_keyed_strings(
551            attr.tokens.clone().into_iter(),
552            attr.span(),
553            &["string_key"],
554        )
555        .unwrap();
556
557        assert_eq!(&values["string_key"], &r#""my_value""#);
558    }
559
560    #[test]
561    fn parse_keyed_strings_fails_with_invalid_entry() {
562        let input: TokenStream = quote! {
563            #[my_attr(string_key)]
564            struct MyAttr;
565        };
566        let derive_input: DeriveInput = syn::parse2(input).unwrap();
567        let attr = &derive_input.attrs[0];
568
569        let values = parse_keyed_strings(
570            attr.tokens.clone().into_iter(),
571            attr.span(),
572            &["string_key"],
573        );
574
575        assert!(values.is_err());
576    }
577
578    #[test]
579    fn parse_keyed_strings_fails_with_invalid_entry2() {
580        let input: TokenStream = quote! {
581            #[my_attr(string_key = [43])]
582            struct MyAttr;
583        };
584        let derive_input: DeriveInput = syn::parse2(input).unwrap();
585        let attr = &derive_input.attrs[0];
586
587        let values = parse_keyed_strings(
588            attr.tokens.clone().into_iter(),
589            attr.span(),
590            &["string_key"],
591        );
592
593        assert!(values.is_err());
594    }
595}