1extern crate proc_macro;
2
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::quote;
5use syn::parse::Parser;
6use syn::punctuated::Punctuated;
7use syn::Attribute;
8
9type AttributeArgs = Punctuated<syn::Meta, syn::Token![,]>;
10
11#[derive(Default)]
12struct Configuration {
13 crate_name: Option<Ident>,
14 parallelism: Option<usize>,
15 send: Option<bool>,
16}
17
18impl Configuration {
19 fn set_send(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
20 let span = lit.span();
21 if self.send.is_some() {
22 return Err(syn::Error::new(span, "`send` already set"));
23 }
24 if let syn::Lit::Bool(lit) = lit {
25 self.send = Some(lit.value);
26 return Ok(());
27 }
28 Err(syn::Error::new(span, "invalid `send` value, bool required"))
29 }
30
31 fn set_crate_name(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
32 let span = lit.span();
33 if self.crate_name.is_some() {
34 return Err(syn::Error::new(span, "crate name already set"));
35 }
36 if let syn::Lit::Str(s) = lit {
37 if let Ok(path) = s.parse::<syn::Path>() {
38 if let Some(ident) = path.get_ident() {
39 self.crate_name = Some(ident.clone());
40 return Ok(());
41 }
42 }
43 return Err(syn::Error::new(span, format!("invalid crate name: {}", s.value())));
44 }
45 Err(syn::Error::new(span, "invalid crate name"))
46 }
47
48 fn set_parallelism(&mut self, lit: &syn::Lit) -> Result<(), syn::Error> {
49 let span = lit.span();
50 if self.parallelism.is_some() {
51 return Err(syn::Error::new(span, "parallelism already set"));
52 }
53 if let syn::Lit::Int(lit) = lit {
54 let parallelism = lit.base10_parse::<isize>()?;
55 if parallelism >= 0 {
56 self.parallelism = Some(parallelism as usize);
57 return Ok(());
58 }
59 }
60 Err(syn::Error::new(span, "parallelism should be non negative integer"))
61 }
62}
63
64fn parse_config(args: AttributeArgs) -> Result<Configuration, syn::Error> {
65 let mut config = Configuration::default();
66 for arg in args {
67 match arg {
68 syn::Meta::NameValue(name_value) => {
69 let name = name_value
70 .path
71 .get_ident()
72 .ok_or_else(|| syn::Error::new_spanned(&name_value, "invalid attribute name"))?
73 .to_string();
74 let lit = match &name_value.value {
75 syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
76 expr => return Err(syn::Error::new_spanned(expr, format!("{name} expect literal value"))),
77 };
78 match name.as_str() {
79 "parallelism" => config.set_parallelism(lit)?,
80 "crate" => config.set_crate_name(lit)?,
81 "send" => config.set_send(lit)?,
82 _ => return Err(syn::Error::new_spanned(&name_value, "unknown attribute name")),
83 }
84 },
85 _ => return Err(syn::Error::new_spanned(arg, "unknown attribute")),
86 }
87 }
88 Ok(config)
89}
90
91fn is_test_attribute(attr: &Attribute) -> bool {
96 let path = match &attr.meta {
97 syn::Meta::Path(path) => path,
98 _ => return false,
99 };
100 let candidates = [["core", "prelude", "*", "test"], ["std", "prelude", "*", "test"]];
101 if path.leading_colon.is_none()
102 && path.segments.len() == 1
103 && path.segments[0].arguments.is_none()
104 && path.segments[0].ident == "test"
105 {
106 return true;
107 } else if path.segments.len() != candidates[0].len() {
108 return false;
109 }
110 candidates.into_iter().any(|segments| {
111 path.segments
112 .iter()
113 .zip(segments)
114 .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
115 })
116}
117
118fn generate(attr: TokenStream, item: TokenStream) -> TokenStream {
119 let config = AttributeArgs::parse_terminated.parse2(attr).and_then(parse_config).unwrap();
120
121 let input = syn::parse2::<syn::ItemFn>(item).unwrap();
122
123 let ret = &input.sig.output;
124 let name = &input.sig.ident;
125 let body = &input.block;
126 let attrs = &input.attrs;
127 let vis = &input.vis;
128
129 let crate_name = config.crate_name.unwrap_or_else(|| Ident::new("asyncs", Span::call_site()));
130 let macro_name = format!("#[{crate_name}:test]");
131
132 if input.sig.asyncness.is_none() {
133 let err = syn::Error::new_spanned(input, format!("only asynchronous function can be tagged with {macro_name}"));
134 return err.into_compile_error();
135 }
136
137 if let Some(attr) = attrs.clone().into_iter().find(is_test_attribute) {
138 let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes";
139 return syn::Error::new_spanned(attr, msg).into_compile_error();
140 };
141
142 let prefer_env_parallelism = config.parallelism.is_none();
143 let parallelism = config.parallelism.unwrap_or(2);
144 let parallelism = quote! {
145 let parallelism = match (#prefer_env_parallelism, #parallelism) {
146 (true, parallelism) => match ::std::env::var("ASYNCS_TEST_PARALLELISM") {
147 ::std::result::Result::Err(_) => parallelism,
148 ::std::result::Result::Ok(val) => match val.parse::<usize>() {
149 ::std::result::Result::Err(_) => parallelism,
150 ::std::result::Result::Ok(n) => n,
151 }
152 }
153 (false, parallelism) => parallelism,
154 };
155 };
156
157 let send = config.send.unwrap_or(true);
158 if send {
159 quote! {
160 #(#attrs)*
161 #[::core::prelude::v1::test]
162 #vis fn #name() #ret {
163 #parallelism
164 #crate_name::__executor::Blocking::new(parallelism).block_on(async move #body)
165 }
166 }
167 } else {
168 quote! {
169 #(#attrs)*
170 #[::core::prelude::v1::test]
171 #vis fn #name() #ret {
172 struct _Sendable<T>(T);
173
174 unsafe impl<T> Send for _Sendable<T> {}
175
176 impl<T: ::core::future::Future> ::core::future::Future for _Sendable<T> {
177 type Output = T::Output;
178
179 fn poll(self: ::core::pin::Pin<&mut Self>, cx: &mut ::core::task::Context<'_>) -> ::core::task::Poll<Self::Output> {
180 let future = unsafe { ::core::pin::Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
181 future.poll(cx)
182 }
183 }
184
185 #parallelism
186 #crate_name::__executor::Blocking::new(parallelism).block_on(_Sendable(async move #body))
187 }
188 }
189 }
190}
191
192#[proc_macro_attribute]
214pub fn test(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
215 generate(attr.into(), item.into()).into()
216}