1#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6 connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7 record_connect_attempt, record_connect_result,
8};
9#[cfg(all(target_os = "linux", feature = "io_uring"))]
10use super::types::CONNECT_BACKEND_IO_URING;
11use super::types::{
12 BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
13 CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
14 GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
15 StatementCache, TlsConfig, has_logical_replication_startup_mode,
16};
17use crate::driver::stream::PgStream;
18use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
19use crate::protocol::PROTOCOL_VERSION_3_0;
20use crate::protocol::wire::FrontendMessage;
21use bytes::BytesMut;
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::io::AsyncWriteExt;
26use tokio::net::TcpStream;
27
28#[inline]
29fn protocol_version_from_minor(minor: u16) -> i32 {
30 ((3i32) << 16) | i32::from(minor)
31}
32
33fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
34 let msg = match err {
35 PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
36 PgError::Query(msg) => msg,
37 PgError::QueryServer(server) => &server.message,
38 _ => return false,
39 };
40
41 let lower = msg.to_ascii_lowercase();
42 lower.contains("unsupported frontend protocol")
43 || lower.contains("frontend protocol") && lower.contains("unsupported")
44 || lower.contains("protocol version") && lower.contains("not support")
45}
46
47impl PgConnection {
48 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
57 Self::connect_with_password(host, port, user, database, None).await
58 }
59
60 pub async fn connect_with_password(
67 host: &str,
68 port: u16,
69 user: &str,
70 database: &str,
71 password: Option<&str>,
72 ) -> PgResult<Self> {
73 Self::connect_with_password_and_auth(
74 host,
75 port,
76 user,
77 database,
78 password,
79 AuthSettings::default(),
80 )
81 .await
82 }
83
84 pub async fn connect_with_options(
95 host: &str,
96 port: u16,
97 user: &str,
98 database: &str,
99 password: Option<&str>,
100 options: ConnectOptions,
101 ) -> PgResult<Self> {
102 let ConnectOptions {
103 tls_mode,
104 gss_enc_mode,
105 tls_ca_cert_pem,
106 mtls,
107 gss_token_provider,
108 gss_token_provider_ex,
109 auth,
110 startup_params,
111 } = options;
112
113 if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
114 return Err(PgError::Connection(
115 "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
116 ));
117 }
118
119 if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
123 return Err(PgError::Connection(
124 "gssencmode=require is incompatible with mTLS — both provide \
125 transport encryption; use one or the other"
126 .to_string(),
127 ));
128 }
129
130 if let Some(mtls_config) = mtls {
131 return Self::connect_mtls_with_password_and_auth_and_gss(
134 ConnectParams {
135 host,
136 port,
137 user,
138 database,
139 password,
140 auth_settings: auth,
141 gss_token_provider,
142 gss_token_provider_ex,
143 protocol_minor: Self::default_protocol_minor(),
144 startup_params: startup_params.clone(),
145 },
146 mtls_config,
147 )
148 .await;
149 }
150
151 if gss_enc_mode != GssEncMode::Disable {
153 match Self::try_gssenc_request(host, port).await {
154 Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
155 let connect_started = Instant::now();
156 record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
157 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
158 {
159 let default_minor = Self::default_protocol_minor();
160 let mut result = Self::connect_gssenc_accepted_with_timeout(
161 tcp_stream,
162 host,
163 user,
164 database,
165 password,
166 auth,
167 gss_token_provider,
168 gss_token_provider_ex.clone(),
169 startup_params.clone(),
170 default_minor,
171 )
172 .await;
173 if let Err(err) = &result {
174 if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
175 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
176 let retry_stream = match Self::try_gssenc_request(host, port).await
177 {
178 Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
179 Ok(GssEncNegotiationResult::Rejected) => {
180 return Err(PgError::Connection(
181 "Protocol downgrade retry failed: server rejected GSSENCRequest"
182 .to_string(),
183 ));
184 }
185 Ok(GssEncNegotiationResult::ServerError) => {
186 return Err(PgError::Connection(
187 "Protocol downgrade retry failed: server returned error to GSSENCRequest"
188 .to_string(),
189 ));
190 }
191 Err(e) => {
192 return Err(e);
193 }
194 };
195 result = Self::connect_gssenc_accepted_with_timeout(
196 retry_stream,
197 host,
198 user,
199 database,
200 password,
201 auth,
202 gss_token_provider,
203 gss_token_provider_ex,
204 startup_params.clone(),
205 downgrade_minor,
206 )
207 .await;
208 }
209 }
210 record_connect_result(
211 CONNECT_TRANSPORT_GSSENC,
212 CONNECT_BACKEND_TOKIO,
213 &result,
214 connect_started.elapsed(),
215 );
216 return result;
217 }
218 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
219 {
220 let _ = tcp_stream;
221 let err = PgError::Connection(
222 "Server accepted GSSENCRequest but GSSAPI encryption requires \
223 feature enterprise-gssapi on Linux"
224 .to_string(),
225 );
226 metrics::histogram!(
227 "qail_pg_connect_duration_seconds",
228 "transport" => CONNECT_TRANSPORT_GSSENC,
229 "backend" => CONNECT_BACKEND_TOKIO,
230 "outcome" => "error"
231 )
232 .record(connect_started.elapsed().as_secs_f64());
233 metrics::counter!(
234 "qail_pg_connect_failure_total",
235 "transport" => CONNECT_TRANSPORT_GSSENC,
236 "backend" => CONNECT_BACKEND_TOKIO,
237 "error_kind" => connect_error_kind(&err)
238 )
239 .increment(1);
240 return Err(err);
241 }
242 }
243 Ok(GssEncNegotiationResult::Rejected)
244 | Ok(GssEncNegotiationResult::ServerError) => {
245 if gss_enc_mode == GssEncMode::Require {
246 return Err(PgError::Connection(
247 "gssencmode=require but server rejected GSSENCRequest".to_string(),
248 ));
249 }
250 }
252 Err(e) => {
253 if gss_enc_mode == GssEncMode::Require {
254 return Err(e);
255 }
256 tracing::debug!(
258 host = %host,
259 port = %port,
260 error = %e,
261 "gssenc_prefer_fallthrough"
262 );
263 }
264 }
265 }
266
267 match tls_mode {
269 TlsMode::Disable => {
270 Self::connect_with_password_and_auth_and_gss(ConnectParams {
271 host,
272 port,
273 user,
274 database,
275 password,
276 auth_settings: auth,
277 gss_token_provider,
278 gss_token_provider_ex,
279 protocol_minor: Self::default_protocol_minor(),
280 startup_params: startup_params.clone(),
281 })
282 .await
283 }
284 TlsMode::Require => {
285 Self::connect_tls_with_auth_and_gss(
286 ConnectParams {
287 host,
288 port,
289 user,
290 database,
291 password,
292 auth_settings: auth,
293 gss_token_provider,
294 gss_token_provider_ex,
295 protocol_minor: Self::default_protocol_minor(),
296 startup_params: startup_params.clone(),
297 },
298 tls_ca_cert_pem.as_deref(),
299 )
300 .await
301 }
302 TlsMode::Prefer => {
303 match Self::connect_tls_with_auth_and_gss(
304 ConnectParams {
305 host,
306 port,
307 user,
308 database,
309 password,
310 auth_settings: auth,
311 gss_token_provider,
312 gss_token_provider_ex: gss_token_provider_ex.clone(),
313 protocol_minor: Self::default_protocol_minor(),
314 startup_params: startup_params.clone(),
315 },
316 tls_ca_cert_pem.as_deref(),
317 )
318 .await
319 {
320 Ok(conn) => Ok(conn),
321 Err(PgError::Connection(msg))
322 if msg.contains("Server does not support TLS") =>
323 {
324 Self::connect_with_password_and_auth_and_gss(ConnectParams {
325 host,
326 port,
327 user,
328 database,
329 password,
330 auth_settings: auth,
331 gss_token_provider,
332 gss_token_provider_ex,
333 protocol_minor: Self::default_protocol_minor(),
334 startup_params: startup_params.clone(),
335 })
336 .await
337 }
338 Err(e) => Err(e),
339 }
340 }
341 }
342 }
343
344 async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
351 tokio::time::timeout(
352 DEFAULT_CONNECT_TIMEOUT,
353 Self::try_gssenc_request_inner(host, port),
354 )
355 .await
356 .map_err(|_| {
357 PgError::Connection(format!(
358 "GSSENCRequest timeout after {:?}",
359 DEFAULT_CONNECT_TIMEOUT
360 ))
361 })?
362 }
363
364 async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
366 use tokio::io::AsyncReadExt;
367
368 let addr = format!("{}:{}", host, port);
369 let mut tcp_stream = TcpStream::connect(&addr).await?;
370 tcp_stream.set_nodelay(true)?;
371
372 tcp_stream.write_all(&GSSENC_REQUEST).await?;
374 tcp_stream.flush().await?;
375
376 let mut response = [0u8; 1];
380 tcp_stream.read_exact(&mut response).await?;
381
382 match response[0] {
383 b'G' => {
384 let mut peek_buf = [0u8; 1];
387 match tcp_stream.try_read(&mut peek_buf) {
388 Ok(0) => {} Ok(_n) => {
390 return Err(PgError::Connection(
392 "Protocol violation: extra bytes after GSSENCRequest 'G' response \
393 (possible CVE-2021-23222 buffer-stuffing attack)"
394 .to_string(),
395 ));
396 }
397 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
398 }
400 Err(e) => {
401 return Err(PgError::Io(e));
402 }
403 }
404 Ok(GssEncNegotiationResult::Accepted(tcp_stream))
405 }
406 b'N' => Ok(GssEncNegotiationResult::Rejected),
407 b'E' => {
408 tracing::trace!(
412 host = %host,
413 port = %port,
414 "gssenc_request_server_error (suppressed per CVE-2024-10977)"
415 );
416 Ok(GssEncNegotiationResult::ServerError)
417 }
418 other => Err(PgError::Connection(format!(
419 "Unexpected response to GSSENCRequest: 0x{:02X} \
420 (expected 'G'=0x47 or 'N'=0x4E)",
421 other
422 ))),
423 }
424 }
425
426 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
427 async fn connect_gssenc_accepted_with_timeout(
428 tcp_stream: TcpStream,
429 host: &str,
430 user: &str,
431 database: &str,
432 password: Option<&str>,
433 auth_settings: AuthSettings,
434 gss_token_provider: Option<super::super::GssTokenProvider>,
435 gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
436 startup_params: Vec<(String, String)>,
437 protocol_minor: u16,
438 ) -> PgResult<Self> {
439 let gssenc_fut = async {
440 let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
441 .await
442 .map_err(PgError::Auth)?;
443 let mut conn = Self {
444 stream: PgStream::GssEnc(gss_stream),
445 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
446 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
447 sql_buf: BytesMut::with_capacity(512),
448 params_buf: Vec::with_capacity(16),
449 prepared_statements: HashMap::new(),
450 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
451 column_info_cache: HashMap::new(),
452 process_id: 0,
453 secret_key: 0,
454 cancel_key_bytes: Vec::new(),
455 requested_protocol_minor: protocol_minor,
456 negotiated_protocol_minor: protocol_minor,
457 notifications: VecDeque::new(),
458 replication_stream_active: false,
459 replication_mode_enabled: has_logical_replication_startup_mode(&startup_params),
460 last_replication_wal_end: None,
461 io_desynced: false,
462 pending_statement_closes: Vec::new(),
463 draining_statement_closes: false,
464 };
465 conn.send(FrontendMessage::Startup {
466 user: user.to_string(),
467 database: database.to_string(),
468 protocol_version: protocol_version_from_minor(protocol_minor),
469 startup_params: startup_params.clone(),
470 })
471 .await?;
472 conn.handle_startup(
473 user,
474 password,
475 auth_settings,
476 gss_token_provider,
477 gss_token_provider_ex,
478 )
479 .await?;
480 Ok(conn)
481 };
482 tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
483 .await
484 .map_err(|_| {
485 PgError::Connection(format!(
486 "GSSENC connection timeout after {:?} (handshake + auth)",
487 DEFAULT_CONNECT_TIMEOUT
488 ))
489 })?
490 }
491
492 pub async fn connect_with_password_and_auth(
494 host: &str,
495 port: u16,
496 user: &str,
497 database: &str,
498 password: Option<&str>,
499 auth_settings: AuthSettings,
500 ) -> PgResult<Self> {
501 Self::connect_with_password_and_auth_and_gss(ConnectParams {
502 host,
503 port,
504 user,
505 database,
506 password,
507 auth_settings,
508 gss_token_provider: None,
509 gss_token_provider_ex: None,
510 protocol_minor: Self::default_protocol_minor(),
511 startup_params: Vec::new(),
512 })
513 .await
514 }
515
516 async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
517 let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
518 if let Err(err) = &first {
519 if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
520 let mut downgraded = params;
521 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
522 return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
523 }
524 }
525 first
526 }
527
528 async fn connect_with_password_and_auth_and_gss_once(
529 params: ConnectParams<'_>,
530 ) -> PgResult<Self> {
531 let connect_started = Instant::now();
532 let attempt_backend = plain_connect_attempt_backend();
533 record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
534 let result = tokio::time::timeout(
535 DEFAULT_CONNECT_TIMEOUT,
536 Self::connect_with_password_inner(params),
537 )
538 .await
539 .map_err(|_| {
540 PgError::Connection(format!(
541 "Connection timeout after {:?} (TCP connect + handshake)",
542 DEFAULT_CONNECT_TIMEOUT
543 ))
544 })?;
545 let backend = result
546 .as_ref()
547 .map(|conn| connect_backend_for_stream(&conn.stream))
548 .unwrap_or(attempt_backend);
549 record_connect_result(
550 CONNECT_TRANSPORT_PLAIN,
551 backend,
552 &result,
553 connect_started.elapsed(),
554 );
555 result
556 }
557
558 async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
560 let ConnectParams {
561 host,
562 port,
563 user,
564 database,
565 password,
566 auth_settings,
567 gss_token_provider,
568 gss_token_provider_ex,
569 protocol_minor,
570 startup_params,
571 } = params;
572 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
573 let addr = format!("{}:{}", host, port);
574 let stream = Self::connect_plain_stream(&addr).await?;
575
576 let mut conn = Self {
577 stream,
578 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
579 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
581 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
583 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
584 column_info_cache: HashMap::new(),
585 process_id: 0,
586 secret_key: 0,
587 cancel_key_bytes: Vec::new(),
588 requested_protocol_minor: protocol_minor,
589 negotiated_protocol_minor: protocol_minor,
590 notifications: VecDeque::new(),
591 replication_stream_active: false,
592 replication_mode_enabled,
593 last_replication_wal_end: None,
594 io_desynced: false,
595 pending_statement_closes: Vec::new(),
596 draining_statement_closes: false,
597 };
598
599 conn.send(FrontendMessage::Startup {
600 user: user.to_string(),
601 database: database.to_string(),
602 protocol_version: protocol_version_from_minor(protocol_minor),
603 startup_params,
604 })
605 .await?;
606
607 conn.handle_startup(
608 user,
609 password,
610 auth_settings,
611 gss_token_provider,
612 gss_token_provider_ex,
613 )
614 .await?;
615
616 Ok(conn)
617 }
618
619 async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
620 let tcp_stream = TcpStream::connect(addr).await?;
621 tcp_stream.set_nodelay(true)?;
622
623 #[cfg(all(target_os = "linux", feature = "io_uring"))]
624 {
625 if should_try_uring_plain() {
626 match super::super::uring::UringTcpStream::from_tokio(tcp_stream) {
627 Ok(uring_stream) => {
628 tracing::info!(
629 addr = %addr,
630 "qail-pg: using io_uring plain TCP transport"
631 );
632 return Ok(PgStream::Uring(uring_stream));
633 }
634 Err(e) => {
635 tracing::warn!(
636 addr = %addr,
637 error = %e,
638 "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
639 );
640 let fallback = TcpStream::connect(addr).await?;
641 fallback.set_nodelay(true)?;
642 return Ok(PgStream::Tcp(fallback));
643 }
644 }
645 }
646 }
647
648 Ok(PgStream::Tcp(tcp_stream))
649 }
650
651 pub async fn connect_tls(
654 host: &str,
655 port: u16,
656 user: &str,
657 database: &str,
658 password: Option<&str>,
659 ) -> PgResult<Self> {
660 Self::connect_tls_with_auth(
661 host,
662 port,
663 user,
664 database,
665 password,
666 AuthSettings::default(),
667 None,
668 )
669 .await
670 }
671
672 pub async fn connect_tls_with_auth(
674 host: &str,
675 port: u16,
676 user: &str,
677 database: &str,
678 password: Option<&str>,
679 auth_settings: AuthSettings,
680 ca_cert_pem: Option<&[u8]>,
681 ) -> PgResult<Self> {
682 Self::connect_tls_with_auth_and_gss(
683 ConnectParams {
684 host,
685 port,
686 user,
687 database,
688 password,
689 auth_settings,
690 gss_token_provider: None,
691 gss_token_provider_ex: None,
692 protocol_minor: Self::default_protocol_minor(),
693 startup_params: Vec::new(),
694 },
695 ca_cert_pem,
696 )
697 .await
698 }
699
700 async fn connect_tls_with_auth_and_gss(
701 params: ConnectParams<'_>,
702 ca_cert_pem: Option<&[u8]>,
703 ) -> PgResult<Self> {
704 let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
705 if let Err(err) = &first {
706 if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
707 let mut downgraded = params;
708 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
709 return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
710 }
711 }
712 first
713 }
714
715 async fn connect_tls_with_auth_and_gss_once(
716 params: ConnectParams<'_>,
717 ca_cert_pem: Option<&[u8]>,
718 ) -> PgResult<Self> {
719 let connect_started = Instant::now();
720 record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
721 let result = tokio::time::timeout(
722 DEFAULT_CONNECT_TIMEOUT,
723 Self::connect_tls_inner(params, ca_cert_pem),
724 )
725 .await
726 .map_err(|_| {
727 PgError::Connection(format!(
728 "TLS connection timeout after {:?}",
729 DEFAULT_CONNECT_TIMEOUT
730 ))
731 })?;
732 record_connect_result(
733 CONNECT_TRANSPORT_TLS,
734 CONNECT_BACKEND_TOKIO,
735 &result,
736 connect_started.elapsed(),
737 );
738 result
739 }
740
741 async fn connect_tls_inner(
743 params: ConnectParams<'_>,
744 ca_cert_pem: Option<&[u8]>,
745 ) -> PgResult<Self> {
746 let ConnectParams {
747 host,
748 port,
749 user,
750 database,
751 password,
752 auth_settings,
753 gss_token_provider,
754 gss_token_provider_ex,
755 protocol_minor,
756 startup_params,
757 } = params;
758 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
759 use tokio::io::AsyncReadExt;
760 use tokio_rustls::TlsConnector;
761 use tokio_rustls::rustls::ClientConfig;
762 use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
763
764 let addr = format!("{}:{}", host, port);
765 let mut tcp_stream = TcpStream::connect(&addr).await?;
766
767 tcp_stream.write_all(&SSL_REQUEST).await?;
769
770 let mut response = [0u8; 1];
772 tcp_stream.read_exact(&mut response).await?;
773
774 if response[0] != b'S' {
775 return Err(PgError::Connection(
776 "Server does not support TLS".to_string(),
777 ));
778 }
779
780 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
781
782 if let Some(ca_pem) = ca_cert_pem {
783 let certs = CertificateDer::pem_slice_iter(ca_pem)
784 .collect::<Result<Vec<_>, _>>()
785 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
786 if certs.is_empty() {
787 return Err(PgError::Connection(
788 "No CA certificates found in provided PEM".to_string(),
789 ));
790 }
791 for cert in certs {
792 let _ = root_cert_store.add(cert);
793 }
794 } else {
795 let certs = rustls_native_certs::load_native_certs();
796 for cert in certs.certs {
797 let _ = root_cert_store.add(cert);
798 }
799 }
800
801 let config = ClientConfig::builder()
802 .with_root_certificates(root_cert_store)
803 .with_no_client_auth();
804
805 let connector = TlsConnector::from(Arc::new(config));
806 let server_name = ServerName::try_from(host.to_string())
807 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
808
809 let tls_stream = connector
810 .connect(server_name, tcp_stream)
811 .await
812 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
813
814 let mut conn = Self {
815 stream: PgStream::Tls(Box::new(tls_stream)),
816 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
817 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
818 sql_buf: BytesMut::with_capacity(512),
819 params_buf: Vec::with_capacity(16),
820 prepared_statements: HashMap::new(),
821 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
822 column_info_cache: HashMap::new(),
823 process_id: 0,
824 secret_key: 0,
825 cancel_key_bytes: Vec::new(),
826 requested_protocol_minor: protocol_minor,
827 negotiated_protocol_minor: protocol_minor,
828 notifications: VecDeque::new(),
829 replication_stream_active: false,
830 replication_mode_enabled,
831 last_replication_wal_end: None,
832 io_desynced: false,
833 pending_statement_closes: Vec::new(),
834 draining_statement_closes: false,
835 };
836
837 conn.send(FrontendMessage::Startup {
838 user: user.to_string(),
839 database: database.to_string(),
840 protocol_version: protocol_version_from_minor(protocol_minor),
841 startup_params,
842 })
843 .await?;
844
845 conn.handle_startup(
846 user,
847 password,
848 auth_settings,
849 gss_token_provider,
850 gss_token_provider_ex,
851 )
852 .await?;
853
854 Ok(conn)
855 }
856
857 pub async fn connect_mtls(
874 host: &str,
875 port: u16,
876 user: &str,
877 database: &str,
878 config: TlsConfig,
879 ) -> PgResult<Self> {
880 Self::connect_mtls_with_password_and_auth(
881 host,
882 port,
883 user,
884 database,
885 None,
886 config,
887 AuthSettings::default(),
888 )
889 .await
890 }
891
892 pub async fn connect_mtls_with_password_and_auth(
894 host: &str,
895 port: u16,
896 user: &str,
897 database: &str,
898 password: Option<&str>,
899 config: TlsConfig,
900 auth_settings: AuthSettings,
901 ) -> PgResult<Self> {
902 Self::connect_mtls_with_password_and_auth_and_gss(
903 ConnectParams {
904 host,
905 port,
906 user,
907 database,
908 password,
909 auth_settings,
910 gss_token_provider: None,
911 gss_token_provider_ex: None,
912 protocol_minor: Self::default_protocol_minor(),
913 startup_params: Vec::new(),
914 },
915 config,
916 )
917 .await
918 }
919
920 async fn connect_mtls_with_password_and_auth_and_gss(
921 params: ConnectParams<'_>,
922 config: TlsConfig,
923 ) -> PgResult<Self> {
924 let first =
925 Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
926 .await;
927 if let Err(err) = &first {
928 if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
929 let mut downgraded = params;
930 downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
931 return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
932 .await;
933 }
934 }
935 first
936 }
937
938 async fn connect_mtls_with_password_and_auth_and_gss_once(
939 params: ConnectParams<'_>,
940 config: TlsConfig,
941 ) -> PgResult<Self> {
942 let connect_started = Instant::now();
943 record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
944 let result = tokio::time::timeout(
945 DEFAULT_CONNECT_TIMEOUT,
946 Self::connect_mtls_inner(params, config),
947 )
948 .await
949 .map_err(|_| {
950 PgError::Connection(format!(
951 "mTLS connection timeout after {:?}",
952 DEFAULT_CONNECT_TIMEOUT
953 ))
954 })?;
955 record_connect_result(
956 CONNECT_TRANSPORT_MTLS,
957 CONNECT_BACKEND_TOKIO,
958 &result,
959 connect_started.elapsed(),
960 );
961 result
962 }
963
964 async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
966 let ConnectParams {
967 host,
968 port,
969 user,
970 database,
971 password,
972 auth_settings,
973 gss_token_provider,
974 gss_token_provider_ex,
975 protocol_minor,
976 startup_params,
977 } = params;
978 let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
979 use tokio::io::AsyncReadExt;
980 use tokio_rustls::TlsConnector;
981 use tokio_rustls::rustls::{
982 ClientConfig,
983 pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
984 };
985
986 let addr = format!("{}:{}", host, port);
987 let mut tcp_stream = TcpStream::connect(&addr).await?;
988
989 tcp_stream.write_all(&SSL_REQUEST).await?;
991
992 let mut response = [0u8; 1];
994 tcp_stream.read_exact(&mut response).await?;
995
996 if response[0] != b'S' {
997 return Err(PgError::Connection(
998 "Server does not support TLS".to_string(),
999 ));
1000 }
1001
1002 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1003
1004 if let Some(ca_pem) = &config.ca_cert_pem {
1005 let certs = CertificateDer::pem_slice_iter(ca_pem)
1006 .collect::<Result<Vec<_>, _>>()
1007 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1008 if certs.is_empty() {
1009 return Err(PgError::Connection(
1010 "No CA certificates found in provided PEM".to_string(),
1011 ));
1012 }
1013 for cert in certs {
1014 let _ = root_cert_store.add(cert);
1015 }
1016 } else {
1017 let certs = rustls_native_certs::load_native_certs();
1019 for cert in certs.certs {
1020 let _ = root_cert_store.add(cert);
1021 }
1022 }
1023
1024 let client_certs: Vec<CertificateDer<'static>> =
1025 CertificateDer::pem_slice_iter(&config.client_cert_pem)
1026 .collect::<Result<Vec<_>, _>>()
1027 .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1028 if client_certs.is_empty() {
1029 return Err(PgError::Connection(
1030 "No client certificates found in PEM".to_string(),
1031 ));
1032 }
1033
1034 let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1035 .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1036
1037 let tls_config = ClientConfig::builder()
1038 .with_root_certificates(root_cert_store)
1039 .with_client_auth_cert(client_certs, client_key)
1040 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1041
1042 let connector = TlsConnector::from(Arc::new(tls_config));
1043 let server_name = ServerName::try_from(host.to_string())
1044 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1045
1046 let tls_stream = connector
1047 .connect(server_name, tcp_stream)
1048 .await
1049 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1050
1051 let mut conn = Self {
1052 stream: PgStream::Tls(Box::new(tls_stream)),
1053 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1054 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1055 sql_buf: BytesMut::with_capacity(512),
1056 params_buf: Vec::with_capacity(16),
1057 prepared_statements: HashMap::new(),
1058 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1059 column_info_cache: HashMap::new(),
1060 process_id: 0,
1061 secret_key: 0,
1062 cancel_key_bytes: Vec::new(),
1063 requested_protocol_minor: protocol_minor,
1064 negotiated_protocol_minor: protocol_minor,
1065 notifications: VecDeque::new(),
1066 replication_stream_active: false,
1067 replication_mode_enabled,
1068 last_replication_wal_end: None,
1069 io_desynced: false,
1070 pending_statement_closes: Vec::new(),
1071 draining_statement_closes: false,
1072 };
1073
1074 conn.send(FrontendMessage::Startup {
1075 user: user.to_string(),
1076 database: database.to_string(),
1077 protocol_version: protocol_version_from_minor(protocol_minor),
1078 startup_params,
1079 })
1080 .await?;
1081
1082 conn.handle_startup(
1083 user,
1084 password,
1085 auth_settings,
1086 gss_token_provider,
1087 gss_token_provider_ex,
1088 )
1089 .await?;
1090
1091 Ok(conn)
1092 }
1093
1094 #[cfg(unix)]
1096 pub async fn connect_unix(
1097 socket_path: &str,
1098 user: &str,
1099 database: &str,
1100 password: Option<&str>,
1101 ) -> PgResult<Self> {
1102 let default_minor = Self::default_protocol_minor();
1103 let first =
1104 Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1105 .await;
1106 if let Err(err) = &first {
1107 if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
1108 let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1109 return Self::connect_unix_with_protocol(
1110 socket_path,
1111 user,
1112 database,
1113 password,
1114 downgrade_minor,
1115 )
1116 .await;
1117 }
1118 }
1119 first
1120 }
1121
1122 #[cfg(unix)]
1123 async fn connect_unix_with_protocol(
1124 socket_path: &str,
1125 user: &str,
1126 database: &str,
1127 password: Option<&str>,
1128 protocol_minor: u16,
1129 ) -> PgResult<Self> {
1130 use tokio::net::UnixStream;
1131
1132 let unix_stream = UnixStream::connect(socket_path).await?;
1133
1134 let mut conn = Self {
1135 stream: PgStream::Unix(unix_stream),
1136 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1137 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1138 sql_buf: BytesMut::with_capacity(512),
1139 params_buf: Vec::with_capacity(16),
1140 prepared_statements: HashMap::new(),
1141 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1142 column_info_cache: HashMap::new(),
1143 process_id: 0,
1144 secret_key: 0,
1145 cancel_key_bytes: Vec::new(),
1146 requested_protocol_minor: protocol_minor,
1147 negotiated_protocol_minor: protocol_minor,
1148 notifications: VecDeque::new(),
1149 replication_stream_active: false,
1150 replication_mode_enabled: false,
1151 last_replication_wal_end: None,
1152 io_desynced: false,
1153 pending_statement_closes: Vec::new(),
1154 draining_statement_closes: false,
1155 };
1156
1157 conn.send(FrontendMessage::Startup {
1158 user: user.to_string(),
1159 database: database.to_string(),
1160 protocol_version: protocol_version_from_minor(protocol_minor),
1161 startup_params: Vec::new(),
1162 })
1163 .await?;
1164
1165 conn.handle_startup(user, password, AuthSettings::default(), None, None)
1166 .await?;
1167
1168 Ok(conn)
1169 }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor};
1175 use crate::driver::PgError;
1176
1177 #[test]
1178 fn protocol_version_from_minor_encodes_major_3() {
1179 assert_eq!(protocol_version_from_minor(2), 196610);
1180 assert_eq!(protocol_version_from_minor(0), 196608);
1181 }
1182
1183 #[test]
1184 fn explicit_protocol_rejection_detection_is_case_insensitive() {
1185 let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1186 assert!(is_explicit_protocol_version_rejection(&err));
1187
1188 let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1189 assert!(is_explicit_protocol_version_rejection(&err));
1190 }
1191
1192 #[test]
1193 fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1194 let err = PgError::Connection("connection reset by peer".to_string());
1195 assert!(!is_explicit_protocol_version_rejection(&err));
1196 }
1197}