1use axum::extract::FromRequestParts;
29use diesel_async::AsyncPgConnection;
30use diesel_async::pooled_connection::AsyncDieselConnectionManager;
31use diesel_async::pooled_connection::deadpool::Pool;
32
33use crate::AppState;
34use crate::config::DatabaseConfig;
35use crate::error::AutumnError;
36
37pub type PoolError = diesel_async::pooled_connection::deadpool::BuildError;
42
43pub fn create_pool(config: &DatabaseConfig) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
53 let Some(url) = &config.url else {
54 return Ok(None);
55 };
56
57 let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(url);
58 let pool = Pool::builder(manager).max_size(config.pool_size).build()?;
59
60 Ok(Some(pool))
61}
62
63type PooledConnection = diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>;
67
68pub struct Db(PooledConnection);
95
96impl std::ops::Deref for Db {
97 type Target = AsyncPgConnection;
98 fn deref(&self) -> &Self::Target {
99 &self.0
100 }
101}
102
103impl std::ops::DerefMut for Db {
104 fn deref_mut(&mut self) -> &mut Self::Target {
105 &mut self.0
106 }
107}
108
109impl FromRequestParts<AppState> for Db {
110 type Rejection = AutumnError;
111
112 async fn from_request_parts(
113 _parts: &mut axum::http::request::Parts,
114 state: &AppState,
115 ) -> Result<Self, Self::Rejection> {
116 let pool = state
117 .pool
118 .as_ref()
119 .ok_or_else(|| AutumnError::service_unavailable_msg("Database not configured"))?;
120
121 let conn = pool.get().await.map_err(|e| {
122 tracing::error!("Failed to acquire database connection: {e}");
123 AutumnError::service_unavailable_msg(e.to_string())
124 })?;
125
126 Ok(Self(conn))
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::config::DatabaseConfig;
134
135 #[test]
138 fn create_pool_with_no_url_returns_none() {
139 let config = DatabaseConfig::default();
140 let pool = create_pool(&config).expect("should not fail with no URL");
141 assert!(pool.is_none());
142 }
143
144 #[test]
145 fn create_pool_with_url_returns_some() {
146 let config = DatabaseConfig {
147 url: Some("postgres://localhost/test".into()),
148 ..Default::default()
149 };
150 let pool = create_pool(&config).expect("should build pool from valid config");
151 assert!(pool.is_some());
152 }
153
154 #[test]
155 fn pool_respects_max_size() {
156 let config = DatabaseConfig {
157 url: Some("postgres://localhost/test".into()),
158 pool_size: 5,
159 ..Default::default()
160 };
161 let pool = create_pool(&config)
162 .expect("should build pool")
163 .expect("should be Some");
164 assert_eq!(pool.status().max_size, 5);
165 }
166
167 #[tokio::test]
170 async fn db_extractor_rejects_when_no_pool() {
171 use axum::Router;
172 use axum::body::Body;
173 use axum::http::{Request, StatusCode};
174 use axum::routing::get;
175 use tower::ServiceExt;
176
177 async fn handler(_db: Db) -> &'static str {
178 "ok"
179 }
180
181 let app = Router::new()
182 .route("/", get(handler))
183 .with_state(AppState { pool: None });
184
185 let response = app
186 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
187 .await
188 .unwrap();
189
190 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
191 }
192}