use sqlx::PgPool;
#[cfg(feature = "embedded-test-db")]
use tokio::sync::OnceCell;
use crate::error::{ForgeError, Result};
#[cfg(feature = "embedded-test-db")]
static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
pub struct TestDatabase {
pool: PgPool,
url: String,
}
impl TestDatabase {
pub async fn from_url(url: &str) -> Result<Self> {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.connect(url)
.await
.map_err(ForgeError::Sql)?;
Ok(Self {
pool,
url: url.to_string(),
})
}
pub async fn from_env() -> Result<Self> {
let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
ForgeError::Database(
"TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
)
})?;
Self::from_url(&url).await
}
#[cfg(feature = "embedded-test-db")]
pub async fn embedded() -> Result<Self> {
let pg = EMBEDDED_PG
.get_or_try_init(|| async {
let mut pg = postgresql_embedded::PostgreSQL::default();
pg.setup().await.map_err(|e| {
ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
})?;
pg.start().await.map_err(|e| {
ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
})?;
Ok::<_, ForgeError>(pg)
})
.await?;
let url = pg.settings().url("postgres");
Self::from_url(&url).await
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn url(&self) -> &str {
&self.url
}
pub async fn execute(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.execute(&self.pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
let base_url = self.url.clone();
let db_name = format!(
"forge_test_{}_{}",
sanitize_db_name(test_name),
uuid::Uuid::new_v4().simple()
);
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&base_url)
.await
.map_err(ForgeError::Sql)?;
sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
.execute(&pool)
.await
.map_err(ForgeError::Sql)?;
let test_url = replace_db_name(&base_url, &db_name);
let test_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&test_url)
.await
.map_err(ForgeError::Sql)?;
Ok(IsolatedTestDb {
pool: test_pool,
db_name,
base_url,
})
}
}
pub struct IsolatedTestDb {
pool: PgPool,
db_name: String,
base_url: String,
}
impl IsolatedTestDb {
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn db_name(&self) -> &str {
&self.db_name
}
pub async fn execute(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.execute(&self.pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
pub async fn cleanup(self) -> Result<()> {
self.pool.close().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&self.base_url)
.await
.map_err(ForgeError::Sql)?;
let _ = sqlx::query(&format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
self.db_name
))
.execute(&pool)
.await;
sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
.execute(&pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
}
fn sanitize_db_name(name: &str) -> String {
name.chars()
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.take(32)
.collect()
}
fn replace_db_name(url: &str, new_db: &str) -> String {
if let Some(idx) = url.rfind('/') {
let base = &url[..=idx];
if let Some(query_idx) = url[idx + 1..].find('?') {
let query = &url[idx + 1 + query_idx..];
format!("{}{}{}", base, new_db, query)
} else {
format!("{}{}", base, new_db)
}
} else {
format!("{}/{}", url, new_db)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_db_name() {
assert_eq!(sanitize_db_name("my_test"), "my_test");
assert_eq!(sanitize_db_name("my-test"), "my_test");
assert_eq!(sanitize_db_name("my test"), "my_test");
assert_eq!(sanitize_db_name("test::function"), "test__function");
}
#[test]
fn test_replace_db_name() {
assert_eq!(
replace_db_name("postgres://localhost/olddb", "newdb"),
"postgres://localhost/newdb"
);
assert_eq!(
replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
"postgres://user:pass@localhost:5432/newdb"
);
assert_eq!(
replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
"postgres://localhost/newdb?sslmode=disable"
);
}
}