Skip to main content

qail_pg/driver/
pool.rs

1//! PostgreSQL Connection Pool
2//!
3//! Provides connection pooling for efficient resource management.
4//! Connections are reused across queries to avoid reconnection overhead.
5
6use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14    pub host: String,
15    pub port: u16,
16    pub user: String,
17    pub database: String,
18    pub password: Option<String>,
19    pub max_connections: usize,
20    pub min_connections: usize,
21    pub idle_timeout: Duration,
22    pub acquire_timeout: Duration,
23    pub connect_timeout: Duration,
24    pub max_lifetime: Option<Duration>,
25    pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29    /// Create a new pool configuration with sensible defaults.
30    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31        Self {
32            host: host.to_string(),
33            port,
34            user: user.to_string(),
35            database: database.to_string(),
36            password: None,
37            max_connections: 10,
38            min_connections: 1,
39            idle_timeout: Duration::from_secs(600), // 10 minutes
40            acquire_timeout: Duration::from_secs(30), // 30 seconds
41            connect_timeout: Duration::from_secs(10), // 10 seconds
42            max_lifetime: None,                      // No limit by default
43            test_on_acquire: false,                  // Disabled by default for performance
44        }
45    }
46
47    /// Set password for authentication.
48    pub fn password(mut self, password: &str) -> Self {
49        self.password = Some(password.to_string());
50        self
51    }
52
53    pub fn max_connections(mut self, max: usize) -> Self {
54        self.max_connections = max;
55        self
56    }
57
58    /// Set minimum idle connections.
59    pub fn min_connections(mut self, min: usize) -> Self {
60        self.min_connections = min;
61        self
62    }
63
64    /// Set idle timeout (connections idle longer than this are closed).
65    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66        self.idle_timeout = timeout;
67        self
68    }
69
70    /// Set acquire timeout (max wait time when getting a connection).
71    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72        self.acquire_timeout = timeout;
73        self
74    }
75
76    /// Set connect timeout (max time to establish new connection).
77    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78        self.connect_timeout = timeout;
79        self
80    }
81
82    /// Set maximum lifetime of a connection before recycling.
83    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84        self.max_lifetime = Some(lifetime);
85        self
86    }
87
88    /// Enable connection validation on acquire.
89    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90        self.test_on_acquire = enabled;
91        self
92    }
93
94    /// Create a `PoolConfig` from a centralized `QailConfig`.
95    ///
96    /// Parses `postgres.url` for host/port/user/database/password
97    /// and applies pool tuning from `[postgres]` section.
98    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
99        let pg = &qail.postgres;
100        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
101
102        let mut config = PoolConfig::new(&host, port, &user, &database)
103            .max_connections(pg.max_connections)
104            .min_connections(pg.min_connections)
105            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
106            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
107            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
108            .test_on_acquire(pg.test_on_acquire);
109
110        if let Some(ref pw) = password {
111            config = config.password(pw);
112        }
113
114        Ok(config)
115    }
116}
117
118/// Parse a postgres URL into (host, port, user, database, password).
119fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
120    let url = url.trim_start_matches("postgres://").trim_start_matches("postgresql://");
121
122    let (credentials, host_part) = if url.contains('@') {
123        let mut parts = url.splitn(2, '@');
124        let creds = parts.next().unwrap_or("");
125        let host = parts.next().unwrap_or("localhost/postgres");
126        (Some(creds), host)
127    } else {
128        (None, url)
129    };
130
131    let (host_port, database) = if host_part.contains('/') {
132        let mut parts = host_part.splitn(2, '/');
133        (parts.next().unwrap_or("localhost"), parts.next().unwrap_or("postgres").to_string())
134    } else {
135        (host_part, "postgres".to_string())
136    };
137
138    let (host, port) = if host_port.contains(':') {
139        let mut parts = host_port.split(':');
140        let h = parts.next().unwrap_or("localhost").to_string();
141        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
142        (h, p)
143    } else {
144        (host_port.to_string(), 5432u16)
145    };
146
147    let (user, password) = if let Some(creds) = credentials {
148        if creds.contains(':') {
149            let mut parts = creds.splitn(2, ':');
150            let u = parts.next().unwrap_or("postgres").to_string();
151            let p = parts.next().map(|s| s.to_string());
152            (u, p)
153        } else {
154            (creds.to_string(), None)
155        }
156    } else {
157        ("postgres".to_string(), None)
158    };
159
160    Ok((host, port, user, database, password))
161}
162
163/// Pool statistics for monitoring.
164#[derive(Debug, Clone, Default)]
165pub struct PoolStats {
166    pub active: usize,
167    pub idle: usize,
168    pub pending: usize,
169    /// Maximum connections configured
170    pub max_size: usize,
171    pub total_created: usize,
172}
173
174/// A pooled connection with creation timestamp for idle tracking.
175struct PooledConn {
176    conn: PgConnection,
177    created_at: Instant,
178    last_used: Instant,
179}
180
181/// A pooled connection that returns to the pool when dropped.
182///
183/// When `rls_dirty` is true (set by `acquire_with_rls`), the connection
184/// will automatically reset RLS session variables before returning to
185/// the pool. This prevents cross-tenant data leakage.
186pub struct PooledConnection {
187    conn: Option<PgConnection>,
188    pool: Arc<PgPoolInner>,
189    rls_dirty: bool,
190}
191
192impl PooledConnection {
193    /// Get a mutable reference to the underlying connection.
194    pub fn get_mut(&mut self) -> &mut PgConnection {
195        self.conn
196            .as_mut()
197            .expect("Connection should always be present")
198    }
199
200    /// Get a token to cancel the currently running query.
201    pub fn cancel_token(&self) -> crate::driver::CancelToken {
202        let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
203        crate::driver::CancelToken {
204            host: self.pool.config.host.clone(),
205            port: self.pool.config.port,
206            process_id,
207            secret_key,
208        }
209    }
210
211    /// Deterministic connection cleanup and pool return.
212    ///
213    /// This is the **correct** way to return a connection to the pool.
214    /// COMMITs the transaction (which auto-resets transaction-local RLS
215    /// session variables) and returns the connection to the pool with
216    /// prepared statement caches intact.
217    ///
218    /// If cleanup fails, the connection is destroyed (not returned to pool).
219    ///
220    /// # Usage
221    /// ```ignore
222    /// let mut conn = pool.acquire_with_rls(ctx).await?;
223    /// let result = conn.fetch_all_cached(&cmd).await;
224    /// conn.release().await; // COMMIT + return to pool
225    /// result
226    /// ```
227    pub async fn release(mut self) {
228        if let Some(mut conn) = self.conn.take() {
229            // COMMIT the transaction opened by acquire_with_rls.
230            // Transaction-local set_config values auto-reset on COMMIT,
231            // so no explicit RLS cleanup is needed.
232            // Prepared statements survive — they are NOT transaction-scoped.
233            if let Err(e) = conn.execute_simple(super::rls::reset_sql()).await {
234                eprintln!(
235                    "[CRITICAL] pool_release_failed: COMMIT failed — \
236                     dropping connection to prevent state leak: {}",
237                    e
238                );
239                return; // Connection destroyed — not returned to pool
240            }
241
242            self.pool.return_connection(conn).await;
243        }
244    }
245
246    /// Execute a QAIL command and fetch all rows (UNCACHED).
247    /// Returns rows with column metadata for JSON serialization.
248    pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
249        use crate::protocol::AstEncoder;
250        use super::ColumnInfo;
251
252        let conn = self.conn.as_mut().expect("Connection should always be present");
253
254        let wire_bytes = AstEncoder::encode_cmd_reuse(
255            cmd,
256            &mut conn.sql_buf,
257            &mut conn.params_buf,
258        )
259        .map_err(|e| PgError::Encode(e.to_string()))?;
260
261        conn.send_bytes(&wire_bytes).await?;
262
263        let mut rows: Vec<super::PgRow> = Vec::new();
264        let mut column_info: Option<Arc<ColumnInfo>> = None;
265        let mut error: Option<PgError> = None;
266
267        loop {
268            let msg = conn.recv().await?;
269            match msg {
270                crate::protocol::BackendMessage::ParseComplete
271                | crate::protocol::BackendMessage::BindComplete => {}
272                crate::protocol::BackendMessage::RowDescription(fields) => {
273                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
274                }
275                crate::protocol::BackendMessage::DataRow(data) => {
276                    if error.is_none() {
277                        rows.push(super::PgRow {
278                            columns: data,
279                            column_info: column_info.clone(),
280                        });
281                    }
282                }
283                crate::protocol::BackendMessage::CommandComplete(_) => {}
284                crate::protocol::BackendMessage::ReadyForQuery(_) => {
285                    if let Some(err) = error {
286                        return Err(err);
287                    }
288                    return Ok(rows);
289                }
290                crate::protocol::BackendMessage::ErrorResponse(err) => {
291                    if error.is_none() {
292                        error = Some(PgError::Query(err.message));
293                    }
294                }
295                _ => {}
296            }
297        }
298    }
299
300    /// Execute a QAIL command and fetch all rows (FAST VERSION).
301    /// Uses native AST-to-wire encoding and optimized recv_with_data_fast.
302    /// Skips column metadata for maximum speed.
303    pub async fn fetch_all_fast(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
304        use crate::protocol::AstEncoder;
305
306        let conn = self.conn.as_mut().expect("Connection should always be present");
307
308        AstEncoder::encode_cmd_reuse_into(
309            cmd,
310            &mut conn.sql_buf,
311            &mut conn.params_buf,
312            &mut conn.write_buf,
313        )
314        .map_err(|e| PgError::Encode(e.to_string()))?;
315
316        conn.flush_write_buf().await?;
317
318        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
319        let mut error: Option<PgError> = None;
320
321        loop {
322            let res = conn.recv_with_data_fast().await;
323            match res {
324                Ok((msg_type, data)) => {
325                    match msg_type {
326                        b'D' => {
327                            if error.is_none() && let Some(columns) = data {
328                                rows.push(super::PgRow {
329                                    columns,
330                                    column_info: None,
331                                });
332                            }
333                        }
334                        b'Z' => {
335                            if let Some(err) = error {
336                                return Err(err);
337                            }
338                            return Ok(rows);
339                        }
340                        _ => {}
341                    }
342                }
343                Err(e) => {
344                    if error.is_none() {
345                        error = Some(e);
346                    }
347                }
348            }
349        }
350    }
351
352    /// Execute a QAIL command and fetch all rows (CACHED).
353    /// Uses prepared statement caching: Parse+Describe on first call,
354    /// then Bind+Execute only on subsequent calls with the same SQL shape.
355    /// This matches PostgREST's behavior for fair benchmarks.
356    pub async fn fetch_all_cached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
357        use super::ColumnInfo;
358        use std::collections::hash_map::DefaultHasher;
359        use std::hash::{Hash, Hasher};
360
361        let conn = self.conn.as_mut().expect("Connection should always be present");
362
363        conn.sql_buf.clear();
364        conn.params_buf.clear();
365
366        // Encode SQL + params to reusable buffers
367        match cmd.action {
368            qail_core::ast::Action::Get | qail_core::ast::Action::With => {
369                crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
370            }
371            qail_core::ast::Action::Add => {
372                crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
373            }
374            qail_core::ast::Action::Set => {
375                crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
376            }
377            qail_core::ast::Action::Del => {
378                crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
379            }
380            _ => {
381                // Fallback: unsupported actions go through uncached path
382                return self.fetch_all_uncached(cmd).await;
383            }
384        }
385
386        let mut hasher = DefaultHasher::new();
387        conn.sql_buf.hash(&mut hasher);
388        let sql_hash = hasher.finish();
389
390        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
391
392        conn.write_buf.clear();
393
394        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
395            name.clone()
396        } else {
397            let name = format!("qail_{:x}", sql_hash);
398
399            conn.evict_prepared_if_full();
400
401            let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
402
403            use crate::protocol::PgEncoder;
404            let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
405            let describe_msg = PgEncoder::encode_describe(false, &name);
406            conn.write_buf.extend_from_slice(&parse_msg);
407            conn.write_buf.extend_from_slice(&describe_msg);
408
409            conn.stmt_cache.put(sql_hash, name.clone());
410            conn.prepared_statements.insert(name.clone(), sql_str.to_string());
411
412            // Register in global hot-statement registry for cross-connection sharing
413            if let Ok(mut hot) = self.pool.hot_statements.write()
414                && hot.len() < MAX_HOT_STATEMENTS
415            {
416                hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
417            }
418
419            name
420        };
421
422        use crate::protocol::PgEncoder;
423        PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
424            .map_err(|e| PgError::Encode(e.to_string()))?;
425        PgEncoder::encode_execute_to(&mut conn.write_buf);
426        PgEncoder::encode_sync_to(&mut conn.write_buf);
427
428        conn.flush_write_buf().await?;
429
430        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
431
432        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
433        let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
434        let mut error: Option<PgError> = None;
435
436        loop {
437            let msg = conn.recv().await?;
438            match msg {
439                crate::protocol::BackendMessage::ParseComplete
440                | crate::protocol::BackendMessage::BindComplete => {}
441                crate::protocol::BackendMessage::ParameterDescription(_) => {}
442                crate::protocol::BackendMessage::RowDescription(fields) => {
443                    let info = Arc::new(ColumnInfo::from_fields(&fields));
444                    if is_cache_miss {
445                        conn.column_info_cache.insert(sql_hash, info.clone());
446                    }
447                    column_info = Some(info);
448                }
449                crate::protocol::BackendMessage::DataRow(data) => {
450                    if error.is_none() {
451                        rows.push(super::PgRow {
452                            columns: data,
453                            column_info: column_info.clone(),
454                        });
455                    }
456                }
457                crate::protocol::BackendMessage::CommandComplete(_) => {}
458                crate::protocol::BackendMessage::ReadyForQuery(_) => {
459                    if let Some(err) = error {
460                        return Err(err);
461                    }
462                    return Ok(rows);
463                }
464                crate::protocol::BackendMessage::ErrorResponse(err) => {
465                    if error.is_none() {
466                        error = Some(PgError::Query(err.message));
467                    }
468                }
469                _ => {}
470            }
471        }
472    }
473
474    /// Execute a QAIL command with RLS context in a SINGLE roundtrip.
475    ///
476    /// Pipelines the RLS setup (BEGIN + set_config) and the query
477    /// (Parse/Bind/Execute/Sync) into one `write_all` syscall.
478    /// PG processes messages in order, so the BEGIN + set_config
479    /// completes before the query executes — security is preserved.
480    ///
481    /// Wire layout:
482    /// ```text
483    /// [SimpleQuery: "BEGIN; SET LOCAL...; SELECT set_config(...)"]
484    /// [Parse (if cache miss)]
485    /// [Describe (if cache miss)]
486    /// [Bind]
487    /// [Execute]
488    /// [Sync]
489    /// ```
490    ///
491    /// Response processing: consume 2× ReadyForQuery (SimpleQuery + Sync).
492    pub async fn fetch_all_with_rls(
493        &mut self,
494        cmd: &qail_core::ast::Qail,
495        rls_sql: &str,
496    ) -> PgResult<Vec<super::PgRow>> {
497        use super::ColumnInfo;
498        use std::collections::hash_map::DefaultHasher;
499        use std::hash::{Hash, Hasher};
500
501        let conn = self.conn.as_mut().expect("Connection should always be present");
502
503        conn.sql_buf.clear();
504        conn.params_buf.clear();
505
506        // Encode SQL + params to reusable buffers
507        match cmd.action {
508            qail_core::ast::Action::Get | qail_core::ast::Action::With => {
509                crate::protocol::ast_encoder::dml::encode_select(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
510            }
511            qail_core::ast::Action::Add => {
512                crate::protocol::ast_encoder::dml::encode_insert(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
513            }
514            qail_core::ast::Action::Set => {
515                crate::protocol::ast_encoder::dml::encode_update(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
516            }
517            qail_core::ast::Action::Del => {
518                crate::protocol::ast_encoder::dml::encode_delete(cmd, &mut conn.sql_buf, &mut conn.params_buf).ok();
519            }
520            _ => {
521                // Fallback: RLS setup must happen synchronously for unsupported actions
522                conn.execute_simple(rls_sql).await?;
523                self.rls_dirty = true;
524                return self.fetch_all_uncached(cmd).await;
525            }
526        }
527
528        let mut hasher = DefaultHasher::new();
529        conn.sql_buf.hash(&mut hasher);
530        let sql_hash = hasher.finish();
531
532        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
533
534        conn.write_buf.clear();
535
536        // ── Prepend RLS Simple Query message ─────────────────────────
537        // This is the key optimization: RLS setup bytes go first in the
538        // same buffer as the query messages.
539        let rls_msg = crate::protocol::PgEncoder::encode_query_string(rls_sql);
540        conn.write_buf.extend_from_slice(&rls_msg);
541
542        // ── Then append the query messages (same as fetch_all_cached) ──
543        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
544            name.clone()
545        } else {
546            let name = format!("qail_{:x}", sql_hash);
547
548            conn.evict_prepared_if_full();
549
550            let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
551
552            use crate::protocol::PgEncoder;
553            let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
554            let describe_msg = PgEncoder::encode_describe(false, &name);
555            conn.write_buf.extend_from_slice(&parse_msg);
556            conn.write_buf.extend_from_slice(&describe_msg);
557
558            conn.stmt_cache.put(sql_hash, name.clone());
559            conn.prepared_statements.insert(name.clone(), sql_str.to_string());
560
561            if let Ok(mut hot) = self.pool.hot_statements.write()
562                && hot.len() < MAX_HOT_STATEMENTS
563            {
564                hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
565            }
566
567            name
568        };
569
570        use crate::protocol::PgEncoder;
571        PgEncoder::encode_bind_to(&mut conn.write_buf, &stmt_name, &conn.params_buf)
572            .map_err(|e| PgError::Encode(e.to_string()))?;
573        PgEncoder::encode_execute_to(&mut conn.write_buf);
574        PgEncoder::encode_sync_to(&mut conn.write_buf);
575
576        // ── Single write_all for RLS + Query ────────────────────────
577        conn.flush_write_buf().await?;
578
579        // Mark connection as RLS-dirty (needs COMMIT on release)
580        self.rls_dirty = true;
581
582        // ── Phase 1: Consume Simple Query responses (RLS setup) ─────
583        // Simple Query produces: CommandComplete × N, then ReadyForQuery.
584        // set_config results and BEGIN/SET LOCAL responses are all here.
585        let mut rls_error: Option<PgError> = None;
586        loop {
587            let msg = conn.recv().await?;
588            match msg {
589                crate::protocol::BackendMessage::ReadyForQuery(_) => {
590                    // RLS setup done — break to Extended Query phase
591                    if let Some(err) = rls_error {
592                        return Err(err);
593                    }
594                    break;
595                }
596                crate::protocol::BackendMessage::ErrorResponse(err) => {
597                    if rls_error.is_none() {
598                        rls_error = Some(PgError::Query(err.message));
599                    }
600                }
601                // CommandComplete, DataRow (from set_config), RowDescription — ignore
602                _ => {}
603            }
604        }
605
606        // ── Phase 2: Consume Extended Query responses (actual data) ──
607        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
608
609        let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
610        let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
611        let mut error: Option<PgError> = None;
612
613        loop {
614            let msg = conn.recv().await?;
615            match msg {
616                crate::protocol::BackendMessage::ParseComplete
617                | crate::protocol::BackendMessage::BindComplete => {}
618                crate::protocol::BackendMessage::ParameterDescription(_) => {}
619                crate::protocol::BackendMessage::RowDescription(fields) => {
620                    let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
621                    if is_cache_miss {
622                        conn.column_info_cache.insert(sql_hash, info.clone());
623                    }
624                    column_info = Some(info);
625                }
626                crate::protocol::BackendMessage::DataRow(data) => {
627                    if error.is_none() {
628                        rows.push(super::PgRow {
629                            columns: data,
630                            column_info: column_info.clone(),
631                        });
632                    }
633                }
634                crate::protocol::BackendMessage::CommandComplete(_) => {}
635                crate::protocol::BackendMessage::ReadyForQuery(_) => {
636                    if let Some(err) = error {
637                        return Err(err);
638                    }
639                    return Ok(rows);
640                }
641                crate::protocol::BackendMessage::ErrorResponse(err) => {
642                    if error.is_none() {
643                        error = Some(PgError::Query(err.message));
644                    }
645                }
646                _ => {}
647            }
648        }
649    }
650
651    /// Execute multiple QAIL commands in a single PG pipeline round-trip.
652    ///
653    /// Sends all queries as Parse+Bind+Execute in one write, receives all
654    /// responses in one read. Returns raw column data per query per row.
655    ///
656    /// This is the fastest path for batch operations — amortizes TCP
657    /// overhead across N queries into a single syscall pair.
658    pub async fn pipeline_ast(
659        &mut self,
660        cmds: &[qail_core::ast::Qail],
661    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
662        let conn = self.conn.as_mut().expect("Connection should always be present");
663        conn.pipeline_ast(cmds).await
664    }
665
666    /// Run `EXPLAIN (FORMAT JSON)` on a Qail command and return cost estimates.
667    ///
668    /// Uses `simple_query` under the hood — no additional round-trips beyond
669    /// the single EXPLAIN statement. Returns `None` if parsing fails or
670    /// the EXPLAIN output is unexpected.
671    pub async fn explain_estimate(
672        &mut self,
673        cmd: &qail_core::ast::Qail,
674    ) -> PgResult<Option<super::explain::ExplainEstimate>> {
675        use qail_core::transpiler::ToSql;
676
677        let sql = cmd.to_sql();
678        let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", sql);
679
680        let rows = self.simple_query(&explain_sql).await?;
681
682        // PostgreSQL returns the JSON plan as a single text column across one or more rows
683        let mut json_output = String::new();
684        for row in &rows {
685            if let Some(Some(val)) = row.columns.first()
686                && let Ok(text) = std::str::from_utf8(val)
687            {
688                json_output.push_str(text);
689            }
690        }
691
692        Ok(super::explain::parse_explain_json(&json_output))
693    }
694}
695
696impl Drop for PooledConnection {
697    fn drop(&mut self) {
698        if self.conn.is_some() {
699            // Safety net: connection was NOT released via `release()`.
700            // This happens when:
701            //   - Handler panicked
702            //   - Early return without calling release()
703            //   - Missed release() call (programming error)
704            //
705            // We DESTROY the connection (don't return to pool) to prevent
706            // dirty session state from being reused. This costs a pool slot
707            // but guarantees no cross-tenant leakage.
708            //
709            // The `conn` field is dropped here, closing the TCP socket.
710            eprintln!(
711                "[WARN] pool_connection_leaked: PooledConnection dropped without release() — \
712                 connection destroyed to prevent state leak (rls_dirty={}). \
713                 Use conn.release().await for deterministic cleanup.",
714                self.rls_dirty
715            );
716            // Decrement active count so pool can create a replacement
717            self.pool.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
718        }
719    }
720}
721
722impl std::ops::Deref for PooledConnection {
723    type Target = PgConnection;
724
725    fn deref(&self) -> &Self::Target {
726        self.conn
727            .as_ref()
728            .expect("Connection should always be present")
729    }
730}
731
732impl std::ops::DerefMut for PooledConnection {
733    fn deref_mut(&mut self) -> &mut Self::Target {
734        self.conn
735            .as_mut()
736            .expect("Connection should always be present")
737    }
738}
739
740/// Maximum number of hot statements to track globally.
741const MAX_HOT_STATEMENTS: usize = 32;
742
743/// Inner pool state (shared across clones).
744struct PgPoolInner {
745    config: PoolConfig,
746    connections: Mutex<Vec<PooledConn>>,
747    semaphore: Semaphore,
748    closed: AtomicBool,
749    active_count: AtomicUsize,
750    total_created: AtomicUsize,
751    /// Global registry of frequently-used prepared statements.
752    /// Maps sql_hash → (stmt_name, sql_text).
753    /// New connections pre-prepare these on checkout for instant cache hits.
754    hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
755}
756
757impl PgPoolInner {
758    async fn return_connection(&self, conn: PgConnection) {
759
760        self.active_count.fetch_sub(1, Ordering::Relaxed);
761        
762
763        if self.closed.load(Ordering::Relaxed) {
764            return;
765        }
766        
767        let mut connections = self.connections.lock().await;
768        if connections.len() < self.config.max_connections {
769            connections.push(PooledConn {
770                conn,
771                created_at: Instant::now(),
772                last_used: Instant::now(),
773            });
774        }
775
776        self.semaphore.add_permits(1);
777    }
778
779    /// Get a healthy connection from the pool, or None if pool is empty.
780    async fn get_healthy_connection(&self) -> Option<PgConnection> {
781        let mut connections = self.connections.lock().await;
782
783        while let Some(pooled) = connections.pop() {
784            if pooled.last_used.elapsed() > self.config.idle_timeout {
785                // Connection is stale, drop it
786                continue;
787            }
788
789            if let Some(max_life) = self.config.max_lifetime
790                && pooled.created_at.elapsed() > max_life
791            {
792                // Connection exceeded max lifetime, recycle it
793                continue;
794            }
795
796            return Some(pooled.conn);
797        }
798
799        None
800    }
801}
802
803/// # Example
804/// ```ignore
805/// let config = PoolConfig::new("localhost", 5432, "user", "db")
806///     .password("secret")
807///     .max_connections(20);
808/// let pool = PgPool::connect(config).await?;
809/// // Get a connection from the pool
810/// let mut conn = pool.acquire_raw().await?;
811/// conn.simple_query("SELECT 1").await?;
812/// ```
813#[derive(Clone)]
814pub struct PgPool {
815    inner: Arc<PgPoolInner>,
816}
817
818impl PgPool {
819    /// Create a pool from `qail.toml` (loads and parses automatically).
820    ///
821    /// # Example
822    /// ```ignore
823    /// let pool = PgPool::from_config().await?;
824    /// ```
825    pub async fn from_config() -> PgResult<Self> {
826        let qail = qail_core::config::QailConfig::load()
827            .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
828        let config = PoolConfig::from_qail_config(&qail)?;
829        Self::connect(config).await
830    }
831
832    /// Create a new connection pool.
833    pub async fn connect(config: PoolConfig) -> PgResult<Self> {
834        // Semaphore starts with max_connections permits
835        let semaphore = Semaphore::new(config.max_connections);
836
837        let mut initial_connections = Vec::new();
838        for _ in 0..config.min_connections {
839            let conn = Self::create_connection(&config).await?;
840            initial_connections.push(PooledConn {
841                conn,
842                created_at: Instant::now(),
843                last_used: Instant::now(),
844            });
845        }
846
847        let initial_count = initial_connections.len();
848
849        let inner = Arc::new(PgPoolInner {
850            config,
851            connections: Mutex::new(initial_connections),
852            semaphore,
853            closed: AtomicBool::new(false),
854            active_count: AtomicUsize::new(0),
855            total_created: AtomicUsize::new(initial_count),
856            hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
857        });
858
859        Ok(Self { inner })
860    }
861
862    /// Acquire a raw connection from the pool (crate-internal only).
863    ///
864    /// # Safety (not `unsafe` in the Rust sense, but security-critical)
865    ///
866    /// This returns a connection with **no RLS context**. All tenant data
867    /// queries on this connection will bypass row-level security.
868    ///
869    /// **Safe usage**: Pair with `fetch_all_with_rls()` for pipelined
870    /// RLS+query execution (single roundtrip). Or use `acquire_with_rls()`
871    /// / `acquire_with_rls_timeout()` for the 2-roundtrip path.
872    ///
873    /// **Unsafe usage**: Running queries directly on a raw connection
874    /// without RLS context. Every call site MUST include a `// SAFETY:`
875    /// comment explaining why raw acquisition is justified.
876    pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
877        if self.inner.closed.load(Ordering::Relaxed) {
878            return Err(PgError::Connection("Pool is closed".to_string()));
879        }
880
881        // Wait for available slot with timeout
882        let acquire_timeout = self.inner.config.acquire_timeout;
883        let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
884            .await
885            .map_err(|_| {
886                PgError::Connection(format!(
887                    "Timed out waiting for connection ({}s)",
888                    acquire_timeout.as_secs()
889                ))
890            })?
891            .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
892        permit.forget();
893
894        // Try to get existing healthy connection
895        let mut conn = if let Some(conn) = self.inner.get_healthy_connection().await {
896            conn
897        } else {
898            let conn = Self::create_connection(&self.inner.config).await?;
899            self.inner.total_created.fetch_add(1, Ordering::Relaxed);
900            conn
901        };
902
903        // Pre-prepare hot statements that this connection doesn't have yet.
904        // Collect data synchronously (guard dropped before async work).
905        let missing: Vec<(u64, String, String)> = {
906            if let Ok(hot) = self.inner.hot_statements.read() {
907                hot.iter()
908                    .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
909                    .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
910                    .collect()
911            } else {
912                Vec::new()
913            }
914        }; // RwLockReadGuard dropped here — safe across .await
915
916        if !missing.is_empty() {
917            use crate::protocol::PgEncoder;
918            let mut buf = bytes::BytesMut::new();
919            for (_, name, sql) in &missing {
920                let parse_msg = PgEncoder::encode_parse(name, sql, &[]);
921                buf.extend_from_slice(&parse_msg);
922            }
923            PgEncoder::encode_sync_to(&mut buf);
924            if conn.send_bytes(&buf).await.is_ok() {
925                // Drain responses (ParseComplete + ReadyForQuery)
926                loop {
927                    match conn.recv().await {
928                        Ok(crate::protocol::BackendMessage::ReadyForQuery(_)) => break,
929                        Ok(_) => continue,
930                        Err(_) => break,
931                    }
932                }
933                // Register in local cache
934                for (hash, name, sql) in &missing {
935                    conn.stmt_cache.put(*hash, name.clone());
936                    conn.prepared_statements.insert(name.clone(), sql.clone());
937                }
938            }
939        }
940
941        self.inner.active_count.fetch_add(1, Ordering::Relaxed);
942
943        Ok(PooledConnection {
944            conn: Some(conn),
945            pool: self.inner.clone(),
946            rls_dirty: false,
947        })
948    }
949
950    /// Acquire a connection with RLS context pre-configured.
951    ///
952    /// Sets PostgreSQL session variables for tenant isolation before
953    /// returning the connection. When the connection is dropped, it
954    /// automatically clears the RLS context before returning to the pool.
955    ///
956    /// # Example
957    /// ```ignore
958    /// use qail_core::rls::RlsContext;
959    ///
960    /// let mut conn = pool.acquire_with_rls(
961    ///     RlsContext::operator("550e8400-e29b-41d4-a716-446655440000")
962    /// ).await?;
963    /// // All queries through `conn` are now scoped to this operator
964    /// ```
965    pub async fn acquire_with_rls(
966        &self,
967        ctx: qail_core::rls::RlsContext,
968    ) -> PgResult<PooledConnection> {
969        // SAFETY: RLS context is set immediately below via context_to_sql().
970        let mut conn = self.acquire_raw().await?;
971
972        // Set RLS context on the raw connection
973        let sql = super::rls::context_to_sql(&ctx);
974        let pg_conn = conn.get_mut();
975        pg_conn.execute_simple(&sql).await?;
976
977        // Mark dirty so Drop resets context before pool return
978        conn.rls_dirty = true;
979
980        Ok(conn)
981    }
982
983    /// Acquire a connection with RLS context AND statement timeout.
984    ///
985    /// Like `acquire_with_rls()`, but also sets `statement_timeout` to prevent
986    /// runaway queries from holding pool connections indefinitely.
987    pub async fn acquire_with_rls_timeout(
988        &self,
989        ctx: qail_core::rls::RlsContext,
990        timeout_ms: u32,
991    ) -> PgResult<PooledConnection> {
992        // SAFETY: RLS context + timeout set immediately below via context_to_sql_with_timeout().
993        let mut conn = self.acquire_raw().await?;
994
995        // Set RLS context + statement_timeout atomically
996        let sql = super::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
997        let pg_conn = conn.get_mut();
998        pg_conn.execute_simple(&sql).await?;
999
1000        // Mark dirty so Drop resets context + timeout before pool return
1001        conn.rls_dirty = true;
1002
1003        Ok(conn)
1004    }
1005
1006    /// Acquire a connection for system-level operations (no tenant context).
1007    ///
1008    /// Sets RLS session variables to maximally restrictive values:
1009    /// - `app.current_operator_id = ''`
1010    /// - `app.current_agent_id = ''`  
1011    /// - `app.is_super_admin = false`
1012    ///
1013    /// Use this for startup introspection, migrations, and health checks
1014    /// that must not operate within any tenant scope.
1015    pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
1016        let ctx = qail_core::rls::RlsContext::empty();
1017        self.acquire_with_rls(ctx).await
1018    }
1019
1020    /// Acquire a connection with branch context pre-configured.
1021    ///
1022    /// Sets PostgreSQL session variable `app.branch_id` for data virtualization.
1023    /// When the connection is dropped, it automatically clears the branch context.
1024    ///
1025    /// # Example
1026    /// ```ignore
1027    /// use qail_core::branch::BranchContext;
1028    ///
1029    /// let ctx = BranchContext::branch("feature-auth");
1030    /// let mut conn = pool.acquire_with_branch(&ctx).await?;
1031    /// // All queries through `conn` are now branch-aware
1032    /// ```
1033    pub async fn acquire_with_branch(
1034        &self,
1035        ctx: &qail_core::branch::BranchContext,
1036    ) -> PgResult<PooledConnection> {
1037        // SAFETY: Branch context is set immediately below via branch_context_sql().
1038        let mut conn = self.acquire_raw().await?;
1039
1040        if let Some(branch_name) = ctx.branch_name() {
1041            let sql = super::branch_sql::branch_context_sql(branch_name);
1042            let pg_conn = conn.get_mut();
1043            pg_conn.execute_simple(&sql).await?;
1044            conn.rls_dirty = true; // Reuse dirty flag for auto-reset
1045        }
1046
1047        Ok(conn)
1048    }
1049
1050    /// Get the current number of idle connections.
1051    pub async fn idle_count(&self) -> usize {
1052        self.inner.connections.lock().await.len()
1053    }
1054
1055    /// Get the number of connections currently in use.
1056    pub fn active_count(&self) -> usize {
1057        self.inner.active_count.load(Ordering::Relaxed)
1058    }
1059
1060    /// Get the maximum number of connections.
1061    pub fn max_connections(&self) -> usize {
1062        self.inner.config.max_connections
1063    }
1064
1065    /// Get comprehensive pool statistics.
1066    pub async fn stats(&self) -> PoolStats {
1067        let idle = self.inner.connections.lock().await.len();
1068        PoolStats {
1069            active: self.inner.active_count.load(Ordering::Relaxed),
1070            idle,
1071            pending: self.inner.config.max_connections
1072                - self.inner.semaphore.available_permits()
1073                - self.active_count(),
1074            max_size: self.inner.config.max_connections,
1075            total_created: self.inner.total_created.load(Ordering::Relaxed),
1076        }
1077    }
1078
1079    /// Check if the pool is closed.
1080    pub fn is_closed(&self) -> bool {
1081        self.inner.closed.load(Ordering::Relaxed)
1082    }
1083
1084    /// Close the pool gracefully.
1085    pub async fn close(&self) {
1086        self.inner.closed.store(true, Ordering::Relaxed);
1087
1088        let mut connections = self.inner.connections.lock().await;
1089        connections.clear();
1090    }
1091
1092    /// Create a new connection using the pool configuration.
1093    async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
1094        match &config.password {
1095            Some(password) => {
1096                PgConnection::connect_with_password(
1097                    &config.host,
1098                    config.port,
1099                    &config.user,
1100                    &config.database,
1101                    Some(password),
1102                )
1103                .await
1104            }
1105            None => {
1106                PgConnection::connect(&config.host, config.port, &config.user, &config.database)
1107                    .await
1108            }
1109        }
1110    }
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115    use super::*;
1116
1117    #[test]
1118    fn test_pool_config() {
1119        let config = PoolConfig::new("localhost", 5432, "user", "testdb")
1120            .password("secret123")
1121            .max_connections(20)
1122            .min_connections(5);
1123
1124        assert_eq!(config.host, "localhost");
1125        assert_eq!(config.port, 5432);
1126        assert_eq!(config.user, "user");
1127        assert_eq!(config.database, "testdb");
1128        assert_eq!(config.password, Some("secret123".to_string()));
1129        assert_eq!(config.max_connections, 20);
1130        assert_eq!(config.min_connections, 5);
1131    }
1132}