Skip to main content

qail_pg/driver/connection/
connect.rs

1//! Connection establishment — connect_*, TLS, mTLS, Unix socket.
2
3#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6    connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7    record_connect_attempt, record_connect_result,
8};
9#[cfg(all(target_os = "linux", feature = "io_uring"))]
10use super::types::CONNECT_BACKEND_IO_URING;
11use super::types::{
12    BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
13    CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
14    GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
15    StatementCache, TlsConfig, has_logical_replication_startup_mode,
16};
17use crate::driver::stream::PgStream;
18use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
19use crate::protocol::PROTOCOL_VERSION_3_0;
20use crate::protocol::wire::FrontendMessage;
21use bytes::BytesMut;
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::io::AsyncWriteExt;
26use tokio::net::TcpStream;
27
28#[inline]
29fn protocol_version_from_minor(minor: u16) -> i32 {
30    ((3i32) << 16) | i32::from(minor)
31}
32
33fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
34    let msg = match err {
35        PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
36        PgError::Query(msg) => msg,
37        PgError::QueryServer(server) => &server.message,
38        _ => return false,
39    };
40
41    let lower = msg.to_ascii_lowercase();
42    lower.contains("unsupported frontend protocol")
43        || lower.contains("frontend protocol") && lower.contains("unsupported")
44        || lower.contains("protocol version") && lower.contains("not support")
45}
46
47impl PgConnection {
48    /// Connect to PostgreSQL server without authentication (trust mode).
49    ///
50    /// # Arguments
51    ///
52    /// * `host` — PostgreSQL server hostname or IP.
53    /// * `port` — TCP port (typically 5432).
54    /// * `user` — PostgreSQL role name.
55    /// * `database` — Target database name.
56    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
57        Self::connect_with_password(host, port, user, database, None).await
58    }
59
60    /// Connect to PostgreSQL server with optional password authentication.
61    /// Includes a default 10-second timeout covering TCP connect + handshake.
62    ///
63    /// Startup requests protocol 3.2 by default and performs a one-shot retry
64    /// with protocol 3.0 only when startup fails due to explicit
65    /// protocol-version rejection from the server.
66    pub async fn connect_with_password(
67        host: &str,
68        port: u16,
69        user: &str,
70        database: &str,
71        password: Option<&str>,
72    ) -> PgResult<Self> {
73        Self::connect_with_password_and_auth(
74            host,
75            port,
76            user,
77            database,
78            password,
79            AuthSettings::default(),
80        )
81        .await
82    }
83
84    /// Connect to PostgreSQL with explicit enterprise options.
85    ///
86    /// Negotiation preface order follows libpq:
87    ///   1. If gss_enc_mode != Disable → try GSSENCRequest on fresh TCP
88    ///   2. If GSSENC rejected/unavailable and tls_mode != Disable → try SSLRequest
89    ///   3. If both rejected/unavailable → plain StartupMessage
90    ///
91    /// The StartupMessage protocol version behavior is the same as
92    /// `connect_with_password`: request protocol 3.2 first, then retry once
93    /// with 3.0 only on explicit protocol-version rejection.
94    pub async fn connect_with_options(
95        host: &str,
96        port: u16,
97        user: &str,
98        database: &str,
99        password: Option<&str>,
100        options: ConnectOptions,
101    ) -> PgResult<Self> {
102        let ConnectOptions {
103            tls_mode,
104            gss_enc_mode,
105            tls_ca_cert_pem,
106            mtls,
107            gss_token_provider,
108            gss_token_provider_ex,
109            auth,
110            startup_params,
111        } = options;
112
113        if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
114            return Err(PgError::Connection(
115                "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
116            ));
117        }
118
119        // Enforce gss_enc_mode policy before mTLS early-return.
120        // GSSENC and mTLS are both transport-level encryption; using
121        // both simultaneously is not supported by the PostgreSQL protocol.
122        if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
123            return Err(PgError::Connection(
124                "gssencmode=require is incompatible with mTLS — both provide \
125                 transport encryption; use one or the other"
126                    .to_string(),
127            ));
128        }
129
130        if let Some(mtls_config) = mtls {
131            // gss_enc_mode is Disable or Prefer here (Require rejected above).
132            // mTLS already provides transport encryption; skip GSSENC.
133            return Self::connect_mtls_with_password_and_auth_and_gss(
134                ConnectParams {
135                    host,
136                    port,
137                    user,
138                    database,
139                    password,
140                    auth_settings: auth,
141                    gss_token_provider,
142                    gss_token_provider_ex,
143                    protocol_minor: Self::default_protocol_minor(),
144                    startup_params: startup_params.clone(),
145                },
146                mtls_config,
147            )
148            .await;
149        }
150
151        // ── Phase 1: Try GSSENC if requested ──────────────────────────
152        if gss_enc_mode != GssEncMode::Disable {
153            match Self::try_gssenc_request(host, port).await {
154                Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
155                    let connect_started = Instant::now();
156                    record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
157                    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
158                    {
159                        let default_minor = Self::default_protocol_minor();
160                        let mut result = Self::connect_gssenc_accepted_with_timeout(
161                            tcp_stream,
162                            host,
163                            user,
164                            database,
165                            password,
166                            auth,
167                            gss_token_provider,
168                            gss_token_provider_ex.clone(),
169                            startup_params.clone(),
170                            default_minor,
171                        )
172                        .await;
173                        if let Err(err) = &result {
174                            if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
175                                let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
176                                let retry_stream = match Self::try_gssenc_request(host, port).await
177                                {
178                                    Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
179                                    Ok(GssEncNegotiationResult::Rejected) => {
180                                        return Err(PgError::Connection(
181                                            "Protocol downgrade retry failed: server rejected GSSENCRequest"
182                                                .to_string(),
183                                        ));
184                                    }
185                                    Ok(GssEncNegotiationResult::ServerError) => {
186                                        return Err(PgError::Connection(
187                                            "Protocol downgrade retry failed: server returned error to GSSENCRequest"
188                                                .to_string(),
189                                        ));
190                                    }
191                                    Err(e) => {
192                                        return Err(e);
193                                    }
194                                };
195                                result = Self::connect_gssenc_accepted_with_timeout(
196                                    retry_stream,
197                                    host,
198                                    user,
199                                    database,
200                                    password,
201                                    auth,
202                                    gss_token_provider,
203                                    gss_token_provider_ex,
204                                    startup_params.clone(),
205                                    downgrade_minor,
206                                )
207                                .await;
208                            }
209                        }
210                        record_connect_result(
211                            CONNECT_TRANSPORT_GSSENC,
212                            CONNECT_BACKEND_TOKIO,
213                            &result,
214                            connect_started.elapsed(),
215                        );
216                        return result;
217                    }
218                    #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
219                    {
220                        let _ = tcp_stream;
221                        let err = PgError::Connection(
222                            "Server accepted GSSENCRequest but GSSAPI encryption requires \
223                             feature enterprise-gssapi on Linux"
224                                .to_string(),
225                        );
226                        metrics::histogram!(
227                            "qail_pg_connect_duration_seconds",
228                            "transport" => CONNECT_TRANSPORT_GSSENC,
229                            "backend" => CONNECT_BACKEND_TOKIO,
230                            "outcome" => "error"
231                        )
232                        .record(connect_started.elapsed().as_secs_f64());
233                        metrics::counter!(
234                            "qail_pg_connect_failure_total",
235                            "transport" => CONNECT_TRANSPORT_GSSENC,
236                            "backend" => CONNECT_BACKEND_TOKIO,
237                            "error_kind" => connect_error_kind(&err)
238                        )
239                        .increment(1);
240                        return Err(err);
241                    }
242                }
243                Ok(GssEncNegotiationResult::Rejected)
244                | Ok(GssEncNegotiationResult::ServerError) => {
245                    if gss_enc_mode == GssEncMode::Require {
246                        return Err(PgError::Connection(
247                            "gssencmode=require but server rejected GSSENCRequest".to_string(),
248                        ));
249                    }
250                    // gss_enc_mode == Prefer — fall through to TLS / plain
251                }
252                Err(e) => {
253                    if gss_enc_mode == GssEncMode::Require {
254                        return Err(e);
255                    }
256                    // gss_enc_mode == Prefer — connection error, fall through
257                    tracing::debug!(
258                        host = %host,
259                        port = %port,
260                        error = %e,
261                        "gssenc_prefer_fallthrough"
262                    );
263                }
264            }
265        }
266
267        // ── Phase 2: TLS / plain per sslmode ──────────────────────────
268        match tls_mode {
269            TlsMode::Disable => {
270                Self::connect_with_password_and_auth_and_gss(ConnectParams {
271                    host,
272                    port,
273                    user,
274                    database,
275                    password,
276                    auth_settings: auth,
277                    gss_token_provider,
278                    gss_token_provider_ex,
279                    protocol_minor: Self::default_protocol_minor(),
280                    startup_params: startup_params.clone(),
281                })
282                .await
283            }
284            TlsMode::Require => {
285                Self::connect_tls_with_auth_and_gss(
286                    ConnectParams {
287                        host,
288                        port,
289                        user,
290                        database,
291                        password,
292                        auth_settings: auth,
293                        gss_token_provider,
294                        gss_token_provider_ex,
295                        protocol_minor: Self::default_protocol_minor(),
296                        startup_params: startup_params.clone(),
297                    },
298                    tls_ca_cert_pem.as_deref(),
299                )
300                .await
301            }
302            TlsMode::Prefer => {
303                match Self::connect_tls_with_auth_and_gss(
304                    ConnectParams {
305                        host,
306                        port,
307                        user,
308                        database,
309                        password,
310                        auth_settings: auth,
311                        gss_token_provider,
312                        gss_token_provider_ex: gss_token_provider_ex.clone(),
313                        protocol_minor: Self::default_protocol_minor(),
314                        startup_params: startup_params.clone(),
315                    },
316                    tls_ca_cert_pem.as_deref(),
317                )
318                .await
319                {
320                    Ok(conn) => Ok(conn),
321                    Err(PgError::Connection(msg))
322                        if msg.contains("Server does not support TLS") =>
323                    {
324                        Self::connect_with_password_and_auth_and_gss(ConnectParams {
325                            host,
326                            port,
327                            user,
328                            database,
329                            password,
330                            auth_settings: auth,
331                            gss_token_provider,
332                            gss_token_provider_ex,
333                            protocol_minor: Self::default_protocol_minor(),
334                            startup_params: startup_params.clone(),
335                        })
336                        .await
337                    }
338                    Err(e) => Err(e),
339                }
340            }
341        }
342    }
343
344    /// Attempt GSSAPI session encryption negotiation.
345    ///
346    /// Opens a fresh TCP connection, sends GSSENCRequest (80877104),
347    /// reads exactly one byte (CVE-2021-23222 safe), and returns
348    /// the result.  The entire operation is bounded by
349    /// `DEFAULT_CONNECT_TIMEOUT`.
350    async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
351        tokio::time::timeout(
352            DEFAULT_CONNECT_TIMEOUT,
353            Self::try_gssenc_request_inner(host, port),
354        )
355        .await
356        .map_err(|_| {
357            PgError::Connection(format!(
358                "GSSENCRequest timeout after {:?}",
359                DEFAULT_CONNECT_TIMEOUT
360            ))
361        })?
362    }
363
364    /// Inner GSSENCRequest logic without timeout wrapper.
365    async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
366        use tokio::io::AsyncReadExt;
367
368        let addr = format!("{}:{}", host, port);
369        let mut tcp_stream = TcpStream::connect(&addr).await?;
370        tcp_stream.set_nodelay(true)?;
371
372        // Send the 8-byte GSSENCRequest.
373        tcp_stream.write_all(&GSSENC_REQUEST).await?;
374        tcp_stream.flush().await?;
375
376        // CVE-2021-23222: Read exactly one byte.  The server must
377        // respond with a single 'G' or 'N'.  Any additional bytes
378        // in the buffer indicate a buffer-stuffing attack.
379        let mut response = [0u8; 1];
380        tcp_stream.read_exact(&mut response).await?;
381
382        match response[0] {
383            b'G' => {
384                // CVE-2021-23222 check: verify no extra bytes are buffered.
385                // Use a non-blocking peek to detect leftover data.
386                let mut peek_buf = [0u8; 1];
387                match tcp_stream.try_read(&mut peek_buf) {
388                    Ok(0) => {} // EOF — fine (shouldn't happen yet but harmless)
389                    Ok(_n) => {
390                        // Extra bytes after 'G' — possible buffer-stuffing.
391                        return Err(PgError::Connection(
392                            "Protocol violation: extra bytes after GSSENCRequest 'G' response \
393                             (possible CVE-2021-23222 buffer-stuffing attack)"
394                                .to_string(),
395                        ));
396                    }
397                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
398                        // No extra data — this is the expected path.
399                    }
400                    Err(e) => {
401                        return Err(PgError::Io(e));
402                    }
403                }
404                Ok(GssEncNegotiationResult::Accepted(tcp_stream))
405            }
406            b'N' => Ok(GssEncNegotiationResult::Rejected),
407            b'E' => {
408                // Server sent an ErrorMessage.  Per CVE-2024-10977 we
409                // must NOT display this to users since the server has
410                // not been authenticated.  Log at trace only.
411                tracing::trace!(
412                    host = %host,
413                    port = %port,
414                    "gssenc_request_server_error (suppressed per CVE-2024-10977)"
415                );
416                Ok(GssEncNegotiationResult::ServerError)
417            }
418            other => Err(PgError::Connection(format!(
419                "Unexpected response to GSSENCRequest: 0x{:02X} \
420                     (expected 'G'=0x47 or 'N'=0x4E)",
421                other
422            ))),
423        }
424    }
425
426    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
427    async fn connect_gssenc_accepted_with_timeout(
428        tcp_stream: TcpStream,
429        host: &str,
430        user: &str,
431        database: &str,
432        password: Option<&str>,
433        auth_settings: AuthSettings,
434        gss_token_provider: Option<super::super::GssTokenProvider>,
435        gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
436        startup_params: Vec<(String, String)>,
437        protocol_minor: u16,
438    ) -> PgResult<Self> {
439        let gssenc_fut = async {
440            let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
441                .await
442                .map_err(PgError::Auth)?;
443            let mut conn = Self {
444                stream: PgStream::GssEnc(gss_stream),
445                buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
446                write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
447                sql_buf: BytesMut::with_capacity(512),
448                params_buf: Vec::with_capacity(16),
449                prepared_statements: HashMap::new(),
450                stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
451                column_info_cache: HashMap::new(),
452                process_id: 0,
453                secret_key: 0,
454                cancel_key_bytes: Vec::new(),
455                requested_protocol_minor: protocol_minor,
456                negotiated_protocol_minor: protocol_minor,
457                notifications: VecDeque::new(),
458                replication_stream_active: false,
459                replication_mode_enabled: has_logical_replication_startup_mode(&startup_params),
460                last_replication_wal_end: None,
461                io_desynced: false,
462                pending_statement_closes: Vec::new(),
463                draining_statement_closes: false,
464            };
465            conn.send(FrontendMessage::Startup {
466                user: user.to_string(),
467                database: database.to_string(),
468                protocol_version: protocol_version_from_minor(protocol_minor),
469                startup_params: startup_params.clone(),
470            })
471            .await?;
472            conn.handle_startup(
473                user,
474                password,
475                auth_settings,
476                gss_token_provider,
477                gss_token_provider_ex,
478            )
479            .await?;
480            Ok(conn)
481        };
482        tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
483            .await
484            .map_err(|_| {
485                PgError::Connection(format!(
486                    "GSSENC connection timeout after {:?} (handshake + auth)",
487                    DEFAULT_CONNECT_TIMEOUT
488                ))
489            })?
490    }
491
492    /// Connect to PostgreSQL server with optional password authentication and auth policy.
493    pub async fn connect_with_password_and_auth(
494        host: &str,
495        port: u16,
496        user: &str,
497        database: &str,
498        password: Option<&str>,
499        auth_settings: AuthSettings,
500    ) -> PgResult<Self> {
501        Self::connect_with_password_and_auth_and_gss(ConnectParams {
502            host,
503            port,
504            user,
505            database,
506            password,
507            auth_settings,
508            gss_token_provider: None,
509            gss_token_provider_ex: None,
510            protocol_minor: Self::default_protocol_minor(),
511            startup_params: Vec::new(),
512        })
513        .await
514    }
515
516    async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
517        let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
518        if let Err(err) = &first {
519            if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
520                let mut downgraded = params;
521                downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
522                return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
523            }
524        }
525        first
526    }
527
528    async fn connect_with_password_and_auth_and_gss_once(
529        params: ConnectParams<'_>,
530    ) -> PgResult<Self> {
531        let connect_started = Instant::now();
532        let attempt_backend = plain_connect_attempt_backend();
533        record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
534        let result = tokio::time::timeout(
535            DEFAULT_CONNECT_TIMEOUT,
536            Self::connect_with_password_inner(params),
537        )
538        .await
539        .map_err(|_| {
540            PgError::Connection(format!(
541                "Connection timeout after {:?} (TCP connect + handshake)",
542                DEFAULT_CONNECT_TIMEOUT
543            ))
544        })?;
545        let backend = result
546            .as_ref()
547            .map(|conn| connect_backend_for_stream(&conn.stream))
548            .unwrap_or(attempt_backend);
549        record_connect_result(
550            CONNECT_TRANSPORT_PLAIN,
551            backend,
552            &result,
553            connect_started.elapsed(),
554        );
555        result
556    }
557
558    /// Inner connection logic without timeout wrapper.
559    async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
560        let ConnectParams {
561            host,
562            port,
563            user,
564            database,
565            password,
566            auth_settings,
567            gss_token_provider,
568            gss_token_provider_ex,
569            protocol_minor,
570            startup_params,
571        } = params;
572        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
573        let addr = format!("{}:{}", host, port);
574        let stream = Self::connect_plain_stream(&addr).await?;
575
576        let mut conn = Self {
577            stream,
578            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
579            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
580            sql_buf: BytesMut::with_capacity(512),
581            params_buf: Vec::with_capacity(16), // SQL encoding buffer
582            prepared_statements: HashMap::new(),
583            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
584            column_info_cache: HashMap::new(),
585            process_id: 0,
586            secret_key: 0,
587            cancel_key_bytes: Vec::new(),
588            requested_protocol_minor: protocol_minor,
589            negotiated_protocol_minor: protocol_minor,
590            notifications: VecDeque::new(),
591            replication_stream_active: false,
592            replication_mode_enabled,
593            last_replication_wal_end: None,
594            io_desynced: false,
595            pending_statement_closes: Vec::new(),
596            draining_statement_closes: false,
597        };
598
599        conn.send(FrontendMessage::Startup {
600            user: user.to_string(),
601            database: database.to_string(),
602            protocol_version: protocol_version_from_minor(protocol_minor),
603            startup_params,
604        })
605        .await?;
606
607        conn.handle_startup(
608            user,
609            password,
610            auth_settings,
611            gss_token_provider,
612            gss_token_provider_ex,
613        )
614        .await?;
615
616        Ok(conn)
617    }
618
619    async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
620        let tcp_stream = TcpStream::connect(addr).await?;
621        tcp_stream.set_nodelay(true)?;
622
623        #[cfg(all(target_os = "linux", feature = "io_uring"))]
624        {
625            if should_try_uring_plain() {
626                match super::super::uring::UringTcpStream::from_tokio(tcp_stream) {
627                    Ok(uring_stream) => {
628                        tracing::info!(
629                            addr = %addr,
630                            "qail-pg: using io_uring plain TCP transport"
631                        );
632                        return Ok(PgStream::Uring(uring_stream));
633                    }
634                    Err(e) => {
635                        tracing::warn!(
636                            addr = %addr,
637                            error = %e,
638                            "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
639                        );
640                        let fallback = TcpStream::connect(addr).await?;
641                        fallback.set_nodelay(true)?;
642                        return Ok(PgStream::Tcp(fallback));
643                    }
644                }
645            }
646        }
647
648        Ok(PgStream::Tcp(tcp_stream))
649    }
650
651    /// Connect to PostgreSQL server with TLS encryption.
652    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
653    pub async fn connect_tls(
654        host: &str,
655        port: u16,
656        user: &str,
657        database: &str,
658        password: Option<&str>,
659    ) -> PgResult<Self> {
660        Self::connect_tls_with_auth(
661            host,
662            port,
663            user,
664            database,
665            password,
666            AuthSettings::default(),
667            None,
668        )
669        .await
670    }
671
672    /// Connect to PostgreSQL over TLS with explicit auth policy and optional custom CA bundle.
673    pub async fn connect_tls_with_auth(
674        host: &str,
675        port: u16,
676        user: &str,
677        database: &str,
678        password: Option<&str>,
679        auth_settings: AuthSettings,
680        ca_cert_pem: Option<&[u8]>,
681    ) -> PgResult<Self> {
682        Self::connect_tls_with_auth_and_gss(
683            ConnectParams {
684                host,
685                port,
686                user,
687                database,
688                password,
689                auth_settings,
690                gss_token_provider: None,
691                gss_token_provider_ex: None,
692                protocol_minor: Self::default_protocol_minor(),
693                startup_params: Vec::new(),
694            },
695            ca_cert_pem,
696        )
697        .await
698    }
699
700    async fn connect_tls_with_auth_and_gss(
701        params: ConnectParams<'_>,
702        ca_cert_pem: Option<&[u8]>,
703    ) -> PgResult<Self> {
704        let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
705        if let Err(err) = &first {
706            if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
707                let mut downgraded = params;
708                downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
709                return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
710            }
711        }
712        first
713    }
714
715    async fn connect_tls_with_auth_and_gss_once(
716        params: ConnectParams<'_>,
717        ca_cert_pem: Option<&[u8]>,
718    ) -> PgResult<Self> {
719        let connect_started = Instant::now();
720        record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
721        let result = tokio::time::timeout(
722            DEFAULT_CONNECT_TIMEOUT,
723            Self::connect_tls_inner(params, ca_cert_pem),
724        )
725        .await
726        .map_err(|_| {
727            PgError::Connection(format!(
728                "TLS connection timeout after {:?}",
729                DEFAULT_CONNECT_TIMEOUT
730            ))
731        })?;
732        record_connect_result(
733            CONNECT_TRANSPORT_TLS,
734            CONNECT_BACKEND_TOKIO,
735            &result,
736            connect_started.elapsed(),
737        );
738        result
739    }
740
741    /// Inner TLS connection logic without timeout wrapper.
742    async fn connect_tls_inner(
743        params: ConnectParams<'_>,
744        ca_cert_pem: Option<&[u8]>,
745    ) -> PgResult<Self> {
746        let ConnectParams {
747            host,
748            port,
749            user,
750            database,
751            password,
752            auth_settings,
753            gss_token_provider,
754            gss_token_provider_ex,
755            protocol_minor,
756            startup_params,
757        } = params;
758        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
759        use tokio::io::AsyncReadExt;
760        use tokio_rustls::TlsConnector;
761        use tokio_rustls::rustls::ClientConfig;
762        use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
763
764        let addr = format!("{}:{}", host, port);
765        let mut tcp_stream = TcpStream::connect(&addr).await?;
766
767        // Send SSLRequest
768        tcp_stream.write_all(&SSL_REQUEST).await?;
769
770        // Read response
771        let mut response = [0u8; 1];
772        tcp_stream.read_exact(&mut response).await?;
773
774        if response[0] != b'S' {
775            return Err(PgError::Connection(
776                "Server does not support TLS".to_string(),
777            ));
778        }
779
780        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
781
782        if let Some(ca_pem) = ca_cert_pem {
783            let certs = CertificateDer::pem_slice_iter(ca_pem)
784                .collect::<Result<Vec<_>, _>>()
785                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
786            if certs.is_empty() {
787                return Err(PgError::Connection(
788                    "No CA certificates found in provided PEM".to_string(),
789                ));
790            }
791            for cert in certs {
792                let _ = root_cert_store.add(cert);
793            }
794        } else {
795            let certs = rustls_native_certs::load_native_certs();
796            for cert in certs.certs {
797                let _ = root_cert_store.add(cert);
798            }
799        }
800
801        let config = ClientConfig::builder()
802            .with_root_certificates(root_cert_store)
803            .with_no_client_auth();
804
805        let connector = TlsConnector::from(Arc::new(config));
806        let server_name = ServerName::try_from(host.to_string())
807            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
808
809        let tls_stream = connector
810            .connect(server_name, tcp_stream)
811            .await
812            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
813
814        let mut conn = Self {
815            stream: PgStream::Tls(Box::new(tls_stream)),
816            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
817            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
818            sql_buf: BytesMut::with_capacity(512),
819            params_buf: Vec::with_capacity(16),
820            prepared_statements: HashMap::new(),
821            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
822            column_info_cache: HashMap::new(),
823            process_id: 0,
824            secret_key: 0,
825            cancel_key_bytes: Vec::new(),
826            requested_protocol_minor: protocol_minor,
827            negotiated_protocol_minor: protocol_minor,
828            notifications: VecDeque::new(),
829            replication_stream_active: false,
830            replication_mode_enabled,
831            last_replication_wal_end: None,
832            io_desynced: false,
833            pending_statement_closes: Vec::new(),
834            draining_statement_closes: false,
835        };
836
837        conn.send(FrontendMessage::Startup {
838            user: user.to_string(),
839            database: database.to_string(),
840            protocol_version: protocol_version_from_minor(protocol_minor),
841            startup_params,
842        })
843        .await?;
844
845        conn.handle_startup(
846            user,
847            password,
848            auth_settings,
849            gss_token_provider,
850            gss_token_provider_ex,
851        )
852        .await?;
853
854        Ok(conn)
855    }
856
857    /// Connect with mutual TLS (client certificate authentication).
858    /// # Arguments
859    /// * `host` - PostgreSQL server hostname
860    /// * `port` - PostgreSQL server port
861    /// * `user` - Database user
862    /// * `database` - Database name
863    /// * `config` - TLS configuration with client cert/key
864    /// # Example
865    /// ```ignore
866    /// let config = TlsConfig {
867    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
868    ///     client_key_pem: include_bytes!("client.key").to_vec(),
869    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
870    /// };
871    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
872    /// ```
873    pub async fn connect_mtls(
874        host: &str,
875        port: u16,
876        user: &str,
877        database: &str,
878        config: TlsConfig,
879    ) -> PgResult<Self> {
880        Self::connect_mtls_with_password_and_auth(
881            host,
882            port,
883            user,
884            database,
885            None,
886            config,
887            AuthSettings::default(),
888        )
889        .await
890    }
891
892    /// Connect with mutual TLS and optional password fallback.
893    pub async fn connect_mtls_with_password_and_auth(
894        host: &str,
895        port: u16,
896        user: &str,
897        database: &str,
898        password: Option<&str>,
899        config: TlsConfig,
900        auth_settings: AuthSettings,
901    ) -> PgResult<Self> {
902        Self::connect_mtls_with_password_and_auth_and_gss(
903            ConnectParams {
904                host,
905                port,
906                user,
907                database,
908                password,
909                auth_settings,
910                gss_token_provider: None,
911                gss_token_provider_ex: None,
912                protocol_minor: Self::default_protocol_minor(),
913                startup_params: Vec::new(),
914            },
915            config,
916        )
917        .await
918    }
919
920    async fn connect_mtls_with_password_and_auth_and_gss(
921        params: ConnectParams<'_>,
922        config: TlsConfig,
923    ) -> PgResult<Self> {
924        let first =
925            Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
926                .await;
927        if let Err(err) = &first {
928            if params.protocol_minor > 0 && is_explicit_protocol_version_rejection(err) {
929                let mut downgraded = params;
930                downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
931                return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
932                    .await;
933            }
934        }
935        first
936    }
937
938    async fn connect_mtls_with_password_and_auth_and_gss_once(
939        params: ConnectParams<'_>,
940        config: TlsConfig,
941    ) -> PgResult<Self> {
942        let connect_started = Instant::now();
943        record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
944        let result = tokio::time::timeout(
945            DEFAULT_CONNECT_TIMEOUT,
946            Self::connect_mtls_inner(params, config),
947        )
948        .await
949        .map_err(|_| {
950            PgError::Connection(format!(
951                "mTLS connection timeout after {:?}",
952                DEFAULT_CONNECT_TIMEOUT
953            ))
954        })?;
955        record_connect_result(
956            CONNECT_TRANSPORT_MTLS,
957            CONNECT_BACKEND_TOKIO,
958            &result,
959            connect_started.elapsed(),
960        );
961        result
962    }
963
964    /// Inner mTLS connection logic without timeout wrapper.
965    async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
966        let ConnectParams {
967            host,
968            port,
969            user,
970            database,
971            password,
972            auth_settings,
973            gss_token_provider,
974            gss_token_provider_ex,
975            protocol_minor,
976            startup_params,
977        } = params;
978        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
979        use tokio::io::AsyncReadExt;
980        use tokio_rustls::TlsConnector;
981        use tokio_rustls::rustls::{
982            ClientConfig,
983            pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
984        };
985
986        let addr = format!("{}:{}", host, port);
987        let mut tcp_stream = TcpStream::connect(&addr).await?;
988
989        // Send SSLRequest
990        tcp_stream.write_all(&SSL_REQUEST).await?;
991
992        // Read response
993        let mut response = [0u8; 1];
994        tcp_stream.read_exact(&mut response).await?;
995
996        if response[0] != b'S' {
997            return Err(PgError::Connection(
998                "Server does not support TLS".to_string(),
999            ));
1000        }
1001
1002        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1003
1004        if let Some(ca_pem) = &config.ca_cert_pem {
1005            let certs = CertificateDer::pem_slice_iter(ca_pem)
1006                .collect::<Result<Vec<_>, _>>()
1007                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1008            if certs.is_empty() {
1009                return Err(PgError::Connection(
1010                    "No CA certificates found in provided PEM".to_string(),
1011                ));
1012            }
1013            for cert in certs {
1014                let _ = root_cert_store.add(cert);
1015            }
1016        } else {
1017            // Use system certs
1018            let certs = rustls_native_certs::load_native_certs();
1019            for cert in certs.certs {
1020                let _ = root_cert_store.add(cert);
1021            }
1022        }
1023
1024        let client_certs: Vec<CertificateDer<'static>> =
1025            CertificateDer::pem_slice_iter(&config.client_cert_pem)
1026                .collect::<Result<Vec<_>, _>>()
1027                .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1028        if client_certs.is_empty() {
1029            return Err(PgError::Connection(
1030                "No client certificates found in PEM".to_string(),
1031            ));
1032        }
1033
1034        let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1035            .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1036
1037        let tls_config = ClientConfig::builder()
1038            .with_root_certificates(root_cert_store)
1039            .with_client_auth_cert(client_certs, client_key)
1040            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1041
1042        let connector = TlsConnector::from(Arc::new(tls_config));
1043        let server_name = ServerName::try_from(host.to_string())
1044            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1045
1046        let tls_stream = connector
1047            .connect(server_name, tcp_stream)
1048            .await
1049            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1050
1051        let mut conn = Self {
1052            stream: PgStream::Tls(Box::new(tls_stream)),
1053            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1054            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1055            sql_buf: BytesMut::with_capacity(512),
1056            params_buf: Vec::with_capacity(16),
1057            prepared_statements: HashMap::new(),
1058            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1059            column_info_cache: HashMap::new(),
1060            process_id: 0,
1061            secret_key: 0,
1062            cancel_key_bytes: Vec::new(),
1063            requested_protocol_minor: protocol_minor,
1064            negotiated_protocol_minor: protocol_minor,
1065            notifications: VecDeque::new(),
1066            replication_stream_active: false,
1067            replication_mode_enabled,
1068            last_replication_wal_end: None,
1069            io_desynced: false,
1070            pending_statement_closes: Vec::new(),
1071            draining_statement_closes: false,
1072        };
1073
1074        conn.send(FrontendMessage::Startup {
1075            user: user.to_string(),
1076            database: database.to_string(),
1077            protocol_version: protocol_version_from_minor(protocol_minor),
1078            startup_params,
1079        })
1080        .await?;
1081
1082        conn.handle_startup(
1083            user,
1084            password,
1085            auth_settings,
1086            gss_token_provider,
1087            gss_token_provider_ex,
1088        )
1089        .await?;
1090
1091        Ok(conn)
1092    }
1093
1094    /// Connect to PostgreSQL server via Unix domain socket.
1095    #[cfg(unix)]
1096    pub async fn connect_unix(
1097        socket_path: &str,
1098        user: &str,
1099        database: &str,
1100        password: Option<&str>,
1101    ) -> PgResult<Self> {
1102        let default_minor = Self::default_protocol_minor();
1103        let first =
1104            Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1105                .await;
1106        if let Err(err) = &first {
1107            if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
1108                let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1109                return Self::connect_unix_with_protocol(
1110                    socket_path,
1111                    user,
1112                    database,
1113                    password,
1114                    downgrade_minor,
1115                )
1116                .await;
1117            }
1118        }
1119        first
1120    }
1121
1122    #[cfg(unix)]
1123    async fn connect_unix_with_protocol(
1124        socket_path: &str,
1125        user: &str,
1126        database: &str,
1127        password: Option<&str>,
1128        protocol_minor: u16,
1129    ) -> PgResult<Self> {
1130        use tokio::net::UnixStream;
1131
1132        let unix_stream = UnixStream::connect(socket_path).await?;
1133
1134        let mut conn = Self {
1135            stream: PgStream::Unix(unix_stream),
1136            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1137            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1138            sql_buf: BytesMut::with_capacity(512),
1139            params_buf: Vec::with_capacity(16),
1140            prepared_statements: HashMap::new(),
1141            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1142            column_info_cache: HashMap::new(),
1143            process_id: 0,
1144            secret_key: 0,
1145            cancel_key_bytes: Vec::new(),
1146            requested_protocol_minor: protocol_minor,
1147            negotiated_protocol_minor: protocol_minor,
1148            notifications: VecDeque::new(),
1149            replication_stream_active: false,
1150            replication_mode_enabled: false,
1151            last_replication_wal_end: None,
1152            io_desynced: false,
1153            pending_statement_closes: Vec::new(),
1154            draining_statement_closes: false,
1155        };
1156
1157        conn.send(FrontendMessage::Startup {
1158            user: user.to_string(),
1159            database: database.to_string(),
1160            protocol_version: protocol_version_from_minor(protocol_minor),
1161            startup_params: Vec::new(),
1162        })
1163        .await?;
1164
1165        conn.handle_startup(user, password, AuthSettings::default(), None, None)
1166            .await?;
1167
1168        Ok(conn)
1169    }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor};
1175    use crate::driver::PgError;
1176
1177    #[test]
1178    fn protocol_version_from_minor_encodes_major_3() {
1179        assert_eq!(protocol_version_from_minor(2), 196610);
1180        assert_eq!(protocol_version_from_minor(0), 196608);
1181    }
1182
1183    #[test]
1184    fn explicit_protocol_rejection_detection_is_case_insensitive() {
1185        let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1186        assert!(is_explicit_protocol_version_rejection(&err));
1187
1188        let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1189        assert!(is_explicit_protocol_version_rejection(&err));
1190    }
1191
1192    #[test]
1193    fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1194        let err = PgError::Connection("connection reset by peer".to_string());
1195        assert!(!is_explicit_protocol_version_rejection(&err));
1196    }
1197}