Skip to main content

pgwire_replication/protocol/
framing.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use std::io;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::error::{PgWireError, Result};
6
7/// Maximum backend message size (1GB) - prevents memory exhaustion from malformed length fields
8/// This is more than enough.
9pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BackendMessage {
13    pub tag: u8,
14    pub payload: Bytes, // payload excludes the 4-byte length field
15}
16
17impl BackendMessage {
18    /// Returns true if this is an ErrorResponse ('E')
19    #[inline]
20    pub fn is_error(&self) -> bool {
21        self.tag == b'E'
22    }
23
24    /// Returns true if this is a ReadyForQuery ('Z')
25    #[inline]
26    pub fn is_ready_for_query(&self) -> bool {
27        self.tag == b'Z'
28    }
29
30    /// Returns true if this is CopyBothResponse ('W')
31    #[inline]
32    pub fn is_copy_both_response(&self) -> bool {
33        self.tag == b'W'
34    }
35
36    /// Returns true if this is CopyData ('d')
37    #[inline]
38    pub fn is_copy_data(&self) -> bool {
39        self.tag == b'd'
40    }
41
42    /// Returns true if this is AuthenticationRequest ('R')
43    #[inline]
44    pub fn is_auth_request(&self) -> bool {
45        self.tag == b'R'
46    }
47}
48
49pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
50    let mut reader = MessageReader::new();
51    reader.read(rd).await
52}
53
54/// Cancellation-safe backend message reader.
55///
56/// PostgreSQL backend messages span multiple `read` operations (5-byte header,
57/// then a variable payload). A naive implementation using `read_exact` is
58/// **not** cancellation-safe: if the future is dropped between reads (e.g. by
59/// `tokio::select!` or `tokio::time::timeout`), bytes already pulled from the
60/// underlying stream are lost and the next read mis-parses the wire stream.
61///
62/// `MessageReader` externalizes the partial-read state so it survives across
63/// dropped futures. Each call to [`read`](Self::read) uses one-shot
64/// `AsyncReadExt::read` (which **is** cancel-safe) and accumulates progress
65/// on `self`. If the returned future is dropped, no bytes are lost; the next
66/// invocation resumes from where the previous one left off.
67pub struct MessageReader {
68    hdr: [u8; 5],
69    hdr_filled: usize,
70    payload: BytesMut,
71    payload_filled: usize,
72    /// `Some` once the header has been fully read and parsed; reset to
73    /// `None` after each completed message.
74    payload_len: Option<usize>,
75    tag: u8,
76}
77
78impl MessageReader {
79    pub fn new() -> Self {
80        Self::with_capacity(4096)
81    }
82
83    pub fn with_capacity(capacity: usize) -> Self {
84        Self {
85            hdr: [0u8; 5],
86            hdr_filled: 0,
87            payload: BytesMut::with_capacity(capacity),
88            payload_filled: 0,
89            payload_len: None,
90            tag: 0,
91        }
92    }
93
94    /// Read the next complete backend message.
95    ///
96    /// Cancellation-safe: dropping the returned future preserves all progress
97    /// so far on `self`. Re-call to resume.
98    pub async fn read<R: AsyncRead + Unpin>(&mut self, rd: &mut R) -> Result<BackendMessage> {
99        // Phase 1: fill the 5-byte header
100        while self.hdr_filled < 5 {
101            let n = rd.read(&mut self.hdr[self.hdr_filled..]).await?;
102            if n == 0 {
103                return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
104                    io::ErrorKind::UnexpectedEof,
105                    "EOF while reading backend message header",
106                ))));
107            }
108            self.hdr_filled += n;
109        }
110
111        // Phase 2: parse the header (idempotent — runs once per message)
112        if self.payload_len.is_none() {
113            let len = i32::from_be_bytes([self.hdr[1], self.hdr[2], self.hdr[3], self.hdr[4]]);
114
115            if len < 4 {
116                // Reset so the reader is reusable after a protocol error is
117                // surfaced (callers typically tear down on this anyway).
118                self.hdr_filled = 0;
119                return Err(PgWireError::Protocol(format!(
120                    "invalid backend message length: {len}"
121                )));
122            }
123
124            let payload_len = (len - 4) as usize;
125
126            if payload_len > MAX_MESSAGE_SIZE {
127                self.hdr_filled = 0;
128                return Err(PgWireError::Protocol(format!(
129                    "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
130                )));
131            }
132
133            self.tag = self.hdr[0];
134            self.payload.clear();
135            self.payload.resize(payload_len, 0);
136            self.payload_filled = 0;
137            self.payload_len = Some(payload_len);
138        }
139
140        let payload_len = self.payload_len.unwrap();
141
142        // Phase 3: fill the payload
143        while self.payload_filled < payload_len {
144            let n = rd.read(&mut self.payload[self.payload_filled..]).await?;
145            if n == 0 {
146                return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
147                    io::ErrorKind::UnexpectedEof,
148                    "EOF while reading backend message payload",
149                ))));
150            }
151            self.payload_filled += n;
152        }
153
154        // Phase 4: take payload, reset state for next message
155        let payload = self.payload.split().freeze();
156        let tag = self.tag;
157        self.hdr_filled = 0;
158        self.payload_len = None;
159        self.payload_filled = 0;
160
161        Ok(BackendMessage { tag, payload })
162    }
163}
164
165impl Default for MessageReader {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171/// Read a single backend message, reusing the provided buffer.
172///
173/// **Not** cancellation-safe — see [`MessageReader`] for a cancel-safe
174/// alternative used in the streaming loop.
175pub async fn read_backend_message_into<R: AsyncRead + Unpin>(
176    rd: &mut R,
177    buf: &mut BytesMut,
178) -> Result<BackendMessage> {
179    let mut hdr = [0u8; 5];
180    rd.read_exact(&mut hdr).await?;
181    let tag = hdr[0];
182    let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
183
184    if len < 4 {
185        return Err(PgWireError::Protocol(format!(
186            "invalid backend message length: {len}"
187        )));
188    }
189
190    let payload_len = (len - 4) as usize;
191
192    if payload_len > MAX_MESSAGE_SIZE {
193        return Err(PgWireError::Protocol(format!(
194            "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
195        )));
196    }
197
198    buf.clear();
199    buf.resize(payload_len, 0);
200    rd.read_exact(&mut buf[..]).await?;
201    Ok(BackendMessage {
202        tag,
203        payload: buf.split().freeze(),
204    })
205}
206
207pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
208    let mut buf = [0u8; 8];
209    buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
210    buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
211    wr.write_all(&buf).await?;
212    wr.flush().await?;
213    Ok(())
214}
215
216pub async fn write_startup_message<W: AsyncWrite + Unpin>(
217    wr: &mut W,
218    protocol_version: i32,
219    params: &[(&str, &str)],
220) -> Result<()> {
221    let mut buf = BytesMut::with_capacity(256);
222    buf.put_i32(0); // length placeholder
223    buf.put_i32(protocol_version);
224
225    for (k, v) in params {
226        buf.extend_from_slice(k.as_bytes());
227        buf.put_u8(0);
228        buf.extend_from_slice(v.as_bytes());
229        buf.put_u8(0);
230    }
231    buf.put_u8(0); // terminator
232
233    let len = buf.len() as i32;
234    buf[0..4].copy_from_slice(&len.to_be_bytes());
235
236    wr.write_all(&buf).await?;
237    wr.flush().await?;
238    Ok(())
239}
240
241pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
242    let mut buf = BytesMut::with_capacity(sql.len() + 64);
243    buf.put_u8(b'Q');
244    buf.put_i32(0);
245    buf.extend_from_slice(sql.as_bytes());
246    buf.put_u8(0);
247
248    let len = (buf.len() - 1) as i32;
249    buf[1..5].copy_from_slice(&len.to_be_bytes());
250
251    wr.write_all(&buf).await?;
252    wr.flush().await?;
253    Ok(())
254}
255
256pub async fn write_password_message<W: AsyncWrite + Unpin>(
257    wr: &mut W,
258    payload: &[u8],
259) -> Result<()> {
260    let mut buf = BytesMut::with_capacity(payload.len() + 16);
261    buf.put_u8(b'p');
262    buf.put_i32(0);
263    buf.extend_from_slice(payload);
264
265    let len = (buf.len() - 1) as i32;
266    buf[1..5].copy_from_slice(&len.to_be_bytes());
267
268    wr.write_all(&buf).await?;
269    wr.flush().await?;
270    Ok(())
271}
272
273pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
274    let mut buf = BytesMut::with_capacity(payload.len() + 16);
275    buf.put_u8(b'd');
276    buf.put_i32(0);
277    buf.extend_from_slice(payload);
278
279    let len = (buf.len() - 1) as i32;
280    buf[1..5].copy_from_slice(&len.to_be_bytes());
281
282    wr.write_all(&buf).await?;
283    wr.flush().await?;
284    Ok(())
285}
286
287pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
288    let mut buf = BytesMut::with_capacity(5);
289    buf.put_u8(b'c'); // CopyDone
290    buf.put_i32(4); // length includes itself; CopyDone has no payload
291    wr.write_all(&buf).await?;
292    wr.flush().await?;
293    Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use std::io::Cursor;
300    use tokio::io::AsyncWriteExt;
301
302    #[tokio::test]
303    async fn read_backend_message_parses_valid_message() {
304        // Tag 'Z' (ReadyForQuery), length=5 (4 + 1 byte payload), payload='I' (idle)
305        let data = [b'Z', 0, 0, 0, 5, b'I'];
306        let mut cursor = Cursor::new(&data[..]);
307
308        let msg = read_backend_message(&mut cursor).await.unwrap();
309        assert_eq!(msg.tag, b'Z');
310        assert_eq!(&msg.payload[..], b"I");
311        assert!(msg.is_ready_for_query());
312    }
313
314    #[tokio::test]
315    async fn read_backend_message_handles_empty_payload() {
316        // Tag 'N' (NoticeResponse placeholder), length=4 (no payload)
317        let data = [b'N', 0, 0, 0, 4];
318        let mut cursor = Cursor::new(&data[..]);
319
320        let msg = read_backend_message(&mut cursor).await.unwrap();
321        assert_eq!(msg.tag, b'N');
322        assert!(msg.payload.is_empty());
323    }
324
325    #[tokio::test]
326    async fn read_backend_message_rejects_invalid_length() {
327        // length < 4 is invalid
328        let data = [b'Z', 0, 0, 0, 3];
329        let mut cursor = Cursor::new(&data[..]);
330
331        let err = read_backend_message(&mut cursor).await.unwrap_err();
332        assert!(err.to_string().contains("invalid backend message length"));
333    }
334
335    #[tokio::test]
336    async fn message_reader_reads_complete_message() {
337        // Tag 'Z' (ReadyForQuery), length=5 (4 + 1 byte payload), payload='I'
338        let data = [b'Z', 0, 0, 0, 5, b'I'];
339        let mut cursor = Cursor::new(&data[..]);
340
341        let mut reader = MessageReader::new();
342        let msg = reader.read(&mut cursor).await.unwrap();
343        assert_eq!(msg.tag, b'Z');
344        assert_eq!(&msg.payload[..], b"I");
345    }
346
347    #[tokio::test]
348    async fn message_reader_reads_back_to_back_messages() {
349        // Two messages on one stream: ReadyForQuery + NoticeResponse w/ empty payload
350        let data = [b'Z', 0, 0, 0, 5, b'I', b'N', 0, 0, 0, 4];
351        let mut cursor = Cursor::new(&data[..]);
352
353        let mut reader = MessageReader::new();
354
355        let m1 = reader.read(&mut cursor).await.unwrap();
356        assert_eq!(m1.tag, b'Z');
357        assert_eq!(&m1.payload[..], b"I");
358
359        let m2 = reader.read(&mut cursor).await.unwrap();
360        assert_eq!(m2.tag, b'N');
361        assert!(m2.payload.is_empty());
362    }
363
364    /// Regression test for issue #5: reading a backend message must be
365    /// cancellation-safe so that `tokio::select!` / `tokio::time::timeout`
366    /// dropping the read future mid-message does not corrupt the stream.
367    ///
368    /// With the old `read_backend_message_into`, dropping the future after
369    /// 3 of 5 header bytes were consumed would lose those 3 bytes and
370    /// re-parse the next bytes as a new header, producing a bogus length
371    /// and a Protocol error (or worse, a silent desync).
372    #[tokio::test]
373    async fn message_reader_resumes_after_cancellation_mid_header() {
374        let (mut writer, mut rd) = tokio::io::duplex(64);
375        let mut reader = MessageReader::new();
376
377        // Tag 'd' (CopyData), length = 8 (4 + 4-byte payload), payload b"abcd"
378        let header = [b'd', 0, 0, 0, 8];
379        let payload = b"abcd";
380
381        // Deliver only the first 3 header bytes, then cancel.
382        writer.write_all(&header[..3]).await.unwrap();
383
384        let timed_out =
385            tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
386        assert!(
387            timed_out.is_err(),
388            "read must time out while waiting for remaining header bytes"
389        );
390
391        // Deliver the remaining bytes. A correct cancel-safe reader resumes
392        // and returns the original message intact.
393        writer.write_all(&header[3..]).await.unwrap();
394        writer.write_all(payload).await.unwrap();
395
396        let msg = reader.read(&mut rd).await.unwrap();
397        assert_eq!(msg.tag, b'd');
398        assert_eq!(&msg.payload[..], payload);
399    }
400
401    /// Ensures partial-payload cancellation also resumes correctly.
402    #[tokio::test]
403    async fn message_reader_resumes_after_cancellation_mid_payload() {
404        let (mut writer, mut rd) = tokio::io::duplex(64);
405        let mut reader = MessageReader::new();
406
407        // 16-byte payload to ensure we can split it.
408        let payload: [u8; 16] = std::array::from_fn(|i| i as u8);
409        let len = (4 + payload.len()) as i32;
410        let header = [
411            b'd',
412            (len >> 24) as u8,
413            (len >> 16) as u8,
414            (len >> 8) as u8,
415            len as u8,
416        ];
417
418        // Full header + first 5 bytes of payload, then cancel.
419        writer.write_all(&header).await.unwrap();
420        writer.write_all(&payload[..5]).await.unwrap();
421
422        let timed_out =
423            tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
424        assert!(
425            timed_out.is_err(),
426            "read must time out while waiting for remaining payload bytes"
427        );
428
429        // Deliver the rest.
430        writer.write_all(&payload[5..]).await.unwrap();
431
432        let msg = reader.read(&mut rd).await.unwrap();
433        assert_eq!(msg.tag, b'd');
434        assert_eq!(&msg.payload[..], &payload[..]);
435    }
436
437    #[tokio::test]
438    async fn message_reader_rejects_invalid_length() {
439        let data = [b'Z', 0, 0, 0, 3];
440        let mut cursor = Cursor::new(&data[..]);
441
442        let mut reader = MessageReader::new();
443        let err = reader.read(&mut cursor).await.unwrap_err();
444        assert!(err.to_string().contains("invalid backend message length"));
445    }
446
447    #[tokio::test]
448    async fn read_backend_message_rejects_oversized_message() {
449        // length = MAX_MESSAGE_SIZE + 5 (over limit)
450        let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
451        let data = [
452            b'Z',
453            (huge_len >> 24) as u8,
454            (huge_len >> 16) as u8,
455            (huge_len >> 8) as u8,
456            huge_len as u8,
457        ];
458        let mut cursor = Cursor::new(&data[..]);
459
460        let err = read_backend_message(&mut cursor).await.unwrap_err();
461        assert!(err.to_string().contains("too large"));
462    }
463
464    #[tokio::test]
465    async fn write_ssl_request_produces_valid_bytes() {
466        let mut buf = Vec::new();
467        write_ssl_request(&mut buf).await.unwrap();
468
469        assert_eq!(buf.len(), 8);
470        // length = 8
471        assert_eq!(&buf[0..4], &8i32.to_be_bytes());
472        // SSL request code = 80877103
473        assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
474    }
475
476    #[tokio::test]
477    async fn write_startup_message_includes_params() {
478        let mut buf = Vec::new();
479        let params = [("user", "postgres"), ("database", "test")];
480        write_startup_message(&mut buf, 196608, &params)
481            .await
482            .unwrap();
483
484        // Should contain the parameter strings
485        let s = String::from_utf8_lossy(&buf);
486        assert!(s.contains("user"));
487        assert!(s.contains("postgres"));
488        assert!(s.contains("database"));
489        assert!(s.contains("test"));
490
491        // Length field should be at start
492        let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
493        assert_eq!(len, buf.len());
494    }
495
496    #[tokio::test]
497    async fn write_query_produces_valid_message() {
498        let mut buf = Vec::new();
499        write_query(&mut buf, "SELECT 1").await.unwrap();
500
501        // Should start with 'Q'
502        assert_eq!(buf[0], b'Q');
503
504        // Length should be correct (excludes tag byte)
505        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
506        assert_eq!(len, buf.len() - 1);
507
508        // Should contain the SQL
509        assert!(buf[5..].starts_with(b"SELECT 1"));
510
511        // Should be null-terminated
512        assert_eq!(buf[buf.len() - 1], 0);
513    }
514
515    #[tokio::test]
516    async fn write_password_message_produces_valid_message() {
517        let mut buf = Vec::new();
518        write_password_message(&mut buf, b"secret").await.unwrap();
519
520        assert_eq!(buf[0], b'p');
521        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
522        assert_eq!(len, buf.len() - 1);
523        assert_eq!(&buf[5..], b"secret");
524    }
525
526    #[tokio::test]
527    async fn write_copy_data_produces_valid_message() {
528        let mut buf = Vec::new();
529        write_copy_data(&mut buf, b"payload").await.unwrap();
530
531        assert_eq!(buf[0], b'd');
532        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
533        assert_eq!(len, buf.len() - 1);
534        assert_eq!(&buf[5..], b"payload");
535    }
536
537    #[tokio::test]
538    async fn write_copy_done_produces_valid_message() {
539        let mut buf = Vec::new();
540        write_copy_done(&mut buf).await.unwrap();
541
542        assert_eq!(buf.len(), 5);
543        assert_eq!(buf[0], b'c');
544        // Length = 4 (just the length field itself, no payload)
545        assert_eq!(&buf[1..5], &4i32.to_be_bytes());
546    }
547
548    #[test]
549    fn backend_message_helper_methods() {
550        let error = BackendMessage {
551            tag: b'E',
552            payload: Bytes::new(),
553        };
554        assert!(error.is_error());
555        assert!(!error.is_ready_for_query());
556
557        let ready = BackendMessage {
558            tag: b'Z',
559            payload: Bytes::from_static(b"I"),
560        };
561        assert!(ready.is_ready_for_query());
562        assert!(!ready.is_error());
563
564        let copy_both = BackendMessage {
565            tag: b'W',
566            payload: Bytes::new(),
567        };
568        assert!(copy_both.is_copy_both_response());
569
570        let copy_data = BackendMessage {
571            tag: b'd',
572            payload: Bytes::new(),
573        };
574        assert!(copy_data.is_copy_data());
575
576        let auth = BackendMessage {
577            tag: b'R',
578            payload: Bytes::new(),
579        };
580        assert!(auth.is_auth_request());
581    }
582}