Skip to main content

qail_pg/driver/connection/
connect.rs

1//! Connection establishment — connect_*, TLS, mTLS, Unix socket.
2
3#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6    connect_backend_for_stream, plain_connect_attempt_backend, record_connect_attempt,
7    record_connect_result,
8};
9use super::types::{
10    BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
11    CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
12    GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
13    StatementCache, TlsConfig, has_logical_replication_startup_mode,
14};
15use crate::driver::stream::PgStream;
16use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
17use crate::protocol::PROTOCOL_VERSION_3_0;
18use crate::protocol::wire::FrontendMessage;
19use bytes::BytesMut;
20use std::collections::{HashMap, VecDeque};
21use std::sync::Arc;
22use std::time::Instant;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26#[inline]
27fn protocol_version_from_minor(minor: u16) -> i32 {
28    ((3i32) << 16) | i32::from(minor)
29}
30
31fn socket_addr(host: &str, port: u16) -> String {
32    if host.contains(':') && !host.starts_with('[') {
33        format!("[{}]:{}", host, port)
34    } else {
35        format!("{}:{}", host, port)
36    }
37}
38
39fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
40    let msg = match err {
41        PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
42        PgError::Query(msg) => msg,
43        PgError::QueryServer(server) => &server.message,
44        _ => return false,
45    };
46
47    let lower = msg.to_ascii_lowercase();
48    lower.contains("unsupported frontend protocol")
49        || lower.contains("frontend protocol") && lower.contains("unsupported")
50        || lower.contains("protocol version") && lower.contains("not support")
51}
52
53impl PgConnection {
54    /// Connect to PostgreSQL server without authentication (trust mode).
55    ///
56    /// # Arguments
57    ///
58    /// * `host` — PostgreSQL server hostname or IP.
59    /// * `port` — TCP port (typically 5432).
60    /// * `user` — PostgreSQL role name.
61    /// * `database` — Target database name.
62    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
63        Self::connect_with_password(host, port, user, database, None).await
64    }
65
66    /// Connect to PostgreSQL server with optional password authentication.
67    /// Includes a default 10-second timeout covering TCP connect + handshake.
68    ///
69    /// Startup requests protocol 3.2 by default and performs a one-shot retry
70    /// with protocol 3.0 only when startup fails due to explicit
71    /// protocol-version rejection from the server.
72    pub async fn connect_with_password(
73        host: &str,
74        port: u16,
75        user: &str,
76        database: &str,
77        password: Option<&str>,
78    ) -> PgResult<Self> {
79        Self::connect_with_password_and_auth(
80            host,
81            port,
82            user,
83            database,
84            password,
85            AuthSettings::default(),
86        )
87        .await
88    }
89
90    /// Connect to PostgreSQL with explicit enterprise options.
91    ///
92    /// Negotiation preface order follows libpq:
93    ///   1. If gss_enc_mode != Disable → try GSSENCRequest on fresh TCP
94    ///   2. If GSSENC rejected/unavailable and tls_mode != Disable → try SSLRequest
95    ///   3. If both rejected/unavailable → plain StartupMessage
96    ///
97    /// The StartupMessage protocol version behavior is the same as
98    /// `connect_with_password`: request protocol 3.2 first, then retry once
99    /// with 3.0 only on explicit protocol-version rejection.
100    pub async fn connect_with_options(
101        host: &str,
102        port: u16,
103        user: &str,
104        database: &str,
105        password: Option<&str>,
106        options: ConnectOptions,
107    ) -> PgResult<Self> {
108        let ConnectOptions {
109            tls_mode,
110            gss_enc_mode,
111            tls_ca_cert_pem,
112            mtls,
113            gss_token_provider,
114            gss_token_provider_ex,
115            auth,
116            startup_params,
117        } = options;
118
119        if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
120            return Err(PgError::Connection(
121                "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
122            ));
123        }
124
125        // Enforce gss_enc_mode policy before mTLS early-return.
126        // GSSENC and mTLS are both transport-level encryption; using
127        // both simultaneously is not supported by the PostgreSQL protocol.
128        if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
129            return Err(PgError::Connection(
130                "gssencmode=require is incompatible with mTLS — both provide \
131                 transport encryption; use one or the other"
132                    .to_string(),
133            ));
134        }
135
136        if let Some(mtls_config) = mtls {
137            // gss_enc_mode is Disable or Prefer here (Require rejected above).
138            // mTLS already provides transport encryption; skip GSSENC.
139            return Self::connect_mtls_with_password_and_auth_and_gss(
140                ConnectParams {
141                    host,
142                    port,
143                    user,
144                    database,
145                    password,
146                    auth_settings: auth,
147                    gss_token_provider,
148                    gss_token_provider_ex,
149                    protocol_minor: Self::default_protocol_minor(),
150                    startup_params: startup_params.clone(),
151                },
152                mtls_config,
153            )
154            .await;
155        }
156
157        // ── Phase 1: Try GSSENC if requested ──────────────────────────
158        if gss_enc_mode != GssEncMode::Disable {
159            match Self::try_gssenc_request(host, port).await {
160                Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
161                    let connect_started = Instant::now();
162                    record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
163                    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
164                    {
165                        let default_minor = Self::default_protocol_minor();
166                        let gss_params = ConnectParams {
167                            host,
168                            port,
169                            user,
170                            database,
171                            password,
172                            auth_settings: auth,
173                            gss_token_provider,
174                            gss_token_provider_ex: gss_token_provider_ex.clone(),
175                            protocol_minor: default_minor,
176                            startup_params: startup_params.clone(),
177                        };
178                        let mut result = Self::connect_gssenc_accepted_with_timeout(
179                            tcp_stream,
180                            gss_params.clone(),
181                        )
182                        .await;
183                        if let Err(err) = &result
184                            && default_minor > 0
185                            && is_explicit_protocol_version_rejection(err)
186                        {
187                            let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
188                            let retry_stream = match Self::try_gssenc_request(host, port).await {
189                                Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
190                                Ok(GssEncNegotiationResult::Rejected) => {
191                                    return Err(PgError::Connection(
192                                        "Protocol downgrade retry failed: server rejected GSSENCRequest"
193                                            .to_string(),
194                                    ));
195                                }
196                                Ok(GssEncNegotiationResult::ServerError) => {
197                                    return Err(PgError::Connection(
198                                        "Protocol downgrade retry failed: server returned error to GSSENCRequest"
199                                            .to_string(),
200                                    ));
201                                }
202                                Err(e) => {
203                                    return Err(e);
204                                }
205                            };
206                            let mut retry_params = gss_params;
207                            retry_params.protocol_minor = downgrade_minor;
208                            result = Self::connect_gssenc_accepted_with_timeout(
209                                retry_stream,
210                                retry_params,
211                            )
212                            .await;
213                        }
214                        record_connect_result(
215                            CONNECT_TRANSPORT_GSSENC,
216                            CONNECT_BACKEND_TOKIO,
217                            &result,
218                            connect_started.elapsed(),
219                        );
220                        return result;
221                    }
222                    #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
223                    {
224                        let _ = tcp_stream;
225                        let err = PgError::Connection(
226                            "Server accepted GSSENCRequest but GSSAPI encryption requires \
227                             feature enterprise-gssapi on Linux"
228                                .to_string(),
229                        );
230                        metrics::histogram!(
231                            "qail_pg_connect_duration_seconds",
232                            "transport" => CONNECT_TRANSPORT_GSSENC,
233                            "backend" => CONNECT_BACKEND_TOKIO,
234                            "outcome" => "error"
235                        )
236                        .record(connect_started.elapsed().as_secs_f64());
237                        metrics::counter!(
238                            "qail_pg_connect_failure_total",
239                            "transport" => CONNECT_TRANSPORT_GSSENC,
240                            "backend" => CONNECT_BACKEND_TOKIO,
241                            "error_kind" => super::helpers::connect_error_kind(&err)
242                        )
243                        .increment(1);
244                        return Err(err);
245                    }
246                }
247                Ok(GssEncNegotiationResult::Rejected)
248                | Ok(GssEncNegotiationResult::ServerError) => {
249                    if gss_enc_mode == GssEncMode::Require {
250                        return Err(PgError::Connection(
251                            "gssencmode=require but server rejected GSSENCRequest".to_string(),
252                        ));
253                    }
254                    // gss_enc_mode == Prefer — fall through to TLS / plain
255                }
256                Err(e) => {
257                    if gss_enc_mode == GssEncMode::Require {
258                        return Err(e);
259                    }
260                    // gss_enc_mode == Prefer — connection error, fall through
261                    tracing::debug!(
262                        host = %host,
263                        port = %port,
264                        error = %e,
265                        "gssenc_prefer_fallthrough"
266                    );
267                }
268            }
269        }
270
271        // ── Phase 2: TLS / plain per sslmode ──────────────────────────
272        match tls_mode {
273            TlsMode::Disable => {
274                Self::connect_with_password_and_auth_and_gss(ConnectParams {
275                    host,
276                    port,
277                    user,
278                    database,
279                    password,
280                    auth_settings: auth,
281                    gss_token_provider,
282                    gss_token_provider_ex,
283                    protocol_minor: Self::default_protocol_minor(),
284                    startup_params: startup_params.clone(),
285                })
286                .await
287            }
288            TlsMode::Require => {
289                Self::connect_tls_with_auth_and_gss(
290                    ConnectParams {
291                        host,
292                        port,
293                        user,
294                        database,
295                        password,
296                        auth_settings: auth,
297                        gss_token_provider,
298                        gss_token_provider_ex,
299                        protocol_minor: Self::default_protocol_minor(),
300                        startup_params: startup_params.clone(),
301                    },
302                    tls_ca_cert_pem.as_deref(),
303                )
304                .await
305            }
306            TlsMode::Prefer => {
307                match Self::connect_tls_with_auth_and_gss(
308                    ConnectParams {
309                        host,
310                        port,
311                        user,
312                        database,
313                        password,
314                        auth_settings: auth,
315                        gss_token_provider,
316                        gss_token_provider_ex: gss_token_provider_ex.clone(),
317                        protocol_minor: Self::default_protocol_minor(),
318                        startup_params: startup_params.clone(),
319                    },
320                    tls_ca_cert_pem.as_deref(),
321                )
322                .await
323                {
324                    Ok(conn) => Ok(conn),
325                    Err(PgError::Connection(msg))
326                        if msg.contains("Server does not support TLS") =>
327                    {
328                        Self::connect_with_password_and_auth_and_gss(ConnectParams {
329                            host,
330                            port,
331                            user,
332                            database,
333                            password,
334                            auth_settings: auth,
335                            gss_token_provider,
336                            gss_token_provider_ex,
337                            protocol_minor: Self::default_protocol_minor(),
338                            startup_params: startup_params.clone(),
339                        })
340                        .await
341                    }
342                    Err(e) => Err(e),
343                }
344            }
345        }
346    }
347
348    /// Attempt GSSAPI session encryption negotiation.
349    ///
350    /// Opens a fresh TCP connection, sends GSSENCRequest (80877104),
351    /// reads exactly one byte (CVE-2021-23222 safe), and returns
352    /// the result.  The entire operation is bounded by
353    /// `DEFAULT_CONNECT_TIMEOUT`.
354    async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
355        tokio::time::timeout(
356            DEFAULT_CONNECT_TIMEOUT,
357            Self::try_gssenc_request_inner(host, port),
358        )
359        .await
360        .map_err(|_| {
361            PgError::Connection(format!(
362                "GSSENCRequest timeout after {:?}",
363                DEFAULT_CONNECT_TIMEOUT
364            ))
365        })?
366    }
367
368    /// Inner GSSENCRequest logic without timeout wrapper.
369    async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
370        use tokio::io::AsyncReadExt;
371
372        let addr = socket_addr(host, port);
373        let mut tcp_stream = TcpStream::connect(&addr).await?;
374        tcp_stream.set_nodelay(true)?;
375
376        // Send the 8-byte GSSENCRequest.
377        tcp_stream.write_all(&GSSENC_REQUEST).await?;
378        tcp_stream.flush().await?;
379
380        // CVE-2021-23222: Read exactly one byte.  The server must
381        // respond with a single 'G' or 'N'.  Any additional bytes
382        // in the buffer indicate a buffer-stuffing attack.
383        let mut response = [0u8; 1];
384        tcp_stream.read_exact(&mut response).await?;
385
386        match response[0] {
387            b'G' => {
388                // CVE-2021-23222 check: verify no extra bytes are buffered.
389                // Use a non-blocking peek to detect leftover data.
390                let mut peek_buf = [0u8; 1];
391                match tcp_stream.try_read(&mut peek_buf) {
392                    Ok(0) => {} // EOF — fine (shouldn't happen yet but harmless)
393                    Ok(_n) => {
394                        // Extra bytes after 'G' — possible buffer-stuffing.
395                        return Err(PgError::Connection(
396                            "Protocol violation: extra bytes after GSSENCRequest 'G' response \
397                             (possible CVE-2021-23222 buffer-stuffing attack)"
398                                .to_string(),
399                        ));
400                    }
401                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
402                        // No extra data — this is the expected path.
403                    }
404                    Err(e) => {
405                        return Err(PgError::Io(e));
406                    }
407                }
408                Ok(GssEncNegotiationResult::Accepted(tcp_stream))
409            }
410            b'N' => Ok(GssEncNegotiationResult::Rejected),
411            b'E' => {
412                // Server sent an ErrorMessage.  Per CVE-2024-10977 we
413                // must NOT display this to users since the server has
414                // not been authenticated.  Log at trace only.
415                tracing::trace!(
416                    host = %host,
417                    port = %port,
418                    "gssenc_request_server_error (suppressed per CVE-2024-10977)"
419                );
420                Ok(GssEncNegotiationResult::ServerError)
421            }
422            other => Err(PgError::Connection(format!(
423                "Unexpected response to GSSENCRequest: 0x{:02X} \
424                     (expected 'G'=0x47 or 'N'=0x4E)",
425                other
426            ))),
427        }
428    }
429
430    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
431    async fn connect_gssenc_accepted_with_timeout(
432        tcp_stream: TcpStream,
433        params: ConnectParams<'_>,
434    ) -> PgResult<Self> {
435        let gssenc_fut = async {
436            let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, params.host)
437                .await
438                .map_err(PgError::Auth)?;
439            let mut conn = Self {
440                stream: PgStream::GssEnc(gss_stream),
441                buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
442                write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
443                sql_buf: BytesMut::with_capacity(512),
444                params_buf: Vec::with_capacity(16),
445                prepared_statements: HashMap::new(),
446                stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
447                column_info_cache: HashMap::new(),
448                process_id: 0,
449                cancel_key_bytes: Vec::new(),
450                requested_protocol_minor: params.protocol_minor,
451                negotiated_protocol_minor: params.protocol_minor,
452                notifications: VecDeque::new(),
453                replication_stream_active: false,
454                replication_mode_enabled: has_logical_replication_startup_mode(
455                    &params.startup_params,
456                ),
457                last_replication_wal_end: None,
458                io_desynced: false,
459                pending_statement_closes: Vec::new(),
460                draining_statement_closes: false,
461            };
462            conn.send(FrontendMessage::Startup {
463                user: params.user.to_string(),
464                database: params.database.to_string(),
465                protocol_version: protocol_version_from_minor(params.protocol_minor),
466                startup_params: params.startup_params.clone(),
467            })
468            .await?;
469            conn.handle_startup(
470                params.user,
471                params.password,
472                params.auth_settings,
473                params.gss_token_provider,
474                params.gss_token_provider_ex,
475            )
476            .await?;
477            Ok(conn)
478        };
479        tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
480            .await
481            .map_err(|_| {
482                PgError::Connection(format!(
483                    "GSSENC connection timeout after {:?} (handshake + auth)",
484                    DEFAULT_CONNECT_TIMEOUT
485                ))
486            })?
487    }
488
489    /// Connect to PostgreSQL server with optional password authentication and auth policy.
490    pub async fn connect_with_password_and_auth(
491        host: &str,
492        port: u16,
493        user: &str,
494        database: &str,
495        password: Option<&str>,
496        auth_settings: AuthSettings,
497    ) -> PgResult<Self> {
498        Self::connect_with_password_and_auth_and_gss(ConnectParams {
499            host,
500            port,
501            user,
502            database,
503            password,
504            auth_settings,
505            gss_token_provider: None,
506            gss_token_provider_ex: None,
507            protocol_minor: Self::default_protocol_minor(),
508            startup_params: Vec::new(),
509        })
510        .await
511    }
512
513    async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
514        let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
515        if let Err(err) = &first
516            && params.protocol_minor > 0
517            && is_explicit_protocol_version_rejection(err)
518        {
519            let mut downgraded = params;
520            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
521            return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
522        }
523        first
524    }
525
526    async fn connect_with_password_and_auth_and_gss_once(
527        params: ConnectParams<'_>,
528    ) -> PgResult<Self> {
529        let connect_started = Instant::now();
530        let attempt_backend = plain_connect_attempt_backend();
531        record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
532        let result = tokio::time::timeout(
533            DEFAULT_CONNECT_TIMEOUT,
534            Self::connect_with_password_inner(params),
535        )
536        .await
537        .map_err(|_| {
538            PgError::Connection(format!(
539                "Connection timeout after {:?} (TCP connect + handshake)",
540                DEFAULT_CONNECT_TIMEOUT
541            ))
542        })?;
543        let backend = result
544            .as_ref()
545            .map(|conn| connect_backend_for_stream(&conn.stream))
546            .unwrap_or(attempt_backend);
547        record_connect_result(
548            CONNECT_TRANSPORT_PLAIN,
549            backend,
550            &result,
551            connect_started.elapsed(),
552        );
553        result
554    }
555
556    /// Inner connection logic without timeout wrapper.
557    async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
558        let ConnectParams {
559            host,
560            port,
561            user,
562            database,
563            password,
564            auth_settings,
565            gss_token_provider,
566            gss_token_provider_ex,
567            protocol_minor,
568            startup_params,
569        } = params;
570        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
571        let addr = socket_addr(host, port);
572        let stream = Self::connect_plain_stream(&addr).await?;
573
574        let mut conn = Self {
575            stream,
576            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
577            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
578            sql_buf: BytesMut::with_capacity(512),
579            params_buf: Vec::with_capacity(16), // SQL encoding buffer
580            prepared_statements: HashMap::new(),
581            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
582            column_info_cache: HashMap::new(),
583            process_id: 0,
584            cancel_key_bytes: Vec::new(),
585            requested_protocol_minor: protocol_minor,
586            negotiated_protocol_minor: protocol_minor,
587            notifications: VecDeque::new(),
588            replication_stream_active: false,
589            replication_mode_enabled,
590            last_replication_wal_end: None,
591            io_desynced: false,
592            pending_statement_closes: Vec::new(),
593            draining_statement_closes: false,
594        };
595
596        conn.send(FrontendMessage::Startup {
597            user: user.to_string(),
598            database: database.to_string(),
599            protocol_version: protocol_version_from_minor(protocol_minor),
600            startup_params,
601        })
602        .await?;
603
604        conn.handle_startup(
605            user,
606            password,
607            auth_settings,
608            gss_token_provider,
609            gss_token_provider_ex,
610        )
611        .await?;
612
613        Ok(conn)
614    }
615
616    async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
617        let tcp_stream = TcpStream::connect(addr).await?;
618        tcp_stream.set_nodelay(true)?;
619
620        #[cfg(all(target_os = "linux", feature = "io_uring"))]
621        {
622            if should_try_uring_plain() {
623                let std_stream = tcp_stream.into_std()?;
624                let fallback_std = std_stream.try_clone()?;
625                match super::super::uring::UringTcpStream::from_std(std_stream) {
626                    Ok(uring_stream) => {
627                        tracing::info!(
628                            addr = %addr,
629                            "qail-pg: using io_uring plain TCP transport"
630                        );
631                        return Ok(PgStream::Uring(uring_stream));
632                    }
633                    Err(e) => {
634                        tracing::warn!(
635                            addr = %addr,
636                            error = %e,
637                            "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
638                        );
639                        fallback_std.set_nonblocking(true)?;
640                        let fallback = TcpStream::from_std(fallback_std)?;
641                        return Ok(PgStream::Tcp(fallback));
642                    }
643                }
644            }
645        }
646
647        Ok(PgStream::Tcp(tcp_stream))
648    }
649
650    /// Connect to PostgreSQL server with TLS encryption.
651    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
652    pub async fn connect_tls(
653        host: &str,
654        port: u16,
655        user: &str,
656        database: &str,
657        password: Option<&str>,
658    ) -> PgResult<Self> {
659        Self::connect_tls_with_auth(
660            host,
661            port,
662            user,
663            database,
664            password,
665            AuthSettings::default(),
666            None,
667        )
668        .await
669    }
670
671    /// Connect to PostgreSQL over TLS with explicit auth policy and optional custom CA bundle.
672    pub async fn connect_tls_with_auth(
673        host: &str,
674        port: u16,
675        user: &str,
676        database: &str,
677        password: Option<&str>,
678        auth_settings: AuthSettings,
679        ca_cert_pem: Option<&[u8]>,
680    ) -> PgResult<Self> {
681        Self::connect_tls_with_auth_and_gss(
682            ConnectParams {
683                host,
684                port,
685                user,
686                database,
687                password,
688                auth_settings,
689                gss_token_provider: None,
690                gss_token_provider_ex: None,
691                protocol_minor: Self::default_protocol_minor(),
692                startup_params: Vec::new(),
693            },
694            ca_cert_pem,
695        )
696        .await
697    }
698
699    async fn connect_tls_with_auth_and_gss(
700        params: ConnectParams<'_>,
701        ca_cert_pem: Option<&[u8]>,
702    ) -> PgResult<Self> {
703        let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
704        if let Err(err) = &first
705            && params.protocol_minor > 0
706            && is_explicit_protocol_version_rejection(err)
707        {
708            let mut downgraded = params;
709            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
710            return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
711        }
712        first
713    }
714
715    async fn connect_tls_with_auth_and_gss_once(
716        params: ConnectParams<'_>,
717        ca_cert_pem: Option<&[u8]>,
718    ) -> PgResult<Self> {
719        let connect_started = Instant::now();
720        record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
721        let result = tokio::time::timeout(
722            DEFAULT_CONNECT_TIMEOUT,
723            Self::connect_tls_inner(params, ca_cert_pem),
724        )
725        .await
726        .map_err(|_| {
727            PgError::Connection(format!(
728                "TLS connection timeout after {:?}",
729                DEFAULT_CONNECT_TIMEOUT
730            ))
731        })?;
732        record_connect_result(
733            CONNECT_TRANSPORT_TLS,
734            CONNECT_BACKEND_TOKIO,
735            &result,
736            connect_started.elapsed(),
737        );
738        result
739    }
740
741    /// Inner TLS connection logic without timeout wrapper.
742    async fn connect_tls_inner(
743        params: ConnectParams<'_>,
744        ca_cert_pem: Option<&[u8]>,
745    ) -> PgResult<Self> {
746        let ConnectParams {
747            host,
748            port,
749            user,
750            database,
751            password,
752            auth_settings,
753            gss_token_provider,
754            gss_token_provider_ex,
755            protocol_minor,
756            startup_params,
757        } = params;
758        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
759        use tokio::io::AsyncReadExt;
760        use tokio_rustls::TlsConnector;
761        use tokio_rustls::rustls::ClientConfig;
762        use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
763
764        let addr = socket_addr(host, port);
765        let mut tcp_stream = TcpStream::connect(&addr).await?;
766
767        // Send SSLRequest
768        tcp_stream.write_all(&SSL_REQUEST).await?;
769
770        // Read response
771        let mut response = [0u8; 1];
772        tcp_stream.read_exact(&mut response).await?;
773
774        if response[0] != b'S' {
775            return Err(PgError::Connection(
776                "Server does not support TLS".to_string(),
777            ));
778        }
779
780        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
781
782        if let Some(ca_pem) = ca_cert_pem {
783            let certs = CertificateDer::pem_slice_iter(ca_pem)
784                .collect::<Result<Vec<_>, _>>()
785                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
786            if certs.is_empty() {
787                return Err(PgError::Connection(
788                    "No CA certificates found in provided PEM".to_string(),
789                ));
790            }
791            for cert in certs {
792                let _ = root_cert_store.add(cert);
793            }
794        } else {
795            let certs = rustls_native_certs::load_native_certs();
796            for cert in certs.certs {
797                let _ = root_cert_store.add(cert);
798            }
799        }
800
801        let config = ClientConfig::builder()
802            .with_root_certificates(root_cert_store)
803            .with_no_client_auth();
804
805        let connector = TlsConnector::from(Arc::new(config));
806        let server_name = ServerName::try_from(host.to_string())
807            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
808
809        let tls_stream = connector
810            .connect(server_name, tcp_stream)
811            .await
812            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
813
814        let mut conn = Self {
815            stream: PgStream::Tls(Box::new(tls_stream)),
816            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
817            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
818            sql_buf: BytesMut::with_capacity(512),
819            params_buf: Vec::with_capacity(16),
820            prepared_statements: HashMap::new(),
821            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
822            column_info_cache: HashMap::new(),
823            process_id: 0,
824            cancel_key_bytes: Vec::new(),
825            requested_protocol_minor: protocol_minor,
826            negotiated_protocol_minor: protocol_minor,
827            notifications: VecDeque::new(),
828            replication_stream_active: false,
829            replication_mode_enabled,
830            last_replication_wal_end: None,
831            io_desynced: false,
832            pending_statement_closes: Vec::new(),
833            draining_statement_closes: false,
834        };
835
836        conn.send(FrontendMessage::Startup {
837            user: user.to_string(),
838            database: database.to_string(),
839            protocol_version: protocol_version_from_minor(protocol_minor),
840            startup_params,
841        })
842        .await?;
843
844        conn.handle_startup(
845            user,
846            password,
847            auth_settings,
848            gss_token_provider,
849            gss_token_provider_ex,
850        )
851        .await?;
852
853        Ok(conn)
854    }
855
856    /// Connect with mutual TLS (client certificate authentication).
857    /// # Arguments
858    /// * `host` - PostgreSQL server hostname
859    /// * `port` - PostgreSQL server port
860    /// * `user` - Database user
861    /// * `database` - Database name
862    /// * `config` - TLS configuration with client cert/key
863    /// # Example
864    /// ```ignore
865    /// let config = TlsConfig {
866    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
867    ///     client_key_pem: include_bytes!("client.key").to_vec(),
868    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
869    /// };
870    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
871    /// ```
872    pub async fn connect_mtls(
873        host: &str,
874        port: u16,
875        user: &str,
876        database: &str,
877        config: TlsConfig,
878    ) -> PgResult<Self> {
879        Self::connect_mtls_with_password_and_auth(
880            host,
881            port,
882            user,
883            database,
884            None,
885            config,
886            AuthSettings::default(),
887        )
888        .await
889    }
890
891    /// Connect with mutual TLS and optional password fallback.
892    pub async fn connect_mtls_with_password_and_auth(
893        host: &str,
894        port: u16,
895        user: &str,
896        database: &str,
897        password: Option<&str>,
898        config: TlsConfig,
899        auth_settings: AuthSettings,
900    ) -> PgResult<Self> {
901        Self::connect_mtls_with_password_and_auth_and_gss(
902            ConnectParams {
903                host,
904                port,
905                user,
906                database,
907                password,
908                auth_settings,
909                gss_token_provider: None,
910                gss_token_provider_ex: None,
911                protocol_minor: Self::default_protocol_minor(),
912                startup_params: Vec::new(),
913            },
914            config,
915        )
916        .await
917    }
918
919    async fn connect_mtls_with_password_and_auth_and_gss(
920        params: ConnectParams<'_>,
921        config: TlsConfig,
922    ) -> PgResult<Self> {
923        let first =
924            Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
925                .await;
926        if let Err(err) = &first
927            && params.protocol_minor > 0
928            && is_explicit_protocol_version_rejection(err)
929        {
930            let mut downgraded = params;
931            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
932            return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
933                .await;
934        }
935        first
936    }
937
938    async fn connect_mtls_with_password_and_auth_and_gss_once(
939        params: ConnectParams<'_>,
940        config: TlsConfig,
941    ) -> PgResult<Self> {
942        let connect_started = Instant::now();
943        record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
944        let result = tokio::time::timeout(
945            DEFAULT_CONNECT_TIMEOUT,
946            Self::connect_mtls_inner(params, config),
947        )
948        .await
949        .map_err(|_| {
950            PgError::Connection(format!(
951                "mTLS connection timeout after {:?}",
952                DEFAULT_CONNECT_TIMEOUT
953            ))
954        })?;
955        record_connect_result(
956            CONNECT_TRANSPORT_MTLS,
957            CONNECT_BACKEND_TOKIO,
958            &result,
959            connect_started.elapsed(),
960        );
961        result
962    }
963
964    /// Inner mTLS connection logic without timeout wrapper.
965    async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
966        let ConnectParams {
967            host,
968            port,
969            user,
970            database,
971            password,
972            auth_settings,
973            gss_token_provider,
974            gss_token_provider_ex,
975            protocol_minor,
976            startup_params,
977        } = params;
978        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
979        use tokio::io::AsyncReadExt;
980        use tokio_rustls::TlsConnector;
981        use tokio_rustls::rustls::{
982            ClientConfig,
983            pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
984        };
985
986        let addr = socket_addr(host, port);
987        let mut tcp_stream = TcpStream::connect(&addr).await?;
988
989        // Send SSLRequest
990        tcp_stream.write_all(&SSL_REQUEST).await?;
991
992        // Read response
993        let mut response = [0u8; 1];
994        tcp_stream.read_exact(&mut response).await?;
995
996        if response[0] != b'S' {
997            return Err(PgError::Connection(
998                "Server does not support TLS".to_string(),
999            ));
1000        }
1001
1002        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1003
1004        if let Some(ca_pem) = &config.ca_cert_pem {
1005            let certs = CertificateDer::pem_slice_iter(ca_pem)
1006                .collect::<Result<Vec<_>, _>>()
1007                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1008            if certs.is_empty() {
1009                return Err(PgError::Connection(
1010                    "No CA certificates found in provided PEM".to_string(),
1011                ));
1012            }
1013            for cert in certs {
1014                let _ = root_cert_store.add(cert);
1015            }
1016        } else {
1017            // Use system certs
1018            let certs = rustls_native_certs::load_native_certs();
1019            for cert in certs.certs {
1020                let _ = root_cert_store.add(cert);
1021            }
1022        }
1023
1024        let client_certs: Vec<CertificateDer<'static>> =
1025            CertificateDer::pem_slice_iter(&config.client_cert_pem)
1026                .collect::<Result<Vec<_>, _>>()
1027                .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1028        if client_certs.is_empty() {
1029            return Err(PgError::Connection(
1030                "No client certificates found in PEM".to_string(),
1031            ));
1032        }
1033
1034        let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1035            .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1036
1037        let tls_config = ClientConfig::builder()
1038            .with_root_certificates(root_cert_store)
1039            .with_client_auth_cert(client_certs, client_key)
1040            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1041
1042        let connector = TlsConnector::from(Arc::new(tls_config));
1043        let server_name = ServerName::try_from(host.to_string())
1044            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1045
1046        let tls_stream = connector
1047            .connect(server_name, tcp_stream)
1048            .await
1049            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1050
1051        let mut conn = Self {
1052            stream: PgStream::Tls(Box::new(tls_stream)),
1053            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1054            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1055            sql_buf: BytesMut::with_capacity(512),
1056            params_buf: Vec::with_capacity(16),
1057            prepared_statements: HashMap::new(),
1058            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1059            column_info_cache: HashMap::new(),
1060            process_id: 0,
1061            cancel_key_bytes: Vec::new(),
1062            requested_protocol_minor: protocol_minor,
1063            negotiated_protocol_minor: protocol_minor,
1064            notifications: VecDeque::new(),
1065            replication_stream_active: false,
1066            replication_mode_enabled,
1067            last_replication_wal_end: None,
1068            io_desynced: false,
1069            pending_statement_closes: Vec::new(),
1070            draining_statement_closes: false,
1071        };
1072
1073        conn.send(FrontendMessage::Startup {
1074            user: user.to_string(),
1075            database: database.to_string(),
1076            protocol_version: protocol_version_from_minor(protocol_minor),
1077            startup_params,
1078        })
1079        .await?;
1080
1081        conn.handle_startup(
1082            user,
1083            password,
1084            auth_settings,
1085            gss_token_provider,
1086            gss_token_provider_ex,
1087        )
1088        .await?;
1089
1090        Ok(conn)
1091    }
1092
1093    /// Connect to PostgreSQL server via Unix domain socket.
1094    #[cfg(unix)]
1095    pub async fn connect_unix(
1096        socket_path: &str,
1097        user: &str,
1098        database: &str,
1099        password: Option<&str>,
1100    ) -> PgResult<Self> {
1101        let default_minor = Self::default_protocol_minor();
1102        let first =
1103            Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1104                .await;
1105        if let Err(err) = &first
1106            && default_minor > 0
1107            && is_explicit_protocol_version_rejection(err)
1108        {
1109            let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1110            return Self::connect_unix_with_protocol(
1111                socket_path,
1112                user,
1113                database,
1114                password,
1115                downgrade_minor,
1116            )
1117            .await;
1118        }
1119        first
1120    }
1121
1122    #[cfg(unix)]
1123    async fn connect_unix_with_protocol(
1124        socket_path: &str,
1125        user: &str,
1126        database: &str,
1127        password: Option<&str>,
1128        protocol_minor: u16,
1129    ) -> PgResult<Self> {
1130        use tokio::net::UnixStream;
1131
1132        let unix_stream = UnixStream::connect(socket_path).await?;
1133
1134        let mut conn = Self {
1135            stream: PgStream::Unix(unix_stream),
1136            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1137            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1138            sql_buf: BytesMut::with_capacity(512),
1139            params_buf: Vec::with_capacity(16),
1140            prepared_statements: HashMap::new(),
1141            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1142            column_info_cache: HashMap::new(),
1143            process_id: 0,
1144            cancel_key_bytes: Vec::new(),
1145            requested_protocol_minor: protocol_minor,
1146            negotiated_protocol_minor: protocol_minor,
1147            notifications: VecDeque::new(),
1148            replication_stream_active: false,
1149            replication_mode_enabled: false,
1150            last_replication_wal_end: None,
1151            io_desynced: false,
1152            pending_statement_closes: Vec::new(),
1153            draining_statement_closes: false,
1154        };
1155
1156        conn.send(FrontendMessage::Startup {
1157            user: user.to_string(),
1158            database: database.to_string(),
1159            protocol_version: protocol_version_from_minor(protocol_minor),
1160            startup_params: Vec::new(),
1161        })
1162        .await?;
1163
1164        conn.handle_startup(user, password, AuthSettings::default(), None, None)
1165            .await?;
1166
1167        Ok(conn)
1168    }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173    use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor, socket_addr};
1174    use crate::driver::PgError;
1175
1176    #[test]
1177    fn protocol_version_from_minor_encodes_major_3() {
1178        assert_eq!(protocol_version_from_minor(2), 196610);
1179        assert_eq!(protocol_version_from_minor(0), 196608);
1180    }
1181
1182    #[test]
1183    fn socket_addr_brackets_ipv6_hosts() {
1184        assert_eq!(socket_addr("127.0.0.1", 5432), "127.0.0.1:5432");
1185        assert_eq!(socket_addr("::1", 5432), "[::1]:5432");
1186        assert_eq!(socket_addr("[::1]", 5432), "[::1]:5432");
1187    }
1188
1189    #[test]
1190    fn explicit_protocol_rejection_detection_is_case_insensitive() {
1191        let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1192        assert!(is_explicit_protocol_version_rejection(&err));
1193
1194        let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1195        assert!(is_explicit_protocol_version_rejection(&err));
1196    }
1197
1198    #[test]
1199    fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1200        let err = PgError::Connection("connection reset by peer".to_string());
1201        assert!(!is_explicit_protocol_version_rejection(&err));
1202    }
1203}