use axum::extract::FromRequestParts;
use diesel_async::AsyncPgConnection;
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::deadpool::Pool;
use std::time::Duration;
use tracing::Instrument as _;
use crate::config::DatabaseConfig;
use crate::error::AutumnError;
pub trait DbState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>>;
fn replica_pool(&self) -> Option<&Pool<AsyncPgConnection>> {
None
}
fn read_pool(&self) -> Option<&Pool<AsyncPgConnection>> {
self.replica_pool().or_else(|| self.pool())
}
}
pub type PoolError = diesel_async::pooled_connection::deadpool::BuildError;
#[derive(Clone)]
pub struct DatabaseTopology {
primary: Pool<AsyncPgConnection>,
replica: Option<Pool<AsyncPgConnection>>,
}
impl DatabaseTopology {
#[must_use]
pub const fn from_pools(
primary: Pool<AsyncPgConnection>,
replica: Option<Pool<AsyncPgConnection>>,
) -> Self {
Self { primary, replica }
}
#[must_use]
pub const fn primary_only(primary: Pool<AsyncPgConnection>) -> Self {
Self {
primary,
replica: None,
}
}
#[must_use]
pub const fn primary(&self) -> &Pool<AsyncPgConnection> {
&self.primary
}
#[must_use]
pub const fn replica(&self) -> Option<&Pool<AsyncPgConnection>> {
self.replica.as_ref()
}
#[must_use]
pub fn read(&self) -> &Pool<AsyncPgConnection> {
self.replica.as_ref().unwrap_or(&self.primary)
}
}
fn build_pool(
url: &str,
pool_size: usize,
connect_timeout_secs: u64,
) -> Result<Pool<AsyncPgConnection>, PoolError> {
let timeout = Duration::from_secs(connect_timeout_secs);
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(url);
Pool::builder(manager)
.max_size(pool_size.max(1))
.wait_timeout(Some(timeout))
.create_timeout(Some(timeout))
.runtime(deadpool::Runtime::Tokio1)
.build()
}
pub fn create_pool(config: &DatabaseConfig) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
let Some(url) = config.effective_primary_url() else {
return Ok(None);
};
let pool = build_pool(
url,
config.effective_primary_pool_size(),
config.connect_timeout_secs,
)?;
Ok(Some(pool))
}
pub fn create_topology(config: &DatabaseConfig) -> Result<Option<DatabaseTopology>, PoolError> {
let Some(primary_url) = config.effective_primary_url() else {
return Ok(None);
};
let primary = build_pool(
primary_url,
config.effective_primary_pool_size(),
config.connect_timeout_secs,
)?;
let replica = config
.replica_url
.as_deref()
.map(|url| {
build_pool(
url,
config.effective_replica_pool_size(),
config.connect_timeout_secs,
)
})
.transpose()?;
Ok(Some(DatabaseTopology { primary, replica }))
}
type PooledConnection = diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>;
struct TxDepthGuard<'a> {
depth: &'a mut usize,
poisoned: &'a mut bool,
disarmed: bool,
}
impl Drop for TxDepthGuard<'_> {
fn drop(&mut self) {
*self.depth -= 1;
if !self.disarmed {
*self.poisoned = true;
}
}
}
pub struct Db {
conn: PooledConnection,
span: tracing::Span,
tx_depth: usize,
tx_poisoned: bool,
}
impl Db {
#[must_use]
pub const fn span(&self) -> &tracing::Span {
&self.span
}
pub async fn tx<'a, T, E, F>(&'a mut self, f: F) -> Result<T, crate::error::AutumnError>
where
T: Send + 'a,
E: From<diesel::result::Error> + Send + Sync + 'a,
crate::error::AutumnError: From<E>,
F: for<'r> FnOnce(
&'r mut PooledConnection,
) -> scoped_futures::ScopedBoxFuture<'a, 'r, Result<T, E>>
+ Send
+ 'a,
{
use diesel_async::AsyncConnection as _;
if self.tx_poisoned {
return Err(crate::error::AutumnError::service_unavailable_msg(
"Database connection is in an invalid transaction state",
));
}
if self.tx_depth > 0 {
return Err(crate::error::AutumnError::bad_request_msg(
"Nested Db::tx calls are not supported",
));
}
self.tx_depth += 1;
let mut guard = TxDepthGuard {
depth: &mut self.tx_depth,
poisoned: &mut self.tx_poisoned,
disarmed: false,
};
let result = self
.conn
.transaction::<T, E, _>(f)
.await
.map_err(Into::into);
guard.disarmed = true;
result
}
}
impl std::ops::Deref for Db {
type Target = AsyncPgConnection;
fn deref(&self) -> &Self::Target {
assert!(
!self.tx_poisoned,
"Db connection is poisoned due to a cancelled/dropped transaction"
);
&self.conn
}
}
impl std::ops::DerefMut for Db {
fn deref_mut(&mut self) -> &mut Self::Target {
assert!(
!self.tx_poisoned,
"Db connection is poisoned due to a cancelled/dropped transaction"
);
&mut self.conn
}
}
impl<S> FromRequestParts<S> for Db
where
S: DbState + Send + Sync,
{
type Rejection = AutumnError;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let pool = state
.pool()
.ok_or_else(|| AutumnError::service_unavailable_msg("Database not configured"))?;
let span = tracing::info_span!(
"db.connection",
otel.kind = "client",
db.system = "postgresql",
);
let conn = async {
pool.get().await.map_err(|e| {
tracing::error!("Failed to acquire database connection: {e}");
AutumnError::service_unavailable_msg(e.to_string())
})
}
.instrument(span.clone())
.await?;
Ok(Self {
conn,
span,
tx_depth: 0,
tx_poisoned: false,
})
}
}
pub trait DatabasePoolProvider: Send + Sync + 'static {
fn create_pool(
&self,
config: &DatabaseConfig,
) -> impl std::future::Future<Output = Result<Option<Pool<AsyncPgConnection>>, PoolError>> + Send;
fn create_topology(
&self,
config: &DatabaseConfig,
) -> impl std::future::Future<Output = Result<Option<DatabaseTopology>, PoolError>> + Send {
async move {
let Some(primary) = self.create_pool(config).await? else {
return Ok(None);
};
let replica = config
.replica_url
.as_deref()
.map(|url| {
build_pool(
url,
config.effective_replica_pool_size(),
config.connect_timeout_secs,
)
})
.transpose()?;
Ok(Some(DatabaseTopology::from_pools(primary, replica)))
}
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DieselDeadpoolPoolProvider;
impl DieselDeadpoolPoolProvider {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl DatabasePoolProvider for DieselDeadpoolPoolProvider {
async fn create_pool(
&self,
config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
create_pool(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DatabaseConfig;
use std::time::Duration;
struct NoOpPoolProvider;
impl DatabasePoolProvider for NoOpPoolProvider {
async fn create_pool(
&self,
_config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
Ok(None)
}
}
#[tokio::test]
async fn pool_provider_trait_returns_supplied_pool() {
let config = DatabaseConfig {
url: Some("postgres://localhost/ignored".to_owned()),
..Default::default()
};
let provider = NoOpPoolProvider;
let pool = provider
.create_pool(&config)
.await
.expect("no-op provider should succeed");
assert!(
pool.is_none(),
"no-op provider must override default behaviour"
);
}
#[tokio::test]
async fn default_pool_provider_matches_free_function() {
let config = DatabaseConfig::default();
let via_provider = DieselDeadpoolPoolProvider::new()
.create_pool(&config)
.await
.expect("default provider should succeed");
let via_function = create_pool(&config).expect("free fn should succeed");
assert_eq!(via_provider.is_none(), via_function.is_none());
}
#[tokio::test]
async fn default_pool_provider_respects_url_config() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let provider = DieselDeadpoolPoolProvider::new();
let pool = provider
.create_pool(&config)
.await
.expect("default provider should succeed");
assert!(
pool.is_some(),
"default provider should return Some when url is provided"
);
}
#[test]
fn create_pool_with_no_url_returns_none() {
let config = DatabaseConfig::default();
let pool = create_pool(&config).expect("should not fail with no URL");
assert!(pool.is_none());
}
#[test]
fn create_pool_with_url_returns_some() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let pool = create_pool(&config).expect("should build pool from valid config");
assert!(pool.is_some());
}
#[test]
fn pool_respects_max_size() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 5,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(pool.status().max_size, 5);
}
#[test]
fn pool_clamps_size_to_one_if_zero() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 0,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(
pool.status().max_size,
1,
"Pool size should be clamped to 1"
);
}
#[test]
fn database_topology_builds_primary_and_replica_pools() {
let config = DatabaseConfig {
primary_url: Some("postgres://localhost/primary".into()),
replica_url: Some("postgres://localhost/replica".into()),
primary_pool_size: Some(6),
replica_pool_size: Some(2),
..Default::default()
};
let topology = create_topology(&config)
.expect("topology should build")
.expect("topology should be configured");
assert_eq!(topology.primary().status().max_size, 6);
assert_eq!(
topology.replica().expect("replica pool").status().max_size,
2
);
assert_eq!(topology.read().status().max_size, 2);
}
#[test]
fn database_topology_single_url_builds_only_primary_pool() {
let config = DatabaseConfig {
url: Some("postgres://localhost/single".into()),
pool_size: 5,
..Default::default()
};
let topology = create_topology(&config)
.expect("topology should build")
.expect("topology should be configured");
assert_eq!(topology.primary().status().max_size, 5);
assert!(topology.replica().is_none());
assert_eq!(topology.read().status().max_size, 5);
}
#[test]
fn config_runtime_drift_pool_applies_connect_timeout_to_wait_and_create() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
connect_timeout_secs: 7,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
let timeouts = pool.timeouts();
assert_eq!(timeouts.wait, Some(Duration::from_secs(7)));
assert_eq!(timeouts.create, Some(Duration::from_secs(7)));
}
#[derive(Clone)]
struct TestDbState;
impl DbState for TestDbState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>> {
None
}
}
#[derive(Clone)]
struct TestReadState {
primary: Pool<AsyncPgConnection>,
}
impl DbState for TestReadState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>> {
Some(&self.primary)
}
}
#[test]
fn database_topology_read_pool_falls_back_to_primary() {
let config = DatabaseConfig {
url: Some("postgres://localhost/read-fallback".into()),
pool_size: 3,
..Default::default()
};
let primary = create_pool(&config).unwrap().unwrap();
let state = TestReadState { primary };
assert_eq!(state.read_pool().expect("read pool").status().max_size, 3);
}
#[tokio::test]
async fn db_extractor_rejects_when_no_pool() {
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use tower::ServiceExt;
async fn handler(_db: Db) -> &'static str {
"ok"
}
let app = Router::new()
.route("/", get(handler))
.with_state(TestDbState);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}