1extern crate core;
2
3use proc_macro::TokenStream;
4
5use crate::attributes::CiterneAttributes;
6use quote::quote;
7use syn::{parse_macro_input, parse_quote, FnArg, Ident, Pat, Stmt};
8
9mod attributes;
10
11const PG_SIG_ERROR_MESSAGE: &str =
12 "database_container_test should have `fn(conn: &mut PgConnection)` signature";
13
14#[proc_macro_attribute]
15pub fn database_container_test(attr: TokenStream, item: TokenStream) -> TokenStream {
16 let mut function = parse_macro_input!(item as syn::ItemFn);
17
18 let arg = function
19 .sig
20 .inputs
21 .pop()
22 .unwrap_or_else(|| panic!("{PG_SIG_ERROR_MESSAGE}"));
23
24 let ident = match arg.value() {
25 FnArg::Receiver(_) => panic!("{PG_SIG_ERROR_MESSAGE}"),
26 FnArg::Typed(typed) => {
27 let arg_type = &typed.ty;
28 if quote!(#arg_type).to_string() != "&mut PgConnection" {
29 panic!("{PG_SIG_ERROR_MESSAGE}");
30 }
31
32 if let Pat::Ident(ident) = typed.pat.as_ref() {
33 ident.ident.clone()
34 } else {
35 panic!("{PG_SIG_ERROR_MESSAGE}");
36 }
37 }
38 };
39
40 let attributes = parse_macro_input!(attr as CiterneAttributes);
41
42 #[cfg(feature = "postgres")]
43 let mut init_container_statements = generate_postgres_container_init(attributes, ident);
44
45 #[cfg(feature = "mysql")]
46 let mut init_container_statements = generate_mysql_container_init(attributes);
47
48 init_container_statements.extend(function.block.stmts.clone());
49 function.block.stmts = init_container_statements;
50
51 let function = quote! {
52 #[test]
53 #function
54 };
55
56 TokenStream::from(function)
57}
58
59#[cfg(feature = "postgres")]
60fn generate_postgres_container_init(attributes: CiterneAttributes, conn_ident: Ident) -> Vec<Stmt> {
61 let mut init_container_statements: Vec<Stmt> = parse_quote! {
62 let docker = testcontainers::clients::Cli::default();
63 println!("Starting postgresql container");
64 let __postgres_container__: testcontainers::Container<testcontainers::images::postgres::Postgres> = docker.run(testcontainers::images::postgres::Postgres::default());
65 let __postgres_addr__ = format!(
66 "postgresql://postgres@127.0.0.1:{}/postgres",
67 __postgres_container__.get_host_port_ipv4(5432)
68 );
69 println!("Container started, trying to establish postgresql connection on port 5432");
70 let mut #conn_ident = {
71 use diesel::Connection;
72 diesel::PgConnection::establish(&__postgres_addr__)?
73 };
74 let #conn_ident = &mut #conn_ident;
75 };
76
77 for migration in attributes.migrations {
78 let mut run_migrations: Vec<Stmt> = parse_quote! {
79 let crate_path = std::env::var("CARGO_MANIFEST_DIR").map(|dir|std::path::PathBuf::from(dir))?;
80 let migration_path = crate_path.join(#migration);
81 if !migration_path.is_dir() {
82 let sql = std::fs::read_to_string(&migration_path)?;
83 {
84 use diesel::connection::SimpleConnection;
85 println!("Running single migration {}", migration_path.file_name().unwrap().to_string_lossy());
86 #conn_ident.batch_execute(&sql)?;
87 }
88 } else {
89 use diesel::migration::MigrationSource;
90 let migrations = diesel_migrations::FileBasedMigrations::from_path(&migration_path)?.migrations()?;
91 for migration in migrations {
92 println!("Running migrations {}", migration.name());
93 migration.run(conn)?;
94 }
95 }
96 };
97
98 if let Some(sql) = attributes.sql.clone() {
99 run_migrations.push(parse_quote! {
100 {
101 use diesel::connection::SimpleConnection;
102 #conn_ident.batch_execute(#sql)?;
103 }
104 });
105 }
106
107 init_container_statements.extend::<Vec<Stmt>>(run_migrations);
108 }
109 init_container_statements
110}