Skip to main content

qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14    pub host: String,
15    pub port: u16,
16    pub user: String,
17    pub database: String,
18    pub password: Option<String>,
19    pub max_connections: usize,
20    pub min_connections: usize,
21    pub idle_timeout: Duration,
22    pub acquire_timeout: Duration,
23    pub connect_timeout: Duration,
24    pub max_lifetime: Option<Duration>,
25    pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29    /// Create a new pool configuration with sensible defaults.
30    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31        Self {
32            host: host.to_string(),
33            port,
34            user: user.to_string(),
35            database: database.to_string(),
36            password: None,
37            max_connections: 10,
38            min_connections: 1,
39            idle_timeout: Duration::from_secs(600), // 10 minutes
40            acquire_timeout: Duration::from_secs(30), // 30 seconds
41            connect_timeout: Duration::from_secs(10), // 10 seconds
42            max_lifetime: None,                      // No limit by default
43            test_on_acquire: false,                  // Disabled by default for performance
44        }
45    }
46
47    /// Set password for authentication.
48    pub fn password(mut self, password: &str) -> Self {
49        self.password = Some(password.to_string());
50        self
51    }
52
53    pub fn max_connections(mut self, max: usize) -> Self {
54        self.max_connections = max;
55        self
56    }
57
58    /// Set minimum idle connections.
59    pub fn min_connections(mut self, min: usize) -> Self {
60        self.min_connections = min;
61        self
62    }
63
64    /// Set idle timeout (connections idle longer than this are closed).
65    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66        self.idle_timeout = timeout;
67        self
68    }
69
70    /// Set acquire timeout (max wait time when getting a connection).
71    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72        self.acquire_timeout = timeout;
73        self
74    }
75
76    /// Set connect timeout (max time to establish new connection).
77    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78        self.connect_timeout = timeout;
79        self
80    }
81
82    /// Set maximum lifetime of a connection before recycling.
83    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84        self.max_lifetime = Some(lifetime);
85        self
86    }
87
88    /// Enable connection validation on acquire.
89    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90        self.test_on_acquire = enabled;
91        self
92    }
93}
94
95/// Pool statistics for monitoring.
96#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98    pub active: usize,
99    pub idle: usize,
100    pub pending: usize,
101    /// Maximum connections configured
102    pub max_size: usize,
103    pub total_created: usize,
104}
105
106/// A pooled connection with creation timestamp for idle tracking.
107struct PooledConn {
108    conn: PgConnection,
109    created_at: Instant,
110    last_used: Instant,
111}
112
113/// A pooled connection that returns to the pool when dropped.
114///
115/// When `rls_dirty` is true (set by `acquire_with_rls`), the connection
116/// will automatically reset RLS session variables before returning to
117/// the pool. This prevents cross-tenant data leakage.
118pub struct PooledConnection {
119    conn: Option<PgConnection>,
120    pool: Arc<PgPoolInner>,
121    rls_dirty: bool,
122}
123
124impl PooledConnection {
125    /// Get a mutable reference to the underlying connection.
126    pub fn get_mut(&mut self) -> &mut PgConnection {
127        self.conn
128            .as_mut()
129            .expect("Connection should always be present")
130    }
131
132    /// Get a token to cancel the currently running query.
133    pub fn cancel_token(&self) -> crate::driver::CancelToken {
134        let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
135        crate::driver::CancelToken {
136            host: self.pool.config.host.clone(),
137            port: self.pool.config.port,
138            process_id,
139            secret_key,
140        }
141    }
142
143    /// Execute a QAIL command and fetch all rows (UNCACHED).
144    /// Returns rows with column metadata for JSON serialization.
145    pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
146        use crate::protocol::AstEncoder;
147        use super::ColumnInfo;
148
149        let conn = self.conn.as_mut().expect("Connection should always be present");
150
151        let wire_bytes = AstEncoder::encode_cmd_reuse(
152            cmd,
153            &mut conn.sql_buf,
154            &mut conn.params_buf,
155        );
156
157        conn.send_bytes(&wire_bytes).await?;
158
159        let mut rows: Vec<super::PgRow> = Vec::new();
160        let mut column_info: Option<Arc<ColumnInfo>> = None;
161        let mut error: Option<PgError> = None;
162
163        loop {
164            let msg = conn.recv().await?;
165            match msg {
166                crate::protocol::BackendMessage::ParseComplete
167                | crate::protocol::BackendMessage::BindComplete => {}
168                crate::protocol::BackendMessage::RowDescription(fields) => {
169                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
170                }
171                crate::protocol::BackendMessage::DataRow(data) => {
172                    if error.is_none() {
173                        rows.push(super::PgRow {
174                            columns: data,
175                            column_info: column_info.clone(),
176                        });
177                    }
178                }
179                crate::protocol::BackendMessage::CommandComplete(_) => {}
180                crate::protocol::BackendMessage::ReadyForQuery(_) => {
181                    if let Some(err) = error {
182                        return Err(err);
183                    }
184                    return Ok(rows);
185                }
186                crate::protocol::BackendMessage::ErrorResponse(err) => {
187                    if error.is_none() {
188                        error = Some(PgError::Query(err.message));
189                    }
190                }
191                _ => {}
192            }
193        }
194    }
195}
196
197impl Drop for PooledConnection {
198    fn drop(&mut self) {
199        if let Some(conn) = self.conn.take() {
200            let pool = self.pool.clone();
201            let rls_dirty = self.rls_dirty;
202            tokio::spawn(async move {
203                if rls_dirty {
204                    // Reset RLS session variables before returning to pool.
205                    // This prevents the next acquire() from inheriting
206                    // a stale tenant context from a different request.
207                    let mut conn = conn;
208                    let _ = conn.execute_simple(super::rls::reset_sql()).await;
209                    pool.return_connection(conn).await;
210                } else {
211                    pool.return_connection(conn).await;
212                }
213            });
214        }
215    }
216}
217
218impl std::ops::Deref for PooledConnection {
219    type Target = PgConnection;
220
221    fn deref(&self) -> &Self::Target {
222        self.conn
223            .as_ref()
224            .expect("Connection should always be present")
225    }
226}
227
228impl std::ops::DerefMut for PooledConnection {
229    fn deref_mut(&mut self) -> &mut Self::Target {
230        self.conn
231            .as_mut()
232            .expect("Connection should always be present")
233    }
234}
235
236/// Inner pool state (shared across clones).
237struct PgPoolInner {
238    config: PoolConfig,
239    connections: Mutex<Vec<PooledConn>>,
240    semaphore: Semaphore,
241    closed: AtomicBool,
242    active_count: AtomicUsize,
243    total_created: AtomicUsize,
244}
245
246impl PgPoolInner {
247    async fn return_connection(&self, conn: PgConnection) {
248
249        self.active_count.fetch_sub(1, Ordering::Relaxed);
250        
251
252        if self.closed.load(Ordering::Relaxed) {
253            return;
254        }
255        
256        let mut connections = self.connections.lock().await;
257        if connections.len() < self.config.max_connections {
258            connections.push(PooledConn {
259                conn,
260                created_at: Instant::now(),
261                last_used: Instant::now(),
262            });
263        }
264
265        self.semaphore.add_permits(1);
266    }
267
268    /// Get a healthy connection from the pool, or None if pool is empty.
269    async fn get_healthy_connection(&self) -> Option<PgConnection> {
270        let mut connections = self.connections.lock().await;
271
272        while let Some(pooled) = connections.pop() {
273            if pooled.last_used.elapsed() > self.config.idle_timeout {
274                // Connection is stale, drop it
275                continue;
276            }
277
278            if let Some(max_life) = self.config.max_lifetime
279                && pooled.created_at.elapsed() > max_life
280            {
281                // Connection exceeded max lifetime, recycle it
282                continue;
283            }
284
285            return Some(pooled.conn);
286        }
287
288        None
289    }
290}
291
292/// # Example
293/// ```ignore
294/// let config = PoolConfig::new("localhost", 5432, "user", "db")
295///     .password("secret")
296///     .max_connections(20);
297/// let pool = PgPool::connect(config).await?;
298/// // Get a connection from the pool
299/// let mut conn = pool.acquire().await?;
300/// conn.simple_query("SELECT 1").await?;
301/// ```
302#[derive(Clone)]
303pub struct PgPool {
304    inner: Arc<PgPoolInner>,
305}
306
307impl PgPool {
308    /// Create a new connection pool.
309    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
310        // Semaphore starts with max_connections permits
311        let semaphore = Semaphore::new(config.max_connections);
312
313        let mut initial_connections = Vec::new();
314        for _ in 0..config.min_connections {
315            let conn = Self::create_connection(&config).await?;
316            initial_connections.push(PooledConn {
317                conn,
318                created_at: Instant::now(),
319                last_used: Instant::now(),
320            });
321        }
322
323        let initial_count = initial_connections.len();
324
325        let inner = Arc::new(PgPoolInner {
326            config,
327            connections: Mutex::new(initial_connections),
328            semaphore,
329            closed: AtomicBool::new(false),
330            active_count: AtomicUsize::new(0),
331            total_created: AtomicUsize::new(initial_count),
332        });
333
334        Ok(Self { inner })
335    }
336
337    /// Acquire a connection from the pool.
338    pub async fn acquire(&self) -> PgResult<PooledConnection> {
339        if self.inner.closed.load(Ordering::Relaxed) {
340            return Err(PgError::Connection("Pool is closed".to_string()));
341        }
342
343        // Wait for available slot with timeout
344        let acquire_timeout = self.inner.config.acquire_timeout;
345        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
346            .await
347            .map_err(|_| {
348                PgError::Connection(format!(
349                    "Timed out waiting for connection ({}s)",
350                    acquire_timeout.as_secs()
351                ))
352            })?
353            .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
354        permit.forget();
355
356        // Try to get existing healthy connection
357        let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
358            conn
359        } else {
360            let conn = Self::create_connection(&self.inner.config).await?;
361            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
362            conn
363        };
364
365
366        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
367
368        Ok(PooledConnection {
369            conn: Some(conn),
370            pool: self.inner.clone(),
371            rls_dirty: false,
372        })
373    }
374
375    /// Acquire a connection with RLS context pre-configured.
376    ///
377    /// Sets PostgreSQL session variables for tenant isolation before
378    /// returning the connection. When the connection is dropped, it
379    /// automatically clears the RLS context before returning to the pool.
380    ///
381    /// # Example
382    /// ```ignore
383    /// use qail_core::rls::RlsContext;
384    ///
385    /// let mut conn = pool.acquire_with_rls(
386    ///     RlsContext::operator("550e8400-e29b-41d4-a716-446655440000")
387    /// ).await?;
388    /// // All queries through `conn` are now scoped to this operator
389    /// ```
390    pub async fn acquire_with_rls(
391        &self,
392        ctx: qail_core::rls::RlsContext,
393    ) -> PgResult<PooledConnection> {
394        let mut conn = self.acquire().await?;
395
396        // Set RLS context on the raw connection
397        let sql = super::rls::context_to_sql(&ctx);
398        let pg_conn = conn.get_mut();
399        pg_conn.execute_simple(&sql).await?;
400
401        // Mark dirty so Drop resets context before pool return
402        conn.rls_dirty = true;
403
404        Ok(conn)
405    }
406
407    /// Get the current number of idle connections.
408    pub async fn idle_count(&self) -> usize {
409        self.inner.connections.lock().await.len()
410    }
411
412    /// Get the number of connections currently in use.
413    pub fn active_count(&self) -> usize {
414        self.inner.active_count.load(Ordering::Relaxed)
415    }
416
417    /// Get the maximum number of connections.
418    pub fn max_connections(&self) -> usize {
419        self.inner.config.max_connections
420    }
421
422    /// Get comprehensive pool statistics.
423    pub async fn stats(&self) -> PoolStats {
424        let idle = self.inner.connections.lock().await.len();
425        PoolStats {
426            active: self.inner.active_count.load(Ordering::Relaxed),
427            idle,
428            pending: self.inner.config.max_connections
429                - self.inner.semaphore.available_permits()
430                - self.active_count(),
431            max_size: self.inner.config.max_connections,
432            total_created: self.inner.total_created.load(Ordering::Relaxed),
433        }
434    }
435
436    /// Check if the pool is closed.
437    pub fn is_closed(&self) -> bool {
438        self.inner.closed.load(Ordering::Relaxed)
439    }
440
441    /// Close the pool gracefully.
442    pub async fn close(&self) {
443        self.inner.closed.store(true, Ordering::Relaxed);
444
445        let mut connections = self.inner.connections.lock().await;
446        connections.clear();
447    }
448
449    /// Create a new connection using the pool configuration.
450    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
451        match &config.password {
452            Some(password) => {
453                PgConnection::connect_with_password(
454                    &config.host,
455                    config.port,
456                    &config.user,
457                    &config.database,
458                    Some(password),
459                )
460                .await
461            }
462            None => {
463                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
464                    .await
465            }
466        }
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_pool_config() {
476        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
477            .password("secret123")
478            .max_connections(20)
479            .min_connections(5);
480
481        assert_eq!(config.host, "localhost");
482        assert_eq!(config.port, 5432);
483        assert_eq!(config.user, "user");
484        assert_eq!(config.database, "testdb");
485        assert_eq!(config.password, Some("secret123".to_string()));
486        assert_eq!(config.max_connections, 20);
487        assert_eq!(config.min_connections, 5);
488    }
489}