use axum::extract::FromRequestParts;
use diesel_async::AsyncPgConnection;
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::deadpool::Pool;
use std::time::Duration;
use tracing::Instrument as _;
use crate::config::DatabaseConfig;
use crate::error::AutumnError;
pub trait DbState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>>;
}
pub type PoolError = diesel_async::pooled_connection::deadpool::BuildError;
pub fn create_pool(config: &DatabaseConfig) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
let Some(url) = &config.url else {
return Ok(None);
};
let timeout = Duration::from_secs(config.connect_timeout_secs);
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(url);
let pool = Pool::builder(manager)
.max_size(config.pool_size.max(1))
.wait_timeout(Some(timeout))
.create_timeout(Some(timeout))
.runtime(deadpool::Runtime::Tokio1)
.build()?;
Ok(Some(pool))
}
type PooledConnection = diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>;
pub struct Db {
conn: PooledConnection,
span: tracing::Span,
}
impl Db {
#[must_use]
pub const fn span(&self) -> &tracing::Span {
&self.span
}
}
impl std::ops::Deref for Db {
type Target = AsyncPgConnection;
fn deref(&self) -> &Self::Target {
&self.conn
}
}
impl std::ops::DerefMut for Db {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.conn
}
}
impl<S> FromRequestParts<S> for Db
where
S: DbState + Send + Sync,
{
type Rejection = AutumnError;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let pool = state
.pool()
.ok_or_else(|| AutumnError::service_unavailable_msg("Database not configured"))?;
let span = tracing::info_span!(
"db.connection",
otel.kind = "client",
db.system = "postgresql",
);
let conn = async {
pool.get().await.map_err(|e| {
tracing::error!("Failed to acquire database connection: {e}");
AutumnError::service_unavailable_msg(e.to_string())
})
}
.instrument(span.clone())
.await?;
Ok(Self { conn, span })
}
}
pub trait DatabasePoolProvider: Send + Sync + 'static {
fn create_pool(
&self,
config: &DatabaseConfig,
) -> impl std::future::Future<Output = Result<Option<Pool<AsyncPgConnection>>, PoolError>> + Send;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DieselDeadpoolPoolProvider;
impl DieselDeadpoolPoolProvider {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl DatabasePoolProvider for DieselDeadpoolPoolProvider {
async fn create_pool(
&self,
config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
create_pool(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DatabaseConfig;
use std::time::Duration;
struct NoOpPoolProvider;
impl DatabasePoolProvider for NoOpPoolProvider {
async fn create_pool(
&self,
_config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
Ok(None)
}
}
#[tokio::test]
async fn pool_provider_trait_returns_supplied_pool() {
let config = DatabaseConfig {
url: Some("postgres://localhost/ignored".to_owned()),
..Default::default()
};
let provider = NoOpPoolProvider;
let pool = provider
.create_pool(&config)
.await
.expect("no-op provider should succeed");
assert!(
pool.is_none(),
"no-op provider must override default behaviour"
);
}
#[tokio::test]
async fn default_pool_provider_matches_free_function() {
let config = DatabaseConfig::default();
let via_provider = DieselDeadpoolPoolProvider::new()
.create_pool(&config)
.await
.expect("default provider should succeed");
let via_function = create_pool(&config).expect("free fn should succeed");
assert_eq!(via_provider.is_none(), via_function.is_none());
}
#[tokio::test]
async fn default_pool_provider_respects_url_config() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let provider = DieselDeadpoolPoolProvider::new();
let pool = provider
.create_pool(&config)
.await
.expect("default provider should succeed");
assert!(
pool.is_some(),
"default provider should return Some when url is provided"
);
}
#[test]
fn create_pool_with_no_url_returns_none() {
let config = DatabaseConfig::default();
let pool = create_pool(&config).expect("should not fail with no URL");
assert!(pool.is_none());
}
#[test]
fn create_pool_with_url_returns_some() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let pool = create_pool(&config).expect("should build pool from valid config");
assert!(pool.is_some());
}
#[test]
fn pool_respects_max_size() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 5,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(pool.status().max_size, 5);
}
#[test]
fn pool_clamps_size_to_one_if_zero() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 0,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(
pool.status().max_size,
1,
"Pool size should be clamped to 1"
);
}
#[test]
fn config_runtime_drift_pool_applies_connect_timeout_to_wait_and_create() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
connect_timeout_secs: 7,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
let timeouts = pool.timeouts();
assert_eq!(timeouts.wait, Some(Duration::from_secs(7)));
assert_eq!(timeouts.create, Some(Duration::from_secs(7)));
}
#[tokio::test]
async fn db_extractor_rejects_when_no_pool() {
use crate::state::AppState;
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use tower::ServiceExt;
async fn handler(_db: Db) -> &'static str {
"ok"
}
let app = Router::new().route("/", get(handler)).with_state(AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: true,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
});
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}