use sea_orm::{ConnectionTrait, Database, DatabaseConnection, DbErr, Statement};
use sea_orm_migration::MigratorTrait;
#[cfg(feature = "test-containers")]
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use url::Url;
#[cfg(feature = "test-containers")]
mod postgres_container;
static TEST_DB_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, Default)]
pub enum TestDbBackend {
#[default]
SqliteMemory,
Postgres,
#[cfg(feature = "test-containers")]
PostgresContainer,
}
#[derive(Debug, Default)]
pub struct TestDbConfig {
pub backend: TestDbBackend,
pub database_url: Option<String>,
}
pub struct TestDb {
pub connection: DatabaseConnection,
#[cfg(feature = "test-containers")]
#[allow(dead_code)]
container: Option<Arc<postgres_container::PostgresContainer>>,
}
impl TestDb {
pub async fn new_with_config(config: TestDbConfig) -> Result<Self, DbErr> {
match config.backend {
TestDbBackend::SqliteMemory => Self::new().await,
TestDbBackend::Postgres => {
if let Some(url) = config.database_url {
Self::create_postgres_db_from_base_url(&url).await
} else {
Self::new_postgres().await
}
}
#[cfg(feature = "test-containers")]
TestDbBackend::PostgresContainer => Self::new_postgres_container().await,
}
}
pub async fn new_with_sqlite() -> Result<Self, DbErr> {
Self::new().await
}
pub async fn new_with_postgres_url(database_url: &str) -> Result<Self, DbErr> {
Self::create_postgres_db_from_base_url(database_url).await
}
pub async fn new_with_migrator<M: MigratorTrait>() -> Result<Self, DbErr> {
let connection = Database::connect("sqlite::memory:?mode=rwc&cache=shared").await?;
connection
.execute_unprepared("PRAGMA journal_mode=WAL;")
.await?;
connection
.execute_unprepared("PRAGMA busy_timeout=5000;")
.await?;
M::up(&connection, None).await?;
Ok(Self {
connection,
#[cfg(feature = "test-containers")]
container: None,
})
}
pub async fn new() -> Result<Self, DbErr> {
let connection = Database::connect("sqlite::memory:?mode=rwc&cache=shared").await?;
connection
.execute_unprepared("PRAGMA journal_mode=WAL;")
.await?;
connection
.execute_unprepared("PRAGMA busy_timeout=5000;")
.await?;
Ok(Self {
connection,
#[cfg(feature = "test-containers")]
container: None,
})
}
pub async fn new_postgres_with_migrator<M: MigratorTrait>() -> Result<Self, DbErr> {
let instance = Self::create_postgres_db().await?;
M::up(&instance.connection, None).await?;
Ok(instance)
}
pub async fn new_postgres() -> Result<Self, DbErr> {
Self::create_postgres_db().await
}
#[cfg(feature = "test-containers")]
pub async fn new_postgres_container() -> Result<Self, DbErr> {
let container = postgres_container::PostgresContainer::start().await?;
let connection = Database::connect(&container.connection_url).await?;
Ok(Self {
connection,
container: Some(Arc::new(container)),
})
}
#[cfg(feature = "test-containers")]
pub async fn new_postgres_container_with_migrator<M: MigratorTrait>() -> Result<Self, DbErr> {
let instance = Self::new_postgres_container().await?;
M::up(&instance.connection, None).await?;
Ok(instance)
}
async fn create_postgres_db() -> Result<Self, DbErr> {
let base_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
Self::create_postgres_db_from_base_url(&base_url).await
}
async fn create_postgres_db_from_base_url(base_url: &str) -> Result<Self, DbErr> {
let admin_connection = Database::connect(base_url).await?;
let counter = TEST_DB_COUNTER.fetch_add(1, Ordering::SeqCst);
let db_name = format!("test_db_{}_{}", std::process::id(), counter);
let create_db_stmt = format!("CREATE DATABASE \"{}\"", escape_identifier(&db_name));
admin_connection
.execute(Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
create_db_stmt,
))
.await
.map_err(|e| {
DbErr::Custom(format!(
"Failed to create test database '{}': {}",
db_name, e
))
})?;
admin_connection
.close()
.await
.map_err(|e| DbErr::Custom(format!("Failed to close admin connection: {}", e)))?;
let test_db_url = build_test_db_url(base_url, &db_name)?;
let connection = Database::connect(&test_db_url).await.map_err(|e| {
DbErr::Custom(format!(
"Failed to connect to test database '{}': {}",
db_name, e
))
})?;
Ok(Self {
connection,
#[cfg(feature = "test-containers")]
container: None,
})
}
pub fn connection(&self) -> DatabaseConnection {
self.connection.clone()
}
pub async fn seed(&self, statements: &[&str]) -> Result<(), DbErr> {
for statement in statements {
self.connection.execute_unprepared(statement).await?;
}
Ok(())
}
pub async fn reset(&self) -> Result<(), DbErr> {
let drop_tables_stmt = Statement::from_string(
sea_orm::DatabaseBackend::Sqlite,
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
.to_string(),
);
let result = self.connection.query_all(drop_tables_stmt).await;
if let Ok(rows) = result {
for row in rows {
if let Ok(table_name) = row.try_get::<String>("", "name") {
let drop_stmt = format!("DROP TABLE IF EXISTS \"{}\"", table_name);
self.connection.execute_unprepared(&drop_stmt).await?;
}
}
}
Ok(())
}
pub async fn with_transaction_rollback<F, Fut>(&self, f: F) -> Result<(), DbErr>
where
F: for<'a> FnOnce(&'a sea_orm::DatabaseTransaction) -> Fut,
Fut: std::future::Future<Output = Result<(), DbErr>>,
{
use sea_orm::TransactionTrait;
let txn = self.connection.begin().await?;
let result = f(&txn).await;
txn.rollback().await?;
result
}
}
fn escape_identifier(identifier: &str) -> String {
identifier.replace('"', "\"\"")
}
fn build_test_db_url(base_url: &str, new_db_name: &str) -> Result<String, DbErr> {
let mut url = Url::parse(base_url)
.map_err(|e| DbErr::Custom(format!("Invalid database URL '{}': {}", base_url, e)))?;
let path = url.path();
let new_path = if let Some(idx) = path.rfind('/') {
format!("{}/{}", &path[..idx], new_db_name)
} else {
format!("/{}", new_db_name)
};
url.set_path(&new_path);
Ok(url.to_string())
}
#[macro_export]
macro_rules! test_db {
() => {{
$crate::testing::TestDb::new()
.await
.expect("Failed to create test database")
}};
($migrator:ty) => {{
$crate::testing::TestDb::new_with_migrator::<$migrator>()
.await
.expect("Failed to create test database")
}};
}