1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parser as _;
4use syn::punctuated::Punctuated;
5use syn::Attribute;
6use syn::ItemFn;
7use syn::Lit;
8use syn::Meta;
9use syn::MetaList;
10use syn::MetaNameValue;
11use syn::NestedMeta;
12use syn::Token;
13
14struct FlakyTestArgs {
15 times: usize,
16 runtime: Runtime,
17}
18
19enum Runtime {
20 Sync,
21 Tokio(Option<Punctuated<NestedMeta, Token![,]>>),
22}
23
24impl Default for FlakyTestArgs {
25 fn default() -> Self {
26 FlakyTestArgs {
27 times: 3,
28 runtime: Runtime::Sync,
29 }
30 }
31}
32
33fn parse_attr(attr: proc_macro2::TokenStream) -> syn::Result<FlakyTestArgs> {
34 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
35 let punctuated = parser.parse2(attr)?;
36
37 let mut ret = FlakyTestArgs::default();
38
39 for meta in punctuated {
40 match meta {
41 Meta::Path(path) => {
42 if path.is_ident("tokio") {
43 ret.runtime = Runtime::Tokio(None);
44 } else {
45 return Err(syn::Error::new_spanned(path, "expected `tokio`"));
46 }
47 }
48 Meta::NameValue(MetaNameValue {
49 path,
50 lit: Lit::Int(lit_int),
51 ..
52 }) => {
53 if path.is_ident("times") {
54 ret.times = lit_int.base10_parse::<usize>()?;
55 } else {
56 return Err(syn::Error::new_spanned(
57 path,
58 "expected `times = <int>`",
59 ));
60 }
61 }
62 Meta::List(MetaList { path, nested, .. }) => {
63 if path.is_ident("tokio") {
64 ret.runtime = Runtime::Tokio(Some(nested));
65 } else {
66 return Err(syn::Error::new_spanned(path, "expected `tokio`"));
67 }
68 }
69 _ => {
70 return Err(syn::Error::new_spanned(
71 meta,
72 "expected `times = <int>` or `tokio`",
73 ));
74 }
75 }
76 }
77
78 Ok(ret)
79}
80
81#[proc_macro_attribute]
127pub fn flaky_test(attr: TokenStream, input: TokenStream) -> TokenStream {
128 let attr = proc_macro2::TokenStream::from(attr);
129 let mut input = proc_macro2::TokenStream::from(input);
130
131 match inner(attr, input.clone()) {
132 Err(e) => {
133 input.extend(e.into_compile_error());
134 input.into()
135 }
136 Ok(t) => t.into(),
137 }
138}
139
140fn inner(
141 attr: proc_macro2::TokenStream,
142 input: proc_macro2::TokenStream,
143) -> syn::Result<proc_macro2::TokenStream> {
144 let args = parse_attr(attr)?;
145 let input_fn: ItemFn = syn::parse2(input)?;
146 let attrs = input_fn.attrs.clone();
147
148 match args.runtime {
149 Runtime::Sync => sync(input_fn, attrs, args.times),
150 Runtime::Tokio(tokio_args) => {
151 tokio(input_fn, attrs, args.times, tokio_args)
152 }
153 }
154}
155
156fn sync(
157 input_fn: ItemFn,
158 attrs: Vec<Attribute>,
159 times: usize,
160) -> syn::Result<proc_macro2::TokenStream> {
161 let fn_name = input_fn.sig.ident.clone();
162
163 Ok(quote! {
164 #[test]
165 #(#attrs)*
166 fn #fn_name() {
167 #input_fn
168
169 for i in 0..#times {
170 println!("flaky_test retry {}", i);
171 let r = ::std::panic::catch_unwind(|| {
172 #fn_name();
173 });
174 if r.is_ok() {
175 return;
176 }
177 if i == #times - 1 {
178 ::std::panic::resume_unwind(r.unwrap_err());
179 }
180 }
181 }
182 })
183}
184
185fn tokio(
186 input_fn: ItemFn,
187 attrs: Vec<Attribute>,
188 times: usize,
189 tokio_args: Option<Punctuated<NestedMeta, Token![,]>>,
190) -> syn::Result<proc_macro2::TokenStream> {
191 if input_fn.sig.asyncness.is_none() {
192 return Err(syn::Error::new_spanned(input_fn.sig, "must be `async fn`"));
193 }
194
195 let fn_name = input_fn.sig.ident.clone();
196 let tokio_macro = match tokio_args {
197 Some(args) => quote! { #[::tokio::test(#args)] },
198 None => quote! { #[::tokio::test] },
199 };
200
201 Ok(quote! {
202 #tokio_macro
203 #(#attrs)*
204 async fn #fn_name() {
205 #input_fn
206
207 for i in 0..#times {
208 println!("flaky_test retry {}", i);
209 let fut = ::std::panic::AssertUnwindSafe(#fn_name());
210 let r = <_ as ::flaky_test::futures_util::future::FutureExt>::catch_unwind(fut).await;
211 if r.is_ok() {
212 return;
213 }
214 if i == #times - 1 {
215 ::std::panic::resume_unwind(r.unwrap_err());
216 }
217 }
218 }
219 })
220}