1use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
2use proc_macro2::Span;
3use quote::{quote, quote_spanned};
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, parse_quote, visit,
7 visit::Visit,
8 visit_mut,
9 visit_mut::VisitMut,
10 Block, Expr, ExprAsync, ExprAwait, ExprYield, ItemFn, Pat, PatWild, Stmt, Token,
11};
12
13#[proc_macro]
14pub fn generator(item: TokenStream) -> TokenStream {
15 let (mut block, pat) = {
16 let arg_version = item.clone();
17
18 match syn::parse::<ArgGenerator>(arg_version) {
19 Ok(gen) => (gen.block, gen.pattern),
20 Err(_) => {
21 let group = Group::new(Delimiter::Brace, item);
22 let stream = TokenTree::Group(group).into();
23 (
24 parse_macro_input!(stream as Block),
25 Pat::Wild(PatWild {
26 attrs: vec![],
27 underscore_token: <Token!(_)>::default(),
28 }),
29 )
30 }
31 }
32 };
33
34 let mut visitor = YieldVisitor::default();
35 visitor.visit_block(&block);
36
37 if visitor.errors.len() > 0 {
38 let errors = visitor.errors.into_iter().map(|(error, span)| {
39 quote_spanned! { span =>
40 compile_error!(#error);
41 }
42 });
43 let out = quote! {
44 {
45 #(#errors)*
46 }
47 };
48 return out.into();
49 }
50
51 let type_hint = if visitor.found_exprs > visitor.found_statement_exprs {
52 quote! { _ }
53 } else {
54 quote! { () }
55 };
56
57 let mut visitor = BlockVisitor {};
58 visitor.visit_block_mut(&mut block);
59
60 let tokens = quote! {
61 {
62 use ::generate::{Generator, GeneratorState, __support};
63
64 let (mut __resume, mut __yield) = __support::generator_mem::<#type_hint, _>();
65
66 let __await_resume = __resume.clone();
67 let __await_yield = __yield.clone();
68 let __yield_awaiter = move |val| __support::yield_future(__await_resume.clone(), __await_yield.clone(), val);
69
70 let build = move |#pat| async move {
71 #block
72 };
73
74 __support::generator_for(__resume, __yield, build)
75 }
76 };
77
78 tokens.into()
79}
80
81#[allow(unused)]
82struct ArgGenerator {
83 left_or: Token![|],
84 pattern: Pat,
85 right_or: Token![|],
86 block: Block,
87}
88
89impl Parse for ArgGenerator {
90 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
91 Ok(ArgGenerator {
92 left_or: input.parse()?,
93 pattern: input.parse()?,
94 right_or: input.parse()?,
95 block: Block {
96 brace_token: Default::default(),
97 stmts: Block::parse_within(input)?,
98 },
99 })
100 }
101}
102
103#[derive(Default)]
104struct YieldVisitor {
105 found_exprs: usize,
106 found_statement_exprs: usize,
107 errors: Vec<(String, Span)>,
108}
109
110impl<'a> Visit<'a> for YieldVisitor {
111 fn visit_stmt(&mut self, i: &'a Stmt) {
112 if let Stmt::Semi(expr, _) = i {
113 if let Expr::Yield(_) = expr {
114 self.found_statement_exprs += 1
115 }
116 }
117
118 visit::visit_stmt(self, i)
119 }
120
121 fn visit_expr_yield(&mut self, i: &'a ExprYield) {
122 self.found_exprs += 1;
123
124 visit::visit_expr_yield(self, i)
125 }
126
127 fn visit_expr_await(&mut self, i: &'a ExprAwait) {
128 self.errors.push((
129 format!("Await must not be used inside of a generator"),
130 i.await_token.span,
131 ))
132 }
133
134 fn visit_expr_async(&mut self, _i: &'a ExprAsync) {
135 }
139
140 fn visit_item_fn(&mut self, _i: &'a ItemFn) {
141 }
144}
145
146struct BlockVisitor {}
147
148impl VisitMut for BlockVisitor {
149 fn visit_expr_mut(&mut self, i: &mut Expr) {
150 if let Expr::Yield(expr) = i {
151 let yield_expr = expr
152 .expr
153 .take()
154 .unwrap_or_else(|| Box::new(parse_quote! {()}));
155
156 *i = Expr::Await(ExprAwait {
157 attrs: std::mem::replace(&mut expr.attrs, vec![]),
158 await_token: <Token!(await)>::default(),
159 dot_token: <Token!(.)>::default(),
160 base: parse_quote! {
161 __yield_awaiter(#yield_expr)
162 },
163 })
164 }
165
166 visit_mut::visit_expr_mut(self, i)
167 }
168}