use std::{future::Future, sync::LazyLock};
use tokio::{sync::Mutex, task};
use tokio_postgres::NoTls;
static DB_LOCK: LazyLock<Mutex<()>> = std::sync::LazyLock::new(|| Mutex::new(()));
pub(crate) async fn with_db<F, FUT>(f: F) -> anyhow::Result<()>
where
F: FnOnce(String, tokio_postgres::Client) -> FUT,
FUT: Future<Output = anyhow::Result<()>> + 'static,
{
let _ = env_logger::builder().is_test(true).try_init();
let Ok(connection_string) = std::env::var("TEST_DB_URL") else {
if std::env::var("GITHUB_ACTIONS").is_ok() {
panic!("TEST_DB_URL must be set in GitHub actions");
}
return Ok(());
};
let _db_guard = DB_LOCK.lock().await;
let local_set = task::LocalSet::new();
local_set
.run_until(async move {
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
let conn_join_handle = tokio::spawn(async move {
if let Err(e) = connection.await {
log::warn!("connection error: {e}");
}
});
client
.execute("drop schema if exists public cascade", &[])
.await?;
client.execute("create schema public", &[]).await?;
client.simple_query(include_str!("../schema.sql")).await?;
let test_join_handle = tokio::task::spawn_local(f(connection_string.clone(), client));
let test_res = test_join_handle.await?;
conn_join_handle.await?;
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
let conn_join_handle = tokio::spawn(async move {
if let Err(e) = connection.await {
log::warn!("connection error: {e}");
}
});
client
.execute("drop schema if exists public cascade", &[])
.await?;
drop(client);
conn_join_handle.await?;
test_res
})
.await
}