use sqlx::postgres::{PgPool, PgPoolOptions};
use thiserror::Error;
use tracing::{error, info, warn};
use url::Url;
#[derive(Debug, Error)]
pub enum SqlxError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Migration error: {0}")]
Migration(#[from] sqlx::migrate::MigrateError),
#[error("Database URL parse error: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("Database name missing in URL")]
DatabaseNameMissing,
#[error("Configuration error: {0}")]
Config(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub type SqlxResult<T> = Result<T, SqlxError>;
pub type SqlxTx<'a> = sqlx::Transaction<'a, sqlx::Postgres>;
async fn ensure_database_exists(url: &str) -> SqlxResult<()> {
let parsed = Url::parse(url)?;
let db_name = parsed
.path_segments()
.and_then(|segs| segs.filter(|s| !s.is_empty()).last())
.map(str::to_string)
.filter(|s| !s.trim().is_empty())
.ok_or(SqlxError::DatabaseNameMissing)?;
let mut system_url = parsed.clone();
system_url.set_path("/postgres");
let system_pool = PgPoolOptions::new()
.max_connections(1)
.connect(system_url.as_str())
.await?;
let exists: bool =
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)")
.bind(&db_name)
.fetch_one(&system_pool)
.await?;
if exists {
info!("Database '{}' already exists", db_name);
} else {
let sanitized = db_name.replace('"', "\"\"");
sqlx::query(&format!("CREATE DATABASE \"{}\"", sanitized))
.execute(&system_pool)
.await?;
info!("Database '{}' created", db_name);
}
system_pool.close().await;
Ok(())
}
#[derive(Clone, Debug)]
pub struct SqlxPool {
pool: PgPool,
}
impl SqlxPool {
pub async fn new(url: &str, max_connections: u32) -> SqlxResult<Self> {
ensure_database_exists(url).await?;
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.after_connect(|conn, _meta| {
Box::pin(async move {
sqlx::query("SET TIME ZONE 'UTC'")
.execute(&mut *conn)
.await
.map(|_| ())
})
})
.connect(url)
.await?;
Ok(Self { pool })
}
pub async fn from_env() -> SqlxResult<Self> {
let url = std::env::var("DATABASE_URL")
.map_err(|_| SqlxError::Config("DATABASE_URL env var not set".into()))?;
let max_connections = std::env::var("DATABASE_POOL_SIZE")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(10);
Self::new(&url, max_connections).await
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn begin(&self) -> SqlxResult<sqlx::Transaction<'_, sqlx::Postgres>> {
self.pool.begin().await.map_err(SqlxError::from)
}
pub async fn with_transaction<F, T>(&self, f: F) -> SqlxResult<T>
where
F: for<'c> FnOnce(
&'c mut SqlxTx<'static>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = SqlxResult<T>> + 'c>,
>,
{
let mut tx = self.pool.begin().await?;
match f(&mut tx).await {
Ok(val) => {
tx.commit().await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = tx.rollback().await {
warn!("SQLx transaction rollback failed: {}", rb_err);
}
Err(e)
}
}
}
pub async fn run_migrations(
&self,
migrator: &sqlx::migrate::Migrator,
) -> SqlxResult<()> {
migrator.run(&self.pool).await.map_err(SqlxError::from)
}
pub fn size(&self) -> u32 {
self.pool.size()
}
pub fn idle(&self) -> u32 {
self.pool.num_idle() as u32
}
pub async fn health_check(&self) -> SqlxResult<()> {
sqlx::query("SELECT 1")
.execute(&self.pool)
.await
.map_err(|e| {
error!("SQLx health check failed: {}", e);
SqlxError::from(e)
})?;
info!("SQLx health check passed: db connection test successful");
Ok(())
}
}
impl std::ops::Deref for SqlxPool {
type Target = PgPool;
fn deref(&self) -> &Self::Target {
&self.pool
}
}
#[macro_export]
macro_rules! sqlx_with_tx {
($pool:expr, |$tx:ident| $body:block) => {
$pool.with_transaction(|$tx| Box::pin(async move { $body }))
};
}