use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::error::BoltError;
const MAX_CHUNK_SIZE: usize = 65535;
const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
pub struct ChunkReader<R> {
reader: R,
buf: BytesMut,
max_message_size: usize,
}
impl<R: AsyncRead + Unpin> ChunkReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
buf: BytesMut::with_capacity(MAX_CHUNK_SIZE),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
pub fn set_max_message_size(&mut self, max_bytes: usize) {
self.max_message_size = max_bytes;
}
pub async fn read_message(&mut self) -> Result<BytesMut, BoltError> {
let mut message = BytesMut::new();
loop {
let mut header = [0u8; 2];
self.reader.read_exact(&mut header).await?;
let chunk_len = u16::from_be_bytes(header) as usize;
if chunk_len == 0 {
break;
}
if message.len() + chunk_len > self.max_message_size {
return Err(BoltError::Protocol(format!(
"message size exceeds limit of {} bytes",
self.max_message_size
)));
}
self.buf.resize(chunk_len, 0);
self.reader.read_exact(&mut self.buf[..chunk_len]).await?;
message.extend_from_slice(&self.buf[..chunk_len]);
}
Ok(message)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn read_single_chunk_message() {
let data: Vec<u8> = vec![
0x00, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, ];
let mut reader = ChunkReader::new(Cursor::new(data));
let msg = reader.read_message().await.unwrap();
assert_eq!(&msg[..], &[0x01, 0x02, 0x03]);
}
#[tokio::test]
async fn read_multi_chunk_message() {
let data: Vec<u8> = vec![
0x00, 0x02, 0xAA, 0xBB, 0x00, 0x01, 0xCC, 0x00, 0x00, ];
let mut reader = ChunkReader::new(Cursor::new(data));
let msg = reader.read_message().await.unwrap();
assert_eq!(&msg[..], &[0xAA, 0xBB, 0xCC]);
}
#[tokio::test]
async fn read_empty_message() {
let data: Vec<u8> = vec![0x00, 0x00];
let mut reader = ChunkReader::new(Cursor::new(data));
let msg = reader.read_message().await.unwrap();
assert!(msg.is_empty());
}
#[tokio::test]
async fn read_message_exceeds_limit() {
let data: Vec<u8> = vec![
0x00, 0x04, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00, ];
let mut reader = ChunkReader::new(Cursor::new(data));
reader.set_max_message_size(2);
let result = reader.read_message().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("exceeds limit"), "unexpected error: {err}");
}
}