Skip to main content

pgwire_replication/protocol/
framing.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::error::{PgWireError, Result};
5
6/// Maximum backend message size (1GB) - prevents memory exhaustion from malformed length fields
7/// This is more than enough.
8pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct BackendMessage {
12    pub tag: u8,
13    pub payload: Bytes, // payload excludes the 4-byte length field
14}
15
16impl BackendMessage {
17    /// Returns true if this is an ErrorResponse ('E')
18    #[inline]
19    pub fn is_error(&self) -> bool {
20        self.tag == b'E'
21    }
22
23    /// Returns true if this is a ReadyForQuery ('Z')
24    #[inline]
25    pub fn is_ready_for_query(&self) -> bool {
26        self.tag == b'Z'
27    }
28
29    /// Returns true if this is CopyBothResponse ('W')
30    #[inline]
31    pub fn is_copy_both_response(&self) -> bool {
32        self.tag == b'W'
33    }
34
35    /// Returns true if this is CopyData ('d')
36    #[inline]
37    pub fn is_copy_data(&self) -> bool {
38        self.tag == b'd'
39    }
40
41    /// Returns true if this is AuthenticationRequest ('R')
42    #[inline]
43    pub fn is_auth_request(&self) -> bool {
44        self.tag == b'R'
45    }
46}
47
48pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
49    read_backend_message_into(rd, &mut BytesMut::new()).await
50}
51
52/// Read a backend message, reusing `buf` to avoid per-message allocation.
53pub async fn read_backend_message_into<R: AsyncRead + Unpin>(
54    rd: &mut R,
55    buf: &mut BytesMut,
56) -> Result<BackendMessage> {
57    let mut hdr = [0u8; 5];
58    rd.read_exact(&mut hdr).await?;
59    let tag = hdr[0];
60    let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
61
62    if len < 4 {
63        return Err(PgWireError::Protocol(format!(
64            "invalid backend message length: {len}"
65        )));
66    }
67
68    let payload_len = (len - 4) as usize;
69
70    if payload_len > MAX_MESSAGE_SIZE {
71        return Err(PgWireError::Protocol(format!(
72            "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
73        )));
74    }
75
76    buf.clear();
77    buf.resize(payload_len, 0);
78    rd.read_exact(&mut buf[..]).await?;
79    Ok(BackendMessage {
80        tag,
81        payload: buf.split().freeze(),
82    })
83}
84
85pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
86    let mut buf = [0u8; 8];
87    buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
88    buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
89    wr.write_all(&buf).await?;
90    wr.flush().await?;
91    Ok(())
92}
93
94pub async fn write_startup_message<W: AsyncWrite + Unpin>(
95    wr: &mut W,
96    protocol_version: i32,
97    params: &[(&str, &str)],
98) -> Result<()> {
99    let mut buf = BytesMut::with_capacity(256);
100    buf.put_i32(0); // length placeholder
101    buf.put_i32(protocol_version);
102
103    for (k, v) in params {
104        buf.extend_from_slice(k.as_bytes());
105        buf.put_u8(0);
106        buf.extend_from_slice(v.as_bytes());
107        buf.put_u8(0);
108    }
109    buf.put_u8(0); // terminator
110
111    let len = buf.len() as i32;
112    buf[0..4].copy_from_slice(&len.to_be_bytes());
113
114    wr.write_all(&buf).await?;
115    wr.flush().await?;
116    Ok(())
117}
118
119pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
120    let mut buf = BytesMut::with_capacity(sql.len() + 64);
121    buf.put_u8(b'Q');
122    buf.put_i32(0);
123    buf.extend_from_slice(sql.as_bytes());
124    buf.put_u8(0);
125
126    let len = (buf.len() - 1) as i32;
127    buf[1..5].copy_from_slice(&len.to_be_bytes());
128
129    wr.write_all(&buf).await?;
130    wr.flush().await?;
131    Ok(())
132}
133
134pub async fn write_password_message<W: AsyncWrite + Unpin>(
135    wr: &mut W,
136    payload: &[u8],
137) -> Result<()> {
138    let mut buf = BytesMut::with_capacity(payload.len() + 16);
139    buf.put_u8(b'p');
140    buf.put_i32(0);
141    buf.extend_from_slice(payload);
142
143    let len = (buf.len() - 1) as i32;
144    buf[1..5].copy_from_slice(&len.to_be_bytes());
145
146    wr.write_all(&buf).await?;
147    wr.flush().await?;
148    Ok(())
149}
150
151pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
152    let mut buf = BytesMut::with_capacity(payload.len() + 16);
153    buf.put_u8(b'd');
154    buf.put_i32(0);
155    buf.extend_from_slice(payload);
156
157    let len = (buf.len() - 1) as i32;
158    buf[1..5].copy_from_slice(&len.to_be_bytes());
159
160    wr.write_all(&buf).await?;
161    wr.flush().await?;
162    Ok(())
163}
164
165pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
166    let mut buf = BytesMut::with_capacity(5);
167    buf.put_u8(b'c'); // CopyDone
168    buf.put_i32(4); // length includes itself; CopyDone has no payload
169    wr.write_all(&buf).await?;
170    wr.flush().await?;
171    Ok(())
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::io::Cursor;
178
179    #[tokio::test]
180    async fn read_backend_message_parses_valid_message() {
181        // Tag 'Z' (ReadyForQuery), length=5 (4 + 1 byte payload), payload='I' (idle)
182        let data = [b'Z', 0, 0, 0, 5, b'I'];
183        let mut cursor = Cursor::new(&data[..]);
184
185        let msg = read_backend_message(&mut cursor).await.unwrap();
186        assert_eq!(msg.tag, b'Z');
187        assert_eq!(&msg.payload[..], b"I");
188        assert!(msg.is_ready_for_query());
189    }
190
191    #[tokio::test]
192    async fn read_backend_message_handles_empty_payload() {
193        // Tag 'N' (NoticeResponse placeholder), length=4 (no payload)
194        let data = [b'N', 0, 0, 0, 4];
195        let mut cursor = Cursor::new(&data[..]);
196
197        let msg = read_backend_message(&mut cursor).await.unwrap();
198        assert_eq!(msg.tag, b'N');
199        assert!(msg.payload.is_empty());
200    }
201
202    #[tokio::test]
203    async fn read_backend_message_rejects_invalid_length() {
204        // length < 4 is invalid
205        let data = [b'Z', 0, 0, 0, 3];
206        let mut cursor = Cursor::new(&data[..]);
207
208        let err = read_backend_message(&mut cursor).await.unwrap_err();
209        assert!(err.to_string().contains("invalid backend message length"));
210    }
211
212    #[tokio::test]
213    async fn read_backend_message_rejects_oversized_message() {
214        // length = MAX_MESSAGE_SIZE + 5 (over limit)
215        let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
216        let data = [
217            b'Z',
218            (huge_len >> 24) as u8,
219            (huge_len >> 16) as u8,
220            (huge_len >> 8) as u8,
221            huge_len as u8,
222        ];
223        let mut cursor = Cursor::new(&data[..]);
224
225        let err = read_backend_message(&mut cursor).await.unwrap_err();
226        assert!(err.to_string().contains("too large"));
227    }
228
229    #[tokio::test]
230    async fn write_ssl_request_produces_valid_bytes() {
231        let mut buf = Vec::new();
232        write_ssl_request(&mut buf).await.unwrap();
233
234        assert_eq!(buf.len(), 8);
235        // length = 8
236        assert_eq!(&buf[0..4], &8i32.to_be_bytes());
237        // SSL request code = 80877103
238        assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
239    }
240
241    #[tokio::test]
242    async fn write_startup_message_includes_params() {
243        let mut buf = Vec::new();
244        let params = [("user", "postgres"), ("database", "test")];
245        write_startup_message(&mut buf, 196608, &params)
246            .await
247            .unwrap();
248
249        // Should contain the parameter strings
250        let s = String::from_utf8_lossy(&buf);
251        assert!(s.contains("user"));
252        assert!(s.contains("postgres"));
253        assert!(s.contains("database"));
254        assert!(s.contains("test"));
255
256        // Length field should be at start
257        let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
258        assert_eq!(len, buf.len());
259    }
260
261    #[tokio::test]
262    async fn write_query_produces_valid_message() {
263        let mut buf = Vec::new();
264        write_query(&mut buf, "SELECT 1").await.unwrap();
265
266        // Should start with 'Q'
267        assert_eq!(buf[0], b'Q');
268
269        // Length should be correct (excludes tag byte)
270        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
271        assert_eq!(len, buf.len() - 1);
272
273        // Should contain the SQL
274        assert!(buf[5..].starts_with(b"SELECT 1"));
275
276        // Should be null-terminated
277        assert_eq!(buf[buf.len() - 1], 0);
278    }
279
280    #[tokio::test]
281    async fn write_password_message_produces_valid_message() {
282        let mut buf = Vec::new();
283        write_password_message(&mut buf, b"secret").await.unwrap();
284
285        assert_eq!(buf[0], b'p');
286        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
287        assert_eq!(len, buf.len() - 1);
288        assert_eq!(&buf[5..], b"secret");
289    }
290
291    #[tokio::test]
292    async fn write_copy_data_produces_valid_message() {
293        let mut buf = Vec::new();
294        write_copy_data(&mut buf, b"payload").await.unwrap();
295
296        assert_eq!(buf[0], b'd');
297        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
298        assert_eq!(len, buf.len() - 1);
299        assert_eq!(&buf[5..], b"payload");
300    }
301
302    #[tokio::test]
303    async fn write_copy_done_produces_valid_message() {
304        let mut buf = Vec::new();
305        write_copy_done(&mut buf).await.unwrap();
306
307        assert_eq!(buf.len(), 5);
308        assert_eq!(buf[0], b'c');
309        // Length = 4 (just the length field itself, no payload)
310        assert_eq!(&buf[1..5], &4i32.to_be_bytes());
311    }
312
313    #[test]
314    fn backend_message_helper_methods() {
315        let error = BackendMessage {
316            tag: b'E',
317            payload: Bytes::new(),
318        };
319        assert!(error.is_error());
320        assert!(!error.is_ready_for_query());
321
322        let ready = BackendMessage {
323            tag: b'Z',
324            payload: Bytes::from_static(b"I"),
325        };
326        assert!(ready.is_ready_for_query());
327        assert!(!ready.is_error());
328
329        let copy_both = BackendMessage {
330            tag: b'W',
331            payload: Bytes::new(),
332        };
333        assert!(copy_both.is_copy_both_response());
334
335        let copy_data = BackendMessage {
336            tag: b'd',
337            payload: Bytes::new(),
338        };
339        assert!(copy_data.is_copy_data());
340
341        let auth = BackendMessage {
342            tag: b'R',
343            payload: Bytes::new(),
344        };
345        assert!(auth.is_auth_request());
346    }
347}