1#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6 connect_backend_for_stream, plain_connect_attempt_backend, record_connect_attempt,
7 record_connect_result,
8};
9use super::types::{
10 BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
11 CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
12 GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
13 StatementCache, TlsConfig, has_logical_replication_startup_mode,
14};
15use crate::driver::stream::PgStream;
16use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
17use crate::protocol::PROTOCOL_VERSION_3_0;
18use crate::protocol::wire::FrontendMessage;
19use bytes::BytesMut;
20use std::collections::{HashMap, VecDeque};
21use std::sync::Arc;
22use std::time::Instant;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26#[inline]
27fn protocol_version_from_minor(minor: u16) -> i32 {
28 ((3i32) << 16) | i32::from(minor)
29}
30
31fn socket_addr(host: &str, port: u16) -> String {
32 if host.contains(':') && !host.starts_with('[') {
33 format!("[{}]:{}", host, port)
34 } else {
35 format!("{}:{}", host, port)
36 }
37}
38
39fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
40 let msg = match err {
41 PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
42 PgError::Query(msg) => msg,
43 PgError::QueryServer(server) => &server.message,
44 _ => return false,
45 };
46
47 let lower = msg.to_ascii_lowercase();
48 lower.contains("unsupported frontend protocol")
49 || lower.contains("frontend protocol") && lower.contains("unsupported")
50 || lower.contains("protocol version") && lower.contains("not support")
51}
52
53impl PgConnection {
54 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
63 Self::connect_with_password(host, port, user, database, None).await
64 }
65
66 pub async fn connect_with_password(
73 host: &str,
74 port: u16,
75 user: &str,
76 database: &str,
77 password: Option<&str>,
78 ) -> PgResult<Self> {
79 Self::connect_with_password_and_auth(
80 host,
81 port,
82 user,
83 database,
84 password,
85 AuthSettings::default(),
86 )
87 .await
88 }
89
90 pub async fn connect_with_options(
101 host: &str,
102 port: u16,
103 user: &str,
104 database: &str,
105 password: Option<&str>,
106 options: ConnectOptions,
107 ) -> PgResult<Self> {
108 let ConnectOptions {
109 tls_mode,
110 gss_enc_mode,
111 tls_ca_cert_pem,
112 mtls,
113 gss_token_provider,
114 gss_token_provider_ex,
115 auth,
116 startup_params,
117 } = options;
118
119 if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
120 return Err(PgError::Connection(
121 "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
122 ));
123 }
124
125 if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
129 return Err(PgError::Connection(
130 "gssencmode=require is incompatible with mTLS — both provide \
131 transport encryption; use one or the other"
132 .to_string(),
133 ));
134 }
135
136 if let Some(mtls_config) = mtls {
137 return Self::connect_mtls_with_password_and_auth_and_gss(
140 ConnectParams {
141 host,
142 port,
143 user,
144 database,
145 password,
146 auth_settings: auth,
147 gss_token_provider,
148 gss_token_provider_ex,
149 protocol_minor: Self::default_protocol_minor(),
150 startup_params: startup_params.clone(),
151 },
152 mtls_config,
153 )
154 .await;
155 }
156
157 if gss_enc_mode != GssEncMode::Disable {
159 match Self::try_gssenc_request(host, port).await {
160 Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
161 let connect_started = Instant::now();
162 record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
163 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
164 {
165 let default_minor = Self::default_protocol_minor();
166 let gss_params = ConnectParams {
167 host,
168 port,
169 user,
170 database,
171 password,
172 auth_settings: auth,
173 gss_token_provider,
174 gss_token_provider_ex: gss_token_provider_ex.clone(),
175 protocol_minor: default_minor,
176 startup_params: startup_params.clone(),
177 };
178 let mut result = Self::connect_gssenc_accepted_with_timeout(
179 tcp_stream,
180 gss_params.clone(),
181 )
182 .await;
183 if let Err(err) = &result
184 && default_minor > 0
185 && is_explicit_protocol_version_rejection(err)
186 {
187 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
188 let retry_stream = match Self::try_gssenc_request(host, port).await {
189 Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
190 Ok(GssEncNegotiationResult::Rejected) => {
191 return Err(PgError::Connection(
192 "Protocol downgrade retry failed: server rejected GSSENCRequest"
193 .to_string(),
194 ));
195 }
196 Ok(GssEncNegotiationResult::ServerError) => {
197 return Err(PgError::Connection(
198 "Protocol downgrade retry failed: server returned error to GSSENCRequest"
199 .to_string(),
200 ));
201 }
202 Err(e) => {
203 return Err(e);
204 }
205 };
206 let mut retry_params = gss_params;
207 retry_params.protocol_minor = downgrade_minor;
208 result = Self::connect_gssenc_accepted_with_timeout(
209 retry_stream,
210 retry_params,
211 )
212 .await;
213 }
214 record_connect_result(
215 CONNECT_TRANSPORT_GSSENC,
216 CONNECT_BACKEND_TOKIO,
217 &result,
218 connect_started.elapsed(),
219 );
220 return result;
221 }
222 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
223 {
224 let _ = tcp_stream;
225 let err = PgError::Connection(
226 "Server accepted GSSENCRequest but GSSAPI encryption requires \
227 feature enterprise-gssapi on Linux"
228 .to_string(),
229 );
230 metrics::histogram!(
231 "qail_pg_connect_duration_seconds",
232 "transport" => CONNECT_TRANSPORT_GSSENC,
233 "backend" => CONNECT_BACKEND_TOKIO,
234 "outcome" => "error"
235 )
236 .record(connect_started.elapsed().as_secs_f64());
237 metrics::counter!(
238 "qail_pg_connect_failure_total",
239 "transport" => CONNECT_TRANSPORT_GSSENC,
240 "backend" => CONNECT_BACKEND_TOKIO,
241 "error_kind" => super::helpers::connect_error_kind(&err)
242 )
243 .increment(1);
244 return Err(err);
245 }
246 }
247 Ok(GssEncNegotiationResult::Rejected)
248 | Ok(GssEncNegotiationResult::ServerError) => {
249 if gss_enc_mode == GssEncMode::Require {
250 return Err(PgError::Connection(
251 "gssencmode=require but server rejected GSSENCRequest".to_string(),
252 ));
253 }
254 }
256 Err(e) => {
257 if gss_enc_mode == GssEncMode::Require {
258 return Err(e);
259 }
260 tracing::debug!(
262 host = %host,
263 port = %port,
264 error = %e,
265 "gssenc_prefer_fallthrough"
266 );
267 }
268 }
269 }
270
271 match tls_mode {
273 TlsMode::Disable => {
274 Self::connect_with_password_and_auth_and_gss(ConnectParams {
275 host,
276 port,
277 user,
278 database,
279 password,
280 auth_settings: auth,
281 gss_token_provider,
282 gss_token_provider_ex,
283 protocol_minor: Self::default_protocol_minor(),
284 startup_params: startup_params.clone(),
285 })
286 .await
287 }
288 TlsMode::Require => {
289 Self::connect_tls_with_auth_and_gss(
290 ConnectParams {
291 host,
292 port,
293 user,
294 database,
295 password,
296 auth_settings: auth,
297 gss_token_provider,
298 gss_token_provider_ex,
299 protocol_minor: Self::default_protocol_minor(),
300 startup_params: startup_params.clone(),
301 },
302 tls_ca_cert_pem.as_deref(),
303 )
304 .await
305 }
306 TlsMode::Prefer => {
307 match Self::connect_tls_with_auth_and_gss(
308 ConnectParams {
309 host,
310 port,
311 user,
312 database,
313 password,
314 auth_settings: auth,
315 gss_token_provider,
316 gss_token_provider_ex: gss_token_provider_ex.clone(),
317 protocol_minor: Self::default_protocol_minor(),
318 startup_params: startup_params.clone(),
319 },
320 tls_ca_cert_pem.as_deref(),
321 )
322 .await
323 {
324 Ok(conn) => Ok(conn),
325 Err(PgError::Connection(msg))
326 if msg.contains("Server does not support TLS") =>
327 {
328 Self::connect_with_password_and_auth_and_gss(ConnectParams {
329 host,
330 port,
331 user,
332 database,
333 password,
334 auth_settings: auth,
335 gss_token_provider,
336 gss_token_provider_ex,
337 protocol_minor: Self::default_protocol_minor(),
338 startup_params: startup_params.clone(),
339 })
340 .await
341 }
342 Err(e) => Err(e),
343 }
344 }
345 }
346 }
347
348 async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
355 tokio::time::timeout(
356 DEFAULT_CONNECT_TIMEOUT,
357 Self::try_gssenc_request_inner(host, port),
358 )
359 .await
360 .map_err(|_| {
361 PgError::Connection(format!(
362 "GSSENCRequest timeout after {:?}",
363 DEFAULT_CONNECT_TIMEOUT
364 ))
365 })?
366 }
367
368 async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
370 use tokio::io::AsyncReadExt;
371
372 let addr = socket_addr(host, port);
373 let mut tcp_stream = TcpStream::connect(&addr).await?;
374 tcp_stream.set_nodelay(true)?;
375
376 tcp_stream.write_all(&GSSENC_REQUEST).await?;
378 tcp_stream.flush().await?;
379
380 let mut response = [0u8; 1];
384 tcp_stream.read_exact(&mut response).await?;
385
386 match response[0] {
387 b'G' => {
388 let mut peek_buf = [0u8; 1];
391 match tcp_stream.try_read(&mut peek_buf) {
392 Ok(0) => {} Ok(_n) => {
394 return Err(PgError::Connection(
396 "Protocol violation: extra bytes after GSSENCRequest 'G' response \
397 (possible CVE-2021-23222 buffer-stuffing attack)"
398 .to_string(),
399 ));
400 }
401 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
402 }
404 Err(e) => {
405 return Err(PgError::Io(e));
406 }
407 }
408 Ok(GssEncNegotiationResult::Accepted(tcp_stream))
409 }
410 b'N' => Ok(GssEncNegotiationResult::Rejected),
411 b'E' => {
412 tracing::trace!(
416 host = %host,
417 port = %port,
418 "gssenc_request_server_error (suppressed per CVE-2024-10977)"
419 );
420 Ok(GssEncNegotiationResult::ServerError)
421 }
422 other => Err(PgError::Connection(format!(
423 "Unexpected response to GSSENCRequest: 0x{:02X} \
424 (expected 'G'=0x47 or 'N'=0x4E)",
425 other
426 ))),
427 }
428 }
429
430 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
431 async fn connect_gssenc_accepted_with_timeout(
432 tcp_stream: TcpStream,
433 params: ConnectParams<'_>,
434 ) -> PgResult<Self> {
435 let gssenc_fut = async {
436 let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, params.host)
437 .await
438 .map_err(PgError::Auth)?;
439 let mut conn = Self {
440 stream: PgStream::GssEnc(gss_stream),
441 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
442 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
443 sql_buf: BytesMut::with_capacity(512),
444 params_buf: Vec::with_capacity(16),
445 prepared_statements: HashMap::new(),
446 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
447 column_info_cache: HashMap::new(),
448 process_id: 0,
449 cancel_key_bytes: Vec::new(),
450 requested_protocol_minor: params.protocol_minor,
451 negotiated_protocol_minor: params.protocol_minor,
452 notifications: VecDeque::new(),
453 replication_stream_active: false,
454 replication_mode_enabled: has_logical_replication_startup_mode(
455 ¶ms.startup_params,
456 ),
457 last_replication_wal_end: None,
458 io_desynced: false,
459 pending_statement_closes: Vec::new(),
460 draining_statement_closes: false,
461 };
462 conn.send(FrontendMessage::Startup {
463 user: params.user.to_string(),
464 database: params.database.to_string(),
465 protocol_version: protocol_version_from_minor(params.protocol_minor),
466 startup_params: params.startup_params.clone(),
467 })
468 .await?;
469 conn.handle_startup(
470 params.user,
471 params.password,
472 params.auth_settings,
473 params.gss_token_provider,
474 params.gss_token_provider_ex,
475 )
476 .await?;
477 Ok(conn)
478 };
479 tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
480 .await
481 .map_err(|_| {
482 PgError::Connection(format!(
483 "GSSENC connection timeout after {:?} (handshake + auth)",
484 DEFAULT_CONNECT_TIMEOUT
485 ))
486 })?
487 }
488
489 pub async fn connect_with_password_and_auth(
491 host: &str,
492 port: u16,
493 user: &str,
494 database: &str,
495 password: Option<&str>,
496 auth_settings: AuthSettings,
497 ) -> PgResult<Self> {
498 Self::connect_with_password_and_auth_and_gss(ConnectParams {
499 host,
500 port,
501 user,
502 database,
503 password,
504 auth_settings,
505 gss_token_provider: None,
506 gss_token_provider_ex: None,
507 protocol_minor: Self::default_protocol_minor(),
508 startup_params: Vec::new(),
509 })
510 .await
511 }
512
513 async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
514 let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
515 if let Err(err) = &first
516 && params.protocol_minor > 0
517 && is_explicit_protocol_version_rejection(err)
518 {
519 let mut downgraded = params;
520 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
521 return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
522 }
523 first
524 }
525
526 async fn connect_with_password_and_auth_and_gss_once(
527 params: ConnectParams<'_>,
528 ) -> PgResult<Self> {
529 let connect_started = Instant::now();
530 let attempt_backend = plain_connect_attempt_backend();
531 record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
532 let result = tokio::time::timeout(
533 DEFAULT_CONNECT_TIMEOUT,
534 Self::connect_with_password_inner(params),
535 )
536 .await
537 .map_err(|_| {
538 PgError::Connection(format!(
539 "Connection timeout after {:?} (TCP connect + handshake)",
540 DEFAULT_CONNECT_TIMEOUT
541 ))
542 })?;
543 let backend = result
544 .as_ref()
545 .map(|conn| connect_backend_for_stream(&conn.stream))
546 .unwrap_or(attempt_backend);
547 record_connect_result(
548 CONNECT_TRANSPORT_PLAIN,
549 backend,
550 &result,
551 connect_started.elapsed(),
552 );
553 result
554 }
555
556 async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
558 let ConnectParams {
559 host,
560 port,
561 user,
562 database,
563 password,
564 auth_settings,
565 gss_token_provider,
566 gss_token_provider_ex,
567 protocol_minor,
568 startup_params,
569 } = params;
570 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
571 let addr = socket_addr(host, port);
572 let stream = Self::connect_plain_stream(&addr).await?;
573
574 let mut conn = Self {
575 stream,
576 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
577 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
579 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
581 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
582 column_info_cache: HashMap::new(),
583 process_id: 0,
584 cancel_key_bytes: Vec::new(),
585 requested_protocol_minor: protocol_minor,
586 negotiated_protocol_minor: protocol_minor,
587 notifications: VecDeque::new(),
588 replication_stream_active: false,
589 replication_mode_enabled,
590 last_replication_wal_end: None,
591 io_desynced: false,
592 pending_statement_closes: Vec::new(),
593 draining_statement_closes: false,
594 };
595
596 conn.send(FrontendMessage::Startup {
597 user: user.to_string(),
598 database: database.to_string(),
599 protocol_version: protocol_version_from_minor(protocol_minor),
600 startup_params,
601 })
602 .await?;
603
604 conn.handle_startup(
605 user,
606 password,
607 auth_settings,
608 gss_token_provider,
609 gss_token_provider_ex,
610 )
611 .await?;
612
613 Ok(conn)
614 }
615
616 async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
617 let tcp_stream = TcpStream::connect(addr).await?;
618 tcp_stream.set_nodelay(true)?;
619
620 #[cfg(all(target_os = "linux", feature = "io_uring"))]
621 {
622 if should_try_uring_plain() {
623 let std_stream = tcp_stream.into_std()?;
624 let fallback_std = std_stream.try_clone()?;
625 match super::super::uring::UringTcpStream::from_std(std_stream) {
626 Ok(uring_stream) => {
627 tracing::info!(
628 addr = %addr,
629 "qail-pg: using io_uring plain TCP transport"
630 );
631 return Ok(PgStream::Uring(uring_stream));
632 }
633 Err(e) => {
634 tracing::warn!(
635 addr = %addr,
636 error = %e,
637 "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
638 );
639 fallback_std.set_nonblocking(true)?;
640 let fallback = TcpStream::from_std(fallback_std)?;
641 return Ok(PgStream::Tcp(fallback));
642 }
643 }
644 }
645 }
646
647 Ok(PgStream::Tcp(tcp_stream))
648 }
649
650 pub async fn connect_tls(
653 host: &str,
654 port: u16,
655 user: &str,
656 database: &str,
657 password: Option<&str>,
658 ) -> PgResult<Self> {
659 Self::connect_tls_with_auth(
660 host,
661 port,
662 user,
663 database,
664 password,
665 AuthSettings::default(),
666 None,
667 )
668 .await
669 }
670
671 pub async fn connect_tls_with_auth(
673 host: &str,
674 port: u16,
675 user: &str,
676 database: &str,
677 password: Option<&str>,
678 auth_settings: AuthSettings,
679 ca_cert_pem: Option<&[u8]>,
680 ) -> PgResult<Self> {
681 Self::connect_tls_with_auth_and_gss(
682 ConnectParams {
683 host,
684 port,
685 user,
686 database,
687 password,
688 auth_settings,
689 gss_token_provider: None,
690 gss_token_provider_ex: None,
691 protocol_minor: Self::default_protocol_minor(),
692 startup_params: Vec::new(),
693 },
694 ca_cert_pem,
695 )
696 .await
697 }
698
699 async fn connect_tls_with_auth_and_gss(
700 params: ConnectParams<'_>,
701 ca_cert_pem: Option<&[u8]>,
702 ) -> PgResult<Self> {
703 let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
704 if let Err(err) = &first
705 && params.protocol_minor > 0
706 && is_explicit_protocol_version_rejection(err)
707 {
708 let mut downgraded = params;
709 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
710 return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
711 }
712 first
713 }
714
715 async fn connect_tls_with_auth_and_gss_once(
716 params: ConnectParams<'_>,
717 ca_cert_pem: Option<&[u8]>,
718 ) -> PgResult<Self> {
719 let connect_started = Instant::now();
720 record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
721 let result = tokio::time::timeout(
722 DEFAULT_CONNECT_TIMEOUT,
723 Self::connect_tls_inner(params, ca_cert_pem),
724 )
725 .await
726 .map_err(|_| {
727 PgError::Connection(format!(
728 "TLS connection timeout after {:?}",
729 DEFAULT_CONNECT_TIMEOUT
730 ))
731 })?;
732 record_connect_result(
733 CONNECT_TRANSPORT_TLS,
734 CONNECT_BACKEND_TOKIO,
735 &result,
736 connect_started.elapsed(),
737 );
738 result
739 }
740
741 async fn connect_tls_inner(
743 params: ConnectParams<'_>,
744 ca_cert_pem: Option<&[u8]>,
745 ) -> PgResult<Self> {
746 let ConnectParams {
747 host,
748 port,
749 user,
750 database,
751 password,
752 auth_settings,
753 gss_token_provider,
754 gss_token_provider_ex,
755 protocol_minor,
756 startup_params,
757 } = params;
758 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
759 use tokio::io::AsyncReadExt;
760 use tokio_rustls::TlsConnector;
761 use tokio_rustls::rustls::ClientConfig;
762 use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
763
764 let addr = socket_addr(host, port);
765 let mut tcp_stream = TcpStream::connect(&addr).await?;
766
767 tcp_stream.write_all(&SSL_REQUEST).await?;
769
770 let mut response = [0u8; 1];
772 tcp_stream.read_exact(&mut response).await?;
773
774 if response[0] != b'S' {
775 return Err(PgError::Connection(
776 "Server does not support TLS".to_string(),
777 ));
778 }
779
780 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
781
782 if let Some(ca_pem) = ca_cert_pem {
783 let certs = CertificateDer::pem_slice_iter(ca_pem)
784 .collect::<Result<Vec<_>, _>>()
785 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
786 if certs.is_empty() {
787 return Err(PgError::Connection(
788 "No CA certificates found in provided PEM".to_string(),
789 ));
790 }
791 for cert in certs {
792 let _ = root_cert_store.add(cert);
793 }
794 } else {
795 let certs = rustls_native_certs::load_native_certs();
796 for cert in certs.certs {
797 let _ = root_cert_store.add(cert);
798 }
799 }
800
801 let config = ClientConfig::builder()
802 .with_root_certificates(root_cert_store)
803 .with_no_client_auth();
804
805 let connector = TlsConnector::from(Arc::new(config));
806 let server_name = ServerName::try_from(host.to_string())
807 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
808
809 let tls_stream = connector
810 .connect(server_name, tcp_stream)
811 .await
812 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
813
814 let mut conn = Self {
815 stream: PgStream::Tls(Box::new(tls_stream)),
816 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
817 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
818 sql_buf: BytesMut::with_capacity(512),
819 params_buf: Vec::with_capacity(16),
820 prepared_statements: HashMap::new(),
821 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
822 column_info_cache: HashMap::new(),
823 process_id: 0,
824 cancel_key_bytes: Vec::new(),
825 requested_protocol_minor: protocol_minor,
826 negotiated_protocol_minor: protocol_minor,
827 notifications: VecDeque::new(),
828 replication_stream_active: false,
829 replication_mode_enabled,
830 last_replication_wal_end: None,
831 io_desynced: false,
832 pending_statement_closes: Vec::new(),
833 draining_statement_closes: false,
834 };
835
836 conn.send(FrontendMessage::Startup {
837 user: user.to_string(),
838 database: database.to_string(),
839 protocol_version: protocol_version_from_minor(protocol_minor),
840 startup_params,
841 })
842 .await?;
843
844 conn.handle_startup(
845 user,
846 password,
847 auth_settings,
848 gss_token_provider,
849 gss_token_provider_ex,
850 )
851 .await?;
852
853 Ok(conn)
854 }
855
856 pub async fn connect_mtls(
873 host: &str,
874 port: u16,
875 user: &str,
876 database: &str,
877 config: TlsConfig,
878 ) -> PgResult<Self> {
879 Self::connect_mtls_with_password_and_auth(
880 host,
881 port,
882 user,
883 database,
884 None,
885 config,
886 AuthSettings::default(),
887 )
888 .await
889 }
890
891 pub async fn connect_mtls_with_password_and_auth(
893 host: &str,
894 port: u16,
895 user: &str,
896 database: &str,
897 password: Option<&str>,
898 config: TlsConfig,
899 auth_settings: AuthSettings,
900 ) -> PgResult<Self> {
901 Self::connect_mtls_with_password_and_auth_and_gss(
902 ConnectParams {
903 host,
904 port,
905 user,
906 database,
907 password,
908 auth_settings,
909 gss_token_provider: None,
910 gss_token_provider_ex: None,
911 protocol_minor: Self::default_protocol_minor(),
912 startup_params: Vec::new(),
913 },
914 config,
915 )
916 .await
917 }
918
919 async fn connect_mtls_with_password_and_auth_and_gss(
920 params: ConnectParams<'_>,
921 config: TlsConfig,
922 ) -> PgResult<Self> {
923 let first =
924 Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
925 .await;
926 if let Err(err) = &first
927 && params.protocol_minor > 0
928 && is_explicit_protocol_version_rejection(err)
929 {
930 let mut downgraded = params;
931 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
932 return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
933 .await;
934 }
935 first
936 }
937
938 async fn connect_mtls_with_password_and_auth_and_gss_once(
939 params: ConnectParams<'_>,
940 config: TlsConfig,
941 ) -> PgResult<Self> {
942 let connect_started = Instant::now();
943 record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
944 let result = tokio::time::timeout(
945 DEFAULT_CONNECT_TIMEOUT,
946 Self::connect_mtls_inner(params, config),
947 )
948 .await
949 .map_err(|_| {
950 PgError::Connection(format!(
951 "mTLS connection timeout after {:?}",
952 DEFAULT_CONNECT_TIMEOUT
953 ))
954 })?;
955 record_connect_result(
956 CONNECT_TRANSPORT_MTLS,
957 CONNECT_BACKEND_TOKIO,
958 &result,
959 connect_started.elapsed(),
960 );
961 result
962 }
963
964 async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
966 let ConnectParams {
967 host,
968 port,
969 user,
970 database,
971 password,
972 auth_settings,
973 gss_token_provider,
974 gss_token_provider_ex,
975 protocol_minor,
976 startup_params,
977 } = params;
978 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
979 use tokio::io::AsyncReadExt;
980 use tokio_rustls::TlsConnector;
981 use tokio_rustls::rustls::{
982 ClientConfig,
983 pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
984 };
985
986 let addr = socket_addr(host, port);
987 let mut tcp_stream = TcpStream::connect(&addr).await?;
988
989 tcp_stream.write_all(&SSL_REQUEST).await?;
991
992 let mut response = [0u8; 1];
994 tcp_stream.read_exact(&mut response).await?;
995
996 if response[0] != b'S' {
997 return Err(PgError::Connection(
998 "Server does not support TLS".to_string(),
999 ));
1000 }
1001
1002 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1003
1004 if let Some(ca_pem) = &config.ca_cert_pem {
1005 let certs = CertificateDer::pem_slice_iter(ca_pem)
1006 .collect::<Result<Vec<_>, _>>()
1007 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1008 if certs.is_empty() {
1009 return Err(PgError::Connection(
1010 "No CA certificates found in provided PEM".to_string(),
1011 ));
1012 }
1013 for cert in certs {
1014 let _ = root_cert_store.add(cert);
1015 }
1016 } else {
1017 let certs = rustls_native_certs::load_native_certs();
1019 for cert in certs.certs {
1020 let _ = root_cert_store.add(cert);
1021 }
1022 }
1023
1024 let client_certs: Vec<CertificateDer<'static>> =
1025 CertificateDer::pem_slice_iter(&config.client_cert_pem)
1026 .collect::<Result<Vec<_>, _>>()
1027 .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1028 if client_certs.is_empty() {
1029 return Err(PgError::Connection(
1030 "No client certificates found in PEM".to_string(),
1031 ));
1032 }
1033
1034 let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1035 .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1036
1037 let tls_config = ClientConfig::builder()
1038 .with_root_certificates(root_cert_store)
1039 .with_client_auth_cert(client_certs, client_key)
1040 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1041
1042 let connector = TlsConnector::from(Arc::new(tls_config));
1043 let server_name = ServerName::try_from(host.to_string())
1044 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1045
1046 let tls_stream = connector
1047 .connect(server_name, tcp_stream)
1048 .await
1049 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1050
1051 let mut conn = Self {
1052 stream: PgStream::Tls(Box::new(tls_stream)),
1053 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1054 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1055 sql_buf: BytesMut::with_capacity(512),
1056 params_buf: Vec::with_capacity(16),
1057 prepared_statements: HashMap::new(),
1058 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1059 column_info_cache: HashMap::new(),
1060 process_id: 0,
1061 cancel_key_bytes: Vec::new(),
1062 requested_protocol_minor: protocol_minor,
1063 negotiated_protocol_minor: protocol_minor,
1064 notifications: VecDeque::new(),
1065 replication_stream_active: false,
1066 replication_mode_enabled,
1067 last_replication_wal_end: None,
1068 io_desynced: false,
1069 pending_statement_closes: Vec::new(),
1070 draining_statement_closes: false,
1071 };
1072
1073 conn.send(FrontendMessage::Startup {
1074 user: user.to_string(),
1075 database: database.to_string(),
1076 protocol_version: protocol_version_from_minor(protocol_minor),
1077 startup_params,
1078 })
1079 .await?;
1080
1081 conn.handle_startup(
1082 user,
1083 password,
1084 auth_settings,
1085 gss_token_provider,
1086 gss_token_provider_ex,
1087 )
1088 .await?;
1089
1090 Ok(conn)
1091 }
1092
1093 #[cfg(unix)]
1095 pub async fn connect_unix(
1096 socket_path: &str,
1097 user: &str,
1098 database: &str,
1099 password: Option<&str>,
1100 ) -> PgResult<Self> {
1101 let default_minor = Self::default_protocol_minor();
1102 let first =
1103 Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1104 .await;
1105 if let Err(err) = &first
1106 && default_minor > 0
1107 && is_explicit_protocol_version_rejection(err)
1108 {
1109 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1110 return Self::connect_unix_with_protocol(
1111 socket_path,
1112 user,
1113 database,
1114 password,
1115 downgrade_minor,
1116 )
1117 .await;
1118 }
1119 first
1120 }
1121
1122 #[cfg(unix)]
1123 async fn connect_unix_with_protocol(
1124 socket_path: &str,
1125 user: &str,
1126 database: &str,
1127 password: Option<&str>,
1128 protocol_minor: u16,
1129 ) -> PgResult<Self> {
1130 use tokio::net::UnixStream;
1131
1132 let unix_stream = UnixStream::connect(socket_path).await?;
1133
1134 let mut conn = Self {
1135 stream: PgStream::Unix(unix_stream),
1136 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1137 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1138 sql_buf: BytesMut::with_capacity(512),
1139 params_buf: Vec::with_capacity(16),
1140 prepared_statements: HashMap::new(),
1141 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1142 column_info_cache: HashMap::new(),
1143 process_id: 0,
1144 cancel_key_bytes: Vec::new(),
1145 requested_protocol_minor: protocol_minor,
1146 negotiated_protocol_minor: protocol_minor,
1147 notifications: VecDeque::new(),
1148 replication_stream_active: false,
1149 replication_mode_enabled: false,
1150 last_replication_wal_end: None,
1151 io_desynced: false,
1152 pending_statement_closes: Vec::new(),
1153 draining_statement_closes: false,
1154 };
1155
1156 conn.send(FrontendMessage::Startup {
1157 user: user.to_string(),
1158 database: database.to_string(),
1159 protocol_version: protocol_version_from_minor(protocol_minor),
1160 startup_params: Vec::new(),
1161 })
1162 .await?;
1163
1164 conn.handle_startup(user, password, AuthSettings::default(), None, None)
1165 .await?;
1166
1167 Ok(conn)
1168 }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173 use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor, socket_addr};
1174 use crate::driver::PgError;
1175
1176 #[test]
1177 fn protocol_version_from_minor_encodes_major_3() {
1178 assert_eq!(protocol_version_from_minor(2), 196610);
1179 assert_eq!(protocol_version_from_minor(0), 196608);
1180 }
1181
1182 #[test]
1183 fn socket_addr_brackets_ipv6_hosts() {
1184 assert_eq!(socket_addr("127.0.0.1", 5432), "127.0.0.1:5432");
1185 assert_eq!(socket_addr("::1", 5432), "[::1]:5432");
1186 assert_eq!(socket_addr("[::1]", 5432), "[::1]:5432");
1187 }
1188
1189 #[test]
1190 fn explicit_protocol_rejection_detection_is_case_insensitive() {
1191 let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1192 assert!(is_explicit_protocol_version_rejection(&err));
1193
1194 let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1195 assert!(is_explicit_protocol_version_rejection(&err));
1196 }
1197
1198 #[test]
1199 fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1200 let err = PgError::Connection("connection reset by peer".to_string());
1201 assert!(!is_explicit_protocol_version_rejection(&err));
1202 }
1203}