Skip to main content

smol_potat_derive/
lib.rs

1#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)]
2#![deny(missing_debug_implementations, nonstandard_style)]
3#![recursion_limit = "512"]
4
5use proc_macro::TokenStream;
6use quote::{quote, quote_spanned};
7use syn::spanned::Spanned;
8
9/// Enables an async main function.
10///
11/// # Examples
12///
13/// ## Single-Threaded
14///
15/// By default, this spawns the single thread executor.
16///
17/// ```ignore
18/// #[smol_potat::main]
19/// async fn main() -> std::io::Result<()> {
20///     Ok(())
21/// }
22/// ```
23///
24/// ## Automatic Threadpool
25///
26/// Alternatively, `smol_potat::main` can used to automatically
27/// set the number of threads by adding the `auto` feature (off
28/// by default).
29///
30/// ```ignore
31/// #[smol_potat::main] // with 'auto' feature enabled
32/// async fn main() -> std::io::Result<()> {
33///     Ok(())
34/// }
35/// ```
36///
37/// ## Manually Configure Threads
38///
39/// To manually set the number of threads, add this to the attribute:
40///
41/// ```ignore
42/// #[smol_potat::main(threads=3)]
43/// async fn main() -> std::io::Result<()> {
44///     Ok(())
45/// }
46/// ```
47#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
48#[proc_macro_attribute]
49pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
50    let input = syn::parse_macro_input!(item as syn::ItemFn);
51    let args = syn::parse_macro_input!(attr as syn::AttributeArgs);
52
53    let ret = &input.sig.output;
54    let inputs = &input.sig.inputs;
55    let name = &input.sig.ident;
56    let body = &input.block;
57    let attrs = &input.attrs;
58    let mut threads = None;
59
60    for arg in args {
61        match arg {
62            syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
63                let ident = namevalue.path.get_ident();
64                if ident.is_none() {
65                    return TokenStream::from(quote_spanned! { ident.span() =>
66                        compile_error!("Must have specified ident"),
67                    });
68                }
69                match ident.unwrap().to_string().to_lowercase().as_str() {
70                    "threads" => {
71                        match &namevalue.lit {
72                            syn::Lit::Int(expr) => {
73                                let num = expr.base10_parse::<u32>().unwrap();
74                                if num > 1 {
75                                    threads = Some(num);
76                                }
77                            }
78                            _ => {
79                                return TokenStream::from(quote_spanned! { namevalue.span() =>
80                                    compile_error!("threads argument must be an int"),
81                                });
82                            }
83                        }
84                    }
85                    name => {
86                        return TokenStream::from(quote_spanned! { name.span() =>
87                            compile_error!("Unknown attribute pair {} is specified; expected: `threads`"),
88                        });
89                    }
90                }
91            }
92            other => {
93                return TokenStream::from(quote_spanned! { other.span() =>
94                    compile_error!("Unknown attribute inside the macro"),
95                });
96            }
97        }
98    }
99
100    if name != "main" {
101        return TokenStream::from(quote_spanned! { name.span() =>
102            compile_error!("only the main function can be tagged with #[smol::main]"),
103        });
104    }
105
106    if input.sig.asyncness.is_none() {
107        return TokenStream::from(quote_spanned! { input.span() =>
108            compile_error!("the async keyword is missing from the function declaration"),
109        });
110    }
111
112    let result = match threads {
113        Some(num) => quote! {
114            fn main() #ret {
115                #(#attrs)*
116                async fn main(#inputs) #ret {
117                    #body
118                }
119
120                struct Pending;
121
122                impl std::future::Future for Pending {
123                    type Output = ();
124                    fn poll(
125                        self: std::pin::Pin<&mut Self>,
126                        _cx: &mut std::task::Context<'_>,
127                    ) -> std::task::Poll<Self::Output> {
128                        std::task::Poll::Pending
129                    }
130                }
131
132                for _ in 0..#num {
133                    std::thread::spawn(|| smol_potat::run(Pending));
134                }
135    
136                smol_potat::block_on(async {
137                    main().await
138                })
139            }
140        },
141        #[cfg(feature = "auto")]
142        _ => quote! {
143            fn main() #ret {
144                #(#attrs)*
145                async fn main(#inputs) #ret {
146                    #body
147                }
148
149                struct Pending;
150
151                impl std::future::Future for Pending {
152                    type Output = ();
153                    fn poll(
154                        self: std::pin::Pin<&mut Self>,
155                        _cx: &mut std::task::Context<'_>,
156                    ) -> std::task::Poll<Self::Output> {
157                        std::task::Poll::Pending
158                    }
159                }
160
161                let num_cpus = smol_potat::num_cpus::get().max(1);
162
163                for _ in 0..num_cpus {
164                    std::thread::spawn(|| smol_potat::run(Pending));
165                }
166    
167                smol_potat::block_on(async {
168                    main().await
169                })
170            }
171        },
172        #[cfg(not(feature = "auto"))]
173        _ => quote! {
174            fn main() #ret {
175                #(#attrs)*
176                async fn main(#inputs) #ret {
177                    #body
178                }
179
180                smol_potat::run(async {
181                    main().await
182                })
183            }
184        }
185    };
186
187    result.into()
188}
189
190/// Enables an async test function.
191///
192/// # Examples
193///
194/// ```ignore
195/// #[smol_potat::test]
196/// async fn my_test() -> std::io::Result<()> {
197///     assert_eq!(2 * 2, 4);
198///     Ok(())
199/// }
200/// ```
201#[proc_macro_attribute]
202pub fn test(_attr: TokenStream, item: TokenStream) -> TokenStream {
203    let input = syn::parse_macro_input!(item as syn::ItemFn);
204
205    let ret = &input.sig.output;
206    let name = &input.sig.ident;
207    let body = &input.block;
208    let attrs = &input.attrs;
209
210    if input.sig.asyncness.is_none() {
211        return TokenStream::from(quote_spanned! { input.span() =>
212            compile_error!("the async keyword is missing from the function declaration"),
213        });
214    }
215
216    let result = quote! {
217        #[test]
218        #(#attrs)*
219        fn #name() #ret {
220            smol::run(async { #body })
221        }
222    };
223
224    result.into()
225}
226
227/// Enables an async benchmark function.
228///
229/// # Examples
230///
231/// ```ignore
232/// #![feature(test)]
233/// extern crate test;
234///
235/// #[smol_potat::bench]
236/// async fn bench() {
237///     println!("hello world");
238/// }
239/// ```
240#[proc_macro_attribute]
241pub fn bench(_attr: TokenStream, item: TokenStream) -> TokenStream {
242    let input = syn::parse_macro_input!(item as syn::ItemFn);
243
244    let ret = &input.sig.output;
245    let args = &input.sig.inputs;
246    let name = &input.sig.ident;
247    let body = &input.block;
248    let attrs = &input.attrs;
249
250    if input.sig.asyncness.is_none() {
251        return TokenStream::from(quote_spanned! { input.span() =>
252            compile_error!("the async keyword is missing from the function declaration"),
253        });
254    }
255
256    if !args.is_empty() {
257        return TokenStream::from(quote_spanned! { args.span() =>
258            compile_error!("async benchmarks don't take any arguments"),
259        });
260    }
261
262    let result = quote! {
263        #[bench]
264        #(#attrs)*
265        fn #name(b: &mut test::Bencher) #ret {
266            let _ = b.iter(|| {
267                smol::block_on(async {
268                    #body
269                })
270            });
271        }
272    };
273
274    result.into()
275}