use bytes::{BufMut, Bytes, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::error::{PgWireError, Result};
pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BackendMessage {
pub tag: u8,
pub payload: Bytes, }
impl BackendMessage {
#[inline]
pub fn is_error(&self) -> bool {
self.tag == b'E'
}
#[inline]
pub fn is_ready_for_query(&self) -> bool {
self.tag == b'Z'
}
#[inline]
pub fn is_copy_both_response(&self) -> bool {
self.tag == b'W'
}
#[inline]
pub fn is_copy_data(&self) -> bool {
self.tag == b'd'
}
#[inline]
pub fn is_auth_request(&self) -> bool {
self.tag == b'R'
}
}
pub async fn read_backend_message<R: AsyncRead + Unpin>(rd: &mut R) -> Result<BackendMessage> {
let mut reader = MessageReader::new();
reader.read(rd).await
}
pub struct MessageReader {
hdr: [u8; 5],
hdr_filled: usize,
payload: BytesMut,
payload_filled: usize,
payload_len: Option<usize>,
tag: u8,
}
impl MessageReader {
pub fn new() -> Self {
Self::with_capacity(4096)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
hdr: [0u8; 5],
hdr_filled: 0,
payload: BytesMut::with_capacity(capacity),
payload_filled: 0,
payload_len: None,
tag: 0,
}
}
pub async fn read<R: AsyncRead + Unpin>(&mut self, rd: &mut R) -> Result<BackendMessage> {
while self.hdr_filled < 5 {
let n = rd.read(&mut self.hdr[self.hdr_filled..]).await?;
if n == 0 {
return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
io::ErrorKind::UnexpectedEof,
"EOF while reading backend message header",
))));
}
self.hdr_filled += n;
}
if self.payload_len.is_none() {
let len = i32::from_be_bytes([self.hdr[1], self.hdr[2], self.hdr[3], self.hdr[4]]);
if len < 4 {
self.hdr_filled = 0;
return Err(PgWireError::Protocol(format!(
"invalid backend message length: {len}"
)));
}
let payload_len = (len - 4) as usize;
if payload_len > MAX_MESSAGE_SIZE {
self.hdr_filled = 0;
return Err(PgWireError::Protocol(format!(
"backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
)));
}
self.tag = self.hdr[0];
self.payload.clear();
self.payload.resize(payload_len, 0);
self.payload_filled = 0;
self.payload_len = Some(payload_len);
}
let payload_len = self.payload_len.unwrap();
while self.payload_filled < payload_len {
let n = rd.read(&mut self.payload[self.payload_filled..]).await?;
if n == 0 {
return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new(
io::ErrorKind::UnexpectedEof,
"EOF while reading backend message payload",
))));
}
self.payload_filled += n;
}
let payload = self.payload.split().freeze();
let tag = self.tag;
self.hdr_filled = 0;
self.payload_len = None;
self.payload_filled = 0;
Ok(BackendMessage { tag, payload })
}
}
impl Default for MessageReader {
fn default() -> Self {
Self::new()
}
}
pub async fn read_backend_message_into<R: AsyncRead + Unpin>(
rd: &mut R,
buf: &mut BytesMut,
) -> Result<BackendMessage> {
let mut hdr = [0u8; 5];
rd.read_exact(&mut hdr).await?;
let tag = hdr[0];
let len = i32::from_be_bytes([hdr[1], hdr[2], hdr[3], hdr[4]]);
if len < 4 {
return Err(PgWireError::Protocol(format!(
"invalid backend message length: {len}"
)));
}
let payload_len = (len - 4) as usize;
if payload_len > MAX_MESSAGE_SIZE {
return Err(PgWireError::Protocol(format!(
"backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})"
)));
}
buf.clear();
buf.resize(payload_len, 0);
rd.read_exact(&mut buf[..]).await?;
Ok(BackendMessage {
tag,
payload: buf.split().freeze(),
})
}
pub async fn write_ssl_request<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
let mut buf = [0u8; 8];
buf[0..4].copy_from_slice(&(8i32).to_be_bytes());
buf[4..8].copy_from_slice(&(80877103i32).to_be_bytes());
wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
pub async fn write_startup_message<W: AsyncWrite + Unpin>(
wr: &mut W,
protocol_version: i32,
params: &[(&str, &str)],
) -> Result<()> {
let mut buf = BytesMut::with_capacity(256);
buf.put_i32(0); buf.put_i32(protocol_version);
for (k, v) in params {
buf.extend_from_slice(k.as_bytes());
buf.put_u8(0);
buf.extend_from_slice(v.as_bytes());
buf.put_u8(0);
}
buf.put_u8(0);
let len = buf.len() as i32;
buf[0..4].copy_from_slice(&len.to_be_bytes());
wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
pub async fn write_query<W: AsyncWrite + Unpin>(wr: &mut W, sql: &str) -> Result<()> {
let mut buf = BytesMut::with_capacity(sql.len() + 64);
buf.put_u8(b'Q');
buf.put_i32(0);
buf.extend_from_slice(sql.as_bytes());
buf.put_u8(0);
let len = (buf.len() - 1) as i32;
buf[1..5].copy_from_slice(&len.to_be_bytes());
wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
pub async fn write_password_message<W: AsyncWrite + Unpin>(
wr: &mut W,
payload: &[u8],
) -> Result<()> {
let mut buf = BytesMut::with_capacity(payload.len() + 16);
buf.put_u8(b'p');
buf.put_i32(0);
buf.extend_from_slice(payload);
let len = (buf.len() - 1) as i32;
buf[1..5].copy_from_slice(&len.to_be_bytes());
wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
pub async fn write_copy_data<W: AsyncWrite + Unpin>(wr: &mut W, payload: &[u8]) -> Result<()> {
let mut buf = BytesMut::with_capacity(payload.len() + 16);
buf.put_u8(b'd');
buf.put_i32(0);
buf.extend_from_slice(payload);
let len = (buf.len() - 1) as i32;
buf[1..5].copy_from_slice(&len.to_be_bytes());
wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
pub async fn write_copy_done<W: AsyncWrite + Unpin>(wr: &mut W) -> Result<()> {
let mut buf = BytesMut::with_capacity(5);
buf.put_u8(b'c'); buf.put_i32(4); wr.write_all(&buf).await?;
wr.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn read_backend_message_parses_valid_message() {
let data = [b'Z', 0, 0, 0, 5, b'I'];
let mut cursor = Cursor::new(&data[..]);
let msg = read_backend_message(&mut cursor).await.unwrap();
assert_eq!(msg.tag, b'Z');
assert_eq!(&msg.payload[..], b"I");
assert!(msg.is_ready_for_query());
}
#[tokio::test]
async fn read_backend_message_handles_empty_payload() {
let data = [b'N', 0, 0, 0, 4];
let mut cursor = Cursor::new(&data[..]);
let msg = read_backend_message(&mut cursor).await.unwrap();
assert_eq!(msg.tag, b'N');
assert!(msg.payload.is_empty());
}
#[tokio::test]
async fn read_backend_message_rejects_invalid_length() {
let data = [b'Z', 0, 0, 0, 3];
let mut cursor = Cursor::new(&data[..]);
let err = read_backend_message(&mut cursor).await.unwrap_err();
assert!(err.to_string().contains("invalid backend message length"));
}
#[tokio::test]
async fn message_reader_reads_complete_message() {
let data = [b'Z', 0, 0, 0, 5, b'I'];
let mut cursor = Cursor::new(&data[..]);
let mut reader = MessageReader::new();
let msg = reader.read(&mut cursor).await.unwrap();
assert_eq!(msg.tag, b'Z');
assert_eq!(&msg.payload[..], b"I");
}
#[tokio::test]
async fn message_reader_reads_back_to_back_messages() {
let data = [b'Z', 0, 0, 0, 5, b'I', b'N', 0, 0, 0, 4];
let mut cursor = Cursor::new(&data[..]);
let mut reader = MessageReader::new();
let m1 = reader.read(&mut cursor).await.unwrap();
assert_eq!(m1.tag, b'Z');
assert_eq!(&m1.payload[..], b"I");
let m2 = reader.read(&mut cursor).await.unwrap();
assert_eq!(m2.tag, b'N');
assert!(m2.payload.is_empty());
}
#[tokio::test]
async fn message_reader_resumes_after_cancellation_mid_header() {
let (mut writer, mut rd) = tokio::io::duplex(64);
let mut reader = MessageReader::new();
let header = [b'd', 0, 0, 0, 8];
let payload = b"abcd";
writer.write_all(&header[..3]).await.unwrap();
let timed_out =
tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
assert!(
timed_out.is_err(),
"read must time out while waiting for remaining header bytes"
);
writer.write_all(&header[3..]).await.unwrap();
writer.write_all(payload).await.unwrap();
let msg = reader.read(&mut rd).await.unwrap();
assert_eq!(msg.tag, b'd');
assert_eq!(&msg.payload[..], payload);
}
#[tokio::test]
async fn message_reader_resumes_after_cancellation_mid_payload() {
let (mut writer, mut rd) = tokio::io::duplex(64);
let mut reader = MessageReader::new();
let payload: [u8; 16] = std::array::from_fn(|i| i as u8);
let len = (4 + payload.len()) as i32;
let header = [
b'd',
(len >> 24) as u8,
(len >> 16) as u8,
(len >> 8) as u8,
len as u8,
];
writer.write_all(&header).await.unwrap();
writer.write_all(&payload[..5]).await.unwrap();
let timed_out =
tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await;
assert!(
timed_out.is_err(),
"read must time out while waiting for remaining payload bytes"
);
writer.write_all(&payload[5..]).await.unwrap();
let msg = reader.read(&mut rd).await.unwrap();
assert_eq!(msg.tag, b'd');
assert_eq!(&msg.payload[..], &payload[..]);
}
#[tokio::test]
async fn message_reader_rejects_invalid_length() {
let data = [b'Z', 0, 0, 0, 3];
let mut cursor = Cursor::new(&data[..]);
let mut reader = MessageReader::new();
let err = reader.read(&mut cursor).await.unwrap_err();
assert!(err.to_string().contains("invalid backend message length"));
}
#[tokio::test]
async fn read_backend_message_rejects_oversized_message() {
let huge_len = (MAX_MESSAGE_SIZE as i32) + 5;
let data = [
b'Z',
(huge_len >> 24) as u8,
(huge_len >> 16) as u8,
(huge_len >> 8) as u8,
huge_len as u8,
];
let mut cursor = Cursor::new(&data[..]);
let err = read_backend_message(&mut cursor).await.unwrap_err();
assert!(err.to_string().contains("too large"));
}
#[tokio::test]
async fn write_ssl_request_produces_valid_bytes() {
let mut buf = Vec::new();
write_ssl_request(&mut buf).await.unwrap();
assert_eq!(buf.len(), 8);
assert_eq!(&buf[0..4], &8i32.to_be_bytes());
assert_eq!(&buf[4..8], &80877103i32.to_be_bytes());
}
#[tokio::test]
async fn write_startup_message_includes_params() {
let mut buf = Vec::new();
let params = [("user", "postgres"), ("database", "test")];
write_startup_message(&mut buf, 196608, ¶ms)
.await
.unwrap();
let s = String::from_utf8_lossy(&buf);
assert!(s.contains("user"));
assert!(s.contains("postgres"));
assert!(s.contains("database"));
assert!(s.contains("test"));
let len = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
assert_eq!(len, buf.len());
}
#[tokio::test]
async fn write_query_produces_valid_message() {
let mut buf = Vec::new();
write_query(&mut buf, "SELECT 1").await.unwrap();
assert_eq!(buf[0], b'Q');
let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
assert_eq!(len, buf.len() - 1);
assert!(buf[5..].starts_with(b"SELECT 1"));
assert_eq!(buf[buf.len() - 1], 0);
}
#[tokio::test]
async fn write_password_message_produces_valid_message() {
let mut buf = Vec::new();
write_password_message(&mut buf, b"secret").await.unwrap();
assert_eq!(buf[0], b'p');
let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
assert_eq!(len, buf.len() - 1);
assert_eq!(&buf[5..], b"secret");
}
#[tokio::test]
async fn write_copy_data_produces_valid_message() {
let mut buf = Vec::new();
write_copy_data(&mut buf, b"payload").await.unwrap();
assert_eq!(buf[0], b'd');
let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
assert_eq!(len, buf.len() - 1);
assert_eq!(&buf[5..], b"payload");
}
#[tokio::test]
async fn write_copy_done_produces_valid_message() {
let mut buf = Vec::new();
write_copy_done(&mut buf).await.unwrap();
assert_eq!(buf.len(), 5);
assert_eq!(buf[0], b'c');
assert_eq!(&buf[1..5], &4i32.to_be_bytes());
}
#[test]
fn backend_message_helper_methods() {
let error = BackendMessage {
tag: b'E',
payload: Bytes::new(),
};
assert!(error.is_error());
assert!(!error.is_ready_for_query());
let ready = BackendMessage {
tag: b'Z',
payload: Bytes::from_static(b"I"),
};
assert!(ready.is_ready_for_query());
assert!(!ready.is_error());
let copy_both = BackendMessage {
tag: b'W',
payload: Bytes::new(),
};
assert!(copy_both.is_copy_both_response());
let copy_data = BackendMessage {
tag: b'd',
payload: Bytes::new(),
};
assert!(copy_data.is_copy_data());
let auth = BackendMessage {
tag: b'R',
payload: Bytes::new(),
};
assert!(auth.is_auth_request());
}
}