flaky_test_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parser as _;
4use syn::punctuated::Punctuated;
5use syn::Attribute;
6use syn::ItemFn;
7use syn::Lit;
8use syn::Meta;
9use syn::MetaList;
10use syn::MetaNameValue;
11use syn::NestedMeta;
12use syn::Token;
13
14struct FlakyTestArgs {
15  times: usize,
16  runtime: Runtime,
17}
18
19enum Runtime {
20  Sync,
21  Tokio(Option<Punctuated<NestedMeta, Token![,]>>),
22}
23
24impl Default for FlakyTestArgs {
25  fn default() -> Self {
26    FlakyTestArgs {
27      times: 3,
28      runtime: Runtime::Sync,
29    }
30  }
31}
32
33fn parse_attr(attr: proc_macro2::TokenStream) -> syn::Result<FlakyTestArgs> {
34  let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
35  let punctuated = parser.parse2(attr)?;
36
37  let mut ret = FlakyTestArgs::default();
38
39  for meta in punctuated {
40    match meta {
41      Meta::Path(path) => {
42        if path.is_ident("tokio") {
43          ret.runtime = Runtime::Tokio(None);
44        } else {
45          return Err(syn::Error::new_spanned(path, "expected `tokio`"));
46        }
47      }
48      Meta::NameValue(MetaNameValue {
49        path,
50        lit: Lit::Int(lit_int),
51        ..
52      }) => {
53        if path.is_ident("times") {
54          ret.times = lit_int.base10_parse::<usize>()?;
55        } else {
56          return Err(syn::Error::new_spanned(
57            path,
58            "expected `times = <int>`",
59          ));
60        }
61      }
62      Meta::List(MetaList { path, nested, .. }) => {
63        if path.is_ident("tokio") {
64          ret.runtime = Runtime::Tokio(Some(nested));
65        } else {
66          return Err(syn::Error::new_spanned(path, "expected `tokio`"));
67        }
68      }
69      _ => {
70        return Err(syn::Error::new_spanned(
71          meta,
72          "expected `times = <int>` or `tokio`",
73        ));
74      }
75    }
76  }
77
78  Ok(ret)
79}
80
81/// A flaky test will be run multiple times until it passes.
82///
83/// # Example
84///
85/// ```rust
86/// use flaky_test::flaky_test;
87///
88/// // By default it will be retried up to 3 times.
89/// #[flaky_test]
90/// fn test_default() {
91///  println!("should pass");
92/// }
93///
94/// // The number of max attempts can be adjusted via `times`.
95/// #[flaky_test(times = 5)]
96/// fn usage_with_named_args() {
97///   println!("should pass");
98/// }
99///
100/// # use std::convert::Infallible;
101/// # async fn async_operation() -> Result<i32, Infallible> {
102/// #   Ok(42)
103/// # }
104/// // Async tests can be run by passing `tokio`.
105/// // Make sure `tokio` is added in your `Cargo.toml`.
106/// #[flaky_test(tokio)]
107/// async fn async_test() {
108///   let res = async_operation().await.unwrap();
109///   assert_eq!(res, 42);
110/// }
111///
112/// // `tokio` and `times` can be combined.
113/// #[flaky_test(tokio, times = 5)]
114/// async fn async_test_five_times() {
115///   let res = async_operation().await.unwrap();
116///   assert_eq!(res, 42);
117/// }
118///
119/// // Any arguments that `#[tokio::test]` supports can be specified.
120/// #[flaky_test(tokio(flavor = "multi_thraed", worker_threads = 2))]
121/// async fn async_test_complex() {
122///   let res = async_operation().await.unwrap();
123///   assert_eq!(res, 42);
124/// }
125/// ```
126#[proc_macro_attribute]
127pub fn flaky_test(attr: TokenStream, input: TokenStream) -> TokenStream {
128  let attr = proc_macro2::TokenStream::from(attr);
129  let mut input = proc_macro2::TokenStream::from(input);
130
131  match inner(attr, input.clone()) {
132    Err(e) => {
133      input.extend(e.into_compile_error());
134      input.into()
135    }
136    Ok(t) => t.into(),
137  }
138}
139
140fn inner(
141  attr: proc_macro2::TokenStream,
142  input: proc_macro2::TokenStream,
143) -> syn::Result<proc_macro2::TokenStream> {
144  let args = parse_attr(attr)?;
145  let input_fn: ItemFn = syn::parse2(input)?;
146  let attrs = input_fn.attrs.clone();
147
148  match args.runtime {
149    Runtime::Sync => sync(input_fn, attrs, args.times),
150    Runtime::Tokio(tokio_args) => {
151      tokio(input_fn, attrs, args.times, tokio_args)
152    }
153  }
154}
155
156fn sync(
157  input_fn: ItemFn,
158  attrs: Vec<Attribute>,
159  times: usize,
160) -> syn::Result<proc_macro2::TokenStream> {
161  let fn_name = input_fn.sig.ident.clone();
162
163  Ok(quote! {
164    #[test]
165    #(#attrs)*
166    fn #fn_name() {
167      #input_fn
168
169      for i in 0..#times {
170        println!("flaky_test retry {}", i);
171        let r = ::std::panic::catch_unwind(|| {
172          #fn_name();
173        });
174        if r.is_ok() {
175          return;
176        }
177        if i == #times - 1 {
178          ::std::panic::resume_unwind(r.unwrap_err());
179        }
180      }
181    }
182  })
183}
184
185fn tokio(
186  input_fn: ItemFn,
187  attrs: Vec<Attribute>,
188  times: usize,
189  tokio_args: Option<Punctuated<NestedMeta, Token![,]>>,
190) -> syn::Result<proc_macro2::TokenStream> {
191  if input_fn.sig.asyncness.is_none() {
192    return Err(syn::Error::new_spanned(input_fn.sig, "must be `async fn`"));
193  }
194
195  let fn_name = input_fn.sig.ident.clone();
196  let tokio_macro = match tokio_args {
197    Some(args) => quote! { #[::tokio::test(#args)] },
198    None => quote! { #[::tokio::test] },
199  };
200
201  Ok(quote! {
202    #tokio_macro
203    #(#attrs)*
204    async fn #fn_name() {
205      #input_fn
206
207      for i in 0..#times {
208        println!("flaky_test retry {}", i);
209        let fut = ::std::panic::AssertUnwindSafe(#fn_name());
210        let r = <_ as ::flaky_test::futures_util::future::FutureExt>::catch_unwind(fut).await;
211        if r.is_ok() {
212          return;
213        }
214        if i == #times - 1 {
215          ::std::panic::resume_unwind(r.unwrap_err());
216        }
217      }
218    }
219  })
220}