qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14    pub host: String,
15    pub port: u16,
16    pub user: String,
17    pub database: String,
18    pub password: Option<String>,
19    pub max_connections: usize,
20    pub min_connections: usize,
21    pub idle_timeout: Duration,
22    pub acquire_timeout: Duration,
23    pub connect_timeout: Duration,
24    pub max_lifetime: Option<Duration>,
25    pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29    /// Create a new pool configuration with sensible defaults.
30    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31        Self {
32            host: host.to_string(),
33            port,
34            user: user.to_string(),
35            database: database.to_string(),
36            password: None,
37            max_connections: 10,
38            min_connections: 1,
39            idle_timeout: Duration::from_secs(600), // 10 minutes
40            acquire_timeout: Duration::from_secs(30), // 30 seconds
41            connect_timeout: Duration::from_secs(10), // 10 seconds
42            max_lifetime: None,                      // No limit by default
43            test_on_acquire: false,                  // Disabled by default for performance
44        }
45    }
46
47    /// Set password for authentication.
48    pub fn password(mut self, password: &str) -> Self {
49        self.password = Some(password.to_string());
50        self
51    }
52
53    pub fn max_connections(mut self, max: usize) -> Self {
54        self.max_connections = max;
55        self
56    }
57
58    /// Set minimum idle connections.
59    pub fn min_connections(mut self, min: usize) -> Self {
60        self.min_connections = min;
61        self
62    }
63
64    /// Set idle timeout (connections idle longer than this are closed).
65    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66        self.idle_timeout = timeout;
67        self
68    }
69
70    /// Set acquire timeout (max wait time when getting a connection).
71    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72        self.acquire_timeout = timeout;
73        self
74    }
75
76    /// Set connect timeout (max time to establish new connection).
77    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78        self.connect_timeout = timeout;
79        self
80    }
81
82    /// Set maximum lifetime of a connection before recycling.
83    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84        self.max_lifetime = Some(lifetime);
85        self
86    }
87
88    /// Enable connection validation on acquire.
89    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90        self.test_on_acquire = enabled;
91        self
92    }
93}
94
95/// Pool statistics for monitoring.
96#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98    pub active: usize,
99    pub idle: usize,
100    pub pending: usize,
101    /// Maximum connections configured
102    pub max_size: usize,
103    pub total_created: usize,
104}
105
106/// A pooled connection with creation timestamp for idle tracking.
107struct PooledConn {
108    conn: PgConnection,
109    created_at: Instant,
110    last_used: Instant,
111}
112
113/// A pooled connection that returns to the pool when dropped.
114pub struct PooledConnection {
115    conn: Option<PgConnection>,
116    pool: Arc<PgPoolInner>,
117}
118
119impl PooledConnection {
120    /// Get a mutable reference to the underlying connection.
121    pub fn get_mut(&mut self) -> &mut PgConnection {
122        self.conn
123            .as_mut()
124            .expect("Connection should always be present")
125    }
126
127    /// Get a token to cancel the currently running query.
128    pub fn cancel_token(&self) -> crate::driver::CancelToken {
129        let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
130        crate::driver::CancelToken {
131            host: self.pool.config.host.clone(),
132            port: self.pool.config.port,
133            process_id,
134            secret_key,
135        }
136    }
137}
138
139impl Drop for PooledConnection {
140    fn drop(&mut self) {
141        if let Some(conn) = self.conn.take() {
142            let pool = self.pool.clone();
143            tokio::spawn(async move {
144                pool.return_connection(conn).await;
145            });
146        }
147    }
148}
149
150impl std::ops::Deref for PooledConnection {
151    type Target = PgConnection;
152
153    fn deref(&self) -> &Self::Target {
154        self.conn
155            .as_ref()
156            .expect("Connection should always be present")
157    }
158}
159
160impl std::ops::DerefMut for PooledConnection {
161    fn deref_mut(&mut self) -> &mut Self::Target {
162        self.conn
163            .as_mut()
164            .expect("Connection should always be present")
165    }
166}
167
168/// Inner pool state (shared across clones).
169struct PgPoolInner {
170    config: PoolConfig,
171    connections: Mutex<Vec<PooledConn>>,
172    semaphore: Semaphore,
173    closed: AtomicBool,
174    active_count: AtomicUsize,
175    total_created: AtomicUsize,
176}
177
178impl PgPoolInner {
179    async fn return_connection(&self, conn: PgConnection) {
180
181        self.active_count.fetch_sub(1, Ordering::Relaxed);
182        
183
184        if self.closed.load(Ordering::Relaxed) {
185            return;
186        }
187        
188        let mut connections = self.connections.lock().await;
189        if connections.len() < self.config.max_connections {
190            connections.push(PooledConn {
191                conn,
192                created_at: Instant::now(),
193                last_used: Instant::now(),
194            });
195        }
196
197        self.semaphore.add_permits(1);
198    }
199
200    /// Get a healthy connection from the pool, or None if pool is empty.
201    async fn get_healthy_connection(&self) -> Option<PgConnection> {
202        let mut connections = self.connections.lock().await;
203
204        while let Some(pooled) = connections.pop() {
205            if pooled.last_used.elapsed() > self.config.idle_timeout {
206                // Connection is stale, drop it
207                continue;
208            }
209
210            if let Some(max_life) = self.config.max_lifetime
211                && pooled.created_at.elapsed() > max_life
212            {
213                // Connection exceeded max lifetime, recycle it
214                continue;
215            }
216
217            return Some(pooled.conn);
218        }
219
220        None
221    }
222}
223
224/// # Example
225/// ```ignore
226/// let config = PoolConfig::new("localhost", 5432, "user", "db")
227///     .password("secret")
228///     .max_connections(20);
229/// let pool = PgPool::connect(config).await?;
230/// // Get a connection from the pool
231/// let mut conn = pool.acquire().await?;
232/// conn.simple_query("SELECT 1").await?;
233/// ```
234#[derive(Clone)]
235pub struct PgPool {
236    inner: Arc<PgPoolInner>,
237}
238
239impl PgPool {
240    /// Create a new connection pool.
241    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
242        // Semaphore starts with max_connections permits
243        let semaphore = Semaphore::new(config.max_connections);
244
245        let mut initial_connections = Vec::new();
246        for _ in 0..config.min_connections {
247            let conn = Self::create_connection(&config).await?;
248            initial_connections.push(PooledConn {
249                conn,
250                created_at: Instant::now(),
251                last_used: Instant::now(),
252            });
253        }
254
255        let initial_count = initial_connections.len();
256
257        let inner = Arc::new(PgPoolInner {
258            config,
259            connections: Mutex::new(initial_connections),
260            semaphore,
261            closed: AtomicBool::new(false),
262            active_count: AtomicUsize::new(0),
263            total_created: AtomicUsize::new(initial_count),
264        });
265
266        Ok(Self { inner })
267    }
268
269    /// Acquire a connection from the pool.
270    pub async fn acquire(&self) -> PgResult<PooledConnection> {
271        if self.inner.closed.load(Ordering::Relaxed) {
272            return Err(PgError::Connection("Pool is closed".to_string()));
273        }
274
275        // Wait for available slot with timeout
276        let acquire_timeout = self.inner.config.acquire_timeout;
277        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
278            .await
279            .map_err(|_| {
280                PgError::Connection(format!(
281                    "Timed out waiting for connection ({}s)",
282                    acquire_timeout.as_secs()
283                ))
284            })?
285            .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
286        permit.forget();
287
288        // Try to get existing healthy connection
289        let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
290            conn
291        } else {
292            let conn = Self::create_connection(&self.inner.config).await?;
293            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
294            conn
295        };
296
297
298        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
299
300        Ok(PooledConnection {
301            conn: Some(conn),
302            pool: self.inner.clone(),
303        })
304    }
305
306    /// Get the current number of idle connections.
307    pub async fn idle_count(&self) -> usize {
308        self.inner.connections.lock().await.len()
309    }
310
311    /// Get the number of connections currently in use.
312    pub fn active_count(&self) -> usize {
313        self.inner.active_count.load(Ordering::Relaxed)
314    }
315
316    /// Get the maximum number of connections.
317    pub fn max_connections(&self) -> usize {
318        self.inner.config.max_connections
319    }
320
321    /// Get comprehensive pool statistics.
322    pub async fn stats(&self) -> PoolStats {
323        let idle = self.inner.connections.lock().await.len();
324        PoolStats {
325            active: self.inner.active_count.load(Ordering::Relaxed),
326            idle,
327            pending: self.inner.config.max_connections
328                - self.inner.semaphore.available_permits()
329                - self.active_count(),
330            max_size: self.inner.config.max_connections,
331            total_created: self.inner.total_created.load(Ordering::Relaxed),
332        }
333    }
334
335    /// Check if the pool is closed.
336    pub fn is_closed(&self) -> bool {
337        self.inner.closed.load(Ordering::Relaxed)
338    }
339
340    /// Close the pool gracefully.
341    pub async fn close(&self) {
342        self.inner.closed.store(true, Ordering::Relaxed);
343
344        let mut connections = self.inner.connections.lock().await;
345        connections.clear();
346    }
347
348    /// Create a new connection using the pool configuration.
349    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
350        match &config.password {
351            Some(password) => {
352                PgConnection::connect_with_password(
353                    &config.host,
354                    config.port,
355                    &config.user,
356                    &config.database,
357                    Some(password),
358                )
359                .await
360            }
361            None => {
362                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
363                    .await
364            }
365        }
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_pool_config() {
375        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
376            .password("secret123")
377            .max_connections(20)
378            .min_connections(5);
379
380        assert_eq!(config.host, "localhost");
381        assert_eq!(config.port, 5432);
382        assert_eq!(config.user, "user");
383        assert_eq!(config.database, "testdb");
384        assert_eq!(config.password, Some("secret123".to_string()));
385        assert_eq!(config.max_connections, 20);
386        assert_eq!(config.min_connections, 5);
387    }
388}