Skip to main content

qail_pg/driver/pool/
config.rs

1//! Pool configuration, URL parsing, and builder.
2
3use crate::driver::{
4    AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5    ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9/// Configuration for a PostgreSQL connection pool.
10///
11/// Use the builder pattern to customise settings:
12///
13/// ```ignore
14/// use std::time::Duration;
15/// use qail_pg::driver::pool::PoolConfig;
16/// let config = PoolConfig::new("localhost", 5432, "app", "mydb")
17///     .password("secret")
18///     .max_connections(20)
19///     .acquire_timeout(Duration::from_secs(5));
20/// ```
21#[derive(Clone)]
22pub struct PoolConfig {
23    /// PostgreSQL server hostname or IP address.
24    pub host: String,
25    /// PostgreSQL server port (default: 5432).
26    pub port: u16,
27    /// Database role / user name.
28    pub user: String,
29    /// Target database name.
30    pub database: String,
31    /// Optional password for authentication.
32    pub password: Option<String>,
33    /// Hard upper limit on simultaneous connections (default: 10).
34    pub max_connections: usize,
35    /// Minimum idle connections kept warm in the pool (default: 1).
36    pub min_connections: usize,
37    /// Close idle connections after this duration (default: 10 min).
38    pub idle_timeout: Duration,
39    /// Maximum time to wait when acquiring a connection (default: 30s).
40    pub acquire_timeout: Duration,
41    /// TCP connect timeout for new connections (default: 10s).
42    pub connect_timeout: Duration,
43    /// Optional maximum lifetime of any connection in the pool.
44    pub max_lifetime: Option<Duration>,
45    /// When `true`, run a health check (`SELECT 1`) before handing out a connection.
46    pub test_on_acquire: bool,
47    /// TLS mode for new connections.
48    pub tls_mode: TlsMode,
49    /// Optional custom CA bundle (PEM) for server certificate validation.
50    pub tls_ca_cert_pem: Option<Vec<u8>>,
51    /// Optional mTLS client certificate/key configuration.
52    pub mtls: Option<TlsConfig>,
53    /// Optional callback for Kerberos/GSS/SSPI token generation.
54    pub gss_token_provider: Option<GssTokenProvider>,
55    /// Optional stateful callback for Kerberos/GSS/SSPI token generation.
56    pub gss_token_provider_ex: Option<GssTokenProviderEx>,
57    /// Number of retries for transient GSS/Kerberos connection failures.
58    pub gss_connect_retries: usize,
59    /// Base delay for GSS/Kerberos connect retry backoff.
60    pub gss_retry_base_delay: Duration,
61    /// Transient GSS failures in one window before opening the local circuit.
62    pub gss_circuit_breaker_threshold: usize,
63    /// Rolling window used to count transient GSS failures.
64    pub gss_circuit_breaker_window: Duration,
65    /// Cooldown duration while the local GSS circuit stays open.
66    pub gss_circuit_breaker_cooldown: Duration,
67    /// Password-auth policy.
68    pub auth_settings: AuthSettings,
69    /// GSSAPI session encryption mode (`gssencmode` URL parameter).
70    pub gss_enc_mode: GssEncMode,
71}
72
73impl PoolConfig {
74    /// Create a new pool configuration with **production-safe** defaults.
75    ///
76    /// Defaults: `tls_mode = Require`, `auth_settings = scram_only()`.
77    /// For local development without TLS, use [`PoolConfig::new_dev`].
78    ///
79    /// # Arguments
80    ///
81    /// * `host` — PostgreSQL server hostname or IP.
82    /// * `port` — TCP port (typically 5432).
83    /// * `user` — PostgreSQL role name.
84    /// * `database` — Target database name.
85    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
86        Self {
87            host: host.to_string(),
88            port,
89            user: user.to_string(),
90            database: database.to_string(),
91            password: None,
92            max_connections: 10,
93            min_connections: 1,
94            idle_timeout: Duration::from_secs(600), // 10 minutes
95            acquire_timeout: Duration::from_secs(30), // 30 seconds
96            connect_timeout: Duration::from_secs(10), // 10 seconds
97            max_lifetime: None,                     // No limit by default
98            test_on_acquire: false,                 // Disabled by default for performance
99            tls_mode: TlsMode::Prefer,
100            tls_ca_cert_pem: None,
101            mtls: None,
102            gss_token_provider: None,
103            gss_token_provider_ex: None,
104            gss_connect_retries: 2,
105            gss_retry_base_delay: Duration::from_millis(150),
106            gss_circuit_breaker_threshold: 8,
107            gss_circuit_breaker_window: Duration::from_secs(30),
108            gss_circuit_breaker_cooldown: Duration::from_secs(15),
109            auth_settings: AuthSettings::scram_only(),
110            gss_enc_mode: GssEncMode::Disable,
111        }
112    }
113
114    /// Create a pool configuration with **permissive** defaults for local development.
115    ///
116    /// Defaults: `tls_mode = Disable`, `auth_settings = default()` (accepts any auth).
117    /// Do NOT use in production.
118    pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
119        let mut config = Self::new(host, port, user, database);
120        config.tls_mode = TlsMode::Disable;
121        config.auth_settings = AuthSettings::default();
122        config
123    }
124
125    /// Set password for authentication.
126    pub fn password(mut self, password: &str) -> Self {
127        self.password = Some(password.to_string());
128        self
129    }
130
131    /// Set maximum simultaneous connections.
132    pub fn max_connections(mut self, max: usize) -> Self {
133        self.max_connections = max;
134        self
135    }
136
137    /// Set minimum idle connections.
138    pub fn min_connections(mut self, min: usize) -> Self {
139        self.min_connections = min;
140        self
141    }
142
143    /// Set idle timeout (connections idle longer than this are closed).
144    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
145        self.idle_timeout = timeout;
146        self
147    }
148
149    /// Set acquire timeout (max wait time when getting a connection).
150    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
151        self.acquire_timeout = timeout;
152        self
153    }
154
155    /// Set connect timeout (max time to establish new connection).
156    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
157        self.connect_timeout = timeout;
158        self
159    }
160
161    /// Set maximum lifetime of a connection before recycling.
162    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
163        self.max_lifetime = Some(lifetime);
164        self
165    }
166
167    /// Enable connection validation on acquire.
168    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
169        self.test_on_acquire = enabled;
170        self
171    }
172
173    /// Set TLS mode for pool connections.
174    pub fn tls_mode(mut self, mode: TlsMode) -> Self {
175        self.tls_mode = mode;
176        self
177    }
178
179    /// Set custom CA bundle (PEM) for TLS validation.
180    pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
181        self.tls_ca_cert_pem = Some(ca_cert_pem);
182        self
183    }
184
185    /// Enable mTLS for pool connections.
186    pub fn mtls(mut self, config: TlsConfig) -> Self {
187        self.mtls = Some(config);
188        self.tls_mode = TlsMode::Require;
189        self
190    }
191
192    /// Set Kerberos/GSS/SSPI token provider callback.
193    pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
194        self.gss_token_provider = Some(provider);
195        self
196    }
197
198    /// Set a stateful Kerberos/GSS/SSPI token provider.
199    pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
200        self.gss_token_provider_ex = Some(provider);
201        self
202    }
203
204    /// Set retry count for transient GSS/Kerberos connection failures.
205    pub fn gss_connect_retries(mut self, retries: usize) -> Self {
206        self.gss_connect_retries = retries;
207        self
208    }
209
210    /// Set base backoff delay for GSS/Kerberos connection retry.
211    pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
212        self.gss_retry_base_delay = delay;
213        self
214    }
215
216    /// Set failure threshold for opening local GSS circuit breaker.
217    pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
218        self.gss_circuit_breaker_threshold = threshold;
219        self
220    }
221
222    /// Set rolling failure window for GSS circuit breaker.
223    pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
224        self.gss_circuit_breaker_window = window;
225        self
226    }
227
228    /// Set cooldown duration for open GSS circuit breaker.
229    pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
230        self.gss_circuit_breaker_cooldown = cooldown;
231        self
232    }
233
234    /// Set authentication policy.
235    pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
236        self.auth_settings = settings;
237        self
238    }
239
240    /// Create a `PoolConfig` from a centralized `QailConfig`.
241    ///
242    /// Parses `postgres.url` for host/port/user/database/password
243    /// and applies pool tuning from `[postgres]` section.
244    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
245        let pg = &qail.postgres;
246        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
247
248        let mut config = PoolConfig::new(&host, port, &user, &database)
249            .max_connections(pg.max_connections)
250            .min_connections(pg.min_connections)
251            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
252            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
253            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
254            .test_on_acquire(pg.test_on_acquire);
255
256        if let Some(ref pw) = password {
257            config = config.password(pw);
258        }
259
260        // Optional URL query params for enterprise auth/TLS settings.
261        if let Some(query) = pg.url.split('?').nth(1) {
262            apply_url_query_params(&mut config, query, &host)?;
263        }
264
265        Ok(config)
266    }
267}
268
269/// Apply enterprise auth/TLS query parameters to a `PoolConfig`.
270///
271/// Shared between `PoolConfig::from_qail_config` and `PgDriver::connect_url`
272/// so that both paths support the same set of URL knobs.
273#[allow(unused_variables)]
274pub(crate) fn apply_url_query_params(
275    config: &mut PoolConfig,
276    query: &str,
277    host: &str,
278) -> PgResult<()> {
279    let mut sslcert: Option<String> = None;
280    let mut sslkey: Option<String> = None;
281    let mut gss_provider: Option<String> = None;
282    let mut gss_service = "postgres".to_string();
283    let mut gss_target: Option<String> = None;
284
285    for pair in query.split('&').filter(|p| !p.is_empty()) {
286        let mut kv = pair.splitn(2, '=');
287        let key = kv.next().unwrap_or_default().trim();
288        let value = kv.next().unwrap_or_default().trim();
289
290        match key {
291            "sslmode" => {
292                let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
293                    PgError::Connection(format!("Invalid sslmode value: {}", value))
294                })?;
295                config.tls_mode = mode;
296            }
297            "gssencmode" => {
298                let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
299                    PgError::Connection(format!("Invalid gssencmode value: {}", value))
300                })?;
301                config.gss_enc_mode = mode;
302            }
303            "sslrootcert" => {
304                let ca_pem = std::fs::read(value).map_err(|e| {
305                    PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
306                })?;
307                config.tls_ca_cert_pem = Some(ca_pem);
308            }
309            "sslcert" => sslcert = Some(value.to_string()),
310            "sslkey" => sslkey = Some(value.to_string()),
311            "channel_binding" => {
312                let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
313                    PgError::Connection(format!("Invalid channel_binding value: {}", value))
314                })?;
315                config.auth_settings.channel_binding = mode;
316            }
317            "auth_scram" => {
318                let enabled = parse_bool_param(value).ok_or_else(|| {
319                    PgError::Connection(format!("Invalid auth_scram value: {}", value))
320                })?;
321                config.auth_settings.allow_scram_sha_256 = enabled;
322            }
323            "auth_md5" => {
324                let enabled = parse_bool_param(value).ok_or_else(|| {
325                    PgError::Connection(format!("Invalid auth_md5 value: {}", value))
326                })?;
327                config.auth_settings.allow_md5_password = enabled;
328            }
329            "auth_cleartext" => {
330                let enabled = parse_bool_param(value).ok_or_else(|| {
331                    PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
332                })?;
333                config.auth_settings.allow_cleartext_password = enabled;
334            }
335            "auth_kerberos" => {
336                let enabled = parse_bool_param(value).ok_or_else(|| {
337                    PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
338                })?;
339                config.auth_settings.allow_kerberos_v5 = enabled;
340            }
341            "auth_gssapi" => {
342                let enabled = parse_bool_param(value).ok_or_else(|| {
343                    PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
344                })?;
345                config.auth_settings.allow_gssapi = enabled;
346            }
347            "auth_sspi" => {
348                let enabled = parse_bool_param(value).ok_or_else(|| {
349                    PgError::Connection(format!("Invalid auth_sspi value: {}", value))
350                })?;
351                config.auth_settings.allow_sspi = enabled;
352            }
353            "auth_mode" => {
354                if value.eq_ignore_ascii_case("scram_only") {
355                    config.auth_settings = AuthSettings::scram_only();
356                } else if value.eq_ignore_ascii_case("gssapi_only") {
357                    config.auth_settings = AuthSettings::gssapi_only();
358                } else if value.eq_ignore_ascii_case("compat")
359                    || value.eq_ignore_ascii_case("default")
360                {
361                    config.auth_settings = AuthSettings::default();
362                } else {
363                    return Err(PgError::Connection(format!(
364                        "Invalid auth_mode value: {}",
365                        value
366                    )));
367                }
368            }
369            "gss_provider" => gss_provider = Some(value.to_string()),
370            "gss_service" => {
371                if value.is_empty() {
372                    return Err(PgError::Connection(
373                        "gss_service must not be empty".to_string(),
374                    ));
375                }
376                gss_service = value.to_string();
377            }
378            "gss_target" => {
379                if value.is_empty() {
380                    return Err(PgError::Connection(
381                        "gss_target must not be empty".to_string(),
382                    ));
383                }
384                gss_target = Some(value.to_string());
385            }
386            "gss_connect_retries" => {
387                let retries = value.parse::<usize>().map_err(|_| {
388                    PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
389                })?;
390                if retries > 20 {
391                    return Err(PgError::Connection(
392                        "gss_connect_retries must be <= 20".to_string(),
393                    ));
394                }
395                config.gss_connect_retries = retries;
396            }
397            "gss_retry_base_ms" => {
398                let delay_ms = value.parse::<u64>().map_err(|_| {
399                    PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
400                })?;
401                if delay_ms == 0 {
402                    return Err(PgError::Connection(
403                        "gss_retry_base_ms must be greater than 0".to_string(),
404                    ));
405                }
406                config.gss_retry_base_delay = Duration::from_millis(delay_ms);
407            }
408            "gss_circuit_threshold" => {
409                let threshold = value.parse::<usize>().map_err(|_| {
410                    PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
411                })?;
412                if threshold > 100 {
413                    return Err(PgError::Connection(
414                        "gss_circuit_threshold must be <= 100".to_string(),
415                    ));
416                }
417                config.gss_circuit_breaker_threshold = threshold;
418            }
419            "gss_circuit_window_ms" => {
420                let window_ms = value.parse::<u64>().map_err(|_| {
421                    PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
422                })?;
423                if window_ms == 0 {
424                    return Err(PgError::Connection(
425                        "gss_circuit_window_ms must be greater than 0".to_string(),
426                    ));
427                }
428                config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
429            }
430            "gss_circuit_cooldown_ms" => {
431                let cooldown_ms = value.parse::<u64>().map_err(|_| {
432                    PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
433                })?;
434                if cooldown_ms == 0 {
435                    return Err(PgError::Connection(
436                        "gss_circuit_cooldown_ms must be greater than 0".to_string(),
437                    ));
438                }
439                config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
440            }
441            _ => {}
442        }
443    }
444
445    match (sslcert.as_deref(), sslkey.as_deref()) {
446        (Some(cert_path), Some(key_path)) => {
447            let mtls = TlsConfig {
448                client_cert_pem: std::fs::read(cert_path).map_err(|e| {
449                    PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
450                })?,
451                client_key_pem: std::fs::read(key_path).map_err(|e| {
452                    PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
453                })?,
454                ca_cert_pem: config.tls_ca_cert_pem.clone(),
455            };
456            config.mtls = Some(mtls);
457            config.tls_mode = TlsMode::Require;
458        }
459        (Some(_), None) | (None, Some(_)) => {
460            return Err(PgError::Connection(
461                "Both sslcert and sslkey must be provided together".to_string(),
462            ));
463        }
464        (None, None) => {}
465    }
466
467    if let Some(provider) = gss_provider {
468        if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
469            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
470            {
471                let provider = crate::driver::gss::linux_krb5_token_provider(
472                    crate::driver::gss::LinuxKrb5ProviderConfig {
473                        host: host.to_string(),
474                        service: gss_service.clone(),
475                        target_name: gss_target.clone(),
476                    },
477                )
478                .map_err(PgError::Auth)?;
479                config.gss_token_provider_ex = Some(provider);
480            }
481            #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
482            {
483                let _ = gss_service;
484                let _ = gss_target;
485                return Err(PgError::Connection(
486                    "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
487                        .to_string(),
488                ));
489            }
490        } else if provider.eq_ignore_ascii_case("callback")
491            || provider.eq_ignore_ascii_case("custom")
492        {
493            // External callback wiring is handled by application code.
494        } else {
495            return Err(PgError::Connection(format!(
496                "Invalid gss_provider value: {}",
497                provider
498            )));
499        }
500    }
501
502    Ok(())
503}
504
505/// Parse a postgres URL into (host, port, user, database, password).
506pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
507    let url = url.split('?').next().unwrap_or(url);
508    let url = url
509        .trim_start_matches("postgres://")
510        .trim_start_matches("postgresql://");
511
512    let (credentials, host_part) = if url.contains('@') {
513        let mut parts = url.splitn(2, '@');
514        let creds = parts.next().unwrap_or("");
515        let host = parts.next().unwrap_or("localhost/postgres");
516        (Some(creds), host)
517    } else {
518        (None, url)
519    };
520
521    let (host_port, database) = if host_part.contains('/') {
522        let mut parts = host_part.splitn(2, '/');
523        (
524            parts.next().unwrap_or("localhost"),
525            parts.next().unwrap_or("postgres").to_string(),
526        )
527    } else {
528        (host_part, "postgres".to_string())
529    };
530
531    let (host, port) = if host_port.contains(':') {
532        let mut parts = host_port.split(':');
533        let h = parts.next().unwrap_or("localhost").to_string();
534        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
535        (h, p)
536    } else {
537        (host_port.to_string(), 5432u16)
538    };
539
540    let (user, password) = if let Some(creds) = credentials {
541        if creds.contains(':') {
542            let mut parts = creds.splitn(2, ':');
543            let u = parts.next().unwrap_or("postgres").to_string();
544            let p = parts.next().map(|s| s.to_string());
545            (u, p)
546        } else {
547            (creds.to_string(), None)
548        }
549    } else {
550        ("postgres".to_string(), None)
551    };
552
553    Ok((host, port, user, database, password))
554}
555
556pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
557    match value.trim().to_ascii_lowercase().as_str() {
558        "1" | "true" | "yes" | "on" => Some(true),
559        "0" | "false" | "no" | "off" => Some(false),
560        _ => None,
561    }
562}