use fs2::FileExt;
use sqlx::{Executor, PgPool};
use std::io::{Read, Write};
use std::path::PathBuf;
use std::time::Duration;
use testcontainers::{
ContainerAsync, GenericImage, ImageExt,
core::{IntoContainerPort, WaitFor},
runners::AsyncRunner,
};
use tokio::sync::OnceCell;
use tokio::time::sleep;
use uuid::Uuid;
async fn retry_with_backoff<F, Fut, T, E>(
max_attempts: usize,
initial_delay: Duration,
mut operation: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
{
let mut delay = initial_delay;
let mut last_error = None;
for attempt in 1..=max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < max_attempts {
eprintln!(
"[retry] Attempt {}/{} failed, retrying after {:?}",
attempt, max_attempts, delay
);
sleep(delay).await;
delay *= 2; }
}
}
}
Err(last_error.expect("Should have at least one error"))
}
pub struct SharedPostgres {
#[allow(dead_code)]
container: Option<ContainerAsync<GenericImage>>,
pub base_url: String,
}
static POSTGRES: OnceCell<SharedPostgres> = OnceCell::const_new();
fn get_url_file_path() -> PathBuf {
std::env::temp_dir().join("reinhardt_test_postgres_url")
}
fn get_lock_file_path() -> PathBuf {
std::env::temp_dir().join("reinhardt_test_postgres.lock")
}
async fn test_connection(url: &str) -> bool {
match sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(3))
.connect(url)
.await
{
Ok(pool) => {
let result = sqlx::query("SELECT 1").fetch_one(&pool).await;
pool.close().await;
result.is_ok()
}
Err(_) => false,
}
}
fn read_url_from_file() -> Option<String> {
let path = get_url_file_path();
if !path.exists() {
return None;
}
let mut file = std::fs::File::open(&path).ok()?;
let mut url = String::new();
file.read_to_string(&mut url).ok()?;
if url.trim().is_empty() {
None
} else {
Some(url.trim().to_string())
}
}
fn write_url_to_file(url: &str) -> std::io::Result<()> {
let path = get_url_file_path();
let mut file = std::fs::File::create(&path)?;
file.write_all(url.as_bytes())?;
file.sync_all()
}
async fn start_postgres_container() -> (ContainerAsync<GenericImage>, String) {
let container = GenericImage::new("postgres", "17-alpine")
.with_exposed_port(5432.tcp())
.with_wait_for(WaitFor::message_on_stderr(
"database system is ready to accept connections",
))
.with_env_var("POSTGRES_HOST_AUTH_METHOD", "trust")
.start()
.await
.expect("Failed to start PostgreSQL container");
let host = container.get_host().await.unwrap();
let port = retry_with_backoff(3, Duration::from_millis(100), || async {
container
.get_host_port_ipv4(5432.tcp())
.await
.map_err(|e| format!("Port retrieval failed: {}", e))
})
.await
.expect("Failed to get container port after retries");
let base_url = format!("postgres://postgres@{}:{}", host, port);
eprintln!(
"[shared_postgres] Started new PostgreSQL container at {}:{}",
host, port
);
(container, base_url)
}
async fn init_template_database(base_url: &str) {
let admin_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(60))
.test_before_acquire(false)
.idle_timeout(Some(Duration::from_secs(30)))
.connect(&format!("{}/postgres?sslmode=disable", base_url))
.await
.expect("Failed to connect to PostgreSQL for template setup");
admin_pool
.execute("CREATE DATABASE test_template")
.await
.ok();
admin_pool
.execute("ALTER DATABASE test_template IS_TEMPLATE true")
.await
.ok();
}
pub async fn get_shared_postgres() -> &'static SharedPostgres {
POSTGRES
.get_or_init(|| async {
let lock_path = get_lock_file_path();
let lock_file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(false)
.open(&lock_path)
.expect("Failed to create lock file");
lock_file.lock_exclusive().expect("Failed to acquire lock");
if let Some(url) = read_url_from_file() {
let postgres_url = format!("{}/postgres?sslmode=disable", url);
if test_connection(&postgres_url).await {
eprintln!("[shared_postgres] Reusing existing container at {}", url);
lock_file.unlock().ok();
return SharedPostgres {
container: None, base_url: url,
};
} else {
eprintln!(
"[shared_postgres] Existing container not reachable, starting new one"
);
}
}
let (container, base_url) = start_postgres_container().await;
init_template_database(&base_url).await;
write_url_to_file(&base_url).expect("Failed to write URL to file");
lock_file.unlock().ok();
SharedPostgres {
container: Some(container),
base_url,
}
})
.await
}
pub async fn get_test_pool() -> PgPool {
let pg = get_shared_postgres().await;
let db_name = format!("test_{}", Uuid::now_v7().simple());
let admin_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(10))
.connect(&format!("{}/postgres?sslmode=disable", pg.base_url))
.await
.expect("Failed to connect to postgres for test database creation");
let create_sql = format!("CREATE DATABASE {} TEMPLATE test_template", db_name);
sqlx::query(&create_sql)
.execute(&admin_pool)
.await
.expect("Failed to create test database from template");
sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(10))
.test_before_acquire(false)
.idle_timeout(Some(Duration::from_secs(30)))
.connect(&format!("{}/{}?sslmode=disable", pg.base_url, db_name))
.await
.expect("Failed to connect to test database")
}
pub async fn get_test_pool_with_table(table_sql: &str) -> PgPool {
let pool = get_test_pool().await;
sqlx::query(table_sql)
.execute(&pool)
.await
.expect("Failed to create table in test database");
pool
}
pub async fn get_test_pool_with_orm() -> (PgPool, String) {
let pg = get_shared_postgres().await;
let db_name = format!("test_{}", Uuid::now_v7().simple());
let db_url = format!("{}/{}?sslmode=disable", pg.base_url, db_name);
let admin_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(10))
.connect(&format!("{}/postgres?sslmode=disable", pg.base_url))
.await
.expect("Failed to connect to postgres for test database creation");
let create_sql = format!("CREATE DATABASE {} TEMPLATE test_template", db_name);
sqlx::query(&create_sql)
.execute(&admin_pool)
.await
.expect("Failed to create test database from template");
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(10))
.test_before_acquire(false)
.idle_timeout(Some(Duration::from_secs(30)))
.connect(&db_url)
.await
.expect("Failed to connect to test database");
reinhardt_db::orm::reinitialize_database(&db_url)
.await
.expect("Failed to reinitialize ORM database");
(pool, db_url)
}
pub fn cleanup_shared_postgres() {
let url_path = get_url_file_path();
let lock_path = get_lock_file_path();
if url_path.exists() {
std::fs::remove_file(url_path).ok();
}
if lock_path.exists() {
std::fs::remove_file(lock_path).ok();
}
}
#[rstest::fixture]
pub async fn shared_db_pool() -> (PgPool, String) {
get_test_pool_with_orm().await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shared_postgres_initialization() {
let pg = get_shared_postgres().await;
assert!(!pg.base_url.is_empty());
}
#[tokio::test]
async fn test_isolated_databases() {
let pool1 = get_test_pool().await;
let pool2 = get_test_pool().await;
sqlx::query("CREATE TABLE test_table (id SERIAL PRIMARY KEY)")
.execute(&pool1)
.await
.expect("Failed to create table");
let result = sqlx::query("SELECT 1 FROM test_table")
.fetch_optional(&pool2)
.await;
assert!(
result.is_err(),
"Databases should be isolated - table should not exist in pool2"
);
}
}