ntest_timeout/
lib.rs

1//! Part of the ntest library. Add timeout attribute to the rust test framework.
2
3extern crate proc_macro;
4extern crate syn;
5
6use proc_macro::TokenStream;
7use quote::quote;
8
9use syn::parse_macro_input;
10
11/// The timeout attribute can be used for tests to let them fail if they exceed a certain execution time.
12/// With the `#[timeout]` attribute a timeout in milliseconds is added to a test.
13///
14/// The function input must be of type `int`. For example `#[timeout(10)]` will fail if the test takes longer than 10 milliseconds.
15///
16/// # Examples
17///
18/// This example will not panic
19///
20/// ```
21/// #[test]
22/// #[timeout(100)]
23/// fn no_timeout() {
24///     let fifty_millis = time::Duration::from_millis(50);
25///     thread::sleep(fifty_millis);
26/// }
27/// ```
28///
29/// This example will panic and break the infinite loop after 10 milliseconds.
30///
31/// ```
32/// #[test]
33/// #[timeout(10)]
34/// #[should_panic]
35/// fn timeout() {
36///     loop {};
37/// }
38/// ```
39///
40/// Also works with test functions using a Result:
41///
42/// ```
43/// #[test]
44/// #[timeout(100)]
45/// fn timeout_with_result() -> Result<(), String> {
46///     let ten_millis = time::Duration::from_millis(10);
47///     thread::sleep(ten_millis);
48///     Ok(())
49/// }
50/// ```
51#[proc_macro_attribute]
52pub fn timeout(attr: TokenStream, item: TokenStream) -> TokenStream {
53    let input = syn::parse_macro_input!(item as syn::ItemFn);
54    let time_ms = get_timeout(&parse_macro_input!(attr as syn::AttributeArgs));
55    let vis = &input.vis;
56    let sig = &input.sig;
57    let output = &sig.output;
58    let body = &input.block;
59    let attrs = &input.attrs;
60    check_other_attributes(&input);
61    let result = quote! {
62        #(#attrs)*
63        #vis #sig {
64            fn ntest_callback() #output
65            #body
66            let ntest_timeout_now = std::time::Instant::now();
67            
68            type NtestPanicPayload = std::boxed::Box<dyn std::any::Any + std::marker::Send + 'static>;
69            // Channel sends Result: Ok for success, Err for panic payload
70            let (sender, receiver) = std::sync::mpsc::channel::<std::result::Result<_, NtestPanicPayload>>();
71            std::thread::spawn(move || {
72                let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
73                    ntest_callback()
74                }));
75                // Send will fail if receiver has already timed out or dropped - this is expected
76                let _ = sender.send(panic_result);
77            });
78            match receiver.recv_timeout(std::time::Duration::from_millis(#time_ms)) {
79                std::result::Result::Ok(std::result::Result::Ok(t)) => return t,
80                std::result::Result::Ok(std::result::Result::Err(panic_payload)) => {
81                    // Resume the panic with the original payload to preserve panic message
82                    std::panic::resume_unwind(panic_payload);
83                },
84                Err(std::sync::mpsc::RecvTimeoutError::Timeout) => panic!("timeout: the function call took {} ms. Max time {} ms", ntest_timeout_now.elapsed().as_millis(), #time_ms),
85                Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => panic!("Thread disconnected unexpectedly"),
86            }
87        }
88    };
89    result.into()
90}
91
92fn check_other_attributes(input: &syn::ItemFn) {
93    for attribute in &input.attrs {
94        let meta = attribute.parse_meta();
95        match meta {
96            std::result::Result::Ok(m) => match m {
97                syn::Meta::Path(p) => {
98                    if p.segments.iter().any(|ps| ps.ident == "timeout") {
99                        panic!("Timeout attribute is only allowed once");
100                    }
101                }
102                syn::Meta::List(ml) => {
103                    if ml.path.segments.iter().any(|ps| ps.ident == "timeout") {
104                        panic!("Timeout attribute is only allowed once");
105                    }
106                }
107                syn::Meta::NameValue(nv) => {
108                    if nv.path.segments.iter().any(|ps| ps.ident == "timeout") {
109                        panic!("Timeout attribute is only allowed once");
110                    }
111                }
112            },
113            Err(e) => panic!("Could not determine meta data. Error {}.", e),
114        }
115    }
116}
117
118fn get_timeout(attribute_args: &syn::AttributeArgs) -> u64 {
119    if attribute_args.len() > 1 {
120        panic!("Only one integer expected. Example: #[timeout(10)]");
121    }
122    match &attribute_args[0] {
123        syn::NestedMeta::Meta(_) => {
124            panic!("Integer expected. Example: #[timeout(10)]");
125        }
126        syn::NestedMeta::Lit(lit) => match lit {
127            syn::Lit::Int(int) => int.base10_parse::<u64>().expect("Integer expected"),
128            _ => {
129                panic!("Integer as timeout in ms expected. Example: #[timeout(10)]");
130            }
131        },
132    }
133}