1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
//! Internal lib for [variant](https://docs.rs/extract-variant).

use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::visit::Visit;
use syn::{parse_macro_input, Error, Pat, PatIdent, PatOr};

#[derive(Default, Debug)]
struct VisitPatIdent<'a> {
    idents: Vec<&'a Ident>,
    pat_or: Option<proc_macro2::TokenStream>,
}

impl<'ast> Visit<'ast> for VisitPatIdent<'ast> {
    fn visit_pat_ident(&mut self, node: &'ast PatIdent) {
        // Since `None` cannot be distinguish from an ident, we check for and
        // ignore it here.
        if node.ident != "None" {
            self.idents.push(&node.ident);
        }
        syn::visit::visit_pat_ident(self, node)
    }

    fn visit_pat_or(&mut self, node: &'ast PatOr) {
        self.pat_or = Some(node.to_token_stream())
        // Since we're going to error the macro we can stop visiting.
    }
}

#[proc_macro]
pub fn extract_variant_assign(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as Pat);
    let mut visitor = VisitPatIdent::default();
    visitor.visit_pat(&input);

    if let Some(tokens) = visitor.pat_or {
        let msg = "`variant` cannot match `or` patterns";
        let err = Error::new_spanned(tokens, msg).to_compile_error();
        return TokenStream::from(err);
    }

    let tokens = match visitor.idents.as_slice() {
        [id] => quote! {
            #id
        },
        ids => quote! {
            (#(#ids),*)
        },
    };

    TokenStream::from(tokens)
}