Skip to main content

qail_pg/driver/connection/
connect.rs

1//! Connection establishment — connect_*, TLS, mTLS, Unix socket.
2
3#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6    connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7    record_connect_attempt, record_connect_result,
8};
9#[cfg(all(target_os = "linux", feature = "io_uring"))]
10use super::types::CONNECT_BACKEND_IO_URING;
11use super::types::{
12    BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
13    CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
14    GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
15    StatementCache, TlsConfig, has_logical_replication_startup_mode,
16};
17use crate::driver::stream::PgStream;
18use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
19use crate::protocol::wire::FrontendMessage;
20use bytes::BytesMut;
21use std::collections::{HashMap, VecDeque};
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27impl PgConnection {
28    /// Connect to PostgreSQL server without authentication (trust mode).
29    ///
30    /// # Arguments
31    ///
32    /// * `host` — PostgreSQL server hostname or IP.
33    /// * `port` — TCP port (typically 5432).
34    /// * `user` — PostgreSQL role name.
35    /// * `database` — Target database name.
36    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
37        Self::connect_with_password(host, port, user, database, None).await
38    }
39
40    /// Connect to PostgreSQL server with optional password authentication.
41    /// Includes a default 10-second timeout covering TCP connect + handshake.
42    pub async fn connect_with_password(
43        host: &str,
44        port: u16,
45        user: &str,
46        database: &str,
47        password: Option<&str>,
48    ) -> PgResult<Self> {
49        Self::connect_with_password_and_auth(
50            host,
51            port,
52            user,
53            database,
54            password,
55            AuthSettings::default(),
56        )
57        .await
58    }
59
60    /// Connect to PostgreSQL with explicit enterprise options.
61    ///
62    /// Negotiation preface order follows libpq:
63    ///   1. If gss_enc_mode != Disable → try GSSENCRequest on fresh TCP
64    ///   2. If GSSENC rejected/unavailable and tls_mode != Disable → try SSLRequest
65    ///   3. If both rejected/unavailable → plain StartupMessage
66    pub async fn connect_with_options(
67        host: &str,
68        port: u16,
69        user: &str,
70        database: &str,
71        password: Option<&str>,
72        options: ConnectOptions,
73    ) -> PgResult<Self> {
74        let ConnectOptions {
75            tls_mode,
76            gss_enc_mode,
77            tls_ca_cert_pem,
78            mtls,
79            gss_token_provider,
80            gss_token_provider_ex,
81            auth,
82            startup_params,
83        } = options;
84
85        if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
86            return Err(PgError::Connection(
87                "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
88            ));
89        }
90
91        // Enforce gss_enc_mode policy before mTLS early-return.
92        // GSSENC and mTLS are both transport-level encryption; using
93        // both simultaneously is not supported by the PostgreSQL protocol.
94        if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
95            return Err(PgError::Connection(
96                "gssencmode=require is incompatible with mTLS — both provide \
97                 transport encryption; use one or the other"
98                    .to_string(),
99            ));
100        }
101
102        if let Some(mtls_config) = mtls {
103            // gss_enc_mode is Disable or Prefer here (Require rejected above).
104            // mTLS already provides transport encryption; skip GSSENC.
105            return Self::connect_mtls_with_password_and_auth_and_gss(
106                ConnectParams {
107                    host,
108                    port,
109                    user,
110                    database,
111                    password,
112                    auth_settings: auth,
113                    gss_token_provider,
114                    gss_token_provider_ex,
115                    startup_params: startup_params.clone(),
116                },
117                mtls_config,
118            )
119            .await;
120        }
121
122        // ── Phase 1: Try GSSENC if requested ──────────────────────────
123        if gss_enc_mode != GssEncMode::Disable {
124            match Self::try_gssenc_request(host, port).await {
125                Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
126                    let connect_started = Instant::now();
127                    record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
128                    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
129                    {
130                        let gssenc_fut = async {
131                            let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
132                                .await
133                                .map_err(PgError::Auth)?;
134                            let mut conn = Self {
135                                stream: PgStream::GssEnc(gss_stream),
136                                buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
137                                write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
138                                sql_buf: BytesMut::with_capacity(512),
139                                params_buf: Vec::with_capacity(16),
140                                prepared_statements: HashMap::new(),
141                                stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
142                                column_info_cache: HashMap::new(),
143                                process_id: 0,
144                                secret_key: 0,
145                                notifications: VecDeque::new(),
146                                replication_stream_active: false,
147                                replication_mode_enabled: has_logical_replication_startup_mode(
148                                    &startup_params,
149                                ),
150                                last_replication_wal_end: None,
151                                io_desynced: false,
152                                pending_statement_closes: Vec::new(),
153                                draining_statement_closes: false,
154                            };
155                            conn.send(FrontendMessage::Startup {
156                                user: user.to_string(),
157                                database: database.to_string(),
158                                startup_params: startup_params.clone(),
159                            })
160                            .await?;
161                            conn.handle_startup(
162                                user,
163                                password,
164                                auth,
165                                gss_token_provider,
166                                gss_token_provider_ex,
167                            )
168                            .await?;
169                            Ok(conn)
170                        };
171                        let result: PgResult<Self> =
172                            tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
173                                .await
174                                .map_err(|_| {
175                                    PgError::Connection(format!(
176                                        "GSSENC connection timeout after {:?} \
177                                 (handshake + auth)",
178                                        DEFAULT_CONNECT_TIMEOUT
179                                    ))
180                                })?;
181                        record_connect_result(
182                            CONNECT_TRANSPORT_GSSENC,
183                            CONNECT_BACKEND_TOKIO,
184                            &result,
185                            connect_started.elapsed(),
186                        );
187                        return result;
188                    }
189                    #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
190                    {
191                        let _ = tcp_stream;
192                        let err = PgError::Connection(
193                            "Server accepted GSSENCRequest but GSSAPI encryption requires \
194                             feature enterprise-gssapi on Linux"
195                                .to_string(),
196                        );
197                        metrics::histogram!(
198                            "qail_pg_connect_duration_seconds",
199                            "transport" => CONNECT_TRANSPORT_GSSENC,
200                            "backend" => CONNECT_BACKEND_TOKIO,
201                            "outcome" => "error"
202                        )
203                        .record(connect_started.elapsed().as_secs_f64());
204                        metrics::counter!(
205                            "qail_pg_connect_failure_total",
206                            "transport" => CONNECT_TRANSPORT_GSSENC,
207                            "backend" => CONNECT_BACKEND_TOKIO,
208                            "error_kind" => connect_error_kind(&err)
209                        )
210                        .increment(1);
211                        return Err(err);
212                    }
213                }
214                Ok(GssEncNegotiationResult::Rejected)
215                | Ok(GssEncNegotiationResult::ServerError) => {
216                    if gss_enc_mode == GssEncMode::Require {
217                        return Err(PgError::Connection(
218                            "gssencmode=require but server rejected GSSENCRequest".to_string(),
219                        ));
220                    }
221                    // gss_enc_mode == Prefer — fall through to TLS / plain
222                }
223                Err(e) => {
224                    if gss_enc_mode == GssEncMode::Require {
225                        return Err(e);
226                    }
227                    // gss_enc_mode == Prefer — connection error, fall through
228                    tracing::debug!(
229                        host = %host,
230                        port = %port,
231                        error = %e,
232                        "gssenc_prefer_fallthrough"
233                    );
234                }
235            }
236        }
237
238        // ── Phase 2: TLS / plain per sslmode ──────────────────────────
239        match tls_mode {
240            TlsMode::Disable => {
241                Self::connect_with_password_and_auth_and_gss(ConnectParams {
242                    host,
243                    port,
244                    user,
245                    database,
246                    password,
247                    auth_settings: auth,
248                    gss_token_provider,
249                    gss_token_provider_ex,
250                    startup_params: startup_params.clone(),
251                })
252                .await
253            }
254            TlsMode::Require => {
255                Self::connect_tls_with_auth_and_gss(
256                    ConnectParams {
257                        host,
258                        port,
259                        user,
260                        database,
261                        password,
262                        auth_settings: auth,
263                        gss_token_provider,
264                        gss_token_provider_ex,
265                        startup_params: startup_params.clone(),
266                    },
267                    tls_ca_cert_pem.as_deref(),
268                )
269                .await
270            }
271            TlsMode::Prefer => {
272                match Self::connect_tls_with_auth_and_gss(
273                    ConnectParams {
274                        host,
275                        port,
276                        user,
277                        database,
278                        password,
279                        auth_settings: auth,
280                        gss_token_provider,
281                        gss_token_provider_ex: gss_token_provider_ex.clone(),
282                        startup_params: startup_params.clone(),
283                    },
284                    tls_ca_cert_pem.as_deref(),
285                )
286                .await
287                {
288                    Ok(conn) => Ok(conn),
289                    Err(PgError::Connection(msg))
290                        if msg.contains("Server does not support TLS") =>
291                    {
292                        Self::connect_with_password_and_auth_and_gss(ConnectParams {
293                            host,
294                            port,
295                            user,
296                            database,
297                            password,
298                            auth_settings: auth,
299                            gss_token_provider,
300                            gss_token_provider_ex,
301                            startup_params: startup_params.clone(),
302                        })
303                        .await
304                    }
305                    Err(e) => Err(e),
306                }
307            }
308        }
309    }
310
311    /// Attempt GSSAPI session encryption negotiation.
312    ///
313    /// Opens a fresh TCP connection, sends GSSENCRequest (80877104),
314    /// reads exactly one byte (CVE-2021-23222 safe), and returns
315    /// the result.  The entire operation is bounded by
316    /// `DEFAULT_CONNECT_TIMEOUT`.
317    async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
318        tokio::time::timeout(
319            DEFAULT_CONNECT_TIMEOUT,
320            Self::try_gssenc_request_inner(host, port),
321        )
322        .await
323        .map_err(|_| {
324            PgError::Connection(format!(
325                "GSSENCRequest timeout after {:?}",
326                DEFAULT_CONNECT_TIMEOUT
327            ))
328        })?
329    }
330
331    /// Inner GSSENCRequest logic without timeout wrapper.
332    async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
333        use tokio::io::AsyncReadExt;
334
335        let addr = format!("{}:{}", host, port);
336        let mut tcp_stream = TcpStream::connect(&addr).await?;
337        tcp_stream.set_nodelay(true)?;
338
339        // Send the 8-byte GSSENCRequest.
340        tcp_stream.write_all(&GSSENC_REQUEST).await?;
341        tcp_stream.flush().await?;
342
343        // CVE-2021-23222: Read exactly one byte.  The server must
344        // respond with a single 'G' or 'N'.  Any additional bytes
345        // in the buffer indicate a buffer-stuffing attack.
346        let mut response = [0u8; 1];
347        tcp_stream.read_exact(&mut response).await?;
348
349        match response[0] {
350            b'G' => {
351                // CVE-2021-23222 check: verify no extra bytes are buffered.
352                // Use a non-blocking peek to detect leftover data.
353                let mut peek_buf = [0u8; 1];
354                match tcp_stream.try_read(&mut peek_buf) {
355                    Ok(0) => {} // EOF — fine (shouldn't happen yet but harmless)
356                    Ok(_n) => {
357                        // Extra bytes after 'G' — possible buffer-stuffing.
358                        return Err(PgError::Connection(
359                            "Protocol violation: extra bytes after GSSENCRequest 'G' response \
360                             (possible CVE-2021-23222 buffer-stuffing attack)"
361                                .to_string(),
362                        ));
363                    }
364                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
365                        // No extra data — this is the expected path.
366                    }
367                    Err(e) => {
368                        return Err(PgError::Io(e));
369                    }
370                }
371                Ok(GssEncNegotiationResult::Accepted(tcp_stream))
372            }
373            b'N' => Ok(GssEncNegotiationResult::Rejected),
374            b'E' => {
375                // Server sent an ErrorMessage.  Per CVE-2024-10977 we
376                // must NOT display this to users since the server has
377                // not been authenticated.  Log at trace only.
378                tracing::trace!(
379                    host = %host,
380                    port = %port,
381                    "gssenc_request_server_error (suppressed per CVE-2024-10977)"
382                );
383                Ok(GssEncNegotiationResult::ServerError)
384            }
385            other => Err(PgError::Connection(format!(
386                "Unexpected response to GSSENCRequest: 0x{:02X} \
387                     (expected 'G'=0x47 or 'N'=0x4E)",
388                other
389            ))),
390        }
391    }
392
393    /// Connect to PostgreSQL server with optional password authentication and auth policy.
394    pub async fn connect_with_password_and_auth(
395        host: &str,
396        port: u16,
397        user: &str,
398        database: &str,
399        password: Option<&str>,
400        auth_settings: AuthSettings,
401    ) -> PgResult<Self> {
402        Self::connect_with_password_and_auth_and_gss(ConnectParams {
403            host,
404            port,
405            user,
406            database,
407            password,
408            auth_settings,
409            gss_token_provider: None,
410            gss_token_provider_ex: None,
411            startup_params: Vec::new(),
412        })
413        .await
414    }
415
416    async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
417        let connect_started = Instant::now();
418        let attempt_backend = plain_connect_attempt_backend();
419        record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
420        let result = tokio::time::timeout(
421            DEFAULT_CONNECT_TIMEOUT,
422            Self::connect_with_password_inner(params),
423        )
424        .await
425        .map_err(|_| {
426            PgError::Connection(format!(
427                "Connection timeout after {:?} (TCP connect + handshake)",
428                DEFAULT_CONNECT_TIMEOUT
429            ))
430        })?;
431        let backend = result
432            .as_ref()
433            .map(|conn| connect_backend_for_stream(&conn.stream))
434            .unwrap_or(attempt_backend);
435        record_connect_result(
436            CONNECT_TRANSPORT_PLAIN,
437            backend,
438            &result,
439            connect_started.elapsed(),
440        );
441        result
442    }
443
444    /// Inner connection logic without timeout wrapper.
445    async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
446        let ConnectParams {
447            host,
448            port,
449            user,
450            database,
451            password,
452            auth_settings,
453            gss_token_provider,
454            gss_token_provider_ex,
455            startup_params,
456        } = params;
457        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
458        let addr = format!("{}:{}", host, port);
459        let stream = Self::connect_plain_stream(&addr).await?;
460
461        let mut conn = Self {
462            stream,
463            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
464            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
465            sql_buf: BytesMut::with_capacity(512),
466            params_buf: Vec::with_capacity(16), // SQL encoding buffer
467            prepared_statements: HashMap::new(),
468            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
469            column_info_cache: HashMap::new(),
470            process_id: 0,
471            secret_key: 0,
472            notifications: VecDeque::new(),
473            replication_stream_active: false,
474            replication_mode_enabled,
475            last_replication_wal_end: None,
476            io_desynced: false,
477            pending_statement_closes: Vec::new(),
478            draining_statement_closes: false,
479        };
480
481        conn.send(FrontendMessage::Startup {
482            user: user.to_string(),
483            database: database.to_string(),
484            startup_params,
485        })
486        .await?;
487
488        conn.handle_startup(
489            user,
490            password,
491            auth_settings,
492            gss_token_provider,
493            gss_token_provider_ex,
494        )
495        .await?;
496
497        Ok(conn)
498    }
499
500    async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
501        let tcp_stream = TcpStream::connect(addr).await?;
502        tcp_stream.set_nodelay(true)?;
503
504        #[cfg(all(target_os = "linux", feature = "io_uring"))]
505        {
506            if should_try_uring_plain() {
507                match super::super::uring::UringTcpStream::from_tokio(tcp_stream) {
508                    Ok(uring_stream) => {
509                        tracing::info!(
510                            addr = %addr,
511                            "qail-pg: using io_uring plain TCP transport"
512                        );
513                        return Ok(PgStream::Uring(uring_stream));
514                    }
515                    Err(e) => {
516                        tracing::warn!(
517                            addr = %addr,
518                            error = %e,
519                            "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
520                        );
521                        let fallback = TcpStream::connect(addr).await?;
522                        fallback.set_nodelay(true)?;
523                        return Ok(PgStream::Tcp(fallback));
524                    }
525                }
526            }
527        }
528
529        Ok(PgStream::Tcp(tcp_stream))
530    }
531
532    /// Connect to PostgreSQL server with TLS encryption.
533    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
534    pub async fn connect_tls(
535        host: &str,
536        port: u16,
537        user: &str,
538        database: &str,
539        password: Option<&str>,
540    ) -> PgResult<Self> {
541        Self::connect_tls_with_auth(
542            host,
543            port,
544            user,
545            database,
546            password,
547            AuthSettings::default(),
548            None,
549        )
550        .await
551    }
552
553    /// Connect to PostgreSQL over TLS with explicit auth policy and optional custom CA bundle.
554    pub async fn connect_tls_with_auth(
555        host: &str,
556        port: u16,
557        user: &str,
558        database: &str,
559        password: Option<&str>,
560        auth_settings: AuthSettings,
561        ca_cert_pem: Option<&[u8]>,
562    ) -> PgResult<Self> {
563        Self::connect_tls_with_auth_and_gss(
564            ConnectParams {
565                host,
566                port,
567                user,
568                database,
569                password,
570                auth_settings,
571                gss_token_provider: None,
572                gss_token_provider_ex: None,
573                startup_params: Vec::new(),
574            },
575            ca_cert_pem,
576        )
577        .await
578    }
579
580    async fn connect_tls_with_auth_and_gss(
581        params: ConnectParams<'_>,
582        ca_cert_pem: Option<&[u8]>,
583    ) -> PgResult<Self> {
584        let connect_started = Instant::now();
585        record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
586        let result = tokio::time::timeout(
587            DEFAULT_CONNECT_TIMEOUT,
588            Self::connect_tls_inner(params, ca_cert_pem),
589        )
590        .await
591        .map_err(|_| {
592            PgError::Connection(format!(
593                "TLS connection timeout after {:?}",
594                DEFAULT_CONNECT_TIMEOUT
595            ))
596        })?;
597        record_connect_result(
598            CONNECT_TRANSPORT_TLS,
599            CONNECT_BACKEND_TOKIO,
600            &result,
601            connect_started.elapsed(),
602        );
603        result
604    }
605
606    /// Inner TLS connection logic without timeout wrapper.
607    async fn connect_tls_inner(
608        params: ConnectParams<'_>,
609        ca_cert_pem: Option<&[u8]>,
610    ) -> PgResult<Self> {
611        let ConnectParams {
612            host,
613            port,
614            user,
615            database,
616            password,
617            auth_settings,
618            gss_token_provider,
619            gss_token_provider_ex,
620            startup_params,
621        } = params;
622        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
623        use tokio::io::AsyncReadExt;
624        use tokio_rustls::TlsConnector;
625        use tokio_rustls::rustls::ClientConfig;
626        use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
627
628        let addr = format!("{}:{}", host, port);
629        let mut tcp_stream = TcpStream::connect(&addr).await?;
630
631        // Send SSLRequest
632        tcp_stream.write_all(&SSL_REQUEST).await?;
633
634        // Read response
635        let mut response = [0u8; 1];
636        tcp_stream.read_exact(&mut response).await?;
637
638        if response[0] != b'S' {
639            return Err(PgError::Connection(
640                "Server does not support TLS".to_string(),
641            ));
642        }
643
644        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
645
646        if let Some(ca_pem) = ca_cert_pem {
647            let certs = CertificateDer::pem_slice_iter(ca_pem)
648                .collect::<Result<Vec<_>, _>>()
649                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
650            if certs.is_empty() {
651                return Err(PgError::Connection(
652                    "No CA certificates found in provided PEM".to_string(),
653                ));
654            }
655            for cert in certs {
656                let _ = root_cert_store.add(cert);
657            }
658        } else {
659            let certs = rustls_native_certs::load_native_certs();
660            for cert in certs.certs {
661                let _ = root_cert_store.add(cert);
662            }
663        }
664
665        let config = ClientConfig::builder()
666            .with_root_certificates(root_cert_store)
667            .with_no_client_auth();
668
669        let connector = TlsConnector::from(Arc::new(config));
670        let server_name = ServerName::try_from(host.to_string())
671            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
672
673        let tls_stream = connector
674            .connect(server_name, tcp_stream)
675            .await
676            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
677
678        let mut conn = Self {
679            stream: PgStream::Tls(Box::new(tls_stream)),
680            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
681            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
682            sql_buf: BytesMut::with_capacity(512),
683            params_buf: Vec::with_capacity(16),
684            prepared_statements: HashMap::new(),
685            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
686            column_info_cache: HashMap::new(),
687            process_id: 0,
688            secret_key: 0,
689            notifications: VecDeque::new(),
690            replication_stream_active: false,
691            replication_mode_enabled,
692            last_replication_wal_end: None,
693            io_desynced: false,
694            pending_statement_closes: Vec::new(),
695            draining_statement_closes: false,
696        };
697
698        conn.send(FrontendMessage::Startup {
699            user: user.to_string(),
700            database: database.to_string(),
701            startup_params,
702        })
703        .await?;
704
705        conn.handle_startup(
706            user,
707            password,
708            auth_settings,
709            gss_token_provider,
710            gss_token_provider_ex,
711        )
712        .await?;
713
714        Ok(conn)
715    }
716
717    /// Connect with mutual TLS (client certificate authentication).
718    /// # Arguments
719    /// * `host` - PostgreSQL server hostname
720    /// * `port` - PostgreSQL server port
721    /// * `user` - Database user
722    /// * `database` - Database name
723    /// * `config` - TLS configuration with client cert/key
724    /// # Example
725    /// ```ignore
726    /// let config = TlsConfig {
727    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
728    ///     client_key_pem: include_bytes!("client.key").to_vec(),
729    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
730    /// };
731    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
732    /// ```
733    pub async fn connect_mtls(
734        host: &str,
735        port: u16,
736        user: &str,
737        database: &str,
738        config: TlsConfig,
739    ) -> PgResult<Self> {
740        Self::connect_mtls_with_password_and_auth(
741            host,
742            port,
743            user,
744            database,
745            None,
746            config,
747            AuthSettings::default(),
748        )
749        .await
750    }
751
752    /// Connect with mutual TLS and optional password fallback.
753    pub async fn connect_mtls_with_password_and_auth(
754        host: &str,
755        port: u16,
756        user: &str,
757        database: &str,
758        password: Option<&str>,
759        config: TlsConfig,
760        auth_settings: AuthSettings,
761    ) -> PgResult<Self> {
762        Self::connect_mtls_with_password_and_auth_and_gss(
763            ConnectParams {
764                host,
765                port,
766                user,
767                database,
768                password,
769                auth_settings,
770                gss_token_provider: None,
771                gss_token_provider_ex: None,
772                startup_params: Vec::new(),
773            },
774            config,
775        )
776        .await
777    }
778
779    async fn connect_mtls_with_password_and_auth_and_gss(
780        params: ConnectParams<'_>,
781        config: TlsConfig,
782    ) -> PgResult<Self> {
783        let connect_started = Instant::now();
784        record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
785        let result = tokio::time::timeout(
786            DEFAULT_CONNECT_TIMEOUT,
787            Self::connect_mtls_inner(params, config),
788        )
789        .await
790        .map_err(|_| {
791            PgError::Connection(format!(
792                "mTLS connection timeout after {:?}",
793                DEFAULT_CONNECT_TIMEOUT
794            ))
795        })?;
796        record_connect_result(
797            CONNECT_TRANSPORT_MTLS,
798            CONNECT_BACKEND_TOKIO,
799            &result,
800            connect_started.elapsed(),
801        );
802        result
803    }
804
805    /// Inner mTLS connection logic without timeout wrapper.
806    async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
807        let ConnectParams {
808            host,
809            port,
810            user,
811            database,
812            password,
813            auth_settings,
814            gss_token_provider,
815            gss_token_provider_ex,
816            startup_params,
817        } = params;
818        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
819        use tokio::io::AsyncReadExt;
820        use tokio_rustls::TlsConnector;
821        use tokio_rustls::rustls::{
822            ClientConfig,
823            pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
824        };
825
826        let addr = format!("{}:{}", host, port);
827        let mut tcp_stream = TcpStream::connect(&addr).await?;
828
829        // Send SSLRequest
830        tcp_stream.write_all(&SSL_REQUEST).await?;
831
832        // Read response
833        let mut response = [0u8; 1];
834        tcp_stream.read_exact(&mut response).await?;
835
836        if response[0] != b'S' {
837            return Err(PgError::Connection(
838                "Server does not support TLS".to_string(),
839            ));
840        }
841
842        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
843
844        if let Some(ca_pem) = &config.ca_cert_pem {
845            let certs = CertificateDer::pem_slice_iter(ca_pem)
846                .collect::<Result<Vec<_>, _>>()
847                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
848            if certs.is_empty() {
849                return Err(PgError::Connection(
850                    "No CA certificates found in provided PEM".to_string(),
851                ));
852            }
853            for cert in certs {
854                let _ = root_cert_store.add(cert);
855            }
856        } else {
857            // Use system certs
858            let certs = rustls_native_certs::load_native_certs();
859            for cert in certs.certs {
860                let _ = root_cert_store.add(cert);
861            }
862        }
863
864        let client_certs: Vec<CertificateDer<'static>> =
865            CertificateDer::pem_slice_iter(&config.client_cert_pem)
866                .collect::<Result<Vec<_>, _>>()
867                .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
868        if client_certs.is_empty() {
869            return Err(PgError::Connection(
870                "No client certificates found in PEM".to_string(),
871            ));
872        }
873
874        let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
875            .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
876
877        let tls_config = ClientConfig::builder()
878            .with_root_certificates(root_cert_store)
879            .with_client_auth_cert(client_certs, client_key)
880            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
881
882        let connector = TlsConnector::from(Arc::new(tls_config));
883        let server_name = ServerName::try_from(host.to_string())
884            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
885
886        let tls_stream = connector
887            .connect(server_name, tcp_stream)
888            .await
889            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
890
891        let mut conn = Self {
892            stream: PgStream::Tls(Box::new(tls_stream)),
893            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
894            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
895            sql_buf: BytesMut::with_capacity(512),
896            params_buf: Vec::with_capacity(16),
897            prepared_statements: HashMap::new(),
898            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
899            column_info_cache: HashMap::new(),
900            process_id: 0,
901            secret_key: 0,
902            notifications: VecDeque::new(),
903            replication_stream_active: false,
904            replication_mode_enabled,
905            last_replication_wal_end: None,
906            io_desynced: false,
907            pending_statement_closes: Vec::new(),
908            draining_statement_closes: false,
909        };
910
911        conn.send(FrontendMessage::Startup {
912            user: user.to_string(),
913            database: database.to_string(),
914            startup_params,
915        })
916        .await?;
917
918        conn.handle_startup(
919            user,
920            password,
921            auth_settings,
922            gss_token_provider,
923            gss_token_provider_ex,
924        )
925        .await?;
926
927        Ok(conn)
928    }
929
930    /// Connect to PostgreSQL server via Unix domain socket.
931    #[cfg(unix)]
932    pub async fn connect_unix(
933        socket_path: &str,
934        user: &str,
935        database: &str,
936        password: Option<&str>,
937    ) -> PgResult<Self> {
938        use tokio::net::UnixStream;
939
940        let unix_stream = UnixStream::connect(socket_path).await?;
941
942        let mut conn = Self {
943            stream: PgStream::Unix(unix_stream),
944            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
945            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
946            sql_buf: BytesMut::with_capacity(512),
947            params_buf: Vec::with_capacity(16),
948            prepared_statements: HashMap::new(),
949            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
950            column_info_cache: HashMap::new(),
951            process_id: 0,
952            secret_key: 0,
953            notifications: VecDeque::new(),
954            replication_stream_active: false,
955            replication_mode_enabled: false,
956            last_replication_wal_end: None,
957            io_desynced: false,
958            pending_statement_closes: Vec::new(),
959            draining_statement_closes: false,
960        };
961
962        conn.send(FrontendMessage::Startup {
963            user: user.to_string(),
964            database: database.to_string(),
965            startup_params: Vec::new(),
966        })
967        .await?;
968
969        conn.handle_startup(user, password, AuthSettings::default(), None, None)
970            .await?;
971
972        Ok(conn)
973    }
974}