asyncs_test/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::quote;
5use syn::parse::Parser;
6use syn::punctuated::Punctuated;
7use syn::Attribute;
8
9type AttributeArgs = Punctuated<syn::Meta, syn::Token![,]>;
10
11#[derive(Default)]
12struct Configuration {
13    crate_name: Option<Ident>,
14    parallelism: Option<usize>,
15    send: Option<bool>,
16}
17
18impl Configuration {
19    fn set_send(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
20        let span = lit.span();
21        if self.send.is_some() {
22            return Err(syn::Error::new(span, "`send` already set"));
23        }
24        if let syn::Lit::Bool(lit) = lit {
25            self.send = Some(lit.value);
26            return Ok(());
27        }
28        Err(syn::Error::new(span, "invalid `send` value, bool required"))
29    }
30
31    fn set_crate_name(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
32        let span = lit.span();
33        if self.crate_name.is_some() {
34            return Err(syn::Error::new(span, "crate name already set"));
35        }
36        if let syn::Lit::Str(s) = lit {
37            if let Ok(path) = s.parse::<syn::Path>() {
38                if let Some(ident) = path.get_ident() {
39                    self.crate_name = Some(ident.clone());
40                    return Ok(());
41                }
42            }
43            return Err(syn::Error::new(span, format!("invalid crate name: {}", s.value())));
44        }
45        Err(syn::Error::new(span, "invalid crate name"))
46    }
47
48    fn set_parallelism(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
49        let span = lit.span();
50        if self.parallelism.is_some() {
51            return Err(syn::Error::new(span, "parallelism already set"));
52        }
53        if let syn::Lit::Int(lit) = lit {
54            let parallelism = lit.base10_parse::<isize>()?;
55            if parallelism >= 0 {
56                self.parallelism = Some(parallelism as usize);
57                return Ok(());
58            }
59        }
60        Err(syn::Error::new(span, "parallelism should be non negative integer"))
61    }
62}
63
64fn parse_config(args: AttributeArgs) -> Result<Configuration, syn::Error> {
65    let mut config = Configuration::default();
66    for arg in args {
67        match arg {
68            syn::Meta::NameValue(name_value) => {
69                let name = name_value
70                    .path
71                    .get_ident()
72                    .ok_or_else(|| syn::Error::new_spanned(&name_value, "invalid attribute name"))?
73                    .to_string();
74                let lit = match &name_value.value {
75                    syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
76                    expr => return Err(syn::Error::new_spanned(expr, format!("{name} expect literal value"))),
77                };
78                match name.as_str() {
79                    "parallelism" => config.set_parallelism(lit)?,
80                    "crate" => config.set_crate_name(lit)?,
81                    "send" => config.set_send(lit)?,
82                    _ => return Err(syn::Error::new_spanned(&name_value, "unknown attribute name")),
83                }
84            },
85            _ => return Err(syn::Error::new_spanned(arg, "unknown attribute")),
86        }
87    }
88    Ok(config)
89}
90
91// Check whether given attribute is a test attribute of forms:
92// * `#[test]`
93// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]`
94// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]`
95fn is_test_attribute(attr: &Attribute) -> bool {
96    let path = match &attr.meta {
97        syn::Meta::Path(path) => path,
98        _ => return false,
99    };
100    let candidates = [["core", "prelude", "*", "test"], ["std", "prelude", "*", "test"]];
101    if path.leading_colon.is_none()
102        && path.segments.len() == 1
103        && path.segments[0].arguments.is_none()
104        && path.segments[0].ident == "test"
105    {
106        return true;
107    } else if path.segments.len() != candidates[0].len() {
108        return false;
109    }
110    candidates.into_iter().any(|segments| {
111        path.segments
112            .iter()
113            .zip(segments)
114            .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
115    })
116}
117
118fn generate(attr: TokenStream, item: TokenStream) -> TokenStream {
119    let config = AttributeArgs::parse_terminated.parse2(attr).and_then(parse_config).unwrap();
120
121    let input = syn::parse2::<syn::ItemFn>(item).unwrap();
122
123    let ret = &input.sig.output;
124    let name = &input.sig.ident;
125    let body = &input.block;
126    let attrs = &input.attrs;
127    let vis = &input.vis;
128
129    let crate_name = config.crate_name.unwrap_or_else(|| Ident::new("asyncs", Span::call_site()));
130    let macro_name = format!("#[{crate_name}:test]");
131
132    if input.sig.asyncness.is_none() {
133        let err = syn::Error::new_spanned(input, format!("only asynchronous function can be tagged with {macro_name}"));
134        return err.into_compile_error();
135    }
136
137    if let Some(attr) = attrs.clone().into_iter().find(is_test_attribute) {
138        let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes";
139        return syn::Error::new_spanned(attr, msg).into_compile_error();
140    };
141
142    let prefer_env_parallelism = config.parallelism.is_none();
143    let parallelism = config.parallelism.unwrap_or(2);
144    let parallelism = quote! {
145        let parallelism = match (#prefer_env_parallelism, #parallelism) {
146            (true, parallelism) => match ::std::env::var("ASYNCS_TEST_PARALLELISM") {
147                ::std::result::Result::Err(_) => parallelism,
148                ::std::result::Result::Ok(val) => match val.parse::<usize>() {
149                    ::std::result::Result::Err(_) => parallelism,
150                    ::std::result::Result::Ok(n) => n,
151                }
152            }
153            (false, parallelism) => parallelism,
154        };
155    };
156
157    let send = config.send.unwrap_or(true);
158    if send {
159        quote! {
160            #(#attrs)*
161            #[::core::prelude::v1::test]
162            #vis fn #name() #ret {
163                #parallelism
164                #crate_name::__executor::Blocking::new(parallelism).block_on(async move #body)
165            }
166        }
167    } else {
168        quote! {
169            #(#attrs)*
170            #[::core::prelude::v1::test]
171            #vis fn #name() #ret {
172                struct _Sendable<T>(T);
173
174                unsafe impl<T> Send for _Sendable<T> {}
175
176                impl<T: ::core::future::Future> ::core::future::Future for _Sendable<T> {
177                    type Output = T::Output;
178
179                    fn poll(self: ::core::pin::Pin<&mut Self>, cx: &mut ::core::task::Context<'_>) -> ::core::task::Poll<Self::Output> {
180                        let future = unsafe { ::core::pin::Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
181                        future.poll(cx)
182                    }
183                }
184
185                #parallelism
186                #crate_name::__executor::Blocking::new(parallelism).block_on(_Sendable(async move #body))
187            }
188        }
189    }
190}
191
192/// Converts async function to test against a sample runtime.
193///
194/// ## Options
195/// * `parallelism`: non negative integer to specify parallelism for executor. Defaults to
196///   environment variable `ASYNCS_TEST_PARALLELISM` and `2` in fallback. `0` means available
197///   cores.
198/// * `send`: whether the async function need to be `Send`. Defaults to `true`.
199///
200/// ## Examples
201/// ```ignore
202/// use std::future::pending;
203///
204/// #[asyncs::test]
205/// async fn pending_default() {
206///     let v = select! {
207///         default => 5,
208///         i = pending() => i,
209///     };
210///     assert_eq!(v, 5);
211/// }
212/// ```
213#[proc_macro_attribute]
214pub fn test(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
215    generate(attr.into(), item.into()).into()
216}