1use bytes::BytesMut;
4use tokio::io::{AsyncRead, AsyncReadExt};
5
6use crate::error::BoltError;
7
8const MAX_CHUNK_SIZE: usize = 65535;
10
11const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
13
14pub struct ChunkReader<R> {
19 reader: R,
20 buf: BytesMut,
21 max_message_size: usize,
22}
23
24impl<R: AsyncRead + Unpin> ChunkReader<R> {
25 pub fn new(reader: R) -> Self {
26 Self {
27 reader,
28 buf: BytesMut::with_capacity(MAX_CHUNK_SIZE),
29 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
30 }
31 }
32
33 pub fn set_max_message_size(&mut self, max_bytes: usize) {
38 self.max_message_size = max_bytes;
39 }
40
41 pub async fn read_message(&mut self) -> Result<BytesMut, BoltError> {
43 let mut message = BytesMut::new();
44
45 loop {
46 let mut header = [0u8; 2];
48 self.reader.read_exact(&mut header).await?;
49 let chunk_len = u16::from_be_bytes(header) as usize;
50
51 if chunk_len == 0 {
52 break;
54 }
55
56 if message.len() + chunk_len > self.max_message_size {
57 return Err(BoltError::Protocol(format!(
58 "message size exceeds limit of {} bytes",
59 self.max_message_size
60 )));
61 }
62
63 self.buf.resize(chunk_len, 0);
65 self.reader.read_exact(&mut self.buf[..chunk_len]).await?;
66 message.extend_from_slice(&self.buf[..chunk_len]);
67 }
68
69 Ok(message)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use std::io::Cursor;
77
78 #[tokio::test]
79 async fn read_single_chunk_message() {
80 let data: Vec<u8> = vec![
82 0x00, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, ];
86 let mut reader = ChunkReader::new(Cursor::new(data));
87 let msg = reader.read_message().await.unwrap();
88 assert_eq!(&msg[..], &[0x01, 0x02, 0x03]);
89 }
90
91 #[tokio::test]
92 async fn read_multi_chunk_message() {
93 let data: Vec<u8> = vec![
94 0x00, 0x02, 0xAA, 0xBB, 0x00, 0x01, 0xCC, 0x00, 0x00, ];
98 let mut reader = ChunkReader::new(Cursor::new(data));
99 let msg = reader.read_message().await.unwrap();
100 assert_eq!(&msg[..], &[0xAA, 0xBB, 0xCC]);
101 }
102
103 #[tokio::test]
104 async fn read_empty_message() {
105 let data: Vec<u8> = vec![0x00, 0x00];
107 let mut reader = ChunkReader::new(Cursor::new(data));
108 let msg = reader.read_message().await.unwrap();
109 assert!(msg.is_empty());
110 }
111
112 #[tokio::test]
113 async fn read_message_exceeds_limit() {
114 let data: Vec<u8> = vec![
116 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00, ];
120 let mut reader = ChunkReader::new(Cursor::new(data));
121 reader.set_max_message_size(2);
122 let result = reader.read_message().await;
123 assert!(result.is_err());
124 let err = result.unwrap_err().to_string();
125 assert!(err.contains("exceeds limit"), "unexpected error: {err}");
126 }
127}