1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{parse::Parse, visit_mut::VisitMut, Token};
5
6struct LifetimeAdder {
7 lifetime: syn::Lifetime,
8}
9
10impl VisitMut for LifetimeAdder {
11 fn visit_type_reference_mut(&mut self, i: &mut syn::TypeReference) {
12 if i.lifetime.is_none() {
13 i.lifetime = Some(self.lifetime.clone())
14 }
15 }
16}
17
18struct ReplaceCoroutineAwait {
19 resume_type: syn::Type,
20}
21
22impl VisitMut for ReplaceCoroutineAwait {
23 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
24 match i {
25 syn::Expr::Await(ei) => {
26 assert!(ei.attrs.is_empty());
27
28 let resume_type = self.resume_type.clone();
29 let base = &ei.base;
30
31 *i = syn::parse_quote! {
32 {
33 let mut __coroutine = #base;
34 let mut __response: #resume_type = Default::default();
35
36 loop {
37 use ::core::{pin::Pin, ops::{Coroutine, CoroutineState}};
38
39 match unsafe { Pin::new_unchecked(&mut __coroutine) }.resume(__response) {
40 CoroutineState::Yielded(__request) => __response = yield __request.into(),
41 CoroutineState::Complete(__result) => break __result,
42 }
43 }
44 }
45 };
46 }
47 _ => syn::visit_mut::visit_expr_mut(self, i),
48 }
49 }
50}
51
52mod kw {
53 syn::custom_keyword!(lifetime);
54}
55
56struct CoroutineInput {
57 is_static: bool,
58
59 yield_type: Option<syn::Type>,
60 resume_type: Option<syn::Type>,
61 capture: CaptureMode,
62}
63
64enum CaptureMode {
65 Implicit,
66 Explicit(syn::PreciseCapture),
67 None,
68}
69
70impl Parse for CoroutineInput {
71 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
72 let mut output = Self {
73 is_static: false,
74 yield_type: None,
75 resume_type: None,
76 capture: CaptureMode::Implicit,
77 };
78
79 {
80 let lk = input.lookahead1();
81
82 if lk.peek(Token![static]) {
83 output.is_static = true;
84 input.parse::<Token![static]>()?;
85 if input.is_empty() {
86 return Ok(output);
87 } else {
88 input.parse::<Token![,]>()?;
89 input.parse::<Token![yield]>()?;
90 }
91 } else if lk.peek(Token![yield]) {
92 } else {
93 return Err(lk.error());
94 }
95 }
96
97 if input.is_empty() {
98 return Ok(output);
99 }
100
101 output.yield_type = Some(input.parse::<syn::Type>()?);
102
103 if input.is_empty() {
104 return Ok(output);
105 }
106
107 input.parse::<Token![->]>()?;
108 output.resume_type = Some(input.parse::<syn::Type>()?);
109
110 if input.is_empty() {
111 return Ok(output);
112 }
113
114 input.parse::<Token![,]>()?;
115
116 if input.parse::<Option<Token![!]>>()?.is_some() {
117 input.parse::<Token![use]>()?;
118 output.capture = CaptureMode::None;
119 } else {
120 output.capture = CaptureMode::Explicit(input.parse::<syn::PreciseCapture>()?);
121 }
122
123 Ok(output)
124 }
125}
126
127struct BareItemFn {
128 attrs: Vec<syn::Attribute>,
129 vis: syn::Visibility,
130 sig: syn::Signature,
131 block: TokenStream2,
132}
133
134impl Parse for BareItemFn {
135 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
136 Ok(Self {
137 attrs: syn::Attribute::parse_outer(input)?,
138 vis: syn::Visibility::parse(input)?,
139 sig: syn::Signature::parse(input)?,
140 block: input.parse::<TokenStream2>()?,
141 })
142 }
143}
144
145#[proc_macro_attribute]
146pub fn generator(attr_ts: TokenStream, ts: TokenStream) -> TokenStream {
147 let func_result = syn::parse::<BareItemFn>(ts.clone());
148 let input_result = match attr_ts.is_empty() {
149 true => Ok(None),
150 false => syn::parse::<CoroutineInput>(attr_ts).map(Some),
151 };
152
153 if let Err(err) = &input_result {
154 panic!("{err}");
155 }
156
157 if let (Ok(func), Ok(input)) = (func_result, input_result) {
158 let unit_type: syn::Type = syn::parse_quote!(());
159
160 let attrs = func.attrs;
161 let vis = func.vis;
162 let name = func.sig.ident;
163 let mut generics = func.sig.generics;
164 let mut args = func.sig.inputs;
165 let return_type = match func.sig.output {
166 syn::ReturnType::Default => &unit_type,
167 syn::ReturnType::Type(_, ref tp) => tp,
168 };
169
170 let implicit_lifetime = if input
171 .as_ref()
172 .is_none_or(|x| matches!(x.capture, CaptureMode::Implicit))
173 {
174 let lt = syn::Lifetime::new("'__coroutine", Span::call_site());
175 generics.params.insert(0, syn::parse_quote!(#lt));
176 Some(lt)
177 } else {
178 None
179 };
180
181 if let Some(implicit_lifetime) = implicit_lifetime.clone() {
182 let mut ladder = LifetimeAdder {
183 lifetime: implicit_lifetime.clone(),
184 };
185 for arg in args.iter_mut() {
186 match arg {
187 syn::FnArg::Receiver(recv) => {
188 if let Some((_, lifetime @ None)) = &mut recv.reference {
189 *lifetime = Some(implicit_lifetime.clone())
190 }
191 }
192 syn::FnArg::Typed(pat) => ladder.visit_pat_type_mut(pat),
193 }
194 }
195 }
196
197 let (yield_type, resume_type) = {
198 let opts = input
199 .as_ref()
200 .map(|x| (x.yield_type.clone(), x.resume_type.clone()))
201 .unwrap_or_default();
202
203 (
204 opts.0.unwrap_or_else(|| unit_type.clone()),
205 opts.1.unwrap_or_else(|| unit_type.clone()),
206 )
207 };
208
209 let generic_params = generics.params;
210 let where_clause = generics.where_clause;
211 let precise_captures = match input
212 .as_ref()
213 .map(|x| &x.capture)
214 .unwrap_or(&CaptureMode::Implicit)
215 {
216 CaptureMode::Implicit => {
217 let lifetime = implicit_lifetime.as_ref().unwrap();
218 quote! { + use<#lifetime> }
219 }
220 CaptureMode::Explicit(precise_capture) => {
221 quote! { + #precise_capture }
222 }
223 CaptureMode::None => TokenStream2::new(),
224 };
225
226 let new_body = if let Ok(mut block) = syn::parse2::<syn::Block>(func.block.clone()) {
227 ReplaceCoroutineAwait {
228 resume_type: resume_type.clone(),
229 }
230 .visit_block_mut(&mut block);
231
232 let maybe_static = input
233 .map(|x| {
234 if x.is_static {
235 quote!(static)
236 } else {
237 quote!()
238 }
239 })
240 .unwrap_or(quote!());
241
242 quote!({
243 #[coroutine] #maybe_static move |_: #resume_type| #block
244 })
245 } else {
246 func.block
247 };
248
249 quote! {
250 #(#attrs)*
251 #vis fn #name<#generic_params>(#args) -> impl ::core::ops::Coroutine<
252 #resume_type,
253 Yield = #yield_type,
254 Return = #return_type
255 > #precise_captures #where_clause #new_body
256 }
257 .into()
258 } else {
259 ts
260 }
261}