Skip to main content

qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12/// Configuration for a PostgreSQL connection pool.
13///
14/// Use the builder pattern to customise settings:
15///
16/// ```ignore
17/// use std::time::Duration;
18/// use qail_pg::driver::pool::PoolConfig;
19/// let config = PoolConfig::new("localhost", 5432, "app", "mydb")
20///     .password("secret")
21///     .max_connections(20)
22///     .acquire_timeout(Duration::from_secs(5));
23/// ```
24#[derive(Clone)]
25pub struct PoolConfig {
26    /// PostgreSQL server hostname or IP address.
27    pub host: String,
28    /// PostgreSQL server port (default: 5432).
29    pub port: u16,
30    /// Database role / user name.
31    pub user: String,
32    /// Target database name.
33    pub database: String,
34    /// Optional password for authentication.
35    pub password: Option<String>,
36    /// Hard upper limit on simultaneous connections (default: 10).
37    pub max_connections: usize,
38    /// Minimum idle connections kept warm in the pool (default: 1).
39    pub min_connections: usize,
40    /// Close idle connections after this duration (default: 10 min).
41    pub idle_timeout: Duration,
42    /// Maximum time to wait when acquiring a connection (default: 30s).
43    pub acquire_timeout: Duration,
44    /// TCP connect timeout for new connections (default: 10s).
45    pub connect_timeout: Duration,
46    /// Optional maximum lifetime of any connection in the pool.
47    pub max_lifetime: Option<Duration>,
48    /// When `true`, run a health check (`SELECT 1`) before handing out a connection.
49    pub test_on_acquire: bool,
50}
51
52impl PoolConfig {
53    /// Create a new pool configuration with sensible defaults.
54    ///
55    /// # Arguments
56    ///
57    /// * `host` — PostgreSQL server hostname or IP.
58    /// * `port` — TCP port (typically 5432).
59    /// * `user` — PostgreSQL role name.
60    /// * `database` — Target database name.
61    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
62        Self {
63            host: host.to_string(),
64            port,
65            user: user.to_string(),
66            database: database.to_string(),
67            password: None,
68            max_connections: 10,
69            min_connections: 1,
70            idle_timeout: Duration::from_secs(600), // 10 minutes
71            acquire_timeout: Duration::from_secs(30), // 30 seconds
72            connect_timeout: Duration::from_secs(10), // 10 seconds
73            max_lifetime: None,                      // No limit by default
74            test_on_acquire: false,                  // Disabled by default for performance
75        }
76    }
77
78    /// Set password for authentication.
79    pub fn password(mut self, password: &str) -> Self {
80        self.password = Some(password.to_string());
81        self
82    }
83
84    /// Set maximum simultaneous connections.
85    pub fn max_connections(mut self, max: usize) -> Self {
86        self.max_connections = max;
87        self
88    }
89
90    /// Set minimum idle connections.
91    pub fn min_connections(mut self, min: usize) -> Self {
92        self.min_connections = min;
93        self
94    }
95
96    /// Set idle timeout (connections idle longer than this are closed).
97    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
98        self.idle_timeout = timeout;
99        self
100    }
101
102    /// Set acquire timeout (max wait time when getting a connection).
103    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
104        self.acquire_timeout = timeout;
105        self
106    }
107
108    /// Set connect timeout (max time to establish new connection).
109    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
110        self.connect_timeout = timeout;
111        self
112    }
113
114    /// Set maximum lifetime of a connection before recycling.
115    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
116        self.max_lifetime = Some(lifetime);
117        self
118    }
119
120    /// Enable connection validation on acquire.
121    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
122        self.test_on_acquire = enabled;
123        self
124    }
125
126    /// Create a `PoolConfig` from a centralized `QailConfig`.
127    ///
128    /// Parses `postgres.url` for host/port/user/database/password
129    /// and applies pool tuning from `[postgres]` section.
130    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
131        let pg = &qail.postgres;
132        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
133
134        let mut config = PoolConfig::new(&host, port, &user, &database)
135            .max_connections(pg.max_connections)
136            .min_connections(pg.min_connections)
137            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
138            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
139            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
140            .test_on_acquire(pg.test_on_acquire);
141
142        if let Some(ref pw) = password {
143            config = config.password(pw);
144        }
145
146        Ok(config)
147    }
148}
149
150/// Parse a postgres URL into (host, port, user, database, password).
151fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
152    let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
153
154    let (credentials, host_part) = if url.contains('@') {
155        let mut parts = url.splitn(2, '@');
156        let creds = parts.next().unwrap_or("");
157        let host = parts.next().unwrap_or("localhost/postgres");
158        (Some(creds), host)
159    } else {
160        (None, url)
161    };
162
163    let (host_port, database) = if host_part.contains('/') {
164        let mut parts = host_part.splitn(2, '/');
165        (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
166    } else {
167        (host_part, "postgres".to_string())
168    };
169
170    let (host, port) = if host_port.contains(':') {
171        let mut parts = host_port.split(':');
172        let h = parts.next().unwrap_or("localhost").to_string();
173        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
174        (h, p)
175    } else {
176        (host_port.to_string(), 5432u16)
177    };
178
179    let (user, password) = if let Some(creds) = credentials {
180        if creds.contains(':') {
181            let mut parts = creds.splitn(2, ':');
182            let u = parts.next().unwrap_or("postgres").to_string();
183            let p = parts.next().map(|s| s.to_string());
184            (u, p)
185        } else {
186            (creds.to_string(), None)
187        }
188    } else {
189        ("postgres".to_string(), None)
190    };
191
192    Ok((host, port, user, database, password))
193}
194
195/// Pool statistics for monitoring.
196#[derive(Debug, Clone, Default)]
197pub struct PoolStats {
198    /// Connections currently checked out by callers.
199    pub active: usize,
200    /// Connections idle in the pool, ready for reuse.
201    pub idle: usize,
202    /// Callers waiting for a connection.
203    pub pending: usize,
204    /// Maximum connections configured
205    pub max_size: usize,
206    /// Cumulative connections created since pool startup.
207    pub total_created: usize,
208}
209
210/// A pooled connection with creation timestamp for idle tracking.
211struct PooledConn {
212    conn: PgConnection,
213    created_at: Instant,
214    last_used: Instant,
215}
216
217/// A pooled connection that returns to the pool when dropped.
218///
219/// When `rls_dirty` is true (set by `acquire_with_rls`), the connection
220/// will automatically reset RLS session variables before returning to
221/// the pool. This prevents cross-tenant data leakage.
222pub struct PooledConnection {
223    conn: Option<PgConnection>,
224    pool: Arc<PgPoolInner>,
225    rls_dirty: bool,
226}
227
228impl PooledConnection {
229    /// Get a reference to the underlying connection, returning an error
230    /// if the connection has already been released.
231    fn conn_ref(&self) -> PgResult<&PgConnection> {
232        self.conn.as_ref().ok_or_else(|| PgError::Connection(
233            "Connection already released back to pool".into()
234        ))
235    }
236
237    /// Get a mutable reference to the underlying connection, returning an error
238    /// if the connection has already been released.
239    fn conn_mut(&mut self) -> PgResult<&mut PgConnection> {
240        self.conn.as_mut().ok_or_else(|| PgError::Connection(
241            "Connection already released back to pool".into()
242        ))
243    }
244
245    /// Get a mutable reference to the underlying connection.
246    /// Panics if the connection has been released (use `conn_mut()` for fallible access).
247    pub fn get_mut(&mut self) -> &mut PgConnection {
248        // SAFETY: Connection is always Some while PooledConnection is in use.
249        // Only becomes None after release() or Drop, after which no methods should be called.
250        self.conn
251            .as_mut()
252            .expect("Connection should always be present")
253    }
254
255    /// Get a token to cancel the currently running query.
256    pub fn cancel_token(&self) -> PgResult<crate::driver::CancelToken> {
257        let conn = self.conn_ref()?;
258        let (process_id, secret_key) = conn.get_cancel_key();
259        Ok(crate::driver::CancelToken {
260            host: self.pool.config.host.clone(),
261            port: self.pool.config.port,
262            process_id,
263            secret_key,
264        })
265    }
266
267    /// Deterministic connection cleanup and pool return.
268    ///
269    /// This is the **correct** way to return a connection to the pool.
270    /// COMMITs the transaction (which auto-resets transaction-local RLS
271    /// session variables) and returns the connection to the pool with
272    /// prepared statement caches intact.
273    ///
274    /// If cleanup fails, the connection is destroyed (not returned to pool).
275    ///
276    /// # Usage
277    /// ```ignore
278    /// let mut conn = pool.acquire_with_rls(ctx).await?;
279    /// let result = conn.fetch_all_cached(&cmd).await;
280    /// conn.release().await; // COMMIT + return to pool
281    /// result
282    /// ```
283    pub async fn release(mut self) {
284        if let Some(mut conn) = self.conn.take() {
285            // COMMIT the transaction opened by acquire_with_rls.
286            // Transaction-local set_config values auto-reset on COMMIT,
287            // so no explicit RLS cleanup is needed.
288            // Prepared statements survive — they are NOT transaction-scoped.
289            if let Err(e) = conn.execute_simple(super::rls::reset_sql()).await {
290                eprintln!(
291                    "[CRITICAL] pool_release_failed: COMMIT failed — \
292                     dropping connection to prevent state leak: {}",
293                    e
294                );
295                return; // Connection destroyed — not returned to pool
296            }
297
298            self.pool.return_connection(conn).await;
299        }
300    }
301
302    /// Execute a QAIL command and fetch all rows (UNCACHED).
303    /// Returns rows with column metadata for JSON serialization.
304    pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
305        use crate::protocol::AstEncoder;
306        use super::ColumnInfo;
307
308        let conn = self.conn_mut()?;
309
310        let wire_bytes = AstEncoder::encode_cmd_reuse(
311            cmd,
312            &mut conn.sql_buf,
313            &mut conn.params_buf,
314        )
315        .map_err(|e| PgError::Encode(e.to_string()))?;
316
317        conn.send_bytes(&wire_bytes).await?;
318
319        let mut rows: Vec<super::PgRow> = Vec::new();
320        let mut column_info: Option<Arc<ColumnInfo>> = None;
321        let mut error: Option<PgError> = None;
322
323        loop {
324            let msg = conn.recv().await?;
325            match msg {
326                crate::protocol::BackendMessage::ParseComplete
327                | crate::protocol::BackendMessage::BindComplete => {}
328                crate::protocol::BackendMessage::RowDescription(fields) => {
329                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
330                }
331                crate::protocol::BackendMessage::DataRow(data) => {
332                    if error.is_none() {
333                        rows.push(super::PgRow {
334                            columns: data,
335                            column_info: column_info.clone(),
336                        });
337                    }
338                }
339                crate::protocol::BackendMessage::CommandComplete(_) => {}
340                crate::protocol::BackendMessage::ReadyForQuery(_) => {
341                    if let Some(err) = error {
342                        return Err(err);
343                    }
344                    return Ok(rows);
345                }
346                crate::protocol::BackendMessage::ErrorResponse(err) => {
347                    if error.is_none() {
348                        error = Some(PgError::Query(err.message));
349                    }
350                }
351                _ => {}
352            }
353        }
354    }
355
356    /// Execute a QAIL command and fetch all rows (FAST VERSION).
357    /// Uses native AST-to-wire encoding and optimized recv_with_data_fast.
358    /// Skips column metadata for maximum speed.
359    pub async fn fetch_all_fast(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
360        use crate::protocol::AstEncoder;
361
362        let conn = self.conn_mut()?;
363
364        AstEncoder::encode_cmd_reuse_into(
365            cmd,
366            &mut conn.sql_buf,
367            &mut conn.params_buf,
368            &mut conn.write_buf,
369        )
370        .map_err(|e| PgError::Encode(e.to_string()))?;
371
372        conn.flush_write_buf().await?;
373
374        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
375        let mut error: Option<PgError> = None;
376
377        loop {
378            let res = conn.recv_with_data_fast().await;
379            match res {
380                Ok((msg_type, data)) => {
381                    match msg_type {
382                        b'D' => {
383                            if error.is_none() && let Some(columns) = data {
384                                rows.push(super::PgRow {
385                                    columns,
386                                    column_info: None,
387                                });
388                            }
389                        }
390                        b'Z' => {
391                            if let Some(err) = error {
392                                return Err(err);
393                            }
394                            return Ok(rows);
395                        }
396                        _ => {}
397                    }
398                }
399                Err(e) => {
400                    if error.is_none() {
401                        error = Some(e);
402                    }
403                }
404            }
405        }
406    }
407
408    /// Execute a QAIL command and fetch all rows (CACHED).
409    /// Uses prepared statement caching: Parse+Describe on first call,
410    /// then Bind+Execute only on subsequent calls with the same SQL shape.
411    /// This matches PostgREST's behavior for fair benchmarks.
412    pub async fn fetch_all_cached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
413        use super::ColumnInfo;
414        use std::collections::hash_map::DefaultHasher;
415        use std::hash::{Hash, Hasher};
416
417        let conn = self.conn.as_mut().ok_or_else(|| PgError::Connection(
418            "Connection already released back to pool".into()
419        ))?;
420
421        conn.sql_buf.clear();
422        conn.params_buf.clear();
423
424        // Encode SQL + params to reusable buffers
425        match cmd.action {
426            qail_core::ast::Action::Get | qail_core::ast::Action::With => {
427                crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
428            }
429            qail_core::ast::Action::Add => {
430                crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
431            }
432            qail_core::ast::Action::Set => {
433                crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
434            }
435            qail_core::ast::Action::Del => {
436                crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
437            }
438            _ => {
439                // Fallback: unsupported actions go through uncached path
440                return self.fetch_all_uncached(cmd).await;
441            }
442        }
443
444        let mut hasher = DefaultHasher::new();
445        conn.sql_buf.hash(&mut hasher);
446        let sql_hash = hasher.finish();
447
448        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
449
450        conn.write_buf.clear();
451
452        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
453            name.clone()
454        } else {
455            let name = format!("qail_{:x}", sql_hash);
456
457            conn.evict_prepared_if_full();
458
459            let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
460
461            use crate::protocol::PgEncoder;
462            let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
463            let describe_msg = PgEncoder::encode_describe(false, &name);
464            conn.write_buf.extend_from_slice(&parse_msg);
465            conn.write_buf.extend_from_slice(&describe_msg);
466
467            conn.stmt_cache.put(sql_hash, name.clone());
468            conn.prepared_statements.insert(name.clone(), sql_str.to_string());
469
470            // Register in global hot-statement registry for cross-connection sharing
471            if let Ok(mut hot) = self.pool.hot_statements.write()
472                && hot.len() < MAX_HOT_STATEMENTS
473            {
474                hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
475            }
476
477            name
478        };
479
480        use crate::protocol::PgEncoder;
481        PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
482            .map_err(|e| PgError::Encode(e.to_string()))?;
483        PgEncoder::encode_execute_to(&mut conn.write_buf);
484        PgEncoder::encode_sync_to(&mut conn.write_buf);
485
486        conn.flush_write_buf().await?;
487
488        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
489
490        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
491        let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
492        let mut error: Option<PgError> = None;
493
494        loop {
495            let msg = conn.recv().await?;
496            match msg {
497                crate::protocol::BackendMessage::ParseComplete
498                | crate::protocol::BackendMessage::BindComplete => {}
499                crate::protocol::BackendMessage::ParameterDescription(_) => {}
500                crate::protocol::BackendMessage::RowDescription(fields) => {
501                    let info = Arc::new(ColumnInfo::from_fields(&fields));
502                    if is_cache_miss {
503                        conn.column_info_cache.insert(sql_hash, info.clone());
504                    }
505                    column_info = Some(info);
506                }
507                crate::protocol::BackendMessage::DataRow(data) => {
508                    if error.is_none() {
509                        rows.push(super::PgRow {
510                            columns: data,
511                            column_info: column_info.clone(),
512                        });
513                    }
514                }
515                crate::protocol::BackendMessage::CommandComplete(_) => {}
516                crate::protocol::BackendMessage::ReadyForQuery(_) => {
517                    if let Some(err) = error {
518                        return Err(err);
519                    }
520                    return Ok(rows);
521                }
522                crate::protocol::BackendMessage::ErrorResponse(err) => {
523                    if error.is_none() {
524                        error = Some(PgError::Query(err.message));
525                    }
526                }
527                _ => {}
528            }
529        }
530    }
531
532    /// Execute a QAIL command with RLS context in a SINGLE roundtrip.
533    ///
534    /// Pipelines the RLS setup (BEGIN + set_config) and the query
535    /// (Parse/Bind/Execute/Sync) into one `write_all` syscall.
536    /// PG processes messages in order, so the BEGIN + set_config
537    /// completes before the query executes — security is preserved.
538    ///
539    /// Wire layout:
540    /// ```text
541    /// [SimpleQuery: "BEGIN; SET LOCAL...; SELECT set_config(...)"]
542    /// [Parse (if cache miss)]
543    /// [Describe (if cache miss)]
544    /// [Bind]
545    /// [Execute]
546    /// [Sync]
547    /// ```
548    ///
549    /// Response processing: consume 2× ReadyForQuery (SimpleQuery + Sync).
550    pub async fn fetch_all_with_rls(
551        &mut self,
552        cmd: &qail_core::ast::Qail,
553        rls_sql: &str,
554    ) -> PgResult<Vec<super::PgRow>> {
555        use super::ColumnInfo;
556        use std::collections::hash_map::DefaultHasher;
557        use std::hash::{Hash, Hasher};
558
559        let conn = self.conn.as_mut().ok_or_else(|| PgError::Connection(
560            "Connection already released back to pool".into()
561        ))?;
562
563        conn.sql_buf.clear();
564        conn.params_buf.clear();
565
566        // Encode SQL + params to reusable buffers
567        match cmd.action {
568            qail_core::ast::Action::Get | qail_core::ast::Action::With => {
569                crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
570            }
571            qail_core::ast::Action::Add => {
572                crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
573            }
574            qail_core::ast::Action::Set => {
575                crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
576            }
577            qail_core::ast::Action::Del => {
578                crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
579            }
580            _ => {
581                // Fallback: RLS setup must happen synchronously for unsupported actions
582                conn.execute_simple(rls_sql).await?;
583                self.rls_dirty = true;
584                return self.fetch_all_uncached(cmd).await;
585            }
586        }
587
588        let mut hasher = DefaultHasher::new();
589        conn.sql_buf.hash(&mut hasher);
590        let sql_hash = hasher.finish();
591
592        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
593
594        conn.write_buf.clear();
595
596        // ── Prepend RLS Simple Query message ─────────────────────────
597        // This is the key optimization: RLS setup bytes go first in the
598        // same buffer as the query messages.
599        let rls_msg = crate::protocol::PgEncoder::encode_query_string(rls_sql);
600        conn.write_buf.extend_from_slice(&rls_msg);
601
602        // ── Then append the query messages (same as fetch_all_cached) ──
603        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
604            name.clone()
605        } else {
606            let name = format!("qail_{:x}", sql_hash);
607
608            conn.evict_prepared_if_full();
609
610            let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
611
612            use crate::protocol::PgEncoder;
613            let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
614            let describe_msg = PgEncoder::encode_describe(false, &name);
615            conn.write_buf.extend_from_slice(&parse_msg);
616            conn.write_buf.extend_from_slice(&describe_msg);
617
618            conn.stmt_cache.put(sql_hash, name.clone());
619            conn.prepared_statements.insert(name.clone(), sql_str.to_string());
620
621            if let Ok(mut hot) = self.pool.hot_statements.write()
622                && hot.len() < MAX_HOT_STATEMENTS
623            {
624                hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
625            }
626
627            name
628        };
629
630        use crate::protocol::PgEncoder;
631        PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
632            .map_err(|e| PgError::Encode(e.to_string()))?;
633        PgEncoder::encode_execute_to(&mut conn.write_buf);
634        PgEncoder::encode_sync_to(&mut conn.write_buf);
635
636        // ── Single write_all for RLS + Query ────────────────────────
637        conn.flush_write_buf().await?;
638
639        // Mark connection as RLS-dirty (needs COMMIT on release)
640        self.rls_dirty = true;
641
642        // ── Phase 1: Consume Simple Query responses (RLS setup) ─────
643        // Simple Query produces: CommandComplete × N, then ReadyForQuery.
644        // set_config results and BEGIN/SET LOCAL responses are all here.
645        let mut rls_error: Option<PgError> = None;
646        loop {
647            let msg = conn.recv().await?;
648            match msg {
649                crate::protocol::BackendMessage::ReadyForQuery(_) => {
650                    // RLS setup done — break to Extended Query phase
651                    if let Some(err) = rls_error {
652                        return Err(err);
653                    }
654                    break;
655                }
656                crate::protocol::BackendMessage::ErrorResponse(err) => {
657                    if rls_error.is_none() {
658                        rls_error = Some(PgError::Query(err.message));
659                    }
660                }
661                // CommandComplete, DataRow (from set_config), RowDescription — ignore
662                _ => {}
663            }
664        }
665
666        // ── Phase 2: Consume Extended Query responses (actual data) ──
667        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
668
669        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
670        let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
671        let mut error: Option<PgError> = None;
672
673        loop {
674            let msg = conn.recv().await?;
675            match msg {
676                crate::protocol::BackendMessage::ParseComplete
677                | crate::protocol::BackendMessage::BindComplete => {}
678                crate::protocol::BackendMessage::ParameterDescription(_) => {}
679                crate::protocol::BackendMessage::RowDescription(fields) => {
680                    let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
681                    if is_cache_miss {
682                        conn.column_info_cache.insert(sql_hash, info.clone());
683                    }
684                    column_info = Some(info);
685                }
686                crate::protocol::BackendMessage::DataRow(data) => {
687                    if error.is_none() {
688                        rows.push(super::PgRow {
689                            columns: data,
690                            column_info: column_info.clone(),
691                        });
692                    }
693                }
694                crate::protocol::BackendMessage::CommandComplete(_) => {}
695                crate::protocol::BackendMessage::ReadyForQuery(_) => {
696                    if let Some(err) = error {
697                        return Err(err);
698                    }
699                    return Ok(rows);
700                }
701                crate::protocol::BackendMessage::ErrorResponse(err) => {
702                    if error.is_none() {
703                        error = Some(PgError::Query(err.message));
704                    }
705                }
706                _ => {}
707            }
708        }
709    }
710
711    /// Execute multiple QAIL commands in a single PG pipeline round-trip.
712    ///
713    /// Sends all queries as Parse+Bind+Execute in one write, receives all
714    /// responses in one read. Returns raw column data per query per row.
715    ///
716    /// This is the fastest path for batch operations — amortizes TCP
717    /// overhead across N queries into a single syscall pair.
718    pub async fn pipeline_ast(
719        &mut self,
720        cmds: &[qail_core::ast::Qail],
721    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
722        let conn = self.conn_mut()?;
723        conn.pipeline_ast(cmds).await
724    }
725
726    /// Run `EXPLAIN (FORMAT JSON)` on a Qail command and return cost estimates.
727    ///
728    /// Uses `simple_query` under the hood — no additional round-trips beyond
729    /// the single EXPLAIN statement. Returns `None` if parsing fails or
730    /// the EXPLAIN output is unexpected.
731    pub async fn explain_estimate(
732        &mut self,
733        cmd: &qail_core::ast::Qail,
734    ) -> PgResult<Option<super::explain::ExplainEstimate>> {
735        use qail_core::transpiler::ToSql;
736
737        let sql = cmd.to_sql();
738        let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", sql);
739
740        let rows = self.simple_query(&explain_sql).await?;
741
742        // PostgreSQL returns the JSON plan as a single text column across one or more rows
743        let mut json_output = String::new();
744        for row in &rows {
745            if let Some(Some(val)) = row.columns.first()
746                && let Ok(text) = std::str::from_utf8(val)
747            {
748                json_output.push_str(text);
749            }
750        }
751
752        Ok(super::explain::parse_explain_json(&json_output))
753    }
754}
755
756impl Drop for PooledConnection {
757    fn drop(&mut self) {
758        if self.conn.is_some() {
759            // Safety net: connection was NOT released via `release()`.
760            // This happens when:
761            //   - Handler panicked
762            //   - Early return without calling release()
763            //   - Missed release() call (programming error)
764            //
765            // We DESTROY the connection (don't return to pool) to prevent
766            // dirty session state from being reused. This costs a pool slot
767            // but guarantees no cross-tenant leakage.
768            //
769            // The `conn` field is dropped here, closing the TCP socket.
770            eprintln!(
771                "[WARN] pool_connection_leaked: PooledConnection dropped without release() — \
772                 connection destroyed to prevent state leak (rls_dirty={}). \
773                 Use conn.release().await for deterministic cleanup.",
774                self.rls_dirty
775            );
776            // Decrement active count so pool can create a replacement
777            self.pool.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
778        }
779    }
780}
781
782impl std::ops::Deref for PooledConnection {
783    type Target = PgConnection;
784
785    fn deref(&self) -> &Self::Target {
786        // SAFETY: Connection is always Some while PooledConnection is alive and in use.
787        // Only becomes None after release() consumes self, or during Drop.
788        self.conn
789            .as_ref()
790            .expect("PooledConnection::deref called after release — this is a bug")
791    }
792}
793
794impl std::ops::DerefMut for PooledConnection {
795    fn deref_mut(&mut self) -> &mut Self::Target {
796        // SAFETY: Connection is always Some while PooledConnection is alive and in use.
797        // Only becomes None after release() consumes self, or during Drop.
798        self.conn
799            .as_mut()
800            .expect("PooledConnection::deref_mut called after release — this is a bug")
801    }
802}
803
804/// Maximum number of hot statements to track globally.
805const MAX_HOT_STATEMENTS: usize = 32;
806
807/// Inner pool state (shared across clones).
808struct PgPoolInner {
809    config: PoolConfig,
810    connections: Mutex<Vec<PooledConn>>,
811    semaphore: Semaphore,
812    closed: AtomicBool,
813    active_count: AtomicUsize,
814    total_created: AtomicUsize,
815    /// Global registry of frequently-used prepared statements.
816    /// Maps sql_hash → (stmt_name, sql_text).
817    /// New connections pre-prepare these on checkout for instant cache hits.
818    hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
819}
820
821impl PgPoolInner {
822    async fn return_connection(&self, conn: PgConnection) {
823
824        self.active_count.fetch_sub(1, Ordering::Relaxed);
825        
826
827        if self.closed.load(Ordering::Relaxed) {
828            return;
829        }
830        
831        let mut connections = self.connections.lock().await;
832        if connections.len() < self.config.max_connections {
833            connections.push(PooledConn {
834                conn,
835                created_at: Instant::now(),
836                last_used: Instant::now(),
837            });
838        }
839
840        self.semaphore.add_permits(1);
841    }
842
843    /// Get a healthy connection from the pool, or None if pool is empty.
844    async fn get_healthy_connection(&self) -> Option<PgConnection> {
845        let mut connections = self.connections.lock().await;
846
847        while let Some(pooled) = connections.pop() {
848            if pooled.last_used.elapsed() > self.config.idle_timeout {
849                // Connection is stale, drop it
850                continue;
851            }
852
853            if let Some(max_life) = self.config.max_lifetime
854                && pooled.created_at.elapsed() > max_life
855            {
856                // Connection exceeded max lifetime, recycle it
857                continue;
858            }
859
860            return Some(pooled.conn);
861        }
862
863        None
864    }
865}
866
867/// # Example
868/// ```ignore
869/// let config = PoolConfig::new("localhost", 5432, "user", "db")
870///     .password("secret")
871///     .max_connections(20);
872/// let pool = PgPool::connect(config).await?;
873/// // Get a connection from the pool
874/// let mut conn = pool.acquire_raw().await?;
875/// conn.simple_query("SELECT 1").await?;
876/// ```
877#[derive(Clone)]
878pub struct PgPool {
879    inner: Arc<PgPoolInner>,
880}
881
882impl PgPool {
883    /// Create a pool from `qail.toml` (loads and parses automatically).
884    ///
885    /// # Example
886    /// ```ignore
887    /// let pool = PgPool::from_config().await?;
888    /// ```
889    pub async fn from_config() -> PgResult<Self> {
890        let qail = qail_core::config::QailConfig::load()
891            .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
892        let config = PoolConfig::from_qail_config(&qail)?;
893        Self::connect(config).await
894    }
895
896    /// Create a new connection pool.
897    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
898        // Semaphore starts with max_connections permits
899        let semaphore = Semaphore::new(config.max_connections);
900
901        let mut initial_connections = Vec::new();
902        for _ in 0..config.min_connections {
903            let conn = Self::create_connection(&config).await?;
904            initial_connections.push(PooledConn {
905                conn,
906                created_at: Instant::now(),
907                last_used: Instant::now(),
908            });
909        }
910
911        let initial_count = initial_connections.len();
912
913        let inner = Arc::new(PgPoolInner {
914            config,
915            connections: Mutex::new(initial_connections),
916            semaphore,
917            closed: AtomicBool::new(false),
918            active_count: AtomicUsize::new(0),
919            total_created: AtomicUsize::new(initial_count),
920            hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
921        });
922
923        Ok(Self { inner })
924    }
925
926    /// Acquire a raw connection from the pool (crate-internal only).
927    ///
928    /// # Safety (not `unsafe` in the Rust sense, but security-critical)
929    ///
930    /// This returns a connection with **no RLS context**. All tenant data
931    /// queries on this connection will bypass row-level security.
932    ///
933    /// **Safe usage**: Pair with `fetch_all_with_rls()` for pipelined
934    /// RLS+query execution (single roundtrip). Or use `acquire_with_rls()`
935    /// / `acquire_with_rls_timeout()` for the 2-roundtrip path.
936    ///
937    /// **Unsafe usage**: Running queries directly on a raw connection
938    /// without RLS context. Every call site MUST include a `// SAFETY:`
939    /// comment explaining why raw acquisition is justified.
940    pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
941        if self.inner.closed.load(Ordering::Relaxed) {
942            return Err(PgError::PoolClosed);
943        }
944
945        // Wait for available slot with timeout
946        let acquire_timeout = self.inner.config.acquire_timeout;
947        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
948            .await
949            .map_err(|_| {
950                PgError::Timeout(format!(
951                    "pool acquire after {}s ({} max connections)",
952                    acquire_timeout.as_secs(),
953                    self.inner.config.max_connections
954                ))
955            })?
956            .map_err(|_| PgError::PoolClosed)?;
957        permit.forget();
958
959        // Try to get existing healthy connection
960        let mut conn = if let Some(conn) = self.inner.get_healthy_connection().await {
961            conn
962        } else {
963            let conn = Self::create_connection(&self.inner.config).await?;
964            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
965            conn
966        };
967
968        // Pre-prepare hot statements that this connection doesn't have yet.
969        // Collect data synchronously (guard dropped before async work).
970        let missing: Vec<(u64, String, String)> = {
971            if let Ok(hot) = self.inner.hot_statements.read() {
972                hot.iter()
973                    .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
974                    .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
975                    .collect()
976            } else {
977                Vec::new()
978            }
979        }; // RwLockReadGuard dropped here — safe across .await
980
981        if !missing.is_empty() {
982            use crate::protocol::PgEncoder;
983            let mut buf = bytes::BytesMut::new();
984            for (_, name, sql) in &missing {
985                let parse_msg = PgEncoder::encode_parse(name, sql, &[]);
986                buf.extend_from_slice(&parse_msg);
987            }
988            PgEncoder::encode_sync_to(&mut buf);
989            if conn.send_bytes(&buf).await.is_ok() {
990                // Drain responses (ParseComplete + ReadyForQuery)
991                loop {
992                    match conn.recv().await {
993                        Ok(crate::protocol::BackendMessage::ReadyForQuery(_)) => break,
994                        Ok(_) => continue,
995                        Err(_) => break,
996                    }
997                }
998                // Register in local cache
999                for (hash, name, sql) in &missing {
1000                    conn.stmt_cache.put(*hash, name.clone());
1001                    conn.prepared_statements.insert(name.clone(), sql.clone());
1002                }
1003            }
1004        }
1005
1006        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
1007
1008        Ok(PooledConnection {
1009            conn: Some(conn),
1010            pool: self.inner.clone(),
1011            rls_dirty: false,
1012        })
1013    }
1014
1015    /// Acquire a connection with RLS context pre-configured.
1016    ///
1017    /// Sets PostgreSQL session variables for tenant isolation before
1018    /// returning the connection. When the connection is dropped, it
1019    /// automatically clears the RLS context before returning to the pool.
1020    ///
1021    /// # Example
1022    /// ```ignore
1023    /// use qail_core::rls::RlsContext;
1024    ///
1025    /// let mut conn = pool.acquire_with_rls(
1026    ///     RlsContext::operator("550e8400-e29b-41d4-a716-446655440000")
1027    /// ).await?;
1028    /// // All queries through `conn` are now scoped to this operator
1029    /// ```
1030    pub async fn acquire_with_rls(
1031        &self,
1032        ctx: qail_core::rls::RlsContext,
1033    ) -> PgResult<PooledConnection> {
1034        // SAFETY: RLS context is set immediately below via context_to_sql().
1035        let mut conn = self.acquire_raw().await?;
1036
1037        // Set RLS context on the raw connection
1038        let sql = super::rls::context_to_sql(&ctx);
1039        let pg_conn = conn.get_mut();
1040        pg_conn.execute_simple(&sql).await?;
1041
1042        // Mark dirty so Drop resets context before pool return
1043        conn.rls_dirty = true;
1044
1045        Ok(conn)
1046    }
1047
1048    /// Acquire a connection with RLS context AND statement timeout.
1049    ///
1050    /// Like `acquire_with_rls()`, but also sets `statement_timeout` to prevent
1051    /// runaway queries from holding pool connections indefinitely.
1052    pub async fn acquire_with_rls_timeout(
1053        &self,
1054        ctx: qail_core::rls::RlsContext,
1055        timeout_ms: u32,
1056    ) -> PgResult<PooledConnection> {
1057        // SAFETY: RLS context + timeout set immediately below via context_to_sql_with_timeout().
1058        let mut conn = self.acquire_raw().await?;
1059
1060        // Set RLS context + statement_timeout atomically
1061        let sql = super::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
1062        let pg_conn = conn.get_mut();
1063        pg_conn.execute_simple(&sql).await?;
1064
1065        // Mark dirty so Drop resets context + timeout before pool return
1066        conn.rls_dirty = true;
1067
1068        Ok(conn)
1069    }
1070
1071    /// Acquire a connection for system-level operations (no tenant context).
1072    ///
1073    /// Sets RLS session variables to maximally restrictive values:
1074    /// - `app.current_operator_id = ''`
1075    /// - `app.current_agent_id = ''`  
1076    /// - `app.is_super_admin = false`
1077    ///
1078    /// Use this for startup introspection, migrations, and health checks
1079    /// that must not operate within any tenant scope.
1080    pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
1081        let ctx = qail_core::rls::RlsContext::empty();
1082        self.acquire_with_rls(ctx).await
1083    }
1084
1085    /// Acquire a connection with branch context pre-configured.
1086    ///
1087    /// Sets PostgreSQL session variable `app.branch_id` for data virtualization.
1088    /// When the connection is dropped, it automatically clears the branch context.
1089    ///
1090    /// # Example
1091    /// ```ignore
1092    /// use qail_core::branch::BranchContext;
1093    ///
1094    /// let ctx = BranchContext::branch("feature-auth");
1095    /// let mut conn = pool.acquire_with_branch(&ctx).await?;
1096    /// // All queries through `conn` are now branch-aware
1097    /// ```
1098    pub async fn acquire_with_branch(
1099        &self,
1100        ctx: &qail_core::branch::BranchContext,
1101    ) -> PgResult<PooledConnection> {
1102        // SAFETY: Branch context is set immediately below via branch_context_sql().
1103        let mut conn = self.acquire_raw().await?;
1104
1105        if let Some(branch_name) = ctx.branch_name() {
1106            let sql = super::branch_sql::branch_context_sql(branch_name);
1107            let pg_conn = conn.get_mut();
1108            pg_conn.execute_simple(&sql).await?;
1109            conn.rls_dirty = true; // Reuse dirty flag for auto-reset
1110        }
1111
1112        Ok(conn)
1113    }
1114
1115    /// Get the current number of idle connections.
1116    pub async fn idle_count(&self) -> usize {
1117        self.inner.connections.lock().await.len()
1118    }
1119
1120    /// Get the number of connections currently in use.
1121    pub fn active_count(&self) -> usize {
1122        self.inner.active_count.load(Ordering::Relaxed)
1123    }
1124
1125    /// Get the maximum number of connections.
1126    pub fn max_connections(&self) -> usize {
1127        self.inner.config.max_connections
1128    }
1129
1130    /// Get comprehensive pool statistics.
1131    pub async fn stats(&self) -> PoolStats {
1132        let idle = self.inner.connections.lock().await.len();
1133        PoolStats {
1134            active: self.inner.active_count.load(Ordering::Relaxed),
1135            idle,
1136            pending: self.inner.config.max_connections
1137                - self.inner.semaphore.available_permits()
1138                - self.active_count(),
1139            max_size: self.inner.config.max_connections,
1140            total_created: self.inner.total_created.load(Ordering::Relaxed),
1141        }
1142    }
1143
1144    /// Check if the pool is closed.
1145    pub fn is_closed(&self) -> bool {
1146        self.inner.closed.load(Ordering::Relaxed)
1147    }
1148
1149    /// Close the pool gracefully.
1150    pub async fn close(&self) {
1151        self.inner.closed.store(true, Ordering::Relaxed);
1152
1153        let mut connections = self.inner.connections.lock().await;
1154        connections.clear();
1155    }
1156
1157    /// Create a new connection using the pool configuration.
1158    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
1159        match &config.password {
1160            Some(password) => {
1161                PgConnection::connect_with_password(
1162                    &config.host,
1163                    config.port,
1164                    &config.user,
1165                    &config.database,
1166                    Some(password),
1167                )
1168                .await
1169            }
1170            None => {
1171                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
1172                    .await
1173            }
1174        }
1175    }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use super::*;
1181
1182    #[test]
1183    fn test_pool_config() {
1184        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
1185            .password("secret123")
1186            .max_connections(20)
1187            .min_connections(5);
1188
1189        assert_eq!(config.host, "localhost");
1190        assert_eq!(config.port, 5432);
1191        assert_eq!(config.user, "user");
1192        assert_eq!(config.database, "testdb");
1193        assert_eq!(config.password, Some("secret123".to_string()));
1194        assert_eq!(config.max_connections, 20);
1195        assert_eq!(config.min_connections, 5);
1196    }
1197
1198    #[test]
1199    fn test_pool_config_defaults() {
1200        let config = PoolConfig::new("localhost", 5432, "user", "testdb");
1201        assert_eq!(config.max_connections, 10);
1202        assert_eq!(config.min_connections, 1);
1203        assert_eq!(config.idle_timeout, Duration::from_secs(600));
1204        assert_eq!(config.acquire_timeout, Duration::from_secs(30));
1205        assert_eq!(config.connect_timeout, Duration::from_secs(10));
1206        assert!(config.password.is_none());
1207    }
1208
1209    #[test]
1210    fn test_pool_config_builder_chaining() {
1211        let config = PoolConfig::new("db.example.com", 5433, "admin", "prod")
1212            .password("p@ss")
1213            .max_connections(50)
1214            .min_connections(10)
1215            .idle_timeout(Duration::from_secs(300))
1216            .acquire_timeout(Duration::from_secs(5))
1217            .connect_timeout(Duration::from_secs(3))
1218            .max_lifetime(Duration::from_secs(3600))
1219            .test_on_acquire(false);
1220
1221        assert_eq!(config.host, "db.example.com");
1222        assert_eq!(config.port, 5433);
1223        assert_eq!(config.max_connections, 50);
1224        assert_eq!(config.min_connections, 10);
1225        assert_eq!(config.idle_timeout, Duration::from_secs(300));
1226        assert_eq!(config.acquire_timeout, Duration::from_secs(5));
1227        assert_eq!(config.connect_timeout, Duration::from_secs(3));
1228        assert_eq!(config.max_lifetime, Some(Duration::from_secs(3600)));
1229        assert!(!config.test_on_acquire);
1230    }
1231
1232    #[test]
1233    fn test_timeout_error_display() {
1234        let err = PgError::Timeout("pool acquire after 30s (10 max connections)".to_string());
1235        let msg = err.to_string();
1236        assert!(msg.contains("Timeout"));
1237        assert!(msg.contains("30s"));
1238        assert!(msg.contains("10 max connections"));
1239    }
1240
1241    #[test]
1242    fn test_pool_closed_error_display() {
1243        let err = PgError::PoolClosed;
1244        assert_eq!(err.to_string(), "Connection pool is closed");
1245    }
1246
1247    #[test]
1248    fn test_pool_exhausted_error_display() {
1249        let err = PgError::PoolExhausted { max: 20 };
1250        let msg = err.to_string();
1251        assert!(msg.contains("exhausted"));
1252        assert!(msg.contains("20"));
1253    }
1254
1255    #[test]
1256    fn test_io_error_source_chaining() {
1257        use std::error::Error;
1258        let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "peer reset");
1259        let pg_err = PgError::Io(io_err);
1260        // source() should return the inner io::Error
1261        let source = pg_err.source().expect("Io variant should have source");
1262        assert!(source.to_string().contains("peer reset"));
1263    }
1264
1265    #[test]
1266    fn test_non_io_errors_have_no_source() {
1267        use std::error::Error;
1268        assert!(PgError::Connection("test".into()).source().is_none());
1269        assert!(PgError::Query("test".into()).source().is_none());
1270        assert!(PgError::Timeout("test".into()).source().is_none());
1271        assert!(PgError::PoolClosed.source().is_none());
1272        assert!(PgError::NoRows.source().is_none());
1273    }
1274
1275    #[test]
1276    fn test_io_error_from_conversion() {
1277        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
1278        let pg_err: PgError = io_err.into();
1279        assert!(matches!(pg_err, PgError::Io(_)));
1280        assert!(pg_err.to_string().contains("broken"));
1281    }
1282
1283    #[test]
1284    fn test_error_variants_are_distinct() {
1285        // Ensure we can match on each variant for programmatic error handling
1286        let errors: Vec<PgError> = vec![
1287            PgError::Connection("conn".into()),
1288            PgError::Protocol("proto".into()),
1289            PgError::Auth("auth".into()),
1290            PgError::Query("query".into()),
1291            PgError::NoRows,
1292            PgError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io")),
1293            PgError::Encode("enc".into()),
1294            PgError::Timeout("timeout".into()),
1295            PgError::PoolExhausted { max: 10 },
1296            PgError::PoolClosed,
1297        ];
1298        // All 10 variants produce non-empty display strings
1299        for err in &errors {
1300            assert!(!err.to_string().is_empty());
1301        }
1302        assert_eq!(errors.len(), 10);
1303    }
1304}
1305