Skip to main content

qail_pg/driver/connection/
connect.rs

1//! Connection establishment — connect_*, TLS, mTLS, Unix socket.
2
3#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::helpers::should_try_uring_plain;
5use super::helpers::{
6    connect_backend_for_stream, connect_error_kind, plain_connect_attempt_backend,
7    record_connect_attempt, record_connect_result,
8};
9use super::types::{
10    BUFFER_CAPACITY, CONNECT_BACKEND_TOKIO, CONNECT_TRANSPORT_GSSENC, CONNECT_TRANSPORT_MTLS,
11    CONNECT_TRANSPORT_PLAIN, CONNECT_TRANSPORT_TLS, ConnectParams, DEFAULT_CONNECT_TIMEOUT,
12    GSSENC_REQUEST, GssEncNegotiationResult, PgConnection, SSL_REQUEST, STMT_CACHE_CAPACITY,
13    StatementCache, TlsConfig, has_logical_replication_startup_mode,
14};
15use crate::driver::stream::PgStream;
16use crate::driver::{AuthSettings, ConnectOptions, GssEncMode, PgError, PgResult, TlsMode};
17use crate::protocol::PROTOCOL_VERSION_3_0;
18use crate::protocol::wire::FrontendMessage;
19use bytes::BytesMut;
20use std::collections::{HashMap, VecDeque};
21use std::sync::Arc;
22use std::time::Instant;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26#[inline]
27fn protocol_version_from_minor(minor: u16) -> i32 {
28    ((3i32) << 16) | i32::from(minor)
29}
30
31fn is_explicit_protocol_version_rejection(err: &PgError) -> bool {
32    let msg = match err {
33        PgError::Connection(msg) | PgError::Protocol(msg) | PgError::Auth(msg) => msg,
34        PgError::Query(msg) => msg,
35        PgError::QueryServer(server) => &server.message,
36        _ => return false,
37    };
38
39    let lower = msg.to_ascii_lowercase();
40    lower.contains("unsupported frontend protocol")
41        || lower.contains("frontend protocol") && lower.contains("unsupported")
42        || lower.contains("protocol version") && lower.contains("not support")
43}
44
45impl PgConnection {
46    /// Connect to PostgreSQL server without authentication (trust mode).
47    ///
48    /// # Arguments
49    ///
50    /// * `host` — PostgreSQL server hostname or IP.
51    /// * `port` — TCP port (typically 5432).
52    /// * `user` — PostgreSQL role name.
53    /// * `database` — Target database name.
54    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
55        Self::connect_with_password(host, port, user, database, None).await
56    }
57
58    /// Connect to PostgreSQL server with optional password authentication.
59    /// Includes a default 10-second timeout covering TCP connect + handshake.
60    ///
61    /// Startup requests protocol 3.2 by default and performs a one-shot retry
62    /// with protocol 3.0 only when startup fails due to explicit
63    /// protocol-version rejection from the server.
64    pub async fn connect_with_password(
65        host: &str,
66        port: u16,
67        user: &str,
68        database: &str,
69        password: Option<&str>,
70    ) -> PgResult<Self> {
71        Self::connect_with_password_and_auth(
72            host,
73            port,
74            user,
75            database,
76            password,
77            AuthSettings::default(),
78        )
79        .await
80    }
81
82    /// Connect to PostgreSQL with explicit enterprise options.
83    ///
84    /// Negotiation preface order follows libpq:
85    ///   1. If gss_enc_mode != Disable → try GSSENCRequest on fresh TCP
86    ///   2. If GSSENC rejected/unavailable and tls_mode != Disable → try SSLRequest
87    ///   3. If both rejected/unavailable → plain StartupMessage
88    ///
89    /// The StartupMessage protocol version behavior is the same as
90    /// `connect_with_password`: request protocol 3.2 first, then retry once
91    /// with 3.0 only on explicit protocol-version rejection.
92    pub async fn connect_with_options(
93        host: &str,
94        port: u16,
95        user: &str,
96        database: &str,
97        password: Option<&str>,
98        options: ConnectOptions,
99    ) -> PgResult<Self> {
100        let ConnectOptions {
101            tls_mode,
102            gss_enc_mode,
103            tls_ca_cert_pem,
104            mtls,
105            gss_token_provider,
106            gss_token_provider_ex,
107            auth,
108            startup_params,
109        } = options;
110
111        if mtls.is_some() && matches!(tls_mode, TlsMode::Disable) {
112            return Err(PgError::Connection(
113                "Invalid connect options: mTLS requires tls_mode=Prefer or Require".to_string(),
114            ));
115        }
116
117        // Enforce gss_enc_mode policy before mTLS early-return.
118        // GSSENC and mTLS are both transport-level encryption; using
119        // both simultaneously is not supported by the PostgreSQL protocol.
120        if gss_enc_mode == GssEncMode::Require && mtls.is_some() {
121            return Err(PgError::Connection(
122                "gssencmode=require is incompatible with mTLS — both provide \
123                 transport encryption; use one or the other"
124                    .to_string(),
125            ));
126        }
127
128        if let Some(mtls_config) = mtls {
129            // gss_enc_mode is Disable or Prefer here (Require rejected above).
130            // mTLS already provides transport encryption; skip GSSENC.
131            return Self::connect_mtls_with_password_and_auth_and_gss(
132                ConnectParams {
133                    host,
134                    port,
135                    user,
136                    database,
137                    password,
138                    auth_settings: auth,
139                    gss_token_provider,
140                    gss_token_provider_ex,
141                    protocol_minor: Self::default_protocol_minor(),
142                    startup_params: startup_params.clone(),
143                },
144                mtls_config,
145            )
146            .await;
147        }
148
149        // ── Phase 1: Try GSSENC if requested ──────────────────────────
150        if gss_enc_mode != GssEncMode::Disable {
151            match Self::try_gssenc_request(host, port).await {
152                Ok(GssEncNegotiationResult::Accepted(tcp_stream)) => {
153                    let connect_started = Instant::now();
154                    record_connect_attempt(CONNECT_TRANSPORT_GSSENC, CONNECT_BACKEND_TOKIO);
155                    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
156                    {
157                        let default_minor = Self::default_protocol_minor();
158                        let mut result = Self::connect_gssenc_accepted_with_timeout(
159                            tcp_stream,
160                            host,
161                            user,
162                            database,
163                            password,
164                            auth,
165                            gss_token_provider,
166                            gss_token_provider_ex.clone(),
167                            startup_params.clone(),
168                            default_minor,
169                        )
170                        .await;
171                        if let Err(err) = &result {
172                            if default_minor > 0 && is_explicit_protocol_version_rejection(err) {
173                                let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
174                                let retry_stream = match Self::try_gssenc_request(host, port).await
175                                {
176                                    Ok(GssEncNegotiationResult::Accepted(stream)) => stream,
177                                    Ok(GssEncNegotiationResult::Rejected) => {
178                                        return Err(PgError::Connection(
179                                            "Protocol downgrade retry failed: server rejected GSSENCRequest"
180                                                .to_string(),
181                                        ));
182                                    }
183                                    Ok(GssEncNegotiationResult::ServerError) => {
184                                        return Err(PgError::Connection(
185                                            "Protocol downgrade retry failed: server returned error to GSSENCRequest"
186                                                .to_string(),
187                                        ));
188                                    }
189                                    Err(e) => {
190                                        return Err(e);
191                                    }
192                                };
193                                result = Self::connect_gssenc_accepted_with_timeout(
194                                    retry_stream,
195                                    host,
196                                    user,
197                                    database,
198                                    password,
199                                    auth,
200                                    gss_token_provider,
201                                    gss_token_provider_ex,
202                                    startup_params.clone(),
203                                    downgrade_minor,
204                                )
205                                .await;
206                            }
207                        }
208                        record_connect_result(
209                            CONNECT_TRANSPORT_GSSENC,
210                            CONNECT_BACKEND_TOKIO,
211                            &result,
212                            connect_started.elapsed(),
213                        );
214                        return result;
215                    }
216                    #[cfg(not(all(feature = "enterprise-gssapi", target_os = "linux")))]
217                    {
218                        let _ = tcp_stream;
219                        let err = PgError::Connection(
220                            "Server accepted GSSENCRequest but GSSAPI encryption requires \
221                             feature enterprise-gssapi on Linux"
222                                .to_string(),
223                        );
224                        metrics::histogram!(
225                            "qail_pg_connect_duration_seconds",
226                            "transport" => CONNECT_TRANSPORT_GSSENC,
227                            "backend" => CONNECT_BACKEND_TOKIO,
228                            "outcome" => "error"
229                        )
230                        .record(connect_started.elapsed().as_secs_f64());
231                        metrics::counter!(
232                            "qail_pg_connect_failure_total",
233                            "transport" => CONNECT_TRANSPORT_GSSENC,
234                            "backend" => CONNECT_BACKEND_TOKIO,
235                            "error_kind" => connect_error_kind(&err)
236                        )
237                        .increment(1);
238                        return Err(err);
239                    }
240                }
241                Ok(GssEncNegotiationResult::Rejected)
242                | Ok(GssEncNegotiationResult::ServerError) => {
243                    if gss_enc_mode == GssEncMode::Require {
244                        return Err(PgError::Connection(
245                            "gssencmode=require but server rejected GSSENCRequest".to_string(),
246                        ));
247                    }
248                    // gss_enc_mode == Prefer — fall through to TLS / plain
249                }
250                Err(e) => {
251                    if gss_enc_mode == GssEncMode::Require {
252                        return Err(e);
253                    }
254                    // gss_enc_mode == Prefer — connection error, fall through
255                    tracing::debug!(
256                        host = %host,
257                        port = %port,
258                        error = %e,
259                        "gssenc_prefer_fallthrough"
260                    );
261                }
262            }
263        }
264
265        // ── Phase 2: TLS / plain per sslmode ──────────────────────────
266        match tls_mode {
267            TlsMode::Disable => {
268                Self::connect_with_password_and_auth_and_gss(ConnectParams {
269                    host,
270                    port,
271                    user,
272                    database,
273                    password,
274                    auth_settings: auth,
275                    gss_token_provider,
276                    gss_token_provider_ex,
277                    protocol_minor: Self::default_protocol_minor(),
278                    startup_params: startup_params.clone(),
279                })
280                .await
281            }
282            TlsMode::Require => {
283                Self::connect_tls_with_auth_and_gss(
284                    ConnectParams {
285                        host,
286                        port,
287                        user,
288                        database,
289                        password,
290                        auth_settings: auth,
291                        gss_token_provider,
292                        gss_token_provider_ex,
293                        protocol_minor: Self::default_protocol_minor(),
294                        startup_params: startup_params.clone(),
295                    },
296                    tls_ca_cert_pem.as_deref(),
297                )
298                .await
299            }
300            TlsMode::Prefer => {
301                match Self::connect_tls_with_auth_and_gss(
302                    ConnectParams {
303                        host,
304                        port,
305                        user,
306                        database,
307                        password,
308                        auth_settings: auth,
309                        gss_token_provider,
310                        gss_token_provider_ex: gss_token_provider_ex.clone(),
311                        protocol_minor: Self::default_protocol_minor(),
312                        startup_params: startup_params.clone(),
313                    },
314                    tls_ca_cert_pem.as_deref(),
315                )
316                .await
317                {
318                    Ok(conn) => Ok(conn),
319                    Err(PgError::Connection(msg))
320                        if msg.contains("Server does not support TLS") =>
321                    {
322                        Self::connect_with_password_and_auth_and_gss(ConnectParams {
323                            host,
324                            port,
325                            user,
326                            database,
327                            password,
328                            auth_settings: auth,
329                            gss_token_provider,
330                            gss_token_provider_ex,
331                            protocol_minor: Self::default_protocol_minor(),
332                            startup_params: startup_params.clone(),
333                        })
334                        .await
335                    }
336                    Err(e) => Err(e),
337                }
338            }
339        }
340    }
341
342    /// Attempt GSSAPI session encryption negotiation.
343    ///
344    /// Opens a fresh TCP connection, sends GSSENCRequest (80877104),
345    /// reads exactly one byte (CVE-2021-23222 safe), and returns
346    /// the result.  The entire operation is bounded by
347    /// `DEFAULT_CONNECT_TIMEOUT`.
348    async fn try_gssenc_request(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
349        tokio::time::timeout(
350            DEFAULT_CONNECT_TIMEOUT,
351            Self::try_gssenc_request_inner(host, port),
352        )
353        .await
354        .map_err(|_| {
355            PgError::Connection(format!(
356                "GSSENCRequest timeout after {:?}",
357                DEFAULT_CONNECT_TIMEOUT
358            ))
359        })?
360    }
361
362    /// Inner GSSENCRequest logic without timeout wrapper.
363    async fn try_gssenc_request_inner(host: &str, port: u16) -> PgResult<GssEncNegotiationResult> {
364        use tokio::io::AsyncReadExt;
365
366        let addr = format!("{}:{}", host, port);
367        let mut tcp_stream = TcpStream::connect(&addr).await?;
368        tcp_stream.set_nodelay(true)?;
369
370        // Send the 8-byte GSSENCRequest.
371        tcp_stream.write_all(&GSSENC_REQUEST).await?;
372        tcp_stream.flush().await?;
373
374        // CVE-2021-23222: Read exactly one byte.  The server must
375        // respond with a single 'G' or 'N'.  Any additional bytes
376        // in the buffer indicate a buffer-stuffing attack.
377        let mut response = [0u8; 1];
378        tcp_stream.read_exact(&mut response).await?;
379
380        match response[0] {
381            b'G' => {
382                // CVE-2021-23222 check: verify no extra bytes are buffered.
383                // Use a non-blocking peek to detect leftover data.
384                let mut peek_buf = [0u8; 1];
385                match tcp_stream.try_read(&mut peek_buf) {
386                    Ok(0) => {} // EOF — fine (shouldn't happen yet but harmless)
387                    Ok(_n) => {
388                        // Extra bytes after 'G' — possible buffer-stuffing.
389                        return Err(PgError::Connection(
390                            "Protocol violation: extra bytes after GSSENCRequest 'G' response \
391                             (possible CVE-2021-23222 buffer-stuffing attack)"
392                                .to_string(),
393                        ));
394                    }
395                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
396                        // No extra data — this is the expected path.
397                    }
398                    Err(e) => {
399                        return Err(PgError::Io(e));
400                    }
401                }
402                Ok(GssEncNegotiationResult::Accepted(tcp_stream))
403            }
404            b'N' => Ok(GssEncNegotiationResult::Rejected),
405            b'E' => {
406                // Server sent an ErrorMessage.  Per CVE-2024-10977 we
407                // must NOT display this to users since the server has
408                // not been authenticated.  Log at trace only.
409                tracing::trace!(
410                    host = %host,
411                    port = %port,
412                    "gssenc_request_server_error (suppressed per CVE-2024-10977)"
413                );
414                Ok(GssEncNegotiationResult::ServerError)
415            }
416            other => Err(PgError::Connection(format!(
417                "Unexpected response to GSSENCRequest: 0x{:02X} \
418                     (expected 'G'=0x47 or 'N'=0x4E)",
419                other
420            ))),
421        }
422    }
423
424    #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
425    async fn connect_gssenc_accepted_with_timeout(
426        tcp_stream: TcpStream,
427        host: &str,
428        user: &str,
429        database: &str,
430        password: Option<&str>,
431        auth_settings: AuthSettings,
432        gss_token_provider: Option<super::super::GssTokenProvider>,
433        gss_token_provider_ex: Option<super::super::GssTokenProviderEx>,
434        startup_params: Vec<(String, String)>,
435        protocol_minor: u16,
436    ) -> PgResult<Self> {
437        let gssenc_fut = async {
438            let gss_stream = super::super::gss::gssenc_handshake(tcp_stream, host)
439                .await
440                .map_err(PgError::Auth)?;
441            let mut conn = Self {
442                stream: PgStream::GssEnc(gss_stream),
443                buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
444                write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
445                sql_buf: BytesMut::with_capacity(512),
446                params_buf: Vec::with_capacity(16),
447                prepared_statements: HashMap::new(),
448                stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
449                column_info_cache: HashMap::new(),
450                process_id: 0,
451                secret_key: 0,
452                cancel_key_bytes: Vec::new(),
453                requested_protocol_minor: protocol_minor,
454                negotiated_protocol_minor: protocol_minor,
455                notifications: VecDeque::new(),
456                replication_stream_active: false,
457                replication_mode_enabled: has_logical_replication_startup_mode(&startup_params),
458                last_replication_wal_end: None,
459                io_desynced: false,
460                pending_statement_closes: Vec::new(),
461                draining_statement_closes: false,
462            };
463            conn.send(FrontendMessage::Startup {
464                user: user.to_string(),
465                database: database.to_string(),
466                protocol_version: protocol_version_from_minor(protocol_minor),
467                startup_params: startup_params.clone(),
468            })
469            .await?;
470            conn.handle_startup(
471                user,
472                password,
473                auth_settings,
474                gss_token_provider,
475                gss_token_provider_ex,
476            )
477            .await?;
478            Ok(conn)
479        };
480        tokio::time::timeout(DEFAULT_CONNECT_TIMEOUT, gssenc_fut)
481            .await
482            .map_err(|_| {
483                PgError::Connection(format!(
484                    "GSSENC connection timeout after {:?} (handshake + auth)",
485                    DEFAULT_CONNECT_TIMEOUT
486                ))
487            })?
488    }
489
490    /// Connect to PostgreSQL server with optional password authentication and auth policy.
491    pub async fn connect_with_password_and_auth(
492        host: &str,
493        port: u16,
494        user: &str,
495        database: &str,
496        password: Option<&str>,
497        auth_settings: AuthSettings,
498    ) -> PgResult<Self> {
499        Self::connect_with_password_and_auth_and_gss(ConnectParams {
500            host,
501            port,
502            user,
503            database,
504            password,
505            auth_settings,
506            gss_token_provider: None,
507            gss_token_provider_ex: None,
508            protocol_minor: Self::default_protocol_minor(),
509            startup_params: Vec::new(),
510        })
511        .await
512    }
513
514    async fn connect_with_password_and_auth_and_gss(params: ConnectParams<'_>) -> PgResult<Self> {
515        let first = Self::connect_with_password_and_auth_and_gss_once(params.clone()).await;
516        if let Err(err) = &first
517            && params.protocol_minor > 0
518            && is_explicit_protocol_version_rejection(err)
519        {
520            let mut downgraded = params;
521            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
522            return Self::connect_with_password_and_auth_and_gss_once(downgraded).await;
523        }
524        first
525    }
526
527    async fn connect_with_password_and_auth_and_gss_once(
528        params: ConnectParams<'_>,
529    ) -> PgResult<Self> {
530        let connect_started = Instant::now();
531        let attempt_backend = plain_connect_attempt_backend();
532        record_connect_attempt(CONNECT_TRANSPORT_PLAIN, attempt_backend);
533        let result = tokio::time::timeout(
534            DEFAULT_CONNECT_TIMEOUT,
535            Self::connect_with_password_inner(params),
536        )
537        .await
538        .map_err(|_| {
539            PgError::Connection(format!(
540                "Connection timeout after {:?} (TCP connect + handshake)",
541                DEFAULT_CONNECT_TIMEOUT
542            ))
543        })?;
544        let backend = result
545            .as_ref()
546            .map(|conn| connect_backend_for_stream(&conn.stream))
547            .unwrap_or(attempt_backend);
548        record_connect_result(
549            CONNECT_TRANSPORT_PLAIN,
550            backend,
551            &result,
552            connect_started.elapsed(),
553        );
554        result
555    }
556
557    /// Inner connection logic without timeout wrapper.
558    async fn connect_with_password_inner(params: ConnectParams<'_>) -> PgResult<Self> {
559        let ConnectParams {
560            host,
561            port,
562            user,
563            database,
564            password,
565            auth_settings,
566            gss_token_provider,
567            gss_token_provider_ex,
568            protocol_minor,
569            startup_params,
570        } = params;
571        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
572        let addr = format!("{}:{}", host, port);
573        let stream = Self::connect_plain_stream(&addr).await?;
574
575        let mut conn = Self {
576            stream,
577            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
578            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
579            sql_buf: BytesMut::with_capacity(512),
580            params_buf: Vec::with_capacity(16), // SQL encoding buffer
581            prepared_statements: HashMap::new(),
582            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
583            column_info_cache: HashMap::new(),
584            process_id: 0,
585            secret_key: 0,
586            cancel_key_bytes: Vec::new(),
587            requested_protocol_minor: protocol_minor,
588            negotiated_protocol_minor: protocol_minor,
589            notifications: VecDeque::new(),
590            replication_stream_active: false,
591            replication_mode_enabled,
592            last_replication_wal_end: None,
593            io_desynced: false,
594            pending_statement_closes: Vec::new(),
595            draining_statement_closes: false,
596        };
597
598        conn.send(FrontendMessage::Startup {
599            user: user.to_string(),
600            database: database.to_string(),
601            protocol_version: protocol_version_from_minor(protocol_minor),
602            startup_params,
603        })
604        .await?;
605
606        conn.handle_startup(
607            user,
608            password,
609            auth_settings,
610            gss_token_provider,
611            gss_token_provider_ex,
612        )
613        .await?;
614
615        Ok(conn)
616    }
617
618    async fn connect_plain_stream(addr: &str) -> PgResult<PgStream> {
619        let tcp_stream = TcpStream::connect(addr).await?;
620        tcp_stream.set_nodelay(true)?;
621
622        #[cfg(all(target_os = "linux", feature = "io_uring"))]
623        {
624            if should_try_uring_plain() {
625                let std_stream = tcp_stream.into_std()?;
626                let fallback_std = std_stream.try_clone()?;
627                match super::super::uring::UringTcpStream::from_std(std_stream) {
628                    Ok(uring_stream) => {
629                        tracing::info!(
630                            addr = %addr,
631                            "qail-pg: using io_uring plain TCP transport"
632                        );
633                        return Ok(PgStream::Uring(uring_stream));
634                    }
635                    Err(e) => {
636                        tracing::warn!(
637                            addr = %addr,
638                            error = %e,
639                            "qail-pg: io_uring stream conversion failed; falling back to tokio TCP"
640                        );
641                        fallback_std.set_nonblocking(true)?;
642                        let fallback = TcpStream::from_std(fallback_std)?;
643                        return Ok(PgStream::Tcp(fallback));
644                    }
645                }
646            }
647        }
648
649        Ok(PgStream::Tcp(tcp_stream))
650    }
651
652    /// Connect to PostgreSQL server with TLS encryption.
653    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
654    pub async fn connect_tls(
655        host: &str,
656        port: u16,
657        user: &str,
658        database: &str,
659        password: Option<&str>,
660    ) -> PgResult<Self> {
661        Self::connect_tls_with_auth(
662            host,
663            port,
664            user,
665            database,
666            password,
667            AuthSettings::default(),
668            None,
669        )
670        .await
671    }
672
673    /// Connect to PostgreSQL over TLS with explicit auth policy and optional custom CA bundle.
674    pub async fn connect_tls_with_auth(
675        host: &str,
676        port: u16,
677        user: &str,
678        database: &str,
679        password: Option<&str>,
680        auth_settings: AuthSettings,
681        ca_cert_pem: Option<&[u8]>,
682    ) -> PgResult<Self> {
683        Self::connect_tls_with_auth_and_gss(
684            ConnectParams {
685                host,
686                port,
687                user,
688                database,
689                password,
690                auth_settings,
691                gss_token_provider: None,
692                gss_token_provider_ex: None,
693                protocol_minor: Self::default_protocol_minor(),
694                startup_params: Vec::new(),
695            },
696            ca_cert_pem,
697        )
698        .await
699    }
700
701    async fn connect_tls_with_auth_and_gss(
702        params: ConnectParams<'_>,
703        ca_cert_pem: Option<&[u8]>,
704    ) -> PgResult<Self> {
705        let first = Self::connect_tls_with_auth_and_gss_once(params.clone(), ca_cert_pem).await;
706        if let Err(err) = &first
707            && params.protocol_minor > 0
708            && is_explicit_protocol_version_rejection(err)
709        {
710            let mut downgraded = params;
711            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
712            return Self::connect_tls_with_auth_and_gss_once(downgraded, ca_cert_pem).await;
713        }
714        first
715    }
716
717    async fn connect_tls_with_auth_and_gss_once(
718        params: ConnectParams<'_>,
719        ca_cert_pem: Option<&[u8]>,
720    ) -> PgResult<Self> {
721        let connect_started = Instant::now();
722        record_connect_attempt(CONNECT_TRANSPORT_TLS, CONNECT_BACKEND_TOKIO);
723        let result = tokio::time::timeout(
724            DEFAULT_CONNECT_TIMEOUT,
725            Self::connect_tls_inner(params, ca_cert_pem),
726        )
727        .await
728        .map_err(|_| {
729            PgError::Connection(format!(
730                "TLS connection timeout after {:?}",
731                DEFAULT_CONNECT_TIMEOUT
732            ))
733        })?;
734        record_connect_result(
735            CONNECT_TRANSPORT_TLS,
736            CONNECT_BACKEND_TOKIO,
737            &result,
738            connect_started.elapsed(),
739        );
740        result
741    }
742
743    /// Inner TLS connection logic without timeout wrapper.
744    async fn connect_tls_inner(
745        params: ConnectParams<'_>,
746        ca_cert_pem: Option<&[u8]>,
747    ) -> PgResult<Self> {
748        let ConnectParams {
749            host,
750            port,
751            user,
752            database,
753            password,
754            auth_settings,
755            gss_token_provider,
756            gss_token_provider_ex,
757            protocol_minor,
758            startup_params,
759        } = params;
760        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
761        use tokio::io::AsyncReadExt;
762        use tokio_rustls::TlsConnector;
763        use tokio_rustls::rustls::ClientConfig;
764        use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, pem::PemObject};
765
766        let addr = format!("{}:{}", host, port);
767        let mut tcp_stream = TcpStream::connect(&addr).await?;
768
769        // Send SSLRequest
770        tcp_stream.write_all(&SSL_REQUEST).await?;
771
772        // Read response
773        let mut response = [0u8; 1];
774        tcp_stream.read_exact(&mut response).await?;
775
776        if response[0] != b'S' {
777            return Err(PgError::Connection(
778                "Server does not support TLS".to_string(),
779            ));
780        }
781
782        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
783
784        if let Some(ca_pem) = ca_cert_pem {
785            let certs = CertificateDer::pem_slice_iter(ca_pem)
786                .collect::<Result<Vec<_>, _>>()
787                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
788            if certs.is_empty() {
789                return Err(PgError::Connection(
790                    "No CA certificates found in provided PEM".to_string(),
791                ));
792            }
793            for cert in certs {
794                let _ = root_cert_store.add(cert);
795            }
796        } else {
797            let certs = rustls_native_certs::load_native_certs();
798            for cert in certs.certs {
799                let _ = root_cert_store.add(cert);
800            }
801        }
802
803        let config = ClientConfig::builder()
804            .with_root_certificates(root_cert_store)
805            .with_no_client_auth();
806
807        let connector = TlsConnector::from(Arc::new(config));
808        let server_name = ServerName::try_from(host.to_string())
809            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
810
811        let tls_stream = connector
812            .connect(server_name, tcp_stream)
813            .await
814            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
815
816        let mut conn = Self {
817            stream: PgStream::Tls(Box::new(tls_stream)),
818            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
819            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
820            sql_buf: BytesMut::with_capacity(512),
821            params_buf: Vec::with_capacity(16),
822            prepared_statements: HashMap::new(),
823            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
824            column_info_cache: HashMap::new(),
825            process_id: 0,
826            secret_key: 0,
827            cancel_key_bytes: Vec::new(),
828            requested_protocol_minor: protocol_minor,
829            negotiated_protocol_minor: protocol_minor,
830            notifications: VecDeque::new(),
831            replication_stream_active: false,
832            replication_mode_enabled,
833            last_replication_wal_end: None,
834            io_desynced: false,
835            pending_statement_closes: Vec::new(),
836            draining_statement_closes: false,
837        };
838
839        conn.send(FrontendMessage::Startup {
840            user: user.to_string(),
841            database: database.to_string(),
842            protocol_version: protocol_version_from_minor(protocol_minor),
843            startup_params,
844        })
845        .await?;
846
847        conn.handle_startup(
848            user,
849            password,
850            auth_settings,
851            gss_token_provider,
852            gss_token_provider_ex,
853        )
854        .await?;
855
856        Ok(conn)
857    }
858
859    /// Connect with mutual TLS (client certificate authentication).
860    /// # Arguments
861    /// * `host` - PostgreSQL server hostname
862    /// * `port` - PostgreSQL server port
863    /// * `user` - Database user
864    /// * `database` - Database name
865    /// * `config` - TLS configuration with client cert/key
866    /// # Example
867    /// ```ignore
868    /// let config = TlsConfig {
869    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
870    ///     client_key_pem: include_bytes!("client.key").to_vec(),
871    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
872    /// };
873    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
874    /// ```
875    pub async fn connect_mtls(
876        host: &str,
877        port: u16,
878        user: &str,
879        database: &str,
880        config: TlsConfig,
881    ) -> PgResult<Self> {
882        Self::connect_mtls_with_password_and_auth(
883            host,
884            port,
885            user,
886            database,
887            None,
888            config,
889            AuthSettings::default(),
890        )
891        .await
892    }
893
894    /// Connect with mutual TLS and optional password fallback.
895    pub async fn connect_mtls_with_password_and_auth(
896        host: &str,
897        port: u16,
898        user: &str,
899        database: &str,
900        password: Option<&str>,
901        config: TlsConfig,
902        auth_settings: AuthSettings,
903    ) -> PgResult<Self> {
904        Self::connect_mtls_with_password_and_auth_and_gss(
905            ConnectParams {
906                host,
907                port,
908                user,
909                database,
910                password,
911                auth_settings,
912                gss_token_provider: None,
913                gss_token_provider_ex: None,
914                protocol_minor: Self::default_protocol_minor(),
915                startup_params: Vec::new(),
916            },
917            config,
918        )
919        .await
920    }
921
922    async fn connect_mtls_with_password_and_auth_and_gss(
923        params: ConnectParams<'_>,
924        config: TlsConfig,
925    ) -> PgResult<Self> {
926        let first =
927            Self::connect_mtls_with_password_and_auth_and_gss_once(params.clone(), config.clone())
928                .await;
929        if let Err(err) = &first
930            && params.protocol_minor > 0
931            && is_explicit_protocol_version_rejection(err)
932        {
933            let mut downgraded = params;
934            downgraded.protocol_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
935            return Self::connect_mtls_with_password_and_auth_and_gss_once(downgraded, config)
936                .await;
937        }
938        first
939    }
940
941    async fn connect_mtls_with_password_and_auth_and_gss_once(
942        params: ConnectParams<'_>,
943        config: TlsConfig,
944    ) -> PgResult<Self> {
945        let connect_started = Instant::now();
946        record_connect_attempt(CONNECT_TRANSPORT_MTLS, CONNECT_BACKEND_TOKIO);
947        let result = tokio::time::timeout(
948            DEFAULT_CONNECT_TIMEOUT,
949            Self::connect_mtls_inner(params, config),
950        )
951        .await
952        .map_err(|_| {
953            PgError::Connection(format!(
954                "mTLS connection timeout after {:?}",
955                DEFAULT_CONNECT_TIMEOUT
956            ))
957        })?;
958        record_connect_result(
959            CONNECT_TRANSPORT_MTLS,
960            CONNECT_BACKEND_TOKIO,
961            &result,
962            connect_started.elapsed(),
963        );
964        result
965    }
966
967    /// Inner mTLS connection logic without timeout wrapper.
968    async fn connect_mtls_inner(params: ConnectParams<'_>, config: TlsConfig) -> PgResult<Self> {
969        let ConnectParams {
970            host,
971            port,
972            user,
973            database,
974            password,
975            auth_settings,
976            gss_token_provider,
977            gss_token_provider_ex,
978            protocol_minor,
979            startup_params,
980        } = params;
981        let replication_mode_enabled = has_logical_replication_startup_mode(&startup_params);
982        use tokio::io::AsyncReadExt;
983        use tokio_rustls::TlsConnector;
984        use tokio_rustls::rustls::{
985            ClientConfig,
986            pki_types::{CertificateDer, PrivateKeyDer, ServerName, pem::PemObject},
987        };
988
989        let addr = format!("{}:{}", host, port);
990        let mut tcp_stream = TcpStream::connect(&addr).await?;
991
992        // Send SSLRequest
993        tcp_stream.write_all(&SSL_REQUEST).await?;
994
995        // Read response
996        let mut response = [0u8; 1];
997        tcp_stream.read_exact(&mut response).await?;
998
999        if response[0] != b'S' {
1000            return Err(PgError::Connection(
1001                "Server does not support TLS".to_string(),
1002            ));
1003        }
1004
1005        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
1006
1007        if let Some(ca_pem) = &config.ca_cert_pem {
1008            let certs = CertificateDer::pem_slice_iter(ca_pem)
1009                .collect::<Result<Vec<_>, _>>()
1010                .map_err(|e| PgError::Connection(format!("Invalid CA certificate PEM: {}", e)))?;
1011            if certs.is_empty() {
1012                return Err(PgError::Connection(
1013                    "No CA certificates found in provided PEM".to_string(),
1014                ));
1015            }
1016            for cert in certs {
1017                let _ = root_cert_store.add(cert);
1018            }
1019        } else {
1020            // Use system certs
1021            let certs = rustls_native_certs::load_native_certs();
1022            for cert in certs.certs {
1023                let _ = root_cert_store.add(cert);
1024            }
1025        }
1026
1027        let client_certs: Vec<CertificateDer<'static>> =
1028            CertificateDer::pem_slice_iter(&config.client_cert_pem)
1029                .collect::<Result<Vec<_>, _>>()
1030                .map_err(|e| PgError::Connection(format!("Invalid client cert PEM: {}", e)))?;
1031        if client_certs.is_empty() {
1032            return Err(PgError::Connection(
1033                "No client certificates found in PEM".to_string(),
1034            ));
1035        }
1036
1037        let client_key = PrivateKeyDer::from_pem_slice(&config.client_key_pem)
1038            .map_err(|e| PgError::Connection(format!("Invalid client key PEM: {}", e)))?;
1039
1040        let tls_config = ClientConfig::builder()
1041            .with_root_certificates(root_cert_store)
1042            .with_client_auth_cert(client_certs, client_key)
1043            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
1044
1045        let connector = TlsConnector::from(Arc::new(tls_config));
1046        let server_name = ServerName::try_from(host.to_string())
1047            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
1048
1049        let tls_stream = connector
1050            .connect(server_name, tcp_stream)
1051            .await
1052            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
1053
1054        let mut conn = Self {
1055            stream: PgStream::Tls(Box::new(tls_stream)),
1056            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1057            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1058            sql_buf: BytesMut::with_capacity(512),
1059            params_buf: Vec::with_capacity(16),
1060            prepared_statements: HashMap::new(),
1061            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1062            column_info_cache: HashMap::new(),
1063            process_id: 0,
1064            secret_key: 0,
1065            cancel_key_bytes: Vec::new(),
1066            requested_protocol_minor: protocol_minor,
1067            negotiated_protocol_minor: protocol_minor,
1068            notifications: VecDeque::new(),
1069            replication_stream_active: false,
1070            replication_mode_enabled,
1071            last_replication_wal_end: None,
1072            io_desynced: false,
1073            pending_statement_closes: Vec::new(),
1074            draining_statement_closes: false,
1075        };
1076
1077        conn.send(FrontendMessage::Startup {
1078            user: user.to_string(),
1079            database: database.to_string(),
1080            protocol_version: protocol_version_from_minor(protocol_minor),
1081            startup_params,
1082        })
1083        .await?;
1084
1085        conn.handle_startup(
1086            user,
1087            password,
1088            auth_settings,
1089            gss_token_provider,
1090            gss_token_provider_ex,
1091        )
1092        .await?;
1093
1094        Ok(conn)
1095    }
1096
1097    /// Connect to PostgreSQL server via Unix domain socket.
1098    #[cfg(unix)]
1099    pub async fn connect_unix(
1100        socket_path: &str,
1101        user: &str,
1102        database: &str,
1103        password: Option<&str>,
1104    ) -> PgResult<Self> {
1105        let default_minor = Self::default_protocol_minor();
1106        let first =
1107            Self::connect_unix_with_protocol(socket_path, user, database, password, default_minor)
1108                .await;
1109        if let Err(err) = &first
1110            && default_minor > 0
1111            && is_explicit_protocol_version_rejection(err)
1112        {
1113            let downgrade_minor = (PROTOCOL_VERSION_3_0 & 0xFFFF) as u16;
1114            return Self::connect_unix_with_protocol(
1115                socket_path,
1116                user,
1117                database,
1118                password,
1119                downgrade_minor,
1120            )
1121            .await;
1122        }
1123        first
1124    }
1125
1126    #[cfg(unix)]
1127    async fn connect_unix_with_protocol(
1128        socket_path: &str,
1129        user: &str,
1130        database: &str,
1131        password: Option<&str>,
1132        protocol_minor: u16,
1133    ) -> PgResult<Self> {
1134        use tokio::net::UnixStream;
1135
1136        let unix_stream = UnixStream::connect(socket_path).await?;
1137
1138        let mut conn = Self {
1139            stream: PgStream::Unix(unix_stream),
1140            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
1141            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
1142            sql_buf: BytesMut::with_capacity(512),
1143            params_buf: Vec::with_capacity(16),
1144            prepared_statements: HashMap::new(),
1145            stmt_cache: StatementCache::new(STMT_CACHE_CAPACITY),
1146            column_info_cache: HashMap::new(),
1147            process_id: 0,
1148            secret_key: 0,
1149            cancel_key_bytes: Vec::new(),
1150            requested_protocol_minor: protocol_minor,
1151            negotiated_protocol_minor: protocol_minor,
1152            notifications: VecDeque::new(),
1153            replication_stream_active: false,
1154            replication_mode_enabled: false,
1155            last_replication_wal_end: None,
1156            io_desynced: false,
1157            pending_statement_closes: Vec::new(),
1158            draining_statement_closes: false,
1159        };
1160
1161        conn.send(FrontendMessage::Startup {
1162            user: user.to_string(),
1163            database: database.to_string(),
1164            protocol_version: protocol_version_from_minor(protocol_minor),
1165            startup_params: Vec::new(),
1166        })
1167        .await?;
1168
1169        conn.handle_startup(user, password, AuthSettings::default(), None, None)
1170            .await?;
1171
1172        Ok(conn)
1173    }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::{is_explicit_protocol_version_rejection, protocol_version_from_minor};
1179    use crate::driver::PgError;
1180
1181    #[test]
1182    fn protocol_version_from_minor_encodes_major_3() {
1183        assert_eq!(protocol_version_from_minor(2), 196610);
1184        assert_eq!(protocol_version_from_minor(0), 196608);
1185    }
1186
1187    #[test]
1188    fn explicit_protocol_rejection_detection_is_case_insensitive() {
1189        let err = PgError::Connection("Unsupported frontend protocol 3.2".to_string());
1190        assert!(is_explicit_protocol_version_rejection(&err));
1191
1192        let err = PgError::Protocol("server: Protocol VERSION not supported".to_string());
1193        assert!(is_explicit_protocol_version_rejection(&err));
1194    }
1195
1196    #[test]
1197    fn explicit_protocol_rejection_does_not_match_unrelated_errors() {
1198        let err = PgError::Connection("connection reset by peer".to_string());
1199        assert!(!is_explicit_protocol_version_rejection(&err));
1200    }
1201}