1use crate::driver::{
4 AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5 ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9#[derive(Clone)]
22pub struct PoolConfig {
23 pub host: String,
25 pub port: u16,
27 pub user: String,
29 pub database: String,
31 pub password: Option<String>,
33 pub max_connections: usize,
35 pub min_connections: usize,
37 pub idle_timeout: Duration,
39 pub acquire_timeout: Duration,
41 pub connect_timeout: Duration,
43 pub max_lifetime: Option<Duration>,
45 pub leaked_cleanup_queue: usize,
50 pub test_on_acquire: bool,
52 pub tls_mode: TlsMode,
54 pub tls_ca_cert_pem: Option<Vec<u8>>,
56 pub mtls: Option<TlsConfig>,
58 pub gss_token_provider: Option<GssTokenProvider>,
60 pub gss_token_provider_ex: Option<GssTokenProviderEx>,
62 pub gss_connect_retries: usize,
64 pub gss_retry_base_delay: Duration,
66 pub gss_circuit_breaker_threshold: usize,
68 pub gss_circuit_breaker_window: Duration,
70 pub gss_circuit_breaker_cooldown: Duration,
72 pub auth_settings: AuthSettings,
74 pub gss_enc_mode: GssEncMode,
76}
77
78impl PoolConfig {
79 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
91 Self {
92 host: host.to_string(),
93 port,
94 user: user.to_string(),
95 database: database.to_string(),
96 password: None,
97 max_connections: 10,
98 min_connections: 1,
99 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, leaked_cleanup_queue: 64, test_on_acquire: false, tls_mode: TlsMode::Prefer,
106 tls_ca_cert_pem: None,
107 mtls: None,
108 gss_token_provider: None,
109 gss_token_provider_ex: None,
110 gss_connect_retries: 2,
111 gss_retry_base_delay: Duration::from_millis(150),
112 gss_circuit_breaker_threshold: 8,
113 gss_circuit_breaker_window: Duration::from_secs(30),
114 gss_circuit_breaker_cooldown: Duration::from_secs(15),
115 auth_settings: AuthSettings::scram_only(),
116 gss_enc_mode: GssEncMode::Disable,
117 }
118 }
119
120 pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
125 let mut config = Self::new(host, port, user, database);
126 config.tls_mode = TlsMode::Disable;
127 config.auth_settings = AuthSettings::default();
128 config
129 }
130
131 pub fn password(mut self, password: &str) -> Self {
133 self.password = Some(password.to_string());
134 self
135 }
136
137 pub fn max_connections(mut self, max: usize) -> Self {
139 self.max_connections = max;
140 self
141 }
142
143 pub fn min_connections(mut self, min: usize) -> Self {
145 self.min_connections = min;
146 self
147 }
148
149 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
151 self.idle_timeout = timeout;
152 self
153 }
154
155 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
157 self.acquire_timeout = timeout;
158 self
159 }
160
161 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
163 self.connect_timeout = timeout;
164 self
165 }
166
167 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
169 self.max_lifetime = Some(lifetime);
170 self
171 }
172
173 pub fn leaked_cleanup_queue(mut self, max: usize) -> Self {
177 self.leaked_cleanup_queue = max;
178 self
179 }
180
181 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
183 self.test_on_acquire = enabled;
184 self
185 }
186
187 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
189 self.tls_mode = mode;
190 self
191 }
192
193 pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
195 self.tls_ca_cert_pem = Some(ca_cert_pem);
196 self
197 }
198
199 pub fn mtls(mut self, config: TlsConfig) -> Self {
201 self.mtls = Some(config);
202 self.tls_mode = TlsMode::Require;
203 self
204 }
205
206 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
208 self.gss_token_provider = Some(provider);
209 self
210 }
211
212 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
214 self.gss_token_provider_ex = Some(provider);
215 self
216 }
217
218 pub fn gss_connect_retries(mut self, retries: usize) -> Self {
220 self.gss_connect_retries = retries;
221 self
222 }
223
224 pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
226 self.gss_retry_base_delay = delay;
227 self
228 }
229
230 pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
232 self.gss_circuit_breaker_threshold = threshold;
233 self
234 }
235
236 pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
238 self.gss_circuit_breaker_window = window;
239 self
240 }
241
242 pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
244 self.gss_circuit_breaker_cooldown = cooldown;
245 self
246 }
247
248 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
250 self.auth_settings = settings;
251 self
252 }
253
254 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
259 let pg = &qail.postgres;
260 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
261
262 let mut config = PoolConfig::new(&host, port, &user, &database)
263 .max_connections(pg.max_connections)
264 .min_connections(pg.min_connections)
265 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
266 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
267 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
268 .test_on_acquire(pg.test_on_acquire);
269
270 if let Some(ref pw) = password {
271 config = config.password(pw);
272 }
273
274 if let Some(query) = pg.url.split('?').nth(1) {
276 apply_url_query_params(&mut config, query, &host)?;
277 }
278
279 Ok(config)
280 }
281}
282
283#[allow(unused_variables)]
288pub(crate) fn apply_url_query_params(
289 config: &mut PoolConfig,
290 query: &str,
291 host: &str,
292) -> PgResult<()> {
293 let mut sslcert: Option<String> = None;
294 let mut sslkey: Option<String> = None;
295 let mut gss_provider: Option<String> = None;
296 let mut gss_service = "postgres".to_string();
297 let mut gss_target: Option<String> = None;
298
299 for pair in query.split('&').filter(|p| !p.is_empty()) {
300 let mut kv = pair.splitn(2, '=');
301 let key = kv.next().unwrap_or_default().trim();
302 let value = kv.next().unwrap_or_default().trim();
303
304 match key {
305 "sslmode" => {
306 let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
307 PgError::Connection(format!("Invalid sslmode value: {}", value))
308 })?;
309 config.tls_mode = mode;
310 }
311 "gssencmode" => {
312 let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
313 PgError::Connection(format!("Invalid gssencmode value: {}", value))
314 })?;
315 config.gss_enc_mode = mode;
316 }
317 "sslrootcert" => {
318 let ca_pem = std::fs::read(value).map_err(|e| {
319 PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
320 })?;
321 config.tls_ca_cert_pem = Some(ca_pem);
322 }
323 "sslcert" => sslcert = Some(value.to_string()),
324 "sslkey" => sslkey = Some(value.to_string()),
325 "channel_binding" => {
326 let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
327 PgError::Connection(format!("Invalid channel_binding value: {}", value))
328 })?;
329 config.auth_settings.channel_binding = mode;
330 }
331 "auth_scram" => {
332 let enabled = parse_bool_param(value).ok_or_else(|| {
333 PgError::Connection(format!("Invalid auth_scram value: {}", value))
334 })?;
335 config.auth_settings.allow_scram_sha_256 = enabled;
336 }
337 "auth_md5" => {
338 let enabled = parse_bool_param(value).ok_or_else(|| {
339 PgError::Connection(format!("Invalid auth_md5 value: {}", value))
340 })?;
341 config.auth_settings.allow_md5_password = enabled;
342 }
343 "auth_cleartext" => {
344 let enabled = parse_bool_param(value).ok_or_else(|| {
345 PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
346 })?;
347 config.auth_settings.allow_cleartext_password = enabled;
348 }
349 "auth_kerberos" => {
350 let enabled = parse_bool_param(value).ok_or_else(|| {
351 PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
352 })?;
353 config.auth_settings.allow_kerberos_v5 = enabled;
354 }
355 "auth_gssapi" => {
356 let enabled = parse_bool_param(value).ok_or_else(|| {
357 PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
358 })?;
359 config.auth_settings.allow_gssapi = enabled;
360 }
361 "auth_sspi" => {
362 let enabled = parse_bool_param(value).ok_or_else(|| {
363 PgError::Connection(format!("Invalid auth_sspi value: {}", value))
364 })?;
365 config.auth_settings.allow_sspi = enabled;
366 }
367 "auth_mode" => {
368 if value.eq_ignore_ascii_case("scram_only") {
369 config.auth_settings = AuthSettings::scram_only();
370 } else if value.eq_ignore_ascii_case("gssapi_only") {
371 config.auth_settings = AuthSettings::gssapi_only();
372 } else if value.eq_ignore_ascii_case("compat")
373 || value.eq_ignore_ascii_case("default")
374 {
375 config.auth_settings = AuthSettings::default();
376 } else {
377 return Err(PgError::Connection(format!(
378 "Invalid auth_mode value: {}",
379 value
380 )));
381 }
382 }
383 "gss_provider" => gss_provider = Some(value.to_string()),
384 "gss_service" => {
385 if value.is_empty() {
386 return Err(PgError::Connection(
387 "gss_service must not be empty".to_string(),
388 ));
389 }
390 gss_service = value.to_string();
391 }
392 "krbsrvname" => {
394 if value.is_empty() {
395 return Err(PgError::Connection(
396 "gss_service must not be empty".to_string(),
397 ));
398 }
399 gss_service = value.to_string();
400 }
401 "gss_target" => {
402 if value.is_empty() {
403 return Err(PgError::Connection(
404 "gss_target must not be empty".to_string(),
405 ));
406 }
407 gss_target = Some(value.to_string());
408 }
409 "gsshostname" => {
411 if value.is_empty() {
412 return Err(PgError::Connection(
413 "gss_target must not be empty".to_string(),
414 ));
415 }
416 gss_target = Some(value.to_string());
417 }
418 "gsslib" => match value.trim().to_ascii_lowercase().as_str() {
421 "gssapi" | "sspi" => {}
422 _ => {
423 return Err(PgError::Connection(format!(
424 "Invalid gsslib value: {} (expected gssapi or sspi)",
425 value
426 )));
427 }
428 },
429 "gss_connect_retries" => {
430 let retries = value.parse::<usize>().map_err(|_| {
431 PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
432 })?;
433 if retries > 20 {
434 return Err(PgError::Connection(
435 "gss_connect_retries must be <= 20".to_string(),
436 ));
437 }
438 config.gss_connect_retries = retries;
439 }
440 "gss_retry_base_ms" => {
441 let delay_ms = value.parse::<u64>().map_err(|_| {
442 PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
443 })?;
444 if delay_ms == 0 {
445 return Err(PgError::Connection(
446 "gss_retry_base_ms must be greater than 0".to_string(),
447 ));
448 }
449 config.gss_retry_base_delay = Duration::from_millis(delay_ms);
450 }
451 "gss_circuit_threshold" => {
452 let threshold = value.parse::<usize>().map_err(|_| {
453 PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
454 })?;
455 if threshold > 100 {
456 return Err(PgError::Connection(
457 "gss_circuit_threshold must be <= 100".to_string(),
458 ));
459 }
460 config.gss_circuit_breaker_threshold = threshold;
461 }
462 "gss_circuit_window_ms" => {
463 let window_ms = value.parse::<u64>().map_err(|_| {
464 PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
465 })?;
466 if window_ms == 0 {
467 return Err(PgError::Connection(
468 "gss_circuit_window_ms must be greater than 0".to_string(),
469 ));
470 }
471 config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
472 }
473 "gss_circuit_cooldown_ms" => {
474 let cooldown_ms = value.parse::<u64>().map_err(|_| {
475 PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
476 })?;
477 if cooldown_ms == 0 {
478 return Err(PgError::Connection(
479 "gss_circuit_cooldown_ms must be greater than 0".to_string(),
480 ));
481 }
482 config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
483 }
484 _ => {}
485 }
486 }
487
488 match (sslcert.as_deref(), sslkey.as_deref()) {
489 (Some(cert_path), Some(key_path)) => {
490 let mtls = TlsConfig {
491 client_cert_pem: std::fs::read(cert_path).map_err(|e| {
492 PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
493 })?,
494 client_key_pem: std::fs::read(key_path).map_err(|e| {
495 PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
496 })?,
497 ca_cert_pem: config.tls_ca_cert_pem.clone(),
498 };
499 config.mtls = Some(mtls);
500 config.tls_mode = TlsMode::Require;
501 }
502 (Some(_), None) | (None, Some(_)) => {
503 return Err(PgError::Connection(
504 "Both sslcert and sslkey must be provided together".to_string(),
505 ));
506 }
507 (None, None) => {}
508 }
509
510 if let Some(provider) = gss_provider {
511 if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
512 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
513 {
514 let provider = crate::driver::gss::linux_krb5_token_provider(
515 crate::driver::gss::LinuxKrb5ProviderConfig {
516 host: host.to_string(),
517 service: gss_service.clone(),
518 target_name: gss_target.clone(),
519 },
520 )
521 .map_err(PgError::Auth)?;
522 config.gss_token_provider_ex = Some(provider);
523 }
524 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
525 {
526 let _ = gss_service;
527 let _ = gss_target;
528 return Err(PgError::Connection(
529 "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
530 .to_string(),
531 ));
532 }
533 } else if provider.eq_ignore_ascii_case("callback")
534 || provider.eq_ignore_ascii_case("custom")
535 {
536 } else {
538 return Err(PgError::Connection(format!(
539 "Invalid gss_provider value: {}",
540 provider
541 )));
542 }
543 }
544
545 Ok(())
546}
547
548pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
550 let url = url.split('?').next().unwrap_or(url);
551 let url = url
552 .trim_start_matches("postgres://")
553 .trim_start_matches("postgresql://");
554
555 let (credentials, host_part) = if url.contains('@') {
556 let mut parts = url.splitn(2, '@');
557 let creds = parts.next().unwrap_or("");
558 let host = parts.next().unwrap_or("localhost/postgres");
559 (Some(creds), host)
560 } else {
561 (None, url)
562 };
563
564 let (host_port, database) = if host_part.contains('/') {
565 let mut parts = host_part.splitn(2, '/');
566 (
567 parts.next().unwrap_or("localhost"),
568 parts.next().unwrap_or("postgres").to_string(),
569 )
570 } else {
571 (host_part, "postgres".to_string())
572 };
573
574 let (host, port) = if host_port.contains(':') {
575 let mut parts = host_port.split(':');
576 let h = parts.next().unwrap_or("localhost").to_string();
577 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
578 (h, p)
579 } else {
580 (host_port.to_string(), 5432u16)
581 };
582
583 let (user, password) = if let Some(creds) = credentials {
584 if creds.contains(':') {
585 let mut parts = creds.splitn(2, ':');
586 let u = parts.next().unwrap_or("postgres").to_string();
587 let p = parts.next().map(|s| s.to_string());
588 (u, p)
589 } else {
590 (creds.to_string(), None)
591 }
592 } else {
593 ("postgres".to_string(), None)
594 };
595
596 Ok((host, port, user, database, password))
597}
598
599pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
600 match value.trim().to_ascii_lowercase().as_str() {
601 "1" | "true" | "yes" | "on" => Some(true),
602 "0" | "false" | "no" | "off" => Some(false),
603 _ => None,
604 }
605}