pgwire_replication/protocol/
framing.rs1use bytes::{BufMut, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::error::{PgWireError, Result};
5
6pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct BackendMessage {
12 pub tag: u8,
13 pub payload: Bytes, }
15
16impl BackendMessage {
17 #[inline]
19 pub fn is_error(&self) -> bool {
20 self.tag == b'E'
21 }
22
23 #[inline]
25 pub fn is_ready_for_query(&self) -> bool {
26 self.tag == b'Z'
27 }
28
29 #[inline]
31 pub fn is_copy_both_response(&self) -> bool {
32 self.tag == b'W'
33 }
34
35 #[inline]
37 pub fn is_copy_data(&self) -> bool {
38 self.tag == b'd'
39 }
40
41 #[inline]
43 pub fn is_auth_request(&self) -> bool {
44 self.tag == b'R'
45 }
46}
47
48pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
49 let mut hdr = [0u8; 5];
50 rd.read_exact(&mut hdr).await?;
51 let tag = hdr[0];
52 let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
53
54 if len < 4 {
55 return Err(PgWireError::Protocol(format!(
56 "invalid backend message length: {len}"
57 )));
58 }
59
60 let payload_len = (len - 4) as usize;
61
62 if payload_len > MAX_MESSAGE_SIZE {
63 return Err(PgWireError::Protocol(format!(
64 "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
65 )));
66 }
67
68 let mut buf = vec![0u8; payload_len];
69 rd.read_exact(&mut buf).await?;
70 Ok(BackendMessage {
71 tag,
72 payload: Bytes::from(buf),
73 })
74}
75
76pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
77 let mut buf = [0u8; 8];
78 buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
79 buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
80 wr.write_all(&buf).await?;
81 wr.flush().await?;
82 Ok(())
83}
84
85pub async fn write_startup_message<W: AsyncWrite + Unpin>(
86 wr: &mut W,
87 protocol_version: i32,
88 params: &[(&str, &str)],
89) -> Result<()> {
90 let mut buf = BytesMut::with_capacity(256);
91 buf.put_i32(0); buf.put_i32(protocol_version);
93
94 for (k, v) in params {
95 buf.extend_from_slice(k.as_bytes());
96 buf.put_u8(0);
97 buf.extend_from_slice(v.as_bytes());
98 buf.put_u8(0);
99 }
100 buf.put_u8(0); let len = buf.len() as i32;
103 buf[0..4].copy_from_slice(&len.to_be_bytes());
104
105 wr.write_all(&buf).await?;
106 wr.flush().await?;
107 Ok(())
108}
109
110pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
111 let mut buf = BytesMut::with_capacity(sql.len() + 64);
112 buf.put_u8(b'Q');
113 buf.put_i32(0);
114 buf.extend_from_slice(sql.as_bytes());
115 buf.put_u8(0);
116
117 let len = (buf.len() - 1) as i32;
118 buf[1..5].copy_from_slice(&len.to_be_bytes());
119
120 wr.write_all(&buf).await?;
121 wr.flush().await?;
122 Ok(())
123}
124
125pub async fn write_password_message<W: AsyncWrite + Unpin>(
126 wr: &mut W,
127 payload: &[u8],
128) -> Result<()> {
129 let mut buf = BytesMut::with_capacity(payload.len() + 16);
130 buf.put_u8(b'p');
131 buf.put_i32(0);
132 buf.extend_from_slice(payload);
133
134 let len = (buf.len() - 1) as i32;
135 buf[1..5].copy_from_slice(&len.to_be_bytes());
136
137 wr.write_all(&buf).await?;
138 wr.flush().await?;
139 Ok(())
140}
141
142pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
143 let mut buf = BytesMut::with_capacity(payload.len() + 16);
144 buf.put_u8(b'd');
145 buf.put_i32(0);
146 buf.extend_from_slice(payload);
147
148 let len = (buf.len() - 1) as i32;
149 buf[1..5].copy_from_slice(&len.to_be_bytes());
150
151 wr.write_all(&buf).await?;
152 wr.flush().await?;
153 Ok(())
154}
155
156pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
157 let mut buf = BytesMut::with_capacity(5);
158 buf.put_u8(b'c'); buf.put_i32(4); wr.write_all(&buf).await?;
161 wr.flush().await?;
162 Ok(())
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::io::Cursor;
169
170 #[tokio::test]
171 async fn read_backend_message_parses_valid_message() {
172 let data = [b'Z', 0, 0, 0, 5, b'I'];
174 let mut cursor = Cursor::new(&data[..]);
175
176 let msg = read_backend_message(&mut cursor).await.unwrap();
177 assert_eq!(msg.tag, b'Z');
178 assert_eq!(&msg.payload[..], b"I");
179 assert!(msg.is_ready_for_query());
180 }
181
182 #[tokio::test]
183 async fn read_backend_message_handles_empty_payload() {
184 let data = [b'N', 0, 0, 0, 4];
186 let mut cursor = Cursor::new(&data[..]);
187
188 let msg = read_backend_message(&mut cursor).await.unwrap();
189 assert_eq!(msg.tag, b'N');
190 assert!(msg.payload.is_empty());
191 }
192
193 #[tokio::test]
194 async fn read_backend_message_rejects_invalid_length() {
195 let data = [b'Z', 0, 0, 0, 3];
197 let mut cursor = Cursor::new(&data[..]);
198
199 let err = read_backend_message(&mut cursor).await.unwrap_err();
200 assert!(err.to_string().contains("invalid backend message length"));
201 }
202
203 #[tokio::test]
204 async fn read_backend_message_rejects_oversized_message() {
205 let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
207 let data = [
208 b'Z',
209 (huge_len >> 24) as u8,
210 (huge_len >> 16) as u8,
211 (huge_len >> 8) as u8,
212 huge_len as u8,
213 ];
214 let mut cursor = Cursor::new(&data[..]);
215
216 let err = read_backend_message(&mut cursor).await.unwrap_err();
217 assert!(err.to_string().contains("too large"));
218 }
219
220 #[tokio::test]
221 async fn write_ssl_request_produces_valid_bytes() {
222 let mut buf = Vec::new();
223 write_ssl_request(&mut buf).await.unwrap();
224
225 assert_eq!(buf.len(), 8);
226 assert_eq!(&buf[0..4], &8i32.to_be_bytes());
228 assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
230 }
231
232 #[tokio::test]
233 async fn write_startup_message_includes_params() {
234 let mut buf = Vec::new();
235 let params = [("user", "postgres"), ("database", "test")];
236 write_startup_message(&mut buf, 196608, ¶ms)
237 .await
238 .unwrap();
239
240 let s = String::from_utf8_lossy(&buf);
242 assert!(s.contains("user"));
243 assert!(s.contains("postgres"));
244 assert!(s.contains("database"));
245 assert!(s.contains("test"));
246
247 let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
249 assert_eq!(len, buf.len());
250 }
251
252 #[tokio::test]
253 async fn write_query_produces_valid_message() {
254 let mut buf = Vec::new();
255 write_query(&mut buf, "SELECT 1").await.unwrap();
256
257 assert_eq!(buf[0], b'Q');
259
260 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
262 assert_eq!(len, buf.len() - 1);
263
264 assert!(buf[5..].starts_with(b"SELECT 1"));
266
267 assert_eq!(buf[buf.len() - 1], 0);
269 }
270
271 #[tokio::test]
272 async fn write_password_message_produces_valid_message() {
273 let mut buf = Vec::new();
274 write_password_message(&mut buf, b"secret").await.unwrap();
275
276 assert_eq!(buf[0], b'p');
277 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
278 assert_eq!(len, buf.len() - 1);
279 assert_eq!(&buf[5..], b"secret");
280 }
281
282 #[tokio::test]
283 async fn write_copy_data_produces_valid_message() {
284 let mut buf = Vec::new();
285 write_copy_data(&mut buf, b"payload").await.unwrap();
286
287 assert_eq!(buf[0], b'd');
288 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
289 assert_eq!(len, buf.len() - 1);
290 assert_eq!(&buf[5..], b"payload");
291 }
292
293 #[tokio::test]
294 async fn write_copy_done_produces_valid_message() {
295 let mut buf = Vec::new();
296 write_copy_done(&mut buf).await.unwrap();
297
298 assert_eq!(buf.len(), 5);
299 assert_eq!(buf[0], b'c');
300 assert_eq!(&buf[1..5], &4i32.to_be_bytes());
302 }
303
304 #[test]
305 fn backend_message_helper_methods() {
306 let error = BackendMessage {
307 tag: b'E',
308 payload: Bytes::new(),
309 };
310 assert!(error.is_error());
311 assert!(!error.is_ready_for_query());
312
313 let ready = BackendMessage {
314 tag: b'Z',
315 payload: Bytes::from_static(b"I"),
316 };
317 assert!(ready.is_ready_for_query());
318 assert!(!ready.is_error());
319
320 let copy_both = BackendMessage {
321 tag: b'W',
322 payload: Bytes::new(),
323 };
324 assert!(copy_both.is_copy_both_response());
325
326 let copy_data = BackendMessage {
327 tag: b'd',
328 payload: Bytes::new(),
329 };
330 assert!(copy_data.is_copy_data());
331
332 let auth = BackendMessage {
333 tag: b'R',
334 payload: Bytes::new(),
335 };
336 assert!(auth.is_auth_request());
337 }
338}