Skip to main content

qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14    pub host: String,
15    pub port: u16,
16    pub user: String,
17    pub database: String,
18    pub password: Option<String>,
19    pub max_connections: usize,
20    pub min_connections: usize,
21    pub idle_timeout: Duration,
22    pub acquire_timeout: Duration,
23    pub connect_timeout: Duration,
24    pub max_lifetime: Option<Duration>,
25    pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29    /// Create a new pool configuration with sensible defaults.
30    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31        Self {
32            host: host.to_string(),
33            port,
34            user: user.to_string(),
35            database: database.to_string(),
36            password: None,
37            max_connections: 10,
38            min_connections: 1,
39            idle_timeout: Duration::from_secs(600), // 10 minutes
40            acquire_timeout: Duration::from_secs(30), // 30 seconds
41            connect_timeout: Duration::from_secs(10), // 10 seconds
42            max_lifetime: None,                      // No limit by default
43            test_on_acquire: false,                  // Disabled by default for performance
44        }
45    }
46
47    /// Set password for authentication.
48    pub fn password(mut self, password: &str) -> Self {
49        self.password = Some(password.to_string());
50        self
51    }
52
53    pub fn max_connections(mut self, max: usize) -> Self {
54        self.max_connections = max;
55        self
56    }
57
58    /// Set minimum idle connections.
59    pub fn min_connections(mut self, min: usize) -> Self {
60        self.min_connections = min;
61        self
62    }
63
64    /// Set idle timeout (connections idle longer than this are closed).
65    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66        self.idle_timeout = timeout;
67        self
68    }
69
70    /// Set acquire timeout (max wait time when getting a connection).
71    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72        self.acquire_timeout = timeout;
73        self
74    }
75
76    /// Set connect timeout (max time to establish new connection).
77    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78        self.connect_timeout = timeout;
79        self
80    }
81
82    /// Set maximum lifetime of a connection before recycling.
83    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84        self.max_lifetime = Some(lifetime);
85        self
86    }
87
88    /// Enable connection validation on acquire.
89    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90        self.test_on_acquire = enabled;
91        self
92    }
93
94    /// Create a `PoolConfig` from a centralized `QailConfig`.
95    ///
96    /// Parses `postgres.url` for host/port/user/database/password
97    /// and applies pool tuning from `[postgres]` section.
98    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
99        let pg = &qail.postgres;
100        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
101
102        let mut config = PoolConfig::new(&host, port, &user, &database)
103            .max_connections(pg.max_connections)
104            .min_connections(pg.min_connections)
105            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
106            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
107            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
108            .test_on_acquire(pg.test_on_acquire);
109
110        if let Some(ref pw) = password {
111            config = config.password(pw);
112        }
113
114        Ok(config)
115    }
116}
117
118/// Parse a postgres URL into (host, port, user, database, password).
119fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
120    let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
121
122    let (credentials, host_part) = if url.contains('@') {
123        let mut parts = url.splitn(2, '@');
124        let creds = parts.next().unwrap_or("");
125        let host = parts.next().unwrap_or("localhost/postgres");
126        (Some(creds), host)
127    } else {
128        (None, url)
129    };
130
131    let (host_port, database) = if host_part.contains('/') {
132        let mut parts = host_part.splitn(2, '/');
133        (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
134    } else {
135        (host_part, "postgres".to_string())
136    };
137
138    let (host, port) = if host_port.contains(':') {
139        let mut parts = host_port.split(':');
140        let h = parts.next().unwrap_or("localhost").to_string();
141        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
142        (h, p)
143    } else {
144        (host_port.to_string(), 5432u16)
145    };
146
147    let (user, password) = if let Some(creds) = credentials {
148        if creds.contains(':') {
149            let mut parts = creds.splitn(2, ':');
150            let u = parts.next().unwrap_or("postgres").to_string();
151            let p = parts.next().map(|s| s.to_string());
152            (u, p)
153        } else {
154            (creds.to_string(), None)
155        }
156    } else {
157        ("postgres".to_string(), None)
158    };
159
160    Ok((host, port, user, database, password))
161}
162
163/// Pool statistics for monitoring.
164#[derive(Debug, Clone, Default)]
165pub struct PoolStats {
166    pub active: usize,
167    pub idle: usize,
168    pub pending: usize,
169    /// Maximum connections configured
170    pub max_size: usize,
171    pub total_created: usize,
172}
173
174/// A pooled connection with creation timestamp for idle tracking.
175struct PooledConn {
176    conn: PgConnection,
177    created_at: Instant,
178    last_used: Instant,
179}
180
181/// A pooled connection that returns to the pool when dropped.
182///
183/// When `rls_dirty` is true (set by `acquire_with_rls`), the connection
184/// will automatically reset RLS session variables before returning to
185/// the pool. This prevents cross-tenant data leakage.
186pub struct PooledConnection {
187    conn: Option<PgConnection>,
188    pool: Arc<PgPoolInner>,
189    rls_dirty: bool,
190}
191
192impl PooledConnection {
193    /// Get a mutable reference to the underlying connection.
194    pub fn get_mut(&mut self) -> &mut PgConnection {
195        self.conn
196            .as_mut()
197            .expect("Connection should always be present")
198    }
199
200    /// Get a token to cancel the currently running query.
201    pub fn cancel_token(&self) -> crate::driver::CancelToken {
202        let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
203        crate::driver::CancelToken {
204            host: self.pool.config.host.clone(),
205            port: self.pool.config.port,
206            process_id,
207            secret_key,
208        }
209    }
210
211    /// Execute a QAIL command and fetch all rows (UNCACHED).
212    /// Returns rows with column metadata for JSON serialization.
213    pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
214        use crate::protocol::AstEncoder;
215        use super::ColumnInfo;
216
217        let conn = self.conn.as_mut().expect("Connection should always be present");
218
219        let wire_bytes = AstEncoder::encode_cmd_reuse(
220            cmd,
221            &mut conn.sql_buf,
222            &mut conn.params_buf,
223        );
224
225        conn.send_bytes(&wire_bytes).await?;
226
227        let mut rows: Vec<super::PgRow> = Vec::new();
228        let mut column_info: Option<Arc<ColumnInfo>> = None;
229        let mut error: Option<PgError> = None;
230
231        loop {
232            let msg = conn.recv().await?;
233            match msg {
234                crate::protocol::BackendMessage::ParseComplete
235                | crate::protocol::BackendMessage::BindComplete => {}
236                crate::protocol::BackendMessage::RowDescription(fields) => {
237                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
238                }
239                crate::protocol::BackendMessage::DataRow(data) => {
240                    if error.is_none() {
241                        rows.push(super::PgRow {
242                            columns: data,
243                            column_info: column_info.clone(),
244                        });
245                    }
246                }
247                crate::protocol::BackendMessage::CommandComplete(_) => {}
248                crate::protocol::BackendMessage::ReadyForQuery(_) => {
249                    if let Some(err) = error {
250                        return Err(err);
251                    }
252                    return Ok(rows);
253                }
254                crate::protocol::BackendMessage::ErrorResponse(err) => {
255                    if error.is_none() {
256                        error = Some(PgError::Query(err.message));
257                    }
258                }
259                _ => {}
260            }
261        }
262    }
263}
264
265impl Drop for PooledConnection {
266    fn drop(&mut self) {
267        if let Some(conn) = self.conn.take() {
268            let pool = self.pool.clone();
269            let rls_dirty = self.rls_dirty;
270            tokio::spawn(async move {
271                if rls_dirty {
272                    // Reset RLS session variables before returning to pool.
273                    // This prevents the next acquire() from inheriting
274                    // a stale tenant context from a different request.
275                    let mut conn = conn;
276                    let _ = conn.execute_simple(super::rls::reset_sql()).await;
277                    pool.return_connection(conn).await;
278                } else {
279                    pool.return_connection(conn).await;
280                }
281            });
282        }
283    }
284}
285
286impl std::ops::Deref for PooledConnection {
287    type Target = PgConnection;
288
289    fn deref(&self) -> &Self::Target {
290        self.conn
291            .as_ref()
292            .expect("Connection should always be present")
293    }
294}
295
296impl std::ops::DerefMut for PooledConnection {
297    fn deref_mut(&mut self) -> &mut Self::Target {
298        self.conn
299            .as_mut()
300            .expect("Connection should always be present")
301    }
302}
303
304/// Inner pool state (shared across clones).
305struct PgPoolInner {
306    config: PoolConfig,
307    connections: Mutex<Vec<PooledConn>>,
308    semaphore: Semaphore,
309    closed: AtomicBool,
310    active_count: AtomicUsize,
311    total_created: AtomicUsize,
312}
313
314impl PgPoolInner {
315    async fn return_connection(&self, conn: PgConnection) {
316
317        self.active_count.fetch_sub(1, Ordering::Relaxed);
318        
319
320        if self.closed.load(Ordering::Relaxed) {
321            return;
322        }
323        
324        let mut connections = self.connections.lock().await;
325        if connections.len() < self.config.max_connections {
326            connections.push(PooledConn {
327                conn,
328                created_at: Instant::now(),
329                last_used: Instant::now(),
330            });
331        }
332
333        self.semaphore.add_permits(1);
334    }
335
336    /// Get a healthy connection from the pool, or None if pool is empty.
337    async fn get_healthy_connection(&self) -> Option<PgConnection> {
338        let mut connections = self.connections.lock().await;
339
340        while let Some(pooled) = connections.pop() {
341            if pooled.last_used.elapsed() > self.config.idle_timeout {
342                // Connection is stale, drop it
343                continue;
344            }
345
346            if let Some(max_life) = self.config.max_lifetime
347                && pooled.created_at.elapsed() > max_life
348            {
349                // Connection exceeded max lifetime, recycle it
350                continue;
351            }
352
353            return Some(pooled.conn);
354        }
355
356        None
357    }
358}
359
360/// # Example
361/// ```ignore
362/// let config = PoolConfig::new("localhost", 5432, "user", "db")
363///     .password("secret")
364///     .max_connections(20);
365/// let pool = PgPool::connect(config).await?;
366/// // Get a connection from the pool
367/// let mut conn = pool.acquire().await?;
368/// conn.simple_query("SELECT 1").await?;
369/// ```
370#[derive(Clone)]
371pub struct PgPool {
372    inner: Arc<PgPoolInner>,
373}
374
375impl PgPool {
376    /// Create a pool from `qail.toml` (loads and parses automatically).
377    ///
378    /// # Example
379    /// ```ignore
380    /// let pool = PgPool::from_config().await?;
381    /// ```
382    pub async fn from_config() -> PgResult<Self> {
383        let qail = qail_core::config::QailConfig::load()
384            .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
385        let config = PoolConfig::from_qail_config(&qail)?;
386        Self::connect(config).await
387    }
388
389    /// Create a new connection pool.
390    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
391        // Semaphore starts with max_connections permits
392        let semaphore = Semaphore::new(config.max_connections);
393
394        let mut initial_connections = Vec::new();
395        for _ in 0..config.min_connections {
396            let conn = Self::create_connection(&config).await?;
397            initial_connections.push(PooledConn {
398                conn,
399                created_at: Instant::now(),
400                last_used: Instant::now(),
401            });
402        }
403
404        let initial_count = initial_connections.len();
405
406        let inner = Arc::new(PgPoolInner {
407            config,
408            connections: Mutex::new(initial_connections),
409            semaphore,
410            closed: AtomicBool::new(false),
411            active_count: AtomicUsize::new(0),
412            total_created: AtomicUsize::new(initial_count),
413        });
414
415        Ok(Self { inner })
416    }
417
418    /// Acquire a connection from the pool.
419    pub async fn acquire(&self) -> PgResult<PooledConnection> {
420        if self.inner.closed.load(Ordering::Relaxed) {
421            return Err(PgError::Connection("Pool is closed".to_string()));
422        }
423
424        // Wait for available slot with timeout
425        let acquire_timeout = self.inner.config.acquire_timeout;
426        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
427            .await
428            .map_err(|_| {
429                PgError::Connection(format!(
430                    "Timed out waiting for connection ({}s)",
431                    acquire_timeout.as_secs()
432                ))
433            })?
434            .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
435        permit.forget();
436
437        // Try to get existing healthy connection
438        let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
439            conn
440        } else {
441            let conn = Self::create_connection(&self.inner.config).await?;
442            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
443            conn
444        };
445
446
447        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
448
449        Ok(PooledConnection {
450            conn: Some(conn),
451            pool: self.inner.clone(),
452            rls_dirty: false,
453        })
454    }
455
456    /// Acquire a connection with RLS context pre-configured.
457    ///
458    /// Sets PostgreSQL session variables for tenant isolation before
459    /// returning the connection. When the connection is dropped, it
460    /// automatically clears the RLS context before returning to the pool.
461    ///
462    /// # Example
463    /// ```ignore
464    /// use qail_core::rls::RlsContext;
465    ///
466    /// let mut conn = pool.acquire_with_rls(
467    ///     RlsContext::operator("550e8400-e29b-41d4-a716-446655440000")
468    /// ).await?;
469    /// // All queries through `conn` are now scoped to this operator
470    /// ```
471    pub async fn acquire_with_rls(
472        &self,
473        ctx: qail_core::rls::RlsContext,
474    ) -> PgResult<PooledConnection> {
475        let mut conn = self.acquire().await?;
476
477        // Set RLS context on the raw connection
478        let sql = super::rls::context_to_sql(&ctx);
479        let pg_conn = conn.get_mut();
480        pg_conn.execute_simple(&sql).await?;
481
482        // Mark dirty so Drop resets context before pool return
483        conn.rls_dirty = true;
484
485        Ok(conn)
486    }
487
488    /// Get the current number of idle connections.
489    pub async fn idle_count(&self) -> usize {
490        self.inner.connections.lock().await.len()
491    }
492
493    /// Get the number of connections currently in use.
494    pub fn active_count(&self) -> usize {
495        self.inner.active_count.load(Ordering::Relaxed)
496    }
497
498    /// Get the maximum number of connections.
499    pub fn max_connections(&self) -> usize {
500        self.inner.config.max_connections
501    }
502
503    /// Get comprehensive pool statistics.
504    pub async fn stats(&self) -> PoolStats {
505        let idle = self.inner.connections.lock().await.len();
506        PoolStats {
507            active: self.inner.active_count.load(Ordering::Relaxed),
508            idle,
509            pending: self.inner.config.max_connections
510                - self.inner.semaphore.available_permits()
511                - self.active_count(),
512            max_size: self.inner.config.max_connections,
513            total_created: self.inner.total_created.load(Ordering::Relaxed),
514        }
515    }
516
517    /// Check if the pool is closed.
518    pub fn is_closed(&self) -> bool {
519        self.inner.closed.load(Ordering::Relaxed)
520    }
521
522    /// Close the pool gracefully.
523    pub async fn close(&self) {
524        self.inner.closed.store(true, Ordering::Relaxed);
525
526        let mut connections = self.inner.connections.lock().await;
527        connections.clear();
528    }
529
530    /// Create a new connection using the pool configuration.
531    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
532        match &config.password {
533            Some(password) => {
534                PgConnection::connect_with_password(
535                    &config.host,
536                    config.port,
537                    &config.user,
538                    &config.database,
539                    Some(password),
540                )
541                .await
542            }
543            None => {
544                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
545                    .await
546            }
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn test_pool_config() {
557        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
558            .password("secret123")
559            .max_connections(20)
560            .min_connections(5);
561
562        assert_eq!(config.host, "localhost");
563        assert_eq!(config.port, 5432);
564        assert_eq!(config.user, "user");
565        assert_eq!(config.database, "testdb");
566        assert_eq!(config.password, Some("secret123".to_string()));
567        assert_eq!(config.max_connections, 20);
568        assert_eq!(config.min_connections, 5);
569    }
570}