silpkg_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{parse::Parse, visit_mut::VisitMut, Token};
5
6struct LifetimeAdder {
7    lifetime: syn::Lifetime,
8}
9
10impl VisitMut for LifetimeAdder {
11    fn visit_type_reference_mut(&mut self, i: &mut syn::TypeReference) {
12        if i.lifetime.is_none() {
13            i.lifetime = Some(self.lifetime.clone())
14        }
15    }
16}
17
18struct ReplaceCoroutineAwait {
19    resume_type: syn::Type,
20}
21
22impl VisitMut for ReplaceCoroutineAwait {
23    fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
24        match i {
25            syn::Expr::Await(ei) => {
26                assert!(ei.attrs.is_empty());
27
28                let resume_type = self.resume_type.clone();
29                let base = &ei.base;
30
31                *i = syn::parse_quote! {
32                    {
33                        let mut __coroutine = #base;
34                        let mut __response: #resume_type = Default::default();
35
36                        loop {
37                            use ::core::{pin::Pin, ops::{Coroutine, CoroutineState}};
38
39                            match unsafe { Pin::new_unchecked(&mut __coroutine) }.resume(__response) {
40                                CoroutineState::Yielded(__request) => __response = yield __request.into(),
41                                CoroutineState::Complete(__result) => break __result,
42                            }
43                        }
44                    }
45                };
46            }
47            _ => syn::visit_mut::visit_expr_mut(self, i),
48        }
49    }
50}
51
52mod kw {
53    syn::custom_keyword!(lifetime);
54}
55
56struct CoroutineInput {
57    is_static: bool,
58
59    yield_type: Option<syn::Type>,
60    resume_type: Option<syn::Type>,
61    capture: CaptureMode,
62}
63
64enum CaptureMode {
65    Implicit,
66    Explicit(syn::PreciseCapture),
67    None,
68}
69
70impl Parse for CoroutineInput {
71    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
72        let mut output = Self {
73            is_static: false,
74            yield_type: None,
75            resume_type: None,
76            capture: CaptureMode::Implicit,
77        };
78
79        {
80            let lk = input.lookahead1();
81
82            if lk.peek(Token![static]) {
83                output.is_static = true;
84                input.parse::<Token![static]>()?;
85                if input.is_empty() {
86                    return Ok(output);
87                } else {
88                    input.parse::<Token![,]>()?;
89                    input.parse::<Token![yield]>()?;
90                }
91            } else if lk.peek(Token![yield]) {
92            } else {
93                return Err(lk.error());
94            }
95        }
96
97        if input.is_empty() {
98            return Ok(output);
99        }
100
101        output.yield_type = Some(input.parse::<syn::Type>()?);
102
103        if input.is_empty() {
104            return Ok(output);
105        }
106
107        input.parse::<Token![->]>()?;
108        output.resume_type = Some(input.parse::<syn::Type>()?);
109
110        if input.is_empty() {
111            return Ok(output);
112        }
113
114        input.parse::<Token![,]>()?;
115
116        if input.parse::<Option<Token![!]>>()?.is_some() {
117            input.parse::<Token![use]>()?;
118            output.capture = CaptureMode::None;
119        } else {
120            output.capture = CaptureMode::Explicit(input.parse::<syn::PreciseCapture>()?);
121        }
122
123        Ok(output)
124    }
125}
126
127struct BareItemFn {
128    attrs: Vec<syn::Attribute>,
129    vis: syn::Visibility,
130    sig: syn::Signature,
131    block: TokenStream2,
132}
133
134impl Parse for BareItemFn {
135    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
136        Ok(Self {
137            attrs: syn::Attribute::parse_outer(input)?,
138            vis: syn::Visibility::parse(input)?,
139            sig: syn::Signature::parse(input)?,
140            block: input.parse::<TokenStream2>()?,
141        })
142    }
143}
144
145#[proc_macro_attribute]
146pub fn generator(attr_ts: TokenStream, ts: TokenStream) -> TokenStream {
147    let func_result = syn::parse::<BareItemFn>(ts.clone());
148    let input_result = match attr_ts.is_empty() {
149        true => Ok(None),
150        false => syn::parse::<CoroutineInput>(attr_ts).map(Some),
151    };
152
153    if let Err(err) = &input_result {
154        panic!("{err}");
155    }
156
157    if let (Ok(func), Ok(input)) = (func_result, input_result) {
158        let unit_type: syn::Type = syn::parse_quote!(());
159
160        let attrs = func.attrs;
161        let vis = func.vis;
162        let name = func.sig.ident;
163        let mut generics = func.sig.generics;
164        let mut args = func.sig.inputs;
165        let return_type = match func.sig.output {
166            syn::ReturnType::Default => &unit_type,
167            syn::ReturnType::Type(_, ref tp) => tp,
168        };
169
170        let implicit_lifetime = if input
171            .as_ref()
172            .is_none_or(|x| matches!(x.capture, CaptureMode::Implicit))
173        {
174            let lt = syn::Lifetime::new("'__coroutine", Span::call_site());
175            generics.params.insert(0, syn::parse_quote!(#lt));
176            Some(lt)
177        } else {
178            None
179        };
180
181        if let Some(implicit_lifetime) = implicit_lifetime.clone() {
182            let mut ladder = LifetimeAdder {
183                lifetime: implicit_lifetime.clone(),
184            };
185            for arg in args.iter_mut() {
186                match arg {
187                    syn::FnArg::Receiver(recv) => {
188                        if let Some((_, lifetime @ None)) = &mut recv.reference {
189                            *lifetime = Some(implicit_lifetime.clone())
190                        }
191                    }
192                    syn::FnArg::Typed(pat) => ladder.visit_pat_type_mut(pat),
193                }
194            }
195        }
196
197        let (yield_type, resume_type) = {
198            let opts = input
199                .as_ref()
200                .map(|x| (x.yield_type.clone(), x.resume_type.clone()))
201                .unwrap_or_default();
202
203            (
204                opts.0.unwrap_or_else(|| unit_type.clone()),
205                opts.1.unwrap_or_else(|| unit_type.clone()),
206            )
207        };
208
209        let generic_params = generics.params;
210        let where_clause = generics.where_clause;
211        let precise_captures = match input
212            .as_ref()
213            .map(|x| &x.capture)
214            .unwrap_or(&CaptureMode::Implicit)
215        {
216            CaptureMode::Implicit => {
217                let lifetime = implicit_lifetime.as_ref().unwrap();
218                quote! { + use<#lifetime> }
219            }
220            CaptureMode::Explicit(precise_capture) => {
221                quote! { + #precise_capture }
222            }
223            CaptureMode::None => TokenStream2::new(),
224        };
225
226        let new_body = if let Ok(mut block) = syn::parse2::<syn::Block>(func.block.clone()) {
227            ReplaceCoroutineAwait {
228                resume_type: resume_type.clone(),
229            }
230            .visit_block_mut(&mut block);
231
232            let maybe_static = input
233                .map(|x| {
234                    if x.is_static {
235                        quote!(static)
236                    } else {
237                        quote!()
238                    }
239                })
240                .unwrap_or(quote!());
241
242            quote!({
243                #[coroutine] #maybe_static move |_: #resume_type| #block
244            })
245        } else {
246            func.block
247        };
248
249        quote! {
250            #(#attrs)*
251            #vis fn #name<#generic_params>(#args) -> impl ::core::ops::Coroutine<
252                #resume_type,
253                Yield = #yield_type,
254                Return = #return_type
255            > #precise_captures #where_clause #new_body
256        }
257        .into()
258    } else {
259        ts
260    }
261}