1#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6 connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7 record_connect_attempt, record_connect_result,
8};
9#[cfg(all(target_os = "linux", feature = "io_uring"))]
10use super::types::CONNECT_BACKEND_IO_URING;
11use super::types::{
12 BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
13 CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
14 GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
15 StatementCache, TlsConfig, has_logical_replication_startup_mode,
16};
17use crate::driver::stream::PgStream;
18use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
19use crate::protocol::wire::FrontendMessage;
20use bytes::BytesMut;
21use std::collections::{HashMap, VecDeque};
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27impl PgConnection {
28 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
37 Self::connect_with_password(host, port, user, database, None).await
38 }
39
40 pub async fn connect_with_password(
43 host: &str,
44 port: u16,
45 user: &str,
46 database: &str,
47 password: Option<&str>,
48 ) -> PgResult<Self> {
49 Self::connect_with_password_and_auth(
50 host,
51 port,
52 user,
53 database,
54 password,
55 AuthSettings::default(),
56 )
57 .await
58 }
59
60 pub async fn connect_with_options(
67 host: &str,
68 port: u16,
69 user: &str,
70 database: &str,
71 password: Option<&str>,
72 options: ConnectOptions,
73 ) -> PgResult<Self> {
74 let ConnectOptions {
75 tls_mode,
76 gss_enc_mode,
77 tls_ca_cert_pem,
78 mtls,
79 gss_token_provider,
80 gss_token_provider_ex,
81 auth,
82 startup_params,
83 } = options;
84
85 if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
86 return Err(PgError::Connection(
87 "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
88 ));
89 }
90
91 if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
95 return Err(PgError::Connection(
96 "gssencmode=require is incompatible with mTLS — both provide \
97 transport encryption; use one or the other"
98 .to_string(),
99 ));
100 }
101
102 if let Some(mtls_config) = mtls {
103 return Self::connect_mtls_with_password_and_auth_and_gss(
106 ConnectParams {
107 host,
108 port,
109 user,
110 database,
111 password,
112 auth_settings: auth,
113 gss_token_provider,
114 gss_token_provider_ex,
115 startup_params: startup_params.clone(),
116 },
117 mtls_config,
118 )
119 .await;
120 }
121
122 if gss_enc_mode != GssEncMode::Disable {
124 match Self::try_gssenc_request(host, port).await {
125 Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
126 let connect_started = Instant::now();
127 record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
128 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
129 {
130 let gssenc_fut = async {
131 let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
132 .await
133 .map_err(PgError::Auth)?;
134 let mut conn = Self {
135 stream: PgStream::GssEnc(gss_stream),
136 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
137 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
138 sql_buf: BytesMut::with_capacity(512),
139 params_buf: Vec::with_capacity(16),
140 prepared_statements: HashMap::new(),
141 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
142 column_info_cache: HashMap::new(),
143 process_id: 0,
144 secret_key: 0,
145 notifications: VecDeque::new(),
146 replication_stream_active: false,
147 replication_mode_enabled: has_logical_replication_startup_mode(
148 &startup_params,
149 ),
150 last_replication_wal_end: None,
151 io_desynced: false,
152 pending_statement_closes: Vec::new(),
153 draining_statement_closes: false,
154 };
155 conn.send(FrontendMessage::Startup {
156 user: user.to_string(),
157 database: database.to_string(),
158 startup_params: startup_params.clone(),
159 })
160 .await?;
161 conn.handle_startup(
162 user,
163 password,
164 auth,
165 gss_token_provider,
166 gss_token_provider_ex,
167 )
168 .await?;
169 Ok(conn)
170 };
171 let result: PgResult<Self> =
172 tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
173 .await
174 .map_err(|_| {
175 PgError::Connection(format!(
176 "GSSENC connection timeout after {:?} \
177 (handshake + auth)",
178 DEFAULT_CONNECT_TIMEOUT
179 ))
180 })?;
181 record_connect_result(
182 CONNECT_TRANSPORT_GSSENC,
183 CONNECT_BACKEND_TOKIO,
184 &result,
185 connect_started.elapsed(),
186 );
187 return result;
188 }
189 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
190 {
191 let _ = tcp_stream;
192 let err = PgError::Connection(
193 "Server accepted GSSENCRequest but GSSAPI encryption requires \
194 feature enterprise-gssapi on Linux"
195 .to_string(),
196 );
197 metrics::histogram!(
198 "qail_pg_connect_duration_seconds",
199 "transport" => CONNECT_TRANSPORT_GSSENC,
200 "backend" => CONNECT_BACKEND_TOKIO,
201 "outcome" => "error"
202 )
203 .record(connect_started.elapsed().as_secs_f64());
204 metrics::counter!(
205 "qail_pg_connect_failure_total",
206 "transport" => CONNECT_TRANSPORT_GSSENC,
207 "backend" => CONNECT_BACKEND_TOKIO,
208 "error_kind" => connect_error_kind(&err)
209 )
210 .increment(1);
211 return Err(err);
212 }
213 }
214 Ok(GssEncNegotiationResult::Rejected)
215 | Ok(GssEncNegotiationResult::ServerError) => {
216 if gss_enc_mode == GssEncMode::Require {
217 return Err(PgError::Connection(
218 "gssencmode=require but server rejected GSSENCRequest".to_string(),
219 ));
220 }
221 }
223 Err(e) => {
224 if gss_enc_mode == GssEncMode::Require {
225 return Err(e);
226 }
227 tracing::debug!(
229 host = %host,
230 port = %port,
231 error = %e,
232 "gssenc_prefer_fallthrough"
233 );
234 }
235 }
236 }
237
238 match tls_mode {
240 TlsMode::Disable => {
241 Self::connect_with_password_and_auth_and_gss(ConnectParams {
242 host,
243 port,
244 user,
245 database,
246 password,
247 auth_settings: auth,
248 gss_token_provider,
249 gss_token_provider_ex,
250 startup_params: startup_params.clone(),
251 })
252 .await
253 }
254 TlsMode::Require => {
255 Self::connect_tls_with_auth_and_gss(
256 ConnectParams {
257 host,
258 port,
259 user,
260 database,
261 password,
262 auth_settings: auth,
263 gss_token_provider,
264 gss_token_provider_ex,
265 startup_params: startup_params.clone(),
266 },
267 tls_ca_cert_pem.as_deref(),
268 )
269 .await
270 }
271 TlsMode::Prefer => {
272 match Self::connect_tls_with_auth_and_gss(
273 ConnectParams {
274 host,
275 port,
276 user,
277 database,
278 password,
279 auth_settings: auth,
280 gss_token_provider,
281 gss_token_provider_ex: gss_token_provider_ex.clone(),
282 startup_params: startup_params.clone(),
283 },
284 tls_ca_cert_pem.as_deref(),
285 )
286 .await
287 {
288 Ok(conn) => Ok(conn),
289 Err(PgError::Connection(msg))
290 if msg.contains("Server does not support TLS") =>
291 {
292 Self::connect_with_password_and_auth_and_gss(ConnectParams {
293 host,
294 port,
295 user,
296 database,
297 password,
298 auth_settings: auth,
299 gss_token_provider,
300 gss_token_provider_ex,
301 startup_params: startup_params.clone(),
302 })
303 .await
304 }
305 Err(e) => Err(e),
306 }
307 }
308 }
309 }
310
311 async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
318 tokio::time::timeout(
319 DEFAULT_CONNECT_TIMEOUT,
320 Self::try_gssenc_request_inner(host, port),
321 )
322 .await
323 .map_err(|_| {
324 PgError::Connection(format!(
325 "GSSENCRequest timeout after {:?}",
326 DEFAULT_CONNECT_TIMEOUT
327 ))
328 })?
329 }
330
331 async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
333 use tokio::io::AsyncReadExt;
334
335 let addr = format!("{}:{}", host, port);
336 let mut tcp_stream = TcpStream::connect(&addr).await?;
337 tcp_stream.set_nodelay(true)?;
338
339 tcp_stream.write_all(&GSSENC_REQUEST).await?;
341 tcp_stream.flush().await?;
342
343 let mut response = [0u8; 1];
347 tcp_stream.read_exact(&mut response).await?;
348
349 match response[0] {
350 b'G' => {
351 let mut peek_buf = [0u8; 1];
354 match tcp_stream.try_read(&mut peek_buf) {
355 Ok(0) => {} Ok(_n) => {
357 return Err(PgError::Connection(
359 "Protocol violation: extra bytes after GSSENCRequest 'G' response \
360 (possible CVE-2021-23222 buffer-stuffing attack)"
361 .to_string(),
362 ));
363 }
364 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
365 }
367 Err(e) => {
368 return Err(PgError::Io(e));
369 }
370 }
371 Ok(GssEncNegotiationResult::Accepted(tcp_stream))
372 }
373 b'N' => Ok(GssEncNegotiationResult::Rejected),
374 b'E' => {
375 tracing::trace!(
379 host = %host,
380 port = %port,
381 "gssenc_request_server_error (suppressed per CVE-2024-10977)"
382 );
383 Ok(GssEncNegotiationResult::ServerError)
384 }
385 other => Err(PgError::Connection(format!(
386 "Unexpected response to GSSENCRequest: 0x{:02X} \
387 (expected 'G'=0x47 or 'N'=0x4E)",
388 other
389 ))),
390 }
391 }
392
393 pub async fn connect_with_password_and_auth(
395 host: &str,
396 port: u16,
397 user: &str,
398 database: &str,
399 password: Option<&str>,
400 auth_settings: AuthSettings,
401 ) -> PgResult<Self> {
402 Self::connect_with_password_and_auth_and_gss(ConnectParams {
403 host,
404 port,
405 user,
406 database,
407 password,
408 auth_settings,
409 gss_token_provider: None,
410 gss_token_provider_ex: None,
411 startup_params: Vec::new(),
412 })
413 .await
414 }
415
416 async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
417 let connect_started = Instant::now();
418 let attempt_backend = plain_connect_attempt_backend();
419 record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
420 let result = tokio::time::timeout(
421 DEFAULT_CONNECT_TIMEOUT,
422 Self::connect_with_password_inner(params),
423 )
424 .await
425 .map_err(|_| {
426 PgError::Connection(format!(
427 "Connection timeout after {:?} (TCP connect + handshake)",
428 DEFAULT_CONNECT_TIMEOUT
429 ))
430 })?;
431 let backend = result
432 .as_ref()
433 .map(|conn| connect_backend_for_stream(&conn.stream))
434 .unwrap_or(attempt_backend);
435 record_connect_result(
436 CONNECT_TRANSPORT_PLAIN,
437 backend,
438 &result,
439 connect_started.elapsed(),
440 );
441 result
442 }
443
444 async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
446 let ConnectParams {
447 host,
448 port,
449 user,
450 database,
451 password,
452 auth_settings,
453 gss_token_provider,
454 gss_token_provider_ex,
455 startup_params,
456 } = params;
457 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
458 let addr = format!("{}:{}", host, port);
459 let stream = Self::connect_plain_stream(&addr).await?;
460
461 let mut conn = Self {
462 stream,
463 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
464 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
466 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
468 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
469 column_info_cache: HashMap::new(),
470 process_id: 0,
471 secret_key: 0,
472 notifications: VecDeque::new(),
473 replication_stream_active: false,
474 replication_mode_enabled,
475 last_replication_wal_end: None,
476 io_desynced: false,
477 pending_statement_closes: Vec::new(),
478 draining_statement_closes: false,
479 };
480
481 conn.send(FrontendMessage::Startup {
482 user: user.to_string(),
483 database: database.to_string(),
484 startup_params,
485 })
486 .await?;
487
488 conn.handle_startup(
489 user,
490 password,
491 auth_settings,
492 gss_token_provider,
493 gss_token_provider_ex,
494 )
495 .await?;
496
497 Ok(conn)
498 }
499
500 async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
501 let tcp_stream = TcpStream::connect(addr).await?;
502 tcp_stream.set_nodelay(true)?;
503
504 #[cfg(all(target_os = "linux", feature = "io_uring"))]
505 {
506 if should_try_uring_plain() {
507 match super::super::uring::UringTcpStream::from_tokio(tcp_stream) {
508 Ok(uring_stream) => {
509 tracing::info!(
510 addr = %addr,
511 "qail-pg: using io_uring plain TCP transport"
512 );
513 return Ok(PgStream::Uring(uring_stream));
514 }
515 Err(e) => {
516 tracing::warn!(
517 addr = %addr,
518 error = %e,
519 "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
520 );
521 let fallback = TcpStream::connect(addr).await?;
522 fallback.set_nodelay(true)?;
523 return Ok(PgStream::Tcp(fallback));
524 }
525 }
526 }
527 }
528
529 Ok(PgStream::Tcp(tcp_stream))
530 }
531
532 pub async fn connect_tls(
535 host: &str,
536 port: u16,
537 user: &str,
538 database: &str,
539 password: Option<&str>,
540 ) -> PgResult<Self> {
541 Self::connect_tls_with_auth(
542 host,
543 port,
544 user,
545 database,
546 password,
547 AuthSettings::default(),
548 None,
549 )
550 .await
551 }
552
553 pub async fn connect_tls_with_auth(
555 host: &str,
556 port: u16,
557 user: &str,
558 database: &str,
559 password: Option<&str>,
560 auth_settings: AuthSettings,
561 ca_cert_pem: Option<&[u8]>,
562 ) -> PgResult<Self> {
563 Self::connect_tls_with_auth_and_gss(
564 ConnectParams {
565 host,
566 port,
567 user,
568 database,
569 password,
570 auth_settings,
571 gss_token_provider: None,
572 gss_token_provider_ex: None,
573 startup_params: Vec::new(),
574 },
575 ca_cert_pem,
576 )
577 .await
578 }
579
580 async fn connect_tls_with_auth_and_gss(
581 params: ConnectParams<'_>,
582 ca_cert_pem: Option<&[u8]>,
583 ) -> PgResult<Self> {
584 let connect_started = Instant::now();
585 record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
586 let result = tokio::time::timeout(
587 DEFAULT_CONNECT_TIMEOUT,
588 Self::connect_tls_inner(params, ca_cert_pem),
589 )
590 .await
591 .map_err(|_| {
592 PgError::Connection(format!(
593 "TLS connection timeout after {:?}",
594 DEFAULT_CONNECT_TIMEOUT
595 ))
596 })?;
597 record_connect_result(
598 CONNECT_TRANSPORT_TLS,
599 CONNECT_BACKEND_TOKIO,
600 &result,
601 connect_started.elapsed(),
602 );
603 result
604 }
605
606 async fn connect_tls_inner(
608 params: ConnectParams<'_>,
609 ca_cert_pem: Option<&[u8]>,
610 ) -> PgResult<Self> {
611 let ConnectParams {
612 host,
613 port,
614 user,
615 database,
616 password,
617 auth_settings,
618 gss_token_provider,
619 gss_token_provider_ex,
620 startup_params,
621 } = params;
622 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
623 use tokio::io::AsyncReadExt;
624 use tokio_rustls::TlsConnector;
625 use tokio_rustls::rustls::ClientConfig;
626 use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
627
628 let addr = format!("{}:{}", host, port);
629 let mut tcp_stream = TcpStream::connect(&addr).await?;
630
631 tcp_stream.write_all(&SSL_REQUEST).await?;
633
634 let mut response = [0u8; 1];
636 tcp_stream.read_exact(&mut response).await?;
637
638 if response[0] != b'S' {
639 return Err(PgError::Connection(
640 "Server does not support TLS".to_string(),
641 ));
642 }
643
644 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
645
646 if let Some(ca_pem) = ca_cert_pem {
647 let certs = CertificateDer::pem_slice_iter(ca_pem)
648 .collect::<Result<Vec<_>, _>>()
649 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
650 if certs.is_empty() {
651 return Err(PgError::Connection(
652 "No CA certificates found in provided PEM".to_string(),
653 ));
654 }
655 for cert in certs {
656 let _ = root_cert_store.add(cert);
657 }
658 } else {
659 let certs = rustls_native_certs::load_native_certs();
660 for cert in certs.certs {
661 let _ = root_cert_store.add(cert);
662 }
663 }
664
665 let config = ClientConfig::builder()
666 .with_root_certificates(root_cert_store)
667 .with_no_client_auth();
668
669 let connector = TlsConnector::from(Arc::new(config));
670 let server_name = ServerName::try_from(host.to_string())
671 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
672
673 let tls_stream = connector
674 .connect(server_name, tcp_stream)
675 .await
676 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
677
678 let mut conn = Self {
679 stream: PgStream::Tls(Box::new(tls_stream)),
680 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
681 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
682 sql_buf: BytesMut::with_capacity(512),
683 params_buf: Vec::with_capacity(16),
684 prepared_statements: HashMap::new(),
685 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
686 column_info_cache: HashMap::new(),
687 process_id: 0,
688 secret_key: 0,
689 notifications: VecDeque::new(),
690 replication_stream_active: false,
691 replication_mode_enabled,
692 last_replication_wal_end: None,
693 io_desynced: false,
694 pending_statement_closes: Vec::new(),
695 draining_statement_closes: false,
696 };
697
698 conn.send(FrontendMessage::Startup {
699 user: user.to_string(),
700 database: database.to_string(),
701 startup_params,
702 })
703 .await?;
704
705 conn.handle_startup(
706 user,
707 password,
708 auth_settings,
709 gss_token_provider,
710 gss_token_provider_ex,
711 )
712 .await?;
713
714 Ok(conn)
715 }
716
717 pub async fn connect_mtls(
734 host: &str,
735 port: u16,
736 user: &str,
737 database: &str,
738 config: TlsConfig,
739 ) -> PgResult<Self> {
740 Self::connect_mtls_with_password_and_auth(
741 host,
742 port,
743 user,
744 database,
745 None,
746 config,
747 AuthSettings::default(),
748 )
749 .await
750 }
751
752 pub async fn connect_mtls_with_password_and_auth(
754 host: &str,
755 port: u16,
756 user: &str,
757 database: &str,
758 password: Option<&str>,
759 config: TlsConfig,
760 auth_settings: AuthSettings,
761 ) -> PgResult<Self> {
762 Self::connect_mtls_with_password_and_auth_and_gss(
763 ConnectParams {
764 host,
765 port,
766 user,
767 database,
768 password,
769 auth_settings,
770 gss_token_provider: None,
771 gss_token_provider_ex: None,
772 startup_params: Vec::new(),
773 },
774 config,
775 )
776 .await
777 }
778
779 async fn connect_mtls_with_password_and_auth_and_gss(
780 params: ConnectParams<'_>,
781 config: TlsConfig,
782 ) -> PgResult<Self> {
783 let connect_started = Instant::now();
784 record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
785 let result = tokio::time::timeout(
786 DEFAULT_CONNECT_TIMEOUT,
787 Self::connect_mtls_inner(params, config),
788 )
789 .await
790 .map_err(|_| {
791 PgError::Connection(format!(
792 "mTLS connection timeout after {:?}",
793 DEFAULT_CONNECT_TIMEOUT
794 ))
795 })?;
796 record_connect_result(
797 CONNECT_TRANSPORT_MTLS,
798 CONNECT_BACKEND_TOKIO,
799 &result,
800 connect_started.elapsed(),
801 );
802 result
803 }
804
805 async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
807 let ConnectParams {
808 host,
809 port,
810 user,
811 database,
812 password,
813 auth_settings,
814 gss_token_provider,
815 gss_token_provider_ex,
816 startup_params,
817 } = params;
818 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
819 use tokio::io::AsyncReadExt;
820 use tokio_rustls::TlsConnector;
821 use tokio_rustls::rustls::{
822 ClientConfig,
823 pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
824 };
825
826 let addr = format!("{}:{}", host, port);
827 let mut tcp_stream = TcpStream::connect(&addr).await?;
828
829 tcp_stream.write_all(&SSL_REQUEST).await?;
831
832 let mut response = [0u8; 1];
834 tcp_stream.read_exact(&mut response).await?;
835
836 if response[0] != b'S' {
837 return Err(PgError::Connection(
838 "Server does not support TLS".to_string(),
839 ));
840 }
841
842 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
843
844 if let Some(ca_pem) = &config.ca_cert_pem {
845 let certs = CertificateDer::pem_slice_iter(ca_pem)
846 .collect::<Result<Vec<_>, _>>()
847 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
848 if certs.is_empty() {
849 return Err(PgError::Connection(
850 "No CA certificates found in provided PEM".to_string(),
851 ));
852 }
853 for cert in certs {
854 let _ = root_cert_store.add(cert);
855 }
856 } else {
857 let certs = rustls_native_certs::load_native_certs();
859 for cert in certs.certs {
860 let _ = root_cert_store.add(cert);
861 }
862 }
863
864 let client_certs: Vec<CertificateDer<'static>> =
865 CertificateDer::pem_slice_iter(&config.client_cert_pem)
866 .collect::<Result<Vec<_>, _>>()
867 .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
868 if client_certs.is_empty() {
869 return Err(PgError::Connection(
870 "No client certificates found in PEM".to_string(),
871 ));
872 }
873
874 let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
875 .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
876
877 let tls_config = ClientConfig::builder()
878 .with_root_certificates(root_cert_store)
879 .with_client_auth_cert(client_certs, client_key)
880 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
881
882 let connector = TlsConnector::from(Arc::new(tls_config));
883 let server_name = ServerName::try_from(host.to_string())
884 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
885
886 let tls_stream = connector
887 .connect(server_name, tcp_stream)
888 .await
889 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
890
891 let mut conn = Self {
892 stream: PgStream::Tls(Box::new(tls_stream)),
893 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
894 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
895 sql_buf: BytesMut::with_capacity(512),
896 params_buf: Vec::with_capacity(16),
897 prepared_statements: HashMap::new(),
898 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
899 column_info_cache: HashMap::new(),
900 process_id: 0,
901 secret_key: 0,
902 notifications: VecDeque::new(),
903 replication_stream_active: false,
904 replication_mode_enabled,
905 last_replication_wal_end: None,
906 io_desynced: false,
907 pending_statement_closes: Vec::new(),
908 draining_statement_closes: false,
909 };
910
911 conn.send(FrontendMessage::Startup {
912 user: user.to_string(),
913 database: database.to_string(),
914 startup_params,
915 })
916 .await?;
917
918 conn.handle_startup(
919 user,
920 password,
921 auth_settings,
922 gss_token_provider,
923 gss_token_provider_ex,
924 )
925 .await?;
926
927 Ok(conn)
928 }
929
930 #[cfg(unix)]
932 pub async fn connect_unix(
933 socket_path: &str,
934 user: &str,
935 database: &str,
936 password: Option<&str>,
937 ) -> PgResult<Self> {
938 use tokio::net::UnixStream;
939
940 let unix_stream = UnixStream::connect(socket_path).await?;
941
942 let mut conn = Self {
943 stream: PgStream::Unix(unix_stream),
944 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
945 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
946 sql_buf: BytesMut::with_capacity(512),
947 params_buf: Vec::with_capacity(16),
948 prepared_statements: HashMap::new(),
949 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
950 column_info_cache: HashMap::new(),
951 process_id: 0,
952 secret_key: 0,
953 notifications: VecDeque::new(),
954 replication_stream_active: false,
955 replication_mode_enabled: false,
956 last_replication_wal_end: None,
957 io_desynced: false,
958 pending_statement_closes: Vec::new(),
959 draining_statement_closes: false,
960 };
961
962 conn.send(FrontendMessage::Startup {
963 user: user.to_string(),
964 database: database.to_string(),
965 startup_params: Vec::new(),
966 })
967 .await?;
968
969 conn.handle_startup(user, password, AuthSettings::default(), None, None)
970 .await?;
971
972 Ok(conn)
973 }
974}