use deadpool_diesel::postgres::{Manager, Pool, Runtime};
use diesel::connection::Connection as DieselConnection;
use diesel::{PgConnection, QueryableByName, RunQueryDsl, sql_query, sql_types::Text};
use thiserror::Error;
use tracing::{error, info};
use url::Url;
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Database connection error: {0}")]
ConnectionError(#[from] deadpool_diesel::PoolError),
#[error("Database query error: {0}")]
QueryError(#[from] diesel::result::Error),
#[error("Database URL parse error: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("Database name missing in URL")]
DatabaseNameMissing,
#[error("Database interaction error: {0}")]
InteractionError(#[from] deadpool_diesel::InteractError),
#[error("Database initialization error: {0}")]
InitializationError(String),
#[error(transparent)]
UserError(#[from] anyhow::Error),
}
pub type DatabaseResult<T> = Result<T, DatabaseError>;
#[derive(QueryableByName)]
pub struct DbRow {
#[diesel(sql_type = Text)]
pub datname: String,
}
async fn ensure_database_exists(database_url: &str) -> DatabaseResult<()> {
let parsed = Url::parse(database_url)?;
let db_name = parsed
.path_segments()
.and_then(|segments| segments.filter(|s| !s.is_empty()).last())
.map(str::to_string)
.filter(|s| !s.trim().is_empty())
.ok_or(DatabaseError::DatabaseNameMissing)?;
let mut default_url = parsed.clone();
default_url.set_path("/postgres");
let default_url_string = default_url.to_string();
let sanitized_db_name = db_name.replace('"', "\"\"");
tokio::task::spawn_blocking(move || -> DatabaseResult<()> {
let mut conn = PgConnection::establish(&default_url_string).map_err(|e| {
DatabaseError::InitializationError(format!("Failed to connect to default db: {}", e))
})?;
let exists = !sql_query("SELECT datname FROM pg_database WHERE datname = $1")
.bind::<Text, _>(db_name.clone())
.load::<DbRow>(&mut conn)
.map_err(DatabaseError::QueryError)?
.is_empty();
if exists {
info!("Database '{}' already exists", db_name);
return Ok(());
}
let create_query = format!("CREATE DATABASE \"{}\"", sanitized_db_name);
sql_query(create_query)
.execute(&mut conn)
.map_err(DatabaseError::QueryError)?;
info!("Database '{}' created", db_name);
Ok(())
})
.await
.map_err(|e| {
DatabaseError::InitializationError(format!("Failed to ensure database exists: {}", e))
})?
}
#[derive(Clone)]
pub struct DieselPool {
pool: Pool,
}
impl DieselPool {
pub async fn new(url: impl Into<String>, max_size: usize) -> DatabaseResult<Self> {
let url = url.into();
ensure_database_exists(&url).await?;
let manager = Manager::new(url.clone(), Runtime::Tokio1);
let pool = Pool::builder(manager)
.max_size(max_size)
.build()
.map_err(|e| {
DatabaseError::InitializationError(format!("Failed to build pool: {}", e))
})?;
let conn = pool.get().await.map_err(DatabaseError::ConnectionError)?;
conn.interact(|conn| sql_query("SET TIME ZONE 'UTC'").execute(conn))
.await
.map_err(DatabaseError::InteractionError)?
.map_err(|e| {
DatabaseError::InitializationError(format!(
"Failed to execute timezone query: {}",
e
))
})?;
Ok(Self { pool })
}
pub fn pool(&self) -> &Pool {
&self.pool
}
pub async fn connection(&self) -> DatabaseResult<deadpool::managed::Object<Manager>> {
self.pool
.get()
.await
.map_err(DatabaseError::ConnectionError)
}
pub fn status(&self) -> deadpool::Status {
self.pool.status()
}
pub async fn health_check(&self) -> DatabaseResult<()> {
let conn = self.connection().await?;
conn.interact(|conn| {
sql_query("SELECT 1")
.execute(conn)
.map(|_| ())
.map_err(DatabaseError::from)
})
.await?
.map_err(|e| {
error!("Diesel health check query failed: {}", e);
e
})?;
info!("Diesel health check executed: db connection test successful");
Ok(())
}
pub async fn interact<F, T, E>(&self, f: F) -> DatabaseResult<T>
where
F: FnOnce(&mut PgConnection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: Send + 'static + Into<DatabaseError>,
{
let conn = self.connection().await?;
conn.interact(f)
.await
.map_err(DatabaseError::InteractionError)?
.map_err(Into::into)
}
pub async fn transaction<F, T>(&self, f: F) -> DatabaseResult<T>
where
F: FnOnce(&mut PgConnection) -> diesel::result::QueryResult<T> + Send + 'static,
T: Send + 'static,
{
self.interact(|conn| DieselConnection::transaction(conn, f))
.await
}
pub async fn run<F, T, E>(&self, f: F) -> DatabaseResult<T>
where
F: FnOnce(&mut PgConnection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: Send + 'static + Into<DatabaseError>,
{
self.interact(f).await
}
}