use bytes::{BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::error::ReplicationError;
const HEADER_LEN: usize = 5;
const MAX_MESSAGE_LEN: usize = 128 * 1024 * 1024;
pub async fn read_message<R: AsyncRead + Unpin>(
reader: &mut R,
buf: &mut BytesMut,
) -> Result<Bytes, ReplicationError> {
loop {
if buf.len() >= HEADER_LEN {
let body_len = i32::from_be_bytes(buf[1..5].try_into().unwrap());
if body_len < 4 {
return Err(ReplicationError::protocol(format!(
"invalid message length {body_len} (must be >= 4)"
)));
}
let body_len = body_len as usize;
if body_len > MAX_MESSAGE_LEN {
return Err(ReplicationError::protocol(format!(
"message length {} exceeds maximum allowed {} bytes",
body_len, MAX_MESSAGE_LEN
)));
}
let total_len = 1 + body_len;
if buf.len() >= total_len {
return Ok(buf.split_to(total_len).freeze());
}
buf.reserve(total_len - buf.len());
}
let n = reader
.read_buf(buf)
.await
.map_err(|e| ReplicationError::transient_connection(format!("read error: {e}")))?;
if n == 0 {
return Err(ReplicationError::transient_connection(
"connection closed by server".to_string(),
));
}
}
}
pub async fn read_byte<R: AsyncRead + Unpin>(reader: &mut R) -> Result<u8, ReplicationError> {
let mut buf = [0u8; 1];
reader
.read_exact(&mut buf)
.await
.map_err(|e| ReplicationError::transient_connection(format!("read_byte error: {e}")))?;
Ok(buf[0])
}
pub async fn write_all<W: AsyncWrite + Unpin>(
writer: &mut W,
data: &[u8],
) -> Result<(), ReplicationError> {
writer
.write_all(data)
.await
.map_err(|e| ReplicationError::transient_connection(format!("write error: {e}")))?;
Ok(())
}
pub async fn flush<W: AsyncWrite + Unpin>(writer: &mut W) -> Result<(), ReplicationError> {
writer
.flush()
.await
.map_err(|e| ReplicationError::transient_connection(format!("flush error: {e}")))?;
Ok(())
}
pub fn build_startup_message(params: &[(&str, &str)]) -> BytesMut {
let mut body = BytesMut::new();
body.put_i32(196608); for (key, val) in params {
body.put_slice(key.as_bytes());
body.put_u8(0);
body.put_slice(val.as_bytes());
body.put_u8(0);
}
body.put_u8(0);
let mut msg = BytesMut::with_capacity(4 + body.len());
msg.put_i32((4 + body.len()) as i32);
msg.put(body);
msg
}
pub fn build_ssl_request() -> BytesMut {
let mut buf = BytesMut::with_capacity(8);
buf.put_i32(8); buf.put_i32(80877103); buf
}
pub fn build_query_message(sql: &str) -> BytesMut {
let body_len = 4 + sql.len() + 1; let mut buf = BytesMut::with_capacity(1 + body_len);
buf.put_u8(b'Q');
buf.put_i32(body_len as i32);
buf.put_slice(sql.as_bytes());
buf.put_u8(0);
buf
}
pub fn build_password_message(password: &str) -> BytesMut {
let body_len = 4 + password.len() + 1;
let mut buf = BytesMut::with_capacity(1 + body_len);
buf.put_u8(b'p');
buf.put_i32(body_len as i32);
buf.put_slice(password.as_bytes());
buf.put_u8(0);
buf
}
pub fn build_copy_data(payload: &[u8]) -> BytesMut {
let body_len = 4 + payload.len();
let mut buf = BytesMut::with_capacity(1 + body_len);
buf.put_u8(b'd');
buf.put_i32(body_len as i32);
buf.put_slice(payload);
buf
}
pub fn build_copy_done() -> BytesMut {
let mut buf = BytesMut::with_capacity(5);
buf.put_u8(b'c');
buf.put_i32(4);
buf
}
pub fn build_terminate() -> BytesMut {
let mut buf = BytesMut::with_capacity(5);
buf.put_u8(b'X');
buf.put_i32(4);
buf
}
pub fn read_cstring(data: &[u8]) -> (&str, usize) {
match memchr::memchr(0, data) {
Some(null_pos) => {
let s = std::str::from_utf8(&data[..null_pos]).unwrap_or("");
(s, null_pos + 1) }
None => {
let s = std::str::from_utf8(data).unwrap_or("");
(s, data.len())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_query_message() {
let msg = build_query_message("IDENTIFY_SYSTEM");
assert_eq!(msg[0], b'Q');
let len = i32::from_be_bytes(msg[1..5].try_into().unwrap());
assert_eq!(len as usize, 4 + 15 + 1); assert_eq!(msg[msg.len() - 1], 0); }
#[test]
fn test_build_ssl_request() {
let msg = build_ssl_request();
assert_eq!(msg.len(), 8);
let len = i32::from_be_bytes(msg[0..4].try_into().unwrap());
assert_eq!(len, 8);
let code = i32::from_be_bytes(msg[4..8].try_into().unwrap());
assert_eq!(code, 80877103);
}
#[test]
fn test_build_startup_message() {
let msg = build_startup_message(&[("user", "test"), ("database", "mydb")]);
let len = i32::from_be_bytes(msg[0..4].try_into().unwrap());
assert_eq!(len as usize, msg.len());
let proto = i32::from_be_bytes(msg[4..8].try_into().unwrap());
assert_eq!(proto, 196608); }
#[test]
fn test_read_cstring() {
let data = b"hello\0world";
let (s, consumed) = read_cstring(data);
assert_eq!(s, "hello");
assert_eq!(consumed, 6);
}
#[test]
fn test_build_password_message() {
let msg = build_password_message("secret");
assert_eq!(msg[0], b'p');
let len = i32::from_be_bytes(msg[1..5].try_into().unwrap());
assert_eq!(len as usize, msg.len() - 1); assert_eq!(&msg[5..11], b"secret");
assert_eq!(msg[11], 0); }
#[test]
fn test_build_copy_data() {
let payload = b"hello world";
let msg = build_copy_data(payload);
assert_eq!(msg[0], b'd');
let len = i32::from_be_bytes(msg[1..5].try_into().unwrap());
assert_eq!(len as usize, 4 + payload.len());
assert_eq!(&msg[5..], payload);
}
#[test]
fn test_build_copy_done() {
let msg = build_copy_done();
assert_eq!(msg.len(), 5);
assert_eq!(msg[0], b'c');
let len = i32::from_be_bytes(msg[1..5].try_into().unwrap());
assert_eq!(len, 4);
}
#[test]
fn test_build_terminate() {
let msg = build_terminate();
assert_eq!(msg.len(), 5);
assert_eq!(msg[0], b'X');
let len = i32::from_be_bytes(msg[1..5].try_into().unwrap());
assert_eq!(len, 4);
}
#[tokio::test]
async fn test_read_message_single_complete() {
use tokio::io::AsyncWriteExt;
let (mut client, mut server) = tokio::io::duplex(8192);
let msg = vec![b'Z', 0, 0, 0, 5, b'I'];
tokio::spawn(async move {
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = read_message(&mut client, &mut buf).await.unwrap();
assert_eq!(result[0], b'Z');
assert_eq!(result.len(), 6); assert_eq!(result[5], b'I');
}
#[tokio::test]
async fn test_read_message_two_messages_sequentially() {
use tokio::io::AsyncWriteExt;
let (mut client, mut server) = tokio::io::duplex(8192);
let mut msgs = Vec::new();
msgs.extend_from_slice(&[b'Z', 0, 0, 0, 5, b'I']);
msgs.extend_from_slice(&[b'Z', 0, 0, 0, 5, b'T']);
tokio::spawn(async move {
server.write_all(&msgs).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let m1 = read_message(&mut client, &mut buf).await.unwrap();
assert_eq!(m1[5], b'I');
let m2 = read_message(&mut client, &mut buf).await.unwrap();
assert_eq!(m2[5], b'T');
}
#[tokio::test]
async fn test_read_message_partial_then_complete() {
use tokio::io::AsyncWriteExt;
let (mut client, mut server) = tokio::io::duplex(8192);
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'Z', 0, 0]);
tokio::spawn(async move {
server.write_all(&[0, 5, b'I']).await.unwrap();
server.flush().await.unwrap();
});
let result = read_message(&mut client, &mut buf).await.unwrap();
assert_eq!(result[0], b'Z');
assert_eq!(result.len(), 6);
}
#[tokio::test]
async fn test_read_message_connection_closed() {
let (mut client, server) = tokio::io::duplex(8192);
drop(server);
let mut buf = BytesMut::new();
let result = read_message(&mut client, &mut buf).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_message_negative_length() {
let (mut client, _server) = tokio::io::duplex(8192);
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'Z', 0xFF, 0xFF, 0xFF, 0xFF]);
let result = read_message(&mut client, &mut buf).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("invalid message length"),
"Expected invalid length error, got: {err}"
);
}
#[tokio::test]
async fn test_read_message_body_len_too_small() {
let (mut client, _server) = tokio::io::duplex(8192);
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'Z', 0, 0, 0, 3]);
let result = read_message(&mut client, &mut buf).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("invalid message length"),
"Expected invalid length error, got: {err}"
);
}
#[tokio::test]
async fn test_read_message_exceeds_max_len() {
let (mut client, _server) = tokio::io::duplex(8192);
let huge_len: i32 = (MAX_MESSAGE_LEN as i32) + 1;
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'Z']);
buf.extend_from_slice(&huge_len.to_be_bytes());
let result = read_message(&mut client, &mut buf).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("exceeds maximum"),
"Expected max length error, got: {err}"
);
}
#[tokio::test]
async fn test_read_byte_ssl_response() {
use tokio::io::AsyncWriteExt;
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
server.write_all(&[b'S']).await.unwrap();
server.flush().await.unwrap();
});
let byte = read_byte(&mut client).await.unwrap();
assert_eq!(byte, b'S');
}
#[test]
fn test_read_cstring_no_null_terminator() {
let data = b"no null here";
let (s, _) = read_cstring(data);
assert_eq!(s, "no null here");
}
}