Skip to main content

qail_pg/driver/pool/
config.rs

1//! Pool configuration, URL parsing, and builder.
2
3use crate::driver::{
4    AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5    ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9/// Configuration for a PostgreSQL connection pool.
10///
11/// Use the builder pattern to customise settings:
12///
13/// ```ignore
14/// use std::time::Duration;
15/// use qail_pg::driver::pool::PoolConfig;
16/// let config = PoolConfig::new("localhost", 5432, "app", "mydb")
17///     .password("secret")
18///     .max_connections(20)
19///     .acquire_timeout(Duration::from_secs(5));
20/// ```
21#[derive(Clone)]
22pub struct PoolConfig {
23    /// PostgreSQL server hostname or IP address.
24    pub host: String,
25    /// PostgreSQL server port (default: 5432).
26    pub port: u16,
27    /// Database role / user name.
28    pub user: String,
29    /// Target database name.
30    pub database: String,
31    /// Optional password for authentication.
32    pub password: Option<String>,
33    /// Hard upper limit on simultaneous connections (default: 10).
34    pub max_connections: usize,
35    /// Minimum idle connections kept warm in the pool (default: 1).
36    pub min_connections: usize,
37    /// Close idle connections after this duration (default: 10 min).
38    pub idle_timeout: Duration,
39    /// Maximum time to wait when acquiring a connection (default: 30s).
40    pub acquire_timeout: Duration,
41    /// TCP connect timeout for new connections (default: 10s).
42    pub connect_timeout: Duration,
43    /// Optional maximum lifetime of any connection in the pool.
44    pub max_lifetime: Option<Duration>,
45    /// When `true`, run a health check (`SELECT 1`) before handing out a connection.
46    pub test_on_acquire: bool,
47    /// TLS mode for new connections.
48    pub tls_mode: TlsMode,
49    /// Optional custom CA bundle (PEM) for server certificate validation.
50    pub tls_ca_cert_pem: Option<Vec<u8>>,
51    /// Optional mTLS client certificate/key configuration.
52    pub mtls: Option<TlsConfig>,
53    /// Optional callback for Kerberos/GSS/SSPI token generation.
54    pub gss_token_provider: Option<GssTokenProvider>,
55    /// Optional stateful callback for Kerberos/GSS/SSPI token generation.
56    pub gss_token_provider_ex: Option<GssTokenProviderEx>,
57    /// Number of retries for transient GSS/Kerberos connection failures.
58    pub gss_connect_retries: usize,
59    /// Base delay for GSS/Kerberos connect retry backoff.
60    pub gss_retry_base_delay: Duration,
61    /// Transient GSS failures in one window before opening the local circuit.
62    pub gss_circuit_breaker_threshold: usize,
63    /// Rolling window used to count transient GSS failures.
64    pub gss_circuit_breaker_window: Duration,
65    /// Cooldown duration while the local GSS circuit stays open.
66    pub gss_circuit_breaker_cooldown: Duration,
67    /// Password-auth policy.
68    pub auth_settings: AuthSettings,
69    /// GSSAPI session encryption mode (`gssencmode` URL parameter).
70    pub gss_enc_mode: GssEncMode,
71}
72
73impl PoolConfig {
74    /// Create a new pool configuration with **production-safe** defaults.
75    ///
76    /// Defaults: `tls_mode = Require`, `auth_settings = scram_only()`.
77    /// For local development without TLS, use [`PoolConfig::new_dev`].
78    ///
79    /// # Arguments
80    ///
81    /// * `host` — PostgreSQL server hostname or IP.
82    /// * `port` — TCP port (typically 5432).
83    /// * `user` — PostgreSQL role name.
84    /// * `database` — Target database name.
85    pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
86        Self {
87            host: host.to_string(),
88            port,
89            user: user.to_string(),
90            database: database.to_string(),
91            password: None,
92            max_connections: 10,
93            min_connections: 1,
94            idle_timeout: Duration::from_secs(600), // 10 minutes
95            acquire_timeout: Duration::from_secs(30), // 30 seconds
96            connect_timeout: Duration::from_secs(10), // 10 seconds
97            max_lifetime: None,                     // No limit by default
98            test_on_acquire: false,                 // Disabled by default for performance
99            tls_mode: TlsMode::Prefer,
100            tls_ca_cert_pem: None,
101            mtls: None,
102            gss_token_provider: None,
103            gss_token_provider_ex: None,
104            gss_connect_retries: 2,
105            gss_retry_base_delay: Duration::from_millis(150),
106            gss_circuit_breaker_threshold: 8,
107            gss_circuit_breaker_window: Duration::from_secs(30),
108            gss_circuit_breaker_cooldown: Duration::from_secs(15),
109            auth_settings: AuthSettings::scram_only(),
110            gss_enc_mode: GssEncMode::Disable,
111        }
112    }
113
114    /// Create a pool configuration with **permissive** defaults for local development.
115    ///
116    /// Defaults: `tls_mode = Disable`, `auth_settings = default()` (accepts any auth).
117    /// Do NOT use in production.
118    pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
119        let mut config = Self::new(host, port, user, database);
120        config.tls_mode = TlsMode::Disable;
121        config.auth_settings = AuthSettings::default();
122        config
123    }
124
125    /// Set password for authentication.
126    pub fn password(mut self, password: &str) -> Self {
127        self.password = Some(password.to_string());
128        self
129    }
130
131    /// Set maximum simultaneous connections.
132    pub fn max_connections(mut self, max: usize) -> Self {
133        self.max_connections = max;
134        self
135    }
136
137    /// Set minimum idle connections.
138    pub fn min_connections(mut self, min: usize) -> Self {
139        self.min_connections = min;
140        self
141    }
142
143    /// Set idle timeout (connections idle longer than this are closed).
144    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
145        self.idle_timeout = timeout;
146        self
147    }
148
149    /// Set acquire timeout (max wait time when getting a connection).
150    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
151        self.acquire_timeout = timeout;
152        self
153    }
154
155    /// Set connect timeout (max time to establish new connection).
156    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
157        self.connect_timeout = timeout;
158        self
159    }
160
161    /// Set maximum lifetime of a connection before recycling.
162    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
163        self.max_lifetime = Some(lifetime);
164        self
165    }
166
167    /// Enable connection validation on acquire.
168    pub fn test_on_acquire(mut self, enabled: bool) -> Self {
169        self.test_on_acquire = enabled;
170        self
171    }
172
173    /// Set TLS mode for pool connections.
174    pub fn tls_mode(mut self, mode: TlsMode) -> Self {
175        self.tls_mode = mode;
176        self
177    }
178
179    /// Set custom CA bundle (PEM) for TLS validation.
180    pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
181        self.tls_ca_cert_pem = Some(ca_cert_pem);
182        self
183    }
184
185    /// Enable mTLS for pool connections.
186    pub fn mtls(mut self, config: TlsConfig) -> Self {
187        self.mtls = Some(config);
188        self.tls_mode = TlsMode::Require;
189        self
190    }
191
192    /// Set Kerberos/GSS/SSPI token provider callback.
193    pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
194        self.gss_token_provider = Some(provider);
195        self
196    }
197
198    /// Set a stateful Kerberos/GSS/SSPI token provider.
199    pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
200        self.gss_token_provider_ex = Some(provider);
201        self
202    }
203
204    /// Set retry count for transient GSS/Kerberos connection failures.
205    pub fn gss_connect_retries(mut self, retries: usize) -> Self {
206        self.gss_connect_retries = retries;
207        self
208    }
209
210    /// Set base backoff delay for GSS/Kerberos connection retry.
211    pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
212        self.gss_retry_base_delay = delay;
213        self
214    }
215
216    /// Set failure threshold for opening local GSS circuit breaker.
217    pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
218        self.gss_circuit_breaker_threshold = threshold;
219        self
220    }
221
222    /// Set rolling failure window for GSS circuit breaker.
223    pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
224        self.gss_circuit_breaker_window = window;
225        self
226    }
227
228    /// Set cooldown duration for open GSS circuit breaker.
229    pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
230        self.gss_circuit_breaker_cooldown = cooldown;
231        self
232    }
233
234    /// Set authentication policy.
235    pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
236        self.auth_settings = settings;
237        self
238    }
239
240    /// Create a `PoolConfig` from a centralized `QailConfig`.
241    ///
242    /// Parses `postgres.url` for host/port/user/database/password
243    /// and applies pool tuning from `[postgres]` section.
244    pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
245        let pg = &qail.postgres;
246        let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
247
248        let mut config = PoolConfig::new(&host, port, &user, &database)
249            .max_connections(pg.max_connections)
250            .min_connections(pg.min_connections)
251            .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
252            .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
253            .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
254            .test_on_acquire(pg.test_on_acquire);
255
256        if let Some(ref pw) = password {
257            config = config.password(pw);
258        }
259
260        // Optional URL query params for enterprise auth/TLS settings.
261        if let Some(query) = pg.url.split('?').nth(1) {
262            apply_url_query_params(&mut config, query, &host)?;
263        }
264
265        Ok(config)
266    }
267}
268
269/// Apply enterprise auth/TLS query parameters to a `PoolConfig`.
270///
271/// Shared between `PoolConfig::from_qail_config` and `PgDriver::connect_url`
272/// so that both paths support the same set of URL knobs.
273#[allow(unused_variables)]
274pub(crate) fn apply_url_query_params(
275    config: &mut PoolConfig,
276    query: &str,
277    host: &str,
278) -> PgResult<()> {
279    let mut sslcert: Option<String> = None;
280    let mut sslkey: Option<String> = None;
281    let mut gss_provider: Option<String> = None;
282    let mut gss_service = "postgres".to_string();
283    let mut gss_target: Option<String> = None;
284
285    for pair in query.split('&').filter(|p| !p.is_empty()) {
286        let mut kv = pair.splitn(2, '=');
287        let key = kv.next().unwrap_or_default().trim();
288        let value = kv.next().unwrap_or_default().trim();
289
290        match key {
291            "sslmode" => {
292                let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
293                    PgError::Connection(format!("Invalid sslmode value: {}", value))
294                })?;
295                config.tls_mode = mode;
296            }
297            "gssencmode" => {
298                let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
299                    PgError::Connection(format!("Invalid gssencmode value: {}", value))
300                })?;
301                config.gss_enc_mode = mode;
302            }
303            "sslrootcert" => {
304                let ca_pem = std::fs::read(value).map_err(|e| {
305                    PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
306                })?;
307                config.tls_ca_cert_pem = Some(ca_pem);
308            }
309            "sslcert" => sslcert = Some(value.to_string()),
310            "sslkey" => sslkey = Some(value.to_string()),
311            "channel_binding" => {
312                let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
313                    PgError::Connection(format!("Invalid channel_binding value: {}", value))
314                })?;
315                config.auth_settings.channel_binding = mode;
316            }
317            "auth_scram" => {
318                let enabled = parse_bool_param(value).ok_or_else(|| {
319                    PgError::Connection(format!("Invalid auth_scram value: {}", value))
320                })?;
321                config.auth_settings.allow_scram_sha_256 = enabled;
322            }
323            "auth_md5" => {
324                let enabled = parse_bool_param(value).ok_or_else(|| {
325                    PgError::Connection(format!("Invalid auth_md5 value: {}", value))
326                })?;
327                config.auth_settings.allow_md5_password = enabled;
328            }
329            "auth_cleartext" => {
330                let enabled = parse_bool_param(value).ok_or_else(|| {
331                    PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
332                })?;
333                config.auth_settings.allow_cleartext_password = enabled;
334            }
335            "auth_kerberos" => {
336                let enabled = parse_bool_param(value).ok_or_else(|| {
337                    PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
338                })?;
339                config.auth_settings.allow_kerberos_v5 = enabled;
340            }
341            "auth_gssapi" => {
342                let enabled = parse_bool_param(value).ok_or_else(|| {
343                    PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
344                })?;
345                config.auth_settings.allow_gssapi = enabled;
346            }
347            "auth_sspi" => {
348                let enabled = parse_bool_param(value).ok_or_else(|| {
349                    PgError::Connection(format!("Invalid auth_sspi value: {}", value))
350                })?;
351                config.auth_settings.allow_sspi = enabled;
352            }
353            "auth_mode" => {
354                if value.eq_ignore_ascii_case("scram_only") {
355                    config.auth_settings = AuthSettings::scram_only();
356                } else if value.eq_ignore_ascii_case("gssapi_only") {
357                    config.auth_settings = AuthSettings::gssapi_only();
358                } else if value.eq_ignore_ascii_case("compat")
359                    || value.eq_ignore_ascii_case("default")
360                {
361                    config.auth_settings = AuthSettings::default();
362                } else {
363                    return Err(PgError::Connection(format!(
364                        "Invalid auth_mode value: {}",
365                        value
366                    )));
367                }
368            }
369            "gss_provider" => gss_provider = Some(value.to_string()),
370            "gss_service" => {
371                if value.is_empty() {
372                    return Err(PgError::Connection(
373                        "gss_service must not be empty".to_string(),
374                    ));
375                }
376                gss_service = value.to_string();
377            }
378            // libpq alias for kerberos service principal name component.
379            "krbsrvname" => {
380                if value.is_empty() {
381                    return Err(PgError::Connection(
382                        "gss_service must not be empty".to_string(),
383                    ));
384                }
385                gss_service = value.to_string();
386            }
387            "gss_target" => {
388                if value.is_empty() {
389                    return Err(PgError::Connection(
390                        "gss_target must not be empty".to_string(),
391                    ));
392                }
393                gss_target = Some(value.to_string());
394            }
395            // libpq alias for GSS target hostname override.
396            "gsshostname" => {
397                if value.is_empty() {
398                    return Err(PgError::Connection(
399                        "gss_target must not be empty".to_string(),
400                    ));
401                }
402                gss_target = Some(value.to_string());
403            }
404            // libpq compatibility knob; accepted values are validated but
405            // provider selection remains controlled by qail `gss_provider`.
406            "gsslib" => match value.trim().to_ascii_lowercase().as_str() {
407                "gssapi" | "sspi" => {}
408                _ => {
409                    return Err(PgError::Connection(format!(
410                        "Invalid gsslib value: {} (expected gssapi or sspi)",
411                        value
412                    )));
413                }
414            },
415            "gss_connect_retries" => {
416                let retries = value.parse::<usize>().map_err(|_| {
417                    PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
418                })?;
419                if retries > 20 {
420                    return Err(PgError::Connection(
421                        "gss_connect_retries must be <= 20".to_string(),
422                    ));
423                }
424                config.gss_connect_retries = retries;
425            }
426            "gss_retry_base_ms" => {
427                let delay_ms = value.parse::<u64>().map_err(|_| {
428                    PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
429                })?;
430                if delay_ms == 0 {
431                    return Err(PgError::Connection(
432                        "gss_retry_base_ms must be greater than 0".to_string(),
433                    ));
434                }
435                config.gss_retry_base_delay = Duration::from_millis(delay_ms);
436            }
437            "gss_circuit_threshold" => {
438                let threshold = value.parse::<usize>().map_err(|_| {
439                    PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
440                })?;
441                if threshold > 100 {
442                    return Err(PgError::Connection(
443                        "gss_circuit_threshold must be <= 100".to_string(),
444                    ));
445                }
446                config.gss_circuit_breaker_threshold = threshold;
447            }
448            "gss_circuit_window_ms" => {
449                let window_ms = value.parse::<u64>().map_err(|_| {
450                    PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
451                })?;
452                if window_ms == 0 {
453                    return Err(PgError::Connection(
454                        "gss_circuit_window_ms must be greater than 0".to_string(),
455                    ));
456                }
457                config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
458            }
459            "gss_circuit_cooldown_ms" => {
460                let cooldown_ms = value.parse::<u64>().map_err(|_| {
461                    PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
462                })?;
463                if cooldown_ms == 0 {
464                    return Err(PgError::Connection(
465                        "gss_circuit_cooldown_ms must be greater than 0".to_string(),
466                    ));
467                }
468                config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
469            }
470            _ => {}
471        }
472    }
473
474    match (sslcert.as_deref(), sslkey.as_deref()) {
475        (Some(cert_path), Some(key_path)) => {
476            let mtls = TlsConfig {
477                client_cert_pem: std::fs::read(cert_path).map_err(|e| {
478                    PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
479                })?,
480                client_key_pem: std::fs::read(key_path).map_err(|e| {
481                    PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
482                })?,
483                ca_cert_pem: config.tls_ca_cert_pem.clone(),
484            };
485            config.mtls = Some(mtls);
486            config.tls_mode = TlsMode::Require;
487        }
488        (Some(_), None) | (None, Some(_)) => {
489            return Err(PgError::Connection(
490                "Both sslcert and sslkey must be provided together".to_string(),
491            ));
492        }
493        (None, None) => {}
494    }
495
496    if let Some(provider) = gss_provider {
497        if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
498            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
499            {
500                let provider = crate::driver::gss::linux_krb5_token_provider(
501                    crate::driver::gss::LinuxKrb5ProviderConfig {
502                        host: host.to_string(),
503                        service: gss_service.clone(),
504                        target_name: gss_target.clone(),
505                    },
506                )
507                .map_err(PgError::Auth)?;
508                config.gss_token_provider_ex = Some(provider);
509            }
510            #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
511            {
512                let _ = gss_service;
513                let _ = gss_target;
514                return Err(PgError::Connection(
515                    "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
516                        .to_string(),
517                ));
518            }
519        } else if provider.eq_ignore_ascii_case("callback")
520            || provider.eq_ignore_ascii_case("custom")
521        {
522            // External callback wiring is handled by application code.
523        } else {
524            return Err(PgError::Connection(format!(
525                "Invalid gss_provider value: {}",
526                provider
527            )));
528        }
529    }
530
531    Ok(())
532}
533
534/// Parse a postgres URL into (host, port, user, database, password).
535pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
536    let url = url.split('?').next().unwrap_or(url);
537    let url = url
538        .trim_start_matches("postgres://")
539        .trim_start_matches("postgresql://");
540
541    let (credentials, host_part) = if url.contains('@') {
542        let mut parts = url.splitn(2, '@');
543        let creds = parts.next().unwrap_or("");
544        let host = parts.next().unwrap_or("localhost/postgres");
545        (Some(creds), host)
546    } else {
547        (None, url)
548    };
549
550    let (host_port, database) = if host_part.contains('/') {
551        let mut parts = host_part.splitn(2, '/');
552        (
553            parts.next().unwrap_or("localhost"),
554            parts.next().unwrap_or("postgres").to_string(),
555        )
556    } else {
557        (host_part, "postgres".to_string())
558    };
559
560    let (host, port) = if host_port.contains(':') {
561        let mut parts = host_port.split(':');
562        let h = parts.next().unwrap_or("localhost").to_string();
563        let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
564        (h, p)
565    } else {
566        (host_port.to_string(), 5432u16)
567    };
568
569    let (user, password) = if let Some(creds) = credentials {
570        if creds.contains(':') {
571            let mut parts = creds.splitn(2, ':');
572            let u = parts.next().unwrap_or("postgres").to_string();
573            let p = parts.next().map(|s| s.to_string());
574            (u, p)
575        } else {
576            (creds.to_string(), None)
577        }
578    } else {
579        ("postgres".to_string(), None)
580    };
581
582    Ok((host, port, user, database, password))
583}
584
585pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
586    match value.trim().to_ascii_lowercase().as_str() {
587        "1" | "true" | "yes" | "on" => Some(true),
588        "0" | "false" | "no" | "off" => Some(false),
589        _ => None,
590    }
591}