1use super::notification::Notification;
16use super::stream::PgStream;
17use super::{
18 AuthSettings, ConnectOptions, EnterpriseAuthMechanism, GssEncMode, GssTokenProvider,
19 GssTokenProviderEx, GssTokenRequest, PgError, PgResult, ScramChannelBindingMode, TlsMode,
20};
21use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
22use bytes::BytesMut;
23use sha2::{Digest, Sha256};
24use std::collections::{HashMap, VecDeque};
25use std::num::NonZeroUsize;
26use std::sync::Arc;
27use std::sync::atomic::{AtomicU64, Ordering};
28use tokio::io::AsyncWriteExt;
29use tokio::net::TcpStream;
30
31const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
33
34#[derive(Debug)]
39pub(crate) struct StatementCache {
40 capacity: NonZeroUsize,
41 entries: HashMap<u64, String>,
42 order: VecDeque<u64>, }
44
45impl StatementCache {
46 pub(crate) fn new(capacity: NonZeroUsize) -> Self {
47 Self {
48 capacity,
49 entries: HashMap::with_capacity(capacity.get()),
50 order: VecDeque::with_capacity(capacity.get()),
51 }
52 }
53
54 pub(crate) fn len(&self) -> usize {
55 self.entries.len()
56 }
57
58 pub(crate) fn cap(&self) -> NonZeroUsize {
59 self.capacity
60 }
61
62 pub(crate) fn contains(&self, key: &u64) -> bool {
63 self.entries.contains_key(key)
64 }
65
66 pub(crate) fn get(&mut self, key: &u64) -> Option<String> {
67 let value = self.entries.get(key).cloned()?;
68 self.touch(*key);
69 Some(value)
70 }
71
72 pub(crate) fn put(&mut self, key: u64, value: String) {
73 if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
74 e.insert(value);
75 self.touch(key);
76 return;
77 }
78
79 if self.entries.len() >= self.capacity.get() {
80 let _ = self.pop_lru();
81 }
82
83 self.entries.insert(key, value);
84 self.order.push_back(key);
85 }
86
87 pub(crate) fn pop_lru(&mut self) -> Option<(u64, String)> {
88 while let Some(key) = self.order.pop_front() {
89 if let Some(value) = self.entries.remove(&key) {
90 return Some((key, value));
91 }
92 }
93 None
94 }
95
96 pub(crate) fn clear(&mut self) {
97 self.entries.clear();
98 self.order.clear();
99 }
100
101 fn touch(&mut self, key: u64) {
102 self.order.retain(|k| *k != key);
103 self.order.push_back(key);
104 }
105}
106
107pub(crate) const BUFFER_CAPACITY: usize = 65536;
109
110const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
112
113const GSSENC_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 48];
116
117#[derive(Debug)]
119enum GssEncNegotiationResult {
120 Accepted(TcpStream),
124 Rejected,
126 ServerError,
129}
130
131pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
133
134static GSS_SESSION_COUNTER: AtomicU64 = AtomicU64::new(1);
136
137pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
140
141#[derive(Debug, Clone)]
143pub struct TlsConfig {
144 pub client_cert_pem: Vec<u8>,
146 pub client_key_pem: Vec<u8>,
148 pub ca_cert_pem: Option<Vec<u8>>,
150}
151
152impl TlsConfig {
153 pub fn from_files(
155 cert_path: impl AsRef<std::path::Path>,
156 key_path: impl AsRef<std::path::Path>,
157 ca_path: Option<impl AsRef<std::path::Path>>,
158 ) -> std::io::Result<Self> {
159 Ok(Self {
160 client_cert_pem: std::fs::read(cert_path)?,
161 client_key_pem: std::fs::read(key_path)?,
162 ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
163 })
164 }
165}
166
167struct ConnectParams<'a> {
172 host: &'a str,
173 port: u16,
174 user: &'a str,
175 database: &'a str,
176 password: Option<&'a str>,
177 auth_settings: AuthSettings,
178 gss_token_provider: Option<GssTokenProvider>,
179 gss_token_provider_ex: Option<GssTokenProviderEx>,
180}
181
182pub struct PgConnection {
184 pub(crate) stream: PgStream,
185 pub(crate) buffer: BytesMut,
186 pub(crate) write_buf: BytesMut,
187 pub(crate) sql_buf: BytesMut,
188 pub(crate) params_buf: Vec<Option<Vec<u8>>>,
189 pub(crate) prepared_statements: HashMap<String, String>,
190 pub(crate) stmt_cache: StatementCache,
191 pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
195 pub(crate) process_id: i32,
196 pub(crate) secret_key: i32,
197 pub(crate) notifications: VecDeque<Notification>,
200}
201
202impl PgConnection {
203 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
212 Self::connect_with_password(host, port, user, database, None).await
213 }
214
215 pub async fn connect_with_password(
218 host: &str,
219 port: u16,
220 user: &str,
221 database: &str,
222 password: Option<&str>,
223 ) -> PgResult<Self> {
224 Self::connect_with_password_and_auth(
225 host,
226 port,
227 user,
228 database,
229 password,
230 AuthSettings::default(),
231 )
232 .await
233 }
234
235 pub async fn connect_with_options(
242 host: &str,
243 port: u16,
244 user: &str,
245 database: &str,
246 password: Option<&str>,
247 options: ConnectOptions,
248 ) -> PgResult<Self> {
249 let ConnectOptions {
250 tls_mode,
251 gss_enc_mode,
252 tls_ca_cert_pem,
253 mtls,
254 gss_token_provider,
255 gss_token_provider_ex,
256 auth,
257 } = options;
258
259 if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
260 return Err(PgError::Connection(
261 "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
262 ));
263 }
264
265 if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
269 return Err(PgError::Connection(
270 "gssencmode=require is incompatible with mTLS — both provide \
271 transport encryption; use one or the other"
272 .to_string(),
273 ));
274 }
275
276 if let Some(mtls_config) = mtls {
277 return Self::connect_mtls_with_password_and_auth_and_gss(
280 ConnectParams {
281 host,
282 port,
283 user,
284 database,
285 password,
286 auth_settings: auth,
287 gss_token_provider,
288 gss_token_provider_ex,
289 },
290 mtls_config,
291 )
292 .await;
293 }
294
295 if gss_enc_mode != GssEncMode::Disable {
297 match Self::try_gssenc_request(host, port).await {
298 Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
299 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
300 {
301 let gssenc_fut = async {
302 let gss_stream = super::gss::gssenc_handshake(tcp_stream, host)
303 .await
304 .map_err(PgError::Auth)?;
305 let mut conn = Self {
306 stream: PgStream::GssEnc(gss_stream),
307 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
308 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
309 sql_buf: BytesMut::with_capacity(512),
310 params_buf: Vec::with_capacity(16),
311 prepared_statements: HashMap::new(),
312 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
313 column_info_cache: HashMap::new(),
314 process_id: 0,
315 secret_key: 0,
316 notifications: VecDeque::new(),
317 };
318 conn.send(FrontendMessage::Startup {
319 user: user.to_string(),
320 database: database.to_string(),
321 })
322 .await?;
323 conn.handle_startup(
324 user,
325 password,
326 auth,
327 gss_token_provider,
328 gss_token_provider_ex,
329 )
330 .await?;
331 Ok(conn)
332 };
333 return tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
334 .await
335 .map_err(|_| {
336 PgError::Connection(format!(
337 "GSSENC connection timeout after {:?} \
338 (handshake + auth)",
339 DEFAULT_CONNECT_TIMEOUT
340 ))
341 })?;
342 }
343 #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
344 {
345 let _ = tcp_stream;
346 return Err(PgError::Connection(
347 "Server accepted GSSENCRequest but GSSAPI encryption requires \
348 feature enterprise-gssapi on Linux"
349 .to_string(),
350 ));
351 }
352 }
353 Ok(GssEncNegotiationResult::Rejected)
354 | Ok(GssEncNegotiationResult::ServerError) => {
355 if gss_enc_mode == GssEncMode::Require {
356 return Err(PgError::Connection(
357 "gssencmode=require but server rejected GSSENCRequest".to_string(),
358 ));
359 }
360 }
362 Err(e) => {
363 if gss_enc_mode == GssEncMode::Require {
364 return Err(e);
365 }
366 tracing::debug!(
368 host = %host,
369 port = %port,
370 error = %e,
371 "gssenc_prefer_fallthrough"
372 );
373 }
374 }
375 }
376
377 match tls_mode {
379 TlsMode::Disable => {
380 Self::connect_with_password_and_auth_and_gss(ConnectParams {
381 host,
382 port,
383 user,
384 database,
385 password,
386 auth_settings: auth,
387 gss_token_provider,
388 gss_token_provider_ex,
389 })
390 .await
391 }
392 TlsMode::Require => {
393 Self::connect_tls_with_auth_and_gss(
394 ConnectParams {
395 host,
396 port,
397 user,
398 database,
399 password,
400 auth_settings: auth,
401 gss_token_provider,
402 gss_token_provider_ex,
403 },
404 tls_ca_cert_pem.as_deref(),
405 )
406 .await
407 }
408 TlsMode::Prefer => {
409 match Self::connect_tls_with_auth_and_gss(
410 ConnectParams {
411 host,
412 port,
413 user,
414 database,
415 password,
416 auth_settings: auth,
417 gss_token_provider,
418 gss_token_provider_ex: gss_token_provider_ex.clone(),
419 },
420 tls_ca_cert_pem.as_deref(),
421 )
422 .await
423 {
424 Ok(conn) => Ok(conn),
425 Err(PgError::Connection(msg))
426 if msg.contains("Server does not support TLS") =>
427 {
428 Self::connect_with_password_and_auth_and_gss(ConnectParams {
429 host,
430 port,
431 user,
432 database,
433 password,
434 auth_settings: auth,
435 gss_token_provider,
436 gss_token_provider_ex,
437 })
438 .await
439 }
440 Err(e) => Err(e),
441 }
442 }
443 }
444 }
445
446 async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
453 tokio::time::timeout(
454 DEFAULT_CONNECT_TIMEOUT,
455 Self::try_gssenc_request_inner(host, port),
456 )
457 .await
458 .map_err(|_| {
459 PgError::Connection(format!(
460 "GSSENCRequest timeout after {:?}",
461 DEFAULT_CONNECT_TIMEOUT
462 ))
463 })?
464 }
465
466 async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
468 use tokio::io::AsyncReadExt;
469
470 let addr = format!("{}:{}", host, port);
471 let mut tcp_stream = TcpStream::connect(&addr).await?;
472 tcp_stream.set_nodelay(true)?;
473
474 tcp_stream.write_all(&GSSENC_REQUEST).await?;
476 tcp_stream.flush().await?;
477
478 let mut response = [0u8; 1];
482 tcp_stream.read_exact(&mut response).await?;
483
484 match response[0] {
485 b'G' => {
486 let mut peek_buf = [0u8; 1];
489 match tcp_stream.try_read(&mut peek_buf) {
490 Ok(0) => {} Ok(_n) => {
492 return Err(PgError::Connection(
494 "Protocol violation: extra bytes after GSSENCRequest 'G' response \
495 (possible CVE-2021-23222 buffer-stuffing attack)"
496 .to_string(),
497 ));
498 }
499 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
500 }
502 Err(e) => {
503 return Err(PgError::Io(e));
504 }
505 }
506 Ok(GssEncNegotiationResult::Accepted(tcp_stream))
507 }
508 b'N' => Ok(GssEncNegotiationResult::Rejected),
509 b'E' => {
510 tracing::trace!(
514 host = %host,
515 port = %port,
516 "gssenc_request_server_error (suppressed per CVE-2024-10977)"
517 );
518 Ok(GssEncNegotiationResult::ServerError)
519 }
520 other => Err(PgError::Connection(format!(
521 "Unexpected response to GSSENCRequest: 0x{:02X} \
522 (expected 'G'=0x47 or 'N'=0x4E)",
523 other
524 ))),
525 }
526 }
527
528 pub async fn connect_with_password_and_auth(
530 host: &str,
531 port: u16,
532 user: &str,
533 database: &str,
534 password: Option<&str>,
535 auth_settings: AuthSettings,
536 ) -> PgResult<Self> {
537 Self::connect_with_password_and_auth_and_gss(ConnectParams {
538 host,
539 port,
540 user,
541 database,
542 password,
543 auth_settings,
544 gss_token_provider: None,
545 gss_token_provider_ex: None,
546 })
547 .await
548 }
549
550 async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
551 tokio::time::timeout(
552 DEFAULT_CONNECT_TIMEOUT,
553 Self::connect_with_password_inner(params),
554 )
555 .await
556 .map_err(|_| {
557 PgError::Connection(format!(
558 "Connection timeout after {:?} (TCP connect + handshake)",
559 DEFAULT_CONNECT_TIMEOUT
560 ))
561 })?
562 }
563
564 async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
566 let ConnectParams {
567 host,
568 port,
569 user,
570 database,
571 password,
572 auth_settings,
573 gss_token_provider,
574 gss_token_provider_ex,
575 } = params;
576 let addr = format!("{}:{}", host, port);
577 let tcp_stream = TcpStream::connect(&addr).await?;
578
579 tcp_stream.set_nodelay(true)?;
581
582 let mut conn = Self {
583 stream: PgStream::Tcp(tcp_stream),
584 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
585 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
587 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
589 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
590 column_info_cache: HashMap::new(),
591 process_id: 0,
592 secret_key: 0,
593 notifications: VecDeque::new(),
594 };
595
596 conn.send(FrontendMessage::Startup {
597 user: user.to_string(),
598 database: database.to_string(),
599 })
600 .await?;
601
602 conn.handle_startup(
603 user,
604 password,
605 auth_settings,
606 gss_token_provider,
607 gss_token_provider_ex,
608 )
609 .await?;
610
611 Ok(conn)
612 }
613
614 pub async fn connect_tls(
617 host: &str,
618 port: u16,
619 user: &str,
620 database: &str,
621 password: Option<&str>,
622 ) -> PgResult<Self> {
623 Self::connect_tls_with_auth(
624 host,
625 port,
626 user,
627 database,
628 password,
629 AuthSettings::default(),
630 None,
631 )
632 .await
633 }
634
635 pub async fn connect_tls_with_auth(
637 host: &str,
638 port: u16,
639 user: &str,
640 database: &str,
641 password: Option<&str>,
642 auth_settings: AuthSettings,
643 ca_cert_pem: Option<&[u8]>,
644 ) -> PgResult<Self> {
645 Self::connect_tls_with_auth_and_gss(
646 ConnectParams {
647 host,
648 port,
649 user,
650 database,
651 password,
652 auth_settings,
653 gss_token_provider: None,
654 gss_token_provider_ex: None,
655 },
656 ca_cert_pem,
657 )
658 .await
659 }
660
661 async fn connect_tls_with_auth_and_gss(
662 params: ConnectParams<'_>,
663 ca_cert_pem: Option<&[u8]>,
664 ) -> PgResult<Self> {
665 tokio::time::timeout(
666 DEFAULT_CONNECT_TIMEOUT,
667 Self::connect_tls_inner(params, ca_cert_pem),
668 )
669 .await
670 .map_err(|_| {
671 PgError::Connection(format!(
672 "TLS connection timeout after {:?}",
673 DEFAULT_CONNECT_TIMEOUT
674 ))
675 })?
676 }
677
678 async fn connect_tls_inner(
680 params: ConnectParams<'_>,
681 ca_cert_pem: Option<&[u8]>,
682 ) -> PgResult<Self> {
683 let ConnectParams {
684 host,
685 port,
686 user,
687 database,
688 password,
689 auth_settings,
690 gss_token_provider,
691 gss_token_provider_ex,
692 } = params;
693 use tokio::io::AsyncReadExt;
694 use tokio_rustls::TlsConnector;
695 use tokio_rustls::rustls::ClientConfig;
696 use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
697
698 let addr = format!("{}:{}", host, port);
699 let mut tcp_stream = TcpStream::connect(&addr).await?;
700
701 tcp_stream.write_all(&SSL_REQUEST).await?;
703
704 let mut response = [0u8; 1];
706 tcp_stream.read_exact(&mut response).await?;
707
708 if response[0] != b'S' {
709 return Err(PgError::Connection(
710 "Server does not support TLS".to_string(),
711 ));
712 }
713
714 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
715
716 if let Some(ca_pem) = ca_cert_pem {
717 let certs = CertificateDer::pem_slice_iter(ca_pem)
718 .collect::<Result<Vec<_>, _>>()
719 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
720 if certs.is_empty() {
721 return Err(PgError::Connection(
722 "No CA certificates found in provided PEM".to_string(),
723 ));
724 }
725 for cert in certs {
726 let _ = root_cert_store.add(cert);
727 }
728 } else {
729 let certs = rustls_native_certs::load_native_certs();
730 for cert in certs.certs {
731 let _ = root_cert_store.add(cert);
732 }
733 }
734
735 let config = ClientConfig::builder()
736 .with_root_certificates(root_cert_store)
737 .with_no_client_auth();
738
739 let connector = TlsConnector::from(Arc::new(config));
740 let server_name = ServerName::try_from(host.to_string())
741 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
742
743 let tls_stream = connector
744 .connect(server_name, tcp_stream)
745 .await
746 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
747
748 let mut conn = Self {
749 stream: PgStream::Tls(Box::new(tls_stream)),
750 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
751 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
752 sql_buf: BytesMut::with_capacity(512),
753 params_buf: Vec::with_capacity(16),
754 prepared_statements: HashMap::new(),
755 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
756 column_info_cache: HashMap::new(),
757 process_id: 0,
758 secret_key: 0,
759 notifications: VecDeque::new(),
760 };
761
762 conn.send(FrontendMessage::Startup {
763 user: user.to_string(),
764 database: database.to_string(),
765 })
766 .await?;
767
768 conn.handle_startup(
769 user,
770 password,
771 auth_settings,
772 gss_token_provider,
773 gss_token_provider_ex,
774 )
775 .await?;
776
777 Ok(conn)
778 }
779
780 pub async fn connect_mtls(
797 host: &str,
798 port: u16,
799 user: &str,
800 database: &str,
801 config: TlsConfig,
802 ) -> PgResult<Self> {
803 Self::connect_mtls_with_password_and_auth(
804 host,
805 port,
806 user,
807 database,
808 None,
809 config,
810 AuthSettings::default(),
811 )
812 .await
813 }
814
815 pub async fn connect_mtls_with_password_and_auth(
817 host: &str,
818 port: u16,
819 user: &str,
820 database: &str,
821 password: Option<&str>,
822 config: TlsConfig,
823 auth_settings: AuthSettings,
824 ) -> PgResult<Self> {
825 Self::connect_mtls_with_password_and_auth_and_gss(
826 ConnectParams {
827 host,
828 port,
829 user,
830 database,
831 password,
832 auth_settings,
833 gss_token_provider: None,
834 gss_token_provider_ex: None,
835 },
836 config,
837 )
838 .await
839 }
840
841 async fn connect_mtls_with_password_and_auth_and_gss(
842 params: ConnectParams<'_>,
843 config: TlsConfig,
844 ) -> PgResult<Self> {
845 tokio::time::timeout(
846 DEFAULT_CONNECT_TIMEOUT,
847 Self::connect_mtls_inner(params, config),
848 )
849 .await
850 .map_err(|_| {
851 PgError::Connection(format!(
852 "mTLS connection timeout after {:?}",
853 DEFAULT_CONNECT_TIMEOUT
854 ))
855 })?
856 }
857
858 async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
860 let ConnectParams {
861 host,
862 port,
863 user,
864 database,
865 password,
866 auth_settings,
867 gss_token_provider,
868 gss_token_provider_ex,
869 } = params;
870 use tokio::io::AsyncReadExt;
871 use tokio_rustls::TlsConnector;
872 use tokio_rustls::rustls::{
873 ClientConfig,
874 pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
875 };
876
877 let addr = format!("{}:{}", host, port);
878 let mut tcp_stream = TcpStream::connect(&addr).await?;
879
880 tcp_stream.write_all(&SSL_REQUEST).await?;
882
883 let mut response = [0u8; 1];
885 tcp_stream.read_exact(&mut response).await?;
886
887 if response[0] != b'S' {
888 return Err(PgError::Connection(
889 "Server does not support TLS".to_string(),
890 ));
891 }
892
893 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
894
895 if let Some(ca_pem) = &config.ca_cert_pem {
896 let certs = CertificateDer::pem_slice_iter(ca_pem)
897 .collect::<Result<Vec<_>, _>>()
898 .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
899 if certs.is_empty() {
900 return Err(PgError::Connection(
901 "No CA certificates found in provided PEM".to_string(),
902 ));
903 }
904 for cert in certs {
905 let _ = root_cert_store.add(cert);
906 }
907 } else {
908 let certs = rustls_native_certs::load_native_certs();
910 for cert in certs.certs {
911 let _ = root_cert_store.add(cert);
912 }
913 }
914
915 let client_certs: Vec<CertificateDer<'static>> =
916 CertificateDer::pem_slice_iter(&config.client_cert_pem)
917 .collect::<Result<Vec<_>, _>>()
918 .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
919 if client_certs.is_empty() {
920 return Err(PgError::Connection(
921 "No client certificates found in PEM".to_string(),
922 ));
923 }
924
925 let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
926 .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
927
928 let tls_config = ClientConfig::builder()
929 .with_root_certificates(root_cert_store)
930 .with_client_auth_cert(client_certs, client_key)
931 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
932
933 let connector = TlsConnector::from(Arc::new(tls_config));
934 let server_name = ServerName::try_from(host.to_string())
935 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
936
937 let tls_stream = connector
938 .connect(server_name, tcp_stream)
939 .await
940 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
941
942 let mut conn = Self {
943 stream: PgStream::Tls(Box::new(tls_stream)),
944 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
945 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
946 sql_buf: BytesMut::with_capacity(512),
947 params_buf: Vec::with_capacity(16),
948 prepared_statements: HashMap::new(),
949 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
950 column_info_cache: HashMap::new(),
951 process_id: 0,
952 secret_key: 0,
953 notifications: VecDeque::new(),
954 };
955
956 conn.send(FrontendMessage::Startup {
957 user: user.to_string(),
958 database: database.to_string(),
959 })
960 .await?;
961
962 conn.handle_startup(
963 user,
964 password,
965 auth_settings,
966 gss_token_provider,
967 gss_token_provider_ex,
968 )
969 .await?;
970
971 Ok(conn)
972 }
973
974 #[cfg(unix)]
976 pub async fn connect_unix(
977 socket_path: &str,
978 user: &str,
979 database: &str,
980 password: Option<&str>,
981 ) -> PgResult<Self> {
982 use tokio::net::UnixStream;
983
984 let unix_stream = UnixStream::connect(socket_path).await?;
985
986 let mut conn = Self {
987 stream: PgStream::Unix(unix_stream),
988 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
989 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
990 sql_buf: BytesMut::with_capacity(512),
991 params_buf: Vec::with_capacity(16),
992 prepared_statements: HashMap::new(),
993 stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
994 column_info_cache: HashMap::new(),
995 process_id: 0,
996 secret_key: 0,
997 notifications: VecDeque::new(),
998 };
999
1000 conn.send(FrontendMessage::Startup {
1001 user: user.to_string(),
1002 database: database.to_string(),
1003 })
1004 .await?;
1005
1006 conn.handle_startup(user, password, AuthSettings::default(), None, None)
1007 .await?;
1008
1009 Ok(conn)
1010 }
1011
1012 async fn handle_startup(
1014 &mut self,
1015 user: &str,
1016 password: Option<&str>,
1017 auth_settings: AuthSettings,
1018 gss_token_provider: Option<GssTokenProvider>,
1019 gss_token_provider_ex: Option<GssTokenProviderEx>,
1020 ) -> PgResult<()> {
1021 let mut scram_client: Option<ScramClient> = None;
1022 let mut gss_mechanism: Option<EnterpriseAuthMechanism> = None;
1023 let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
1024 let mut gss_roundtrips: u32 = 0;
1025 const MAX_GSS_ROUNDTRIPS: u32 = 32;
1026
1027 loop {
1028 let msg = self.recv().await?;
1029 match msg {
1030 BackendMessage::AuthenticationOk => {}
1031 BackendMessage::AuthenticationKerberosV5 => {
1032 if !auth_settings.allow_kerberos_v5 {
1033 return Err(PgError::Auth(
1034 "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
1035 ));
1036 }
1037
1038 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1039 return Err(PgError::Auth(
1040 "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1041 ));
1042 }
1043
1044 let token = generate_gss_token(
1045 gss_session_id,
1046 EnterpriseAuthMechanism::KerberosV5,
1047 None,
1048 gss_token_provider,
1049 gss_token_provider_ex.as_ref(),
1050 )
1051 .map_err(|e| {
1052 PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
1053 })?;
1054
1055 self.send(FrontendMessage::GSSResponse(token)).await?;
1056 gss_mechanism = Some(EnterpriseAuthMechanism::KerberosV5);
1057 }
1058 BackendMessage::AuthenticationGSS => {
1059 if !auth_settings.allow_gssapi {
1060 return Err(PgError::Auth(
1061 "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
1062 ));
1063 }
1064
1065 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1066 return Err(PgError::Auth(
1067 "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1068 ));
1069 }
1070
1071 let token = generate_gss_token(
1072 gss_session_id,
1073 EnterpriseAuthMechanism::GssApi,
1074 None,
1075 gss_token_provider,
1076 gss_token_provider_ex.as_ref(),
1077 )
1078 .map_err(|e| {
1079 PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
1080 })?;
1081
1082 self.send(FrontendMessage::GSSResponse(token)).await?;
1083 gss_mechanism = Some(EnterpriseAuthMechanism::GssApi);
1084 }
1085 BackendMessage::AuthenticationSSPI => {
1086 if !auth_settings.allow_sspi {
1087 return Err(PgError::Auth(
1088 "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
1089 ));
1090 }
1091
1092 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1093 return Err(PgError::Auth(
1094 "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1095 ));
1096 }
1097
1098 let token = generate_gss_token(
1099 gss_session_id,
1100 EnterpriseAuthMechanism::Sspi,
1101 None,
1102 gss_token_provider,
1103 gss_token_provider_ex.as_ref(),
1104 )
1105 .map_err(|e| {
1106 PgError::Auth(format!("SSPI initial token generation failed: {}", e))
1107 })?;
1108
1109 self.send(FrontendMessage::GSSResponse(token)).await?;
1110 gss_mechanism = Some(EnterpriseAuthMechanism::Sspi);
1111 }
1112 BackendMessage::AuthenticationGSSContinue(server_token) => {
1113 gss_roundtrips += 1;
1114 if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
1115 return Err(PgError::Auth(format!(
1116 "GSS handshake exceeded {} roundtrips — aborting",
1117 MAX_GSS_ROUNDTRIPS
1118 )));
1119 }
1120
1121 let mechanism = gss_mechanism.ok_or_else(|| {
1122 PgError::Auth(
1123 "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
1124 .to_string(),
1125 )
1126 })?;
1127
1128 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1129 return Err(PgError::Auth(
1130 "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1131 ));
1132 }
1133
1134 let token = generate_gss_token(
1135 gss_session_id,
1136 mechanism,
1137 Some(&server_token),
1138 gss_token_provider,
1139 gss_token_provider_ex.as_ref(),
1140 )
1141 .map_err(|e| {
1142 PgError::Auth(format!("GSS continue token generation failed: {}", e))
1143 })?;
1144
1145 if !token.is_empty() {
1152 self.send(FrontendMessage::GSSResponse(token)).await?;
1153 }
1154 }
1155 BackendMessage::AuthenticationCleartextPassword => {
1156 if !auth_settings.allow_cleartext_password {
1157 return Err(PgError::Auth(
1158 "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
1159 .to_string(),
1160 ));
1161 }
1162 let password = password.ok_or_else(|| {
1163 PgError::Auth("Password required for cleartext authentication".to_string())
1164 })?;
1165 self.send(FrontendMessage::PasswordMessage(password.to_string()))
1166 .await?;
1167 }
1168 BackendMessage::AuthenticationMD5Password(salt) => {
1169 if !auth_settings.allow_md5_password {
1170 return Err(PgError::Auth(
1171 "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
1172 .to_string(),
1173 ));
1174 }
1175 let password = password.ok_or_else(|| {
1176 PgError::Auth("Password required for MD5 authentication".to_string())
1177 })?;
1178 let md5_password = md5_password_message(user, password, salt);
1179 self.send(FrontendMessage::PasswordMessage(md5_password))
1180 .await?;
1181 }
1182 BackendMessage::AuthenticationSASL(mechanisms) => {
1183 if !auth_settings.allow_scram_sha_256 {
1184 return Err(PgError::Auth(
1185 "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
1186 .to_string(),
1187 ));
1188 }
1189 let password = password.ok_or_else(|| {
1190 PgError::Auth("Password required for SCRAM authentication".to_string())
1191 })?;
1192
1193 let tls_binding = self.tls_server_end_point_channel_binding();
1194 let (mechanism, channel_binding_data) = select_scram_mechanism(
1195 &mechanisms,
1196 tls_binding,
1197 auth_settings.channel_binding,
1198 )
1199 .map_err(PgError::Auth)?;
1200
1201 let client = if let Some(binding_data) = channel_binding_data {
1202 ScramClient::new_with_tls_server_end_point(user, password, binding_data)
1203 } else {
1204 ScramClient::new(user, password)
1205 };
1206 let first_message = client.client_first_message();
1207
1208 self.send(FrontendMessage::SASLInitialResponse {
1209 mechanism,
1210 data: first_message,
1211 })
1212 .await?;
1213
1214 scram_client = Some(client);
1215 }
1216 BackendMessage::AuthenticationSASLContinue(server_data) => {
1217 let client = scram_client.as_mut().ok_or_else(|| {
1218 PgError::Auth("Received SASL Continue without SASL init".to_string())
1219 })?;
1220
1221 let final_message = client
1222 .process_server_first(&server_data)
1223 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
1224
1225 self.send(FrontendMessage::SASLResponse(final_message))
1226 .await?;
1227 }
1228 BackendMessage::AuthenticationSASLFinal(server_signature) => {
1229 if let Some(client) = scram_client.as_ref() {
1230 client.verify_server_final(&server_signature).map_err(|e| {
1231 PgError::Auth(format!("Server verification failed: {}", e))
1232 })?;
1233 }
1234 }
1235 BackendMessage::ParameterStatus { .. } => {}
1236 BackendMessage::BackendKeyData {
1237 process_id,
1238 secret_key,
1239 } => {
1240 self.process_id = process_id;
1241 self.secret_key = secret_key;
1242 }
1243 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
1244 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
1245 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
1246 return Ok(());
1247 }
1248 BackendMessage::ErrorResponse(err) => {
1249 return Err(PgError::Connection(err.message));
1250 }
1251 _ => {}
1252 }
1253 }
1254 }
1255
1256 fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
1261 let PgStream::Tls(tls) = &self.stream else {
1262 return None;
1263 };
1264
1265 let (_, conn) = tls.get_ref();
1266 let certs = conn.peer_certificates()?;
1267 let leaf_cert = certs.first()?;
1268
1269 let mut hasher = Sha256::new();
1270 hasher.update(leaf_cert.as_ref());
1271 Some(hasher.finalize().to_vec())
1272 }
1273
1274 pub async fn close(mut self) -> PgResult<()> {
1277 use crate::protocol::PgEncoder;
1278
1279 let terminate = PgEncoder::encode_terminate();
1281 self.stream.write_all(&terminate).await?;
1282 self.stream.flush().await?;
1283
1284 Ok(())
1285 }
1286
1287 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
1293
1294 pub(crate) fn evict_prepared_if_full(&mut self) {
1300 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
1301 if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
1303 self.prepared_statements.remove(&evicted_name);
1304 } else {
1305 if let Some(key) = self.prepared_statements.keys().next().cloned() {
1309 self.prepared_statements.remove(&key);
1310 }
1311 }
1312 }
1313 }
1314
1315 pub(crate) fn clear_prepared_statement_state(&mut self) {
1320 self.stmt_cache.clear();
1321 self.prepared_statements.clear();
1322 self.column_info_cache.clear();
1323 }
1324}
1325
1326fn generate_gss_token(
1327 session_id: u64,
1328 mechanism: EnterpriseAuthMechanism,
1329 server_token: Option<&[u8]>,
1330 legacy_provider: Option<GssTokenProvider>,
1331 stateful_provider: Option<&GssTokenProviderEx>,
1332) -> Result<Vec<u8>, String> {
1333 if let Some(provider) = stateful_provider {
1334 return provider(GssTokenRequest {
1335 session_id,
1336 mechanism,
1337 server_token,
1338 });
1339 }
1340
1341 if let Some(provider) = legacy_provider {
1342 return provider(mechanism, server_token);
1343 }
1344
1345 Err("No GSS token provider configured".to_string())
1346}
1347
1348fn select_scram_mechanism(
1349 mechanisms: &[String],
1350 tls_server_end_point_binding: Option<Vec<u8>>,
1351 channel_binding_mode: ScramChannelBindingMode,
1352) -> Result<(String, Option<Vec<u8>>), String> {
1353 let has_scram = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
1354 let has_scram_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
1355
1356 match channel_binding_mode {
1357 ScramChannelBindingMode::Disable => {
1358 if has_scram {
1359 return Ok(("SCRAM-SHA-256".to_string(), None));
1360 }
1361 Err(format!(
1362 "channel_binding=disable, but server does not advertise SCRAM-SHA-256. Available: {:?}",
1363 mechanisms
1364 ))
1365 }
1366 ScramChannelBindingMode::Prefer => {
1367 if has_scram_plus {
1368 if let Some(binding) = tls_server_end_point_binding {
1369 return Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)));
1370 }
1371
1372 if has_scram {
1373 return Ok(("SCRAM-SHA-256".to_string(), None));
1374 }
1375
1376 return Err(
1377 "Server requires SCRAM-SHA-256-PLUS but TLS channel binding is unavailable"
1378 .to_string(),
1379 );
1380 }
1381
1382 if has_scram {
1383 return Ok(("SCRAM-SHA-256".to_string(), None));
1384 }
1385
1386 Err(format!(
1387 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
1388 mechanisms
1389 ))
1390 }
1391 ScramChannelBindingMode::Require => {
1392 if !has_scram_plus {
1393 return Err(
1394 "channel_binding=require, but server does not advertise SCRAM-SHA-256-PLUS"
1395 .to_string(),
1396 );
1397 }
1398 let binding = tls_server_end_point_binding.ok_or_else(|| {
1399 "channel_binding=require, but TLS channel binding data is unavailable".to_string()
1400 })?;
1401 Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)))
1402 }
1403 }
1404}
1405
1406fn md5_password_message(user: &str, password: &str, salt: [u8; 4]) -> String {
1408 use md5::{Digest, Md5};
1409
1410 let mut inner = Md5::new();
1411 inner.update(password.as_bytes());
1412 inner.update(user.as_bytes());
1413 let inner_hex = format!("{:x}", inner.finalize());
1414
1415 let mut outer = Md5::new();
1416 outer.update(inner_hex.as_bytes());
1417 outer.update(salt);
1418 format!("md5{:x}", outer.finalize())
1419}
1420
1421impl Drop for PgConnection {
1424 fn drop(&mut self) {
1425 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
1428
1429 match &mut self.stream {
1430 PgStream::Tcp(tcp) => {
1431 let _ = tcp.try_write(&terminate);
1433 }
1434 PgStream::Tls(_) => {
1435 }
1439 #[cfg(unix)]
1440 PgStream::Unix(unix) => {
1441 let _ = unix.try_write(&terminate);
1442 }
1443 }
1444 }
1445}
1446
1447pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
1448 tag.split_whitespace()
1449 .last()
1450 .and_then(|s| s.parse().ok())
1451 .unwrap_or(0)
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456 use super::{md5_password_message, select_scram_mechanism};
1457 use crate::driver::ScramChannelBindingMode;
1458
1459 #[test]
1460 fn test_md5_password_message_known_vector() {
1461 let hash = md5_password_message("postgres", "secret", [0x12, 0x34, 0x56, 0x78]);
1462 assert_eq!(hash, "md521561af64619ca746c2a6c4d6cbedb30");
1463 }
1464
1465 #[test]
1466 fn test_md5_password_message_is_stable() {
1467 let a = md5_password_message("user_a", "pw", [1, 2, 3, 4]);
1468 let b = md5_password_message("user_a", "pw", [1, 2, 3, 4]);
1469 assert_eq!(a, b);
1470 assert!(a.starts_with("md5"));
1471 assert_eq!(a.len(), 35);
1472 }
1473
1474 #[test]
1475 fn test_select_scram_plus_when_binding_available() {
1476 let mechanisms = vec![
1477 "SCRAM-SHA-256".to_string(),
1478 "SCRAM-SHA-256-PLUS".to_string(),
1479 ];
1480 let binding = vec![1, 2, 3];
1481 let (mechanism, selected_binding) = select_scram_mechanism(
1482 &mechanisms,
1483 Some(binding.clone()),
1484 ScramChannelBindingMode::Prefer,
1485 )
1486 .unwrap();
1487 assert_eq!(mechanism, "SCRAM-SHA-256-PLUS");
1488 assert_eq!(selected_binding, Some(binding));
1489 }
1490
1491 #[test]
1492 fn test_select_scram_fallback_without_binding() {
1493 let mechanisms = vec![
1494 "SCRAM-SHA-256".to_string(),
1495 "SCRAM-SHA-256-PLUS".to_string(),
1496 ];
1497 let (mechanism, selected_binding) =
1498 select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Prefer).unwrap();
1499 assert_eq!(mechanism, "SCRAM-SHA-256");
1500 assert_eq!(selected_binding, None);
1501 }
1502
1503 #[test]
1504 fn test_select_scram_plus_only_requires_binding() {
1505 let mechanisms = vec!["SCRAM-SHA-256-PLUS".to_string()];
1506 let err =
1507 select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Prefer).unwrap_err();
1508 assert!(err.contains("SCRAM-SHA-256-PLUS"));
1509 }
1510
1511 #[test]
1512 fn test_select_scram_require_fails_without_plus() {
1513 let mechanisms = vec!["SCRAM-SHA-256".to_string()];
1514 let err = select_scram_mechanism(
1515 &mechanisms,
1516 Some(vec![1, 2, 3]),
1517 ScramChannelBindingMode::Require,
1518 )
1519 .unwrap_err();
1520 assert!(err.contains("channel_binding=require"));
1521 assert!(err.contains("SCRAM-SHA-256-PLUS"));
1522 }
1523
1524 #[test]
1525 fn test_select_scram_disable_rejects_plus_only() {
1526 let mechanisms = vec!["SCRAM-SHA-256-PLUS".to_string()];
1527 let err = select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Disable)
1528 .unwrap_err();
1529 assert!(err.contains("channel_binding=disable"));
1530 }
1531
1532 #[test]
1533 fn test_select_scram_require_fails_without_tls_binding() {
1534 let mechanisms = vec![
1535 "SCRAM-SHA-256".to_string(),
1536 "SCRAM-SHA-256-PLUS".to_string(),
1537 ];
1538 let err = select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Require)
1539 .unwrap_err();
1540 assert!(err.contains("channel_binding=require"));
1541 assert!(err.contains("unavailable"));
1542 }
1543
1544 #[test]
1545 fn test_select_scram_require_succeeds_with_plus_and_binding() {
1546 let mechanisms = vec![
1547 "SCRAM-SHA-256".to_string(),
1548 "SCRAM-SHA-256-PLUS".to_string(),
1549 ];
1550 let binding = vec![10, 20, 30];
1551 let (mechanism, selected_binding) = select_scram_mechanism(
1552 &mechanisms,
1553 Some(binding.clone()),
1554 ScramChannelBindingMode::Require,
1555 )
1556 .unwrap();
1557 assert_eq!(mechanism, "SCRAM-SHA-256-PLUS");
1558 assert_eq!(selected_binding, Some(binding));
1559 }
1560}