1use serde::Serialize;
4use sqlx::{postgres::PgPoolOptions, PgPool, Row};
5use std::time::Duration;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum PoolError {
10 #[error("Failed to create connection pool: {0}")]
11 ConnectionFailed(#[from] sqlx::Error),
12
13 #[error("Connection failed after {attempts} retry attempts: {last_error}")]
14 RetryExhausted { attempts: u32, last_error: String },
15}
16
17#[derive(Debug, Clone)]
19pub struct RetryConfig {
20 pub max_attempts: u32,
22 pub initial_delay_ms: u64,
24 pub max_delay_ms: u64,
26 pub backoff_multiplier: f64,
28}
29
30impl Default for RetryConfig {
31 fn default() -> Self {
32 Self {
33 max_attempts: 5,
34 initial_delay_ms: 100,
35 max_delay_ms: 10000,
36 backoff_multiplier: 2.0,
37 }
38 }
39}
40
41impl RetryConfig {
42 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
44 let base_delay = self.initial_delay_ms as f64;
45 let multiplier = self.backoff_multiplier.powi(attempt as i32);
46 let delay_ms = (base_delay * multiplier).min(self.max_delay_ms as f64) as u64;
47 Duration::from_millis(delay_ms)
48 }
49}
50
51pub async fn create_pool(database_url: &str) -> Result<PgPool, PoolError> {
53 create_pool_with_retry(database_url, &RetryConfig::default()).await
54}
55
56pub async fn create_pool_with_retry(
58 database_url: &str,
59 retry_config: &RetryConfig,
60) -> Result<PgPool, PoolError> {
61 let mut last_error = String::new();
62
63 for attempt in 0..retry_config.max_attempts {
64 match try_create_pool(database_url).await {
65 Ok(pool) => {
66 if attempt > 0 {
67 tracing::info!(
68 attempt = attempt + 1,
69 "Database connection pool created after retry"
70 );
71 } else {
72 tracing::info!("Database connection pool created successfully");
73 }
74 return Ok(pool);
75 }
76 Err(e) => {
77 last_error = e.to_string();
78 let remaining = retry_config.max_attempts - attempt - 1;
79
80 if remaining > 0 {
81 let delay = retry_config.delay_for_attempt(attempt);
82 tracing::warn!(
83 attempt = attempt + 1,
84 remaining_attempts = remaining,
85 delay_ms = delay.as_millis(),
86 error = %e,
87 "Database connection failed, retrying..."
88 );
89 tokio::time::sleep(delay).await;
90 } else {
91 tracing::error!(
92 attempts = retry_config.max_attempts,
93 error = %e,
94 "Database connection failed, no retries remaining"
95 );
96 }
97 }
98 }
99 }
100
101 Err(PoolError::RetryExhausted {
102 attempts: retry_config.max_attempts,
103 last_error,
104 })
105}
106
107async fn try_create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
109 PgPoolOptions::new()
110 .max_connections(20)
111 .min_connections(5)
112 .acquire_timeout(Duration::from_secs(5))
113 .idle_timeout(Duration::from_secs(600))
114 .connect(database_url)
115 .await
116}
117
118pub async fn create_pool_with_options(
120 database_url: &str,
121 max_connections: u32,
122 min_connections: u32,
123 acquire_timeout_secs: u64,
124) -> Result<PgPool, PoolError> {
125 let pool = PgPoolOptions::new()
126 .max_connections(max_connections)
127 .min_connections(min_connections)
128 .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
129 .idle_timeout(Duration::from_secs(600))
130 .connect(database_url)
131 .await?;
132
133 tracing::info!(
134 max_connections = max_connections,
135 min_connections = min_connections,
136 "Database connection pool created with custom settings"
137 );
138 Ok(pool)
139}
140
141#[derive(Debug, Serialize)]
143pub struct HealthCheck {
144 pub status: HealthStatus,
145 pub database_connected: bool,
146 pub pool_size: u32,
147 pub pool_idle: u32,
148 pub latency_ms: Option<u64>,
149 pub version: Option<String>,
150}
151
152#[derive(Debug, Serialize, PartialEq, Eq)]
153pub enum HealthStatus {
154 Healthy,
155 Degraded,
156 Unhealthy,
157}
158
159pub async fn health_check(pool: &PgPool) -> HealthCheck {
161 let pool_size = pool.size();
162 let pool_idle = pool.num_idle() as u32;
163
164 let start = std::time::Instant::now();
165 let query_result = sqlx::query("SELECT version()").fetch_optional(pool).await;
166 let latency = start.elapsed().as_millis() as u64;
167
168 match query_result {
169 Ok(Some(row)) => {
170 let version: String = row.get(0);
171 let status = if latency > 1000 {
172 HealthStatus::Degraded
173 } else {
174 HealthStatus::Healthy
175 };
176
177 HealthCheck {
178 status,
179 database_connected: true,
180 pool_size,
181 pool_idle,
182 latency_ms: Some(latency),
183 version: Some(version),
184 }
185 }
186 Ok(None) => HealthCheck {
187 status: HealthStatus::Degraded,
188 database_connected: true,
189 pool_size,
190 pool_idle,
191 latency_ms: Some(latency),
192 version: None,
193 },
194 Err(e) => {
195 tracing::error!(error = %e, "Database health check failed");
196 HealthCheck {
197 status: HealthStatus::Unhealthy,
198 database_connected: false,
199 pool_size,
200 pool_idle,
201 latency_ms: None,
202 version: None,
203 }
204 }
205 }
206}
207
208#[derive(Debug, Serialize)]
210pub struct PoolStats {
211 pub size: u32,
212 pub idle: u32,
213 pub in_use: u32,
214}
215
216pub fn pool_stats(pool: &PgPool) -> PoolStats {
218 let size = pool.size();
219 let idle = pool.num_idle() as u32;
220 PoolStats {
221 size,
222 idle,
223 in_use: size.saturating_sub(idle),
224 }
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum WarmupStrategy {
230 None,
232 MinConnections,
234 HalfCapacity,
236 FullCapacity,
238}
239
240pub async fn warmup_pool(pool: &PgPool, strategy: WarmupStrategy) -> Result<(), PoolError> {
245 let target = match strategy {
246 WarmupStrategy::None => return Ok(()),
247 WarmupStrategy::MinConnections => 5, WarmupStrategy::HalfCapacity => pool.size() / 2,
249 WarmupStrategy::FullCapacity => pool.size(),
250 };
251
252 tracing::info!(
253 strategy = ?strategy,
254 target = target,
255 "Warming up connection pool"
256 );
257
258 let mut connections = Vec::new();
260 for i in 0..target {
261 match pool.acquire().await {
262 Ok(conn) => {
263 connections.push(conn);
264 tracing::debug!(acquired = i + 1, target = target, "Pool warm-up progress");
265 }
266 Err(e) => {
267 tracing::error!(
268 error = %e,
269 acquired = i,
270 target = target,
271 "Failed to warm up pool"
272 );
273 return Err(PoolError::ConnectionFailed(e));
274 }
275 }
276 }
277
278 drop(connections);
280
281 tracing::info!(
282 warmed_connections = target,
283 "Connection pool warm-up completed"
284 );
285
286 Ok(())
287}
288
289pub async fn validate_pool_connections(pool: &PgPool) -> Result<u32, PoolError> {
293 let pool_size = pool.size();
294 let mut valid_count = 0;
295
296 tracing::info!(pool_size = pool_size, "Validating pool connections");
297
298 for i in 0..pool_size {
299 match sqlx::query("SELECT 1").execute(pool).await {
300 Ok(_) => {
301 valid_count += 1;
302 }
303 Err(e) => {
304 tracing::warn!(
305 connection = i,
306 error = %e,
307 "Connection validation failed"
308 );
309 }
310 }
311 }
312
313 tracing::info!(
314 valid_count = valid_count,
315 total = pool_size,
316 "Connection validation completed"
317 );
318
319 Ok(valid_count)
320}
321
322pub async fn refresh_pool_connections(pool: &PgPool, interval_secs: u64) -> Result<(), PoolError> {
327 let interval = Duration::from_secs(interval_secs);
328
329 loop {
330 tokio::time::sleep(interval).await;
331
332 tracing::debug!("Refreshing pool connections");
333
334 match pool.acquire().await {
336 Ok(mut conn) => {
337 if let Err(e) = sqlx::query("SELECT 1").execute(&mut *conn).await {
339 tracing::warn!(error = %e, "Connection refresh validation failed");
340 }
341 drop(conn);
342 }
343 Err(e) => {
344 tracing::warn!(error = %e, "Failed to acquire connection for refresh");
345 }
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[tokio::test]
355 #[ignore = "requires database"]
356 async fn test_create_pool() {
357 let url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
358 let pool = create_pool(&url).await;
359 assert!(pool.is_ok());
360 }
361
362 #[tokio::test]
363 #[ignore = "requires database"]
364 async fn test_health_check() {
365 let url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
366 let pool = create_pool(&url).await.unwrap();
367 let health = health_check(&pool).await;
368 assert_eq!(health.status, HealthStatus::Healthy);
369 assert!(health.database_connected);
370 }
371
372 #[test]
373 fn test_warmup_strategy() {
374 assert_eq!(WarmupStrategy::None, WarmupStrategy::None);
375 assert_ne!(WarmupStrategy::None, WarmupStrategy::MinConnections);
376 }
377
378 #[test]
379 fn test_retry_config_delay() {
380 let config = RetryConfig::default();
381
382 let delay0 = config.delay_for_attempt(0);
383 let delay1 = config.delay_for_attempt(1);
384 let delay2 = config.delay_for_attempt(2);
385
386 assert_eq!(delay0.as_millis(), 100);
387 assert_eq!(delay1.as_millis(), 200);
388 assert_eq!(delay2.as_millis(), 400);
389 }
390
391 #[test]
392 fn test_retry_config_max_delay() {
393 let config = RetryConfig {
394 initial_delay_ms: 1000,
395 max_delay_ms: 5000,
396 backoff_multiplier: 10.0,
397 max_attempts: 10,
398 };
399
400 let delay = config.delay_for_attempt(5);
401 assert_eq!(delay.as_millis(), 5000); }
403}