1use super::churn::{
5 PoolStats, decrement_active_count_saturating, pool_churn_record_destroy,
6 pool_churn_remaining_open, record_pool_connection_destroy,
7};
8use super::config::PoolConfig;
9use super::connection::PooledConn;
10use super::connection::PooledConnection;
11use super::gss::*;
12use crate::driver::{
13 ConnectOptions, PgConnection, PgError, PgResult, is_ignorable_session_message,
14 unexpected_backend_message,
15};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
18use std::time::{Duration, Instant};
19use tokio::sync::{Mutex, Semaphore};
20
21pub(super) const MAX_HOT_STATEMENTS: usize = 32;
23
24pub(super) struct PgPoolInner {
26 pub(super) config: PoolConfig,
27 pub(super) connections: Mutex<Vec<PooledConn>>,
28 pub(super) semaphore: Semaphore,
29 pub(super) closed: AtomicBool,
30 pub(super) active_count: AtomicUsize,
31 pub(super) total_created: AtomicUsize,
32 pub(super) hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
36}
37
38pub(super) fn handle_hot_preprepare_message(
39 msg: &crate::protocol::BackendMessage,
40 parse_complete_count: &mut usize,
41 error: &mut Option<PgError>,
42) -> PgResult<bool> {
43 match msg {
44 crate::protocol::BackendMessage::ParseComplete => {
45 *parse_complete_count += 1;
46 Ok(false)
47 }
48 crate::protocol::BackendMessage::ErrorResponse(err) => {
49 if error.is_none() {
50 *error = Some(PgError::QueryServer(err.clone().into()));
51 }
52 Ok(false)
53 }
54 crate::protocol::BackendMessage::ReadyForQuery(_) => Ok(true),
55 msg if is_ignorable_session_message(msg) => Ok(false),
56 other => Err(unexpected_backend_message("pool hot pre-prepare", other)),
57 }
58}
59
60impl PgPoolInner {
61 pub(super) async fn return_connection(&self, conn: PgConnection, created_at: Instant) {
62 decrement_active_count_saturating(&self.active_count);
63
64 if conn.is_io_desynced() {
65 tracing::warn!(
66 host = %self.config.host,
67 port = self.config.port,
68 user = %self.config.user,
69 db = %self.config.database,
70 "pool_return_desynced: dropping connection due to prior I/O/protocol desync"
71 );
72 record_pool_connection_destroy("pool_desynced_drop");
73 self.semaphore.add_permits(1);
74 pool_churn_record_destroy(&self.config, "return_desynced");
75 return;
76 }
77
78 if self.closed.load(Ordering::Relaxed) {
79 record_pool_connection_destroy("pool_closed_drop");
80 self.semaphore.add_permits(1);
81 return;
82 }
83
84 let mut connections = self.connections.lock().await;
85 if connections.len() < self.config.max_connections {
86 connections.push(PooledConn {
87 conn,
88 created_at,
89 last_used: Instant::now(),
90 });
91 } else {
92 record_pool_connection_destroy("pool_overflow_drop");
93 }
94
95 self.semaphore.add_permits(1);
96 }
97
98 async fn get_healthy_connection(&self) -> Option<PooledConn> {
100 let mut connections = self.connections.lock().await;
101
102 while let Some(pooled) = connections.pop() {
103 if pooled.last_used.elapsed() > self.config.idle_timeout {
104 tracing::debug!(
105 idle_secs = pooled.last_used.elapsed().as_secs(),
106 timeout_secs = self.config.idle_timeout.as_secs(),
107 "pool_checkout_evict: connection exceeded idle timeout"
108 );
109 record_pool_connection_destroy("idle_timeout_evict");
110 continue;
111 }
112
113 if let Some(max_life) = self.config.max_lifetime
114 && pooled.created_at.elapsed() > max_life
115 {
116 tracing::debug!(
117 age_secs = pooled.created_at.elapsed().as_secs(),
118 max_lifetime_secs = max_life.as_secs(),
119 "pool_checkout_evict: connection exceeded max lifetime"
120 );
121 record_pool_connection_destroy("max_lifetime_evict");
122 continue;
123 }
124
125 return Some(pooled);
126 }
127
128 None
129 }
130}
131
132#[derive(Clone)]
143pub struct PgPool {
144 pub(super) inner: Arc<PgPoolInner>,
145}
146
147impl PgPool {
148 pub async fn from_config() -> PgResult<Self> {
155 let qail = qail_core::config::QailConfig::load()
156 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
157 let config = PoolConfig::from_qail_config(&qail)?;
158 Self::connect(config).await
159 }
160
161 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
163 validate_pool_config(&config)?;
164
165 let semaphore = Semaphore::new(config.max_connections);
167
168 let mut initial_connections = Vec::new();
169 for _ in 0..config.min_connections {
170 let conn = Self::create_connection(&config).await?;
171 initial_connections.push(PooledConn {
172 conn,
173 created_at: Instant::now(),
174 last_used: Instant::now(),
175 });
176 }
177
178 let initial_count = initial_connections.len();
179
180 let inner = Arc::new(PgPoolInner {
181 config,
182 connections: Mutex::new(initial_connections),
183 semaphore,
184 closed: AtomicBool::new(false),
185 active_count: AtomicUsize::new(0),
186 total_created: AtomicUsize::new(initial_count),
187 hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
188 });
189
190 Ok(Self { inner })
191 }
192
193 pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
208 if self.inner.closed.load(Ordering::Relaxed) {
209 return Err(PgError::PoolClosed);
210 }
211
212 if let Some(remaining) = pool_churn_remaining_open(&self.inner.config) {
213 metrics::counter!("qail_pg_pool_churn_circuit_reject_total").increment(1);
214 tracing::warn!(
215 host = %self.inner.config.host,
216 port = self.inner.config.port,
217 user = %self.inner.config.user,
218 db = %self.inner.config.database,
219 remaining_ms = remaining.as_millis() as u64,
220 "pool_connection_churn_circuit_open"
221 );
222 return Err(PgError::PoolExhausted {
223 max: self.inner.config.max_connections,
224 });
225 }
226
227 let acquire_timeout = self.inner.config.acquire_timeout;
229 let permit =
230 match tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire()).await {
231 Ok(permit) => permit.map_err(|_| PgError::PoolClosed)?,
232 Err(_) => {
233 metrics::counter!("qail_pg_pool_acquire_timeouts_total").increment(1);
234 return Err(PgError::Timeout(format!(
235 "pool acquire after {}s ({} max connections)",
236 acquire_timeout.as_secs(),
237 self.inner.config.max_connections
238 )));
239 }
240 };
241
242 if self.inner.closed.load(Ordering::Relaxed) {
243 return Err(PgError::PoolClosed);
244 }
245
246 let (mut conn, mut created_at) =
248 if let Some(pooled) = self.inner.get_healthy_connection().await {
249 (pooled.conn, pooled.created_at)
250 } else {
251 let conn = Self::create_connection(&self.inner.config).await?;
252 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
253 (conn, Instant::now())
254 };
255
256 if self.inner.config.test_on_acquire
257 && let Err(e) = execute_simple_with_timeout(
258 &mut conn,
259 "SELECT 1",
260 self.inner.config.connect_timeout,
261 "pool checkout health check",
262 )
263 .await
264 {
265 tracing::warn!(
266 host = %self.inner.config.host,
267 port = self.inner.config.port,
268 user = %self.inner.config.user,
269 db = %self.inner.config.database,
270 error = %e,
271 "pool_health_check_failed: checkout probe failed, creating replacement connection"
272 );
273 pool_churn_record_destroy(&self.inner.config, "health_check_failed");
274 conn = Self::create_connection(&self.inner.config).await?;
275 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
276 created_at = Instant::now();
277 }
278
279 let missing: Vec<(u64, String, String)> = {
282 if let Ok(hot) = self.inner.hot_statements.read() {
283 hot.iter()
284 .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
285 .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
286 .collect()
287 } else {
288 Vec::new()
289 }
290 }; if !missing.is_empty() {
293 use crate::protocol::PgEncoder;
294 let mut buf = bytes::BytesMut::new();
295 for (_, name, sql) in &missing {
296 let parse_msg = PgEncoder::try_encode_parse(name, sql, &[])?;
297 buf.extend_from_slice(&parse_msg);
298 }
299 PgEncoder::encode_sync_to(&mut buf);
300 let preprepare_timeout = self.inner.config.connect_timeout;
301 let preprepare_result: PgResult<()> = match tokio::time::timeout(
302 preprepare_timeout,
303 async {
304 conn.send_bytes(&buf).await?;
305 let mut parse_complete_count = 0usize;
307 let mut parse_error: Option<PgError> = None;
308 loop {
309 let msg = conn.recv().await?;
310 if handle_hot_preprepare_message(
311 &msg,
312 &mut parse_complete_count,
313 &mut parse_error,
314 )? {
315 if let Some(err) = parse_error {
316 return Err(err);
317 }
318 if parse_complete_count != missing.len() {
319 return Err(PgError::Protocol(format!(
320 "hot pre-prepare completed with {} ParseComplete messages (expected {})",
321 parse_complete_count,
322 missing.len()
323 )));
324 }
325 break;
326 }
327 }
328 Ok::<(), PgError>(())
329 },
330 )
331 .await
332 {
333 Ok(res) => res,
334 Err(_) => Err(PgError::Timeout(format!(
335 "hot statement pre-prepare timeout after {:?} (pool config connect_timeout)",
336 preprepare_timeout
337 ))),
338 };
339
340 if let Err(e) = preprepare_result {
341 tracing::warn!(
342 host = %self.inner.config.host,
343 port = self.inner.config.port,
344 user = %self.inner.config.user,
345 db = %self.inner.config.database,
346 timeout_ms = preprepare_timeout.as_millis() as u64,
347 error = %e,
348 "pool_hot_prepare_failed: replacing connection to avoid handing out uncertain protocol state"
349 );
350 pool_churn_record_destroy(&self.inner.config, "hot_prepare_failed");
351 conn = Self::create_connection(&self.inner.config).await?;
352 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
353 created_at = Instant::now();
354 } else {
355 for (hash, name, sql) in &missing {
357 conn.stmt_cache.put(*hash, name.clone());
358 conn.prepared_statements.insert(name.clone(), sql.clone());
359 }
360 }
361 }
362
363 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
364 permit.forget();
366
367 Ok(PooledConnection {
368 conn: Some(conn),
369 pool: self.inner.clone(),
370 rls_dirty: false,
371 created_at,
372 })
373 }
374
375 pub async fn acquire_with_rls(
391 &self,
392 ctx: qail_core::rls::RlsContext,
393 ) -> PgResult<PooledConnection> {
394 let mut conn = self.acquire_raw().await?;
396
397 let sql = crate::driver::rls::context_to_sql(&ctx);
399 let pg_conn = conn.get_mut()?;
400 if let Err(e) = execute_simple_with_timeout(
401 pg_conn,
402 &sql,
403 self.inner.config.connect_timeout,
404 "pool acquire_with_rls setup",
405 )
406 .await
407 {
408 if let Ok(pg_conn) = conn.get_mut() {
411 let _ = pg_conn.execute_simple("ROLLBACK").await;
412 }
413 conn.release().await;
414 return Err(e);
415 }
416
417 conn.rls_dirty = true;
419
420 Ok(conn)
421 }
422
423 pub async fn acquire_with_rls_timeout(
428 &self,
429 ctx: qail_core::rls::RlsContext,
430 timeout_ms: u32,
431 ) -> PgResult<PooledConnection> {
432 let mut conn = self.acquire_raw().await?;
434
435 let sql = crate::driver::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
437 let pg_conn = conn.get_mut()?;
438 if let Err(e) = execute_simple_with_timeout(
439 pg_conn,
440 &sql,
441 self.inner.config.connect_timeout,
442 "pool acquire_with_rls_timeout setup",
443 )
444 .await
445 {
446 if let Ok(pg_conn) = conn.get_mut() {
447 let _ = pg_conn.execute_simple("ROLLBACK").await;
448 }
449 conn.release().await;
450 return Err(e);
451 }
452
453 conn.rls_dirty = true;
455
456 Ok(conn)
457 }
458
459 pub async fn acquire_with_rls_timeouts(
465 &self,
466 ctx: qail_core::rls::RlsContext,
467 statement_timeout_ms: u32,
468 lock_timeout_ms: u32,
469 ) -> PgResult<PooledConnection> {
470 let mut conn = self.acquire_raw().await?;
472
473 let sql = crate::driver::rls::context_to_sql_with_timeouts(
474 &ctx,
475 statement_timeout_ms,
476 lock_timeout_ms,
477 );
478 let pg_conn = conn.get_mut()?;
479 if let Err(e) = execute_simple_with_timeout(
480 pg_conn,
481 &sql,
482 self.inner.config.connect_timeout,
483 "pool acquire_with_rls_timeouts setup",
484 )
485 .await
486 {
487 if let Ok(pg_conn) = conn.get_mut() {
488 let _ = pg_conn.execute_simple("ROLLBACK").await;
489 }
490 conn.release().await;
491 return Err(e);
492 }
493
494 conn.rls_dirty = true;
495
496 Ok(conn)
497 }
498
499 pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
509 let ctx = qail_core::rls::RlsContext::empty();
510 self.acquire_with_rls(ctx).await
511 }
512
513 pub async fn acquire_global(&self) -> PgResult<PooledConnection> {
519 self.acquire_with_rls(qail_core::rls::RlsContext::global())
520 .await
521 }
522
523 pub async fn acquire_for_tenant(&self, tenant_id: &str) -> PgResult<PooledConnection> {
535 self.acquire_with_rls(qail_core::rls::RlsContext::tenant(tenant_id))
536 .await
537 }
538
539 pub async fn acquire_with_branch(
553 &self,
554 ctx: &qail_core::branch::BranchContext,
555 ) -> PgResult<PooledConnection> {
556 let mut conn = self.acquire_raw().await?;
558
559 if let Some(branch_name) = ctx.branch_name() {
560 let sql = crate::driver::branch_sql::branch_context_sql(branch_name);
561 let pg_conn = conn.get_mut()?;
562 if let Err(e) = execute_simple_with_timeout(
563 pg_conn,
564 &sql,
565 self.inner.config.connect_timeout,
566 "pool acquire_with_branch setup",
567 )
568 .await
569 {
570 if let Ok(pg_conn) = conn.get_mut() {
571 let _ = pg_conn.execute_simple("ROLLBACK").await;
572 }
573 conn.release().await;
574 return Err(e);
575 }
576 conn.rls_dirty = true; }
578
579 Ok(conn)
580 }
581
582 pub async fn idle_count(&self) -> usize {
584 self.inner.connections.lock().await.len()
585 }
586
587 pub fn active_count(&self) -> usize {
589 self.inner.active_count.load(Ordering::Relaxed)
590 }
591
592 pub fn max_connections(&self) -> usize {
594 self.inner.config.max_connections
595 }
596
597 pub async fn stats(&self) -> PoolStats {
599 let idle = self.inner.connections.lock().await.len();
600 let active = self.inner.active_count.load(Ordering::Relaxed);
601 let used_slots = self
602 .inner
603 .config
604 .max_connections
605 .saturating_sub(self.inner.semaphore.available_permits());
606 PoolStats {
607 active,
608 idle,
609 pending: used_slots.saturating_sub(active),
610 max_size: self.inner.config.max_connections,
611 total_created: self.inner.total_created.load(Ordering::Relaxed),
612 }
613 }
614
615 pub fn is_closed(&self) -> bool {
617 self.inner.closed.load(Ordering::Relaxed)
618 }
619
620 pub async fn close(&self) {
627 self.close_graceful(self.inner.config.acquire_timeout).await;
628 }
629
630 pub async fn close_graceful(&self, drain_timeout: Duration) {
632 self.inner.closed.store(true, Ordering::Relaxed);
633 self.inner.semaphore.close();
635
636 let deadline = Instant::now() + drain_timeout;
637 loop {
638 let active = self.inner.active_count.load(Ordering::Relaxed);
639 if active == 0 {
640 break;
641 }
642 if Instant::now() >= deadline {
643 tracing::warn!(
644 active_connections = active,
645 timeout_ms = drain_timeout.as_millis() as u64,
646 "pool_close_drain_timeout: forcing idle cleanup while active connections remain"
647 );
648 break;
649 }
650 tokio::time::sleep(Duration::from_millis(25)).await;
651 }
652
653 let mut connections = self.inner.connections.lock().await;
654 let dropped_idle = connections.len();
655 connections.clear();
656 tracing::info!(
657 dropped_idle_connections = dropped_idle,
658 active_connections = self.inner.active_count.load(Ordering::Relaxed),
659 "pool_closed"
660 );
661 }
662
663 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
665 if !config.auth_settings.has_any_password_method()
666 && config.mtls.is_none()
667 && config.password.is_some()
668 {
669 return Err(PgError::Auth(
670 "Invalid PoolConfig: all password auth methods are disabled".to_string(),
671 ));
672 }
673
674 let options = ConnectOptions {
675 tls_mode: config.tls_mode,
676 gss_enc_mode: config.gss_enc_mode,
677 tls_ca_cert_pem: config.tls_ca_cert_pem.clone(),
678 mtls: config.mtls.clone(),
679 gss_token_provider: config.gss_token_provider,
680 gss_token_provider_ex: config.gss_token_provider_ex.clone(),
681 auth: config.auth_settings,
682 startup_params: Vec::new(),
683 };
684
685 if let Some(remaining) = gss_circuit_remaining_open(config) {
686 metrics::counter!("qail_pg_gss_circuit_open_total").increment(1);
687 tracing::warn!(
688 host = %config.host,
689 port = config.port,
690 user = %config.user,
691 db = %config.database,
692 remaining_ms = remaining.as_millis() as u64,
693 "gss_connect_circuit_open"
694 );
695 return Err(PgError::Connection(format!(
696 "GSS connection circuit is open; retry after {:?}",
697 remaining
698 )));
699 }
700
701 let mut attempt = 0usize;
702 loop {
703 let connect_result = tokio::time::timeout(
704 config.connect_timeout,
705 PgConnection::connect_with_options(
706 &config.host,
707 config.port,
708 &config.user,
709 &config.database,
710 config.password.as_deref(),
711 options.clone(),
712 ),
713 )
714 .await;
715
716 let connect_result = match connect_result {
717 Ok(result) => result,
718 Err(_) => Err(PgError::Timeout(format!(
719 "connect timeout after {:?} (pool config connect_timeout)",
720 config.connect_timeout
721 ))),
722 };
723
724 match connect_result {
725 Ok(conn) => {
726 metrics::counter!("qail_pg_pool_connect_success_total").increment(1);
727 gss_circuit_record_success(config);
728 return Ok(conn);
729 }
730 Err(err) if should_retry_gss_connect_error(config, attempt, &err) => {
731 metrics::counter!("qail_pg_gss_connect_retries_total").increment(1);
732 gss_circuit_record_failure(config);
733 let delay = gss_retry_delay(config.gss_retry_base_delay, attempt);
734 tracing::warn!(
735 host = %config.host,
736 port = config.port,
737 user = %config.user,
738 db = %config.database,
739 attempt = attempt + 1,
740 delay_ms = delay.as_millis() as u64,
741 error = %err,
742 "gss_connect_retry"
743 );
744 tokio::time::sleep(delay).await;
745 attempt += 1;
746 }
747 Err(err) => {
748 metrics::counter!("qail_pg_pool_connect_failures_total").increment(1);
749 if should_track_gss_circuit_error(config, &err) {
750 metrics::counter!("qail_pg_gss_connect_failures_total").increment(1);
751 gss_circuit_record_failure(config);
752 }
753 return Err(err);
754 }
755 }
756 }
757 }
758
759 pub async fn maintain(&self) {
762 if self.inner.closed.load(Ordering::Relaxed) {
763 return;
764 }
765
766 let evicted = {
768 let mut connections = self.inner.connections.lock().await;
769 let before = connections.len();
770 connections.retain(|pooled| {
771 if pooled.last_used.elapsed() > self.inner.config.idle_timeout {
772 record_pool_connection_destroy("idle_sweep_evict");
773 return false;
774 }
775 if let Some(max_life) = self.inner.config.max_lifetime
776 && pooled.created_at.elapsed() > max_life
777 {
778 record_pool_connection_destroy("lifetime_sweep_evict");
779 return false;
780 }
781 true
782 });
783 before - connections.len()
784 };
785
786 if evicted > 0 {
787 tracing::debug!(evicted, "pool_maintenance: evicted stale idle connections");
788 }
789
790 let min = self.inner.config.min_connections;
792 if min == 0 {
793 return;
794 }
795
796 let idle_count = self.inner.connections.lock().await.len();
797 if idle_count >= min {
798 return;
799 }
800
801 let deficit = min - idle_count;
802 let mut created = 0usize;
803 for _ in 0..deficit {
804 match Self::create_connection(&self.inner.config).await {
805 Ok(conn) => {
806 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
807 let mut connections = self.inner.connections.lock().await;
808 if connections.len() < self.inner.config.max_connections {
809 connections.push(PooledConn {
810 conn,
811 created_at: Instant::now(),
812 last_used: Instant::now(),
813 });
814 created += 1;
815 } else {
816 break;
818 }
819 }
820 Err(e) => {
821 tracing::warn!(error = %e, "pool_maintenance: backfill connection failed");
822 break; }
824 }
825 }
826
827 if created > 0 {
828 tracing::debug!(
829 created,
830 min_connections = min,
831 "pool_maintenance: backfilled idle connections"
832 );
833 }
834 }
835}
836
837pub fn spawn_pool_maintenance(pool: PgPool) {
842 let interval_secs = std::cmp::max(pool.inner.config.idle_timeout.as_secs() / 2, 5);
843 tokio::spawn(async move {
844 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
845 loop {
846 interval.tick().await;
847 if pool.is_closed() {
848 break;
849 }
850 pool.maintain().await;
851 }
852 });
853}
854
855pub(super) fn validate_pool_config(config: &PoolConfig) -> PgResult<()> {
856 if config.max_connections == 0 {
857 return Err(PgError::Connection(
858 "Invalid PoolConfig: max_connections must be >= 1".to_string(),
859 ));
860 }
861 if config.min_connections > config.max_connections {
862 return Err(PgError::Connection(format!(
863 "Invalid PoolConfig: min_connections ({}) must be <= max_connections ({})",
864 config.min_connections, config.max_connections
865 )));
866 }
867 if config.acquire_timeout.is_zero() {
868 return Err(PgError::Connection(
869 "Invalid PoolConfig: acquire_timeout must be > 0".to_string(),
870 ));
871 }
872 if config.connect_timeout.is_zero() {
873 return Err(PgError::Connection(
874 "Invalid PoolConfig: connect_timeout must be > 0".to_string(),
875 ));
876 }
877 Ok(())
878}
879
880pub(super) async fn execute_simple_with_timeout(
881 conn: &mut PgConnection,
882 sql: &str,
883 timeout: Duration,
884 operation: &str,
885) -> PgResult<()> {
886 match tokio::time::timeout(timeout, conn.execute_simple(sql)).await {
887 Ok(result) => result,
888 Err(_) => {
889 conn.mark_io_desynced();
890 Err(PgError::Timeout(format!(
891 "{} timeout after {:?} (pool config connect_timeout)",
892 operation, timeout
893 )))
894 }
895 }
896}