1#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6 connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7 record_connect_attempt, record_connect_result,
8};
9use super::types::{
10 BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
11 CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
12 GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
13 StatementCache, TlsConfig, has_logical_replication_startup_mode,
14};
15use crate::driver::stream::PgStream;
16use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
17use crate::protocol::PROTOCOL_VERSION_3_0;
18use crate::protocol::wire::FrontendMessage;
19use bytes::BytesMut;
20use std::collections::{HashMap, VecDeque};
21use std::sync::Arc;
22use std::time::Instant;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26#[inline]
27fn protocol_version_from_minor(minor: u16) -> i32 {
28 ((3i32) << 16) | i32::from(minor)
29}
30
31fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
32 let msg = match err {
33 PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
34 PgError::Query(msg) => msg,
35 PgError::QueryServer(server) => &server.message,
36 _ => return false,
37 };
38
39 let lower = msg.to_ascii_lowercase();
40 lower.contains("unsupported frontend protocol")
41 || lower.contains("frontend protocol") && lower.contains("unsupported")
42 || lower.contains("protocol version") && lower.contains("not support")
43}
44
45impl PgConnection {
46 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
55 Self::connect_with_password(host, port, user, database, None).await
56 }
57
58 pub async fn connect_with_password(
65 host: &str,
66 port: u16,
67 user: &str,
68 database: &str,
69 password: Option<&str>,
70 ) -> PgResult<Self> {
71 Self::connect_with_password_and_auth(
72 host,
73 port,
74 user,
75 database,
76 password,
77 AuthSettings::default(),
78 )
79 .await
80 }
81
82 pub async fn connect_with_options(
93 host: &str,
94 port: u16,
95 user: &str,
96 database: &str,
97 password: Option<&str>,
98 options: ConnectOptions,
99 ) -> PgResult<Self> {
100 let ConnectOptions {
101 tls_mode,
102 gss_enc_mode,
103 tls_ca_cert_pem,
104 mtls,
105 gss_token_provider,
106 gss_token_provider_ex,
107 auth,
108 startup_params,
109 } = options;
110
111 if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
112 return Err(PgError::Connection(
113 "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
114 ));
115 }
116
117 if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
121 return Err(PgError::Connection(
122 "gssencmode=require is incompatible with mTLS — both provide \
123 transport encryption; use one or the other"
124 .to_string(),
125 ));
126 }
127
128 if let Some(mtls_config) = mtls {
129 return Self::connect_mtls_with_password_and_auth_and_gss(
132 ConnectParams {
133 host,
134 port,
135 user,
136 database,
137 password,
138 auth_settings: auth,
139 gss_token_provider,
140 gss_token_provider_ex,
141 protocol_minor: Self::default_protocol_minor(),
142 startup_params: startup_params.clone(),
143 },
144 mtls_config,
145 )
146 .await;
147 }
148
149 if gss_enc_mode != GssEncMode::Disable {
151 match Self::try_gssenc_request(host, port).await {
152 Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
153 let connect_started = Instant::now();
154 record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
155 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
156 {
157 let default_minor = Self::default_protocol_minor();
158 let mut result = Self::connect_gssenc_accepted_with_timeout(
159 tcp_stream,
160 host,
161 user,
162 database,
163 password,
164 auth,
165 gss_token_provider,
166 gss_token_provider_ex.clone(),
167 startup_params.clone(),
168 default_minor,
169 )
170 .await;
171 if let Err(err) = &result {
172 if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
173 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
174 let retry_stream = match Self::try_gssenc_request(host, port).await
175 {
176 Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
177 Ok(GssEncNegotiationResult::Rejected) => {
178 return Err(PgError::Connection(
179 "Protocol downgrade retry failed: server rejected GSSENCRequest"
180 .to_string(),
181 ));
182 }
183 Ok(GssEncNegotiationResult::ServerError) => {
184 return Err(PgError::Connection(
185 "Protocol downgrade retry failed: server returned error to GSSENCRequest"
186 .to_string(),
187 ));
188 }
189 Err(e) => {
190 return Err(e);
191 }
192 };
193 result = Self::connect_gssenc_accepted_with_timeout(
194 retry_stream,
195 host,
196 user,
197 database,
198 password,
199 auth,
200 gss_token_provider,
201 gss_token_provider_ex,
202 startup_params.clone(),
203 downgrade_minor,
204 )
205 .await;
206 }
207 }
208 record_connect_result(
209 CONNECT_TRANSPORT_GSSENC,
210 CONNECT_BACKEND_TOKIO,
211 &result,
212 connect_started.elapsed(),
213 );
214 return result;
215 }
216 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
217 {
218 let _ = tcp_stream;
219 let err = PgError::Connection(
220 "Server accepted GSSENCRequest but GSSAPI encryption requires \
221 feature enterprise-gssapi on Linux"
222 .to_string(),
223 );
224 metrics::histogram!(
225 "qail_pg_connect_duration_seconds",
226 "transport" => CONNECT_TRANSPORT_GSSENC,
227 "backend" => CONNECT_BACKEND_TOKIO,
228 "outcome" => "error"
229 )
230 .record(connect_started.elapsed().as_secs_f64());
231 metrics::counter!(
232 "qail_pg_connect_failure_total",
233 "transport" => CONNECT_TRANSPORT_GSSENC,
234 "backend" => CONNECT_BACKEND_TOKIO,
235 "error_kind" => connect_error_kind(&err)
236 )
237 .increment(1);
238 return Err(err);
239 }
240 }
241 Ok(GssEncNegotiationResult::Rejected)
242 | Ok(GssEncNegotiationResult::ServerError) => {
243 if gss_enc_mode == GssEncMode::Require {
244 return Err(PgError::Connection(
245 "gssencmode=require but server rejected GSSENCRequest".to_string(),
246 ));
247 }
248 }
250 Err(e) => {
251 if gss_enc_mode == GssEncMode::Require {
252 return Err(e);
253 }
254 tracing::debug!(
256 host = %host,
257 port = %port,
258 error = %e,
259 "gssenc_prefer_fallthrough"
260 );
261 }
262 }
263 }
264
265 match tls_mode {
267 TlsMode::Disable => {
268 Self::connect_with_password_and_auth_and_gss(ConnectParams {
269 host,
270 port,
271 user,
272 database,
273 password,
274 auth_settings: auth,
275 gss_token_provider,
276 gss_token_provider_ex,
277 protocol_minor: Self::default_protocol_minor(),
278 startup_params: startup_params.clone(),
279 })
280 .await
281 }
282 TlsMode::Require => {
283 Self::connect_tls_with_auth_and_gss(
284 ConnectParams {
285 host,
286 port,
287 user,
288 database,
289 password,
290 auth_settings: auth,
291 gss_token_provider,
292 gss_token_provider_ex,
293 protocol_minor: Self::default_protocol_minor(),
294 startup_params: startup_params.clone(),
295 },
296 tls_ca_cert_pem.as_deref(),
297 )
298 .await
299 }
300 TlsMode::Prefer => {
301 match Self::connect_tls_with_auth_and_gss(
302 ConnectParams {
303 host,
304 port,
305 user,
306 database,
307 password,
308 auth_settings: auth,
309 gss_token_provider,
310 gss_token_provider_ex: gss_token_provider_ex.clone(),
311 protocol_minor: Self::default_protocol_minor(),
312 startup_params: startup_params.clone(),
313 },
314 tls_ca_cert_pem.as_deref(),
315 )
316 .await
317 {
318 Ok(conn) => Ok(conn),
319 Err(PgError::Connection(msg))
320 if msg.contains("Server does not support TLS") =>
321 {
322 Self::connect_with_password_and_auth_and_gss(ConnectParams {
323 host,
324 port,
325 user,
326 database,
327 password,
328 auth_settings: auth,
329 gss_token_provider,
330 gss_token_provider_ex,
331 protocol_minor: Self::default_protocol_minor(),
332 startup_params: startup_params.clone(),
333 })
334 .await
335 }
336 Err(e) => Err(e),
337 }
338 }
339 }
340 }
341
342 async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
349 tokio::time::timeout(
350 DEFAULT_CONNECT_TIMEOUT,
351 Self::try_gssenc_request_inner(host, port),
352 )
353 .await
354 .map_err(|_| {
355 PgError::Connection(format!(
356 "GSSENCRequest timeout after {:?}",
357 DEFAULT_CONNECT_TIMEOUT
358 ))
359 })?
360 }
361
362 async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
364 use tokio::io::AsyncReadExt;
365
366 let addr = format!("{}:{}", host, port);
367 let mut tcp_stream = TcpStream::connect(&addr).await?;
368 tcp_stream.set_nodelay(true)?;
369
370 tcp_stream.write_all(&GSSENC_REQUEST).await?;
372 tcp_stream.flush().await?;
373
374 let mut response = [0u8; 1];
378 tcp_stream.read_exact(&mut response).await?;
379
380 match response[0] {
381 b'G' => {
382 let mut peek_buf = [0u8; 1];
385 match tcp_stream.try_read(&mut peek_buf) {
386 Ok(0) => {} Ok(_n) => {
388 return Err(PgError::Connection(
390 "Protocol violation: extra bytes after GSSENCRequest 'G' response \
391 (possible CVE-2021-23222 buffer-stuffing attack)"
392 .to_string(),
393 ));
394 }
395 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
396 }
398 Err(e) => {
399 return Err(PgError::Io(e));
400 }
401 }
402 Ok(GssEncNegotiationResult::Accepted(tcp_stream))
403 }
404 b'N' => Ok(GssEncNegotiationResult::Rejected),
405 b'E' => {
406 tracing::trace!(
410 host = %host,
411 port = %port,
412 "gssenc_request_server_error (suppressed per CVE-2024-10977)"
413 );
414 Ok(GssEncNegotiationResult::ServerError)
415 }
416 other => Err(PgError::Connection(format!(
417 "Unexpected response to GSSENCRequest: 0x{:02X} \
418 (expected 'G'=0x47 or 'N'=0x4E)",
419 other
420 ))),
421 }
422 }
423
424 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
425 async fn connect_gssenc_accepted_with_timeout(
426 tcp_stream: TcpStream,
427 host: &str,
428 user: &str,
429 database: &str,
430 password: Option<&str>,
431 auth_settings: AuthSettings,
432 gss_token_provider: Option<super::super::GssTokenProvider>,
433 gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
434 startup_params: Vec<(String, String)>,
435 protocol_minor: u16,
436 ) -> PgResult<Self> {
437 let gssenc_fut = async {
438 let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
439 .await
440 .map_err(PgError::Auth)?;
441 let mut conn = Self {
442 stream: PgStream::GssEnc(gss_stream),
443 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
444 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
445 sql_buf: BytesMut::with_capacity(512),
446 params_buf: Vec::with_capacity(16),
447 prepared_statements: HashMap::new(),
448 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
449 column_info_cache: HashMap::new(),
450 process_id: 0,
451 secret_key: 0,
452 cancel_key_bytes: Vec::new(),
453 requested_protocol_minor: protocol_minor,
454 negotiated_protocol_minor: protocol_minor,
455 notifications: VecDeque::new(),
456 replication_stream_active: false,
457 replication_mode_enabled: has_logical_replication_startup_mode(&startup_params),
458 last_replication_wal_end: None,
459 io_desynced: false,
460 pending_statement_closes: Vec::new(),
461 draining_statement_closes: false,
462 };
463 conn.send(FrontendMessage::Startup {
464 user: user.to_string(),
465 database: database.to_string(),
466 protocol_version: protocol_version_from_minor(protocol_minor),
467 startup_params: startup_params.clone(),
468 })
469 .await?;
470 conn.handle_startup(
471 user,
472 password,
473 auth_settings,
474 gss_token_provider,
475 gss_token_provider_ex,
476 )
477 .await?;
478 Ok(conn)
479 };
480 tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
481 .await
482 .map_err(|_| {
483 PgError::Connection(format!(
484 "GSSENC connection timeout after {:?} (handshake + auth)",
485 DEFAULT_CONNECT_TIMEOUT
486 ))
487 })?
488 }
489
490 pub async fn connect_with_password_and_auth(
492 host: &str,
493 port: u16,
494 user: &str,
495 database: &str,
496 password: Option<&str>,
497 auth_settings: AuthSettings,
498 ) -> PgResult<Self> {
499 Self::connect_with_password_and_auth_and_gss(ConnectParams {
500 host,
501 port,
502 user,
503 database,
504 password,
505 auth_settings,
506 gss_token_provider: None,
507 gss_token_provider_ex: None,
508 protocol_minor: Self::default_protocol_minor(),
509 startup_params: Vec::new(),
510 })
511 .await
512 }
513
514 async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
515 let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
516 if let Err(err) = &first
517 && params.protocol_minor > 0
518 && is_explicit_protocol_version_rejection(err)
519 {
520 let mut downgraded = params;
521 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
522 return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
523 }
524 first
525 }
526
527 async fn connect_with_password_and_auth_and_gss_once(
528 params: ConnectParams<'_>,
529 ) -> PgResult<Self> {
530 let connect_started = Instant::now();
531 let attempt_backend = plain_connect_attempt_backend();
532 record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
533 let result = tokio::time::timeout(
534 DEFAULT_CONNECT_TIMEOUT,
535 Self::connect_with_password_inner(params),
536 )
537 .await
538 .map_err(|_| {
539 PgError::Connection(format!(
540 "Connection timeout after {:?} (TCP connect + handshake)",
541 DEFAULT_CONNECT_TIMEOUT
542 ))
543 })?;
544 let backend = result
545 .as_ref()
546 .map(|conn| connect_backend_for_stream(&conn.stream))
547 .unwrap_or(attempt_backend);
548 record_connect_result(
549 CONNECT_TRANSPORT_PLAIN,
550 backend,
551 &result,
552 connect_started.elapsed(),
553 );
554 result
555 }
556
557 async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
559 let ConnectParams {
560 host,
561 port,
562 user,
563 database,
564 password,
565 auth_settings,
566 gss_token_provider,
567 gss_token_provider_ex,
568 protocol_minor,
569 startup_params,
570 } = params;
571 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
572 let addr = format!("{}:{}", host, port);
573 let stream = Self::connect_plain_stream(&addr).await?;
574
575 let mut conn = Self {
576 stream,
577 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
578 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
580 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
582 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
583 column_info_cache: HashMap::new(),
584 process_id: 0,
585 secret_key: 0,
586 cancel_key_bytes: Vec::new(),
587 requested_protocol_minor: protocol_minor,
588 negotiated_protocol_minor: protocol_minor,
589 notifications: VecDeque::new(),
590 replication_stream_active: false,
591 replication_mode_enabled,
592 last_replication_wal_end: None,
593 io_desynced: false,
594 pending_statement_closes: Vec::new(),
595 draining_statement_closes: false,
596 };
597
598 conn.send(FrontendMessage::Startup {
599 user: user.to_string(),
600 database: database.to_string(),
601 protocol_version: protocol_version_from_minor(protocol_minor),
602 startup_params,
603 })
604 .await?;
605
606 conn.handle_startup(
607 user,
608 password,
609 auth_settings,
610 gss_token_provider,
611 gss_token_provider_ex,
612 )
613 .await?;
614
615 Ok(conn)
616 }
617
618 async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
619 let tcp_stream = TcpStream::connect(addr).await?;
620 tcp_stream.set_nodelay(true)?;
621
622 #[cfg(all(target_os = "linux", feature = "io_uring"))]
623 {
624 if should_try_uring_plain() {
625 let std_stream = tcp_stream.into_std()?;
626 let fallback_std = std_stream.try_clone()?;
627 match super::super::uring::UringTcpStream::from_std(std_stream) {
628 Ok(uring_stream) => {
629 tracing::info!(
630 addr = %addr,
631 "qail-pg: using io_uring plain TCP transport"
632 );
633 return Ok(PgStream::Uring(uring_stream));
634 }
635 Err(e) => {
636 tracing::warn!(
637 addr = %addr,
638 error = %e,
639 "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
640 );
641 fallback_std.set_nonblocking(true)?;
642 let fallback = TcpStream::from_std(fallback_std)?;
643 return Ok(PgStream::Tcp(fallback));
644 }
645 }
646 }
647 }
648
649 Ok(PgStream::Tcp(tcp_stream))
650 }
651
652 pub async fn connect_tls(
655 host: &str,
656 port: u16,
657 user: &str,
658 database: &str,
659 password: Option<&str>,
660 ) -> PgResult<Self> {
661 Self::connect_tls_with_auth(
662 host,
663 port,
664 user,
665 database,
666 password,
667 AuthSettings::default(),
668 None,
669 )
670 .await
671 }
672
673 pub async fn connect_tls_with_auth(
675 host: &str,
676 port: u16,
677 user: &str,
678 database: &str,
679 password: Option<&str>,
680 auth_settings: AuthSettings,
681 ca_cert_pem: Option<&[u8]>,
682 ) -> PgResult<Self> {
683 Self::connect_tls_with_auth_and_gss(
684 ConnectParams {
685 host,
686 port,
687 user,
688 database,
689 password,
690 auth_settings,
691 gss_token_provider: None,
692 gss_token_provider_ex: None,
693 protocol_minor: Self::default_protocol_minor(),
694 startup_params: Vec::new(),
695 },
696 ca_cert_pem,
697 )
698 .await
699 }
700
701 async fn connect_tls_with_auth_and_gss(
702 params: ConnectParams<'_>,
703 ca_cert_pem: Option<&[u8]>,
704 ) -> PgResult<Self> {
705 let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
706 if let Err(err) = &first
707 && params.protocol_minor > 0
708 && is_explicit_protocol_version_rejection(err)
709 {
710 let mut downgraded = params;
711 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
712 return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
713 }
714 first
715 }
716
717 async fn connect_tls_with_auth_and_gss_once(
718 params: ConnectParams<'_>,
719 ca_cert_pem: Option<&[u8]>,
720 ) -> PgResult<Self> {
721 let connect_started = Instant::now();
722 record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
723 let result = tokio::time::timeout(
724 DEFAULT_CONNECT_TIMEOUT,
725 Self::connect_tls_inner(params, ca_cert_pem),
726 )
727 .await
728 .map_err(|_| {
729 PgError::Connection(format!(
730 "TLS connection timeout after {:?}",
731 DEFAULT_CONNECT_TIMEOUT
732 ))
733 })?;
734 record_connect_result(
735 CONNECT_TRANSPORT_TLS,
736 CONNECT_BACKEND_TOKIO,
737 &result,
738 connect_started.elapsed(),
739 );
740 result
741 }
742
743 async fn connect_tls_inner(
745 params: ConnectParams<'_>,
746 ca_cert_pem: Option<&[u8]>,
747 ) -> PgResult<Self> {
748 let ConnectParams {
749 host,
750 port,
751 user,
752 database,
753 password,
754 auth_settings,
755 gss_token_provider,
756 gss_token_provider_ex,
757 protocol_minor,
758 startup_params,
759 } = params;
760 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
761 use tokio::io::AsyncReadExt;
762 use tokio_rustls::TlsConnector;
763 use tokio_rustls::rustls::ClientConfig;
764 use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
765
766 let addr = format!("{}:{}", host, port);
767 let mut tcp_stream = TcpStream::connect(&addr).await?;
768
769 tcp_stream.write_all(&SSL_REQUEST).await?;
771
772 let mut response = [0u8; 1];
774 tcp_stream.read_exact(&mut response).await?;
775
776 if response[0] != b'S' {
777 return Err(PgError::Connection(
778 "Server does not support TLS".to_string(),
779 ));
780 }
781
782 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
783
784 if let Some(ca_pem) = ca_cert_pem {
785 let certs = CertificateDer::pem_slice_iter(ca_pem)
786 .collect::<Result<Vec<_>, _>>()
787 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
788 if certs.is_empty() {
789 return Err(PgError::Connection(
790 "No CA certificates found in provided PEM".to_string(),
791 ));
792 }
793 for cert in certs {
794 let _ = root_cert_store.add(cert);
795 }
796 } else {
797 let certs = rustls_native_certs::load_native_certs();
798 for cert in certs.certs {
799 let _ = root_cert_store.add(cert);
800 }
801 }
802
803 let config = ClientConfig::builder()
804 .with_root_certificates(root_cert_store)
805 .with_no_client_auth();
806
807 let connector = TlsConnector::from(Arc::new(config));
808 let server_name = ServerName::try_from(host.to_string())
809 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
810
811 let tls_stream = connector
812 .connect(server_name, tcp_stream)
813 .await
814 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
815
816 let mut conn = Self {
817 stream: PgStream::Tls(Box::new(tls_stream)),
818 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
819 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
820 sql_buf: BytesMut::with_capacity(512),
821 params_buf: Vec::with_capacity(16),
822 prepared_statements: HashMap::new(),
823 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
824 column_info_cache: HashMap::new(),
825 process_id: 0,
826 secret_key: 0,
827 cancel_key_bytes: Vec::new(),
828 requested_protocol_minor: protocol_minor,
829 negotiated_protocol_minor: protocol_minor,
830 notifications: VecDeque::new(),
831 replication_stream_active: false,
832 replication_mode_enabled,
833 last_replication_wal_end: None,
834 io_desynced: false,
835 pending_statement_closes: Vec::new(),
836 draining_statement_closes: false,
837 };
838
839 conn.send(FrontendMessage::Startup {
840 user: user.to_string(),
841 database: database.to_string(),
842 protocol_version: protocol_version_from_minor(protocol_minor),
843 startup_params,
844 })
845 .await?;
846
847 conn.handle_startup(
848 user,
849 password,
850 auth_settings,
851 gss_token_provider,
852 gss_token_provider_ex,
853 )
854 .await?;
855
856 Ok(conn)
857 }
858
859 pub async fn connect_mtls(
876 host: &str,
877 port: u16,
878 user: &str,
879 database: &str,
880 config: TlsConfig,
881 ) -> PgResult<Self> {
882 Self::connect_mtls_with_password_and_auth(
883 host,
884 port,
885 user,
886 database,
887 None,
888 config,
889 AuthSettings::default(),
890 )
891 .await
892 }
893
894 pub async fn connect_mtls_with_password_and_auth(
896 host: &str,
897 port: u16,
898 user: &str,
899 database: &str,
900 password: Option<&str>,
901 config: TlsConfig,
902 auth_settings: AuthSettings,
903 ) -> PgResult<Self> {
904 Self::connect_mtls_with_password_and_auth_and_gss(
905 ConnectParams {
906 host,
907 port,
908 user,
909 database,
910 password,
911 auth_settings,
912 gss_token_provider: None,
913 gss_token_provider_ex: None,
914 protocol_minor: Self::default_protocol_minor(),
915 startup_params: Vec::new(),
916 },
917 config,
918 )
919 .await
920 }
921
922 async fn connect_mtls_with_password_and_auth_and_gss(
923 params: ConnectParams<'_>,
924 config: TlsConfig,
925 ) -> PgResult<Self> {
926 let first =
927 Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
928 .await;
929 if let Err(err) = &first
930 && params.protocol_minor > 0
931 && is_explicit_protocol_version_rejection(err)
932 {
933 let mut downgraded = params;
934 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
935 return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
936 .await;
937 }
938 first
939 }
940
941 async fn connect_mtls_with_password_and_auth_and_gss_once(
942 params: ConnectParams<'_>,
943 config: TlsConfig,
944 ) -> PgResult<Self> {
945 let connect_started = Instant::now();
946 record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
947 let result = tokio::time::timeout(
948 DEFAULT_CONNECT_TIMEOUT,
949 Self::connect_mtls_inner(params, config),
950 )
951 .await
952 .map_err(|_| {
953 PgError::Connection(format!(
954 "mTLS connection timeout after {:?}",
955 DEFAULT_CONNECT_TIMEOUT
956 ))
957 })?;
958 record_connect_result(
959 CONNECT_TRANSPORT_MTLS,
960 CONNECT_BACKEND_TOKIO,
961 &result,
962 connect_started.elapsed(),
963 );
964 result
965 }
966
967 async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
969 let ConnectParams {
970 host,
971 port,
972 user,
973 database,
974 password,
975 auth_settings,
976 gss_token_provider,
977 gss_token_provider_ex,
978 protocol_minor,
979 startup_params,
980 } = params;
981 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
982 use tokio::io::AsyncReadExt;
983 use tokio_rustls::TlsConnector;
984 use tokio_rustls::rustls::{
985 ClientConfig,
986 pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
987 };
988
989 let addr = format!("{}:{}", host, port);
990 let mut tcp_stream = TcpStream::connect(&addr).await?;
991
992 tcp_stream.write_all(&SSL_REQUEST).await?;
994
995 let mut response = [0u8; 1];
997 tcp_stream.read_exact(&mut response).await?;
998
999 if response[0] != b'S' {
1000 return Err(PgError::Connection(
1001 "Server does not support TLS".to_string(),
1002 ));
1003 }
1004
1005 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1006
1007 if let Some(ca_pem) = &config.ca_cert_pem {
1008 let certs = CertificateDer::pem_slice_iter(ca_pem)
1009 .collect::<Result<Vec<_>, _>>()
1010 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1011 if certs.is_empty() {
1012 return Err(PgError::Connection(
1013 "No CA certificates found in provided PEM".to_string(),
1014 ));
1015 }
1016 for cert in certs {
1017 let _ = root_cert_store.add(cert);
1018 }
1019 } else {
1020 let certs = rustls_native_certs::load_native_certs();
1022 for cert in certs.certs {
1023 let _ = root_cert_store.add(cert);
1024 }
1025 }
1026
1027 let client_certs: Vec<CertificateDer<'static>> =
1028 CertificateDer::pem_slice_iter(&config.client_cert_pem)
1029 .collect::<Result<Vec<_>, _>>()
1030 .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1031 if client_certs.is_empty() {
1032 return Err(PgError::Connection(
1033 "No client certificates found in PEM".to_string(),
1034 ));
1035 }
1036
1037 let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1038 .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1039
1040 let tls_config = ClientConfig::builder()
1041 .with_root_certificates(root_cert_store)
1042 .with_client_auth_cert(client_certs, client_key)
1043 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1044
1045 let connector = TlsConnector::from(Arc::new(tls_config));
1046 let server_name = ServerName::try_from(host.to_string())
1047 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1048
1049 let tls_stream = connector
1050 .connect(server_name, tcp_stream)
1051 .await
1052 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1053
1054 let mut conn = Self {
1055 stream: PgStream::Tls(Box::new(tls_stream)),
1056 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1057 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1058 sql_buf: BytesMut::with_capacity(512),
1059 params_buf: Vec::with_capacity(16),
1060 prepared_statements: HashMap::new(),
1061 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1062 column_info_cache: HashMap::new(),
1063 process_id: 0,
1064 secret_key: 0,
1065 cancel_key_bytes: Vec::new(),
1066 requested_protocol_minor: protocol_minor,
1067 negotiated_protocol_minor: protocol_minor,
1068 notifications: VecDeque::new(),
1069 replication_stream_active: false,
1070 replication_mode_enabled,
1071 last_replication_wal_end: None,
1072 io_desynced: false,
1073 pending_statement_closes: Vec::new(),
1074 draining_statement_closes: false,
1075 };
1076
1077 conn.send(FrontendMessage::Startup {
1078 user: user.to_string(),
1079 database: database.to_string(),
1080 protocol_version: protocol_version_from_minor(protocol_minor),
1081 startup_params,
1082 })
1083 .await?;
1084
1085 conn.handle_startup(
1086 user,
1087 password,
1088 auth_settings,
1089 gss_token_provider,
1090 gss_token_provider_ex,
1091 )
1092 .await?;
1093
1094 Ok(conn)
1095 }
1096
1097 #[cfg(unix)]
1099 pub async fn connect_unix(
1100 socket_path: &str,
1101 user: &str,
1102 database: &str,
1103 password: Option<&str>,
1104 ) -> PgResult<Self> {
1105 let default_minor = Self::default_protocol_minor();
1106 let first =
1107 Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1108 .await;
1109 if let Err(err) = &first
1110 && default_minor > 0
1111 && is_explicit_protocol_version_rejection(err)
1112 {
1113 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1114 return Self::connect_unix_with_protocol(
1115 socket_path,
1116 user,
1117 database,
1118 password,
1119 downgrade_minor,
1120 )
1121 .await;
1122 }
1123 first
1124 }
1125
1126 #[cfg(unix)]
1127 async fn connect_unix_with_protocol(
1128 socket_path: &str,
1129 user: &str,
1130 database: &str,
1131 password: Option<&str>,
1132 protocol_minor: u16,
1133 ) -> PgResult<Self> {
1134 use tokio::net::UnixStream;
1135
1136 let unix_stream = UnixStream::connect(socket_path).await?;
1137
1138 let mut conn = Self {
1139 stream: PgStream::Unix(unix_stream),
1140 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1141 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1142 sql_buf: BytesMut::with_capacity(512),
1143 params_buf: Vec::with_capacity(16),
1144 prepared_statements: HashMap::new(),
1145 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1146 column_info_cache: HashMap::new(),
1147 process_id: 0,
1148 secret_key: 0,
1149 cancel_key_bytes: Vec::new(),
1150 requested_protocol_minor: protocol_minor,
1151 negotiated_protocol_minor: protocol_minor,
1152 notifications: VecDeque::new(),
1153 replication_stream_active: false,
1154 replication_mode_enabled: false,
1155 last_replication_wal_end: None,
1156 io_desynced: false,
1157 pending_statement_closes: Vec::new(),
1158 draining_statement_closes: false,
1159 };
1160
1161 conn.send(FrontendMessage::Startup {
1162 user: user.to_string(),
1163 database: database.to_string(),
1164 protocol_version: protocol_version_from_minor(protocol_minor),
1165 startup_params: Vec::new(),
1166 })
1167 .await?;
1168
1169 conn.handle_startup(user, password, AuthSettings::default(), None, None)
1170 .await?;
1171
1172 Ok(conn)
1173 }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor};
1179 use crate::driver::PgError;
1180
1181 #[test]
1182 fn protocol_version_from_minor_encodes_major_3() {
1183 assert_eq!(protocol_version_from_minor(2), 196610);
1184 assert_eq!(protocol_version_from_minor(0), 196608);
1185 }
1186
1187 #[test]
1188 fn explicit_protocol_rejection_detection_is_case_insensitive() {
1189 let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1190 assert!(is_explicit_protocol_version_rejection(&err));
1191
1192 let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1193 assert!(is_explicit_protocol_version_rejection(&err));
1194 }
1195
1196 #[test]
1197 fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1198 let err = PgError::Connection("connection reset by peer".to_string());
1199 assert!(!is_explicit_protocol_version_rejection(&err));
1200 }
1201}