citerne_derive/
lib.rs

1extern crate core;
2
3use proc_macro::TokenStream;
4
5use crate::attributes::CiterneAttributes;
6use quote::quote;
7use syn::{parse_macro_input, parse_quote, FnArg, Ident, Pat, Stmt};
8
9mod attributes;
10
11const PG_SIG_ERROR_MESSAGE: &str =
12    "database_container_test should have `fn(conn: &mut PgConnection)` signature";
13
14#[proc_macro_attribute]
15pub fn database_container_test(attr: TokenStream, item: TokenStream) -> TokenStream {
16    let mut function = parse_macro_input!(item as syn::ItemFn);
17
18    let arg = function
19        .sig
20        .inputs
21        .pop()
22        .unwrap_or_else(|| panic!("{PG_SIG_ERROR_MESSAGE}"));
23
24    let ident = match arg.value() {
25        FnArg::Receiver(_) => panic!("{PG_SIG_ERROR_MESSAGE}"),
26        FnArg::Typed(typed) => {
27            let arg_type = &typed.ty;
28            if quote!(#arg_type).to_string() != "&mut PgConnection" {
29                panic!("{PG_SIG_ERROR_MESSAGE}");
30            }
31
32            if let Pat::Ident(ident) = typed.pat.as_ref() {
33                ident.ident.clone()
34            } else {
35                panic!("{PG_SIG_ERROR_MESSAGE}");
36            }
37        }
38    };
39
40    let attributes = parse_macro_input!(attr as CiterneAttributes);
41
42    #[cfg(feature = "postgres")]
43    let mut init_container_statements = generate_postgres_container_init(attributes, ident);
44
45    #[cfg(feature = "mysql")]
46    let mut init_container_statements = generate_mysql_container_init(attributes);
47
48    init_container_statements.extend(function.block.stmts.clone());
49    function.block.stmts = init_container_statements;
50
51    let function = quote! {
52        #[test]
53        #function
54    };
55
56    TokenStream::from(function)
57}
58
59#[cfg(feature = "postgres")]
60fn generate_postgres_container_init(attributes: CiterneAttributes, conn_ident: Ident) -> Vec<Stmt> {
61    let mut init_container_statements: Vec<Stmt> = parse_quote! {
62        let docker = testcontainers::clients::Cli::default();
63        println!("Starting postgresql container");
64        let __postgres_container__: testcontainers::Container<testcontainers::images::postgres::Postgres> = docker.run(testcontainers::images::postgres::Postgres::default());
65        let __postgres_addr__ = format!(
66            "postgresql://postgres@127.0.0.1:{}/postgres",
67            __postgres_container__.get_host_port_ipv4(5432)
68        );
69        println!("Container started, trying to establish postgresql connection on port 5432");
70        let mut #conn_ident = {
71            use diesel::Connection;
72            diesel::PgConnection::establish(&__postgres_addr__)?
73        };
74        let #conn_ident = &mut #conn_ident;
75    };
76
77    for migration in attributes.migrations {
78        let mut run_migrations: Vec<Stmt> = parse_quote! {
79            let crate_path = std::env::var("CARGO_MANIFEST_DIR").map(|dir|std::path::PathBuf::from(dir))?;
80            let migration_path = crate_path.join(#migration);
81            if !migration_path.is_dir() {
82                let sql = std::fs::read_to_string(&migration_path)?;
83                {
84                    use diesel::connection::SimpleConnection;
85                    println!("Running single migration {}", migration_path.file_name().unwrap().to_string_lossy());
86                    #conn_ident.batch_execute(&sql)?;
87                }
88            } else {
89                use diesel::migration::MigrationSource;
90                let migrations = diesel_migrations::FileBasedMigrations::from_path(&migration_path)?.migrations()?;
91                for migration in migrations {
92                    println!("Running migrations {}", migration.name());
93                    migration.run(conn)?;
94                }
95            }
96        };
97
98        if let Some(sql) = attributes.sql.clone() {
99            run_migrations.push(parse_quote! {
100                {
101                    use diesel::connection::SimpleConnection;
102                    #conn_ident.batch_execute(#sql)?;
103                }
104            });
105        }
106
107        init_container_statements.extend::<Vec<Stmt>>(run_migrations);
108    }
109    init_container_statements
110}