pgwire_replication/protocol/
framing.rs1use bytes::{BufMut, Bytes, BytesMut};
2use std::io;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::error::{PgWireError, Result};
6
7pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BackendMessage {
13 pub tag: u8,
14 pub payload: Bytes, }
16
17impl BackendMessage {
18 #[inline]
20 pub fn is_error(&self) -> bool {
21 self.tag == b'E'
22 }
23
24 #[inline]
26 pub fn is_ready_for_query(&self) -> bool {
27 self.tag == b'Z'
28 }
29
30 #[inline]
32 pub fn is_copy_both_response(&self) -> bool {
33 self.tag == b'W'
34 }
35
36 #[inline]
38 pub fn is_copy_data(&self) -> bool {
39 self.tag == b'd'
40 }
41
42 #[inline]
44 pub fn is_auth_request(&self) -> bool {
45 self.tag == b'R'
46 }
47}
48
49pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
50 let mut reader = MessageReader::new();
51 reader.read(rd).await
52}
53
54pub struct MessageReader {
68 hdr: [u8; 5],
69 hdr_filled: usize,
70 payload: BytesMut,
71 payload_filled: usize,
72 payload_len: Option<usize>,
75 tag: u8,
76}
77
78impl MessageReader {
79 pub fn new() -> Self {
80 Self::with_capacity(4096)
81 }
82
83 pub fn with_capacity(capacity: usize) -> Self {
84 Self {
85 hdr: [0u8; 5],
86 hdr_filled: 0,
87 payload: BytesMut::with_capacity(capacity),
88 payload_filled: 0,
89 payload_len: None,
90 tag: 0,
91 }
92 }
93
94 pub async fn read<R: AsyncRead + Unpin>(&mut self, rd: &mut R) -> Result<BackendMessage> {
99 while self.hdr_filled < 5 {
101 let n = rd.read(&mut self.hdr[self.hdr_filled..]).await?;
102 if n == 0 {
103 return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
104 io::ErrorKind::UnexpectedEof,
105 "EOF while reading backend message header",
106 ))));
107 }
108 self.hdr_filled += n;
109 }
110
111 if self.payload_len.is_none() {
113 let len = i32::from_be_bytes([self.hdr[1], self.hdr[2], self.hdr[3], self.hdr[4]]);
114
115 if len < 4 {
116 self.hdr_filled = 0;
119 return Err(PgWireError::Protocol(format!(
120 "invalid backend message length: {len}"
121 )));
122 }
123
124 let payload_len = (len - 4) as usize;
125
126 if payload_len > MAX_MESSAGE_SIZE {
127 self.hdr_filled = 0;
128 return Err(PgWireError::Protocol(format!(
129 "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
130 )));
131 }
132
133 self.tag = self.hdr[0];
134 self.payload.clear();
135 self.payload.resize(payload_len, 0);
136 self.payload_filled = 0;
137 self.payload_len = Some(payload_len);
138 }
139
140 let payload_len = self.payload_len.unwrap();
141
142 while self.payload_filled < payload_len {
144 let n = rd.read(&mut self.payload[self.payload_filled..]).await?;
145 if n == 0 {
146 return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
147 io::ErrorKind::UnexpectedEof,
148 "EOF while reading backend message payload",
149 ))));
150 }
151 self.payload_filled += n;
152 }
153
154 let payload = self.payload.split().freeze();
156 let tag = self.tag;
157 self.hdr_filled = 0;
158 self.payload_len = None;
159 self.payload_filled = 0;
160
161 Ok(BackendMessage { tag, payload })
162 }
163}
164
165impl Default for MessageReader {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171pub async fn read_backend_message_into<R: AsyncRead + Unpin>(
176 rd: &mut R,
177 buf: &mut BytesMut,
178) -> Result<BackendMessage> {
179 let mut hdr = [0u8; 5];
180 rd.read_exact(&mut hdr).await?;
181 let tag = hdr[0];
182 let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
183
184 if len < 4 {
185 return Err(PgWireError::Protocol(format!(
186 "invalid backend message length: {len}"
187 )));
188 }
189
190 let payload_len = (len - 4) as usize;
191
192 if payload_len > MAX_MESSAGE_SIZE {
193 return Err(PgWireError::Protocol(format!(
194 "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
195 )));
196 }
197
198 buf.clear();
199 buf.resize(payload_len, 0);
200 rd.read_exact(&mut buf[..]).await?;
201 Ok(BackendMessage {
202 tag,
203 payload: buf.split().freeze(),
204 })
205}
206
207pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
208 let mut buf = [0u8; 8];
209 buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
210 buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
211 wr.write_all(&buf).await?;
212 wr.flush().await?;
213 Ok(())
214}
215
216pub async fn write_startup_message<W: AsyncWrite + Unpin>(
217 wr: &mut W,
218 protocol_version: i32,
219 params: &[(&str, &str)],
220) -> Result<()> {
221 let mut buf = BytesMut::with_capacity(256);
222 buf.put_i32(0); buf.put_i32(protocol_version);
224
225 for (k, v) in params {
226 buf.extend_from_slice(k.as_bytes());
227 buf.put_u8(0);
228 buf.extend_from_slice(v.as_bytes());
229 buf.put_u8(0);
230 }
231 buf.put_u8(0); let len = buf.len() as i32;
234 buf[0..4].copy_from_slice(&len.to_be_bytes());
235
236 wr.write_all(&buf).await?;
237 wr.flush().await?;
238 Ok(())
239}
240
241pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
242 let mut buf = BytesMut::with_capacity(sql.len() + 64);
243 buf.put_u8(b'Q');
244 buf.put_i32(0);
245 buf.extend_from_slice(sql.as_bytes());
246 buf.put_u8(0);
247
248 let len = (buf.len() - 1) as i32;
249 buf[1..5].copy_from_slice(&len.to_be_bytes());
250
251 wr.write_all(&buf).await?;
252 wr.flush().await?;
253 Ok(())
254}
255
256pub async fn write_password_message<W: AsyncWrite + Unpin>(
257 wr: &mut W,
258 payload: &[u8],
259) -> Result<()> {
260 let mut buf = BytesMut::with_capacity(payload.len() + 16);
261 buf.put_u8(b'p');
262 buf.put_i32(0);
263 buf.extend_from_slice(payload);
264
265 let len = (buf.len() - 1) as i32;
266 buf[1..5].copy_from_slice(&len.to_be_bytes());
267
268 wr.write_all(&buf).await?;
269 wr.flush().await?;
270 Ok(())
271}
272
273pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
274 let mut buf = BytesMut::with_capacity(payload.len() + 16);
275 buf.put_u8(b'd');
276 buf.put_i32(0);
277 buf.extend_from_slice(payload);
278
279 let len = (buf.len() - 1) as i32;
280 buf[1..5].copy_from_slice(&len.to_be_bytes());
281
282 wr.write_all(&buf).await?;
283 wr.flush().await?;
284 Ok(())
285}
286
287pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
288 let mut buf = BytesMut::with_capacity(5);
289 buf.put_u8(b'c'); buf.put_i32(4); wr.write_all(&buf).await?;
292 wr.flush().await?;
293 Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use std::io::Cursor;
300 use tokio::io::AsyncWriteExt;
301
302 #[tokio::test]
303 async fn read_backend_message_parses_valid_message() {
304 let data = [b'Z', 0, 0, 0, 5, b'I'];
306 let mut cursor = Cursor::new(&data[..]);
307
308 let msg = read_backend_message(&mut cursor).await.unwrap();
309 assert_eq!(msg.tag, b'Z');
310 assert_eq!(&msg.payload[..], b"I");
311 assert!(msg.is_ready_for_query());
312 }
313
314 #[tokio::test]
315 async fn read_backend_message_handles_empty_payload() {
316 let data = [b'N', 0, 0, 0, 4];
318 let mut cursor = Cursor::new(&data[..]);
319
320 let msg = read_backend_message(&mut cursor).await.unwrap();
321 assert_eq!(msg.tag, b'N');
322 assert!(msg.payload.is_empty());
323 }
324
325 #[tokio::test]
326 async fn read_backend_message_rejects_invalid_length() {
327 let data = [b'Z', 0, 0, 0, 3];
329 let mut cursor = Cursor::new(&data[..]);
330
331 let err = read_backend_message(&mut cursor).await.unwrap_err();
332 assert!(err.to_string().contains("invalid backend message length"));
333 }
334
335 #[tokio::test]
336 async fn message_reader_reads_complete_message() {
337 let data = [b'Z', 0, 0, 0, 5, b'I'];
339 let mut cursor = Cursor::new(&data[..]);
340
341 let mut reader = MessageReader::new();
342 let msg = reader.read(&mut cursor).await.unwrap();
343 assert_eq!(msg.tag, b'Z');
344 assert_eq!(&msg.payload[..], b"I");
345 }
346
347 #[tokio::test]
348 async fn message_reader_reads_back_to_back_messages() {
349 let data = [b'Z', 0, 0, 0, 5, b'I', b'N', 0, 0, 0, 4];
351 let mut cursor = Cursor::new(&data[..]);
352
353 let mut reader = MessageReader::new();
354
355 let m1 = reader.read(&mut cursor).await.unwrap();
356 assert_eq!(m1.tag, b'Z');
357 assert_eq!(&m1.payload[..], b"I");
358
359 let m2 = reader.read(&mut cursor).await.unwrap();
360 assert_eq!(m2.tag, b'N');
361 assert!(m2.payload.is_empty());
362 }
363
364 #[tokio::test]
373 async fn message_reader_resumes_after_cancellation_mid_header() {
374 let (mut writer, mut rd) = tokio::io::duplex(64);
375 let mut reader = MessageReader::new();
376
377 let header = [b'd', 0, 0, 0, 8];
379 let payload = b"abcd";
380
381 writer.write_all(&header[..3]).await.unwrap();
383
384 let timed_out =
385 tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
386 assert!(
387 timed_out.is_err(),
388 "read must time out while waiting for remaining header bytes"
389 );
390
391 writer.write_all(&header[3..]).await.unwrap();
394 writer.write_all(payload).await.unwrap();
395
396 let msg = reader.read(&mut rd).await.unwrap();
397 assert_eq!(msg.tag, b'd');
398 assert_eq!(&msg.payload[..], payload);
399 }
400
401 #[tokio::test]
403 async fn message_reader_resumes_after_cancellation_mid_payload() {
404 let (mut writer, mut rd) = tokio::io::duplex(64);
405 let mut reader = MessageReader::new();
406
407 let payload: [u8; 16] = std::array::from_fn(|i| i as u8);
409 let len = (4 + payload.len()) as i32;
410 let header = [
411 b'd',
412 (len >> 24) as u8,
413 (len >> 16) as u8,
414 (len >> 8) as u8,
415 len as u8,
416 ];
417
418 writer.write_all(&header).await.unwrap();
420 writer.write_all(&payload[..5]).await.unwrap();
421
422 let timed_out =
423 tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
424 assert!(
425 timed_out.is_err(),
426 "read must time out while waiting for remaining payload bytes"
427 );
428
429 writer.write_all(&payload[5..]).await.unwrap();
431
432 let msg = reader.read(&mut rd).await.unwrap();
433 assert_eq!(msg.tag, b'd');
434 assert_eq!(&msg.payload[..], &payload[..]);
435 }
436
437 #[tokio::test]
438 async fn message_reader_rejects_invalid_length() {
439 let data = [b'Z', 0, 0, 0, 3];
440 let mut cursor = Cursor::new(&data[..]);
441
442 let mut reader = MessageReader::new();
443 let err = reader.read(&mut cursor).await.unwrap_err();
444 assert!(err.to_string().contains("invalid backend message length"));
445 }
446
447 #[tokio::test]
448 async fn read_backend_message_rejects_oversized_message() {
449 let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
451 let data = [
452 b'Z',
453 (huge_len >> 24) as u8,
454 (huge_len >> 16) as u8,
455 (huge_len >> 8) as u8,
456 huge_len as u8,
457 ];
458 let mut cursor = Cursor::new(&data[..]);
459
460 let err = read_backend_message(&mut cursor).await.unwrap_err();
461 assert!(err.to_string().contains("too large"));
462 }
463
464 #[tokio::test]
465 async fn write_ssl_request_produces_valid_bytes() {
466 let mut buf = Vec::new();
467 write_ssl_request(&mut buf).await.unwrap();
468
469 assert_eq!(buf.len(), 8);
470 assert_eq!(&buf[0..4], &8i32.to_be_bytes());
472 assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
474 }
475
476 #[tokio::test]
477 async fn write_startup_message_includes_params() {
478 let mut buf = Vec::new();
479 let params = [("user", "postgres"), ("database", "test")];
480 write_startup_message(&mut buf, 196608, ¶ms)
481 .await
482 .unwrap();
483
484 let s = String::from_utf8_lossy(&buf);
486 assert!(s.contains("user"));
487 assert!(s.contains("postgres"));
488 assert!(s.contains("database"));
489 assert!(s.contains("test"));
490
491 let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
493 assert_eq!(len, buf.len());
494 }
495
496 #[tokio::test]
497 async fn write_query_produces_valid_message() {
498 let mut buf = Vec::new();
499 write_query(&mut buf, "SELECT 1").await.unwrap();
500
501 assert_eq!(buf[0], b'Q');
503
504 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
506 assert_eq!(len, buf.len() - 1);
507
508 assert!(buf[5..].starts_with(b"SELECT 1"));
510
511 assert_eq!(buf[buf.len() - 1], 0);
513 }
514
515 #[tokio::test]
516 async fn write_password_message_produces_valid_message() {
517 let mut buf = Vec::new();
518 write_password_message(&mut buf, b"secret").await.unwrap();
519
520 assert_eq!(buf[0], b'p');
521 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
522 assert_eq!(len, buf.len() - 1);
523 assert_eq!(&buf[5..], b"secret");
524 }
525
526 #[tokio::test]
527 async fn write_copy_data_produces_valid_message() {
528 let mut buf = Vec::new();
529 write_copy_data(&mut buf, b"payload").await.unwrap();
530
531 assert_eq!(buf[0], b'd');
532 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
533 assert_eq!(len, buf.len() - 1);
534 assert_eq!(&buf[5..], b"payload");
535 }
536
537 #[tokio::test]
538 async fn write_copy_done_produces_valid_message() {
539 let mut buf = Vec::new();
540 write_copy_done(&mut buf).await.unwrap();
541
542 assert_eq!(buf.len(), 5);
543 assert_eq!(buf[0], b'c');
544 assert_eq!(&buf[1..5], &4i32.to_be_bytes());
546 }
547
548 #[test]
549 fn backend_message_helper_methods() {
550 let error = BackendMessage {
551 tag: b'E',
552 payload: Bytes::new(),
553 };
554 assert!(error.is_error());
555 assert!(!error.is_ready_for_query());
556
557 let ready = BackendMessage {
558 tag: b'Z',
559 payload: Bytes::from_static(b"I"),
560 };
561 assert!(ready.is_ready_for_query());
562 assert!(!ready.is_error());
563
564 let copy_both = BackendMessage {
565 tag: b'W',
566 payload: Bytes::new(),
567 };
568 assert!(copy_both.is_copy_both_response());
569
570 let copy_data = BackendMessage {
571 tag: b'd',
572 payload: Bytes::new(),
573 };
574 assert!(copy_data.is_copy_data());
575
576 let auth = BackendMessage {
577 tag: b'R',
578 payload: Bytes::new(),
579 };
580 assert!(auth.is_auth_request());
581 }
582}