use sqlx::PgPool;
use std::sync::Arc;
pub struct TestDatabase {
pool: Arc<PgPool>,
database_name: String,
postgres_url: String,
}
impl TestDatabase {
pub async fn new() -> anyhow::Result<Self> {
Self::with_migrations(true).await
}
pub async fn without_migrations() -> anyhow::Result<Self> {
Self::with_migrations(false).await
}
async fn with_migrations(run_migrations: bool) -> anyhow::Result<Self> {
let database_name = format!("test_db_{}", uuid::Uuid::new_v4().simple());
let postgres_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/postgres".to_string());
let pool = PgPool::connect(&postgres_url).await?;
sqlx::query(&format!("CREATE DATABASE {database_name}"))
.execute(&pool)
.await?;
let test_db_url = postgres_url.replace("/postgres", &format!("/{database_name}"));
let test_pool = PgPool::connect(&test_db_url).await?;
if run_migrations {
sqlx::migrate!("../migrations")
.run(&test_pool)
.await?;
}
Ok(Self {
pool: Arc::new(test_pool),
database_name,
postgres_url,
})
}
#[must_use]
pub fn pool(&self) -> &PgPool {
&self.pool
}
#[must_use]
pub fn name(&self) -> &str {
&self.database_name
}
}
impl Drop for TestDatabase {
fn drop(&mut self) {
let database_name = self.database_name.clone();
let postgres_url = self.postgres_url.clone();
let pool = Arc::clone(&self.pool);
std::mem::drop(pool);
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime for cleanup");
rt.block_on(async {
match PgPool::connect(&postgres_url).await {
Ok(pool) => {
let force_disconnect = format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{database_name}'"
);
let _ = sqlx::query(&force_disconnect).execute(&pool).await;
let drop_query = format!("DROP DATABASE IF EXISTS {database_name}");
match sqlx::query(&drop_query).execute(&pool).await {
Ok(_) => {
tracing::debug!("Successfully dropped test database: {database_name}");
}
Err(e) => {
tracing::warn!("Failed to drop test database {database_name}: {e}");
}
}
}
Err(e) => {
tracing::warn!("Failed to connect for cleanup of {database_name}: {e}");
}
}
});
});
}
}
#[cfg(feature = "sqlite")]
pub async fn create_sqlite_pool() -> anyhow::Result<sqlx::SqlitePool> {
use sqlx::sqlite::SqlitePoolOptions;
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect(":memory:")
.await?;
Ok(pool)
}
#[cfg(test)]
mod tests {
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn test_sqlite_pool() {
use super::*;
let pool = create_sqlite_pool().await.unwrap();
let result: (i32,) = sqlx::query_as("SELECT 1")
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(result.0, 1);
}
}