pgwire_replication/protocol/
framing.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::error::{PgWireError, Result};
5
6/// Maximum backend message size (1GB) - prevents memory exhaustion from malformed length fields
7/// This is more than enough.
8pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct BackendMessage {
12    pub tag: u8,
13    pub payload: Bytes, // payload excludes the 4-byte length field
14}
15
16impl BackendMessage {
17    /// Returns true if this is an ErrorResponse ('E')
18    #[inline]
19    pub fn is_error(&self) -> bool {
20        self.tag == b'E'
21    }
22
23    /// Returns true if this is a ReadyForQuery ('Z')
24    #[inline]
25    pub fn is_ready_for_query(&self) -> bool {
26        self.tag == b'Z'
27    }
28
29    /// Returns true if this is CopyBothResponse ('W')
30    #[inline]
31    pub fn is_copy_both_response(&self) -> bool {
32        self.tag == b'W'
33    }
34
35    /// Returns true if this is CopyData ('d')
36    #[inline]
37    pub fn is_copy_data(&self) -> bool {
38        self.tag == b'd'
39    }
40
41    /// Returns true if this is AuthenticationRequest ('R')
42    #[inline]
43    pub fn is_auth_request(&self) -> bool {
44        self.tag == b'R'
45    }
46}
47
48pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
49    let mut hdr = [0u8; 5];
50    rd.read_exact(&mut hdr).await?;
51    let tag = hdr[0];
52    let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
53
54    if len < 4 {
55        return Err(PgWireError::Protocol(format!(
56            "invalid backend message length: {len}"
57        )));
58    }
59
60    let payload_len = (len - 4) as usize;
61
62    if payload_len > MAX_MESSAGE_SIZE {
63        return Err(PgWireError::Protocol(format!(
64            "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
65        )));
66    }
67
68    let mut buf = vec![0u8; payload_len];
69    rd.read_exact(&mut buf).await?;
70    Ok(BackendMessage {
71        tag,
72        payload: Bytes::from(buf),
73    })
74}
75
76pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
77    let mut buf = [0u8; 8];
78    buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
79    buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
80    wr.write_all(&buf).await?;
81    wr.flush().await?;
82    Ok(())
83}
84
85pub async fn write_startup_message<W: AsyncWrite + Unpin>(
86    wr: &mut W,
87    protocol_version: i32,
88    params: &[(&str, &str)],
89) -> Result<()> {
90    let mut buf = BytesMut::with_capacity(256);
91    buf.put_i32(0); // length placeholder
92    buf.put_i32(protocol_version);
93
94    for (k, v) in params {
95        buf.extend_from_slice(k.as_bytes());
96        buf.put_u8(0);
97        buf.extend_from_slice(v.as_bytes());
98        buf.put_u8(0);
99    }
100    buf.put_u8(0); // terminator
101
102    let len = buf.len() as i32;
103    buf[0..4].copy_from_slice(&len.to_be_bytes());
104
105    wr.write_all(&buf).await?;
106    wr.flush().await?;
107    Ok(())
108}
109
110pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
111    let mut buf = BytesMut::with_capacity(sql.len() + 64);
112    buf.put_u8(b'Q');
113    buf.put_i32(0);
114    buf.extend_from_slice(sql.as_bytes());
115    buf.put_u8(0);
116
117    let len = (buf.len() - 1) as i32;
118    buf[1..5].copy_from_slice(&len.to_be_bytes());
119
120    wr.write_all(&buf).await?;
121    wr.flush().await?;
122    Ok(())
123}
124
125pub async fn write_password_message<W: AsyncWrite + Unpin>(
126    wr: &mut W,
127    payload: &[u8],
128) -> Result<()> {
129    let mut buf = BytesMut::with_capacity(payload.len() + 16);
130    buf.put_u8(b'p');
131    buf.put_i32(0);
132    buf.extend_from_slice(payload);
133
134    let len = (buf.len() - 1) as i32;
135    buf[1..5].copy_from_slice(&len.to_be_bytes());
136
137    wr.write_all(&buf).await?;
138    wr.flush().await?;
139    Ok(())
140}
141
142pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
143    let mut buf = BytesMut::with_capacity(payload.len() + 16);
144    buf.put_u8(b'd');
145    buf.put_i32(0);
146    buf.extend_from_slice(payload);
147
148    let len = (buf.len() - 1) as i32;
149    buf[1..5].copy_from_slice(&len.to_be_bytes());
150
151    wr.write_all(&buf).await?;
152    wr.flush().await?;
153    Ok(())
154}
155
156pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
157    let mut buf = BytesMut::with_capacity(5);
158    buf.put_u8(b'c'); // CopyDone
159    buf.put_i32(4); // length includes itself; CopyDone has no payload
160    wr.write_all(&buf).await?;
161    wr.flush().await?;
162    Ok(())
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::io::Cursor;
169
170    #[tokio::test]
171    async fn read_backend_message_parses_valid_message() {
172        // Tag 'Z' (ReadyForQuery), length=5 (4 + 1 byte payload), payload='I' (idle)
173        let data = [b'Z', 0, 0, 0, 5, b'I'];
174        let mut cursor = Cursor::new(&data[..]);
175
176        let msg = read_backend_message(&mut cursor).await.unwrap();
177        assert_eq!(msg.tag, b'Z');
178        assert_eq!(&msg.payload[..], b"I");
179        assert!(msg.is_ready_for_query());
180    }
181
182    #[tokio::test]
183    async fn read_backend_message_handles_empty_payload() {
184        // Tag 'N' (NoticeResponse placeholder), length=4 (no payload)
185        let data = [b'N', 0, 0, 0, 4];
186        let mut cursor = Cursor::new(&data[..]);
187
188        let msg = read_backend_message(&mut cursor).await.unwrap();
189        assert_eq!(msg.tag, b'N');
190        assert!(msg.payload.is_empty());
191    }
192
193    #[tokio::test]
194    async fn read_backend_message_rejects_invalid_length() {
195        // length < 4 is invalid
196        let data = [b'Z', 0, 0, 0, 3];
197        let mut cursor = Cursor::new(&data[..]);
198
199        let err = read_backend_message(&mut cursor).await.unwrap_err();
200        assert!(err.to_string().contains("invalid backend message length"));
201    }
202
203    #[tokio::test]
204    async fn read_backend_message_rejects_oversized_message() {
205        // length = MAX_MESSAGE_SIZE + 5 (over limit)
206        let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
207        let data = [
208            b'Z',
209            (huge_len >> 24) as u8,
210            (huge_len >> 16) as u8,
211            (huge_len >> 8) as u8,
212            huge_len as u8,
213        ];
214        let mut cursor = Cursor::new(&data[..]);
215
216        let err = read_backend_message(&mut cursor).await.unwrap_err();
217        assert!(err.to_string().contains("too large"));
218    }
219
220    #[tokio::test]
221    async fn write_ssl_request_produces_valid_bytes() {
222        let mut buf = Vec::new();
223        write_ssl_request(&mut buf).await.unwrap();
224
225        assert_eq!(buf.len(), 8);
226        // length = 8
227        assert_eq!(&buf[0..4], &8i32.to_be_bytes());
228        // SSL request code = 80877103
229        assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
230    }
231
232    #[tokio::test]
233    async fn write_startup_message_includes_params() {
234        let mut buf = Vec::new();
235        let params = [("user", "postgres"), ("database", "test")];
236        write_startup_message(&mut buf, 196608, &params)
237            .await
238            .unwrap();
239
240        // Should contain the parameter strings
241        let s = String::from_utf8_lossy(&buf);
242        assert!(s.contains("user"));
243        assert!(s.contains("postgres"));
244        assert!(s.contains("database"));
245        assert!(s.contains("test"));
246
247        // Length field should be at start
248        let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
249        assert_eq!(len, buf.len());
250    }
251
252    #[tokio::test]
253    async fn write_query_produces_valid_message() {
254        let mut buf = Vec::new();
255        write_query(&mut buf, "SELECT 1").await.unwrap();
256
257        // Should start with 'Q'
258        assert_eq!(buf[0], b'Q');
259
260        // Length should be correct (excludes tag byte)
261        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
262        assert_eq!(len, buf.len() - 1);
263
264        // Should contain the SQL
265        assert!(buf[5..].starts_with(b"SELECT 1"));
266
267        // Should be null-terminated
268        assert_eq!(buf[buf.len() - 1], 0);
269    }
270
271    #[tokio::test]
272    async fn write_password_message_produces_valid_message() {
273        let mut buf = Vec::new();
274        write_password_message(&mut buf, b"secret").await.unwrap();
275
276        assert_eq!(buf[0], b'p');
277        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
278        assert_eq!(len, buf.len() - 1);
279        assert_eq!(&buf[5..], b"secret");
280    }
281
282    #[tokio::test]
283    async fn write_copy_data_produces_valid_message() {
284        let mut buf = Vec::new();
285        write_copy_data(&mut buf, b"payload").await.unwrap();
286
287        assert_eq!(buf[0], b'd');
288        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
289        assert_eq!(len, buf.len() - 1);
290        assert_eq!(&buf[5..], b"payload");
291    }
292
293    #[tokio::test]
294    async fn write_copy_done_produces_valid_message() {
295        let mut buf = Vec::new();
296        write_copy_done(&mut buf).await.unwrap();
297
298        assert_eq!(buf.len(), 5);
299        assert_eq!(buf[0], b'c');
300        // Length = 4 (just the length field itself, no payload)
301        assert_eq!(&buf[1..5], &4i32.to_be_bytes());
302    }
303
304    #[test]
305    fn backend_message_helper_methods() {
306        let error = BackendMessage {
307            tag: b'E',
308            payload: Bytes::new(),
309        };
310        assert!(error.is_error());
311        assert!(!error.is_ready_for_query());
312
313        let ready = BackendMessage {
314            tag: b'Z',
315            payload: Bytes::from_static(b"I"),
316        };
317        assert!(ready.is_ready_for_query());
318        assert!(!ready.is_error());
319
320        let copy_both = BackendMessage {
321            tag: b'W',
322            payload: Bytes::new(),
323        };
324        assert!(copy_both.is_copy_both_response());
325
326        let copy_data = BackendMessage {
327            tag: b'd',
328            payload: Bytes::new(),
329        };
330        assert!(copy_data.is_copy_data());
331
332        let auth = BackendMessage {
333            tag: b'R',
334            payload: Bytes::new(),
335        };
336        assert!(auth.is_auth_request());
337    }
338}