#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
use sqlx::PgPool;
use std::path::Path;
#[cfg(feature = "testcontainers")]
use std::sync::Arc;
use tracing::{debug, info};
use crate::error::{ForgeError, Result};
#[cfg(feature = "testcontainers")]
type PgContainer =
Arc<Option<testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>>>;
pub struct TestDatabase {
pool: PgPool,
url: String,
#[cfg(feature = "testcontainers")]
_container: PgContainer,
}
impl TestDatabase {
pub async fn from_url(url: &str) -> Result<Self> {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.connect(url)
.await
.map_err(ForgeError::Sql)?;
Ok(Self {
pool,
url: url.to_string(),
#[cfg(feature = "testcontainers")]
_container: Arc::new(None),
})
}
pub async fn from_env() -> Result<Self> {
match std::env::var("TEST_DATABASE_URL") {
Ok(url) => Self::from_url(&url).await,
Err(_) => {
#[cfg(feature = "testcontainers")]
{
Self::from_container().await
}
#[cfg(not(feature = "testcontainers"))]
{
Err(ForgeError::Database(
"TEST_DATABASE_URL not set. Set it explicitly for database tests, \
or enable the `testcontainers` feature for automatic provisioning."
.to_string(),
))
}
}
}
}
#[cfg(feature = "testcontainers")]
async fn from_container() -> Result<Self> {
use testcontainers::ImageExt;
use testcontainers::runners::AsyncRunner;
use testcontainers_modules::postgres::Postgres;
let container = Postgres::default()
.with_tag("18-alpine")
.start()
.await
.map_err(|e| ForgeError::Database(format!("Failed to start PG container: {e}")))?;
let port = container
.get_host_port_ipv4(5432)
.await
.map_err(|e| ForgeError::Database(format!("Failed to get container port: {e}")))?;
let url = format!("postgres://postgres:postgres@localhost:{port}/postgres");
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.acquire_timeout(std::time::Duration::from_secs(30))
.connect(&url)
.await
.map_err(ForgeError::Sql)?;
Ok(Self {
pool,
url,
_container: Arc::new(Some(container)),
})
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn url(&self) -> &str {
&self.url
}
pub async fn execute(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.execute(&self.pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
let base_url = self.url.clone();
let db_name = format!(
"forge_test_{}_{}",
sanitize_db_name(test_name),
uuid::Uuid::new_v4().simple()
);
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&base_url)
.await
.map_err(ForgeError::Sql)?;
sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
.execute(&pool)
.await
.map_err(ForgeError::Sql)?;
let test_url = replace_db_name(&base_url, &db_name);
let test_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&test_url)
.await
.map_err(ForgeError::Sql)?;
Ok(IsolatedTestDb {
pool: test_pool,
db_name,
base_url,
#[cfg(feature = "testcontainers")]
_container: self._container.clone(),
})
}
}
pub struct IsolatedTestDb {
pool: PgPool,
db_name: String,
base_url: String,
#[cfg(feature = "testcontainers")]
_container: PgContainer,
}
impl IsolatedTestDb {
pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result<Self> {
let base = TestDatabase::from_env().await?;
let db = base.isolated(test_name).await?;
db.run_sql(internal_sql).await?;
db.migrate(migrations_dir).await?;
Ok(db)
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn db_name(&self) -> &str {
&self.db_name
}
pub async fn execute(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.execute(&self.pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
pub async fn run_sql(&self, sql: &str) -> Result<()> {
for stmt in split_sql_statements(sql) {
let stmt = stmt.trim();
if is_blank_sql(stmt) {
continue;
}
sqlx::query(stmt)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {e}")))?;
}
Ok(())
}
pub async fn cleanup(self) -> Result<()> {
self.pool.close().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&self.base_url)
.await
.map_err(ForgeError::Sql)?;
let _ = sqlx::query(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1",
)
.bind(&self.db_name)
.execute(&pool)
.await;
sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
.execute(&pool)
.await
.map_err(ForgeError::Sql)?;
Ok(())
}
pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
if !migrations_dir.exists() {
debug!("Migrations directory does not exist: {:?}", migrations_dir);
return Ok(());
}
let mut migrations = Vec::new();
let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
for entry in entries {
let entry = entry.map_err(ForgeError::Io)?;
let path = entry.path();
if path.extension().map(|e| e == "sql").unwrap_or(false) {
let name = path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
.to_string();
let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
migrations.push((name, content));
}
}
migrations.sort_by(|a, b| a.0.cmp(&b.0));
debug!("Running {} migrations for test", migrations.len());
for (name, content) in migrations {
info!("Applying test migration: {}", name);
let up_sql = parse_up_sql(&content);
for stmt in split_sql_statements(&up_sql) {
let stmt = stmt.trim();
if is_blank_sql(stmt) {
continue;
}
sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
ForgeError::Database(format!("Failed to apply migration '{name}': {e}"))
})?;
}
}
Ok(())
}
}
fn is_blank_sql(sql: &str) -> bool {
sql.is_empty()
|| sql
.lines()
.all(|l| l.trim().is_empty() || l.trim().starts_with("--"))
}
fn sanitize_db_name(name: &str) -> String {
name.chars()
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.take(32)
.collect()
}
fn replace_db_name(url: &str, new_db: &str) -> String {
if let Some(idx) = url.rfind('/') {
let base = &url[..=idx];
if let Some(query_idx) = url[idx + 1..].find('?') {
let query = &url[idx + 1 + query_idx..];
format!("{}{}{}", base, new_db, query)
} else {
format!("{}{}", base, new_db)
}
} else {
format!("{}/{}", url, new_db)
}
}
fn parse_up_sql(content: &str) -> String {
let down_markers = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
let up_part = down_markers
.iter()
.find_map(|m| content.find(m).map(|idx| &content[..idx]))
.unwrap_or(content);
strip_up_markers(up_part)
}
fn strip_up_markers(sql: &str) -> String {
sql.replace("-- @up", "")
.replace("--@up", "")
.replace("-- @UP", "")
.replace("--@UP", "")
.trim()
.to_string()
}
fn split_sql_statements(sql: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current = String::new();
let mut in_dollar_quote = false;
let mut dollar_tag = String::new();
let mut chars = sql.chars().peekable();
while let Some(c) = chars.next() {
current.push(c);
if c == '$' {
let mut potential_tag = String::from("$");
while let Some(&next_c) = chars.peek() {
if next_c == '$' {
potential_tag.push(chars.next().unwrap());
current.push('$');
break;
} else if next_c.is_alphanumeric() || next_c == '_' {
potential_tag.push(chars.next().unwrap());
current.push(potential_tag.chars().last().unwrap());
} else {
break;
}
}
if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
if in_dollar_quote && potential_tag == dollar_tag {
in_dollar_quote = false;
dollar_tag.clear();
} else if !in_dollar_quote {
in_dollar_quote = true;
dollar_tag = potential_tag;
}
}
}
if c == ';' && !in_dollar_quote {
let stmt = current.trim().trim_end_matches(';').trim().to_string();
if !stmt.is_empty() {
statements.push(stmt);
}
current.clear();
}
}
let stmt = current.trim().trim_end_matches(';').trim().to_string();
if !stmt.is_empty() {
statements.push(stmt);
}
statements
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_db_name() {
assert_eq!(sanitize_db_name("my_test"), "my_test");
assert_eq!(sanitize_db_name("my-test"), "my_test");
assert_eq!(sanitize_db_name("my test"), "my_test");
assert_eq!(sanitize_db_name("test::function"), "test__function");
}
#[test]
fn test_replace_db_name() {
assert_eq!(
replace_db_name("postgres://localhost/olddb", "newdb"),
"postgres://localhost/newdb"
);
assert_eq!(
replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
"postgres://user:pass@localhost:5432/newdb"
);
assert_eq!(
replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
"postgres://localhost/newdb?sslmode=disable"
);
}
#[test]
fn split_simple_statements() {
let stmts = split_sql_statements("CREATE TABLE a (id int); CREATE TABLE b (id int);");
assert_eq!(stmts.len(), 2);
assert!(stmts[0].starts_with("CREATE TABLE a"));
assert!(stmts[1].starts_with("CREATE TABLE b"));
}
#[test]
fn split_preserves_dollar_quoted_content() {
let sql = r#"
CREATE FUNCTION test() RETURNS void AS $$
BEGIN
INSERT INTO logs (msg) VALUES ('hello; world');
END;
$$ LANGUAGE plpgsql;
SELECT 1;
"#;
let stmts = split_sql_statements(sql);
assert_eq!(
stmts.len(),
2,
"Should split into function + SELECT, not more"
);
assert!(
stmts[0].contains("$$"),
"Function body must include dollar quotes"
);
}
#[test]
fn split_handles_empty_input() {
let stmts = split_sql_statements("");
assert!(stmts.is_empty());
}
#[test]
fn split_handles_no_trailing_semicolon() {
let stmts = split_sql_statements("SELECT 1");
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT 1");
}
#[test]
fn split_skips_blank_statements() {
let stmts = split_sql_statements("; ; SELECT 1; ;");
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT 1");
}
#[test]
fn parse_up_sql_strips_down_section() {
let content = "CREATE TABLE a (id int);\n-- @down\nDROP TABLE a;";
let up = parse_up_sql(content);
assert!(up.contains("CREATE TABLE"));
assert!(!up.contains("DROP TABLE"), "Down SQL should be excluded");
}
#[test]
fn parse_up_sql_handles_no_down_marker() {
let content = "CREATE TABLE a (id int);";
let up = parse_up_sql(content);
assert!(up.contains("CREATE TABLE"));
}
#[test]
fn parse_up_sql_strips_up_markers() {
let content = "-- @up\nCREATE TABLE a (id int);";
let up = parse_up_sql(content);
assert!(!up.contains("@up"), "Up marker should be stripped");
assert!(up.contains("CREATE TABLE"));
}
#[test]
fn blank_sql_detection() {
assert!(is_blank_sql(""));
assert!(is_blank_sql(" "));
assert!(is_blank_sql("-- just a comment"));
assert!(is_blank_sql("-- comment\n-- another"));
assert!(!is_blank_sql("SELECT 1"));
assert!(!is_blank_sql("-- comment\nSELECT 1"));
}
#[test]
fn sanitize_truncates_long_names() {
let long_name = "a".repeat(100);
let sanitized = sanitize_db_name(&long_name);
assert_eq!(sanitized.len(), 32);
}
#[test]
fn sanitize_handles_special_characters() {
assert_eq!(
sanitize_db_name("test/with:special!chars"),
"test_with_special_chars"
);
assert_eq!(sanitize_db_name(""), "");
}
}