jrest_hooks/
lib.rs

1//! This code is currently an (almost) exact copy of <https://github.com/TomPridham/test-env-helpers> source code.
2//!
3//! ## Getting Started
4//! Using these macros is fairly simple. The four after/before functions all require a function
5//! with the same name as the attribute and are only valid when applied to a mod. They are all used
6//! like in the below example. Replace `before_each` with whichever method you want to use. The
7//! code in the matching function will be inserted into every fn in the containing mod that has an
8//! attribute with the word "test" in it. This is to allow for use with not just normal `#[test]`
9//! attributes, but also other flavors like `#[tokio::test]` and `#[test_case(0)]`.
10//! ```
11//! #[cfg(test)]
12//! use test_env_helpers::*;
13//!
14//! #[before_each]
15//! #[cfg(test)]
16//! mod my_tests{
17//!     fn before_each(){println!("I'm in every test!")}
18//!     #[test]
19//!     fn test_1(){}
20//!     #[test]
21//!     fn test_2(){}
22//!     #[test]
23//!     fn test_3(){}
24//! }
25//! ```
26//!
27//! The `skip` macro is valid on either a mod or an individual test and will remove the mod or test
28//! it is applied to. You can use it to skip tests that aren't working correctly or that you don't
29//! want to run for some reason.
30//!
31//! ```
32//! #[cfg(test)]
33//! use test_env_helpers::*;
34//!
35//! #[cfg(test)]
36//! mod my_tests{
37//!     #[skip]
38//!     #[test]
39//!     fn broken_test(){panic!("I'm hella broke")}
40//!     #[skip]
41//!     mod broken_mod{
42//!         #[test]
43//!         fn i_will_not_be_run(){panic!("I get skipped too")}
44//!     }
45//!     #[test]
46//!     fn test_2(){}
47//!     #[test]
48//!     fn test_3(){}
49//! }
50//! ```
51
52extern crate proc_macro;
53mod utils;
54
55use crate::utils::traverse_use_item;
56
57use proc_macro::TokenStream;
58use quote::quote;
59use syn::{parse_macro_input, parse_quote, Item, Stmt};
60
61/// Will run the code in the matching `after_all` function exactly once when all of the tests have
62/// run. This works by counting the number of `#[test]` attributes and decrementing a counter at
63/// the beginning of every test. Once the counter reaches 0, it will run the code in `after_all`.
64/// It uses [std::sync::Once] internally to ensure that the code is run at maximum one time.
65///
66/// ```
67/// #[cfg(test)]
68/// use test_env_helpers::*;
69///
70/// #[after_all]
71/// #[cfg(test)]
72/// mod my_tests{
73///     fn after_all(){println!("I only get run once at the very end")}
74///     #[test]
75///     fn test_1(){}
76///     #[test]
77///     fn test_2(){}
78///     #[test]
79///     fn test_3(){}
80/// }
81/// ```
82#[proc_macro_attribute]
83pub fn after_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
84    let input: Item = match parse_macro_input!(input as Item) {
85        Item::Mod(mut m) => {
86            let (brace, items) = m.content.unwrap();
87            let (after_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
88                items.into_iter().partition(|t| match t {
89                    Item::Fn(f) => f.sig.ident == "after_all",
90                    _ => false,
91                });
92            let after_all_fn_block = if after_all_fn.len() != 1 {
93                panic!("The `after_all` macro attribute requires a single function named `after_all` in the body of the module it is called on.")
94            } else {
95                match after_all_fn.into_iter().next().unwrap() {
96                    Item::Fn(f) => f.block,
97                    _ => unreachable!(),
98                }
99            };
100            let inc_count: Stmt = parse_quote!(
101                REMAINING_TESTS.fetch_sub(1, Ordering::SeqCst);
102            );
103            let after_all_if: Stmt = parse_quote! {
104                if REMAINING_TESTS.load(Ordering::SeqCst) <= 0{
105                    AFTER_ALL.call_once(|| {
106                        #after_all_fn_block
107                    });
108                }
109            };
110
111            let mut count: usize = 0;
112            let mut has_once: bool = false;
113            let mut has_atomic_usize: bool = false;
114            let mut has_ordering: bool = false;
115
116            let mut e: Vec<Item> = everything_else
117                .into_iter()
118                .map(|t| match t {
119                    Item::Fn(mut f) => {
120                        let test_count = f
121                            .attrs
122                            .iter()
123                            .filter(|attr| {
124                                attr.path()
125                                    .segments
126                                    .iter()
127                                    .any(|segment| segment.ident.to_string().contains("test"))
128                            })
129                            .count();
130                        if test_count > 0 {
131                            count += test_count;
132                            let mut stmts = vec![inc_count.clone()];
133                            stmts.append(&mut f.block.stmts);
134                            stmts.push(after_all_if.clone());
135                            f.block.stmts = stmts;
136                            Item::Fn(f)
137                        } else {
138                            Item::Fn(f)
139                        }
140                    }
141                    Item::Use(use_stmt) => {
142                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
143                        {
144                            has_once = true;
145                        }
146                        if traverse_use_item(
147                            &use_stmt.tree,
148                            vec!["std", "sync", "atomic", "AtomicUsize"],
149                        )
150                        .is_some()
151                        {
152                            has_atomic_usize = true;
153                        }
154                        if traverse_use_item(
155                            &use_stmt.tree,
156                            vec!["std", "sync", "atomic", "Ordering"],
157                        )
158                        .is_some()
159                        {
160                            has_ordering = true;
161                        }
162                        Item::Use(use_stmt)
163                    }
164                    el => el,
165                })
166                .collect();
167
168            let use_once: Item = parse_quote!(
169                use std::sync::Once;
170            );
171            let use_atomic_usize: Item = parse_quote!(
172                use std::sync::atomic::AtomicUsize;
173            );
174            let use_ordering: Item = parse_quote!(
175                use std::sync::atomic::Ordering;
176            );
177            let static_once: Item = parse_quote!(
178                static AFTER_ALL: Once = Once::new();
179            );
180            let static_count: Item = parse_quote!(
181                static REMAINING_TESTS: AtomicUsize = AtomicUsize::new(#count);
182            );
183
184            let mut once_content = vec![];
185
186            if !has_once {
187                once_content.push(use_once);
188            }
189            if !has_atomic_usize {
190                once_content.push(use_atomic_usize);
191            }
192            if !has_ordering {
193                once_content.push(use_ordering);
194            }
195            once_content.append(&mut vec![static_once, static_count]);
196            once_content.append(&mut e);
197
198            m.content = Some((brace, once_content));
199            Item::Mod(m)
200        }
201        _ => {
202            panic!("The `after_all` macro attribute is only valid when called on a module.")
203        }
204    };
205    TokenStream::from(quote! (#input))
206}
207
208/// Will run the code in the matching `after_each` function at the end of every `#[test]` function.
209/// Useful if you want to cleanup after a test or reset some external state. If the test panics,
210/// this code will not be run. If you need something that is infallible, you should use
211/// `before_each` instead.
212/// ```
213/// #[cfg(test)]
214/// use test_env_helpers::*;
215///
216/// #[after_each]
217/// #[cfg(test)]
218/// mod my_tests{
219///     fn after_each(){println!("I get run at the very end of each function")}
220///     #[test]
221///     fn test_1(){}
222///     #[test]
223///     fn test_2(){}
224///     #[test]
225///     fn test_3(){}
226/// }
227/// ```
228#[proc_macro_attribute]
229pub fn after_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
230    let input: Item = match parse_macro_input!(input as Item) {
231        Item::Mod(mut m) => {
232            let (brace, items) = m.content.unwrap();
233            let (after_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
234                items.into_iter().partition(|t| match t {
235                    Item::Fn(f) => f.sig.ident == "after_each",
236                    _ => false,
237                });
238            let after_each_fn_block = if after_each_fn.len() != 1 {
239                panic!("The `after_each` macro attribute requires a single function named `after_each` in the body of the module it is called on.")
240            } else {
241                match after_each_fn.into_iter().next().unwrap() {
242                    Item::Fn(f) => f.block,
243                    _ => unreachable!(),
244                }
245            };
246
247            let e: Vec<Item> = everything_else
248                .into_iter()
249                .map(|t| match t {
250                    Item::Fn(mut f) => {
251                        if f.attrs.iter().any(|attr| {
252                            attr.path()
253                                .segments
254                                .iter()
255                                .any(|segment| segment.ident.to_string().contains("test"))
256                        }) {
257                            f.block.stmts.append(&mut after_each_fn_block.stmts.clone());
258                            Item::Fn(f)
259                        } else {
260                            Item::Fn(f)
261                        }
262                    }
263                    e => e,
264                })
265                .collect();
266            m.content = Some((brace, e));
267            Item::Mod(m)
268        }
269
270        _ => {
271            panic!("The `after_each` macro attribute is only valid when called on a module.")
272        }
273    };
274    TokenStream::from(quote! {#input})
275}
276
277/// Will run the code in the matching `before_all` function exactly once at the very beginning of a
278/// test run. It uses [std::sync::Once](https://doc.rust-lang.org/std/sync/struct.Once.html) internally
279/// to ensure that the code is run at maximum one time. Useful for setting up some external state
280/// that will be reused in multiple tests.
281/// ```
282/// #[cfg(test)]
283/// use test_env_helpers::*;
284///
285/// #[before_all]
286/// #[cfg(test)]
287/// mod my_tests{
288///     fn before_all(){println!("I get run at the very beginning of the test suite")}
289///     #[test]
290///     fn test_1(){}
291///     #[test]
292///     fn test_2(){}
293///     #[test]
294///     fn test_3(){}
295/// }
296/// ```
297#[proc_macro_attribute]
298pub fn before_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
299    let input: Item = match parse_macro_input!(input as Item) {
300        Item::Mod(mut m) => {
301            let (brace, items) = m.content.unwrap();
302            let (before_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
303                items.into_iter().partition(|t| match t {
304                    Item::Fn(f) => f.sig.ident == "before_all",
305                    _ => false,
306                });
307            let before_all_fn_block = if before_all_fn.len() != 1 {
308                panic!("The `before_all` macro attribute requires a single function named `before_all` in the body of the module it is called on.")
309            } else {
310                match before_all_fn.into_iter().next().unwrap() {
311                    Item::Fn(f) => f.block,
312                    _ => unreachable!(),
313                }
314            };
315            let q: Stmt = parse_quote! {
316                BEFORE_ALL.call_once(|| {
317                    #before_all_fn_block
318                });
319            };
320
321            let mut has_once: bool = false;
322            let mut e: Vec<Item> = everything_else
323                .into_iter()
324                .map(|t| match t {
325                    Item::Fn(mut f) => {
326                        if f.attrs.iter().any(|attr| {
327                            attr.path()
328                                .segments
329                                .iter()
330                                .any(|segment| segment.ident.to_string().contains("test"))
331                        }) {
332                            let mut stmts = vec![q.clone()];
333                            stmts.append(&mut f.block.stmts);
334                            f.block.stmts = stmts;
335                            Item::Fn(f)
336                        } else {
337                            Item::Fn(f)
338                        }
339                    }
340                    Item::Use(use_stmt) => {
341                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
342                        {
343                            has_once = true;
344                        }
345                        Item::Use(use_stmt)
346                    }
347                    e => e,
348                })
349                .collect();
350            let use_once: Item = parse_quote!(
351                use std::sync::Once;
352            );
353            let static_once: Item = parse_quote!(
354                static BEFORE_ALL: Once = Once::new();
355            );
356
357            let mut once_content = vec![];
358            if !has_once {
359                once_content.push(use_once);
360            }
361            once_content.push(static_once);
362            once_content.append(&mut e);
363
364            m.content = Some((brace, once_content));
365            Item::Mod(m)
366        }
367
368        _ => {
369            panic!("The `before_all` macro attribute is only valid when called on a module.")
370        }
371    };
372    TokenStream::from(quote! (#input))
373}
374
375/// Will run the code in the matching `before_each` function at the beginning of every test. Useful
376/// to reset state to ensure that a test has a clean slate.
377/// ```
378/// #[cfg(test)]
379/// use test_env_helpers::*;
380///
381/// #[before_each]
382/// #[cfg(test)]
383/// mod my_tests{
384///     fn before_each(){println!("I get run at the very beginning of every test")}
385///     #[test]
386///     fn test_1(){}
387///     #[test]
388///     fn test_2(){}
389///     #[test]
390///     fn test_3(){}
391/// }
392/// ```
393///
394/// Can be used to reduce the amount of boilerplate setup code that needs to be copied into each test.
395/// For example, if you need to ensure that tests in a single test suite are not run in parallel, this can
396/// easily be done with a [Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html).
397/// However, remembering to copy and paste the code to acquire a lock on the `Mutex` in every test
398/// is tedious and error prone.
399/// ```
400/// #[cfg(test)]
401/// mod without_before_each{
402///     lazy_static! {
403///         static ref MTX: Mutex<()> = Mutex::new(());
404///     }
405///     #[test]
406///     fn test_1(){let _m = MTX.lock();}
407///     #[test]
408///     fn test_2(){let _m = MTX.lock();}
409///     #[test]
410///     fn test_3(){let _m = MTX.lock();}
411/// }
412/// ```
413/// Using `before_each` removes the need to copy and paste so much and makes making changes easier
414/// because they only need to be made in a single location instead of once for every test.
415/// ```
416/// #[cfg(test)]
417/// use test_env_helpers::*;
418///
419/// #[before_each]
420/// #[cfg(test)]
421/// mod with_before_each{
422///     lazy_static! {
423///         static ref MTX: Mutex<()> = Mutex::new(());
424///     }
425///     fn before_each(){let _m = MTX.lock();}
426///     #[test]
427///     fn test_1(){}
428///     #[test]
429///     fn test_2(){}
430///     #[test]
431///     fn test_3(){}
432/// }
433/// ```
434#[proc_macro_attribute]
435pub fn before_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
436    let input: Item = match parse_macro_input!(input as Item) {
437        Item::Mod(mut m) => {
438            let (brace, items) = m.content.unwrap();
439            let (before_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
440                items.into_iter().partition(|t| match t {
441                    Item::Fn(f) => f.sig.ident == "before_each",
442                    _ => false,
443                });
444            let before_each_fn_block = if before_each_fn.len() != 1 {
445                panic!("The `before_each` macro attribute requires a single function named `before_each` in the body of the module it is called on.")
446            } else {
447                match before_each_fn.into_iter().next().unwrap() {
448                    Item::Fn(f) => f.block,
449                    _ => unreachable!(),
450                }
451            };
452
453            let e: Vec<Item> = everything_else
454                .into_iter()
455                .map(|t| match t {
456                    Item::Fn(mut f) => {
457                        if f.attrs.iter().any(|attr| {
458                            attr.path()
459                                .segments
460                                .iter()
461                                .any(|segment| segment.ident.to_string().contains("test"))
462                        }) {
463                            let mut b = before_each_fn_block.stmts.clone();
464                            b.append(&mut f.block.stmts);
465                            f.block.stmts = b;
466                            Item::Fn(f)
467                        } else {
468                            Item::Fn(f)
469                        }
470                    }
471                    e => e,
472                })
473                .collect();
474            m.content = Some((brace, e));
475            Item::Mod(m)
476        }
477
478        _ => {
479            panic!("The `before_each` macro attribute is only valid when called on a module.")
480        }
481    };
482    TokenStream::from(quote! {#input})
483}
484
485/// Will skip running the code it is applied on. You can use it to skip tests that aren't working
486/// correctly or that you don't want to run for some reason. There are no checks to make sure it's
487/// applied to a `#[test]` or mod. It will remove whatever it is applied to from the final AST.
488///
489/// ```
490/// #[cfg(test)]
491/// use test_env_helpers::*;
492///
493/// #[cfg(test)]
494/// mod my_tests{
495///     #[skip]
496///     #[test]
497///     fn broken_test(){panic!("I'm hella broke")}
498///     #[skip]
499///     mod broken_mod{
500///         #[test]
501///         fn i_will_not_be_run(){panic!("I get skipped too")}
502///     }
503///     #[test]
504///     fn test_2(){}
505///     #[test]
506///     fn test_3(){}
507/// }
508/// ```
509#[proc_macro_attribute]
510pub fn skip(_metadata: TokenStream, _input: TokenStream) -> TokenStream {
511    TokenStream::from(quote! {})
512}