use ares::db::PostgresClient;
use ares::types::Result;
use std::sync::Once;
static LOAD_ENV: Once = Once::new();
static INIT_SCHEMA: std::sync::OnceLock<()> = std::sync::OnceLock::new();
fn ensure_env_loaded() {
LOAD_ENV.call_once(|| {
let _ = dotenvy::dotenv();
});
}
pub fn test_db_url() -> String {
ensure_env_loaded();
if let Ok(url) = std::env::var("TEST_DATABASE_URL") {
return url;
}
if let Ok(url) = std::env::var("DATABASE_URL") {
if url.contains("/ares") && !url.contains("ares_test") {
return url.replace("/ares", "/ares_test");
}
return url;
}
"postgres://dirmacs@localhost:5432/ares_test".to_string()
}
pub async fn create_test_db() -> PostgresClient {
let url = test_db_url();
let db = PostgresClient::new_remote(url, String::new())
.await
.expect("Failed to connect to ares_test. Ensure it exists and migrations are applied.");
if INIT_SCHEMA.set(()).is_ok() {
cleanup_tables(&db).await;
ensure_schema(&db).await.expect("Failed to run migrations on ares_test");
}
db
}
async fn cleanup_tables(db: &PostgresClient) {
let tables = [
"messages",
"conversations",
"sessions",
"memory_facts",
"preferences",
"user_agents",
"users",
];
for table in &tables {
let query = format!("TRUNCATE TABLE {} CASCADE", table);
if let Err(e) = sqlx::query(&query).execute(&db.pool).await {
eprintln!("Warning: failed to truncate {}: {}", table, e);
}
}
}
pub async fn ensure_schema(db: &PostgresClient) -> Result<()> {
sqlx::migrate!("./migrations")
.run(&db.pool)
.await
.map_err(|e| ares::types::AppError::Database(format!("Migration failed: {}", e)))?;
Ok(())
}