1use super::{
7 AuthSettings, ConnectOptions, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgConnection,
8 PgError, PgResult, ResultFormat, ScramChannelBindingMode, TlsConfig, TlsMode,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15use tokio::sync::{Mutex, Semaphore};
16
17#[derive(Clone)]
30pub struct PoolConfig {
31 pub host: String,
33 pub port: u16,
35 pub user: String,
37 pub database: String,
39 pub password: Option<String>,
41 pub max_connections: usize,
43 pub min_connections: usize,
45 pub idle_timeout: Duration,
47 pub acquire_timeout: Duration,
49 pub connect_timeout: Duration,
51 pub max_lifetime: Option<Duration>,
53 pub test_on_acquire: bool,
55 pub tls_mode: TlsMode,
57 pub tls_ca_cert_pem: Option<Vec<u8>>,
59 pub mtls: Option<TlsConfig>,
61 pub gss_token_provider: Option<GssTokenProvider>,
63 pub gss_token_provider_ex: Option<GssTokenProviderEx>,
65 pub gss_connect_retries: usize,
67 pub gss_retry_base_delay: Duration,
69 pub gss_circuit_breaker_threshold: usize,
71 pub gss_circuit_breaker_window: Duration,
73 pub gss_circuit_breaker_cooldown: Duration,
75 pub auth_settings: AuthSettings,
77 pub gss_enc_mode: GssEncMode,
79}
80
81impl PoolConfig {
82 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
94 Self {
95 host: host.to_string(),
96 port,
97 user: user.to_string(),
98 database: database.to_string(),
99 password: None,
100 max_connections: 10,
101 min_connections: 1,
102 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, tls_mode: TlsMode::Prefer,
108 tls_ca_cert_pem: None,
109 mtls: None,
110 gss_token_provider: None,
111 gss_token_provider_ex: None,
112 gss_connect_retries: 2,
113 gss_retry_base_delay: Duration::from_millis(150),
114 gss_circuit_breaker_threshold: 8,
115 gss_circuit_breaker_window: Duration::from_secs(30),
116 gss_circuit_breaker_cooldown: Duration::from_secs(15),
117 auth_settings: AuthSettings::scram_only(),
118 gss_enc_mode: GssEncMode::Disable,
119 }
120 }
121
122 pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
127 let mut config = Self::new(host, port, user, database);
128 config.tls_mode = TlsMode::Disable;
129 config.auth_settings = AuthSettings::default();
130 config
131 }
132
133 pub fn password(mut self, password: &str) -> Self {
135 self.password = Some(password.to_string());
136 self
137 }
138
139 pub fn max_connections(mut self, max: usize) -> Self {
141 self.max_connections = max;
142 self
143 }
144
145 pub fn min_connections(mut self, min: usize) -> Self {
147 self.min_connections = min;
148 self
149 }
150
151 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
153 self.idle_timeout = timeout;
154 self
155 }
156
157 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
159 self.acquire_timeout = timeout;
160 self
161 }
162
163 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
165 self.connect_timeout = timeout;
166 self
167 }
168
169 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
171 self.max_lifetime = Some(lifetime);
172 self
173 }
174
175 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
177 self.test_on_acquire = enabled;
178 self
179 }
180
181 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
183 self.tls_mode = mode;
184 self
185 }
186
187 pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
189 self.tls_ca_cert_pem = Some(ca_cert_pem);
190 self
191 }
192
193 pub fn mtls(mut self, config: TlsConfig) -> Self {
195 self.mtls = Some(config);
196 self.tls_mode = TlsMode::Require;
197 self
198 }
199
200 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
202 self.gss_token_provider = Some(provider);
203 self
204 }
205
206 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
208 self.gss_token_provider_ex = Some(provider);
209 self
210 }
211
212 pub fn gss_connect_retries(mut self, retries: usize) -> Self {
214 self.gss_connect_retries = retries;
215 self
216 }
217
218 pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
220 self.gss_retry_base_delay = delay;
221 self
222 }
223
224 pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
226 self.gss_circuit_breaker_threshold = threshold;
227 self
228 }
229
230 pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
232 self.gss_circuit_breaker_window = window;
233 self
234 }
235
236 pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
238 self.gss_circuit_breaker_cooldown = cooldown;
239 self
240 }
241
242 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
244 self.auth_settings = settings;
245 self
246 }
247
248 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
253 let pg = &qail.postgres;
254 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
255
256 let mut config = PoolConfig::new(&host, port, &user, &database)
257 .max_connections(pg.max_connections)
258 .min_connections(pg.min_connections)
259 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
260 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
261 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
262 .test_on_acquire(pg.test_on_acquire);
263
264 if let Some(ref pw) = password {
265 config = config.password(pw);
266 }
267
268 if let Some(query) = pg.url.split('?').nth(1) {
270 apply_url_query_params(&mut config, query, &host)?;
271 }
272
273 Ok(config)
274 }
275}
276
277#[allow(unused_variables)]
282pub(crate) fn apply_url_query_params(
283 config: &mut PoolConfig,
284 query: &str,
285 host: &str,
286) -> PgResult<()> {
287 let mut sslcert: Option<String> = None;
288 let mut sslkey: Option<String> = None;
289 let mut gss_provider: Option<String> = None;
290 let mut gss_service = "postgres".to_string();
291 let mut gss_target: Option<String> = None;
292
293 for pair in query.split('&').filter(|p| !p.is_empty()) {
294 let mut kv = pair.splitn(2, '=');
295 let key = kv.next().unwrap_or_default().trim();
296 let value = kv.next().unwrap_or_default().trim();
297
298 match key {
299 "sslmode" => {
300 let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
301 PgError::Connection(format!("Invalid sslmode value: {}", value))
302 })?;
303 config.tls_mode = mode;
304 }
305 "gssencmode" => {
306 let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
307 PgError::Connection(format!("Invalid gssencmode value: {}", value))
308 })?;
309 config.gss_enc_mode = mode;
310 }
311 "sslrootcert" => {
312 let ca_pem = std::fs::read(value).map_err(|e| {
313 PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
314 })?;
315 config.tls_ca_cert_pem = Some(ca_pem);
316 }
317 "sslcert" => sslcert = Some(value.to_string()),
318 "sslkey" => sslkey = Some(value.to_string()),
319 "channel_binding" => {
320 let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
321 PgError::Connection(format!("Invalid channel_binding value: {}", value))
322 })?;
323 config.auth_settings.channel_binding = mode;
324 }
325 "auth_scram" => {
326 let enabled = parse_bool_param(value).ok_or_else(|| {
327 PgError::Connection(format!("Invalid auth_scram value: {}", value))
328 })?;
329 config.auth_settings.allow_scram_sha_256 = enabled;
330 }
331 "auth_md5" => {
332 let enabled = parse_bool_param(value).ok_or_else(|| {
333 PgError::Connection(format!("Invalid auth_md5 value: {}", value))
334 })?;
335 config.auth_settings.allow_md5_password = enabled;
336 }
337 "auth_cleartext" => {
338 let enabled = parse_bool_param(value).ok_or_else(|| {
339 PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
340 })?;
341 config.auth_settings.allow_cleartext_password = enabled;
342 }
343 "auth_kerberos" => {
344 let enabled = parse_bool_param(value).ok_or_else(|| {
345 PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
346 })?;
347 config.auth_settings.allow_kerberos_v5 = enabled;
348 }
349 "auth_gssapi" => {
350 let enabled = parse_bool_param(value).ok_or_else(|| {
351 PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
352 })?;
353 config.auth_settings.allow_gssapi = enabled;
354 }
355 "auth_sspi" => {
356 let enabled = parse_bool_param(value).ok_or_else(|| {
357 PgError::Connection(format!("Invalid auth_sspi value: {}", value))
358 })?;
359 config.auth_settings.allow_sspi = enabled;
360 }
361 "auth_mode" => {
362 if value.eq_ignore_ascii_case("scram_only") {
363 config.auth_settings = AuthSettings::scram_only();
364 } else if value.eq_ignore_ascii_case("gssapi_only") {
365 config.auth_settings = AuthSettings::gssapi_only();
366 } else if value.eq_ignore_ascii_case("compat")
367 || value.eq_ignore_ascii_case("default")
368 {
369 config.auth_settings = AuthSettings::default();
370 } else {
371 return Err(PgError::Connection(format!(
372 "Invalid auth_mode value: {}",
373 value
374 )));
375 }
376 }
377 "gss_provider" => gss_provider = Some(value.to_string()),
378 "gss_service" => {
379 if value.is_empty() {
380 return Err(PgError::Connection(
381 "gss_service must not be empty".to_string(),
382 ));
383 }
384 gss_service = value.to_string();
385 }
386 "gss_target" => {
387 if value.is_empty() {
388 return Err(PgError::Connection(
389 "gss_target must not be empty".to_string(),
390 ));
391 }
392 gss_target = Some(value.to_string());
393 }
394 "gss_connect_retries" => {
395 let retries = value.parse::<usize>().map_err(|_| {
396 PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
397 })?;
398 if retries > 20 {
399 return Err(PgError::Connection(
400 "gss_connect_retries must be <= 20".to_string(),
401 ));
402 }
403 config.gss_connect_retries = retries;
404 }
405 "gss_retry_base_ms" => {
406 let delay_ms = value.parse::<u64>().map_err(|_| {
407 PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
408 })?;
409 if delay_ms == 0 {
410 return Err(PgError::Connection(
411 "gss_retry_base_ms must be greater than 0".to_string(),
412 ));
413 }
414 config.gss_retry_base_delay = Duration::from_millis(delay_ms);
415 }
416 "gss_circuit_threshold" => {
417 let threshold = value.parse::<usize>().map_err(|_| {
418 PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
419 })?;
420 if threshold > 100 {
421 return Err(PgError::Connection(
422 "gss_circuit_threshold must be <= 100".to_string(),
423 ));
424 }
425 config.gss_circuit_breaker_threshold = threshold;
426 }
427 "gss_circuit_window_ms" => {
428 let window_ms = value.parse::<u64>().map_err(|_| {
429 PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
430 })?;
431 if window_ms == 0 {
432 return Err(PgError::Connection(
433 "gss_circuit_window_ms must be greater than 0".to_string(),
434 ));
435 }
436 config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
437 }
438 "gss_circuit_cooldown_ms" => {
439 let cooldown_ms = value.parse::<u64>().map_err(|_| {
440 PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
441 })?;
442 if cooldown_ms == 0 {
443 return Err(PgError::Connection(
444 "gss_circuit_cooldown_ms must be greater than 0".to_string(),
445 ));
446 }
447 config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
448 }
449 _ => {}
450 }
451 }
452
453 match (sslcert.as_deref(), sslkey.as_deref()) {
454 (Some(cert_path), Some(key_path)) => {
455 let mtls = TlsConfig {
456 client_cert_pem: std::fs::read(cert_path).map_err(|e| {
457 PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
458 })?,
459 client_key_pem: std::fs::read(key_path).map_err(|e| {
460 PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
461 })?,
462 ca_cert_pem: config.tls_ca_cert_pem.clone(),
463 };
464 config.mtls = Some(mtls);
465 config.tls_mode = TlsMode::Require;
466 }
467 (Some(_), None) | (None, Some(_)) => {
468 return Err(PgError::Connection(
469 "Both sslcert and sslkey must be provided together".to_string(),
470 ));
471 }
472 (None, None) => {}
473 }
474
475 if let Some(provider) = gss_provider {
476 if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
477 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
478 {
479 let provider =
480 super::gss::linux_krb5_token_provider(super::gss::LinuxKrb5ProviderConfig {
481 host: host.to_string(),
482 service: gss_service.clone(),
483 target_name: gss_target.clone(),
484 })
485 .map_err(PgError::Auth)?;
486 config.gss_token_provider_ex = Some(provider);
487 }
488 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
489 {
490 let _ = gss_service;
491 let _ = gss_target;
492 return Err(PgError::Connection(
493 "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
494 .to_string(),
495 ));
496 }
497 } else if provider.eq_ignore_ascii_case("callback")
498 || provider.eq_ignore_ascii_case("custom")
499 {
500 } else {
502 return Err(PgError::Connection(format!(
503 "Invalid gss_provider value: {}",
504 provider
505 )));
506 }
507 }
508
509 Ok(())
510}
511
512fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
514 let url = url.split('?').next().unwrap_or(url);
515 let url = url
516 .trim_start_matches("postgres://")
517 .trim_start_matches("postgresql://");
518
519 let (credentials, host_part) = if url.contains('@') {
520 let mut parts = url.splitn(2, '@');
521 let creds = parts.next().unwrap_or("");
522 let host = parts.next().unwrap_or("localhost/postgres");
523 (Some(creds), host)
524 } else {
525 (None, url)
526 };
527
528 let (host_port, database) = if host_part.contains('/') {
529 let mut parts = host_part.splitn(2, '/');
530 (
531 parts.next().unwrap_or("localhost"),
532 parts.next().unwrap_or("postgres").to_string(),
533 )
534 } else {
535 (host_part, "postgres".to_string())
536 };
537
538 let (host, port) = if host_port.contains(':') {
539 let mut parts = host_port.split(':');
540 let h = parts.next().unwrap_or("localhost").to_string();
541 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
542 (h, p)
543 } else {
544 (host_port.to_string(), 5432u16)
545 };
546
547 let (user, password) = if let Some(creds) = credentials {
548 if creds.contains(':') {
549 let mut parts = creds.splitn(2, ':');
550 let u = parts.next().unwrap_or("postgres").to_string();
551 let p = parts.next().map(|s| s.to_string());
552 (u, p)
553 } else {
554 (creds.to_string(), None)
555 }
556 } else {
557 ("postgres".to_string(), None)
558 };
559
560 Ok((host, port, user, database, password))
561}
562
563fn parse_bool_param(value: &str) -> Option<bool> {
564 match value.trim().to_ascii_lowercase().as_str() {
565 "1" | "true" | "yes" | "on" => Some(true),
566 "0" | "false" | "no" | "off" => Some(false),
567 _ => None,
568 }
569}
570
571#[derive(Debug, Clone, Default)]
573pub struct PoolStats {
574 pub active: usize,
576 pub idle: usize,
578 pub pending: usize,
580 pub max_size: usize,
582 pub total_created: usize,
584}
585
586struct PooledConn {
588 conn: PgConnection,
589 created_at: Instant,
590 last_used: Instant,
591}
592
593pub struct PooledConnection {
599 conn: Option<PgConnection>,
600 pool: Arc<PgPoolInner>,
601 rls_dirty: bool,
602}
603
604impl PooledConnection {
605 fn conn_ref(&self) -> PgResult<&PgConnection> {
608 self.conn
609 .as_ref()
610 .ok_or_else(|| PgError::Connection("Connection already released back to pool".into()))
611 }
612
613 fn conn_mut(&mut self) -> PgResult<&mut PgConnection> {
616 self.conn
617 .as_mut()
618 .ok_or_else(|| PgError::Connection("Connection already released back to pool".into()))
619 }
620
621 pub fn get_mut(&mut self) -> &mut PgConnection {
624 self.conn
627 .as_mut()
628 .expect("Connection should always be present")
629 }
630
631 pub fn cancel_token(&self) -> PgResult<crate::driver::CancelToken> {
633 let conn = self.conn_ref()?;
634 let (process_id, secret_key) = conn.get_cancel_key();
635 Ok(crate::driver::CancelToken {
636 host: self.pool.config.host.clone(),
637 port: self.pool.config.port,
638 process_id,
639 secret_key,
640 })
641 }
642
643 pub async fn release(mut self) {
660 if let Some(mut conn) = self.conn.take() {
661 if let Err(e) = conn.execute_simple(super::rls::reset_sql()).await {
666 eprintln!(
667 "[CRITICAL] pool_release_failed: COMMIT failed — \
668 dropping connection to prevent state leak: {}",
669 e
670 );
671 return; }
673
674 self.pool.return_connection(conn).await;
675 }
676 }
677
678 pub async fn fetch_all_uncached(
681 &mut self,
682 cmd: &qail_core::ast::Qail,
683 ) -> PgResult<Vec<super::PgRow>> {
684 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
685 .await
686 }
687
688 pub async fn query_raw_with_params(
696 &mut self,
697 sql: &str,
698 params: &[Option<Vec<u8>>],
699 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
700 let conn = self.conn_mut()?;
701 conn.query(sql, params).await
702 }
703
704 pub async fn fetch_all_uncached_with_format(
706 &mut self,
707 cmd: &qail_core::ast::Qail,
708 result_format: ResultFormat,
709 ) -> PgResult<Vec<super::PgRow>> {
710 use super::ColumnInfo;
711 use crate::protocol::AstEncoder;
712
713 let conn = self.conn_mut()?;
714
715 AstEncoder::encode_cmd_reuse_into_with_result_format(
716 cmd,
717 &mut conn.sql_buf,
718 &mut conn.params_buf,
719 &mut conn.write_buf,
720 result_format.as_wire_code(),
721 )
722 .map_err(|e| PgError::Encode(e.to_string()))?;
723
724 conn.flush_write_buf().await?;
725
726 let mut rows: Vec<super::PgRow> = Vec::new();
727 let mut column_info: Option<Arc<ColumnInfo>> = None;
728 let mut error: Option<PgError> = None;
729
730 loop {
731 let msg = conn.recv().await?;
732 match msg {
733 crate::protocol::BackendMessage::ParseComplete
734 | crate::protocol::BackendMessage::BindComplete => {}
735 crate::protocol::BackendMessage::RowDescription(fields) => {
736 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
737 }
738 crate::protocol::BackendMessage::DataRow(data) => {
739 if error.is_none() {
740 rows.push(super::PgRow {
741 columns: data,
742 column_info: column_info.clone(),
743 });
744 }
745 }
746 crate::protocol::BackendMessage::CommandComplete(_) => {}
747 crate::protocol::BackendMessage::ReadyForQuery(_) => {
748 if let Some(err) = error {
749 return Err(err);
750 }
751 return Ok(rows);
752 }
753 crate::protocol::BackendMessage::ErrorResponse(err) => {
754 if error.is_none() {
755 error = Some(PgError::QueryServer(err.into()));
756 }
757 }
758 _ => {}
759 }
760 }
761 }
762
763 pub async fn fetch_all_fast(
767 &mut self,
768 cmd: &qail_core::ast::Qail,
769 ) -> PgResult<Vec<super::PgRow>> {
770 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
771 .await
772 }
773
774 pub async fn fetch_all_fast_with_format(
776 &mut self,
777 cmd: &qail_core::ast::Qail,
778 result_format: ResultFormat,
779 ) -> PgResult<Vec<super::PgRow>> {
780 use crate::protocol::AstEncoder;
781
782 let conn = self.conn_mut()?;
783
784 AstEncoder::encode_cmd_reuse_into_with_result_format(
785 cmd,
786 &mut conn.sql_buf,
787 &mut conn.params_buf,
788 &mut conn.write_buf,
789 result_format.as_wire_code(),
790 )
791 .map_err(|e| PgError::Encode(e.to_string()))?;
792
793 conn.flush_write_buf().await?;
794
795 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
796 let mut error: Option<PgError> = None;
797
798 loop {
799 let res = conn.recv_with_data_fast().await;
800 match res {
801 Ok((msg_type, data)) => match msg_type {
802 b'D' => {
803 if error.is_none()
804 && let Some(columns) = data
805 {
806 rows.push(super::PgRow {
807 columns,
808 column_info: None,
809 });
810 }
811 }
812 b'Z' => {
813 if let Some(err) = error {
814 return Err(err);
815 }
816 return Ok(rows);
817 }
818 _ => {}
819 },
820 Err(e) => {
821 if error.is_none() {
822 error = Some(e);
823 }
824 }
825 }
826 }
827 }
828
829 pub async fn fetch_all_cached(
834 &mut self,
835 cmd: &qail_core::ast::Qail,
836 ) -> PgResult<Vec<super::PgRow>> {
837 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
838 .await
839 }
840
841 pub async fn fetch_all_cached_with_format(
843 &mut self,
844 cmd: &qail_core::ast::Qail,
845 result_format: ResultFormat,
846 ) -> PgResult<Vec<super::PgRow>> {
847 let mut retried = false;
848 loop {
849 match self
850 .fetch_all_cached_with_format_once(cmd, result_format)
851 .await
852 {
853 Ok(rows) => return Ok(rows),
854 Err(err) if !retried && err.is_prepared_statement_retryable() => {
855 retried = true;
856 if let Some(conn) = self.conn.as_mut() {
857 conn.clear_prepared_statement_state();
858 }
859 }
860 Err(err) => return Err(err),
861 }
862 }
863 }
864
865 async fn fetch_all_cached_with_format_once(
866 &mut self,
867 cmd: &qail_core::ast::Qail,
868 result_format: ResultFormat,
869 ) -> PgResult<Vec<super::PgRow>> {
870 use super::ColumnInfo;
871 use std::collections::hash_map::DefaultHasher;
872 use std::hash::{Hash, Hasher};
873
874 let conn = self.conn.as_mut().ok_or_else(|| {
875 PgError::Connection("Connection already released back to pool".into())
876 })?;
877
878 conn.sql_buf.clear();
879 conn.params_buf.clear();
880
881 match cmd.action {
883 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
884 crate::protocol::ast_encoder::dml::encode_select(
885 cmd,
886 &mut conn.sql_buf,
887 &mut conn.params_buf,
888 )?;
889 }
890 qail_core::ast::Action::Add => {
891 crate::protocol::ast_encoder::dml::encode_insert(
892 cmd,
893 &mut conn.sql_buf,
894 &mut conn.params_buf,
895 )?;
896 }
897 qail_core::ast::Action::Set => {
898 crate::protocol::ast_encoder::dml::encode_update(
899 cmd,
900 &mut conn.sql_buf,
901 &mut conn.params_buf,
902 )?;
903 }
904 qail_core::ast::Action::Del => {
905 crate::protocol::ast_encoder::dml::encode_delete(
906 cmd,
907 &mut conn.sql_buf,
908 &mut conn.params_buf,
909 )?;
910 }
911 _ => {
912 return self
914 .fetch_all_uncached_with_format(cmd, result_format)
915 .await;
916 }
917 }
918
919 let mut hasher = DefaultHasher::new();
920 conn.sql_buf.hash(&mut hasher);
921 let sql_hash = hasher.finish();
922
923 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
924
925 conn.write_buf.clear();
926
927 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
928 name
929 } else {
930 let name = format!("qail_{:x}", sql_hash);
931
932 conn.evict_prepared_if_full();
933
934 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
935
936 use crate::protocol::PgEncoder;
937 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
938 let describe_msg = PgEncoder::encode_describe(false, &name);
939 conn.write_buf.extend_from_slice(&parse_msg);
940 conn.write_buf.extend_from_slice(&describe_msg);
941
942 conn.stmt_cache.put(sql_hash, name.clone());
943 conn.prepared_statements
944 .insert(name.clone(), sql_str.to_string());
945
946 if let Ok(mut hot) = self.pool.hot_statements.write()
948 && hot.len() < MAX_HOT_STATEMENTS
949 {
950 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
951 }
952
953 name
954 };
955
956 use crate::protocol::PgEncoder;
957 PgEncoder::encode_bind_to_with_result_format(
958 &mut conn.write_buf,
959 &stmt_name,
960 &conn.params_buf,
961 result_format.as_wire_code(),
962 )
963 .map_err(|e| PgError::Encode(e.to_string()))?;
964 PgEncoder::encode_execute_to(&mut conn.write_buf);
965 PgEncoder::encode_sync_to(&mut conn.write_buf);
966
967 conn.flush_write_buf().await?;
968
969 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
970
971 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
972 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
973 let mut error: Option<PgError> = None;
974
975 loop {
976 let msg = conn.recv().await?;
977 match msg {
978 crate::protocol::BackendMessage::ParseComplete
979 | crate::protocol::BackendMessage::BindComplete => {}
980 crate::protocol::BackendMessage::ParameterDescription(_) => {}
981 crate::protocol::BackendMessage::RowDescription(fields) => {
982 let info = Arc::new(ColumnInfo::from_fields(&fields));
983 if is_cache_miss {
984 conn.column_info_cache.insert(sql_hash, info.clone());
985 }
986 column_info = Some(info);
987 }
988 crate::protocol::BackendMessage::DataRow(data) => {
989 if error.is_none() {
990 rows.push(super::PgRow {
991 columns: data,
992 column_info: column_info.clone(),
993 });
994 }
995 }
996 crate::protocol::BackendMessage::CommandComplete(_) => {}
997 crate::protocol::BackendMessage::ReadyForQuery(_) => {
998 if let Some(err) = error {
999 return Err(err);
1000 }
1001 return Ok(rows);
1002 }
1003 crate::protocol::BackendMessage::ErrorResponse(err) => {
1004 if error.is_none() {
1005 error = Some(PgError::QueryServer(err.into()));
1006 }
1007 }
1008 _ => {}
1009 }
1010 }
1011 }
1012
1013 pub async fn fetch_all_with_rls(
1032 &mut self,
1033 cmd: &qail_core::ast::Qail,
1034 rls_sql: &str,
1035 ) -> PgResult<Vec<super::PgRow>> {
1036 self.fetch_all_with_rls_with_format(cmd, rls_sql, ResultFormat::Text)
1037 .await
1038 }
1039
1040 pub async fn fetch_all_with_rls_with_format(
1042 &mut self,
1043 cmd: &qail_core::ast::Qail,
1044 rls_sql: &str,
1045 result_format: ResultFormat,
1046 ) -> PgResult<Vec<super::PgRow>> {
1047 let mut retried = false;
1048 loop {
1049 match self
1050 .fetch_all_with_rls_with_format_once(cmd, rls_sql, result_format)
1051 .await
1052 {
1053 Ok(rows) => return Ok(rows),
1054 Err(err) if !retried && err.is_prepared_statement_retryable() => {
1055 retried = true;
1056 if let Some(conn) = self.conn.as_mut() {
1057 conn.clear_prepared_statement_state();
1058 let _ = conn.execute_simple("ROLLBACK").await;
1059 }
1060 self.rls_dirty = false;
1061 }
1062 Err(err) => return Err(err),
1063 }
1064 }
1065 }
1066
1067 async fn fetch_all_with_rls_with_format_once(
1068 &mut self,
1069 cmd: &qail_core::ast::Qail,
1070 rls_sql: &str,
1071 result_format: ResultFormat,
1072 ) -> PgResult<Vec<super::PgRow>> {
1073 use super::ColumnInfo;
1074 use std::collections::hash_map::DefaultHasher;
1075 use std::hash::{Hash, Hasher};
1076
1077 let conn = self.conn.as_mut().ok_or_else(|| {
1078 PgError::Connection("Connection already released back to pool".into())
1079 })?;
1080
1081 conn.sql_buf.clear();
1082 conn.params_buf.clear();
1083
1084 if cmd.is_raw_sql() {
1086 conn.sql_buf.clear();
1088 conn.params_buf.clear();
1089 conn.sql_buf.extend_from_slice(cmd.table.as_bytes());
1090 } else {
1091 match cmd.action {
1092 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
1093 crate::protocol::ast_encoder::dml::encode_select(
1094 cmd,
1095 &mut conn.sql_buf,
1096 &mut conn.params_buf,
1097 )?;
1098 }
1099 qail_core::ast::Action::Add => {
1100 crate::protocol::ast_encoder::dml::encode_insert(
1101 cmd,
1102 &mut conn.sql_buf,
1103 &mut conn.params_buf,
1104 )?;
1105 }
1106 qail_core::ast::Action::Set => {
1107 crate::protocol::ast_encoder::dml::encode_update(
1108 cmd,
1109 &mut conn.sql_buf,
1110 &mut conn.params_buf,
1111 )?;
1112 }
1113 qail_core::ast::Action::Del => {
1114 crate::protocol::ast_encoder::dml::encode_delete(
1115 cmd,
1116 &mut conn.sql_buf,
1117 &mut conn.params_buf,
1118 )?;
1119 }
1120 _ => {
1121 conn.execute_simple(rls_sql).await?;
1123 self.rls_dirty = true;
1124 return self
1125 .fetch_all_uncached_with_format(cmd, result_format)
1126 .await;
1127 }
1128 }
1129 }
1130
1131 let mut hasher = DefaultHasher::new();
1132 conn.sql_buf.hash(&mut hasher);
1133 let sql_hash = hasher.finish();
1134
1135 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
1136
1137 conn.write_buf.clear();
1138
1139 let rls_msg = crate::protocol::PgEncoder::encode_query_string(rls_sql);
1143 conn.write_buf.extend_from_slice(&rls_msg);
1144
1145 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
1147 name
1148 } else {
1149 let name = format!("qail_{:x}", sql_hash);
1150
1151 conn.evict_prepared_if_full();
1152
1153 let sql_str = std::str::from_utf8(&conn.sql_buf).unwrap_or("");
1154
1155 use crate::protocol::PgEncoder;
1156 let parse_msg = PgEncoder::encode_parse(&name, sql_str, &[]);
1157 let describe_msg = PgEncoder::encode_describe(false, &name);
1158 conn.write_buf.extend_from_slice(&parse_msg);
1159 conn.write_buf.extend_from_slice(&describe_msg);
1160
1161 conn.stmt_cache.put(sql_hash, name.clone());
1162 conn.prepared_statements
1163 .insert(name.clone(), sql_str.to_string());
1164
1165 if let Ok(mut hot) = self.pool.hot_statements.write()
1166 && hot.len() < MAX_HOT_STATEMENTS
1167 {
1168 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
1169 }
1170
1171 name
1172 };
1173
1174 use crate::protocol::PgEncoder;
1175 PgEncoder::encode_bind_to_with_result_format(
1176 &mut conn.write_buf,
1177 &stmt_name,
1178 &conn.params_buf,
1179 result_format.as_wire_code(),
1180 )
1181 .map_err(|e| PgError::Encode(e.to_string()))?;
1182 PgEncoder::encode_execute_to(&mut conn.write_buf);
1183 PgEncoder::encode_sync_to(&mut conn.write_buf);
1184
1185 conn.flush_write_buf().await?;
1187
1188 self.rls_dirty = true;
1190
1191 let mut rls_error: Option<PgError> = None;
1195 loop {
1196 let msg = conn.recv().await?;
1197 match msg {
1198 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1199 if let Some(err) = rls_error {
1201 return Err(err);
1202 }
1203 break;
1204 }
1205 crate::protocol::BackendMessage::ErrorResponse(err) => {
1206 if rls_error.is_none() {
1207 rls_error = Some(PgError::QueryServer(err.into()));
1208 }
1209 }
1210 _ => {}
1212 }
1213 }
1214
1215 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
1217
1218 let mut rows: Vec<super::PgRow> = Vec::with_capacity(32);
1219 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
1220 let mut error: Option<PgError> = None;
1221
1222 loop {
1223 let msg = conn.recv().await?;
1224 match msg {
1225 crate::protocol::BackendMessage::ParseComplete
1226 | crate::protocol::BackendMessage::BindComplete => {}
1227 crate::protocol::BackendMessage::ParameterDescription(_) => {}
1228 crate::protocol::BackendMessage::RowDescription(fields) => {
1229 let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
1230 if is_cache_miss {
1231 conn.column_info_cache.insert(sql_hash, info.clone());
1232 }
1233 column_info = Some(info);
1234 }
1235 crate::protocol::BackendMessage::DataRow(data) => {
1236 if error.is_none() {
1237 rows.push(super::PgRow {
1238 columns: data,
1239 column_info: column_info.clone(),
1240 });
1241 }
1242 }
1243 crate::protocol::BackendMessage::CommandComplete(_) => {}
1244 crate::protocol::BackendMessage::ReadyForQuery(_) => {
1245 if let Some(err) = error {
1246 return Err(err);
1247 }
1248 return Ok(rows);
1249 }
1250 crate::protocol::BackendMessage::ErrorResponse(err) => {
1251 if error.is_none() {
1252 error = Some(PgError::QueryServer(err.into()));
1253 }
1254 }
1255 _ => {}
1256 }
1257 }
1258 }
1259
1260 pub async fn pipeline_ast(
1268 &mut self,
1269 cmds: &[qail_core::ast::Qail],
1270 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
1271 let conn = self.conn_mut()?;
1272 conn.pipeline_ast(cmds).await
1273 }
1274
1275 pub async fn explain_estimate(
1281 &mut self,
1282 cmd: &qail_core::ast::Qail,
1283 ) -> PgResult<Option<super::explain::ExplainEstimate>> {
1284 use qail_core::transpiler::ToSql;
1285
1286 let sql = cmd.to_sql();
1287 let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", sql);
1288
1289 let rows = self.simple_query(&explain_sql).await?;
1290
1291 let mut json_output = String::new();
1293 for row in &rows {
1294 if let Some(Some(val)) = row.columns.first()
1295 && let Ok(text) = std::str::from_utf8(val)
1296 {
1297 json_output.push_str(text);
1298 }
1299 }
1300
1301 Ok(super::explain::parse_explain_json(&json_output))
1302 }
1303}
1304
1305impl Drop for PooledConnection {
1306 fn drop(&mut self) {
1307 if self.conn.is_some() {
1308 eprintln!(
1322 "[WARN] pool_connection_leaked: PooledConnection dropped without release() — \
1323 connection destroyed to prevent state leak (rls_dirty={}). \
1324 Use conn.release().await for deterministic cleanup.",
1325 self.rls_dirty
1326 );
1327 self.pool
1329 .active_count
1330 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
1331 self.pool.semaphore.add_permits(1);
1334 }
1335 }
1336}
1337
1338impl std::ops::Deref for PooledConnection {
1339 type Target = PgConnection;
1340
1341 fn deref(&self) -> &Self::Target {
1342 self.conn
1345 .as_ref()
1346 .expect("PooledConnection::deref called after release — this is a bug")
1347 }
1348}
1349
1350impl std::ops::DerefMut for PooledConnection {
1351 fn deref_mut(&mut self) -> &mut Self::Target {
1352 self.conn
1355 .as_mut()
1356 .expect("PooledConnection::deref_mut called after release — this is a bug")
1357 }
1358}
1359
1360const MAX_HOT_STATEMENTS: usize = 32;
1362
1363struct PgPoolInner {
1365 config: PoolConfig,
1366 connections: Mutex<Vec<PooledConn>>,
1367 semaphore: Semaphore,
1368 closed: AtomicBool,
1369 active_count: AtomicUsize,
1370 total_created: AtomicUsize,
1371 hot_statements: std::sync::RwLock<std::collections::HashMap<u64, (String, String)>>,
1375}
1376
1377impl PgPoolInner {
1378 async fn return_connection(&self, conn: PgConnection) {
1379 self.active_count.fetch_sub(1, Ordering::Relaxed);
1380
1381 if self.closed.load(Ordering::Relaxed) {
1382 return;
1383 }
1384
1385 let mut connections = self.connections.lock().await;
1386 if connections.len() < self.config.max_connections {
1387 connections.push(PooledConn {
1388 conn,
1389 created_at: Instant::now(),
1390 last_used: Instant::now(),
1391 });
1392 }
1393
1394 self.semaphore.add_permits(1);
1395 }
1396
1397 async fn get_healthy_connection(&self) -> Option<PgConnection> {
1399 let mut connections = self.connections.lock().await;
1400
1401 while let Some(pooled) = connections.pop() {
1402 if pooled.last_used.elapsed() > self.config.idle_timeout {
1403 continue;
1405 }
1406
1407 if let Some(max_life) = self.config.max_lifetime
1408 && pooled.created_at.elapsed() > max_life
1409 {
1410 continue;
1412 }
1413
1414 return Some(pooled.conn);
1415 }
1416
1417 None
1418 }
1419}
1420
1421#[derive(Clone)]
1432pub struct PgPool {
1433 inner: Arc<PgPoolInner>,
1434}
1435
1436impl PgPool {
1437 pub async fn from_config() -> PgResult<Self> {
1444 let qail = qail_core::config::QailConfig::load()
1445 .map_err(|e| PgError::Connection(format!("Config error: {}", e)))?;
1446 let config = PoolConfig::from_qail_config(&qail)?;
1447 Self::connect(config).await
1448 }
1449
1450 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
1452 let semaphore = Semaphore::new(config.max_connections);
1454
1455 let mut initial_connections = Vec::new();
1456 for _ in 0..config.min_connections {
1457 let conn = Self::create_connection(&config).await?;
1458 initial_connections.push(PooledConn {
1459 conn,
1460 created_at: Instant::now(),
1461 last_used: Instant::now(),
1462 });
1463 }
1464
1465 let initial_count = initial_connections.len();
1466
1467 let inner = Arc::new(PgPoolInner {
1468 config,
1469 connections: Mutex::new(initial_connections),
1470 semaphore,
1471 closed: AtomicBool::new(false),
1472 active_count: AtomicUsize::new(0),
1473 total_created: AtomicUsize::new(initial_count),
1474 hot_statements: std::sync::RwLock::new(std::collections::HashMap::new()),
1475 });
1476
1477 Ok(Self { inner })
1478 }
1479
1480 pub async fn acquire_raw(&self) -> PgResult<PooledConnection> {
1495 if self.inner.closed.load(Ordering::Relaxed) {
1496 return Err(PgError::PoolClosed);
1497 }
1498
1499 let acquire_timeout = self.inner.config.acquire_timeout;
1501 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
1502 .await
1503 .map_err(|_| {
1504 PgError::Timeout(format!(
1505 "pool acquire after {}s ({} max connections)",
1506 acquire_timeout.as_secs(),
1507 self.inner.config.max_connections
1508 ))
1509 })?
1510 .map_err(|_| PgError::PoolClosed)?;
1511
1512 let mut conn = if let Some(conn) = self.inner.get_healthy_connection().await {
1514 conn
1515 } else {
1516 let conn = Self::create_connection(&self.inner.config).await?;
1517 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
1518 conn
1519 };
1520
1521 if self.inner.config.test_on_acquire
1522 && let Err(e) = conn.execute_simple("SELECT 1").await
1523 {
1524 eprintln!(
1525 "[WARN] pool_health_check_failed: checkout probe failed, creating replacement connection: {}",
1526 e
1527 );
1528 conn = Self::create_connection(&self.inner.config).await?;
1529 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
1530 }
1531
1532 let missing: Vec<(u64, String, String)> = {
1535 if let Ok(hot) = self.inner.hot_statements.read() {
1536 hot.iter()
1537 .filter(|(hash, _)| !conn.stmt_cache.contains(hash))
1538 .map(|(hash, (name, sql))| (*hash, name.clone(), sql.clone()))
1539 .collect()
1540 } else {
1541 Vec::new()
1542 }
1543 }; if !missing.is_empty() {
1546 use crate::protocol::PgEncoder;
1547 let mut buf = bytes::BytesMut::new();
1548 for (_, name, sql) in &missing {
1549 let parse_msg = PgEncoder::encode_parse(name, sql, &[]);
1550 buf.extend_from_slice(&parse_msg);
1551 }
1552 PgEncoder::encode_sync_to(&mut buf);
1553 if conn.send_bytes(&buf).await.is_ok() {
1554 loop {
1556 match conn.recv().await {
1557 Ok(crate::protocol::BackendMessage::ReadyForQuery(_)) => break,
1558 Ok(_) => continue,
1559 Err(_) => break,
1560 }
1561 }
1562 for (hash, name, sql) in &missing {
1564 conn.stmt_cache.put(*hash, name.clone());
1565 conn.prepared_statements.insert(name.clone(), sql.clone());
1566 }
1567 }
1568 }
1569
1570 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
1571 permit.forget();
1573
1574 Ok(PooledConnection {
1575 conn: Some(conn),
1576 pool: self.inner.clone(),
1577 rls_dirty: false,
1578 })
1579 }
1580
1581 pub async fn acquire_with_rls(
1597 &self,
1598 ctx: qail_core::rls::RlsContext,
1599 ) -> PgResult<PooledConnection> {
1600 let mut conn = self.acquire_raw().await?;
1602
1603 let sql = super::rls::context_to_sql(&ctx);
1605 let pg_conn = conn.get_mut();
1606 pg_conn.execute_simple(&sql).await?;
1607
1608 conn.rls_dirty = true;
1610
1611 Ok(conn)
1612 }
1613
1614 pub async fn acquire_with_rls_timeout(
1619 &self,
1620 ctx: qail_core::rls::RlsContext,
1621 timeout_ms: u32,
1622 ) -> PgResult<PooledConnection> {
1623 let mut conn = self.acquire_raw().await?;
1625
1626 let sql = super::rls::context_to_sql_with_timeout(&ctx, timeout_ms);
1628 let pg_conn = conn.get_mut();
1629 pg_conn.execute_simple(&sql).await?;
1630
1631 conn.rls_dirty = true;
1633
1634 Ok(conn)
1635 }
1636
1637 pub async fn acquire_with_rls_timeouts(
1643 &self,
1644 ctx: qail_core::rls::RlsContext,
1645 statement_timeout_ms: u32,
1646 lock_timeout_ms: u32,
1647 ) -> PgResult<PooledConnection> {
1648 let mut conn = self.acquire_raw().await?;
1650
1651 let sql =
1652 super::rls::context_to_sql_with_timeouts(&ctx, statement_timeout_ms, lock_timeout_ms);
1653 let pg_conn = conn.get_mut();
1654 pg_conn.execute_simple(&sql).await?;
1655
1656 conn.rls_dirty = true;
1657
1658 Ok(conn)
1659 }
1660
1661 pub async fn acquire_system(&self) -> PgResult<PooledConnection> {
1671 let ctx = qail_core::rls::RlsContext::empty();
1672 self.acquire_with_rls(ctx).await
1673 }
1674
1675 pub async fn acquire_for_tenant(&self, tenant_id: &str) -> PgResult<PooledConnection> {
1687 self.acquire_with_rls(qail_core::rls::RlsContext::tenant(tenant_id))
1688 .await
1689 }
1690
1691 pub async fn acquire_with_branch(
1705 &self,
1706 ctx: &qail_core::branch::BranchContext,
1707 ) -> PgResult<PooledConnection> {
1708 let mut conn = self.acquire_raw().await?;
1710
1711 if let Some(branch_name) = ctx.branch_name() {
1712 let sql = super::branch_sql::branch_context_sql(branch_name);
1713 let pg_conn = conn.get_mut();
1714 pg_conn.execute_simple(&sql).await?;
1715 conn.rls_dirty = true; }
1717
1718 Ok(conn)
1719 }
1720
1721 pub async fn idle_count(&self) -> usize {
1723 self.inner.connections.lock().await.len()
1724 }
1725
1726 pub fn active_count(&self) -> usize {
1728 self.inner.active_count.load(Ordering::Relaxed)
1729 }
1730
1731 pub fn max_connections(&self) -> usize {
1733 self.inner.config.max_connections
1734 }
1735
1736 pub async fn stats(&self) -> PoolStats {
1738 let idle = self.inner.connections.lock().await.len();
1739 PoolStats {
1740 active: self.inner.active_count.load(Ordering::Relaxed),
1741 idle,
1742 pending: self.inner.config.max_connections
1743 - self.inner.semaphore.available_permits()
1744 - self.active_count(),
1745 max_size: self.inner.config.max_connections,
1746 total_created: self.inner.total_created.load(Ordering::Relaxed),
1747 }
1748 }
1749
1750 pub fn is_closed(&self) -> bool {
1752 self.inner.closed.load(Ordering::Relaxed)
1753 }
1754
1755 pub async fn close(&self) {
1757 self.inner.closed.store(true, Ordering::Relaxed);
1758
1759 let mut connections = self.inner.connections.lock().await;
1760 connections.clear();
1761 }
1762
1763 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
1765 if !config.auth_settings.has_any_password_method()
1766 && config.mtls.is_none()
1767 && config.password.is_some()
1768 {
1769 return Err(PgError::Auth(
1770 "Invalid PoolConfig: all password auth methods are disabled".to_string(),
1771 ));
1772 }
1773
1774 let options = ConnectOptions {
1775 tls_mode: config.tls_mode,
1776 gss_enc_mode: config.gss_enc_mode,
1777 tls_ca_cert_pem: config.tls_ca_cert_pem.clone(),
1778 mtls: config.mtls.clone(),
1779 gss_token_provider: config.gss_token_provider,
1780 gss_token_provider_ex: config.gss_token_provider_ex.clone(),
1781 auth: config.auth_settings,
1782 };
1783
1784 if let Some(remaining) = gss_circuit_remaining_open(config) {
1785 metrics::counter!("qail_pg_gss_circuit_open_total").increment(1);
1786 tracing::warn!(
1787 host = %config.host,
1788 port = config.port,
1789 user = %config.user,
1790 db = %config.database,
1791 remaining_ms = remaining.as_millis() as u64,
1792 "gss_connect_circuit_open"
1793 );
1794 return Err(PgError::Connection(format!(
1795 "GSS connection circuit is open; retry after {:?}",
1796 remaining
1797 )));
1798 }
1799
1800 let mut attempt = 0usize;
1801 loop {
1802 match PgConnection::connect_with_options(
1803 &config.host,
1804 config.port,
1805 &config.user,
1806 &config.database,
1807 config.password.as_deref(),
1808 options.clone(),
1809 )
1810 .await
1811 {
1812 Ok(conn) => {
1813 gss_circuit_record_success(config);
1814 return Ok(conn);
1815 }
1816 Err(err) if should_retry_gss_connect_error(config, attempt, &err) => {
1817 metrics::counter!("qail_pg_gss_connect_retries_total").increment(1);
1818 gss_circuit_record_failure(config);
1819 let delay = gss_retry_delay(config.gss_retry_base_delay, attempt);
1820 tracing::warn!(
1821 host = %config.host,
1822 port = config.port,
1823 user = %config.user,
1824 db = %config.database,
1825 attempt = attempt + 1,
1826 delay_ms = delay.as_millis() as u64,
1827 error = %err,
1828 "gss_connect_retry"
1829 );
1830 tokio::time::sleep(delay).await;
1831 attempt += 1;
1832 }
1833 Err(err) => {
1834 if should_track_gss_circuit_error(config, &err) {
1835 metrics::counter!("qail_pg_gss_connect_failures_total").increment(1);
1836 gss_circuit_record_failure(config);
1837 }
1838 return Err(err);
1839 }
1840 }
1841 }
1842 }
1843}
1844
1845fn should_retry_gss_connect_error(config: &PoolConfig, attempt: usize, err: &PgError) -> bool {
1846 if attempt >= config.gss_connect_retries {
1847 return false;
1848 }
1849
1850 if !is_gss_auth_enabled(config) {
1851 return false;
1852 }
1853
1854 match err {
1855 PgError::Auth(msg) | PgError::Connection(msg) => is_transient_gss_message(msg),
1856 PgError::Timeout(_) => true,
1857 PgError::Io(io) => matches!(
1858 io.kind(),
1859 std::io::ErrorKind::TimedOut
1860 | std::io::ErrorKind::ConnectionRefused
1861 | std::io::ErrorKind::ConnectionReset
1862 | std::io::ErrorKind::BrokenPipe
1863 | std::io::ErrorKind::Interrupted
1864 | std::io::ErrorKind::WouldBlock
1865 ),
1866 _ => false,
1867 }
1868}
1869
1870fn is_gss_auth_enabled(config: &PoolConfig) -> bool {
1871 config.gss_token_provider.is_some()
1872 || config.gss_token_provider_ex.is_some()
1873 || config.auth_settings.allow_kerberos_v5
1874 || config.auth_settings.allow_gssapi
1875 || config.auth_settings.allow_sspi
1876}
1877
1878fn is_transient_gss_message(msg: &str) -> bool {
1879 let msg = msg.to_ascii_lowercase();
1880 [
1881 "temporary",
1882 "temporarily unavailable",
1883 "try again",
1884 "timed out",
1885 "timeout",
1886 "connection reset",
1887 "connection refused",
1888 "network is unreachable",
1889 "resource temporarily unavailable",
1890 "service unavailable",
1891 ]
1892 .iter()
1893 .any(|needle| msg.contains(needle))
1894}
1895
1896fn gss_retry_delay(base: Duration, attempt: usize) -> Duration {
1897 let factor = 1u32 << attempt.min(6);
1898 let delay = base.saturating_mul(factor).min(Duration::from_secs(5));
1899 let jitter_cap_ms = ((delay.as_millis() as u64) / 5).clamp(1, 250);
1900 let jitter_ms = pseudo_random_jitter_ms(jitter_cap_ms);
1901 delay.saturating_add(Duration::from_millis(jitter_ms))
1902}
1903
1904fn pseudo_random_jitter_ms(max_inclusive: u64) -> u64 {
1905 if max_inclusive == 0 {
1906 return 0;
1907 }
1908 let nanos = std::time::SystemTime::now()
1909 .duration_since(std::time::UNIX_EPOCH)
1910 .unwrap_or_default()
1911 .subsec_nanos() as u64;
1912 nanos % (max_inclusive + 1)
1913}
1914
1915#[derive(Debug, Clone)]
1916struct GssCircuitState {
1917 window_started_at: Instant,
1918 failure_count: usize,
1919 open_until: Option<Instant>,
1920}
1921
1922fn gss_circuit_registry() -> &'static std::sync::Mutex<HashMap<String, GssCircuitState>> {
1923 static REGISTRY: OnceLock<std::sync::Mutex<HashMap<String, GssCircuitState>>> = OnceLock::new();
1924 REGISTRY.get_or_init(|| std::sync::Mutex::new(HashMap::new()))
1925}
1926
1927fn gss_circuit_key(config: &PoolConfig) -> String {
1928 format!(
1929 "{}:{}:{}:{}",
1930 config.host, config.port, config.user, config.database
1931 )
1932}
1933
1934fn gss_circuit_remaining_open(config: &PoolConfig) -> Option<Duration> {
1935 if !is_gss_auth_enabled(config)
1936 || config.gss_circuit_breaker_threshold == 0
1937 || config.gss_circuit_breaker_window.is_zero()
1938 || config.gss_circuit_breaker_cooldown.is_zero()
1939 {
1940 return None;
1941 }
1942
1943 let now = Instant::now();
1944 let key = gss_circuit_key(config);
1945 let Ok(mut registry) = gss_circuit_registry().lock() else {
1946 return None;
1947 };
1948 let state = registry.get_mut(&key)?;
1949 let until = state.open_until?;
1950 if until > now {
1951 return Some(until.duration_since(now));
1952 }
1953 state.open_until = None;
1954 state.failure_count = 0;
1955 state.window_started_at = now;
1956 None
1957}
1958
1959fn should_track_gss_circuit_error(config: &PoolConfig, err: &PgError) -> bool {
1960 if !is_gss_auth_enabled(config) {
1961 return false;
1962 }
1963 matches!(
1964 err,
1965 PgError::Auth(_) | PgError::Connection(_) | PgError::Timeout(_) | PgError::Io(_)
1966 )
1967}
1968
1969fn gss_circuit_record_failure(config: &PoolConfig) {
1970 if !is_gss_auth_enabled(config)
1971 || config.gss_circuit_breaker_threshold == 0
1972 || config.gss_circuit_breaker_window.is_zero()
1973 || config.gss_circuit_breaker_cooldown.is_zero()
1974 {
1975 return;
1976 }
1977
1978 let now = Instant::now();
1979 let key = gss_circuit_key(config);
1980 let Ok(mut registry) = gss_circuit_registry().lock() else {
1981 return;
1982 };
1983 let state = registry
1984 .entry(key.clone())
1985 .or_insert_with(|| GssCircuitState {
1986 window_started_at: now,
1987 failure_count: 0,
1988 open_until: None,
1989 });
1990
1991 if now.duration_since(state.window_started_at) > config.gss_circuit_breaker_window {
1992 state.window_started_at = now;
1993 state.failure_count = 0;
1994 state.open_until = None;
1995 }
1996
1997 state.failure_count += 1;
1998 if state.failure_count >= config.gss_circuit_breaker_threshold {
1999 metrics::counter!("qail_pg_gss_circuit_open_total").increment(1);
2000 state.open_until = Some(now + config.gss_circuit_breaker_cooldown);
2001 state.failure_count = 0;
2002 state.window_started_at = now;
2003 tracing::warn!(
2004 host = %config.host,
2005 port = config.port,
2006 user = %config.user,
2007 db = %config.database,
2008 threshold = config.gss_circuit_breaker_threshold,
2009 cooldown_ms = config.gss_circuit_breaker_cooldown.as_millis() as u64,
2010 "gss_connect_circuit_opened"
2011 );
2012 }
2013}
2014
2015fn gss_circuit_record_success(config: &PoolConfig) {
2016 if !is_gss_auth_enabled(config) {
2017 return;
2018 }
2019 let key = gss_circuit_key(config);
2020 if let Ok(mut registry) = gss_circuit_registry().lock() {
2021 registry.remove(&key);
2022 }
2023}
2024
2025#[cfg(test)]
2026mod tests {
2027 use super::*;
2028
2029 #[test]
2030 fn test_pool_config() {
2031 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
2032 .password("secret123")
2033 .max_connections(20)
2034 .min_connections(5);
2035
2036 assert_eq!(config.host, "localhost");
2037 assert_eq!(config.port, 5432);
2038 assert_eq!(config.user, "user");
2039 assert_eq!(config.database, "testdb");
2040 assert_eq!(config.password, Some("secret123".to_string()));
2041 assert_eq!(config.max_connections, 20);
2042 assert_eq!(config.min_connections, 5);
2043 }
2044
2045 #[test]
2046 fn test_pool_config_defaults() {
2047 let config = PoolConfig::new("localhost", 5432, "user", "testdb");
2048 assert_eq!(config.max_connections, 10);
2049 assert_eq!(config.min_connections, 1);
2050 assert_eq!(config.idle_timeout, Duration::from_secs(600));
2051 assert_eq!(config.acquire_timeout, Duration::from_secs(30));
2052 assert_eq!(config.connect_timeout, Duration::from_secs(10));
2053 assert!(config.password.is_none());
2054 assert_eq!(config.tls_mode, TlsMode::Prefer);
2055 assert!(config.tls_ca_cert_pem.is_none());
2056 assert!(config.mtls.is_none());
2057 assert!(config.auth_settings.allow_scram_sha_256);
2058 assert!(!config.auth_settings.allow_md5_password);
2059 assert!(!config.auth_settings.allow_cleartext_password);
2060 assert_eq!(config.gss_connect_retries, 2);
2061 assert_eq!(config.gss_retry_base_delay, Duration::from_millis(150));
2062 assert_eq!(config.gss_circuit_breaker_threshold, 8);
2063 assert_eq!(config.gss_circuit_breaker_window, Duration::from_secs(30));
2064 assert_eq!(config.gss_circuit_breaker_cooldown, Duration::from_secs(15));
2065 assert_eq!(config.gss_enc_mode, GssEncMode::Disable);
2066 }
2067
2068 #[test]
2069 fn test_gss_enc_mode_parse() {
2070 assert_eq!(
2071 GssEncMode::parse_gssencmode("disable"),
2072 Some(GssEncMode::Disable)
2073 );
2074 assert_eq!(
2075 GssEncMode::parse_gssencmode("prefer"),
2076 Some(GssEncMode::Prefer)
2077 );
2078 assert_eq!(
2079 GssEncMode::parse_gssencmode("require"),
2080 Some(GssEncMode::Require)
2081 );
2082 assert_eq!(
2083 GssEncMode::parse_gssencmode("PREFER"),
2084 Some(GssEncMode::Prefer)
2085 );
2086 assert_eq!(
2087 GssEncMode::parse_gssencmode(" Require "),
2088 Some(GssEncMode::Require)
2089 );
2090 assert_eq!(GssEncMode::parse_gssencmode(""), None);
2091 assert_eq!(GssEncMode::parse_gssencmode("invalid"), None);
2092 assert_eq!(GssEncMode::parse_gssencmode("allow"), None);
2093 }
2094
2095 #[test]
2096 fn test_gss_enc_mode_default() {
2097 assert_eq!(GssEncMode::default(), GssEncMode::Disable);
2098 }
2099
2100 #[test]
2101 fn test_url_gssencmode_disable() {
2102 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2103 apply_url_query_params(&mut config, "gssencmode=disable", "localhost").unwrap();
2104 assert_eq!(config.gss_enc_mode, GssEncMode::Disable);
2105 }
2106
2107 #[test]
2108 fn test_url_gssencmode_prefer() {
2109 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2110 apply_url_query_params(&mut config, "gssencmode=prefer", "localhost").unwrap();
2111 assert_eq!(config.gss_enc_mode, GssEncMode::Prefer);
2112 }
2113
2114 #[test]
2115 fn test_url_gssencmode_require() {
2116 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2117 apply_url_query_params(&mut config, "gssencmode=require", "localhost").unwrap();
2118 assert_eq!(config.gss_enc_mode, GssEncMode::Require);
2119 }
2120
2121 #[test]
2122 fn test_url_gssencmode_invalid() {
2123 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2124 let err = apply_url_query_params(&mut config, "gssencmode=bogus", "localhost");
2125 assert!(err.is_err());
2126 }
2127
2128 #[test]
2129 fn test_url_gssencmode_with_sslmode() {
2130 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2131 apply_url_query_params(
2132 &mut config,
2133 "gssencmode=prefer&sslmode=require",
2134 "localhost",
2135 )
2136 .unwrap();
2137 assert_eq!(config.gss_enc_mode, GssEncMode::Prefer);
2138 assert_eq!(config.tls_mode, TlsMode::Require);
2139 }
2140
2141 #[test]
2142 fn test_url_gssencmode_require_sslmode_require_is_valid() {
2143 let mut config = PoolConfig::new("localhost", 5432, "u", "db");
2145 apply_url_query_params(
2146 &mut config,
2147 "gssencmode=require&sslmode=require",
2148 "localhost",
2149 )
2150 .unwrap();
2151 assert_eq!(config.gss_enc_mode, GssEncMode::Require);
2152 assert_eq!(config.tls_mode, TlsMode::Require);
2153 }
2154
2155 #[test]
2156 fn test_pool_config_builder_chaining() {
2157 let config = PoolConfig::new("db.example.com", 5433, "admin", "prod")
2158 .password("p@ss")
2159 .max_connections(50)
2160 .min_connections(10)
2161 .idle_timeout(Duration::from_secs(300))
2162 .acquire_timeout(Duration::from_secs(5))
2163 .connect_timeout(Duration::from_secs(3))
2164 .max_lifetime(Duration::from_secs(3600))
2165 .gss_connect_retries(4)
2166 .gss_retry_base_delay(Duration::from_millis(250))
2167 .gss_circuit_breaker_threshold(12)
2168 .gss_circuit_breaker_window(Duration::from_secs(45))
2169 .gss_circuit_breaker_cooldown(Duration::from_secs(20))
2170 .test_on_acquire(false);
2171
2172 assert_eq!(config.host, "db.example.com");
2173 assert_eq!(config.port, 5433);
2174 assert_eq!(config.max_connections, 50);
2175 assert_eq!(config.min_connections, 10);
2176 assert_eq!(config.idle_timeout, Duration::from_secs(300));
2177 assert_eq!(config.acquire_timeout, Duration::from_secs(5));
2178 assert_eq!(config.connect_timeout, Duration::from_secs(3));
2179 assert_eq!(config.max_lifetime, Some(Duration::from_secs(3600)));
2180 assert_eq!(config.gss_connect_retries, 4);
2181 assert_eq!(config.gss_retry_base_delay, Duration::from_millis(250));
2182 assert_eq!(config.gss_circuit_breaker_threshold, 12);
2183 assert_eq!(config.gss_circuit_breaker_window, Duration::from_secs(45));
2184 assert_eq!(config.gss_circuit_breaker_cooldown, Duration::from_secs(20));
2185 assert!(!config.test_on_acquire);
2186 }
2187
2188 #[test]
2189 fn test_parse_pg_url_strips_query_string() {
2190 let (host, port, user, db, password) = parse_pg_url(
2191 "postgresql://alice:secret@db.internal:5433/app?sslmode=require&channel_binding=require",
2192 )
2193 .unwrap();
2194 assert_eq!(host, "db.internal");
2195 assert_eq!(port, 5433);
2196 assert_eq!(user, "alice");
2197 assert_eq!(db, "app");
2198 assert_eq!(password, Some("secret".to_string()));
2199 }
2200
2201 #[test]
2202 fn test_parse_bool_param_variants() {
2203 assert_eq!(parse_bool_param("true"), Some(true));
2204 assert_eq!(parse_bool_param("YES"), Some(true));
2205 assert_eq!(parse_bool_param("0"), Some(false));
2206 assert_eq!(parse_bool_param("off"), Some(false));
2207 assert_eq!(parse_bool_param("invalid"), None);
2208 }
2209
2210 #[test]
2211 fn test_from_qail_config_rejects_invalid_gss_provider() {
2212 let mut qail = qail_core::config::QailConfig::default();
2213 qail.postgres.url =
2214 "postgres://alice:secret@db.internal:5432/app?gss_provider=unknown".to_string();
2215
2216 let err = match PoolConfig::from_qail_config(&qail) {
2217 Ok(_) => panic!("expected invalid gss_provider error"),
2218 Err(e) => e,
2219 };
2220 assert!(err.to_string().contains("Invalid gss_provider value"));
2221 }
2222
2223 #[test]
2224 fn test_from_qail_config_rejects_empty_gss_service() {
2225 let mut qail = qail_core::config::QailConfig::default();
2226 qail.postgres.url = "postgres://alice:secret@db.internal:5432/app?gss_service=".to_string();
2227
2228 let err = match PoolConfig::from_qail_config(&qail) {
2229 Ok(_) => panic!("expected empty gss_service error"),
2230 Err(e) => e,
2231 };
2232 assert!(err.to_string().contains("gss_service must not be empty"));
2233 }
2234
2235 #[test]
2236 fn test_from_qail_config_parses_gss_retry_settings() {
2237 let mut qail = qail_core::config::QailConfig::default();
2238 qail.postgres.url =
2239 "postgres://alice@db.internal:5432/app?gss_connect_retries=5&gss_retry_base_ms=400&gss_circuit_threshold=9&gss_circuit_window_ms=60000&gss_circuit_cooldown_ms=12000".to_string();
2240
2241 let cfg = PoolConfig::from_qail_config(&qail).expect("expected valid config");
2242 assert_eq!(cfg.gss_connect_retries, 5);
2243 assert_eq!(cfg.gss_retry_base_delay, Duration::from_millis(400));
2244 assert_eq!(cfg.gss_circuit_breaker_threshold, 9);
2245 assert_eq!(cfg.gss_circuit_breaker_window, Duration::from_secs(60));
2246 assert_eq!(cfg.gss_circuit_breaker_cooldown, Duration::from_secs(12));
2247 }
2248
2249 #[test]
2250 fn test_from_qail_config_rejects_invalid_gss_retry_base() {
2251 let mut qail = qail_core::config::QailConfig::default();
2252 qail.postgres.url = "postgres://alice@db.internal:5432/app?gss_retry_base_ms=0".to_string();
2253
2254 let err = match PoolConfig::from_qail_config(&qail) {
2255 Ok(_) => panic!("expected invalid gss_retry_base_ms error"),
2256 Err(e) => e,
2257 };
2258 assert!(
2259 err.to_string()
2260 .contains("gss_retry_base_ms must be greater than 0")
2261 );
2262 }
2263
2264 #[test]
2265 fn test_from_qail_config_rejects_invalid_gss_connect_retries() {
2266 let mut qail = qail_core::config::QailConfig::default();
2267 qail.postgres.url =
2268 "postgres://alice@db.internal:5432/app?gss_connect_retries=100".to_string();
2269
2270 let err = match PoolConfig::from_qail_config(&qail) {
2271 Ok(_) => panic!("expected invalid gss_connect_retries error"),
2272 Err(e) => e,
2273 };
2274 assert!(
2275 err.to_string()
2276 .contains("gss_connect_retries must be <= 20")
2277 );
2278 }
2279
2280 #[test]
2281 fn test_from_qail_config_rejects_invalid_gss_circuit_threshold() {
2282 let mut qail = qail_core::config::QailConfig::default();
2283 qail.postgres.url =
2284 "postgres://alice@db.internal:5432/app?gss_circuit_threshold=500".to_string();
2285
2286 let err = match PoolConfig::from_qail_config(&qail) {
2287 Ok(_) => panic!("expected invalid gss_circuit_threshold error"),
2288 Err(e) => e,
2289 };
2290 assert!(
2291 err.to_string()
2292 .contains("gss_circuit_threshold must be <= 100")
2293 );
2294 }
2295
2296 #[test]
2297 fn test_from_qail_config_rejects_invalid_gss_circuit_window() {
2298 let mut qail = qail_core::config::QailConfig::default();
2299 qail.postgres.url =
2300 "postgres://alice@db.internal:5432/app?gss_circuit_window_ms=0".to_string();
2301
2302 let err = match PoolConfig::from_qail_config(&qail) {
2303 Ok(_) => panic!("expected invalid gss_circuit_window_ms error"),
2304 Err(e) => e,
2305 };
2306 assert!(
2307 err.to_string()
2308 .contains("gss_circuit_window_ms must be greater than 0")
2309 );
2310 }
2311
2312 #[test]
2313 fn test_from_qail_config_rejects_invalid_gss_circuit_cooldown() {
2314 let mut qail = qail_core::config::QailConfig::default();
2315 qail.postgres.url =
2316 "postgres://alice@db.internal:5432/app?gss_circuit_cooldown_ms=0".to_string();
2317
2318 let err = match PoolConfig::from_qail_config(&qail) {
2319 Ok(_) => panic!("expected invalid gss_circuit_cooldown_ms error"),
2320 Err(e) => e,
2321 };
2322 assert!(
2323 err.to_string()
2324 .contains("gss_circuit_cooldown_ms must be greater than 0")
2325 );
2326 }
2327
2328 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
2329 #[test]
2330 fn test_from_qail_config_linux_krb5_requires_feature_on_linux() {
2331 let mut qail = qail_core::config::QailConfig::default();
2332 qail.postgres.url =
2333 "postgres://alice@db.internal:5432/app?gss_provider=linux_krb5".to_string();
2334
2335 let err = match PoolConfig::from_qail_config(&qail) {
2336 Ok(_) => panic!("expected linux_krb5 feature-gate error"),
2337 Err(e) => e,
2338 };
2339 assert!(
2340 err.to_string()
2341 .contains("requires qail-pg feature enterprise-gssapi on Linux")
2342 );
2343 }
2344
2345 #[test]
2346 fn test_timeout_error_display() {
2347 let err = PgError::Timeout("pool acquire after 30s (10 max connections)".to_string());
2348 let msg = err.to_string();
2349 assert!(msg.contains("Timeout"));
2350 assert!(msg.contains("30s"));
2351 assert!(msg.contains("10 max connections"));
2352 }
2353
2354 #[test]
2355 fn test_should_retry_gss_connect_error_transient_auth() {
2356 let config = PoolConfig::new("localhost", 5432, "user", "db")
2357 .auth_settings(AuthSettings::gssapi_only())
2358 .gss_connect_retries(3);
2359 let err = PgError::Auth("temporary kerberos service unavailable".to_string());
2360 assert!(should_retry_gss_connect_error(&config, 0, &err));
2361 }
2362
2363 #[test]
2364 fn test_should_retry_gss_connect_error_non_transient_auth() {
2365 let config = PoolConfig::new("localhost", 5432, "user", "db")
2366 .auth_settings(AuthSettings::gssapi_only())
2367 .gss_connect_retries(3);
2368 let err = PgError::Auth(
2369 "Kerberos V5 authentication requested but no GSS token provider is configured"
2370 .to_string(),
2371 );
2372 assert!(!should_retry_gss_connect_error(&config, 0, &err));
2373 }
2374
2375 #[test]
2376 fn test_should_retry_gss_connect_error_respects_retry_limit() {
2377 let config = PoolConfig::new("localhost", 5432, "user", "db")
2378 .auth_settings(AuthSettings::gssapi_only())
2379 .gss_connect_retries(1);
2380 let err = PgError::Connection("temporary network is unreachable".to_string());
2381 assert!(should_retry_gss_connect_error(&config, 0, &err));
2382 assert!(!should_retry_gss_connect_error(&config, 1, &err));
2383 }
2384
2385 #[test]
2386 fn test_gss_retry_delay_has_bounded_jitter() {
2387 let delay = gss_retry_delay(Duration::from_millis(100), 2);
2388 assert!(delay >= Duration::from_millis(400));
2389 assert!(delay <= Duration::from_millis(480));
2390 }
2391
2392 #[test]
2393 fn test_gss_circuit_opens_and_resets_on_success() {
2394 let config = PoolConfig::new("circuit.test", 5432, "user", "db_circuit")
2395 .auth_settings(AuthSettings::gssapi_only())
2396 .gss_circuit_breaker_threshold(2)
2397 .gss_circuit_breaker_window(Duration::from_secs(30))
2398 .gss_circuit_breaker_cooldown(Duration::from_secs(5));
2399
2400 gss_circuit_record_success(&config);
2401 assert!(gss_circuit_remaining_open(&config).is_none());
2402
2403 gss_circuit_record_failure(&config);
2404 assert!(gss_circuit_remaining_open(&config).is_none());
2405
2406 gss_circuit_record_failure(&config);
2407 assert!(gss_circuit_remaining_open(&config).is_some());
2408
2409 gss_circuit_record_success(&config);
2410 assert!(gss_circuit_remaining_open(&config).is_none());
2411 }
2412
2413 #[test]
2414 fn test_pool_closed_error_display() {
2415 let err = PgError::PoolClosed;
2416 assert_eq!(err.to_string(), "Connection pool is closed");
2417 }
2418
2419 #[test]
2420 fn test_pool_exhausted_error_display() {
2421 let err = PgError::PoolExhausted { max: 20 };
2422 let msg = err.to_string();
2423 assert!(msg.contains("exhausted"));
2424 assert!(msg.contains("20"));
2425 }
2426
2427 #[test]
2428 fn test_io_error_source_chaining() {
2429 use std::error::Error;
2430 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "peer reset");
2431 let pg_err = PgError::Io(io_err);
2432 let source = pg_err.source().expect("Io variant should have source");
2434 assert!(source.to_string().contains("peer reset"));
2435 }
2436
2437 #[test]
2438 fn test_non_io_errors_have_no_source() {
2439 use std::error::Error;
2440 assert!(PgError::Connection("test".into()).source().is_none());
2441 assert!(PgError::Query("test".into()).source().is_none());
2442 assert!(PgError::Timeout("test".into()).source().is_none());
2443 assert!(PgError::PoolClosed.source().is_none());
2444 assert!(PgError::NoRows.source().is_none());
2445 }
2446
2447 #[test]
2448 fn test_io_error_from_conversion() {
2449 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
2450 let pg_err: PgError = io_err.into();
2451 assert!(matches!(pg_err, PgError::Io(_)));
2452 assert!(pg_err.to_string().contains("broken"));
2453 }
2454
2455 #[test]
2456 fn test_error_variants_are_distinct() {
2457 let errors: Vec<PgError> = vec![
2459 PgError::Connection("conn".into()),
2460 PgError::Protocol("proto".into()),
2461 PgError::Auth("auth".into()),
2462 PgError::Query("query".into()),
2463 PgError::QueryServer(crate::driver::PgServerError {
2464 severity: "ERROR".to_string(),
2465 code: "23505".to_string(),
2466 message: "duplicate key value violates unique constraint".to_string(),
2467 detail: None,
2468 hint: None,
2469 }),
2470 PgError::NoRows,
2471 PgError::Io(std::io::Error::other("io")),
2472 PgError::Encode("enc".into()),
2473 PgError::Timeout("timeout".into()),
2474 PgError::PoolExhausted { max: 10 },
2475 PgError::PoolClosed,
2476 ];
2477 for err in &errors {
2479 assert!(!err.to_string().is_empty());
2480 }
2481 assert_eq!(errors.len(), 11);
2482 }
2483}