pubky_test_utils_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_crate::{crate_name, FoundCrate};
3use quote::quote;
4use syn::{parse_macro_input, ItemFn};
5
6/// A macro that wraps a test function and makes sure the postgres test
7/// database(s) are dropped after the test completes/panics.
8///
9/// Usage:
10/// ```no_run
11/// #[tokio::test]
12/// #[pubky_testnet::test]
13/// async fn test_function() {
14///     // test code
15/// }
16/// ```
17///
18/// Important: The test function must be async and `#[tokio::test]` must be present above the macro.
19#[proc_macro_attribute]
20pub fn pubky_testcase(_attr: TokenStream, item: TokenStream) -> TokenStream {
21    let input_fn = parse_macro_input!(item as ItemFn);
22    let fn_block = &input_fn.block;
23    let fn_vis = &input_fn.vis;
24    let fn_attrs = &input_fn.attrs;
25    let fn_sig = &input_fn.sig;
26
27    /// Because this macro can be used in any crate, we need to get the crate name dynamically.
28    /// We support 3 crates: pubky_test_utils, pubky-testnet, pubky_test_utils_macro.
29    /// If one of them is not found, we try the next one.
30    /// If all of them are not found, we panic.
31    fn get_crate_name() -> FoundCrate {
32        let lib_names = [
33            "pubky_test_utils",
34            "pubky-testnet",
35            "pubky_test_utils_macro",
36        ];
37        for lib_name in lib_names.iter() {
38            match crate_name(lib_name) {
39                Ok(found) => return found,
40                Err(_e) => {
41                    continue;
42                }
43            };
44        }
45        panic!(
46            "Failed to get crate name. Tested crates: {}",
47            lib_names.join(", ").as_str()
48        );
49    }
50
51    // Get the crate name
52    let found = get_crate_name();
53    let my_crate = match found {
54        FoundCrate::Itself => quote!(crate),
55        FoundCrate::Name(name) => {
56            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
57            quote!(::#ident)
58        }
59    };
60
61    // Check if the function is async
62    let is_async = input_fn.sig.asyncness.is_some();
63
64    let expanded = if is_async {
65        // Handle async functions
66        // For async functions, we use a panic hook approach since catch_unwind doesn't work with async
67        quote! {
68            #(#fn_attrs)*
69            #fn_vis #fn_sig {
70                // Set up a panic hook to ensure cleanup happens
71                let original_hook = std::panic::take_hook();
72                std::panic::set_hook(Box::new(move |panic_info| {
73                // Execute cleanup in a blocking way since we're in a panic handler
74                if let Ok(rt) = tokio::runtime::Handle::try_current() {
75                    rt.block_on(#my_crate::drop_test_databases());
76                } else {
77                    // Fallback: create a new runtime if we're not in a tokio context
78                    if let Ok(rt) = tokio::runtime::Runtime::new() {
79                        rt.block_on(#my_crate::drop_test_databases());
80                    }
81                }
82                    // Call the original panic hook
83                    original_hook(panic_info);
84                }));
85
86                // Execute the test body
87                #fn_block
88
89                // Restore the original panic hook
90                std::panic::set_hook(original_hook);
91
92                // Always execute drop_test_databases() after the test completes normally
93                #my_crate::drop_test_databases().await;
94            }
95        }
96    } else {
97        // Handle sync functions - use std::panic::catch_unwind to ensure cleanup
98        quote! {
99            #(#fn_attrs)*
100            #fn_vis #fn_sig {
101                // Execute the test body and catch any panics
102                let result = std::panic::catch_unwind(|| {
103                    #fn_block
104                });
105
106                // Always execute drop_dbs() after the test, regardless of outcome
107                // Use tokio::runtime to handle the async call in sync context
108                if let Ok(rt) = tokio::runtime::Handle::try_current() {
109                    rt.block_on(#my_crate::drop_test_databases());
110                } else {
111                    // Fallback: create a new runtime if we're not in a tokio context
112                    let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
113                    rt.block_on(#my_crate::drop_test_databases());
114                }
115
116                // Re-panic if the test panicked
117                if let Err(panic) = result {
118                    std::panic::resume_unwind(panic);
119                }
120            }
121        }
122    };
123
124    TokenStream::from(expanded)
125}
126
127#[cfg(test)]
128mod tests {
129    #[test]
130    fn macro_compiles() {
131        // This test just ensures the macro compiles correctly
132        // The actual functionality is tested in integration tests
133        assert!(true);
134    }
135}