pgwire_replication/protocol/
framing.rs1use bytes::{BufMut, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::error::{PgWireError, Result};
5
6pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct BackendMessage {
12 pub tag: u8,
13 pub payload: Bytes, }
15
16impl BackendMessage {
17 #[inline]
19 pub fn is_error(&self) -> bool {
20 self.tag == b'E'
21 }
22
23 #[inline]
25 pub fn is_ready_for_query(&self) -> bool {
26 self.tag == b'Z'
27 }
28
29 #[inline]
31 pub fn is_copy_both_response(&self) -> bool {
32 self.tag == b'W'
33 }
34
35 #[inline]
37 pub fn is_copy_data(&self) -> bool {
38 self.tag == b'd'
39 }
40
41 #[inline]
43 pub fn is_auth_request(&self) -> bool {
44 self.tag == b'R'
45 }
46}
47
48pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
49 read_backend_message_into(rd, &mut BytesMut::new()).await
50}
51
52pub async fn read_backend_message_into<R: AsyncRead + Unpin>(
54 rd: &mut R,
55 buf: &mut BytesMut,
56) -> Result<BackendMessage> {
57 let mut hdr = [0u8; 5];
58 rd.read_exact(&mut hdr).await?;
59 let tag = hdr[0];
60 let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
61
62 if len < 4 {
63 return Err(PgWireError::Protocol(format!(
64 "invalid backend message length: {len}"
65 )));
66 }
67
68 let payload_len = (len - 4) as usize;
69
70 if payload_len > MAX_MESSAGE_SIZE {
71 return Err(PgWireError::Protocol(format!(
72 "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
73 )));
74 }
75
76 buf.clear();
77 buf.resize(payload_len, 0);
78 rd.read_exact(&mut buf[..]).await?;
79 Ok(BackendMessage {
80 tag,
81 payload: buf.split().freeze(),
82 })
83}
84
85pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
86 let mut buf = [0u8; 8];
87 buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
88 buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
89 wr.write_all(&buf).await?;
90 wr.flush().await?;
91 Ok(())
92}
93
94pub async fn write_startup_message<W: AsyncWrite + Unpin>(
95 wr: &mut W,
96 protocol_version: i32,
97 params: &[(&str, &str)],
98) -> Result<()> {
99 let mut buf = BytesMut::with_capacity(256);
100 buf.put_i32(0); buf.put_i32(protocol_version);
102
103 for (k, v) in params {
104 buf.extend_from_slice(k.as_bytes());
105 buf.put_u8(0);
106 buf.extend_from_slice(v.as_bytes());
107 buf.put_u8(0);
108 }
109 buf.put_u8(0); let len = buf.len() as i32;
112 buf[0..4].copy_from_slice(&len.to_be_bytes());
113
114 wr.write_all(&buf).await?;
115 wr.flush().await?;
116 Ok(())
117}
118
119pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
120 let mut buf = BytesMut::with_capacity(sql.len() + 64);
121 buf.put_u8(b'Q');
122 buf.put_i32(0);
123 buf.extend_from_slice(sql.as_bytes());
124 buf.put_u8(0);
125
126 let len = (buf.len() - 1) as i32;
127 buf[1..5].copy_from_slice(&len.to_be_bytes());
128
129 wr.write_all(&buf).await?;
130 wr.flush().await?;
131 Ok(())
132}
133
134pub async fn write_password_message<W: AsyncWrite + Unpin>(
135 wr: &mut W,
136 payload: &[u8],
137) -> Result<()> {
138 let mut buf = BytesMut::with_capacity(payload.len() + 16);
139 buf.put_u8(b'p');
140 buf.put_i32(0);
141 buf.extend_from_slice(payload);
142
143 let len = (buf.len() - 1) as i32;
144 buf[1..5].copy_from_slice(&len.to_be_bytes());
145
146 wr.write_all(&buf).await?;
147 wr.flush().await?;
148 Ok(())
149}
150
151pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
152 let mut buf = BytesMut::with_capacity(payload.len() + 16);
153 buf.put_u8(b'd');
154 buf.put_i32(0);
155 buf.extend_from_slice(payload);
156
157 let len = (buf.len() - 1) as i32;
158 buf[1..5].copy_from_slice(&len.to_be_bytes());
159
160 wr.write_all(&buf).await?;
161 wr.flush().await?;
162 Ok(())
163}
164
165pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
166 let mut buf = BytesMut::with_capacity(5);
167 buf.put_u8(b'c'); buf.put_i32(4); wr.write_all(&buf).await?;
170 wr.flush().await?;
171 Ok(())
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::io::Cursor;
178
179 #[tokio::test]
180 async fn read_backend_message_parses_valid_message() {
181 let data = [b'Z', 0, 0, 0, 5, b'I'];
183 let mut cursor = Cursor::new(&data[..]);
184
185 let msg = read_backend_message(&mut cursor).await.unwrap();
186 assert_eq!(msg.tag, b'Z');
187 assert_eq!(&msg.payload[..], b"I");
188 assert!(msg.is_ready_for_query());
189 }
190
191 #[tokio::test]
192 async fn read_backend_message_handles_empty_payload() {
193 let data = [b'N', 0, 0, 0, 4];
195 let mut cursor = Cursor::new(&data[..]);
196
197 let msg = read_backend_message(&mut cursor).await.unwrap();
198 assert_eq!(msg.tag, b'N');
199 assert!(msg.payload.is_empty());
200 }
201
202 #[tokio::test]
203 async fn read_backend_message_rejects_invalid_length() {
204 let data = [b'Z', 0, 0, 0, 3];
206 let mut cursor = Cursor::new(&data[..]);
207
208 let err = read_backend_message(&mut cursor).await.unwrap_err();
209 assert!(err.to_string().contains("invalid backend message length"));
210 }
211
212 #[tokio::test]
213 async fn read_backend_message_rejects_oversized_message() {
214 let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
216 let data = [
217 b'Z',
218 (huge_len >> 24) as u8,
219 (huge_len >> 16) as u8,
220 (huge_len >> 8) as u8,
221 huge_len as u8,
222 ];
223 let mut cursor = Cursor::new(&data[..]);
224
225 let err = read_backend_message(&mut cursor).await.unwrap_err();
226 assert!(err.to_string().contains("too large"));
227 }
228
229 #[tokio::test]
230 async fn write_ssl_request_produces_valid_bytes() {
231 let mut buf = Vec::new();
232 write_ssl_request(&mut buf).await.unwrap();
233
234 assert_eq!(buf.len(), 8);
235 assert_eq!(&buf[0..4], &8i32.to_be_bytes());
237 assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
239 }
240
241 #[tokio::test]
242 async fn write_startup_message_includes_params() {
243 let mut buf = Vec::new();
244 let params = [("user", "postgres"), ("database", "test")];
245 write_startup_message(&mut buf, 196608, ¶ms)
246 .await
247 .unwrap();
248
249 let s = String::from_utf8_lossy(&buf);
251 assert!(s.contains("user"));
252 assert!(s.contains("postgres"));
253 assert!(s.contains("database"));
254 assert!(s.contains("test"));
255
256 let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
258 assert_eq!(len, buf.len());
259 }
260
261 #[tokio::test]
262 async fn write_query_produces_valid_message() {
263 let mut buf = Vec::new();
264 write_query(&mut buf, "SELECT 1").await.unwrap();
265
266 assert_eq!(buf[0], b'Q');
268
269 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
271 assert_eq!(len, buf.len() - 1);
272
273 assert!(buf[5..].starts_with(b"SELECT 1"));
275
276 assert_eq!(buf[buf.len() - 1], 0);
278 }
279
280 #[tokio::test]
281 async fn write_password_message_produces_valid_message() {
282 let mut buf = Vec::new();
283 write_password_message(&mut buf, b"secret").await.unwrap();
284
285 assert_eq!(buf[0], b'p');
286 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
287 assert_eq!(len, buf.len() - 1);
288 assert_eq!(&buf[5..], b"secret");
289 }
290
291 #[tokio::test]
292 async fn write_copy_data_produces_valid_message() {
293 let mut buf = Vec::new();
294 write_copy_data(&mut buf, b"payload").await.unwrap();
295
296 assert_eq!(buf[0], b'd');
297 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
298 assert_eq!(len, buf.len() - 1);
299 assert_eq!(&buf[5..], b"payload");
300 }
301
302 #[tokio::test]
303 async fn write_copy_done_produces_valid_message() {
304 let mut buf = Vec::new();
305 write_copy_done(&mut buf).await.unwrap();
306
307 assert_eq!(buf.len(), 5);
308 assert_eq!(buf[0], b'c');
309 assert_eq!(&buf[1..5], &4i32.to_be_bytes());
311 }
312
313 #[test]
314 fn backend_message_helper_methods() {
315 let error = BackendMessage {
316 tag: b'E',
317 payload: Bytes::new(),
318 };
319 assert!(error.is_error());
320 assert!(!error.is_ready_for_query());
321
322 let ready = BackendMessage {
323 tag: b'Z',
324 payload: Bytes::from_static(b"I"),
325 };
326 assert!(ready.is_ready_for_query());
327 assert!(!ready.is_error());
328
329 let copy_both = BackendMessage {
330 tag: b'W',
331 payload: Bytes::new(),
332 };
333 assert!(copy_both.is_copy_both_response());
334
335 let copy_data = BackendMessage {
336 tag: b'd',
337 payload: Bytes::new(),
338 };
339 assert!(copy_data.is_copy_data());
340
341 let auth = BackendMessage {
342 tag: b'R',
343 payload: Bytes::new(),
344 };
345 assert!(auth.is_auth_request());
346 }
347}