Skip to main content

qail_pg/driver/
connection.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::notification::Notification;
16use super::stream::PgStream;
17use super::{
18    AuthSettings, ConnectOptions, EnterpriseAuthMechanism, GssEncMode, GssTokenProvider,
19    GssTokenProviderEx, GssTokenRequest, PgError, PgResult, ScramChannelBindingMode, TlsMode,
20};
21use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
22use bytes::BytesMut;
23use sha2::{Digest, Sha256};
24use std::collections::{HashMap, VecDeque};
25use std::num::NonZeroUsize;
26use std::sync::Arc;
27use std::sync::atomic::{AtomicU64, Ordering};
28use tokio::io::AsyncWriteExt;
29use tokio::net::TcpStream;
30
31/// Statement cache capacity per connection.
32const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
33
34/// Small, allocation-bounded prepared statement cache.
35///
36/// This mirrors the subset of `lru::LruCache` APIs used by the driver while
37/// avoiding external unsoundness advisories on `IterMut` (which we don't use).
38#[derive(Debug)]
39pub(crate) struct StatementCache {
40    capacity: NonZeroUsize,
41    entries: HashMap<u64, String>,
42    order: VecDeque<u64>, // Front = LRU, back = MRU
43}
44
45impl StatementCache {
46    pub(crate) fn new(capacity: NonZeroUsize) -> Self {
47        Self {
48            capacity,
49            entries: HashMap::with_capacity(capacity.get()),
50            order: VecDeque::with_capacity(capacity.get()),
51        }
52    }
53
54    pub(crate) fn len(&self) -> usize {
55        self.entries.len()
56    }
57
58    pub(crate) fn cap(&self) -> NonZeroUsize {
59        self.capacity
60    }
61
62    pub(crate) fn contains(&self, key: &u64) -> bool {
63        self.entries.contains_key(key)
64    }
65
66    pub(crate) fn get(&mut self, key: &u64) -> Option<String> {
67        let value = self.entries.get(key).cloned()?;
68        self.touch(*key);
69        Some(value)
70    }
71
72    pub(crate) fn put(&mut self, key: u64, value: String) {
73        if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
74            e.insert(value);
75            self.touch(key);
76            return;
77        }
78
79        if self.entries.len() >= self.capacity.get() {
80            let _ = self.pop_lru();
81        }
82
83        self.entries.insert(key, value);
84        self.order.push_back(key);
85    }
86
87    pub(crate) fn pop_lru(&mut self) -> Option<(u64, String)> {
88        while let Some(key) = self.order.pop_front() {
89            if let Some(value) = self.entries.remove(&key) {
90                return Some((key, value));
91            }
92        }
93        None
94    }
95
96    pub(crate) fn clear(&mut self) {
97        self.entries.clear();
98        self.order.clear();
99    }
100
101    fn touch(&mut self, key: u64) {
102        self.order.retain(|k| *k != key);
103        self.order.push_back(key);
104    }
105}
106
107/// Initial buffer capacity (64KB for pipeline performance)
108pub(crate) const BUFFER_CAPACITY: usize = 65536;
109
110/// SSLRequest message bytes (request code: 80877103)
111const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
112
113/// GSSENCRequest message bytes (request code: 80877104)
114/// Byte breakdown: length=8 (00 00 00 08), code=80877104 (04 D2 16 30)
115const GSSENC_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 48];
116
117/// Result of sending a GSSENCRequest to the server.
118#[derive(Debug)]
119enum GssEncNegotiationResult {
120    /// Server responded 'G' — willing to perform GSSAPI encryption.
121    /// The TCP stream is returned for the caller to establish the
122    /// GSSAPI security context and wrap all subsequent traffic.
123    Accepted(TcpStream),
124    /// Server responded 'N' — unwilling to perform GSSAPI encryption.
125    Rejected,
126    /// Server sent an ErrorMessage — must not be displayed to user
127    /// (CVE-2024-10977: server not yet authenticated).
128    ServerError,
129}
130
131/// CancelRequest protocol code: 80877102
132pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
133
134/// Monotonic session id source for stateful GSS provider callbacks.
135static GSS_SESSION_COUNTER: AtomicU64 = AtomicU64::new(1);
136
137/// Default timeout for TCP connect + PostgreSQL handshake.
138/// Prevents Slowloris DoS where a malicious server accepts TCP but never responds.
139pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
140
141/// TLS configuration for mutual TLS (client certificate authentication).
142#[derive(Debug, Clone)]
143pub struct TlsConfig {
144    /// Client certificate in PEM format
145    pub client_cert_pem: Vec<u8>,
146    /// Client private key in PEM format
147    pub client_key_pem: Vec<u8>,
148    /// Optional CA certificate for server verification (uses system certs if None)
149    pub ca_cert_pem: Option<Vec<u8>>,
150}
151
152impl TlsConfig {
153    /// Create a new TLS config from file paths.
154    pub fn from_files(
155        cert_path: impl AsRef<std::path::Path>,
156        key_path: impl AsRef<std::path::Path>,
157        ca_path: Option<impl AsRef<std::path::Path>>,
158    ) -> std::io::Result<Self> {
159        Ok(Self {
160            client_cert_pem: std::fs::read(cert_path)?,
161            client_key_pem: std::fs::read(key_path)?,
162            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
163        })
164    }
165}
166
167/// Bundled connection parameters for internal functions.
168///
169/// Groups the 8 common arguments to avoid exceeding clippy's
170/// `too_many_arguments` threshold.
171struct ConnectParams<'a> {
172    host: &'a str,
173    port: u16,
174    user: &'a str,
175    database: &'a str,
176    password: Option<&'a str>,
177    auth_settings: AuthSettings,
178    gss_token_provider: Option<GssTokenProvider>,
179    gss_token_provider_ex: Option<GssTokenProviderEx>,
180}
181
182/// A raw PostgreSQL connection.
183pub struct PgConnection {
184    pub(crate) stream: PgStream,
185    pub(crate) buffer: BytesMut,
186    pub(crate) write_buf: BytesMut,
187    pub(crate) sql_buf: BytesMut,
188    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
189    pub(crate) prepared_statements: HashMap<String, String>,
190    pub(crate) stmt_cache: StatementCache,
191    /// Cache of column metadata (RowDescription) per statement hash.
192    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
193    /// This cache ensures by-name column access works even for cached prepared statements.
194    pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
195    pub(crate) process_id: i32,
196    pub(crate) secret_key: i32,
197    /// Buffer for asynchronous LISTEN/NOTIFY notifications.
198    /// Populated by `recv()` when it encounters NotificationResponse messages.
199    pub(crate) notifications: VecDeque<Notification>,
200}
201
202impl PgConnection {
203    /// Connect to PostgreSQL server without authentication (trust mode).
204    ///
205    /// # Arguments
206    ///
207    /// * `host` — PostgreSQL server hostname or IP.
208    /// * `port` — TCP port (typically 5432).
209    /// * `user` — PostgreSQL role name.
210    /// * `database` — Target database name.
211    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
212        Self::connect_with_password(host, port, user, database, None).await
213    }
214
215    /// Connect to PostgreSQL server with optional password authentication.
216    /// Includes a default 10-second timeout covering TCP connect + handshake.
217    pub async fn connect_with_password(
218        host: &str,
219        port: u16,
220        user: &str,
221        database: &str,
222        password: Option<&str>,
223    ) -> PgResult<Self> {
224        Self::connect_with_password_and_auth(
225            host,
226            port,
227            user,
228            database,
229            password,
230            AuthSettings::default(),
231        )
232        .await
233    }
234
235    /// Connect to PostgreSQL with explicit enterprise options.
236    ///
237    /// Negotiation preface order follows libpq:
238    ///   1. If gss_enc_mode != Disable → try GSSENCRequest on fresh TCP
239    ///   2. If GSSENC rejected/unavailable and tls_mode != Disable → try SSLRequest
240    ///   3. If both rejected/unavailable → plain StartupMessage
241    pub async fn connect_with_options(
242        host: &str,
243        port: u16,
244        user: &str,
245        database: &str,
246        password: Option<&str>,
247        options: ConnectOptions,
248    ) -> PgResult<Self> {
249        let ConnectOptions {
250            tls_mode,
251            gss_enc_mode,
252            tls_ca_cert_pem,
253            mtls,
254            gss_token_provider,
255            gss_token_provider_ex,
256            auth,
257        } = options;
258
259        if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
260            return Err(PgError::Connection(
261                "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
262            ));
263        }
264
265        // Enforce gss_enc_mode policy before mTLS early-return.
266        // GSSENC and mTLS are both transport-level encryption; using
267        // both simultaneously is not supported by the PostgreSQL protocol.
268        if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
269            return Err(PgError::Connection(
270                "gssencmode=require is incompatible with mTLS — both provide \
271                 transport encryption; use one or the other"
272                    .to_string(),
273            ));
274        }
275
276        if let Some(mtls_config) = mtls {
277            // gss_enc_mode is Disable or Prefer here (Require rejected above).
278            // mTLS already provides transport encryption; skip GSSENC.
279            return Self::connect_mtls_with_password_and_auth_and_gss(
280                ConnectParams {
281                    host,
282                    port,
283                    user,
284                    database,
285                    password,
286                    auth_settings: auth,
287                    gss_token_provider,
288                    gss_token_provider_ex,
289                },
290                mtls_config,
291            )
292            .await;
293        }
294
295        // ── Phase 1: Try GSSENC if requested ──────────────────────────
296        if gss_enc_mode != GssEncMode::Disable {
297            match Self::try_gssenc_request(host, port).await {
298                Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
299                    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
300                    {
301                        let gssenc_fut = async {
302                            let gss_stream = super::gss::gssenc_handshake(tcp_stream, host)
303                                .await
304                                .map_err(PgError::Auth)?;
305                            let mut conn = Self {
306                                stream: PgStream::GssEnc(gss_stream),
307                                buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
308                                write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
309                                sql_buf: BytesMut::with_capacity(512),
310                                params_buf: Vec::with_capacity(16),
311                                prepared_statements: HashMap::new(),
312                                stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
313                                column_info_cache: HashMap::new(),
314                                process_id: 0,
315                                secret_key: 0,
316                                notifications: VecDeque::new(),
317                            };
318                            conn.send(FrontendMessage::Startup {
319                                user: user.to_string(),
320                                database: database.to_string(),
321                            })
322                            .await?;
323                            conn.handle_startup(
324                                user,
325                                password,
326                                auth,
327                                gss_token_provider,
328                                gss_token_provider_ex,
329                            )
330                            .await?;
331                            Ok(conn)
332                        };
333                        return tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
334                            .await
335                            .map_err(|_| {
336                                PgError::Connection(format!(
337                                    "GSSENC connection timeout after {:?} \
338                                 (handshake + auth)",
339                                    DEFAULT_CONNECT_TIMEOUT
340                                ))
341                            })?;
342                    }
343                    #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
344                    {
345                        let _ = tcp_stream;
346                        return Err(PgError::Connection(
347                            "Server accepted GSSENCRequest but GSSAPI encryption requires \
348                             feature enterprise-gssapi on Linux"
349                                .to_string(),
350                        ));
351                    }
352                }
353                Ok(GssEncNegotiationResult::Rejected)
354                | Ok(GssEncNegotiationResult::ServerError) => {
355                    if gss_enc_mode == GssEncMode::Require {
356                        return Err(PgError::Connection(
357                            "gssencmode=require but server rejected GSSENCRequest".to_string(),
358                        ));
359                    }
360                    // gss_enc_mode == Prefer — fall through to TLS / plain
361                }
362                Err(e) => {
363                    if gss_enc_mode == GssEncMode::Require {
364                        return Err(e);
365                    }
366                    // gss_enc_mode == Prefer — connection error, fall through
367                    tracing::debug!(
368                        host = %host,
369                        port = %port,
370                        error = %e,
371                        "gssenc_prefer_fallthrough"
372                    );
373                }
374            }
375        }
376
377        // ── Phase 2: TLS / plain per sslmode ──────────────────────────
378        match tls_mode {
379            TlsMode::Disable => {
380                Self::connect_with_password_and_auth_and_gss(ConnectParams {
381                    host,
382                    port,
383                    user,
384                    database,
385                    password,
386                    auth_settings: auth,
387                    gss_token_provider,
388                    gss_token_provider_ex,
389                })
390                .await
391            }
392            TlsMode::Require => {
393                Self::connect_tls_with_auth_and_gss(
394                    ConnectParams {
395                        host,
396                        port,
397                        user,
398                        database,
399                        password,
400                        auth_settings: auth,
401                        gss_token_provider,
402                        gss_token_provider_ex,
403                    },
404                    tls_ca_cert_pem.as_deref(),
405                )
406                .await
407            }
408            TlsMode::Prefer => {
409                match Self::connect_tls_with_auth_and_gss(
410                    ConnectParams {
411                        host,
412                        port,
413                        user,
414                        database,
415                        password,
416                        auth_settings: auth,
417                        gss_token_provider,
418                        gss_token_provider_ex: gss_token_provider_ex.clone(),
419                    },
420                    tls_ca_cert_pem.as_deref(),
421                )
422                .await
423                {
424                    Ok(conn) => Ok(conn),
425                    Err(PgError::Connection(msg))
426                        if msg.contains("Server does not support TLS") =>
427                    {
428                        Self::connect_with_password_and_auth_and_gss(ConnectParams {
429                            host,
430                            port,
431                            user,
432                            database,
433                            password,
434                            auth_settings: auth,
435                            gss_token_provider,
436                            gss_token_provider_ex,
437                        })
438                        .await
439                    }
440                    Err(e) => Err(e),
441                }
442            }
443        }
444    }
445
446    /// Attempt GSSAPI session encryption negotiation.
447    ///
448    /// Opens a fresh TCP connection, sends GSSENCRequest (80877104),
449    /// reads exactly one byte (CVE-2021-23222 safe), and returns
450    /// the result.  The entire operation is bounded by
451    /// `DEFAULT_CONNECT_TIMEOUT`.
452    async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
453        tokio::time::timeout(
454            DEFAULT_CONNECT_TIMEOUT,
455            Self::try_gssenc_request_inner(host, port),
456        )
457        .await
458        .map_err(|_| {
459            PgError::Connection(format!(
460                "GSSENCRequest timeout after {:?}",
461                DEFAULT_CONNECT_TIMEOUT
462            ))
463        })?
464    }
465
466    /// Inner GSSENCRequest logic without timeout wrapper.
467    async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
468        use tokio::io::AsyncReadExt;
469
470        let addr = format!("{}:{}", host, port);
471        let mut tcp_stream = TcpStream::connect(&addr).await?;
472        tcp_stream.set_nodelay(true)?;
473
474        // Send the 8-byte GSSENCRequest.
475        tcp_stream.write_all(&GSSENC_REQUEST).await?;
476        tcp_stream.flush().await?;
477
478        // CVE-2021-23222: Read exactly one byte.  The server must
479        // respond with a single 'G' or 'N'.  Any additional bytes
480        // in the buffer indicate a buffer-stuffing attack.
481        let mut response = [0u8; 1];
482        tcp_stream.read_exact(&mut response).await?;
483
484        match response[0] {
485            b'G' => {
486                // CVE-2021-23222 check: verify no extra bytes are buffered.
487                // Use a non-blocking peek to detect leftover data.
488                let mut peek_buf = [0u8; 1];
489                match tcp_stream.try_read(&mut peek_buf) {
490                    Ok(0) => {} // EOF — fine (shouldn't happen yet but harmless)
491                    Ok(_n) => {
492                        // Extra bytes after 'G' — possible buffer-stuffing.
493                        return Err(PgError::Connection(
494                            "Protocol violation: extra bytes after GSSENCRequest 'G' response \
495                             (possible CVE-2021-23222 buffer-stuffing attack)"
496                                .to_string(),
497                        ));
498                    }
499                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
500                        // No extra data — this is the expected path.
501                    }
502                    Err(e) => {
503                        return Err(PgError::Io(e));
504                    }
505                }
506                Ok(GssEncNegotiationResult::Accepted(tcp_stream))
507            }
508            b'N' => Ok(GssEncNegotiationResult::Rejected),
509            b'E' => {
510                // Server sent an ErrorMessage.  Per CVE-2024-10977 we
511                // must NOT display this to users since the server has
512                // not been authenticated.  Log at trace only.
513                tracing::trace!(
514                    host = %host,
515                    port = %port,
516                    "gssenc_request_server_error (suppressed per CVE-2024-10977)"
517                );
518                Ok(GssEncNegotiationResult::ServerError)
519            }
520            other => Err(PgError::Connection(format!(
521                "Unexpected response to GSSENCRequest: 0x{:02X} \
522                     (expected 'G'=0x47 or 'N'=0x4E)",
523                other
524            ))),
525        }
526    }
527
528    /// Connect to PostgreSQL server with optional password authentication and auth policy.
529    pub async fn connect_with_password_and_auth(
530        host: &str,
531        port: u16,
532        user: &str,
533        database: &str,
534        password: Option<&str>,
535        auth_settings: AuthSettings,
536    ) -> PgResult<Self> {
537        Self::connect_with_password_and_auth_and_gss(ConnectParams {
538            host,
539            port,
540            user,
541            database,
542            password,
543            auth_settings,
544            gss_token_provider: None,
545            gss_token_provider_ex: None,
546        })
547        .await
548    }
549
550    async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
551        tokio::time::timeout(
552            DEFAULT_CONNECT_TIMEOUT,
553            Self::connect_with_password_inner(params),
554        )
555        .await
556        .map_err(|_| {
557            PgError::Connection(format!(
558                "Connection timeout after {:?} (TCP connect + handshake)",
559                DEFAULT_CONNECT_TIMEOUT
560            ))
561        })?
562    }
563
564    /// Inner connection logic without timeout wrapper.
565    async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
566        let ConnectParams {
567            host,
568            port,
569            user,
570            database,
571            password,
572            auth_settings,
573            gss_token_provider,
574            gss_token_provider_ex,
575        } = params;
576        let addr = format!("{}:{}", host, port);
577        let tcp_stream = TcpStream::connect(&addr).await?;
578
579        // Disable Nagle's algorithm for lower latency
580        tcp_stream.set_nodelay(true)?;
581
582        let mut conn = Self {
583            stream: PgStream::Tcp(tcp_stream),
584            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
585            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
586            sql_buf: BytesMut::with_capacity(512),
587            params_buf: Vec::with_capacity(16), // SQL encoding buffer
588            prepared_statements: HashMap::new(),
589            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
590            column_info_cache: HashMap::new(),
591            process_id: 0,
592            secret_key: 0,
593            notifications: VecDeque::new(),
594        };
595
596        conn.send(FrontendMessage::Startup {
597            user: user.to_string(),
598            database: database.to_string(),
599        })
600        .await?;
601
602        conn.handle_startup(
603            user,
604            password,
605            auth_settings,
606            gss_token_provider,
607            gss_token_provider_ex,
608        )
609        .await?;
610
611        Ok(conn)
612    }
613
614    /// Connect to PostgreSQL server with TLS encryption.
615    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
616    pub async fn connect_tls(
617        host: &str,
618        port: u16,
619        user: &str,
620        database: &str,
621        password: Option<&str>,
622    ) -> PgResult<Self> {
623        Self::connect_tls_with_auth(
624            host,
625            port,
626            user,
627            database,
628            password,
629            AuthSettings::default(),
630            None,
631        )
632        .await
633    }
634
635    /// Connect to PostgreSQL over TLS with explicit auth policy and optional custom CA bundle.
636    pub async fn connect_tls_with_auth(
637        host: &str,
638        port: u16,
639        user: &str,
640        database: &str,
641        password: Option<&str>,
642        auth_settings: AuthSettings,
643        ca_cert_pem: Option<&[u8]>,
644    ) -> PgResult<Self> {
645        Self::connect_tls_with_auth_and_gss(
646            ConnectParams {
647                host,
648                port,
649                user,
650                database,
651                password,
652                auth_settings,
653                gss_token_provider: None,
654                gss_token_provider_ex: None,
655            },
656            ca_cert_pem,
657        )
658        .await
659    }
660
661    async fn connect_tls_with_auth_and_gss(
662        params: ConnectParams<'_>,
663        ca_cert_pem: Option<&[u8]>,
664    ) -> PgResult<Self> {
665        tokio::time::timeout(
666            DEFAULT_CONNECT_TIMEOUT,
667            Self::connect_tls_inner(params, ca_cert_pem),
668        )
669        .await
670        .map_err(|_| {
671            PgError::Connection(format!(
672                "TLS connection timeout after {:?}",
673                DEFAULT_CONNECT_TIMEOUT
674            ))
675        })?
676    }
677
678    /// Inner TLS connection logic without timeout wrapper.
679    async fn connect_tls_inner(
680        params: ConnectParams<'_>,
681        ca_cert_pem: Option<&[u8]>,
682    ) -> PgResult<Self> {
683        let ConnectParams {
684            host,
685            port,
686            user,
687            database,
688            password,
689            auth_settings,
690            gss_token_provider,
691            gss_token_provider_ex,
692        } = params;
693        use tokio::io::AsyncReadExt;
694        use tokio_rustls::TlsConnector;
695        use tokio_rustls::rustls::ClientConfig;
696        use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
697
698        let addr = format!("{}:{}", host, port);
699        let mut tcp_stream = TcpStream::connect(&addr).await?;
700
701        // Send SSLRequest
702        tcp_stream.write_all(&SSL_REQUEST).await?;
703
704        // Read response
705        let mut response = [0u8; 1];
706        tcp_stream.read_exact(&mut response).await?;
707
708        if response[0] != b'S' {
709            return Err(PgError::Connection(
710                "Server does not support TLS".to_string(),
711            ));
712        }
713
714        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
715
716        if let Some(ca_pem) = ca_cert_pem {
717            let certs = CertificateDer::pem_slice_iter(ca_pem)
718                .collect::<Result<Vec<_>, _>>()
719                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
720            if certs.is_empty() {
721                return Err(PgError::Connection(
722                    "No CA certificates found in provided PEM".to_string(),
723                ));
724            }
725            for cert in certs {
726                let _ = root_cert_store.add(cert);
727            }
728        } else {
729            let certs = rustls_native_certs::load_native_certs();
730            for cert in certs.certs {
731                let _ = root_cert_store.add(cert);
732            }
733        }
734
735        let config = ClientConfig::builder()
736            .with_root_certificates(root_cert_store)
737            .with_no_client_auth();
738
739        let connector = TlsConnector::from(Arc::new(config));
740        let server_name = ServerName::try_from(host.to_string())
741            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
742
743        let tls_stream = connector
744            .connect(server_name, tcp_stream)
745            .await
746            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
747
748        let mut conn = Self {
749            stream: PgStream::Tls(Box::new(tls_stream)),
750            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
751            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
752            sql_buf: BytesMut::with_capacity(512),
753            params_buf: Vec::with_capacity(16),
754            prepared_statements: HashMap::new(),
755            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
756            column_info_cache: HashMap::new(),
757            process_id: 0,
758            secret_key: 0,
759            notifications: VecDeque::new(),
760        };
761
762        conn.send(FrontendMessage::Startup {
763            user: user.to_string(),
764            database: database.to_string(),
765        })
766        .await?;
767
768        conn.handle_startup(
769            user,
770            password,
771            auth_settings,
772            gss_token_provider,
773            gss_token_provider_ex,
774        )
775        .await?;
776
777        Ok(conn)
778    }
779
780    /// Connect with mutual TLS (client certificate authentication).
781    /// # Arguments
782    /// * `host` - PostgreSQL server hostname
783    /// * `port` - PostgreSQL server port
784    /// * `user` - Database user
785    /// * `database` - Database name
786    /// * `config` - TLS configuration with client cert/key
787    /// # Example
788    /// ```ignore
789    /// let config = TlsConfig {
790    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
791    ///     client_key_pem: include_bytes!("client.key").to_vec(),
792    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
793    /// };
794    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
795    /// ```
796    pub async fn connect_mtls(
797        host: &str,
798        port: u16,
799        user: &str,
800        database: &str,
801        config: TlsConfig,
802    ) -> PgResult<Self> {
803        Self::connect_mtls_with_password_and_auth(
804            host,
805            port,
806            user,
807            database,
808            None,
809            config,
810            AuthSettings::default(),
811        )
812        .await
813    }
814
815    /// Connect with mutual TLS and optional password fallback.
816    pub async fn connect_mtls_with_password_and_auth(
817        host: &str,
818        port: u16,
819        user: &str,
820        database: &str,
821        password: Option<&str>,
822        config: TlsConfig,
823        auth_settings: AuthSettings,
824    ) -> PgResult<Self> {
825        Self::connect_mtls_with_password_and_auth_and_gss(
826            ConnectParams {
827                host,
828                port,
829                user,
830                database,
831                password,
832                auth_settings,
833                gss_token_provider: None,
834                gss_token_provider_ex: None,
835            },
836            config,
837        )
838        .await
839    }
840
841    async fn connect_mtls_with_password_and_auth_and_gss(
842        params: ConnectParams<'_>,
843        config: TlsConfig,
844    ) -> PgResult<Self> {
845        tokio::time::timeout(
846            DEFAULT_CONNECT_TIMEOUT,
847            Self::connect_mtls_inner(params, config),
848        )
849        .await
850        .map_err(|_| {
851            PgError::Connection(format!(
852                "mTLS connection timeout after {:?}",
853                DEFAULT_CONNECT_TIMEOUT
854            ))
855        })?
856    }
857
858    /// Inner mTLS connection logic without timeout wrapper.
859    async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
860        let ConnectParams {
861            host,
862            port,
863            user,
864            database,
865            password,
866            auth_settings,
867            gss_token_provider,
868            gss_token_provider_ex,
869        } = params;
870        use tokio::io::AsyncReadExt;
871        use tokio_rustls::TlsConnector;
872        use tokio_rustls::rustls::{
873            ClientConfig,
874            pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
875        };
876
877        let addr = format!("{}:{}", host, port);
878        let mut tcp_stream = TcpStream::connect(&addr).await?;
879
880        // Send SSLRequest
881        tcp_stream.write_all(&SSL_REQUEST).await?;
882
883        // Read response
884        let mut response = [0u8; 1];
885        tcp_stream.read_exact(&mut response).await?;
886
887        if response[0] != b'S' {
888            return Err(PgError::Connection(
889                "Server does not support TLS".to_string(),
890            ));
891        }
892
893        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
894
895        if let Some(ca_pem) = &config.ca_cert_pem {
896            let certs = CertificateDer::pem_slice_iter(ca_pem)
897                .collect::<Result<Vec<_>, _>>()
898                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
899            if certs.is_empty() {
900                return Err(PgError::Connection(
901                    "No CA certificates found in provided PEM".to_string(),
902                ));
903            }
904            for cert in certs {
905                let _ = root_cert_store.add(cert);
906            }
907        } else {
908            // Use system certs
909            let certs = rustls_native_certs::load_native_certs();
910            for cert in certs.certs {
911                let _ = root_cert_store.add(cert);
912            }
913        }
914
915        let client_certs: Vec<CertificateDer<'static>> =
916            CertificateDer::pem_slice_iter(&config.client_cert_pem)
917                .collect::<Result<Vec<_>, _>>()
918                .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
919        if client_certs.is_empty() {
920            return Err(PgError::Connection(
921                "No client certificates found in PEM".to_string(),
922            ));
923        }
924
925        let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
926            .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
927
928        let tls_config = ClientConfig::builder()
929            .with_root_certificates(root_cert_store)
930            .with_client_auth_cert(client_certs, client_key)
931            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
932
933        let connector = TlsConnector::from(Arc::new(tls_config));
934        let server_name = ServerName::try_from(host.to_string())
935            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
936
937        let tls_stream = connector
938            .connect(server_name, tcp_stream)
939            .await
940            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
941
942        let mut conn = Self {
943            stream: PgStream::Tls(Box::new(tls_stream)),
944            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
945            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
946            sql_buf: BytesMut::with_capacity(512),
947            params_buf: Vec::with_capacity(16),
948            prepared_statements: HashMap::new(),
949            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
950            column_info_cache: HashMap::new(),
951            process_id: 0,
952            secret_key: 0,
953            notifications: VecDeque::new(),
954        };
955
956        conn.send(FrontendMessage::Startup {
957            user: user.to_string(),
958            database: database.to_string(),
959        })
960        .await?;
961
962        conn.handle_startup(
963            user,
964            password,
965            auth_settings,
966            gss_token_provider,
967            gss_token_provider_ex,
968        )
969        .await?;
970
971        Ok(conn)
972    }
973
974    /// Connect to PostgreSQL server via Unix domain socket.
975    #[cfg(unix)]
976    pub async fn connect_unix(
977        socket_path: &str,
978        user: &str,
979        database: &str,
980        password: Option<&str>,
981    ) -> PgResult<Self> {
982        use tokio::net::UnixStream;
983
984        let unix_stream = UnixStream::connect(socket_path).await?;
985
986        let mut conn = Self {
987            stream: PgStream::Unix(unix_stream),
988            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
989            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
990            sql_buf: BytesMut::with_capacity(512),
991            params_buf: Vec::with_capacity(16),
992            prepared_statements: HashMap::new(),
993            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
994            column_info_cache: HashMap::new(),
995            process_id: 0,
996            secret_key: 0,
997            notifications: VecDeque::new(),
998        };
999
1000        conn.send(FrontendMessage::Startup {
1001            user: user.to_string(),
1002            database: database.to_string(),
1003        })
1004        .await?;
1005
1006        conn.handle_startup(user, password, AuthSettings::default(), None, None)
1007            .await?;
1008
1009        Ok(conn)
1010    }
1011
1012    /// Handle startup sequence (auth + params).
1013    async fn handle_startup(
1014        &mut self,
1015        user: &str,
1016        password: Option<&str>,
1017        auth_settings: AuthSettings,
1018        gss_token_provider: Option<GssTokenProvider>,
1019        gss_token_provider_ex: Option<GssTokenProviderEx>,
1020    ) -> PgResult<()> {
1021        let mut scram_client: Option<ScramClient> = None;
1022        let mut gss_mechanism: Option<EnterpriseAuthMechanism> = None;
1023        let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
1024        let mut gss_roundtrips: u32 = 0;
1025        const MAX_GSS_ROUNDTRIPS: u32 = 32;
1026
1027        loop {
1028            let msg = self.recv().await?;
1029            match msg {
1030                BackendMessage::AuthenticationOk => {}
1031                BackendMessage::AuthenticationKerberosV5 => {
1032                    if !auth_settings.allow_kerberos_v5 {
1033                        return Err(PgError::Auth(
1034                            "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
1035                        ));
1036                    }
1037
1038                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1039                        return Err(PgError::Auth(
1040                            "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1041                        ));
1042                    }
1043
1044                    let token = generate_gss_token(
1045                        gss_session_id,
1046                        EnterpriseAuthMechanism::KerberosV5,
1047                        None,
1048                        gss_token_provider,
1049                        gss_token_provider_ex.as_ref(),
1050                    )
1051                    .map_err(|e| {
1052                        PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
1053                    })?;
1054
1055                    self.send(FrontendMessage::GSSResponse(token)).await?;
1056                    gss_mechanism = Some(EnterpriseAuthMechanism::KerberosV5);
1057                }
1058                BackendMessage::AuthenticationGSS => {
1059                    if !auth_settings.allow_gssapi {
1060                        return Err(PgError::Auth(
1061                            "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
1062                        ));
1063                    }
1064
1065                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1066                        return Err(PgError::Auth(
1067                            "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1068                        ));
1069                    }
1070
1071                    let token = generate_gss_token(
1072                        gss_session_id,
1073                        EnterpriseAuthMechanism::GssApi,
1074                        None,
1075                        gss_token_provider,
1076                        gss_token_provider_ex.as_ref(),
1077                    )
1078                    .map_err(|e| {
1079                        PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
1080                    })?;
1081
1082                    self.send(FrontendMessage::GSSResponse(token)).await?;
1083                    gss_mechanism = Some(EnterpriseAuthMechanism::GssApi);
1084                }
1085                BackendMessage::AuthenticationSSPI => {
1086                    if !auth_settings.allow_sspi {
1087                        return Err(PgError::Auth(
1088                            "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
1089                        ));
1090                    }
1091
1092                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1093                        return Err(PgError::Auth(
1094                            "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1095                        ));
1096                    }
1097
1098                    let token = generate_gss_token(
1099                        gss_session_id,
1100                        EnterpriseAuthMechanism::Sspi,
1101                        None,
1102                        gss_token_provider,
1103                        gss_token_provider_ex.as_ref(),
1104                    )
1105                    .map_err(|e| {
1106                        PgError::Auth(format!("SSPI initial token generation failed: {}", e))
1107                    })?;
1108
1109                    self.send(FrontendMessage::GSSResponse(token)).await?;
1110                    gss_mechanism = Some(EnterpriseAuthMechanism::Sspi);
1111                }
1112                BackendMessage::AuthenticationGSSContinue(server_token) => {
1113                    gss_roundtrips += 1;
1114                    if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
1115                        return Err(PgError::Auth(format!(
1116                            "GSS handshake exceeded {} roundtrips — aborting",
1117                            MAX_GSS_ROUNDTRIPS
1118                        )));
1119                    }
1120
1121                    let mechanism = gss_mechanism.ok_or_else(|| {
1122                        PgError::Auth(
1123                            "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
1124                                .to_string(),
1125                        )
1126                    })?;
1127
1128                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
1129                        return Err(PgError::Auth(
1130                            "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
1131                        ));
1132                    }
1133
1134                    let token = generate_gss_token(
1135                        gss_session_id,
1136                        mechanism,
1137                        Some(&server_token),
1138                        gss_token_provider,
1139                        gss_token_provider_ex.as_ref(),
1140                    )
1141                    .map_err(|e| {
1142                        PgError::Auth(format!("GSS continue token generation failed: {}", e))
1143                    })?;
1144
1145                    // Only send the response if there is actually a token to
1146                    // send.  When gss_init_sec_context returns GSS_S_COMPLETE
1147                    // on the final round, the token may be empty.  Sending an
1148                    // empty GSSResponse ('p') after the server already
1149                    // considers auth complete trips the "invalid frontend
1150                    // message type 112" FATAL in PostgreSQL.
1151                    if !token.is_empty() {
1152                        self.send(FrontendMessage::GSSResponse(token)).await?;
1153                    }
1154                }
1155                BackendMessage::AuthenticationCleartextPassword => {
1156                    if !auth_settings.allow_cleartext_password {
1157                        return Err(PgError::Auth(
1158                            "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
1159                                .to_string(),
1160                        ));
1161                    }
1162                    let password = password.ok_or_else(|| {
1163                        PgError::Auth("Password required for cleartext authentication".to_string())
1164                    })?;
1165                    self.send(FrontendMessage::PasswordMessage(password.to_string()))
1166                        .await?;
1167                }
1168                BackendMessage::AuthenticationMD5Password(salt) => {
1169                    if !auth_settings.allow_md5_password {
1170                        return Err(PgError::Auth(
1171                            "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
1172                                .to_string(),
1173                        ));
1174                    }
1175                    let password = password.ok_or_else(|| {
1176                        PgError::Auth("Password required for MD5 authentication".to_string())
1177                    })?;
1178                    let md5_password = md5_password_message(user, password, salt);
1179                    self.send(FrontendMessage::PasswordMessage(md5_password))
1180                        .await?;
1181                }
1182                BackendMessage::AuthenticationSASL(mechanisms) => {
1183                    if !auth_settings.allow_scram_sha_256 {
1184                        return Err(PgError::Auth(
1185                            "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
1186                                .to_string(),
1187                        ));
1188                    }
1189                    let password = password.ok_or_else(|| {
1190                        PgError::Auth("Password required for SCRAM authentication".to_string())
1191                    })?;
1192
1193                    let tls_binding = self.tls_server_end_point_channel_binding();
1194                    let (mechanism, channel_binding_data) = select_scram_mechanism(
1195                        &mechanisms,
1196                        tls_binding,
1197                        auth_settings.channel_binding,
1198                    )
1199                    .map_err(PgError::Auth)?;
1200
1201                    let client = if let Some(binding_data) = channel_binding_data {
1202                        ScramClient::new_with_tls_server_end_point(user, password, binding_data)
1203                    } else {
1204                        ScramClient::new(user, password)
1205                    };
1206                    let first_message = client.client_first_message();
1207
1208                    self.send(FrontendMessage::SASLInitialResponse {
1209                        mechanism,
1210                        data: first_message,
1211                    })
1212                    .await?;
1213
1214                    scram_client = Some(client);
1215                }
1216                BackendMessage::AuthenticationSASLContinue(server_data) => {
1217                    let client = scram_client.as_mut().ok_or_else(|| {
1218                        PgError::Auth("Received SASL Continue without SASL init".to_string())
1219                    })?;
1220
1221                    let final_message = client
1222                        .process_server_first(&server_data)
1223                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
1224
1225                    self.send(FrontendMessage::SASLResponse(final_message))
1226                        .await?;
1227                }
1228                BackendMessage::AuthenticationSASLFinal(server_signature) => {
1229                    if let Some(client) = scram_client.as_ref() {
1230                        client.verify_server_final(&server_signature).map_err(|e| {
1231                            PgError::Auth(format!("Server verification failed: {}", e))
1232                        })?;
1233                    }
1234                }
1235                BackendMessage::ParameterStatus { .. } => {}
1236                BackendMessage::BackendKeyData {
1237                    process_id,
1238                    secret_key,
1239                } => {
1240                    self.process_id = process_id;
1241                    self.secret_key = secret_key;
1242                }
1243                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
1244                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
1245                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
1246                    return Ok(());
1247                }
1248                BackendMessage::ErrorResponse(err) => {
1249                    return Err(PgError::Connection(err.message));
1250                }
1251                _ => {}
1252            }
1253        }
1254    }
1255
1256    /// Build SCRAM `tls-server-end-point` channel-binding bytes from the server leaf cert.
1257    ///
1258    /// PostgreSQL expects the hash of the peer certificate DER for
1259    /// `SCRAM-SHA-256-PLUS` channel binding. We currently use SHA-256 here.
1260    fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
1261        let PgStream::Tls(tls) = &self.stream else {
1262            return None;
1263        };
1264
1265        let (_, conn) = tls.get_ref();
1266        let certs = conn.peer_certificates()?;
1267        let leaf_cert = certs.first()?;
1268
1269        let mut hasher = Sha256::new();
1270        hasher.update(leaf_cert.as_ref());
1271        Some(hasher.finalize().to_vec())
1272    }
1273
1274    /// Gracefully close the connection by sending a Terminate message.
1275    /// This tells the server we're done and allows proper cleanup.
1276    pub async fn close(mut self) -> PgResult<()> {
1277        use crate::protocol::PgEncoder;
1278
1279        // Send Terminate packet ('X')
1280        let terminate = PgEncoder::encode_terminate();
1281        self.stream.write_all(&terminate).await?;
1282        self.stream.flush().await?;
1283
1284        Ok(())
1285    }
1286
1287    /// Maximum prepared statements per connection before LRU eviction kicks in.
1288    ///
1289    /// This prevents memory spikes from dynamic batch filters generating
1290    /// thousands of unique SQL shapes within a single request. Using LRU
1291    /// eviction instead of nuclear `.clear()` preserves hot statements.
1292    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
1293
1294    /// Evict the least-recently-used prepared statement if at capacity.
1295    ///
1296    /// Called before every new statement registration to enforce
1297    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
1298    /// `prepared_statements` (name→SQL map) are kept in sync.
1299    pub(crate) fn evict_prepared_if_full(&mut self) {
1300        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
1301            // Pop the LRU entry from the cache
1302            if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
1303                self.prepared_statements.remove(&evicted_name);
1304            } else {
1305                // stmt_cache is empty but prepared_statements is full —
1306                // shouldn't happen in normal flow, but handle defensively
1307                // by clearing the oldest entry from the HashMap.
1308                if let Some(key) = self.prepared_statements.keys().next().cloned() {
1309                    self.prepared_statements.remove(&key);
1310                }
1311            }
1312        }
1313    }
1314
1315    /// Clear all local prepared-statement state for this connection.
1316    ///
1317    /// Used by one-shot self-heal paths when server-side statement state
1318    /// becomes invalid after DDL or failover.
1319    pub(crate) fn clear_prepared_statement_state(&mut self) {
1320        self.stmt_cache.clear();
1321        self.prepared_statements.clear();
1322        self.column_info_cache.clear();
1323    }
1324}
1325
1326fn generate_gss_token(
1327    session_id: u64,
1328    mechanism: EnterpriseAuthMechanism,
1329    server_token: Option<&[u8]>,
1330    legacy_provider: Option<GssTokenProvider>,
1331    stateful_provider: Option<&GssTokenProviderEx>,
1332) -> Result<Vec<u8>, String> {
1333    if let Some(provider) = stateful_provider {
1334        return provider(GssTokenRequest {
1335            session_id,
1336            mechanism,
1337            server_token,
1338        });
1339    }
1340
1341    if let Some(provider) = legacy_provider {
1342        return provider(mechanism, server_token);
1343    }
1344
1345    Err("No GSS token provider configured".to_string())
1346}
1347
1348fn select_scram_mechanism(
1349    mechanisms: &[String],
1350    tls_server_end_point_binding: Option<Vec<u8>>,
1351    channel_binding_mode: ScramChannelBindingMode,
1352) -> Result<(String, Option<Vec<u8>>), String> {
1353    let has_scram = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
1354    let has_scram_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
1355
1356    match channel_binding_mode {
1357        ScramChannelBindingMode::Disable => {
1358            if has_scram {
1359                return Ok(("SCRAM-SHA-256".to_string(), None));
1360            }
1361            Err(format!(
1362                "channel_binding=disable, but server does not advertise SCRAM-SHA-256. Available: {:?}",
1363                mechanisms
1364            ))
1365        }
1366        ScramChannelBindingMode::Prefer => {
1367            if has_scram_plus {
1368                if let Some(binding) = tls_server_end_point_binding {
1369                    return Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)));
1370                }
1371
1372                if has_scram {
1373                    return Ok(("SCRAM-SHA-256".to_string(), None));
1374                }
1375
1376                return Err(
1377                    "Server requires SCRAM-SHA-256-PLUS but TLS channel binding is unavailable"
1378                        .to_string(),
1379                );
1380            }
1381
1382            if has_scram {
1383                return Ok(("SCRAM-SHA-256".to_string(), None));
1384            }
1385
1386            Err(format!(
1387                "Server doesn't support SCRAM-SHA-256. Available: {:?}",
1388                mechanisms
1389            ))
1390        }
1391        ScramChannelBindingMode::Require => {
1392            if !has_scram_plus {
1393                return Err(
1394                    "channel_binding=require, but server does not advertise SCRAM-SHA-256-PLUS"
1395                        .to_string(),
1396                );
1397            }
1398            let binding = tls_server_end_point_binding.ok_or_else(|| {
1399                "channel_binding=require, but TLS channel binding data is unavailable".to_string()
1400            })?;
1401            Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)))
1402        }
1403    }
1404}
1405
1406/// PostgreSQL MD5 password response: `md5` + md5(hex(md5(password + user)) + 4-byte salt).
1407fn md5_password_message(user: &str, password: &str, salt: [u8; 4]) -> String {
1408    use md5::{Digest, Md5};
1409
1410    let mut inner = Md5::new();
1411    inner.update(password.as_bytes());
1412    inner.update(user.as_bytes());
1413    let inner_hex = format!("{:x}", inner.finalize());
1414
1415    let mut outer = Md5::new();
1416    outer.update(inner_hex.as_bytes());
1417    outer.update(salt);
1418    format!("md5{:x}", outer.finalize())
1419}
1420
1421/// Drop implementation sends Terminate packet if possible.
1422/// This ensures proper cleanup even without explicit close() call.
1423impl Drop for PgConnection {
1424    fn drop(&mut self) {
1425        // Try to send Terminate packet synchronously using try_write
1426        // This is best-effort - if it fails, TCP RST will handle cleanup
1427        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
1428
1429        match &mut self.stream {
1430            PgStream::Tcp(tcp) => {
1431                // try_write is non-blocking
1432                let _ = tcp.try_write(&terminate);
1433            }
1434            PgStream::Tls(_) => {
1435                // TLS requires async write which we can't do in Drop.
1436                // The TCP connection close will still notify the server.
1437                // For graceful TLS shutdown, use connection.close() explicitly.
1438            }
1439            #[cfg(unix)]
1440            PgStream::Unix(unix) => {
1441                let _ = unix.try_write(&terminate);
1442            }
1443        }
1444    }
1445}
1446
1447pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
1448    tag.split_whitespace()
1449        .last()
1450        .and_then(|s| s.parse().ok())
1451        .unwrap_or(0)
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456    use super::{md5_password_message, select_scram_mechanism};
1457    use crate::driver::ScramChannelBindingMode;
1458
1459    #[test]
1460    fn test_md5_password_message_known_vector() {
1461        let hash = md5_password_message("postgres", "secret", [0x12, 0x34, 0x56, 0x78]);
1462        assert_eq!(hash, "md521561af64619ca746c2a6c4d6cbedb30");
1463    }
1464
1465    #[test]
1466    fn test_md5_password_message_is_stable() {
1467        let a = md5_password_message("user_a", "pw", [1, 2, 3, 4]);
1468        let b = md5_password_message("user_a", "pw", [1, 2, 3, 4]);
1469        assert_eq!(a, b);
1470        assert!(a.starts_with("md5"));
1471        assert_eq!(a.len(), 35);
1472    }
1473
1474    #[test]
1475    fn test_select_scram_plus_when_binding_available() {
1476        let mechanisms = vec![
1477            "SCRAM-SHA-256".to_string(),
1478            "SCRAM-SHA-256-PLUS".to_string(),
1479        ];
1480        let binding = vec![1, 2, 3];
1481        let (mechanism, selected_binding) = select_scram_mechanism(
1482            &mechanisms,
1483            Some(binding.clone()),
1484            ScramChannelBindingMode::Prefer,
1485        )
1486        .unwrap();
1487        assert_eq!(mechanism, "SCRAM-SHA-256-PLUS");
1488        assert_eq!(selected_binding, Some(binding));
1489    }
1490
1491    #[test]
1492    fn test_select_scram_fallback_without_binding() {
1493        let mechanisms = vec![
1494            "SCRAM-SHA-256".to_string(),
1495            "SCRAM-SHA-256-PLUS".to_string(),
1496        ];
1497        let (mechanism, selected_binding) =
1498            select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Prefer).unwrap();
1499        assert_eq!(mechanism, "SCRAM-SHA-256");
1500        assert_eq!(selected_binding, None);
1501    }
1502
1503    #[test]
1504    fn test_select_scram_plus_only_requires_binding() {
1505        let mechanisms = vec!["SCRAM-SHA-256-PLUS".to_string()];
1506        let err =
1507            select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Prefer).unwrap_err();
1508        assert!(err.contains("SCRAM-SHA-256-PLUS"));
1509    }
1510
1511    #[test]
1512    fn test_select_scram_require_fails_without_plus() {
1513        let mechanisms = vec!["SCRAM-SHA-256".to_string()];
1514        let err = select_scram_mechanism(
1515            &mechanisms,
1516            Some(vec![1, 2, 3]),
1517            ScramChannelBindingMode::Require,
1518        )
1519        .unwrap_err();
1520        assert!(err.contains("channel_binding=require"));
1521        assert!(err.contains("SCRAM-SHA-256-PLUS"));
1522    }
1523
1524    #[test]
1525    fn test_select_scram_disable_rejects_plus_only() {
1526        let mechanisms = vec!["SCRAM-SHA-256-PLUS".to_string()];
1527        let err = select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Disable)
1528            .unwrap_err();
1529        assert!(err.contains("channel_binding=disable"));
1530    }
1531
1532    #[test]
1533    fn test_select_scram_require_fails_without_tls_binding() {
1534        let mechanisms = vec![
1535            "SCRAM-SHA-256".to_string(),
1536            "SCRAM-SHA-256-PLUS".to_string(),
1537        ];
1538        let err = select_scram_mechanism(&mechanisms, None, ScramChannelBindingMode::Require)
1539            .unwrap_err();
1540        assert!(err.contains("channel_binding=require"));
1541        assert!(err.contains("unavailable"));
1542    }
1543
1544    #[test]
1545    fn test_select_scram_require_succeeds_with_plus_and_binding() {
1546        let mechanisms = vec![
1547            "SCRAM-SHA-256".to_string(),
1548            "SCRAM-SHA-256-PLUS".to_string(),
1549        ];
1550        let binding = vec![10, 20, 30];
1551        let (mechanism, selected_binding) = select_scram_mechanism(
1552            &mechanisms,
1553            Some(binding.clone()),
1554            ScramChannelBindingMode::Require,
1555        )
1556        .unwrap();
1557        assert_eq!(mechanism, "SCRAM-SHA-256-PLUS");
1558        assert_eq!(selected_binding, Some(binding));
1559    }
1560}