1use crate::driver::{
4 AuthSettings, GssEncMode, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
5 ScramChannelBindingMode, TlsConfig, TlsMode,
6};
7use std::time::Duration;
8
9#[derive(Clone)]
22pub struct PoolConfig {
23 pub host: String,
25 pub port: u16,
27 pub user: String,
29 pub database: String,
31 pub password: Option<String>,
33 pub max_connections: usize,
35 pub min_connections: usize,
37 pub idle_timeout: Duration,
39 pub acquire_timeout: Duration,
41 pub connect_timeout: Duration,
43 pub max_lifetime: Option<Duration>,
45 pub test_on_acquire: bool,
47 pub tls_mode: TlsMode,
49 pub tls_ca_cert_pem: Option<Vec<u8>>,
51 pub mtls: Option<TlsConfig>,
53 pub gss_token_provider: Option<GssTokenProvider>,
55 pub gss_token_provider_ex: Option<GssTokenProviderEx>,
57 pub gss_connect_retries: usize,
59 pub gss_retry_base_delay: Duration,
61 pub gss_circuit_breaker_threshold: usize,
63 pub gss_circuit_breaker_window: Duration,
65 pub gss_circuit_breaker_cooldown: Duration,
67 pub auth_settings: AuthSettings,
69 pub gss_enc_mode: GssEncMode,
71}
72
73impl PoolConfig {
74 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
86 Self {
87 host: host.to_string(),
88 port,
89 user: user.to_string(),
90 database: database.to_string(),
91 password: None,
92 max_connections: 10,
93 min_connections: 1,
94 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, tls_mode: TlsMode::Prefer,
100 tls_ca_cert_pem: None,
101 mtls: None,
102 gss_token_provider: None,
103 gss_token_provider_ex: None,
104 gss_connect_retries: 2,
105 gss_retry_base_delay: Duration::from_millis(150),
106 gss_circuit_breaker_threshold: 8,
107 gss_circuit_breaker_window: Duration::from_secs(30),
108 gss_circuit_breaker_cooldown: Duration::from_secs(15),
109 auth_settings: AuthSettings::scram_only(),
110 gss_enc_mode: GssEncMode::Disable,
111 }
112 }
113
114 pub fn new_dev(host: &str, port: u16, user: &str, database: &str) -> Self {
119 let mut config = Self::new(host, port, user, database);
120 config.tls_mode = TlsMode::Disable;
121 config.auth_settings = AuthSettings::default();
122 config
123 }
124
125 pub fn password(mut self, password: &str) -> Self {
127 self.password = Some(password.to_string());
128 self
129 }
130
131 pub fn max_connections(mut self, max: usize) -> Self {
133 self.max_connections = max;
134 self
135 }
136
137 pub fn min_connections(mut self, min: usize) -> Self {
139 self.min_connections = min;
140 self
141 }
142
143 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
145 self.idle_timeout = timeout;
146 self
147 }
148
149 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
151 self.acquire_timeout = timeout;
152 self
153 }
154
155 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
157 self.connect_timeout = timeout;
158 self
159 }
160
161 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
163 self.max_lifetime = Some(lifetime);
164 self
165 }
166
167 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
169 self.test_on_acquire = enabled;
170 self
171 }
172
173 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
175 self.tls_mode = mode;
176 self
177 }
178
179 pub fn tls_ca_cert_pem(mut self, ca_cert_pem: Vec<u8>) -> Self {
181 self.tls_ca_cert_pem = Some(ca_cert_pem);
182 self
183 }
184
185 pub fn mtls(mut self, config: TlsConfig) -> Self {
187 self.mtls = Some(config);
188 self.tls_mode = TlsMode::Require;
189 self
190 }
191
192 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
194 self.gss_token_provider = Some(provider);
195 self
196 }
197
198 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
200 self.gss_token_provider_ex = Some(provider);
201 self
202 }
203
204 pub fn gss_connect_retries(mut self, retries: usize) -> Self {
206 self.gss_connect_retries = retries;
207 self
208 }
209
210 pub fn gss_retry_base_delay(mut self, delay: Duration) -> Self {
212 self.gss_retry_base_delay = delay;
213 self
214 }
215
216 pub fn gss_circuit_breaker_threshold(mut self, threshold: usize) -> Self {
218 self.gss_circuit_breaker_threshold = threshold;
219 self
220 }
221
222 pub fn gss_circuit_breaker_window(mut self, window: Duration) -> Self {
224 self.gss_circuit_breaker_window = window;
225 self
226 }
227
228 pub fn gss_circuit_breaker_cooldown(mut self, cooldown: Duration) -> Self {
230 self.gss_circuit_breaker_cooldown = cooldown;
231 self
232 }
233
234 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
236 self.auth_settings = settings;
237 self
238 }
239
240 pub fn from_qail_config(qail: &qail_core::config::QailConfig) -> PgResult<Self> {
245 let pg = &qail.postgres;
246 let (host, port, user, database, password) = parse_pg_url(&pg.url)?;
247
248 let mut config = PoolConfig::new(&host, port, &user, &database)
249 .max_connections(pg.max_connections)
250 .min_connections(pg.min_connections)
251 .idle_timeout(Duration::from_secs(pg.idle_timeout_secs))
252 .acquire_timeout(Duration::from_secs(pg.acquire_timeout_secs))
253 .connect_timeout(Duration::from_secs(pg.connect_timeout_secs))
254 .test_on_acquire(pg.test_on_acquire);
255
256 if let Some(ref pw) = password {
257 config = config.password(pw);
258 }
259
260 if let Some(query) = pg.url.split('?').nth(1) {
262 apply_url_query_params(&mut config, query, &host)?;
263 }
264
265 Ok(config)
266 }
267}
268
269#[allow(unused_variables)]
274pub(crate) fn apply_url_query_params(
275 config: &mut PoolConfig,
276 query: &str,
277 host: &str,
278) -> PgResult<()> {
279 let mut sslcert: Option<String> = None;
280 let mut sslkey: Option<String> = None;
281 let mut gss_provider: Option<String> = None;
282 let mut gss_service = "postgres".to_string();
283 let mut gss_target: Option<String> = None;
284
285 for pair in query.split('&').filter(|p| !p.is_empty()) {
286 let mut kv = pair.splitn(2, '=');
287 let key = kv.next().unwrap_or_default().trim();
288 let value = kv.next().unwrap_or_default().trim();
289
290 match key {
291 "sslmode" => {
292 let mode = TlsMode::parse_sslmode(value).ok_or_else(|| {
293 PgError::Connection(format!("Invalid sslmode value: {}", value))
294 })?;
295 config.tls_mode = mode;
296 }
297 "gssencmode" => {
298 let mode = GssEncMode::parse_gssencmode(value).ok_or_else(|| {
299 PgError::Connection(format!("Invalid gssencmode value: {}", value))
300 })?;
301 config.gss_enc_mode = mode;
302 }
303 "sslrootcert" => {
304 let ca_pem = std::fs::read(value).map_err(|e| {
305 PgError::Connection(format!("Failed to read sslrootcert '{}': {}", value, e))
306 })?;
307 config.tls_ca_cert_pem = Some(ca_pem);
308 }
309 "sslcert" => sslcert = Some(value.to_string()),
310 "sslkey" => sslkey = Some(value.to_string()),
311 "channel_binding" => {
312 let mode = ScramChannelBindingMode::parse(value).ok_or_else(|| {
313 PgError::Connection(format!("Invalid channel_binding value: {}", value))
314 })?;
315 config.auth_settings.channel_binding = mode;
316 }
317 "auth_scram" => {
318 let enabled = parse_bool_param(value).ok_or_else(|| {
319 PgError::Connection(format!("Invalid auth_scram value: {}", value))
320 })?;
321 config.auth_settings.allow_scram_sha_256 = enabled;
322 }
323 "auth_md5" => {
324 let enabled = parse_bool_param(value).ok_or_else(|| {
325 PgError::Connection(format!("Invalid auth_md5 value: {}", value))
326 })?;
327 config.auth_settings.allow_md5_password = enabled;
328 }
329 "auth_cleartext" => {
330 let enabled = parse_bool_param(value).ok_or_else(|| {
331 PgError::Connection(format!("Invalid auth_cleartext value: {}", value))
332 })?;
333 config.auth_settings.allow_cleartext_password = enabled;
334 }
335 "auth_kerberos" => {
336 let enabled = parse_bool_param(value).ok_or_else(|| {
337 PgError::Connection(format!("Invalid auth_kerberos value: {}", value))
338 })?;
339 config.auth_settings.allow_kerberos_v5 = enabled;
340 }
341 "auth_gssapi" => {
342 let enabled = parse_bool_param(value).ok_or_else(|| {
343 PgError::Connection(format!("Invalid auth_gssapi value: {}", value))
344 })?;
345 config.auth_settings.allow_gssapi = enabled;
346 }
347 "auth_sspi" => {
348 let enabled = parse_bool_param(value).ok_or_else(|| {
349 PgError::Connection(format!("Invalid auth_sspi value: {}", value))
350 })?;
351 config.auth_settings.allow_sspi = enabled;
352 }
353 "auth_mode" => {
354 if value.eq_ignore_ascii_case("scram_only") {
355 config.auth_settings = AuthSettings::scram_only();
356 } else if value.eq_ignore_ascii_case("gssapi_only") {
357 config.auth_settings = AuthSettings::gssapi_only();
358 } else if value.eq_ignore_ascii_case("compat")
359 || value.eq_ignore_ascii_case("default")
360 {
361 config.auth_settings = AuthSettings::default();
362 } else {
363 return Err(PgError::Connection(format!(
364 "Invalid auth_mode value: {}",
365 value
366 )));
367 }
368 }
369 "gss_provider" => gss_provider = Some(value.to_string()),
370 "gss_service" => {
371 if value.is_empty() {
372 return Err(PgError::Connection(
373 "gss_service must not be empty".to_string(),
374 ));
375 }
376 gss_service = value.to_string();
377 }
378 "gss_target" => {
379 if value.is_empty() {
380 return Err(PgError::Connection(
381 "gss_target must not be empty".to_string(),
382 ));
383 }
384 gss_target = Some(value.to_string());
385 }
386 "gss_connect_retries" => {
387 let retries = value.parse::<usize>().map_err(|_| {
388 PgError::Connection(format!("Invalid gss_connect_retries value: {}", value))
389 })?;
390 if retries > 20 {
391 return Err(PgError::Connection(
392 "gss_connect_retries must be <= 20".to_string(),
393 ));
394 }
395 config.gss_connect_retries = retries;
396 }
397 "gss_retry_base_ms" => {
398 let delay_ms = value.parse::<u64>().map_err(|_| {
399 PgError::Connection(format!("Invalid gss_retry_base_ms value: {}", value))
400 })?;
401 if delay_ms == 0 {
402 return Err(PgError::Connection(
403 "gss_retry_base_ms must be greater than 0".to_string(),
404 ));
405 }
406 config.gss_retry_base_delay = Duration::from_millis(delay_ms);
407 }
408 "gss_circuit_threshold" => {
409 let threshold = value.parse::<usize>().map_err(|_| {
410 PgError::Connection(format!("Invalid gss_circuit_threshold value: {}", value))
411 })?;
412 if threshold > 100 {
413 return Err(PgError::Connection(
414 "gss_circuit_threshold must be <= 100".to_string(),
415 ));
416 }
417 config.gss_circuit_breaker_threshold = threshold;
418 }
419 "gss_circuit_window_ms" => {
420 let window_ms = value.parse::<u64>().map_err(|_| {
421 PgError::Connection(format!("Invalid gss_circuit_window_ms value: {}", value))
422 })?;
423 if window_ms == 0 {
424 return Err(PgError::Connection(
425 "gss_circuit_window_ms must be greater than 0".to_string(),
426 ));
427 }
428 config.gss_circuit_breaker_window = Duration::from_millis(window_ms);
429 }
430 "gss_circuit_cooldown_ms" => {
431 let cooldown_ms = value.parse::<u64>().map_err(|_| {
432 PgError::Connection(format!("Invalid gss_circuit_cooldown_ms value: {}", value))
433 })?;
434 if cooldown_ms == 0 {
435 return Err(PgError::Connection(
436 "gss_circuit_cooldown_ms must be greater than 0".to_string(),
437 ));
438 }
439 config.gss_circuit_breaker_cooldown = Duration::from_millis(cooldown_ms);
440 }
441 _ => {}
442 }
443 }
444
445 match (sslcert.as_deref(), sslkey.as_deref()) {
446 (Some(cert_path), Some(key_path)) => {
447 let mtls = TlsConfig {
448 client_cert_pem: std::fs::read(cert_path).map_err(|e| {
449 PgError::Connection(format!("Failed to read sslcert '{}': {}", cert_path, e))
450 })?,
451 client_key_pem: std::fs::read(key_path).map_err(|e| {
452 PgError::Connection(format!("Failed to read sslkey '{}': {}", key_path, e))
453 })?,
454 ca_cert_pem: config.tls_ca_cert_pem.clone(),
455 };
456 config.mtls = Some(mtls);
457 config.tls_mode = TlsMode::Require;
458 }
459 (Some(_), None) | (None, Some(_)) => {
460 return Err(PgError::Connection(
461 "Both sslcert and sslkey must be provided together".to_string(),
462 ));
463 }
464 (None, None) => {}
465 }
466
467 if let Some(provider) = gss_provider {
468 if provider.eq_ignore_ascii_case("linux_krb5") || provider.eq_ignore_ascii_case("builtin") {
469 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
470 {
471 let provider = crate::driver::gss::linux_krb5_token_provider(
472 crate::driver::gss::LinuxKrb5ProviderConfig {
473 host: host.to_string(),
474 service: gss_service.clone(),
475 target_name: gss_target.clone(),
476 },
477 )
478 .map_err(PgError::Auth)?;
479 config.gss_token_provider_ex = Some(provider);
480 }
481 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
482 {
483 let _ = gss_service;
484 let _ = gss_target;
485 return Err(PgError::Connection(
486 "gss_provider=linux_krb5 requires qail-pg feature enterprise-gssapi on Linux"
487 .to_string(),
488 ));
489 }
490 } else if provider.eq_ignore_ascii_case("callback")
491 || provider.eq_ignore_ascii_case("custom")
492 {
493 } else {
495 return Err(PgError::Connection(format!(
496 "Invalid gss_provider value: {}",
497 provider
498 )));
499 }
500 }
501
502 Ok(())
503}
504
505pub(super) fn parse_pg_url(url: &str) -> PgResult<(String, u16, String, String, Option<String>)> {
507 let url = url.split('?').next().unwrap_or(url);
508 let url = url
509 .trim_start_matches("postgres://")
510 .trim_start_matches("postgresql://");
511
512 let (credentials, host_part) = if url.contains('@') {
513 let mut parts = url.splitn(2, '@');
514 let creds = parts.next().unwrap_or("");
515 let host = parts.next().unwrap_or("localhost/postgres");
516 (Some(creds), host)
517 } else {
518 (None, url)
519 };
520
521 let (host_port, database) = if host_part.contains('/') {
522 let mut parts = host_part.splitn(2, '/');
523 (
524 parts.next().unwrap_or("localhost"),
525 parts.next().unwrap_or("postgres").to_string(),
526 )
527 } else {
528 (host_part, "postgres".to_string())
529 };
530
531 let (host, port) = if host_port.contains(':') {
532 let mut parts = host_port.split(':');
533 let h = parts.next().unwrap_or("localhost").to_string();
534 let p = parts.next().and_then(|s| s.parse().ok()).unwrap_or(5432u16);
535 (h, p)
536 } else {
537 (host_port.to_string(), 5432u16)
538 };
539
540 let (user, password) = if let Some(creds) = credentials {
541 if creds.contains(':') {
542 let mut parts = creds.splitn(2, ':');
543 let u = parts.next().unwrap_or("postgres").to_string();
544 let p = parts.next().map(|s| s.to_string());
545 (u, p)
546 } else {
547 (creds.to_string(), None)
548 }
549 } else {
550 ("postgres".to_string(), None)
551 };
552
553 Ok((host, port, user, database, password))
554}
555
556pub(super) fn parse_bool_param(value: &str) -> Option<bool> {
557 match value.trim().to_ascii_lowercase().as_str() {
558 "1" | "true" | "yes" | "on" => Some(true),
559 "0" | "false" | "no" | "off" => Some(false),
560 _ => None,
561 }
562}