Skip to main content

kellnr_db_testcontainer/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn db_test(_attr: TokenStream, stream: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(stream as ItemFn);
8
9    let ItemFn {
10        attrs: _,
11        vis,
12        sig,
13        block,
14    } = input;
15
16    // Clone the function signature for our implementation function
17    let impl_sig = sig.clone();
18    let fn_name = &sig.ident;
19    let sqlite_fn_name = quote::format_ident!("sqlite_{fn_name}");
20    let postgres_fn_name = quote::format_ident!("postgres_{fn_name}");
21    let _test_impl_name = quote::format_ident!("{fn_name}_impl");
22    let _stmts = &block.stmts;
23
24    let output = quote! {
25        // Keep the original function as the implementation
26        #[allow(unused)]
27        #vis #impl_sig #block
28
29        // SQLite version of the test
30        #[tokio::test]
31        #vis async fn #sqlite_fn_name() {
32            use std::path;
33            use std::ops::Add;
34            use kellnr_common::util::generate_rand_string;
35
36            let path = path::PathBuf::from("/tmp").join(generate_rand_string(8).add(".db"));
37            let con_string = kellnr_db::SqliteConString {
38                path: path.to_owned(),
39                salt: "salt".to_string(),
40                admin_pwd: "123".to_string(),
41                admin_token: Some("token".to_string()),
42                session_age: std::time::Duration::from_secs(1),
43            };
44            let con_string = kellnr_db::ConString::Sqlite(con_string);
45            let test_db = kellnr_db::Database::new(&con_string, 10).await.unwrap();
46
47            // Run the test with SQLite
48            #fn_name(&test_db).await;
49
50            // Clean up
51            rm_rf::remove(&path).expect("Cannot remove test db");
52        }
53
54        // PostgreSQL version of the test
55        #[tokio::test]
56        #vis async fn #postgres_fn_name() {
57            use testcontainers::runners::AsyncRunner;
58
59            let pg_container = image::Postgres::default().start().await.expect("Failed to start postgres container");
60            let port = pg_container.get_host_port_ipv4(image::Postgres::PG_PORT).await.expect("Failed to get port");
61            let admin = kellnr_db::AdminUser::new("123".to_string(), Some("token".to_string()), "salt".to_string());
62            let pg_db = kellnr_db::PgConString::new("localhost", port, "kellnr", "admin", "admin", admin);
63            let pg_db = kellnr_db::ConString::Postgres(pg_db);
64            let test_db = kellnr_db::Database::new(&pg_db, 10).await.unwrap();
65
66            // Run the test with PostgreSQL
67            #fn_name(&test_db).await;
68        }
69    };
70
71    output.into()
72}