use sea_orm::{ConnectionTrait, Database, DatabaseConnection, DbErr, Statement};
use sea_orm_migration::MigratorTrait;
use std::sync::atomic::{AtomicU64, Ordering};
use url::Url;
static TEST_DB_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct TestDb {
pub connection: DatabaseConnection,
}
impl TestDb {
pub async fn new_with_migrator<M: MigratorTrait>() -> Result<Self, DbErr> {
let connection = Database::connect("sqlite::memory:?mode=rwc&cache=shared").await?;
connection
.execute_unprepared("PRAGMA journal_mode=WAL;")
.await?;
connection
.execute_unprepared("PRAGMA busy_timeout=5000;")
.await?;
M::up(&connection, None).await?;
Ok(Self { connection })
}
pub async fn new() -> Result<Self, DbErr> {
let connection = Database::connect("sqlite::memory:?mode=rwc&cache=shared").await?;
connection
.execute_unprepared("PRAGMA journal_mode=WAL;")
.await?;
connection
.execute_unprepared("PRAGMA busy_timeout=5000;")
.await?;
Ok(Self { connection })
}
pub async fn new_postgres_with_migrator<M: MigratorTrait>() -> Result<Self, DbErr> {
let instance = Self::create_postgres_db().await?;
M::up(&instance.connection, None).await?;
Ok(instance)
}
pub async fn new_postgres() -> Result<Self, DbErr> {
Self::create_postgres_db().await
}
async fn create_postgres_db() -> Result<Self, DbErr> {
let base_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
let admin_connection = Database::connect(&base_url).await?;
let counter = TEST_DB_COUNTER.fetch_add(1, Ordering::SeqCst);
let db_name = format!("test_db_{}_{}", std::process::id(), counter);
let create_db_stmt = format!("CREATE DATABASE \"{}\"", escape_identifier(&db_name));
admin_connection
.execute(Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
create_db_stmt,
))
.await
.map_err(|e| {
DbErr::Custom(format!(
"Failed to create test database '{}': {}",
db_name, e
))
})?;
admin_connection
.close()
.await
.map_err(|e| DbErr::Custom(format!("Failed to close admin connection: {}", e)))?;
let test_db_url = build_test_db_url(&base_url, &db_name)?;
let connection = Database::connect(&test_db_url).await.map_err(|e| {
DbErr::Custom(format!(
"Failed to connect to test database '{}': {}",
db_name, e
))
})?;
Ok(Self { connection })
}
pub fn connection(&self) -> DatabaseConnection {
self.connection.clone()
}
pub async fn seed(&self, statements: &[&str]) -> Result<(), DbErr> {
for statement in statements {
self.connection.execute_unprepared(statement).await?;
}
Ok(())
}
pub async fn reset(&self) -> Result<(), DbErr> {
let drop_tables_stmt = Statement::from_string(
sea_orm::DatabaseBackend::Sqlite,
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'".to_string(),
);
let result = self.connection.query_all(drop_tables_stmt).await;
if let Ok(rows) = result {
for row in rows {
if let Ok(table_name) = row.try_get::<String>("", "name") {
let drop_stmt = format!("DROP TABLE IF EXISTS \"{}\"", table_name);
self.connection.execute_unprepared(&drop_stmt).await?;
}
}
}
Ok(())
}
pub async fn with_transaction_rollback<F, Fut>(&self, f: F) -> Result<(), DbErr>
where
F: for<'a> FnOnce(&'a sea_orm::DatabaseTransaction) -> Fut,
Fut: std::future::Future<Output = Result<(), DbErr>>,
{
use sea_orm::TransactionTrait;
let txn = self.connection.begin().await?;
let result = f(&txn).await;
txn.rollback().await?;
result
}
}
fn escape_identifier(identifier: &str) -> String {
identifier.replace('"', "\"\"")
}
fn build_test_db_url(base_url: &str, new_db_name: &str) -> Result<String, DbErr> {
let mut url = Url::parse(base_url)
.map_err(|e| DbErr::Custom(format!("Invalid database URL '{}': {}", base_url, e)))?;
let path = url.path();
let new_path = if let Some(idx) = path.rfind('/') {
format!("{}/{}", &path[..idx], new_db_name)
} else {
format!("/{}", new_db_name)
};
url.set_path(&new_path);
Ok(url.to_string())
}
#[macro_export]
macro_rules! test_db {
() => {{
$crate::testing::TestDb::new()
.await
.expect("Failed to create test database")
}};
($migrator:ty) => {{
$crate::testing::TestDb::new_with_migrator::<$migrator>()
.await
.expect("Failed to create test database")
}};
}