genawaiter2_proc_macro/
lib.rs

1#![warn(future_incompatible, rust_2018_compatibility, rust_2018_idioms, unused)]
2#![warn(clippy::cargo, clippy::pedantic)]
3#![cfg_attr(feature = "strict", deny(warnings))]
4
5use crate::visit::YieldReplace;
6use proc_macro::TokenStream;
7use proc_macro_error::{abort, abort_call_site, proc_macro_error};
8use quote::quote;
9use std::string::ToString;
10use syn::{
11    self,
12    parse_macro_input,
13    parse_str,
14    spanned::Spanned,
15    visit_mut::VisitMut,
16    ExprBlock,
17    FnArg,
18    Ident,
19    ItemFn,
20    Type,
21};
22
23mod visit;
24
25#[proc_macro_attribute]
26#[proc_macro_error]
27pub fn stack_producer_fn(args: TokenStream, input: TokenStream) -> TokenStream {
28    let a = args.clone();
29    // make sure it is a valid type
30    let _ = parse_macro_input!(a as Type);
31    let mut function = parse_macro_input!(input as ItemFn);
32
33    let co_arg = format!("{}{}>", stack::CO_ARG_FN, args);
34    add_coroutine_arg(&mut function, &co_arg);
35
36    YieldReplace.visit_item_fn_mut(&mut function);
37
38    let tokens = quote! { #function };
39    tokens.into()
40}
41
42#[proc_macro]
43#[proc_macro_error]
44pub fn stack_producer(input: TokenStream) -> TokenStream {
45    let mut input = parse_macro_input!(input as ExprBlock);
46
47    YieldReplace.visit_expr_block_mut(&mut input);
48    // for some reason parsing as a PatType (correct for closures) fails
49    // the only way around is to destructure.
50    let arg = match parse_str::<FnArg>(stack::CO_ARG) {
51        Ok(FnArg::Typed(x)) => x,
52        _ => abort_call_site!("string Pat parse failed Co<...>"),
53    };
54
55    let tokens = quote! { |#arg| async move #input };
56    tokens.into()
57}
58
59#[proc_macro_attribute]
60#[proc_macro_error]
61pub fn sync_producer_fn(args: TokenStream, input: TokenStream) -> TokenStream {
62    let a = args.clone();
63    // make sure it is a valid type
64    let _ = parse_macro_input!(a as Type);
65    let mut function = parse_macro_input!(input as ItemFn);
66
67    let co_arg = format!("{}{}>", sync::CO_ARG_FN, args);
68    add_coroutine_arg(&mut function, &co_arg);
69
70    YieldReplace.visit_item_fn_mut(&mut function);
71
72    let tokens = quote! { #function };
73    tokens.into()
74}
75
76#[proc_macro]
77#[proc_macro_error]
78pub fn sync_producer(input: TokenStream) -> TokenStream {
79    let mut input = parse_macro_input!(input as ExprBlock);
80
81    YieldReplace.visit_expr_block_mut(&mut input);
82    // for some reason parsing as a PatType (correct for closures) fails
83    let arg = match parse_str::<FnArg>(sync::CO_ARG) {
84        Ok(FnArg::Typed(x)) => x,
85        _ => abort_call_site!("string Pat parse failed Co<...>"),
86    };
87
88    let tokens = quote! { |#arg| async move #input };
89    tokens.into()
90}
91
92#[proc_macro_attribute]
93#[proc_macro_error]
94pub fn rc_producer_fn(args: TokenStream, input: TokenStream) -> TokenStream {
95    let a = args.clone();
96    // make sure it is a valid type
97    let _ = parse_macro_input!(a as Type);
98    let mut function = parse_macro_input!(input as ItemFn);
99
100    let co_arg = format!("{}{}>", rc::CO_ARG_FN, args);
101    add_coroutine_arg(&mut function, &co_arg);
102
103    YieldReplace.visit_item_fn_mut(&mut function);
104
105    let tokens = quote! { #function };
106    tokens.into()
107}
108
109#[proc_macro]
110#[proc_macro_error]
111pub fn rc_producer(input: TokenStream) -> TokenStream {
112    let mut input = parse_macro_input!(input as ExprBlock);
113
114    YieldReplace.visit_expr_block_mut(&mut input);
115    // for some reason parsing as a PatType (correct for closures) fails
116    let arg = match parse_str::<FnArg>(rc::CO_ARG) {
117        Ok(FnArg::Typed(x)) => x,
118        _ => abort_call_site!("string Pat parse failed Co<...>"),
119    };
120
121    let tokens = quote! { |#arg| async move #input };
122    tokens.into()
123}
124
125mod stack {
126    pub(crate) const CO_ARG_FN: &str =
127        "mut __private_co_arg__: genawaiter2::stack::Co<'_, ";
128    pub(crate) const CO_ARG: &str =
129        "mut __private_co_arg__: genawaiter2::stack::Co<'_, _, _>";
130}
131
132mod sync {
133    pub(crate) const CO_ARG_FN: &str =
134        "mut __private_co_arg__: genawaiter2::sync::Co<";
135    pub(crate) const CO_ARG: &str =
136        "mut __private_co_arg__: genawaiter2::sync::Co<_, _>";
137}
138
139mod rc {
140    pub(crate) const CO_ARG_FN: &str = "mut __private_co_arg__: genawaiter2::rc::Co<";
141    pub(crate) const CO_ARG: &str =
142        "mut __private_co_arg__: genawaiter2::rc::Co<_, _>";
143}
144
145/// Mutates the input `Punctuated<FnArg, Comma>` to a lifetimeless `co:
146/// Co<{type}>`.
147fn add_coroutine_arg(func: &mut ItemFn, co_ty: &str) {
148    let co_arg_found = func.sig.inputs.iter().any(|input| {
149        match input {
150            FnArg::Receiver(_) => false,
151            FnArg::Typed(arg) => {
152                match &*arg.ty {
153                    Type::Path(ty) => {
154                        ty.path.segments.iter().any(|seg| {
155                            seg.ident
156                                == parse_str::<Ident>("Co").expect("Ident parse failed")
157                        })
158                    }
159                    _ => false,
160                }
161            }
162        }
163    });
164    if !co_arg_found {
165        let co_arg: FnArg = match parse_str::<FnArg>(co_ty) {
166            Ok(s) => s,
167            Err(err) => abort_call_site!(format!("invalid type for Co yield {}", err)),
168        };
169        func.sig.inputs.push_value(co_arg)
170    } else {
171        abort!(
172            func.sig.span(),
173            "A generator producer cannot accept any arguments. Instead, consider \
174             using a closure and capturing the values you need.",
175        )
176    }
177}