1use crate::driver::{
4 AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5 ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9#[derive(Clone)]
22pub struct PoolConfig {
23 pub host: String,
25 pub port: u16,
27 pub user: String,
29 pub database: String,
31 pub password: Option<String>,
33 pub max_connections: usize,
35 pub min_connections: usize,
37 pub idle_timeout: Duration,
39 pub acquire_timeout: Duration,
41 pub connect_timeout: Duration,
43 pub max_lifetime: Option<Duration>,
45 pub leaked_cleanup_queue: usize,
50 pub test_on_acquire: bool,
52 pub tls_mode: TlsMode,
54 pub tls_ca_cert_pem: Option<Vec<u8>>,
56 pub mtls: Option<TlsConfig>,
58 pub gss_token_provider: Option<GssTokenProvider>,
60 pub gss_token_provider_ex: Option<GssTokenProviderEx>,
62 pub gss_connect_retries: usize,
64 pub gss_retry_base_delay: Duration,
66 pub gss_circuit_breaker_threshold: usize,
68 pub gss_circuit_breaker_window: Duration,
70 pub gss_circuit_breaker_cooldown: Duration,
72 pub auth_settings: AuthSettings,
74 pub gss_enc_mode: GssEncMode,
76}
77
78impl PoolConfig {
79 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
91 Self {
92 host: host.to_string(),
93 port,
94 user: user.to_string(),
95 database: database.to_string(),
96 password: None,
97 max_connections: 10,
98 min_connections: 1,
99 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, leaked_cleanup_queue: 64, test_on_acquire: false, tls_mode: TlsMode::Prefer,
106 tls_ca_cert_pem: None,
107 mtls: None,
108 gss_token_provider: None,
109 gss_token_provider_ex: None,
110 gss_connect_retries: 2,
111 gss_retry_base_delay: Duration::from_millis(150),
112 gss_circuit_breaker_threshold: 8,
113 gss_circuit_breaker_window: Duration::from_secs(30),
114 gss_circuit_breaker_cooldown: Duration::from_secs(15),
115 auth_settings: AuthSettings::scram_only(),
116 gss_enc_mode: GssEncMode::Disable,
117 }
118 }
119
120 pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
125 let mut config = Self::new(host, port, user, database);
126 config.tls_mode = TlsMode::Disable;
127 config.auth_settings = AuthSettings::default();
128 config
129 }
130
131 pub fn password(mut self, password: &str) -> Self {
133 self.password = Some(password.to_string());
134 self
135 }
136
137 pub fn max_connections(mut self, max: usize) -> Self {
139 self.max_connections = max;
140 self
141 }
142
143 pub fn min_connections(mut self, min: usize) -> Self {
145 self.min_connections = min;
146 self
147 }
148
149 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
151 self.idle_timeout = timeout;
152 self
153 }
154
155 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
157 self.acquire_timeout = timeout;
158 self
159 }
160
161 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
163 self.connect_timeout = timeout;
164 self
165 }
166
167 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
169 self.max_lifetime = Some(lifetime);
170 self
171 }
172
173 pub fn leaked_cleanup_queue(mut self, max: usize) -> Self {
177 self.leaked_cleanup_queue = max;
178 self
179 }
180
181 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
183 self.test_on_acquire = enabled;
184 self
185 }
186
187 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
189 self.tls_mode = mode;
190 self
191 }
192
193 pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
195 self.tls_ca_cert_pem = Some(ca_cert_pem);
196 self
197 }
198
199 pub fn mtls(mut self, config: TlsConfig) -> Self {
201 self.mtls = Some(config);
202 self.tls_mode = TlsMode::Require;
203 self
204 }
205
206 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
208 self.gss_token_provider = Some(provider);
209 self
210 }
211
212 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
214 self.gss_token_provider_ex = Some(provider);
215 self
216 }
217
218 pub fn gss_connect_retries(mut self, retries: usize) -> Self {
220 self.gss_connect_retries = retries;
221 self
222 }
223
224 pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
226 self.gss_retry_base_delay = delay;
227 self
228 }
229
230 pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
232 self.gss_circuit_breaker_threshold = threshold;
233 self
234 }
235
236 pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
238 self.gss_circuit_breaker_window = window;
239 self
240 }
241
242 pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
244 self.gss_circuit_breaker_cooldown = cooldown;
245 self
246 }
247
248 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
250 self.auth_settings = settings;
251 self
252 }
253
254 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
259 let pg = &qail.postgres;
260 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
261
262 let mut config = PoolConfig::new(&host, port, &user, &database)
263 .max_connections(pg.max_connections)
264 .min_connections(pg.min_connections)
265 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
266 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
267 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
268 .test_on_acquire(pg.test_on_acquire);
269
270 if let Some(ref pw) = password {
271 config = config.password(pw);
272 }
273
274 if let Some((_, query)) = pg.url.split_once('?') {
276 apply_url_query_params(&mut config, query, &host)?;
277 }
278
279 Ok(config)
280 }
281}
282
283#[allow(unused_variables)]
288pub(crate) fn apply_url_query_params(
289 config: &mut PoolConfig,
290 query: &str,
291 host: &str,
292) -> PgResult<()> {
293 let mut sslcert: Option<String> = None;
294 let mut sslkey: Option<String> = None;
295 let mut gss_provider: Option<String> = None;
296 let mut gss_service = "postgres".to_string();
297 let mut gss_target: Option<String> = None;
298
299 for pair in query.split('&').filter(|p| !p.is_empty()) {
300 let mut kv = pair.splitn(2, '=');
301 let key = percent_decode(kv.next().unwrap_or_default().trim())?;
302 let value = percent_decode(kv.next().unwrap_or_default().trim())?;
303
304 match key.as_str() {
305 "sslmode" => {
306 let mode = TlsMode::parse_sslmode(&value).ok_or_else(|| {
307 PgError::Connection(format!("Invalid sslmode value: {}", value))
308 })?;
309 config.tls_mode = mode;
310 }
311 "gssencmode" => {
312 let mode = GssEncMode::parse_gssencmode(&value).ok_or_else(|| {
313 PgError::Connection(format!("Invalid gssencmode value: {}", value))
314 })?;
315 config.gss_enc_mode = mode;
316 }
317 "sslrootcert" => {
318 let ca_pem = std::fs::read(&value).map_err(|e| {
319 PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
320 })?;
321 config.tls_ca_cert_pem = Some(ca_pem);
322 }
323 "sslcert" => sslcert = Some(value.clone()),
324 "sslkey" => sslkey = Some(value.clone()),
325 "channel_binding" => {
326 let mode = ScramChannelBindingMode::parse(&value).ok_or_else(|| {
327 PgError::Connection(format!("Invalid channel_binding value: {}", value))
328 })?;
329 config.auth_settings.channel_binding = mode;
330 }
331 "auth_scram" => {
332 let enabled = parse_bool_param(&value).ok_or_else(|| {
333 PgError::Connection(format!("Invalid auth_scram value: {}", value))
334 })?;
335 config.auth_settings.allow_scram_sha_256 = enabled;
336 }
337 "auth_md5" => {
338 let enabled = parse_bool_param(&value).ok_or_else(|| {
339 PgError::Connection(format!("Invalid auth_md5 value: {}", value))
340 })?;
341 config.auth_settings.allow_md5_password = enabled;
342 }
343 "auth_cleartext" => {
344 let enabled = parse_bool_param(&value).ok_or_else(|| {
345 PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
346 })?;
347 config.auth_settings.allow_cleartext_password = enabled;
348 }
349 "auth_kerberos" => {
350 let enabled = parse_bool_param(&value).ok_or_else(|| {
351 PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
352 })?;
353 config.auth_settings.allow_kerberos_v5 = enabled;
354 }
355 "auth_gssapi" => {
356 let enabled = parse_bool_param(&value).ok_or_else(|| {
357 PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
358 })?;
359 config.auth_settings.allow_gssapi = enabled;
360 }
361 "auth_sspi" => {
362 let enabled = parse_bool_param(&value).ok_or_else(|| {
363 PgError::Connection(format!("Invalid auth_sspi value: {}", value))
364 })?;
365 config.auth_settings.allow_sspi = enabled;
366 }
367 "auth_mode" => {
368 if value.eq_ignore_ascii_case("scram_only") {
369 config.auth_settings = AuthSettings::scram_only();
370 } else if value.eq_ignore_ascii_case("gssapi_only") {
371 config.auth_settings = AuthSettings::gssapi_only();
372 } else if value.eq_ignore_ascii_case("compat")
373 || value.eq_ignore_ascii_case("default")
374 {
375 config.auth_settings = AuthSettings::default();
376 } else {
377 return Err(PgError::Connection(format!(
378 "Invalid auth_mode value: {}",
379 value
380 )));
381 }
382 }
383 "gss_provider" => gss_provider = Some(value.clone()),
384 "gss_service" => {
385 if value.is_empty() {
386 return Err(PgError::Connection(
387 "gss_service must not be empty".to_string(),
388 ));
389 }
390 gss_service = value.clone();
391 }
392 "krbsrvname" => {
394 if value.is_empty() {
395 return Err(PgError::Connection(
396 "gss_service must not be empty".to_string(),
397 ));
398 }
399 gss_service = value.clone();
400 }
401 "gss_target" => {
402 if value.is_empty() {
403 return Err(PgError::Connection(
404 "gss_target must not be empty".to_string(),
405 ));
406 }
407 gss_target = Some(value.clone());
408 }
409 "gsshostname" => {
411 if value.is_empty() {
412 return Err(PgError::Connection(
413 "gss_target must not be empty".to_string(),
414 ));
415 }
416 gss_target = Some(value.clone());
417 }
418 "gsslib" => match value.trim().to_ascii_lowercase().as_str() {
421 "gssapi" | "sspi" => {}
422 _ => {
423 return Err(PgError::Connection(format!(
424 "Invalid gsslib value: {} (expected gssapi or sspi)",
425 value
426 )));
427 }
428 },
429 "gss_connect_retries" => {
430 let retries = value.parse::<usize>().map_err(|_| {
431 PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
432 })?;
433 if retries > 20 {
434 return Err(PgError::Connection(
435 "gss_connect_retries must be <= 20".to_string(),
436 ));
437 }
438 config.gss_connect_retries = retries;
439 }
440 "gss_retry_base_ms" => {
441 let delay_ms = value.parse::<u64>().map_err(|_| {
442 PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
443 })?;
444 if delay_ms == 0 {
445 return Err(PgError::Connection(
446 "gss_retry_base_ms must be greater than 0".to_string(),
447 ));
448 }
449 config.gss_retry_base_delay = Duration::from_millis(delay_ms);
450 }
451 "gss_circuit_threshold" => {
452 let threshold = value.parse::<usize>().map_err(|_| {
453 PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
454 })?;
455 if threshold > 100 {
456 return Err(PgError::Connection(
457 "gss_circuit_threshold must be <= 100".to_string(),
458 ));
459 }
460 config.gss_circuit_breaker_threshold = threshold;
461 }
462 "gss_circuit_window_ms" => {
463 let window_ms = value.parse::<u64>().map_err(|_| {
464 PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
465 })?;
466 if window_ms == 0 {
467 return Err(PgError::Connection(
468 "gss_circuit_window_ms must be greater than 0".to_string(),
469 ));
470 }
471 config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
472 }
473 "gss_circuit_cooldown_ms" => {
474 let cooldown_ms = value.parse::<u64>().map_err(|_| {
475 PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
476 })?;
477 if cooldown_ms == 0 {
478 return Err(PgError::Connection(
479 "gss_circuit_cooldown_ms must be greater than 0".to_string(),
480 ));
481 }
482 config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
483 }
484 _ => {}
485 }
486 }
487
488 match (sslcert.as_deref(), sslkey.as_deref()) {
489 (Some(cert_path), Some(key_path)) => {
490 let mtls = TlsConfig {
491 client_cert_pem: std::fs::read(cert_path).map_err(|e| {
492 PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
493 })?,
494 client_key_pem: std::fs::read(key_path).map_err(|e| {
495 PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
496 })?,
497 ca_cert_pem: config.tls_ca_cert_pem.clone(),
498 };
499 config.mtls = Some(mtls);
500 config.tls_mode = TlsMode::Require;
501 }
502 (Some(_), None) | (None, Some(_)) => {
503 return Err(PgError::Connection(
504 "Both sslcert and sslkey must be provided together".to_string(),
505 ));
506 }
507 (None, None) => {}
508 }
509
510 if let Some(provider) = gss_provider {
511 if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
512 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
513 {
514 let provider = crate::driver::gss::linux_krb5_token_provider(
515 crate::driver::gss::LinuxKrb5ProviderConfig {
516 host: host.to_string(),
517 service: gss_service,
518 target_name: gss_target,
519 },
520 )
521 .map_err(PgError::Auth)?;
522 config.gss_token_provider_ex = Some(provider);
523 }
524 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
525 {
526 let _ = gss_service;
527 let _ = gss_target;
528 return Err(PgError::Connection(
529 "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
530 .to_string(),
531 ));
532 }
533 } else if provider.eq_ignore_ascii_case("callback")
534 || provider.eq_ignore_ascii_case("custom")
535 {
536 } else {
538 return Err(PgError::Connection(format!(
539 "Invalid gss_provider value: {}",
540 provider
541 )));
542 }
543 }
544
545 Ok(())
546}
547
548pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
550 let url = url.split('?').next().unwrap_or(url);
551 let url = if let Some(rest) = url.strip_prefix("postgres://") {
552 rest
553 } else if let Some(rest) = url.strip_prefix("postgresql://") {
554 rest
555 } else {
556 return Err(PgError::Connection(
557 "PostgreSQL URL must start with postgres:// or postgresql://".to_string(),
558 ));
559 };
560
561 let (credentials, host_part) = if url.contains('@') {
562 let mut parts = url.splitn(2, '@');
563 let creds = parts.next().unwrap_or("");
564 let host = parts.next().unwrap_or("localhost/postgres");
565 (Some(creds), host)
566 } else {
567 (None, url)
568 };
569
570 let (host_port, database) = if host_part.contains('/') {
571 let mut parts = host_part.splitn(2, '/');
572 (
573 parts.next().unwrap_or("localhost"),
574 percent_decode(parts.next().unwrap_or("postgres"))?,
575 )
576 } else {
577 (host_part, "postgres".to_string())
578 };
579
580 let (host, port) = if host_port.starts_with('[') {
581 let end = host_port.find(']').ok_or_else(|| {
582 PgError::Connection("Invalid PostgreSQL URL IPv6 host: missing ']'".to_string())
583 })?;
584 let host = &host_port[..=end];
585 if host == "[]" {
586 return Err(PgError::Connection(
587 "Invalid PostgreSQL URL host: missing host".to_string(),
588 ));
589 }
590 let suffix = &host_port[end + 1..];
591 let port = if suffix.is_empty() {
592 5432u16
593 } else if let Some(port_str) = suffix.strip_prefix(':') {
594 if port_str.is_empty() {
595 return Err(PgError::Connection(
596 "Invalid PostgreSQL URL port: missing port after ':'".to_string(),
597 ));
598 }
599 let p = port_str.parse::<u16>().map_err(|_| {
600 PgError::Connection(format!(
601 "Invalid PostgreSQL URL port '{}': expected a number from 1 to 65535",
602 port_str
603 ))
604 })?;
605 if p == 0 {
606 return Err(PgError::Connection(
607 "Invalid PostgreSQL URL port '0': expected a number from 1 to 65535"
608 .to_string(),
609 ));
610 }
611 p
612 } else {
613 return Err(PgError::Connection(
614 "Invalid PostgreSQL URL IPv6 host: unexpected characters after ']'".to_string(),
615 ));
616 };
617 (host.to_string(), port)
618 } else if host_port.contains(':') {
619 let mut parts = host_port.splitn(2, ':');
620 let h = parts.next().unwrap_or("localhost").to_string();
621 if h.is_empty() {
622 return Err(PgError::Connection(
623 "Invalid PostgreSQL URL host: missing host".to_string(),
624 ));
625 }
626 let port_str = parts.next().unwrap_or("");
627 if port_str.is_empty() {
628 return Err(PgError::Connection(
629 "Invalid PostgreSQL URL port: missing port after ':'".to_string(),
630 ));
631 }
632 let p = port_str.parse::<u16>().map_err(|_| {
633 PgError::Connection(format!(
634 "Invalid PostgreSQL URL port '{}': expected a number from 1 to 65535",
635 port_str
636 ))
637 })?;
638 if p == 0 {
639 return Err(PgError::Connection(
640 "Invalid PostgreSQL URL port '0': expected a number from 1 to 65535".to_string(),
641 ));
642 }
643 (h, p)
644 } else {
645 if host_port.is_empty() {
646 return Err(PgError::Connection(
647 "Invalid PostgreSQL URL host: missing host".to_string(),
648 ));
649 }
650 (host_port.to_string(), 5432u16)
651 };
652
653 let (user, password) = if let Some(creds) = credentials {
654 if creds.contains(':') {
655 let mut parts = creds.splitn(2, ':');
656 let u = percent_decode(parts.next().unwrap_or("postgres"))?;
657 let p = parts.next().map(percent_decode).transpose()?;
658 (u, p)
659 } else {
660 (percent_decode(creds)?, None)
661 }
662 } else {
663 ("postgres".to_string(), None)
664 };
665
666 Ok((host, port, user, database, password))
667}
668
669fn percent_decode(s: &str) -> PgResult<String> {
670 fn hex_value(byte: u8) -> Option<u8> {
671 match byte {
672 b'0'..=b'9' => Some(byte - b'0'),
673 b'a'..=b'f' => Some(byte - b'a' + 10),
674 b'A'..=b'F' => Some(byte - b'A' + 10),
675 _ => None,
676 }
677 }
678
679 let bytes = s.as_bytes();
680 let mut decoded = Vec::with_capacity(bytes.len());
681 let mut i = 0;
682
683 while i < bytes.len() {
684 if bytes[i] == b'%'
685 && i + 2 < bytes.len()
686 && let (Some(hi), Some(lo)) = (hex_value(bytes[i + 1]), hex_value(bytes[i + 2]))
687 {
688 decoded.push((hi << 4) | lo);
689 i += 3;
690 } else {
691 decoded.push(bytes[i]);
692 i += 1;
693 }
694 }
695
696 String::from_utf8(decoded).map_err(|_| {
697 PgError::Connection(
698 "Invalid PostgreSQL URL percent-encoding: decoded value is not UTF-8".to_string(),
699 )
700 })
701}
702
703pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
704 match value.trim().to_ascii_lowercase().as_str() {
705 "1" | "true" | "yes" | "on" => Some(true),
706 "0" | "false" | "no" | "off" => Some(false),
707 _ => None,
708 }
709}