use darling::FromMeta;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, AttributeArgs, Ident};
mod generators;
#[derive(Debug, FromMeta)]
pub(crate) struct Pool {
variable: Ident,
#[darling(default)]
transaction_variable: Option<Ident>,
#[darling(default)]
migrations: Option<String>,
#[darling(default)]
skip_migrations: bool,
}
impl Pool {
fn database_name_var(&self) -> Ident {
format_ident!("__{}_db_name", &self.variable)
}
}
#[derive(Debug, FromMeta)]
pub(crate) struct MacroArgs {
#[darling(default)]
level: String,
#[darling(multiple)]
pool: Vec<Pool>,
}
#[proc_macro_attribute]
pub fn test(test_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = syn::parse_macro_input!(item as syn::ItemFn);
let test_attr_args = parse_macro_input!(test_attr as AttributeArgs);
let test_attr: MacroArgs = match MacroArgs::from_list(&test_attr_args) {
Ok(v) => v,
Err(e) => {
return TokenStream::from(e.write_errors());
}
};
let level = test_attr.level.as_str();
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &mut input.sig;
let body = &input.block;
let runtime = if let Some(runtime) = generators::runtime() {
runtime
} else {
return syn::Error::new(
Span::call_site(),
"One of 'runtime-actix' and 'runtime-tokio' features needs to be enabled",
)
.into_compile_error()
.into();
};
if sig.asyncness.is_none() {
return syn::Error::new_spanned(
input.sig.fn_token,
"the async keyword is missing from the function declaration",
)
.to_compile_error()
.into();
}
sig.asyncness = None;
let database_name_vars = generators::database_name_vars(&test_attr);
let database_creators = generators::database_creators(&test_attr);
let database_migrations_exposures = generators::database_migrations_exposures(&test_attr);
let database_closers = generators::database_closers(&test_attr);
let database_destructors = generators::database_destructors(&test_attr);
let sleep = generators::sleep();
(quote! {
#[::core::prelude::v1::test]
#(#attrs)*
#vis #sig {
const MAX_RETRIES: u8 = 30;
const TIME_BETWEEN_RETRIES: u64 = 10;
#[allow(clippy::expect_used)]
async fn connect_with_retry() -> Result<sqlx::PgPool, sqlx::Error> {
let mut i = 0;
loop {
let db_pool = sqlx::PgPool::connect_with(sqlx_database_tester::connect_options(
sqlx_database_tester::derive_db_prefix(&sqlx_database_tester::get_database_uri())
.expect("Getting database name")
.as_deref()
.unwrap_or_default(),
#level,
))
.await;
match db_pool {
Ok(pool) => break Ok(pool),
Err(e) => {
if i >= MAX_RETRIES {
break Err(e);
}
}
}
#sleep(std::time::Duration::from_secs(TIME_BETWEEN_RETRIES)).await;
i += 1;
}
}
sqlx_database_tester::dotenv::dotenv().ok();
#(#database_name_vars)*
#runtime.block_on(async {
#[allow(clippy::expect_used)]
let db_pool = connect_with_retry().await.expect("connecting to db for creation");
#(#database_creators)*
});
let result = std::panic::catch_unwind(|| {
#runtime.block_on(async {
#(#database_migrations_exposures)*
let res = #body;
#(#database_closers)*
res
})
});
#runtime.block_on(async {
#[allow(clippy::expect_used)]
let db_pool = connect_with_retry().await.expect("connecting to db for deletion");
#(#database_destructors)*
});
match result {
std::result::Result::Err(_) => std::panic!("The main test function crashed, the test database got cleaned"),
std::result::Result::Ok(o) => o
}
}
}).into()
}