kaccy_db/
pool.rs

1//! Database connection pool management
2
3use serde::Serialize;
4use sqlx::{postgres::PgPoolOptions, PgPool, Row};
5use std::time::Duration;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum PoolError {
10    #[error("Failed to create connection pool: {0}")]
11    ConnectionFailed(#[from] sqlx::Error),
12
13    #[error("Connection failed after {attempts} retry attempts: {last_error}")]
14    RetryExhausted { attempts: u32, last_error: String },
15}
16
17/// Retry configuration for database connections
18#[derive(Debug, Clone)]
19pub struct RetryConfig {
20    /// Maximum number of retry attempts
21    pub max_attempts: u32,
22    /// Initial delay between retries in milliseconds
23    pub initial_delay_ms: u64,
24    /// Maximum delay between retries in milliseconds
25    pub max_delay_ms: u64,
26    /// Multiplier for exponential backoff
27    pub backoff_multiplier: f64,
28}
29
30impl Default for RetryConfig {
31    fn default() -> Self {
32        Self {
33            max_attempts: 5,
34            initial_delay_ms: 100,
35            max_delay_ms: 10000,
36            backoff_multiplier: 2.0,
37        }
38    }
39}
40
41impl RetryConfig {
42    /// Calculate delay for a given attempt using exponential backoff
43    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
44        let base_delay = self.initial_delay_ms as f64;
45        let multiplier = self.backoff_multiplier.powi(attempt as i32);
46        let delay_ms = (base_delay * multiplier).min(self.max_delay_ms as f64) as u64;
47        Duration::from_millis(delay_ms)
48    }
49}
50
51/// Create a PostgreSQL connection pool with optimized settings
52pub async fn create_pool(database_url: &str) -> Result<PgPool, PoolError> {
53    create_pool_with_retry(database_url, &RetryConfig::default()).await
54}
55
56/// Create a PostgreSQL connection pool with retry logic
57pub async fn create_pool_with_retry(
58    database_url: &str,
59    retry_config: &RetryConfig,
60) -> Result<PgPool, PoolError> {
61    let mut last_error = String::new();
62
63    for attempt in 0..retry_config.max_attempts {
64        match try_create_pool(database_url).await {
65            Ok(pool) => {
66                if attempt > 0 {
67                    tracing::info!(
68                        attempt = attempt + 1,
69                        "Database connection pool created after retry"
70                    );
71                } else {
72                    tracing::info!("Database connection pool created successfully");
73                }
74                return Ok(pool);
75            }
76            Err(e) => {
77                last_error = e.to_string();
78                let remaining = retry_config.max_attempts - attempt - 1;
79
80                if remaining > 0 {
81                    let delay = retry_config.delay_for_attempt(attempt);
82                    tracing::warn!(
83                        attempt = attempt + 1,
84                        remaining_attempts = remaining,
85                        delay_ms = delay.as_millis(),
86                        error = %e,
87                        "Database connection failed, retrying..."
88                    );
89                    tokio::time::sleep(delay).await;
90                } else {
91                    tracing::error!(
92                        attempts = retry_config.max_attempts,
93                        error = %e,
94                        "Database connection failed, no retries remaining"
95                    );
96                }
97            }
98        }
99    }
100
101    Err(PoolError::RetryExhausted {
102        attempts: retry_config.max_attempts,
103        last_error,
104    })
105}
106
107/// Internal function to attempt pool creation
108async fn try_create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
109    PgPoolOptions::new()
110        .max_connections(20)
111        .min_connections(5)
112        .acquire_timeout(Duration::from_secs(5))
113        .idle_timeout(Duration::from_secs(600))
114        .connect(database_url)
115        .await
116}
117
118/// Create a pool with custom settings
119pub async fn create_pool_with_options(
120    database_url: &str,
121    max_connections: u32,
122    min_connections: u32,
123    acquire_timeout_secs: u64,
124) -> Result<PgPool, PoolError> {
125    let pool = PgPoolOptions::new()
126        .max_connections(max_connections)
127        .min_connections(min_connections)
128        .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
129        .idle_timeout(Duration::from_secs(600))
130        .connect(database_url)
131        .await?;
132
133    tracing::info!(
134        max_connections = max_connections,
135        min_connections = min_connections,
136        "Database connection pool created with custom settings"
137    );
138    Ok(pool)
139}
140
141/// Health check result
142#[derive(Debug, Serialize)]
143pub struct HealthCheck {
144    pub status: HealthStatus,
145    pub database_connected: bool,
146    pub pool_size: u32,
147    pub pool_idle: u32,
148    pub latency_ms: Option<u64>,
149    pub version: Option<String>,
150}
151
152#[derive(Debug, Serialize, PartialEq, Eq)]
153pub enum HealthStatus {
154    Healthy,
155    Degraded,
156    Unhealthy,
157}
158
159/// Perform a health check on the database connection
160pub async fn health_check(pool: &PgPool) -> HealthCheck {
161    let pool_size = pool.size();
162    let pool_idle = pool.num_idle() as u32;
163
164    let start = std::time::Instant::now();
165    let query_result = sqlx::query("SELECT version()").fetch_optional(pool).await;
166    let latency = start.elapsed().as_millis() as u64;
167
168    match query_result {
169        Ok(Some(row)) => {
170            let version: String = row.get(0);
171            let status = if latency > 1000 {
172                HealthStatus::Degraded
173            } else {
174                HealthStatus::Healthy
175            };
176
177            HealthCheck {
178                status,
179                database_connected: true,
180                pool_size,
181                pool_idle,
182                latency_ms: Some(latency),
183                version: Some(version),
184            }
185        }
186        Ok(None) => HealthCheck {
187            status: HealthStatus::Degraded,
188            database_connected: true,
189            pool_size,
190            pool_idle,
191            latency_ms: Some(latency),
192            version: None,
193        },
194        Err(e) => {
195            tracing::error!(error = %e, "Database health check failed");
196            HealthCheck {
197                status: HealthStatus::Unhealthy,
198                database_connected: false,
199                pool_size,
200                pool_idle,
201                latency_ms: None,
202                version: None,
203            }
204        }
205    }
206}
207
208/// Pool statistics
209#[derive(Debug, Serialize)]
210pub struct PoolStats {
211    pub size: u32,
212    pub idle: u32,
213    pub in_use: u32,
214}
215
216/// Get current pool statistics
217pub fn pool_stats(pool: &PgPool) -> PoolStats {
218    let size = pool.size();
219    let idle = pool.num_idle() as u32;
220    PoolStats {
221        size,
222        idle,
223        in_use: size.saturating_sub(idle),
224    }
225}
226
227/// Warm-up strategy for connection pool
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum WarmupStrategy {
230    /// No warm-up, connections created on-demand
231    None,
232    /// Pre-populate pool to minimum connections
233    MinConnections,
234    /// Pre-populate pool to half capacity
235    HalfCapacity,
236    /// Pre-populate pool to full capacity
237    FullCapacity,
238}
239
240/// Warm up a connection pool by pre-creating connections
241///
242/// This helps reduce initial latency by establishing connections
243/// before they're needed.
244pub async fn warmup_pool(pool: &PgPool, strategy: WarmupStrategy) -> Result<(), PoolError> {
245    let target = match strategy {
246        WarmupStrategy::None => return Ok(()),
247        WarmupStrategy::MinConnections => 5, // Default min connections
248        WarmupStrategy::HalfCapacity => pool.size() / 2,
249        WarmupStrategy::FullCapacity => pool.size(),
250    };
251
252    tracing::info!(
253        strategy = ?strategy,
254        target = target,
255        "Warming up connection pool"
256    );
257
258    // Acquire and immediately release connections to populate the pool
259    let mut connections = Vec::new();
260    for i in 0..target {
261        match pool.acquire().await {
262            Ok(conn) => {
263                connections.push(conn);
264                tracing::debug!(acquired = i + 1, target = target, "Pool warm-up progress");
265            }
266            Err(e) => {
267                tracing::error!(
268                    error = %e,
269                    acquired = i,
270                    target = target,
271                    "Failed to warm up pool"
272                );
273                return Err(PoolError::ConnectionFailed(e));
274            }
275        }
276    }
277
278    // Release all connections back to the pool
279    drop(connections);
280
281    tracing::info!(
282        warmed_connections = target,
283        "Connection pool warm-up completed"
284    );
285
286    Ok(())
287}
288
289/// Validate pool connections by running a simple query
290///
291/// This ensures all connections in the pool are healthy
292pub async fn validate_pool_connections(pool: &PgPool) -> Result<u32, PoolError> {
293    let pool_size = pool.size();
294    let mut valid_count = 0;
295
296    tracing::info!(pool_size = pool_size, "Validating pool connections");
297
298    for i in 0..pool_size {
299        match sqlx::query("SELECT 1").execute(pool).await {
300            Ok(_) => {
301                valid_count += 1;
302            }
303            Err(e) => {
304                tracing::warn!(
305                    connection = i,
306                    error = %e,
307                    "Connection validation failed"
308                );
309            }
310        }
311    }
312
313    tracing::info!(
314        valid_count = valid_count,
315        total = pool_size,
316        "Connection validation completed"
317    );
318
319    Ok(valid_count)
320}
321
322/// Periodically refresh pool connections
323///
324/// This helps maintain connection health by cycling through
325/// idle connections
326pub async fn refresh_pool_connections(pool: &PgPool, interval_secs: u64) -> Result<(), PoolError> {
327    let interval = Duration::from_secs(interval_secs);
328
329    loop {
330        tokio::time::sleep(interval).await;
331
332        tracing::debug!("Refreshing pool connections");
333
334        // Acquire and release a connection to ensure freshness
335        match pool.acquire().await {
336            Ok(mut conn) => {
337                // Simple validation query
338                if let Err(e) = sqlx::query("SELECT 1").execute(&mut *conn).await {
339                    tracing::warn!(error = %e, "Connection refresh validation failed");
340                }
341                drop(conn);
342            }
343            Err(e) => {
344                tracing::warn!(error = %e, "Failed to acquire connection for refresh");
345            }
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[tokio::test]
355    #[ignore = "requires database"]
356    async fn test_create_pool() {
357        let url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
358        let pool = create_pool(&url).await;
359        assert!(pool.is_ok());
360    }
361
362    #[tokio::test]
363    #[ignore = "requires database"]
364    async fn test_health_check() {
365        let url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
366        let pool = create_pool(&url).await.unwrap();
367        let health = health_check(&pool).await;
368        assert_eq!(health.status, HealthStatus::Healthy);
369        assert!(health.database_connected);
370    }
371
372    #[test]
373    fn test_warmup_strategy() {
374        assert_eq!(WarmupStrategy::None, WarmupStrategy::None);
375        assert_ne!(WarmupStrategy::None, WarmupStrategy::MinConnections);
376    }
377
378    #[test]
379    fn test_retry_config_delay() {
380        let config = RetryConfig::default();
381
382        let delay0 = config.delay_for_attempt(0);
383        let delay1 = config.delay_for_attempt(1);
384        let delay2 = config.delay_for_attempt(2);
385
386        assert_eq!(delay0.as_millis(), 100);
387        assert_eq!(delay1.as_millis(), 200);
388        assert_eq!(delay2.as_millis(), 400);
389    }
390
391    #[test]
392    fn test_retry_config_max_delay() {
393        let config = RetryConfig {
394            initial_delay_ms: 1000,
395            max_delay_ms: 5000,
396            backoff_multiplier: 10.0,
397            max_attempts: 10,
398        };
399
400        let delay = config.delay_for_attempt(5);
401        assert_eq!(delay.as_millis(), 5000); // Capped at max_delay_ms
402    }
403}