Skip to main content

qail_pg/driver/pool/
config.rs

1//! Pool configuration, URL parsing, and builder.
2
3use crate::driver::{
4    AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5    ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9/// Configuration for a PostgreSQL connection pool.
10///
11/// Use the builder pattern to customise settings:
12///
13/// ```ignore
14/// use std::time::Duration;
15/// use qail_pg::driver::pool::PoolConfig;
16/// let config = PoolConfig::new("localhost", 5432, "app", "mydb")
17///     .password("secret")
18///     .max_connections(20)
19///     .acquire_timeout(Duration::from_secs(5));
20/// ```
21#[derive(Clone)]
22pub struct PoolConfig {
23    /// PostgreSQL server hostname or IP address.
24    pub host: String,
25    /// PostgreSQL server port (default: 5432).
26    pub port: u16,
27    /// Database role / user name.
28    pub user: String,
29    /// Target database name.
30    pub database: String,
31    /// Optional password for authentication.
32    pub password: Option<String>,
33    /// Hard upper limit on simultaneous connections (default: 10).
34    pub max_connections: usize,
35    /// Minimum idle connections kept warm in the pool (default: 1).
36    pub min_connections: usize,
37    /// Close idle connections after this duration (default: 10 min).
38    pub idle_timeout: Duration,
39    /// Maximum time to wait when acquiring a connection (default: 30s).
40    pub acquire_timeout: Duration,
41    /// TCP connect timeout for new connections (default: 10s).
42    pub connect_timeout: Duration,
43    /// Optional maximum lifetime of any connection in the pool.
44    pub max_lifetime: Option<Duration>,
45    /// Maximum number of leaked-connection cleanup tasks that may run concurrently.
46    ///
47    /// When a `PooledConnection` is dropped without calling `release()`, the pool
48    /// can attempt async reset-and-return. This bound prevents unbounded cleanup fanout.
49    pub leaked_cleanup_queue: usize,
50    /// When `true`, run a health check (`SELECT 1`) before handing out a connection.
51    pub test_on_acquire: bool,
52    /// TLS mode for new connections.
53    pub tls_mode: TlsMode,
54    /// Optional custom CA bundle (PEM) for server certificate validation.
55    pub tls_ca_cert_pem: Option<Vec<u8>>,
56    /// Optional mTLS client certificate/key configuration.
57    pub mtls: Option<TlsConfig>,
58    /// Optional callback for Kerberos/GSS/SSPI token generation.
59    pub gss_token_provider: Option<GssTokenProvider>,
60    /// Optional stateful callback for Kerberos/GSS/SSPI token generation.
61    pub gss_token_provider_ex: Option<GssTokenProviderEx>,
62    /// Number of retries for transient GSS/Kerberos connection failures.
63    pub gss_connect_retries: usize,
64    /// Base delay for GSS/Kerberos connect retry backoff.
65    pub gss_retry_base_delay: Duration,
66    /// Transient GSS failures in one window before opening the local circuit.
67    pub gss_circuit_breaker_threshold: usize,
68    /// Rolling window used to count transient GSS failures.
69    pub gss_circuit_breaker_window: Duration,
70    /// Cooldown duration while the local GSS circuit stays open.
71    pub gss_circuit_breaker_cooldown: Duration,
72    /// Password-auth policy.
73    pub auth_settings: AuthSettings,
74    /// GSSAPI session encryption mode (`gssencmode` URL parameter).
75    pub gss_enc_mode: GssEncMode,
76}
77
78impl PoolConfig {
79    /// Create a new pool configuration with **production-safe** defaults.
80    ///
81    /// Defaults: `tls_mode = Require`, `auth_settings = scram_only()`.
82    /// For local development without TLS, use [`PoolConfig::new_dev`].
83    ///
84    /// # Arguments
85    ///
86    /// * `host` — PostgreSQL server hostname or IP.
87    /// * `port` — TCP port (typically 5432).
88    /// * `user` — PostgreSQL role name.
89    /// * `database` — Target database name.
90    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
91        Self {
92            host: host.to_string(),
93            port,
94            user: user.to_string(),
95            database: database.to_string(),
96            password: None,
97            max_connections: 10,
98            min_connections: 1,
99            idle_timeout: Duration::from_secs(600), // 10 minutes
100            acquire_timeout: Duration::from_secs(30), // 30 seconds
101            connect_timeout: Duration::from_secs(10), // 10 seconds
102            max_lifetime: None,                     // No limit by default
103            leaked_cleanup_queue: 64,               // Bounded cleanup fanout
104            test_on_acquire: false,                 // Disabled by default for performance
105            tls_mode: TlsMode::Prefer,
106            tls_ca_cert_pem: None,
107            mtls: None,
108            gss_token_provider: None,
109            gss_token_provider_ex: None,
110            gss_connect_retries: 2,
111            gss_retry_base_delay: Duration::from_millis(150),
112            gss_circuit_breaker_threshold: 8,
113            gss_circuit_breaker_window: Duration::from_secs(30),
114            gss_circuit_breaker_cooldown: Duration::from_secs(15),
115            auth_settings: AuthSettings::scram_only(),
116            gss_enc_mode: GssEncMode::Disable,
117        }
118    }
119
120    /// Create a pool configuration with **permissive** defaults for local development.
121    ///
122    /// Defaults: `tls_mode = Disable`, `auth_settings = default()` (accepts any auth).
123    /// Do NOT use in production.
124    pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
125        let mut config = Self::new(host, port, user, database);
126        config.tls_mode = TlsMode::Disable;
127        config.auth_settings = AuthSettings::default();
128        config
129    }
130
131    /// Set password for authentication.
132    pub fn password(mut self, password: &str) -> Self {
133        self.password = Some(password.to_string());
134        self
135    }
136
137    /// Set maximum simultaneous connections.
138    pub fn max_connections(mut self, max: usize) -> Self {
139        self.max_connections = max;
140        self
141    }
142
143    /// Set minimum idle connections.
144    pub fn min_connections(mut self, min: usize) -> Self {
145        self.min_connections = min;
146        self
147    }
148
149    /// Set idle timeout (connections idle longer than this are closed).
150    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
151        self.idle_timeout = timeout;
152        self
153    }
154
155    /// Set acquire timeout (max wait time when getting a connection).
156    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
157        self.acquire_timeout = timeout;
158        self
159    }
160
161    /// Set connect timeout (max time to establish new connection).
162    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
163        self.connect_timeout = timeout;
164        self
165    }
166
167    /// Set maximum lifetime of a connection before recycling.
168    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
169        self.max_lifetime = Some(lifetime);
170        self
171    }
172
173    /// Set max concurrent leaked-connection cleanup tasks.
174    ///
175    /// Values <= 1 force strict fallback-destroy behavior under burst leaks.
176    pub fn leaked_cleanup_queue(mut self, max: usize) -> Self {
177        self.leaked_cleanup_queue = max;
178        self
179    }
180
181    /// Enable connection validation on acquire.
182    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
183        self.test_on_acquire = enabled;
184        self
185    }
186
187    /// Set TLS mode for pool connections.
188    pub fn tls_mode(mut self, mode: TlsMode) -> Self {
189        self.tls_mode = mode;
190        self
191    }
192
193    /// Set custom CA bundle (PEM) for TLS validation.
194    pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
195        self.tls_ca_cert_pem = Some(ca_cert_pem);
196        self
197    }
198
199    /// Enable mTLS for pool connections.
200    pub fn mtls(mut self, config: TlsConfig) -> Self {
201        self.mtls = Some(config);
202        self.tls_mode = TlsMode::Require;
203        self
204    }
205
206    /// Set Kerberos/GSS/SSPI token provider callback.
207    pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
208        self.gss_token_provider = Some(provider);
209        self
210    }
211
212    /// Set a stateful Kerberos/GSS/SSPI token provider.
213    pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
214        self.gss_token_provider_ex = Some(provider);
215        self
216    }
217
218    /// Set retry count for transient GSS/Kerberos connection failures.
219    pub fn gss_connect_retries(mut self, retries: usize) -> Self {
220        self.gss_connect_retries = retries;
221        self
222    }
223
224    /// Set base backoff delay for GSS/Kerberos connection retry.
225    pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
226        self.gss_retry_base_delay = delay;
227        self
228    }
229
230    /// Set failure threshold for opening local GSS circuit breaker.
231    pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
232        self.gss_circuit_breaker_threshold = threshold;
233        self
234    }
235
236    /// Set rolling failure window for GSS circuit breaker.
237    pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
238        self.gss_circuit_breaker_window = window;
239        self
240    }
241
242    /// Set cooldown duration for open GSS circuit breaker.
243    pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
244        self.gss_circuit_breaker_cooldown = cooldown;
245        self
246    }
247
248    /// Set authentication policy.
249    pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
250        self.auth_settings = settings;
251        self
252    }
253
254    /// Create a `PoolConfig` from a centralized `QailConfig`.
255    ///
256    /// Parses `postgres.url` for host/port/user/database/password
257    /// and applies pool tuning from `[postgres]` section.
258    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
259        let pg = &qail.postgres;
260        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
261
262        let mut config = PoolConfig::new(&host, port, &user, &database)
263            .max_connections(pg.max_connections)
264            .min_connections(pg.min_connections)
265            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
266            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
267            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
268            .test_on_acquire(pg.test_on_acquire);
269
270        if let Some(ref pw) = password {
271            config = config.password(pw);
272        }
273
274        // Optional URL query params for enterprise auth/TLS settings.
275        if let Some(query) = pg.url.split('?').nth(1) {
276            apply_url_query_params(&mut config, query, &host)?;
277        }
278
279        Ok(config)
280    }
281}
282
283/// Apply enterprise auth/TLS query parameters to a `PoolConfig`.
284///
285/// Shared between `PoolConfig::from_qail_config` and `PgDriver::connect_url`
286/// so that both paths support the same set of URL knobs.
287#[allow(unused_variables)]
288pub(crate) fn apply_url_query_params(
289    config: &mut PoolConfig,
290    query: &str,
291    host: &str,
292) -> PgResult<()> {
293    let mut sslcert: Option<String> = None;
294    let mut sslkey: Option<String> = None;
295    let mut gss_provider: Option<String> = None;
296    let mut gss_service = "postgres".to_string();
297    let mut gss_target: Option<String> = None;
298
299    for pair in query.split('&').filter(|p| !p.is_empty()) {
300        let mut kv = pair.splitn(2, '=');
301        let key = kv.next().unwrap_or_default().trim();
302        let value = kv.next().unwrap_or_default().trim();
303
304        match key {
305            "sslmode" => {
306                let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
307                    PgError::Connection(format!("Invalid sslmode value: {}", value))
308                })?;
309                config.tls_mode = mode;
310            }
311            "gssencmode" => {
312                let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
313                    PgError::Connection(format!("Invalid gssencmode value: {}", value))
314                })?;
315                config.gss_enc_mode = mode;
316            }
317            "sslrootcert" => {
318                let ca_pem = std::fs::read(value).map_err(|e| {
319                    PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
320                })?;
321                config.tls_ca_cert_pem = Some(ca_pem);
322            }
323            "sslcert" => sslcert = Some(value.to_string()),
324            "sslkey" => sslkey = Some(value.to_string()),
325            "channel_binding" => {
326                let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
327                    PgError::Connection(format!("Invalid channel_binding value: {}", value))
328                })?;
329                config.auth_settings.channel_binding = mode;
330            }
331            "auth_scram" => {
332                let enabled = parse_bool_param(value).ok_or_else(|| {
333                    PgError::Connection(format!("Invalid auth_scram value: {}", value))
334                })?;
335                config.auth_settings.allow_scram_sha_256 = enabled;
336            }
337            "auth_md5" => {
338                let enabled = parse_bool_param(value).ok_or_else(|| {
339                    PgError::Connection(format!("Invalid auth_md5 value: {}", value))
340                })?;
341                config.auth_settings.allow_md5_password = enabled;
342            }
343            "auth_cleartext" => {
344                let enabled = parse_bool_param(value).ok_or_else(|| {
345                    PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
346                })?;
347                config.auth_settings.allow_cleartext_password = enabled;
348            }
349            "auth_kerberos" => {
350                let enabled = parse_bool_param(value).ok_or_else(|| {
351                    PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
352                })?;
353                config.auth_settings.allow_kerberos_v5 = enabled;
354            }
355            "auth_gssapi" => {
356                let enabled = parse_bool_param(value).ok_or_else(|| {
357                    PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
358                })?;
359                config.auth_settings.allow_gssapi = enabled;
360            }
361            "auth_sspi" => {
362                let enabled = parse_bool_param(value).ok_or_else(|| {
363                    PgError::Connection(format!("Invalid auth_sspi value: {}", value))
364                })?;
365                config.auth_settings.allow_sspi = enabled;
366            }
367            "auth_mode" => {
368                if value.eq_ignore_ascii_case("scram_only") {
369                    config.auth_settings = AuthSettings::scram_only();
370                } else if value.eq_ignore_ascii_case("gssapi_only") {
371                    config.auth_settings = AuthSettings::gssapi_only();
372                } else if value.eq_ignore_ascii_case("compat")
373                    || value.eq_ignore_ascii_case("default")
374                {
375                    config.auth_settings = AuthSettings::default();
376                } else {
377                    return Err(PgError::Connection(format!(
378                        "Invalid auth_mode value: {}",
379                        value
380                    )));
381                }
382            }
383            "gss_provider" => gss_provider = Some(value.to_string()),
384            "gss_service" => {
385                if value.is_empty() {
386                    return Err(PgError::Connection(
387                        "gss_service must not be empty".to_string(),
388                    ));
389                }
390                gss_service = value.to_string();
391            }
392            // libpq alias for kerberos service principal name component.
393            "krbsrvname" => {
394                if value.is_empty() {
395                    return Err(PgError::Connection(
396                        "gss_service must not be empty".to_string(),
397                    ));
398                }
399                gss_service = value.to_string();
400            }
401            "gss_target" => {
402                if value.is_empty() {
403                    return Err(PgError::Connection(
404                        "gss_target must not be empty".to_string(),
405                    ));
406                }
407                gss_target = Some(value.to_string());
408            }
409            // libpq alias for GSS target hostname override.
410            "gsshostname" => {
411                if value.is_empty() {
412                    return Err(PgError::Connection(
413                        "gss_target must not be empty".to_string(),
414                    ));
415                }
416                gss_target = Some(value.to_string());
417            }
418            // libpq compatibility knob; accepted values are validated but
419            // provider selection remains controlled by qail `gss_provider`.
420            "gsslib" => match value.trim().to_ascii_lowercase().as_str() {
421                "gssapi" | "sspi" => {}
422                _ => {
423                    return Err(PgError::Connection(format!(
424                        "Invalid gsslib value: {} (expected gssapi or sspi)",
425                        value
426                    )));
427                }
428            },
429            "gss_connect_retries" => {
430                let retries = value.parse::<usize>().map_err(|_| {
431                    PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
432                })?;
433                if retries > 20 {
434                    return Err(PgError::Connection(
435                        "gss_connect_retries must be <= 20".to_string(),
436                    ));
437                }
438                config.gss_connect_retries = retries;
439            }
440            "gss_retry_base_ms" => {
441                let delay_ms = value.parse::<u64>().map_err(|_| {
442                    PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
443                })?;
444                if delay_ms == 0 {
445                    return Err(PgError::Connection(
446                        "gss_retry_base_ms must be greater than 0".to_string(),
447                    ));
448                }
449                config.gss_retry_base_delay = Duration::from_millis(delay_ms);
450            }
451            "gss_circuit_threshold" => {
452                let threshold = value.parse::<usize>().map_err(|_| {
453                    PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
454                })?;
455                if threshold > 100 {
456                    return Err(PgError::Connection(
457                        "gss_circuit_threshold must be <= 100".to_string(),
458                    ));
459                }
460                config.gss_circuit_breaker_threshold = threshold;
461            }
462            "gss_circuit_window_ms" => {
463                let window_ms = value.parse::<u64>().map_err(|_| {
464                    PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
465                })?;
466                if window_ms == 0 {
467                    return Err(PgError::Connection(
468                        "gss_circuit_window_ms must be greater than 0".to_string(),
469                    ));
470                }
471                config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
472            }
473            "gss_circuit_cooldown_ms" => {
474                let cooldown_ms = value.parse::<u64>().map_err(|_| {
475                    PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
476                })?;
477                if cooldown_ms == 0 {
478                    return Err(PgError::Connection(
479                        "gss_circuit_cooldown_ms must be greater than 0".to_string(),
480                    ));
481                }
482                config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
483            }
484            _ => {}
485        }
486    }
487
488    match (sslcert.as_deref(), sslkey.as_deref()) {
489        (Some(cert_path), Some(key_path)) => {
490            let mtls = TlsConfig {
491                client_cert_pem: std::fs::read(cert_path).map_err(|e| {
492                    PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
493                })?,
494                client_key_pem: std::fs::read(key_path).map_err(|e| {
495                    PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
496                })?,
497                ca_cert_pem: config.tls_ca_cert_pem.clone(),
498            };
499            config.mtls = Some(mtls);
500            config.tls_mode = TlsMode::Require;
501        }
502        (Some(_), None) | (None, Some(_)) => {
503            return Err(PgError::Connection(
504                "Both sslcert and sslkey must be provided together".to_string(),
505            ));
506        }
507        (None, None) => {}
508    }
509
510    if let Some(provider) = gss_provider {
511        if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
512            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
513            {
514                let provider = crate::driver::gss::linux_krb5_token_provider(
515                    crate::driver::gss::LinuxKrb5ProviderConfig {
516                        host: host.to_string(),
517                        service: gss_service.clone(),
518                        target_name: gss_target.clone(),
519                    },
520                )
521                .map_err(PgError::Auth)?;
522                config.gss_token_provider_ex = Some(provider);
523            }
524            #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
525            {
526                let _ = gss_service;
527                let _ = gss_target;
528                return Err(PgError::Connection(
529                    "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
530                        .to_string(),
531                ));
532            }
533        } else if provider.eq_ignore_ascii_case("callback")
534            || provider.eq_ignore_ascii_case("custom")
535        {
536            // External callback wiring is handled by application code.
537        } else {
538            return Err(PgError::Connection(format!(
539                "Invalid gss_provider value: {}",
540                provider
541            )));
542        }
543    }
544
545    Ok(())
546}
547
548/// Parse a postgres URL into (host, port, user, database, password).
549pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
550    let url = url.split('?').next().unwrap_or(url);
551    let url = url
552        .trim_start_matches("postgres://")
553        .trim_start_matches("postgresql://");
554
555    let (credentials, host_part) = if url.contains('@') {
556        let mut parts = url.splitn(2, '@');
557        let creds = parts.next().unwrap_or("");
558        let host = parts.next().unwrap_or("localhost/postgres");
559        (Some(creds), host)
560    } else {
561        (None, url)
562    };
563
564    let (host_port, database) = if host_part.contains('/') {
565        let mut parts = host_part.splitn(2, '/');
566        (
567            parts.next().unwrap_or("localhost"),
568            parts.next().unwrap_or("postgres").to_string(),
569        )
570    } else {
571        (host_part, "postgres".to_string())
572    };
573
574    let (host, port) = if host_port.contains(':') {
575        let mut parts = host_port.split(':');
576        let h = parts.next().unwrap_or("localhost").to_string();
577        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
578        (h, p)
579    } else {
580        (host_port.to_string(), 5432u16)
581    };
582
583    let (user, password) = if let Some(creds) = credentials {
584        if creds.contains(':') {
585            let mut parts = creds.splitn(2, ':');
586            let u = parts.next().unwrap_or("postgres").to_string();
587            let p = parts.next().map(|s| s.to_string());
588            (u, p)
589        } else {
590            (creds.to_string(), None)
591        }
592    } else {
593        ("postgres".to_string(), None)
594    };
595
596    Ok((host, port, user, database, password))
597}
598
599pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
600    match value.trim().to_ascii_lowercase().as_str() {
601        "1" | "true" | "yes" | "on" => Some(true),
602        "0" | "false" | "no" | "off" => Some(false),
603        _ => None,
604    }
605}