qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14    pub host: String,
15    pub port: u16,
16    pub user: String,
17    pub database: String,
18    pub password: Option<String>,
19    pub max_connections: usize,
20    pub min_connections: usize,
21    pub idle_timeout: Duration,
22    pub acquire_timeout: Duration,
23    pub connect_timeout: Duration,
24    pub max_lifetime: Option<Duration>,
25    pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29    /// Create a new pool configuration with sensible defaults.
30    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31        Self {
32            host: host.to_string(),
33            port,
34            user: user.to_string(),
35            database: database.to_string(),
36            password: None,
37            max_connections: 10,
38            min_connections: 1,
39            idle_timeout: Duration::from_secs(600), // 10 minutes
40            acquire_timeout: Duration::from_secs(30), // 30 seconds
41            connect_timeout: Duration::from_secs(10), // 10 seconds
42            max_lifetime: None,                      // No limit by default
43            test_on_acquire: false,                  // Disabled by default for performance
44        }
45    }
46
47    /// Set password for authentication.
48    pub fn password(mut self, password: &str) -> Self {
49        self.password = Some(password.to_string());
50        self
51    }
52
53    pub fn max_connections(mut self, max: usize) -> Self {
54        self.max_connections = max;
55        self
56    }
57
58    /// Set minimum idle connections.
59    pub fn min_connections(mut self, min: usize) -> Self {
60        self.min_connections = min;
61        self
62    }
63
64    /// Set idle timeout (connections idle longer than this are closed).
65    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66        self.idle_timeout = timeout;
67        self
68    }
69
70    /// Set acquire timeout (max wait time when getting a connection).
71    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72        self.acquire_timeout = timeout;
73        self
74    }
75
76    /// Set connect timeout (max time to establish new connection).
77    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78        self.connect_timeout = timeout;
79        self
80    }
81
82    /// Set maximum lifetime of a connection before recycling.
83    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84        self.max_lifetime = Some(lifetime);
85        self
86    }
87
88    /// Enable connection validation on acquire.
89    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90        self.test_on_acquire = enabled;
91        self
92    }
93}
94
95/// Pool statistics for monitoring.
96#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98    pub active: usize,
99    pub idle: usize,
100    pub pending: usize,
101    /// Maximum connections configured
102    pub max_size: usize,
103    pub total_created: usize,
104}
105
106/// A pooled connection with creation timestamp for idle tracking.
107struct PooledConn {
108    conn: PgConnection,
109    created_at: Instant,
110    last_used: Instant,
111}
112
113/// A pooled connection that returns to the pool when dropped.
114pub struct PooledConnection {
115    conn: Option<PgConnection>,
116    pool: Arc<PgPoolInner>,
117}
118
119impl PooledConnection {
120    /// Get a mutable reference to the underlying connection.
121    pub fn get_mut(&mut self) -> &mut PgConnection {
122        self.conn
123            .as_mut()
124            .expect("Connection should always be present")
125    }
126
127    /// Get a token to cancel the currently running query.
128    pub fn cancel_token(&self) -> crate::driver::CancelToken {
129        let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
130        crate::driver::CancelToken {
131            host: self.pool.config.host.clone(),
132            port: self.pool.config.port,
133            process_id,
134            secret_key,
135        }
136    }
137
138    /// Execute a QAIL command and fetch all rows (UNCACHED).
139    /// Returns rows with column metadata for JSON serialization.
140    pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
141        use crate::protocol::AstEncoder;
142        use super::ColumnInfo;
143
144        let conn = self.conn.as_mut().expect("Connection should always be present");
145
146        let wire_bytes = AstEncoder::encode_cmd_reuse(
147            cmd,
148            &mut conn.sql_buf,
149            &mut conn.params_buf,
150        );
151
152        conn.send_bytes(&wire_bytes).await?;
153
154        let mut rows: Vec<super::PgRow> = Vec::new();
155        let mut column_info: Option<Arc<ColumnInfo>> = None;
156        let mut error: Option<PgError> = None;
157
158        loop {
159            let msg = conn.recv().await?;
160            match msg {
161                crate::protocol::BackendMessage::ParseComplete
162                | crate::protocol::BackendMessage::BindComplete => {}
163                crate::protocol::BackendMessage::RowDescription(fields) => {
164                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
165                }
166                crate::protocol::BackendMessage::DataRow(data) => {
167                    if error.is_none() {
168                        rows.push(super::PgRow {
169                            columns: data,
170                            column_info: column_info.clone(),
171                        });
172                    }
173                }
174                crate::protocol::BackendMessage::CommandComplete(_) => {}
175                crate::protocol::BackendMessage::ReadyForQuery(_) => {
176                    if let Some(err) = error {
177                        return Err(err);
178                    }
179                    return Ok(rows);
180                }
181                crate::protocol::BackendMessage::ErrorResponse(err) => {
182                    if error.is_none() {
183                        error = Some(PgError::Query(err.message));
184                    }
185                }
186                _ => {}
187            }
188        }
189    }
190}
191
192impl Drop for PooledConnection {
193    fn drop(&mut self) {
194        if let Some(conn) = self.conn.take() {
195            let pool = self.pool.clone();
196            tokio::spawn(async move {
197                pool.return_connection(conn).await;
198            });
199        }
200    }
201}
202
203impl std::ops::Deref for PooledConnection {
204    type Target = PgConnection;
205
206    fn deref(&self) -> &Self::Target {
207        self.conn
208            .as_ref()
209            .expect("Connection should always be present")
210    }
211}
212
213impl std::ops::DerefMut for PooledConnection {
214    fn deref_mut(&mut self) -> &mut Self::Target {
215        self.conn
216            .as_mut()
217            .expect("Connection should always be present")
218    }
219}
220
221/// Inner pool state (shared across clones).
222struct PgPoolInner {
223    config: PoolConfig,
224    connections: Mutex<Vec<PooledConn>>,
225    semaphore: Semaphore,
226    closed: AtomicBool,
227    active_count: AtomicUsize,
228    total_created: AtomicUsize,
229}
230
231impl PgPoolInner {
232    async fn return_connection(&self, conn: PgConnection) {
233
234        self.active_count.fetch_sub(1, Ordering::Relaxed);
235        
236
237        if self.closed.load(Ordering::Relaxed) {
238            return;
239        }
240        
241        let mut connections = self.connections.lock().await;
242        if connections.len() < self.config.max_connections {
243            connections.push(PooledConn {
244                conn,
245                created_at: Instant::now(),
246                last_used: Instant::now(),
247            });
248        }
249
250        self.semaphore.add_permits(1);
251    }
252
253    /// Get a healthy connection from the pool, or None if pool is empty.
254    async fn get_healthy_connection(&self) -> Option<PgConnection> {
255        let mut connections = self.connections.lock().await;
256
257        while let Some(pooled) = connections.pop() {
258            if pooled.last_used.elapsed() > self.config.idle_timeout {
259                // Connection is stale, drop it
260                continue;
261            }
262
263            if let Some(max_life) = self.config.max_lifetime
264                && pooled.created_at.elapsed() > max_life
265            {
266                // Connection exceeded max lifetime, recycle it
267                continue;
268            }
269
270            return Some(pooled.conn);
271        }
272
273        None
274    }
275}
276
277/// # Example
278/// ```ignore
279/// let config = PoolConfig::new("localhost", 5432, "user", "db")
280///     .password("secret")
281///     .max_connections(20);
282/// let pool = PgPool::connect(config).await?;
283/// // Get a connection from the pool
284/// let mut conn = pool.acquire().await?;
285/// conn.simple_query("SELECT 1").await?;
286/// ```
287#[derive(Clone)]
288pub struct PgPool {
289    inner: Arc<PgPoolInner>,
290}
291
292impl PgPool {
293    /// Create a new connection pool.
294    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
295        // Semaphore starts with max_connections permits
296        let semaphore = Semaphore::new(config.max_connections);
297
298        let mut initial_connections = Vec::new();
299        for _ in 0..config.min_connections {
300            let conn = Self::create_connection(&config).await?;
301            initial_connections.push(PooledConn {
302                conn,
303                created_at: Instant::now(),
304                last_used: Instant::now(),
305            });
306        }
307
308        let initial_count = initial_connections.len();
309
310        let inner = Arc::new(PgPoolInner {
311            config,
312            connections: Mutex::new(initial_connections),
313            semaphore,
314            closed: AtomicBool::new(false),
315            active_count: AtomicUsize::new(0),
316            total_created: AtomicUsize::new(initial_count),
317        });
318
319        Ok(Self { inner })
320    }
321
322    /// Acquire a connection from the pool.
323    pub async fn acquire(&self) -> PgResult<PooledConnection> {
324        if self.inner.closed.load(Ordering::Relaxed) {
325            return Err(PgError::Connection("Pool is closed".to_string()));
326        }
327
328        // Wait for available slot with timeout
329        let acquire_timeout = self.inner.config.acquire_timeout;
330        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
331            .await
332            .map_err(|_| {
333                PgError::Connection(format!(
334                    "Timed out waiting for connection ({}s)",
335                    acquire_timeout.as_secs()
336                ))
337            })?
338            .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
339        permit.forget();
340
341        // Try to get existing healthy connection
342        let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
343            conn
344        } else {
345            let conn = Self::create_connection(&self.inner.config).await?;
346            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
347            conn
348        };
349
350
351        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
352
353        Ok(PooledConnection {
354            conn: Some(conn),
355            pool: self.inner.clone(),
356        })
357    }
358
359    /// Get the current number of idle connections.
360    pub async fn idle_count(&self) -> usize {
361        self.inner.connections.lock().await.len()
362    }
363
364    /// Get the number of connections currently in use.
365    pub fn active_count(&self) -> usize {
366        self.inner.active_count.load(Ordering::Relaxed)
367    }
368
369    /// Get the maximum number of connections.
370    pub fn max_connections(&self) -> usize {
371        self.inner.config.max_connections
372    }
373
374    /// Get comprehensive pool statistics.
375    pub async fn stats(&self) -> PoolStats {
376        let idle = self.inner.connections.lock().await.len();
377        PoolStats {
378            active: self.inner.active_count.load(Ordering::Relaxed),
379            idle,
380            pending: self.inner.config.max_connections
381                - self.inner.semaphore.available_permits()
382                - self.active_count(),
383            max_size: self.inner.config.max_connections,
384            total_created: self.inner.total_created.load(Ordering::Relaxed),
385        }
386    }
387
388    /// Check if the pool is closed.
389    pub fn is_closed(&self) -> bool {
390        self.inner.closed.load(Ordering::Relaxed)
391    }
392
393    /// Close the pool gracefully.
394    pub async fn close(&self) {
395        self.inner.closed.store(true, Ordering::Relaxed);
396
397        let mut connections = self.inner.connections.lock().await;
398        connections.clear();
399    }
400
401    /// Create a new connection using the pool configuration.
402    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
403        match &config.password {
404            Some(password) => {
405                PgConnection::connect_with_password(
406                    &config.host,
407                    config.port,
408                    &config.user,
409                    &config.database,
410                    Some(password),
411                )
412                .await
413            }
414            None => {
415                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
416                    .await
417            }
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_pool_config() {
428        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
429            .password("secret123")
430            .max_connections(20)
431            .min_connections(5);
432
433        assert_eq!(config.host, "localhost");
434        assert_eq!(config.port, 5432);
435        assert_eq!(config.user, "user");
436        assert_eq!(config.database, "testdb");
437        assert_eq!(config.password, Some("secret123".to_string()));
438        assert_eq!(config.max_connections, 20);
439        assert_eq!(config.min_connections, 5);
440    }
441}