1use super::churn::{
5 PoolStats, decrement_active_count_saturating, pool_churn_record_destroy,
6 pool_churn_remaining_open, record_pool_connection_destroy,
7};
8use super::config::PoolConfig;
9use super::connection::PooledConn;
10use super::connection::PooledConnection;
11use super::gss::*;
12use crate::driver::{
13 ConnectOptions, PgConnection, PgError, PgResult, is_ignorable_session_message,
14 unexpected_backend_message,
15};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
18use std::time::{Duration, Instant};
19use tokio::sync::{Mutex, Semaphore};
20
21pub(super) const MAX_HOT_STATEMENTS: usize = 32;
23
24pub(super) struct PgPoolInner {
26 pub(super) config: PoolConfig,
27 pub(super) connections: Mutex<Vec<PooledConn>>,
28 pub(super) semaphore: Semaphore,
29 pub(super) closed: AtomicBool,
30 pub(super) active_count: AtomicUsize,
31 pub(super) total_created: AtomicUsize,
32 pub(super) hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
36}
37
38pub(super) fn handle_hot_preprepare_message(
39 msg: &crate::protocol::BackendMessage,
40 parse_complete_count: &mut usize,
41 error: &mut Option<PgError>,
42) -> PgResult<bool> {
43 match msg {
44 crate::protocol::BackendMessage::ParseComplete => {
45 *parse_complete_count += 1;
46 Ok(false)
47 }
48 crate::protocol::BackendMessage::ErrorResponse(err) => {
49 if error.is_none() {
50 *error = Some(PgError::QueryServer(err.clone().into()));
51 }
52 Ok(false)
53 }
54 crate::protocol::BackendMessage::ReadyForQuery(_) => Ok(true),
55 msg if is_ignorable_session_message(msg) => Ok(false),
56 other => Err(unexpected_backend_message("pool hot pre-prepare", other)),
57 }
58}
59
60impl PgPoolInner {
61 pub(super) async fn return_connection(&self, conn: PgConnection, created_at: Instant) {
62 decrement_active_count_saturating(&self.active_count);
63
64 if conn.is_io_desynced() {
65 tracing::warn!(
66 host = %self.config.host,
67 port = self.config.port,
68 user = %self.config.user,
69 db = %self.config.database,
70 "pool_return_desynced: dropping connection due to prior I/O/protocol desync"
71 );
72 record_pool_connection_destroy("pool_desynced_drop");
73 self.semaphore.add_permits(1);
74 pool_churn_record_destroy(&self.config, "return_desynced");
75 return;
76 }
77
78 if self.closed.load(Ordering::Relaxed) {
79 record_pool_connection_destroy("pool_closed_drop");
80 self.semaphore.add_permits(1);
81 return;
82 }
83
84 let mut connections = self.connections.lock().await;
85 if connections.len() < self.config.max_connections {
86 connections.push(PooledConn {
87 conn,
88 created_at,
89 last_used: Instant::now(),
90 });
91 } else {
92 record_pool_connection_destroy("pool_overflow_drop");
93 }
94
95 self.semaphore.add_permits(1);
96 }
97
98 async fn get_healthy_connection(&self) -> Option<PooledConn> {
100 let mut connections = self.connections.lock().await;
101
102 while let Some(pooled) = connections.pop() {
103 if pooled.last_used.elapsed() > self.config.idle_timeout {
104 tracing::debug!(
105 idle_secs = pooled.last_used.elapsed().as_secs(),
106 timeout_secs = self.config.idle_timeout.as_secs(),
107 "pool_checkout_evict: connection exceeded idle timeout"
108 );
109 record_pool_connection_destroy("idle_timeout_evict");
110 continue;
111 }
112
113 if let Some(max_life) = self.config.max_lifetime
114 && pooled.created_at.elapsed() > max_life
115 {
116 tracing::debug!(
117 age_secs = pooled.created_at.elapsed().as_secs(),
118 max_lifetime_secs = max_life.as_secs(),
119 "pool_checkout_evict: connection exceeded max lifetime"
120 );
121 record_pool_connection_destroy("max_lifetime_evict");
122 continue;
123 }
124
125 return Some(pooled);
126 }
127
128 None
129 }
130}
131
132#[derive(Clone)]
143pub struct PgPool {
144 pub(super) inner: Arc<PgPoolInner>,
145}
146
147impl PgPool {
148 pub async fn from_config() -> PgResult<Self> {
155 let qail = qail_core::config::QailConfig::load()
156 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
157 let config = PoolConfig::from_qail_config(&qail)?;
158 Self::connect(config).await
159 }
160
161 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
163 validate_pool_config(&config)?;
164
165 let semaphore = Semaphore::new(config.max_connections);
167
168 let mut initial_connections = Vec::new();
169 for _ in 0..config.min_connections {
170 let conn = Self::create_connection(&config).await?;
171 initial_connections.push(PooledConn {
172 conn,
173 created_at: Instant::now(),
174 last_used: Instant::now(),
175 });
176 }
177
178 let initial_count = initial_connections.len();
179
180 let inner = Arc::new(PgPoolInner {
181 config,
182 connections: Mutex::new(initial_connections),
183 semaphore,
184 closed: AtomicBool::new(false),
185 active_count: AtomicUsize::new(0),
186 total_created: AtomicUsize::new(initial_count),
187 hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
188 });
189
190 Ok(Self { inner })
191 }
192
193 pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
208 if self.inner.closed.load(Ordering::Relaxed) {
209 return Err(PgError::PoolClosed);
210 }
211
212 if let Some(remaining) = pool_churn_remaining_open(&self.inner.config) {
213 metrics::counter!("qail_pg_pool_churn_circuit_reject_total").increment(1);
214 tracing::warn!(
215 host = %self.inner.config.host,
216 port = self.inner.config.port,
217 user = %self.inner.config.user,
218 db = %self.inner.config.database,
219 remaining_ms = remaining.as_millis() as u64,
220 "pool_connection_churn_circuit_open"
221 );
222 return Err(PgError::PoolExhausted {
223 max: self.inner.config.max_connections,
224 });
225 }
226
227 let acquire_timeout = self.inner.config.acquire_timeout;
229 let permit =
230 match tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire()).await {
231 Ok(permit) => permit.map_err(|_| PgError::PoolClosed)?,
232 Err(_) => {
233 metrics::counter!("qail_pg_pool_acquire_timeouts_total").increment(1);
234 return Err(PgError::Timeout(format!(
235 "pool acquire after {}s ({} max connections)",
236 acquire_timeout.as_secs(),
237 self.inner.config.max_connections
238 )));
239 }
240 };
241
242 if self.inner.closed.load(Ordering::Relaxed) {
243 return Err(PgError::PoolClosed);
244 }
245
246 let (mut conn, mut created_at) =
248 if let Some(pooled) = self.inner.get_healthy_connection().await {
249 (pooled.conn, pooled.created_at)
250 } else {
251 let conn = Self::create_connection(&self.inner.config).await?;
252 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
253 (conn, Instant::now())
254 };
255
256 if self.inner.config.test_on_acquire
257 && let Err(e) = execute_simple_with_timeout(
258 &mut conn,
259 "SELECT 1",
260 self.inner.config.connect_timeout,
261 "pool checkout health check",
262 )
263 .await
264 {
265 tracing::warn!(
266 host = %self.inner.config.host,
267 port = self.inner.config.port,
268 user = %self.inner.config.user,
269 db = %self.inner.config.database,
270 error = %e,
271 "pool_health_check_failed: checkout probe failed, creating replacement connection"
272 );
273 pool_churn_record_destroy(&self.inner.config, "health_check_failed");
274 conn = Self::create_connection(&self.inner.config).await?;
275 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
276 created_at = Instant::now();
277 }
278
279 let missing: Vec<(u64, String, String)> = {
282 if let Ok(hot) = self.inner.hot_statements.read() {
283 hot.iter()
284 .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
285 .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
286 .collect()
287 } else {
288 Vec::new()
289 }
290 }; if !missing.is_empty() {
293 use crate::protocol::PgEncoder;
294 let mut buf = bytes::BytesMut::new();
295 for (_, name, sql) in &missing {
296 let parse_msg = PgEncoder::try_encode_parse(name, sql, &[])?;
297 buf.extend_from_slice(&parse_msg);
298 }
299 PgEncoder::encode_sync_to(&mut buf);
300 let preprepare_timeout = self.inner.config.connect_timeout;
301 let preprepare_result: PgResult<()> = match tokio::time::timeout(
302 preprepare_timeout,
303 async {
304 conn.send_bytes(&buf).await?;
305 let mut parse_complete_count = 0usize;
307 let mut parse_error: Option<PgError> = None;
308 loop {
309 let msg = conn.recv().await?;
310 if handle_hot_preprepare_message(
311 &msg,
312 &mut parse_complete_count,
313 &mut parse_error,
314 )? {
315 if let Some(err) = parse_error {
316 return Err(err);
317 }
318 if parse_complete_count != missing.len() {
319 return Err(PgError::Protocol(format!(
320 "hot pre-prepare completed with {} ParseComplete messages (expected {})",
321 parse_complete_count,
322 missing.len()
323 )));
324 }
325 break;
326 }
327 }
328 Ok::<(), PgError>(())
329 },
330 )
331 .await
332 {
333 Ok(res) => res,
334 Err(_) => Err(PgError::Timeout(format!(
335 "hot statement pre-prepare timeout after {:?} (pool config connect_timeout)",
336 preprepare_timeout
337 ))),
338 };
339
340 if let Err(e) = preprepare_result {
341 tracing::warn!(
342 host = %self.inner.config.host,
343 port = self.inner.config.port,
344 user = %self.inner.config.user,
345 db = %self.inner.config.database,
346 timeout_ms = preprepare_timeout.as_millis() as u64,
347 error = %e,
348 "pool_hot_prepare_failed: replacing connection to avoid handing out uncertain protocol state"
349 );
350 pool_churn_record_destroy(&self.inner.config, "hot_prepare_failed");
351 conn = Self::create_connection(&self.inner.config).await?;
352 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
353 created_at = Instant::now();
354 } else {
355 for (hash, name, sql) in &missing {
357 conn.stmt_cache.put(*hash, name.clone());
358 conn.prepared_statements.insert(name.clone(), sql.clone());
359 }
360 }
361 }
362
363 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
364 permit.forget();
366
367 Ok(PooledConnection {
368 conn: Some(conn),
369 pool: self.inner.clone(),
370 rls_dirty: false,
371 created_at,
372 })
373 }
374
375 pub async fn acquire_with_rls(
391 &self,
392 ctx: qail_core::rls::RlsContext,
393 ) -> PgResult<PooledConnection> {
394 let mut conn = self.acquire_raw().await?;
396
397 let sql = crate::driver::rls::context_to_sql(&ctx);
399 let pg_conn = conn.get_mut()?;
400 if let Err(e) = execute_simple_with_timeout(
401 pg_conn,
402 &sql,
403 self.inner.config.connect_timeout,
404 "pool acquire_with_rls setup",
405 )
406 .await
407 {
408 if let Ok(pg_conn) = conn.get_mut() {
411 let _ = pg_conn.execute_simple("ROLLBACK").await;
412 }
413 conn.release().await;
414 return Err(e);
415 }
416
417 conn.rls_dirty = true;
419
420 Ok(conn)
421 }
422
423 pub async fn acquire_with_rls_timeout(
428 &self,
429 ctx: qail_core::rls::RlsContext,
430 timeout_ms: u32,
431 ) -> PgResult<PooledConnection> {
432 let mut conn = self.acquire_raw().await?;
434
435 let sql = crate::driver::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
437 let pg_conn = conn.get_mut()?;
438 if let Err(e) = execute_simple_with_timeout(
439 pg_conn,
440 &sql,
441 self.inner.config.connect_timeout,
442 "pool acquire_with_rls_timeout setup",
443 )
444 .await
445 {
446 if let Ok(pg_conn) = conn.get_mut() {
447 let _ = pg_conn.execute_simple("ROLLBACK").await;
448 }
449 conn.release().await;
450 return Err(e);
451 }
452
453 conn.rls_dirty = true;
455
456 Ok(conn)
457 }
458
459 pub async fn acquire_with_rls_timeouts(
465 &self,
466 ctx: qail_core::rls::RlsContext,
467 statement_timeout_ms: u32,
468 lock_timeout_ms: u32,
469 ) -> PgResult<PooledConnection> {
470 let mut conn = self.acquire_raw().await?;
472
473 let sql = crate::driver::rls::context_to_sql_with_timeouts(
474 &ctx,
475 statement_timeout_ms,
476 lock_timeout_ms,
477 );
478 let pg_conn = conn.get_mut()?;
479 if let Err(e) = execute_simple_with_timeout(
480 pg_conn,
481 &sql,
482 self.inner.config.connect_timeout,
483 "pool acquire_with_rls_timeouts setup",
484 )
485 .await
486 {
487 if let Ok(pg_conn) = conn.get_mut() {
488 let _ = pg_conn.execute_simple("ROLLBACK").await;
489 }
490 conn.release().await;
491 return Err(e);
492 }
493
494 conn.rls_dirty = true;
495
496 Ok(conn)
497 }
498
499 pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
509 let ctx = qail_core::rls::RlsContext::empty();
510 self.acquire_with_rls(ctx).await
511 }
512
513 pub async fn acquire_for_tenant(&self, tenant_id: &str) -> PgResult<PooledConnection> {
525 self.acquire_with_rls(qail_core::rls::RlsContext::tenant(tenant_id))
526 .await
527 }
528
529 pub async fn acquire_with_branch(
543 &self,
544 ctx: &qail_core::branch::BranchContext,
545 ) -> PgResult<PooledConnection> {
546 let mut conn = self.acquire_raw().await?;
548
549 if let Some(branch_name) = ctx.branch_name() {
550 let sql = crate::driver::branch_sql::branch_context_sql(branch_name);
551 let pg_conn = conn.get_mut()?;
552 if let Err(e) = execute_simple_with_timeout(
553 pg_conn,
554 &sql,
555 self.inner.config.connect_timeout,
556 "pool acquire_with_branch setup",
557 )
558 .await
559 {
560 if let Ok(pg_conn) = conn.get_mut() {
561 let _ = pg_conn.execute_simple("ROLLBACK").await;
562 }
563 conn.release().await;
564 return Err(e);
565 }
566 conn.rls_dirty = true; }
568
569 Ok(conn)
570 }
571
572 pub async fn idle_count(&self) -> usize {
574 self.inner.connections.lock().await.len()
575 }
576
577 pub fn active_count(&self) -> usize {
579 self.inner.active_count.load(Ordering::Relaxed)
580 }
581
582 pub fn max_connections(&self) -> usize {
584 self.inner.config.max_connections
585 }
586
587 pub async fn stats(&self) -> PoolStats {
589 let idle = self.inner.connections.lock().await.len();
590 let active = self.inner.active_count.load(Ordering::Relaxed);
591 let used_slots = self
592 .inner
593 .config
594 .max_connections
595 .saturating_sub(self.inner.semaphore.available_permits());
596 PoolStats {
597 active,
598 idle,
599 pending: used_slots.saturating_sub(active),
600 max_size: self.inner.config.max_connections,
601 total_created: self.inner.total_created.load(Ordering::Relaxed),
602 }
603 }
604
605 pub fn is_closed(&self) -> bool {
607 self.inner.closed.load(Ordering::Relaxed)
608 }
609
610 pub async fn close(&self) {
617 self.close_graceful(self.inner.config.acquire_timeout).await;
618 }
619
620 pub async fn close_graceful(&self, drain_timeout: Duration) {
622 self.inner.closed.store(true, Ordering::Relaxed);
623 self.inner.semaphore.close();
625
626 let deadline = Instant::now() + drain_timeout;
627 loop {
628 let active = self.inner.active_count.load(Ordering::Relaxed);
629 if active == 0 {
630 break;
631 }
632 if Instant::now() >= deadline {
633 tracing::warn!(
634 active_connections = active,
635 timeout_ms = drain_timeout.as_millis() as u64,
636 "pool_close_drain_timeout: forcing idle cleanup while active connections remain"
637 );
638 break;
639 }
640 tokio::time::sleep(Duration::from_millis(25)).await;
641 }
642
643 let mut connections = self.inner.connections.lock().await;
644 let dropped_idle = connections.len();
645 connections.clear();
646 tracing::info!(
647 dropped_idle_connections = dropped_idle,
648 active_connections = self.inner.active_count.load(Ordering::Relaxed),
649 "pool_closed"
650 );
651 }
652
653 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
655 if !config.auth_settings.has_any_password_method()
656 && config.mtls.is_none()
657 && config.password.is_some()
658 {
659 return Err(PgError::Auth(
660 "Invalid PoolConfig: all password auth methods are disabled".to_string(),
661 ));
662 }
663
664 let options = ConnectOptions {
665 tls_mode: config.tls_mode,
666 gss_enc_mode: config.gss_enc_mode,
667 tls_ca_cert_pem: config.tls_ca_cert_pem.clone(),
668 mtls: config.mtls.clone(),
669 gss_token_provider: config.gss_token_provider,
670 gss_token_provider_ex: config.gss_token_provider_ex.clone(),
671 auth: config.auth_settings,
672 startup_params: Vec::new(),
673 };
674
675 if let Some(remaining) = gss_circuit_remaining_open(config) {
676 metrics::counter!("qail_pg_gss_circuit_open_total").increment(1);
677 tracing::warn!(
678 host = %config.host,
679 port = config.port,
680 user = %config.user,
681 db = %config.database,
682 remaining_ms = remaining.as_millis() as u64,
683 "gss_connect_circuit_open"
684 );
685 return Err(PgError::Connection(format!(
686 "GSS connection circuit is open; retry after {:?}",
687 remaining
688 )));
689 }
690
691 let mut attempt = 0usize;
692 loop {
693 let connect_result = tokio::time::timeout(
694 config.connect_timeout,
695 PgConnection::connect_with_options(
696 &config.host,
697 config.port,
698 &config.user,
699 &config.database,
700 config.password.as_deref(),
701 options.clone(),
702 ),
703 )
704 .await;
705
706 let connect_result = match connect_result {
707 Ok(result) => result,
708 Err(_) => Err(PgError::Timeout(format!(
709 "connect timeout after {:?} (pool config connect_timeout)",
710 config.connect_timeout
711 ))),
712 };
713
714 match connect_result {
715 Ok(conn) => {
716 metrics::counter!("qail_pg_pool_connect_success_total").increment(1);
717 gss_circuit_record_success(config);
718 return Ok(conn);
719 }
720 Err(err) if should_retry_gss_connect_error(config, attempt, &err) => {
721 metrics::counter!("qail_pg_gss_connect_retries_total").increment(1);
722 gss_circuit_record_failure(config);
723 let delay = gss_retry_delay(config.gss_retry_base_delay, attempt);
724 tracing::warn!(
725 host = %config.host,
726 port = config.port,
727 user = %config.user,
728 db = %config.database,
729 attempt = attempt + 1,
730 delay_ms = delay.as_millis() as u64,
731 error = %err,
732 "gss_connect_retry"
733 );
734 tokio::time::sleep(delay).await;
735 attempt += 1;
736 }
737 Err(err) => {
738 metrics::counter!("qail_pg_pool_connect_failures_total").increment(1);
739 if should_track_gss_circuit_error(config, &err) {
740 metrics::counter!("qail_pg_gss_connect_failures_total").increment(1);
741 gss_circuit_record_failure(config);
742 }
743 return Err(err);
744 }
745 }
746 }
747 }
748
749 pub async fn maintain(&self) {
752 if self.inner.closed.load(Ordering::Relaxed) {
753 return;
754 }
755
756 let evicted = {
758 let mut connections = self.inner.connections.lock().await;
759 let before = connections.len();
760 connections.retain(|pooled| {
761 if pooled.last_used.elapsed() > self.inner.config.idle_timeout {
762 record_pool_connection_destroy("idle_sweep_evict");
763 return false;
764 }
765 if let Some(max_life) = self.inner.config.max_lifetime
766 && pooled.created_at.elapsed() > max_life
767 {
768 record_pool_connection_destroy("lifetime_sweep_evict");
769 return false;
770 }
771 true
772 });
773 before - connections.len()
774 };
775
776 if evicted > 0 {
777 tracing::debug!(evicted, "pool_maintenance: evicted stale idle connections");
778 }
779
780 let min = self.inner.config.min_connections;
782 if min == 0 {
783 return;
784 }
785
786 let idle_count = self.inner.connections.lock().await.len();
787 if idle_count >= min {
788 return;
789 }
790
791 let deficit = min - idle_count;
792 let mut created = 0usize;
793 for _ in 0..deficit {
794 match Self::create_connection(&self.inner.config).await {
795 Ok(conn) => {
796 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
797 let mut connections = self.inner.connections.lock().await;
798 if connections.len() < self.inner.config.max_connections {
799 connections.push(PooledConn {
800 conn,
801 created_at: Instant::now(),
802 last_used: Instant::now(),
803 });
804 created += 1;
805 } else {
806 break;
808 }
809 }
810 Err(e) => {
811 tracing::warn!(error = %e, "pool_maintenance: backfill connection failed");
812 break; }
814 }
815 }
816
817 if created > 0 {
818 tracing::debug!(
819 created,
820 min_connections = min,
821 "pool_maintenance: backfilled idle connections"
822 );
823 }
824 }
825}
826
827pub fn spawn_pool_maintenance(pool: PgPool) {
832 let interval_secs = std::cmp::max(pool.inner.config.idle_timeout.as_secs() / 2, 5);
833 tokio::spawn(async move {
834 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
835 loop {
836 interval.tick().await;
837 if pool.is_closed() {
838 break;
839 }
840 pool.maintain().await;
841 }
842 });
843}
844
845pub(super) fn validate_pool_config(config: &PoolConfig) -> PgResult<()> {
846 if config.max_connections == 0 {
847 return Err(PgError::Connection(
848 "Invalid PoolConfig: max_connections must be >= 1".to_string(),
849 ));
850 }
851 if config.min_connections > config.max_connections {
852 return Err(PgError::Connection(format!(
853 "Invalid PoolConfig: min_connections ({}) must be <= max_connections ({})",
854 config.min_connections, config.max_connections
855 )));
856 }
857 if config.acquire_timeout.is_zero() {
858 return Err(PgError::Connection(
859 "Invalid PoolConfig: acquire_timeout must be > 0".to_string(),
860 ));
861 }
862 if config.connect_timeout.is_zero() {
863 return Err(PgError::Connection(
864 "Invalid PoolConfig: connect_timeout must be > 0".to_string(),
865 ));
866 }
867 Ok(())
868}
869
870pub(super) async fn execute_simple_with_timeout(
871 conn: &mut PgConnection,
872 sql: &str,
873 timeout: Duration,
874 operation: &str,
875) -> PgResult<()> {
876 match tokio::time::timeout(timeout, conn.execute_simple(sql)).await {
877 Ok(result) => result,
878 Err(_) => {
879 conn.mark_io_desynced();
880 Err(PgError::Timeout(format!(
881 "{} timeout after {:?} (pool config connect_timeout)",
882 operation, timeout
883 )))
884 }
885 }
886}