use sqlx::postgres::PgPool;
use crate::persistence::sql::connection_string::ConnectionString;
#[derive(Clone)]
pub struct SqlDb {
pool: PgPool,
#[cfg(any(test, feature = "testing"))]
db_dropper: Option<std::sync::Arc<TestDbDropper>>,
}
impl std::fmt::Debug for SqlDb {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DbConnection")
}
}
impl SqlDb {
pub async fn connect(con_string: &ConnectionString) -> Result<Self, sqlx::Error> {
#[cfg(any(test, feature = "testing"))]
if con_string.is_test_db() {
return Self::test_postgres_db(Some(con_string.clone())).await;
}
Self::connect_inner(con_string).await
}
async fn connect_inner(con_string: &ConnectionString) -> Result<Self, sqlx::Error> {
let pool: PgPool = PgPool::connect(con_string.as_str()).await?;
Ok(Self {
pool,
#[cfg(any(test, feature = "testing"))]
db_dropper: None,
})
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
}
#[cfg(any(test, feature = "testing"))]
struct TestDbDropper {
db_name: String,
connection_string: String,
}
#[cfg(any(test, feature = "testing"))]
impl TestDbDropper {
pub fn new(db_name: String, connection_string: String) -> Self {
Self {
db_name,
connection_string,
}
}
}
#[cfg(any(test, feature = "testing"))]
impl Drop for TestDbDropper {
fn drop(&mut self) {
let _ = pubky_test_utils::register_db_to_drop(
self.db_name.clone(),
self.connection_string.clone(),
);
}
}
#[cfg(any(test, feature = "testing"))]
const DEFAULT_TEST_CONNECTION_STRING: &str = "postgres://localhost:5432/postgres";
#[cfg(any(test, feature = "testing"))]
impl SqlDb {
async fn create_test_database(
admin_con_string: ConnectionString,
) -> Result<ConnectionString, sqlx::Error> {
use uuid::Uuid;
let admin_con = Self::connect_inner(&admin_con_string).await?;
let test_db_name = format!("pubky_test_{}", Uuid::new_v4().as_simple());
let query = format!("CREATE DATABASE {}", test_db_name);
sqlx::query(&query).execute(admin_con.pool()).await?;
let mut test_db_con_string = admin_con_string.clone();
test_db_con_string.set_database_name(&test_db_name);
Ok(test_db_con_string)
}
pub async fn test_postgres_db(
admin_con_string: Option<ConnectionString>,
) -> Result<Self, sqlx::Error> {
let admin_con_string = Self::derive_connection_string(admin_con_string);
let test_db_con_string = Self::create_test_database(admin_con_string.clone()).await?;
let mut con = Self::connect_inner(&test_db_con_string).await?;
con.db_dropper = Some(std::sync::Arc::new(TestDbDropper::new(
test_db_con_string.database_name().to_string(),
admin_con_string.to_string(),
)));
Ok(con)
}
pub fn derive_connection_string(
admin_con_string: Option<ConnectionString>,
) -> ConnectionString {
if let Some(con_string) = admin_con_string {
return con_string.clone();
}
if let Ok(raw_con_string) = std::env::var("TEST_PUBKY_CONNECTION_STRING") {
match ConnectionString::new(&raw_con_string) {
Ok(con_string) => return con_string,
Err(e) => {
tracing::warn!("Invalid database connection string in TEST_PUBKY_CONNECTION_STRING environment variable: {}. Fallback to default test connection string. Error: {e}", raw_con_string);
}
}
}
ConnectionString::new(DEFAULT_TEST_CONNECTION_STRING)
.expect("Default test connection string is valid")
}
pub async fn test_without_migrations() -> Self {
Self::test_postgres_db(None)
.await
.expect("Failed to create test database")
}
pub async fn test() -> Self {
use crate::persistence::sql::migrator::Migrator;
let db = Self::test_without_migrations().await;
let migrator = Migrator::new(&db);
migrator.run().await.expect("Failed to run migrations");
db
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[pubky_test_utils::test]
async fn test_pg_db_available() {
let _db = SqlDb::test_postgres_db(None).await.unwrap();
}
}