Skip to main content

boltr/chunk/
reader.rs

1//! Reads chunked messages from an async byte stream.
2
3use bytes::BytesMut;
4use tokio::io::{AsyncRead, AsyncReadExt};
5
6use crate::error::BoltError;
7
8/// Maximum chunk size (2-byte unsigned length = 65535).
9const MAX_CHUNK_SIZE: usize = 65535;
10
11/// Default maximum message size: 16 MiB.
12const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
13
14/// Reads Bolt-chunked messages from an `AsyncRead` stream.
15///
16/// Each message consists of one or more chunks (2-byte big-endian length prefix
17/// followed by that many data bytes), terminated by a zero-length chunk (0x0000).
18pub struct ChunkReader<R> {
19    reader: R,
20    buf: BytesMut,
21    max_message_size: usize,
22}
23
24impl<R: AsyncRead + Unpin> ChunkReader<R> {
25    pub fn new(reader: R) -> Self {
26        Self {
27            reader,
28            buf: BytesMut::with_capacity(MAX_CHUNK_SIZE),
29            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
30        }
31    }
32
33    /// Sets the maximum allowed message size in bytes.
34    ///
35    /// Messages exceeding this limit will return a protocol error.
36    /// Default: 16 MiB.
37    pub fn set_max_message_size(&mut self, max_bytes: usize) {
38        self.max_message_size = max_bytes;
39    }
40
41    /// Reads a complete message (all chunks until the `0x0000` terminator).
42    pub async fn read_message(&mut self) -> Result<BytesMut, BoltError> {
43        let mut message = BytesMut::new();
44
45        loop {
46            // Read 2-byte chunk length.
47            let mut header = [0u8; 2];
48            self.reader.read_exact(&mut header).await?;
49            let chunk_len = u16::from_be_bytes(header) as usize;
50
51            if chunk_len == 0 {
52                // End of message.
53                break;
54            }
55
56            if message.len() + chunk_len > self.max_message_size {
57                return Err(BoltError::Protocol(format!(
58                    "message size exceeds limit of {} bytes",
59                    self.max_message_size
60                )));
61            }
62
63            // Read chunk data.
64            self.buf.resize(chunk_len, 0);
65            self.reader.read_exact(&mut self.buf[..chunk_len]).await?;
66            message.extend_from_slice(&self.buf[..chunk_len]);
67        }
68
69        Ok(message)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use std::io::Cursor;
77
78    #[tokio::test]
79    async fn read_single_chunk_message() {
80        // One chunk of 3 bytes + terminator.
81        let data: Vec<u8> = vec![
82            0x00, 0x03, // chunk length = 3
83            0x01, 0x02, 0x03, // data
84            0x00, 0x00, // terminator
85        ];
86        let mut reader = ChunkReader::new(Cursor::new(data));
87        let msg = reader.read_message().await.unwrap();
88        assert_eq!(&msg[..], &[0x01, 0x02, 0x03]);
89    }
90
91    #[tokio::test]
92    async fn read_multi_chunk_message() {
93        let data: Vec<u8> = vec![
94            0x00, 0x02, 0xAA, 0xBB, // chunk 1: 2 bytes
95            0x00, 0x01, 0xCC, // chunk 2: 1 byte
96            0x00, 0x00, // terminator
97        ];
98        let mut reader = ChunkReader::new(Cursor::new(data));
99        let msg = reader.read_message().await.unwrap();
100        assert_eq!(&msg[..], &[0xAA, 0xBB, 0xCC]);
101    }
102
103    #[tokio::test]
104    async fn read_empty_message() {
105        // Just a terminator (no data chunks).
106        let data: Vec<u8> = vec![0x00, 0x00];
107        let mut reader = ChunkReader::new(Cursor::new(data));
108        let msg = reader.read_message().await.unwrap();
109        assert!(msg.is_empty());
110    }
111
112    #[tokio::test]
113    async fn read_message_exceeds_limit() {
114        // A 4-byte chunk, but with a 2-byte limit.
115        let data: Vec<u8> = vec![
116            0x00, 0x04, // chunk length = 4
117            0x01, 0x02, 0x03, 0x04, // data
118            0x00, 0x00, // terminator
119        ];
120        let mut reader = ChunkReader::new(Cursor::new(data));
121        reader.set_max_message_size(2);
122        let result = reader.read_message().await;
123        assert!(result.is_err());
124        let err = result.unwrap_err().to_string();
125        assert!(err.contains("exceeds limit"), "unexpected error: {err}");
126    }
127}