generate_derive/
lib.rs

1use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
2use proc_macro2::Span;
3use quote::{quote, quote_spanned};
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, parse_quote, visit,
7    visit::Visit,
8    visit_mut,
9    visit_mut::VisitMut,
10    Block, Expr, ExprAsync, ExprAwait, ExprYield, ItemFn, Pat, PatWild, Stmt, Token,
11};
12
13#[proc_macro]
14pub fn generator(item: TokenStream) -> TokenStream {
15    let (mut block, pat) = {
16        let arg_version = item.clone();
17
18        match syn::parse::<ArgGenerator>(arg_version) {
19            Ok(gen) => (gen.block, gen.pattern),
20            Err(_) => {
21                let group = Group::new(Delimiter::Brace, item);
22                let stream = TokenTree::Group(group).into();
23                (
24                    parse_macro_input!(stream as Block),
25                    Pat::Wild(PatWild {
26                        attrs: vec![],
27                        underscore_token: <Token!(_)>::default(),
28                    }),
29                )
30            }
31        }
32    };
33
34    let mut visitor = YieldVisitor::default();
35    visitor.visit_block(&block);
36
37    if visitor.errors.len() > 0 {
38        let errors = visitor.errors.into_iter().map(|(error, span)| {
39            quote_spanned! { span =>
40                compile_error!(#error);
41            }
42        });
43        let out = quote! {
44            {
45                #(#errors)*
46            }
47        };
48        return out.into();
49    }
50
51    let type_hint = if visitor.found_exprs > visitor.found_statement_exprs {
52        quote! { _ }
53    } else {
54        quote! { () }
55    };
56
57    let mut visitor = BlockVisitor {};
58    visitor.visit_block_mut(&mut block);
59
60    let tokens = quote! {
61        {
62            use ::generate::{Generator, GeneratorState, __support};
63
64            let (mut __resume, mut __yield) = __support::generator_mem::<#type_hint, _>();
65
66            let __await_resume = __resume.clone();
67            let __await_yield = __yield.clone();
68            let __yield_awaiter = move |val| __support::yield_future(__await_resume.clone(), __await_yield.clone(), val);
69
70            let build = move |#pat| async move {
71                #block
72            };
73
74            __support::generator_for(__resume, __yield, build)
75        }
76    };
77
78    tokens.into()
79}
80
81#[allow(unused)]
82struct ArgGenerator {
83    left_or: Token![|],
84    pattern: Pat,
85    right_or: Token![|],
86    block: Block,
87}
88
89impl Parse for ArgGenerator {
90    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
91        Ok(ArgGenerator {
92            left_or: input.parse()?,
93            pattern: input.parse()?,
94            right_or: input.parse()?,
95            block: Block {
96                brace_token: Default::default(),
97                stmts: Block::parse_within(input)?,
98            },
99        })
100    }
101}
102
103#[derive(Default)]
104struct YieldVisitor {
105    found_exprs: usize,
106    found_statement_exprs: usize,
107    errors: Vec<(String, Span)>,
108}
109
110impl<'a> Visit<'a> for YieldVisitor {
111    fn visit_stmt(&mut self, i: &'a Stmt) {
112        if let Stmt::Semi(expr, _) = i {
113            if let Expr::Yield(_) = expr {
114                self.found_statement_exprs += 1
115            }
116        }
117
118        visit::visit_stmt(self, i)
119    }
120
121    fn visit_expr_yield(&mut self, i: &'a ExprYield) {
122        self.found_exprs += 1;
123
124        visit::visit_expr_yield(self, i)
125    }
126
127    fn visit_expr_await(&mut self, i: &'a ExprAwait) {
128        self.errors.push((
129            format!("Await must not be used inside of a generator"),
130            i.await_token.span,
131        ))
132    }
133
134    fn visit_expr_async(&mut self, _i: &'a ExprAsync) {
135        // Don't defer to the standard implementation.
136        // This is so that `await`s within `async`s aren't
137        // caught above.
138    }
139
140    fn visit_item_fn(&mut self, _i: &'a ItemFn) {
141        // Likewise, we don't care about the contents
142        // of locally defined functions.
143    }
144}
145
146struct BlockVisitor {}
147
148impl VisitMut for BlockVisitor {
149    fn visit_expr_mut(&mut self, i: &mut Expr) {
150        if let Expr::Yield(expr) = i {
151            let yield_expr = expr
152                .expr
153                .take()
154                .unwrap_or_else(|| Box::new(parse_quote! {()}));
155
156            *i = Expr::Await(ExprAwait {
157                attrs: std::mem::replace(&mut expr.attrs, vec![]),
158                await_token: <Token!(await)>::default(),
159                dot_token: <Token!(.)>::default(),
160                base: parse_quote! {
161                    __yield_awaiter(#yield_expr)
162                },
163            })
164        }
165
166        visit_mut::visit_expr_mut(self, i)
167    }
168}