Skip to main content

pg_wired/
connection.rs

1//! Synchronous-style PostgreSQL connection driving the v3 wire protocol on
2//! a single owned [`tokio::net::TcpStream`]. Use [`crate::AsyncConn`] for the
3//! shared, multi-task connection wrapper most callers want.
4
5use bytes::BytesMut;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9use crate::error::PgWireError;
10use crate::protocol::backend;
11use crate::protocol::frontend;
12use crate::protocol::types::{BackendMsg, FrontendMsg, RawRow};
13use crate::scram::ScramClient;
14use crate::tls::{MaybeTlsStream, TlsMode};
15
16/// Raw PostgreSQL wire connection.
17/// Handles TCP I/O, buffered reading, and authentication.
18pub struct WireConn {
19    pub(crate) stream: MaybeTlsStream,
20    recv_buf: BytesMut,
21    pub(crate) pid: i32,
22    pub(crate) secret: i32,
23    /// Server parameters reported via `ParameterStatus` during startup.
24    ///
25    /// PostgreSQL reports a small set of GUCs it considers useful to clients:
26    /// typically `server_version`, `server_encoding`, `client_encoding`,
27    /// `application_name`, `is_superuser`, `session_authorization`,
28    /// `DateStyle`, `IntervalStyle`, `TimeZone`, `integer_datetimes`, and
29    /// `standard_conforming_strings`. The exact set depends on the server
30    /// version and its `GUC_REPORT` configuration.
31    ///
32    /// This map is populated once on connect and is not kept in sync when
33    /// the server emits later `ParameterStatus` messages (for example after
34    /// `SET TimeZone = ...`). Treat these values as startup defaults, not a
35    /// live view of the session state.
36    pub params: std::collections::HashMap<String, String>,
37    /// Authentication mechanism the server selected during startup.
38    ///
39    /// One of `"trust"`, `"cleartext"`, `"md5"`, `"SCRAM-SHA-256"`, or
40    /// `"SCRAM-SHA-256-PLUS"`. Useful for tests that need to verify channel
41    /// binding actually fired and for operational logging.
42    pub(crate) auth_mechanism: &'static str,
43}
44
45impl WireConn {
46    /// Backend process ID assigned by the server. Useful for logging and for
47    /// building a cancel token. The secret key that pairs with this PID is
48    /// intentionally not exposed; use `cancel_token()` to obtain a token that
49    /// can send a cancel request.
50    pub fn pid(&self) -> i32 {
51        self.pid
52    }
53
54    /// Authentication mechanism the server selected during startup.
55    ///
56    /// Returns one of `"trust"`, `"cleartext"`, `"md5"`, `"SCRAM-SHA-256"`,
57    /// or `"SCRAM-SHA-256-PLUS"`. `"SCRAM-SHA-256-PLUS"` confirms that
58    /// `tls-server-end-point` channel binding was negotiated.
59    pub fn auth_mechanism(&self) -> &'static str {
60        self.auth_mechanism
61    }
62}
63
64impl std::fmt::Debug for WireConn {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("WireConn")
67            .field("pid", &self.pid)
68            .field("params", &self.params)
69            .finish_non_exhaustive()
70    }
71}
72
73const RECV_BUF_SIZE: usize = 32 * 1024; // 32KB recv buffer
74
75impl WireConn {
76    /// Choose the best SCRAM mechanism based on TLS state and server support.
77    /// Returns (ChannelBinding, mechanism_name_bytes).
78    #[allow(clippy::result_large_err)]
79    fn choose_scram_mechanism(
80        &self,
81        mechanisms: &[String],
82    ) -> Result<(crate::scram::ChannelBinding, &'static [u8], &'static str), PgWireError> {
83        // If TLS is active and server supports SCRAM-SHA-256-PLUS, use channel binding.
84        #[cfg(feature = "tls")]
85        if let MaybeTlsStream::Tls(ref tls) = self.stream {
86            if mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS") {
87                if let Some(certs) = tls.get_ref().1.peer_certificates() {
88                    if let Some(cert) = certs.first() {
89                        let hash = crate::cert_hash::cert_signature_hash(cert.as_ref());
90                        return Ok((
91                            crate::scram::ChannelBinding::TlsServerEndPoint(hash),
92                            b"SCRAM-SHA-256-PLUS",
93                            "SCRAM-SHA-256-PLUS",
94                        ));
95                    }
96                }
97            }
98        }
99
100        // Fall back to plain SCRAM-SHA-256.
101        if mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
102            Ok((
103                crate::scram::ChannelBinding::None,
104                b"SCRAM-SHA-256",
105                "SCRAM-SHA-256",
106            ))
107        } else {
108            Err(PgWireError::Protocol(format!(
109                "No supported SASL mechanism: {:?}",
110                mechanisms
111            )))
112        }
113    }
114
115    /// Check if the connection has unconsumed data in the receive buffer.
116    pub fn has_pending_data(&self) -> bool {
117        !self.recv_buf.is_empty()
118    }
119
120    /// Connect to PostgreSQL and perform authentication.
121    pub async fn connect(
122        addr: &str,
123        user: &str,
124        password: &str,
125        database: &str,
126    ) -> Result<Self, PgWireError> {
127        Self::connect_with_options(addr, user, password, database, &[], TlsMode::default()).await
128    }
129
130    /// Connect with additional startup parameters.
131    ///
132    /// Parameters are sent in the startup message and appear in `pg_stat_activity`.
133    /// Common parameters: `application_name`, `client_encoding`, `options`.
134    ///
135    /// ```no_run
136    /// # async fn _doctest() -> Result<(), Box<dyn std::error::Error>> {
137    /// use pg_wired::WireConn;
138    /// let _conn = WireConn::connect_with_params(
139    ///     "127.0.0.1:5432", "user", "pass", "mydb",
140    ///     &[("application_name", "my-service")],
141    /// ).await?;
142    /// # Ok(()) }
143    /// ```
144    pub async fn connect_with_params(
145        addr: &str,
146        user: &str,
147        password: &str,
148        database: &str,
149        startup_params: &[(&str, &str)],
150    ) -> Result<Self, PgWireError> {
151        Self::connect_with_options(
152            addr,
153            user,
154            password,
155            database,
156            startup_params,
157            TlsMode::default(),
158        )
159        .await
160    }
161
162    /// Connect with startup parameters and an explicit TLS mode.
163    ///
164    /// Uses the system root trust store (`webpki-roots`) for certificate
165    /// verification. To override the trust store, supply a client
166    /// certificate, or otherwise customize TLS, use
167    /// [`Self::connect_with_tls_config`].
168    pub async fn connect_with_options(
169        addr: &str,
170        user: &str,
171        password: &str,
172        database: &str,
173        startup_params: &[(&str, &str)],
174        tls_mode: TlsMode,
175    ) -> Result<Self, PgWireError> {
176        #[cfg(feature = "tls")]
177        {
178            Self::connect_with_tls_config(
179                addr,
180                user,
181                password,
182                database,
183                startup_params,
184                tls_mode,
185                &crate::tls::TlsConfig::default(),
186            )
187            .await
188        }
189        #[cfg(not(feature = "tls"))]
190        {
191            Self::connect_inner(addr, user, password, database, startup_params, tls_mode).await
192        }
193    }
194
195    /// Connect with startup parameters, an explicit TLS mode, and a custom
196    /// TLS configuration (custom trust roots and/or a client certificate).
197    ///
198    /// Only available when the `tls` feature is enabled.
199    #[cfg(feature = "tls")]
200    pub async fn connect_with_tls_config(
201        addr: &str,
202        user: &str,
203        password: &str,
204        database: &str,
205        startup_params: &[(&str, &str)],
206        tls_mode: TlsMode,
207        tls_config: &crate::tls::TlsConfig,
208    ) -> Result<Self, PgWireError> {
209        Self::connect_inner(
210            addr,
211            user,
212            password,
213            database,
214            startup_params,
215            tls_mode,
216            tls_config,
217        )
218        .await
219    }
220
221    #[cfg(feature = "tls")]
222    async fn connect_inner(
223        addr: &str,
224        user: &str,
225        password: &str,
226        database: &str,
227        startup_params: &[(&str, &str)],
228        tls_mode: TlsMode,
229        tls_config: &crate::tls::TlsConfig,
230    ) -> Result<Self, PgWireError> {
231        let stream = TcpStream::connect(addr).await?;
232        stream.set_nodelay(true)?;
233
234        let socket = socket2::SockRef::from(&stream);
235        let keepalive = socket2::TcpKeepalive::new()
236            .with_time(std::time::Duration::from_secs(60))
237            .with_interval(std::time::Duration::from_secs(15));
238        let _ = socket.set_tcp_keepalive(&keepalive);
239
240        let hostname = parse_hostname(addr);
241        let stream =
242            crate::tls::negotiate_tls_with_config(stream, &hostname, tls_config, tls_mode).await?;
243
244        Self::finish_startup(stream, user, password, database, startup_params).await
245    }
246
247    #[cfg(not(feature = "tls"))]
248    async fn connect_inner(
249        addr: &str,
250        user: &str,
251        password: &str,
252        database: &str,
253        startup_params: &[(&str, &str)],
254        tls_mode: TlsMode,
255    ) -> Result<Self, PgWireError> {
256        let stream = TcpStream::connect(addr).await?;
257        stream.set_nodelay(true)?;
258
259        let socket = socket2::SockRef::from(&stream);
260        let keepalive = socket2::TcpKeepalive::new()
261            .with_time(std::time::Duration::from_secs(60))
262            .with_interval(std::time::Duration::from_secs(15));
263        let _ = socket.set_tcp_keepalive(&keepalive);
264
265        if tls_mode == TlsMode::Require {
266            return Err(PgWireError::Protocol(
267                "sslmode=require but pg-wired was built without the `tls` feature".into(),
268            ));
269        }
270        let stream = MaybeTlsStream::Plain(stream);
271
272        Self::finish_startup(stream, user, password, database, startup_params).await
273    }
274
275    async fn finish_startup(
276        stream: MaybeTlsStream,
277        user: &str,
278        password: &str,
279        database: &str,
280        startup_params: &[(&str, &str)],
281    ) -> Result<Self, PgWireError> {
282        let mut conn = WireConn {
283            stream,
284            recv_buf: BytesMut::with_capacity(RECV_BUF_SIZE),
285            pid: 0,
286            secret: 0,
287            params: std::collections::HashMap::new(),
288            // Default to "trust": if the server sends AuthenticationOk without
289            // a prior challenge, no real auth method ran.
290            auth_mechanism: "trust",
291        };
292
293        // Send startup message with optional extra parameters.
294        let mut buf = BytesMut::new();
295        frontend::encode_startup_with_params(user, database, startup_params, &mut buf);
296        conn.send_raw(&buf).await?;
297
298        // Authentication loop.
299        loop {
300            let msg = conn.recv_msg().await?;
301            match msg {
302                BackendMsg::AuthenticationOk => {}
303                BackendMsg::AuthenticationCleartextPassword => {
304                    conn.auth_mechanism = "cleartext";
305                    let mut buf = BytesMut::new();
306                    frontend::encode_password(password.as_bytes(), &mut buf);
307                    conn.send_raw(&buf).await?;
308                }
309                BackendMsg::AuthenticationMd5Password { salt } => {
310                    conn.auth_mechanism = "md5";
311                    let hash = frontend::md5_password(user, password, &salt);
312                    let mut buf = BytesMut::new();
313                    frontend::encode_password(&hash, &mut buf);
314                    conn.send_raw(&buf).await?;
315                }
316                BackendMsg::AuthenticationSASL { mechanisms } => {
317                    // Prefer SCRAM-SHA-256-PLUS (with channel binding) when TLS is active.
318                    let (cb, mechanism, name) = conn.choose_scram_mechanism(&mechanisms)?;
319                    conn.auth_mechanism = name;
320                    let (scram, client_first) = ScramClient::new(password, cb);
321                    let mut buf = BytesMut::new();
322                    frontend::encode_message(
323                        &FrontendMsg::SASLInitialResponse {
324                            mechanism,
325                            data: &client_first,
326                        },
327                        &mut buf,
328                    );
329                    conn.send_raw(&buf).await?;
330
331                    // Wait for server-first.
332                    let server_first = loop {
333                        match conn.recv_msg().await? {
334                            BackendMsg::AuthenticationSASLContinue { data } => break data,
335                            BackendMsg::ErrorResponse { fields } => {
336                                return Err(PgWireError::Pg(fields));
337                            }
338                            _ => {}
339                        }
340                    };
341
342                    let client_final = scram
343                        .process_server_first(&server_first)
344                        .map_err(PgWireError::Protocol)?;
345                    let mut buf = BytesMut::new();
346                    frontend::encode_message(&FrontendMsg::SASLResponse(&client_final), &mut buf);
347                    conn.send_raw(&buf).await?;
348
349                    // Wait for server-final + AuthenticationOk.
350                    loop {
351                        match conn.recv_msg().await? {
352                            BackendMsg::AuthenticationSASLFinal { .. } => {}
353                            BackendMsg::AuthenticationOk => break,
354                            BackendMsg::ErrorResponse { fields } => {
355                                return Err(PgWireError::Pg(fields));
356                            }
357                            _ => {}
358                        }
359                    }
360                }
361                BackendMsg::ParameterStatus { name, value } => {
362                    tracing::debug!(name = %name, value = %value, "server parameter");
363                    conn.params.insert(name, value);
364                }
365                BackendMsg::BackendKeyData { pid, secret } => {
366                    conn.pid = pid;
367                    conn.secret = secret;
368                }
369                BackendMsg::ReadyForQuery { .. } => break,
370                BackendMsg::ErrorResponse { fields } => {
371                    return Err(PgWireError::Pg(fields));
372                }
373                BackendMsg::NoticeResponse { .. } => {}
374                other => {
375                    tracing::debug!("Startup: ignoring {:?}", other);
376                }
377            }
378        }
379
380        Ok(conn)
381    }
382
383    /// Send a raw buffer to the server (one write syscall).
384    pub async fn send_raw(&mut self, buf: &[u8]) -> Result<(), PgWireError> {
385        self.stream.write_all(buf).await?;
386        Ok(())
387    }
388
389    /// Read one complete backend message from the connection.
390    /// Uses an internal buffer to minimize read() syscalls.
391    pub async fn recv_msg(&mut self) -> Result<BackendMsg, PgWireError> {
392        loop {
393            // Try to parse a message from the buffer.
394            if let Some(msg) =
395                backend::parse_message(&mut self.recv_buf).map_err(PgWireError::Protocol)?
396            {
397                return Ok(msg);
398            }
399
400            // Not enough data — read more from the socket.
401            let n = self.stream.read_buf(&mut self.recv_buf).await?;
402            if n == 0 {
403                // EOF — try to parse any remaining buffered data before giving up.
404                if let Some(msg) =
405                    backend::parse_message(&mut self.recv_buf).map_err(PgWireError::Protocol)?
406                {
407                    return Ok(msg);
408                }
409                return Err(PgWireError::ConnectionClosed);
410            }
411        }
412    }
413
414    /// Receive messages until ReadyForQuery, collecting DataRows.
415    /// Returns (rows, command_tag).
416    pub async fn collect_rows(&mut self) -> Result<(Vec<RawRow>, String), PgWireError> {
417        let mut rows = Vec::new();
418        let mut tag = String::new();
419
420        loop {
421            let msg = self.recv_msg().await?;
422            match msg {
423                BackendMsg::DataRow(row) => {
424                    tracing::trace!("collect_rows: DataRow with {} cols", row.len());
425                    rows.push(row);
426                }
427                BackendMsg::CommandComplete { tag: t } => tag = t,
428                BackendMsg::ReadyForQuery { .. } => return Ok((rows, tag)),
429                BackendMsg::ParseComplete | BackendMsg::BindComplete | BackendMsg::NoData => {}
430                BackendMsg::RowDescription { .. } => {}
431                BackendMsg::ErrorResponse { fields } => {
432                    // Drain until ReadyForQuery.
433                    self.drain_until_ready().await?;
434                    return Err(PgWireError::Pg(fields));
435                }
436                BackendMsg::NoticeResponse { .. } => {}
437                BackendMsg::EmptyQueryResponse => {}
438                _ => {}
439            }
440        }
441    }
442
443    /// Describe a SQL statement: sends Parse + Describe Statement + Sync,
444    /// returns (parameter type OIDs, column field descriptions).
445    /// Used by compile-time query checking macros.
446    pub async fn describe_statement(
447        &mut self,
448        sql: &str,
449    ) -> Result<(Vec<u32>, Vec<crate::protocol::types::FieldDescription>), PgWireError> {
450        use crate::protocol::frontend;
451        use crate::protocol::types::FrontendMsg;
452        let mut buf = bytes::BytesMut::with_capacity(256);
453
454        // Parse (unnamed statement).
455        frontend::encode_message(
456            &FrontendMsg::Parse {
457                name: b"",
458                sql: sql.as_bytes(),
459                param_oids: &[],
460            },
461            &mut buf,
462        );
463        // Describe statement.
464        frontend::encode_message(
465            &FrontendMsg::Describe {
466                kind: b'S',
467                name: b"",
468            },
469            &mut buf,
470        );
471        // Sync.
472        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
473
474        self.send_raw(&buf).await?;
475
476        let mut param_oids = Vec::new();
477        let mut fields = Vec::new();
478
479        loop {
480            let msg = self.recv_msg().await?;
481            match msg {
482                BackendMsg::ParseComplete => {}
483                BackendMsg::ParameterDescription { type_oids } => {
484                    param_oids = type_oids;
485                }
486                BackendMsg::RowDescription { fields: f } => {
487                    fields = f;
488                }
489                BackendMsg::NoData => {} // query returns no rows
490                BackendMsg::ReadyForQuery { .. } => {
491                    return Ok((param_oids, fields));
492                }
493                BackendMsg::ErrorResponse { fields } => {
494                    self.drain_until_ready().await?;
495                    return Err(PgWireError::Pg(fields));
496                }
497                _ => {}
498            }
499        }
500    }
501
502    /// Drain messages until ReadyForQuery (error recovery).
503    /// Also attempts to parse any remaining data in the receive buffer
504    /// before declaring the connection closed.
505    pub async fn drain_until_ready(&mut self) -> Result<(), PgWireError> {
506        loop {
507            let msg = self.recv_msg().await?;
508            if matches!(msg, BackendMsg::ReadyForQuery { .. }) {
509                return Ok(());
510            }
511            // ErrorResponse inside a simple query — absorb it, keep draining.
512            if let BackendMsg::ErrorResponse { ref fields } = msg {
513                tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
514            }
515        }
516    }
517}
518
519#[cfg(any(feature = "tls", test))]
520/// Extract hostname from an address string, handling IPv6 bracket notation.
521/// Examples: "localhost:5432" → "localhost", "[::1]:5432" → "::1", "host" → "host"
522fn parse_hostname(addr: &str) -> String {
523    if addr.starts_with('[') {
524        // IPv6 bracket notation: [::1]:5432
525        if let Some(end) = addr.find(']') {
526            return addr[1..end].to_string();
527        }
528    }
529    // IPv4 or hostname: host:port
530    addr.split(':').next().unwrap_or(addr).to_string()
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_parse_hostname_ipv4() {
539        assert_eq!(parse_hostname("127.0.0.1:5432"), "127.0.0.1");
540    }
541
542    #[test]
543    fn test_parse_hostname_name() {
544        assert_eq!(parse_hostname("localhost:5432"), "localhost");
545    }
546
547    #[test]
548    fn test_parse_hostname_ipv6() {
549        assert_eq!(parse_hostname("[::1]:5432"), "::1");
550    }
551
552    #[test]
553    fn test_parse_hostname_ipv6_full() {
554        assert_eq!(parse_hostname("[2001:db8::1]:5432"), "2001:db8::1");
555    }
556
557    #[test]
558    fn test_parse_hostname_no_port() {
559        assert_eq!(parse_hostname("myhost"), "myhost");
560    }
561}