Skip to main content

sqlx_testcontainers_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input, Ident, ItemFn, LitStr, Token,
6};
7
8struct MacroArgs {
9    tag: Option<String>,
10    migrations: Option<String>,
11}
12
13impl Parse for MacroArgs {
14    fn parse(input: ParseStream) -> syn::Result<Self> {
15        let mut tag = None;
16        let mut migrations = None;
17
18        while !input.is_empty() {
19            let ident: Ident = input.parse()?;
20            input.parse::<Token![=]>()?;
21            let value: LitStr = input.parse()?;
22
23            match ident.to_string().as_str() {
24                "tag" => tag = Some(value.value()),
25                "migrations" => migrations = Some(value.value()),
26                _ => {
27                    return Err(syn::Error::new(
28                        ident.span(),
29                        "expected `tag` or `migrations`",
30                    ))
31                }
32            }
33
34            if !input.is_empty() {
35                input.parse::<Token![,]>()?;
36            }
37        }
38
39        Ok(MacroArgs { tag, migrations })
40    }
41}
42
43#[proc_macro_attribute]
44pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
45    let args = parse_macro_input!(args as MacroArgs);
46    let mut input = parse_macro_input!(input as ItemFn);
47
48    let test_name = input.sig.ident.clone();
49    let inner_name = quote::format_ident!("__inner_{}", test_name);
50
51    // Rename the original function to __inner_...
52    input.sig.ident = inner_name.clone();
53
54    let tag_expr = if let Some(t) = args.tag {
55        quote! { Some(#t.to_string()) }
56    } else {
57        quote! { None }
58    };
59
60    let migrate_expr = if let Some(m) = args.migrations {
61        quote! { ::sqlx::migrate!(#m).run(&pool).await.expect("Failed to run migrations"); }
62    } else {
63        quote! { ::sqlx::migrate!().run(&pool).await.expect("Failed to run migrations"); }
64    };
65
66    let expanded = quote! {
67        #[::tokio::test]
68        async fn #test_name() {
69            use ::testcontainers_modules::testcontainers::ImageExt;
70            use ::testcontainers_modules::testcontainers::runners::AsyncRunner;
71            use ::testcontainers_modules::postgres::Postgres;
72
73            let mut image = Postgres::default();
74            let tag: Option<String> = #tag_expr;
75            let container = if let Some(tag) = tag {
76                image.with_tag(tag).start().await.expect("Failed to start postgres container")
77            } else {
78                image.start().await.expect("Failed to start postgres container")
79            };
80
81            let host = container.get_host().await.expect("Failed to get host");
82            let host_port = container.get_host_port_ipv4(5432).await.expect("Failed to get port");
83
84            let conn_str = format!("postgres://postgres:postgres@{}:{}/postgres", host, host_port);
85
86            let pool = ::sqlx::postgres::PgPoolOptions::new()
87                .max_connections(1)
88                .connect(&conn_str)
89                .await
90                .expect("Failed to connect to postgres");
91
92            #migrate_expr
93
94            let mut conn = pool.acquire().await.expect("Failed to acquire connection").detach();
95
96            #input
97
98            #inner_name(conn).await;
99        }
100    };
101
102    TokenStream::from(expanded)
103}