Skip to main content

rns_core/buffer/
mod.rs

1pub mod types;
2
3use alloc::vec::Vec;
4
5use crate::constants::STREAM_DATA_OVERHEAD;
6#[cfg(test)]
7use crate::constants::STREAM_ID_MAX;
8
9pub use types::{BufferError, Compressor, DecompressError, NoopCompressor, StreamId};
10
11/// Stream data message: 2-byte header + data.
12///
13/// Header format: `(stream_id & 0x3FFF) | (eof << 15) | (compressed << 14)`
14#[derive(Debug, Clone, PartialEq)]
15pub struct StreamDataMessage {
16    pub stream_id: StreamId,
17    pub compressed: bool,
18    pub eof: bool,
19    pub data: Vec<u8>,
20}
21
22impl StreamDataMessage {
23    /// Create a new stream data message.
24    pub fn new(stream_id: StreamId, data: Vec<u8>, eof: bool, compressed: bool) -> Self {
25        StreamDataMessage {
26            stream_id,
27            compressed,
28            eof,
29            data,
30        }
31    }
32
33    /// Pack the message: `[header:2 BE][data]`.
34    pub fn pack(&self) -> Vec<u8> {
35        let mut header_val: u16 = self.stream_id & 0x3FFF;
36        if self.eof {
37            header_val |= 0x8000;
38        }
39        if self.compressed {
40            header_val |= 0x4000;
41        }
42
43        let mut packed = Vec::with_capacity(2 + self.data.len());
44        packed.extend_from_slice(&header_val.to_be_bytes());
45        packed.extend_from_slice(&self.data);
46        packed
47    }
48
49    /// Unpack from raw bytes (decompresses if compressed flag is set).
50    pub fn unpack(raw: &[u8], compressor: &dyn Compressor) -> Result<Self, BufferError> {
51        Self::unpack_bounded(raw, compressor, usize::MAX)
52    }
53
54    /// Unpack from raw bytes with an explicit decompressed size limit.
55    pub fn unpack_bounded(
56        raw: &[u8],
57        compressor: &dyn Compressor,
58        max_decompressed_size: usize,
59    ) -> Result<Self, BufferError> {
60        if raw.len() < 2 {
61            return Err(BufferError::InvalidData);
62        }
63
64        let header = u16::from_be_bytes([raw[0], raw[1]]);
65        let eof = (header & 0x8000) != 0;
66        let compressed = (header & 0x4000) != 0;
67        let stream_id = header & 0x3FFF;
68
69        let mut data = raw[2..].to_vec();
70
71        if compressed {
72            data = compressor
73                .decompress_bounded(&data, max_decompressed_size)
74                .map_err(|_| BufferError::DecompressionFailed)?;
75        }
76
77        Ok(StreamDataMessage {
78            stream_id,
79            compressed,
80            eof,
81            data,
82        })
83    }
84
85    /// Maximum data length for a given link MDU.
86    pub fn max_data_len(link_mdu: usize) -> usize {
87        link_mdu.saturating_sub(STREAM_DATA_OVERHEAD)
88    }
89}
90
91/// Chunks data into StreamDataMessages.
92pub struct BufferWriter {
93    stream_id: StreamId,
94    closed: bool,
95}
96
97impl BufferWriter {
98    pub fn new(stream_id: StreamId) -> Self {
99        BufferWriter {
100            stream_id,
101            closed: false,
102        }
103    }
104
105    /// Write data → one or more StreamDataMessages.
106    ///
107    /// Tries compression if data > 32 bytes and compression reduces size.
108    pub fn write(
109        &mut self,
110        data: &[u8],
111        link_mdu: usize,
112        compressor: &dyn Compressor,
113    ) -> Vec<StreamDataMessage> {
114        if self.closed || data.is_empty() {
115            return Vec::new();
116        }
117
118        let max_data = StreamDataMessage::max_data_len(link_mdu);
119        if max_data == 0 {
120            return Vec::new();
121        }
122
123        let mut messages = Vec::new();
124        let mut offset = 0;
125
126        while offset < data.len() {
127            let end = (offset + max_data).min(data.len());
128            let chunk = &data[offset..end];
129
130            // Try compression for larger chunks
131            let (msg_data, compressed) = if chunk.len() > 32 {
132                if let Some(compressed_data) = compressor.compress(chunk) {
133                    if compressed_data.len() < chunk.len() && compressed_data.len() <= max_data {
134                        (compressed_data, true)
135                    } else {
136                        (chunk.to_vec(), false)
137                    }
138                } else {
139                    (chunk.to_vec(), false)
140                }
141            } else {
142                (chunk.to_vec(), false)
143            };
144
145            messages.push(StreamDataMessage::new(
146                self.stream_id,
147                msg_data,
148                false,
149                compressed,
150            ));
151
152            offset = end;
153        }
154
155        messages
156    }
157
158    /// Signal EOF → final StreamDataMessage with eof=true.
159    pub fn close(&mut self) -> StreamDataMessage {
160        self.closed = true;
161        StreamDataMessage::new(self.stream_id, Vec::new(), true, false)
162    }
163
164    pub fn is_closed(&self) -> bool {
165        self.closed
166    }
167}
168
169/// Reassembles a stream from messages.
170pub struct BufferReader {
171    stream_id: StreamId,
172    buffer: Vec<u8>,
173    eof: bool,
174}
175
176impl BufferReader {
177    pub fn new(stream_id: StreamId) -> Self {
178        BufferReader {
179            stream_id,
180            buffer: Vec::new(),
181            eof: false,
182        }
183    }
184
185    /// Receive a stream data message.
186    pub fn receive(&mut self, msg: &StreamDataMessage) {
187        if msg.stream_id != self.stream_id {
188            return;
189        }
190        if !msg.data.is_empty() {
191            self.buffer.extend_from_slice(&msg.data);
192        }
193        if msg.eof {
194            self.eof = true;
195        }
196    }
197
198    /// Read up to `max_bytes` from the buffer.
199    pub fn read(&mut self, max_bytes: usize) -> Vec<u8> {
200        let n = max_bytes.min(self.buffer.len());
201        let data: Vec<u8> = self.buffer.drain(..n).collect();
202        data
203    }
204
205    /// Number of bytes available to read.
206    pub fn available(&self) -> usize {
207        self.buffer.len()
208    }
209
210    /// Whether EOF has been received.
211    pub fn is_eof(&self) -> bool {
212        self.eof
213    }
214
215    /// Whether all data has been consumed (EOF received and buffer empty).
216    pub fn is_done(&self) -> bool {
217        self.eof && self.buffer.is_empty()
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_pack_unpack_roundtrip() {
227        let msg = StreamDataMessage::new(42, b"hello".to_vec(), false, false);
228        let packed = msg.pack();
229        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
230        assert_eq!(unpacked.stream_id, 42);
231        assert_eq!(unpacked.data, b"hello");
232        assert!(!unpacked.eof);
233        assert!(!unpacked.compressed);
234    }
235
236    #[test]
237    fn test_pack_unpack_eof() {
238        let msg = StreamDataMessage::new(0, Vec::new(), true, false);
239        let packed = msg.pack();
240        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
241        assert_eq!(unpacked.stream_id, 0);
242        assert!(unpacked.eof);
243        assert!(unpacked.data.is_empty());
244    }
245
246    #[test]
247    fn test_header_bit_layout() {
248        // stream_id = 0x1234, eof = true, compressed = true
249        let msg = StreamDataMessage::new(0x1234, vec![0xFF], true, true);
250        let packed = msg.pack();
251        let header = u16::from_be_bytes([packed[0], packed[1]]);
252        assert_eq!(header & 0x3FFF, 0x1234);
253        assert!(header & 0x8000 != 0); // eof
254        assert!(header & 0x4000 != 0); // compressed
255    }
256
257    #[test]
258    fn test_max_stream_id() {
259        let msg = StreamDataMessage::new(STREAM_ID_MAX, vec![0x42], false, false);
260        let packed = msg.pack();
261        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
262        assert_eq!(unpacked.stream_id, STREAM_ID_MAX);
263    }
264
265    #[test]
266    fn test_stream_id_overflow() {
267        // If stream_id > STREAM_ID_MAX, only lower 14 bits are used
268        let msg = StreamDataMessage::new(0xFFFF, vec![], false, false);
269        let packed = msg.pack();
270        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
271        assert_eq!(unpacked.stream_id, 0x3FFF);
272    }
273
274    #[test]
275    fn test_unpack_too_short() {
276        assert_eq!(
277            StreamDataMessage::unpack(&[0x00], &NoopCompressor),
278            Err(BufferError::InvalidData)
279        );
280    }
281
282    #[test]
283    fn test_max_data_len() {
284        let mdl = StreamDataMessage::max_data_len(431);
285        assert_eq!(mdl, 431 - STREAM_DATA_OVERHEAD);
286    }
287
288    #[test]
289    fn test_writer_single_chunk() {
290        let mut writer = BufferWriter::new(1);
291        let data = vec![0x42; 100];
292        let msgs = writer.write(&data, 431, &NoopCompressor);
293        assert_eq!(msgs.len(), 1);
294        assert_eq!(msgs[0].data, data);
295        assert_eq!(msgs[0].stream_id, 1);
296        assert!(!msgs[0].eof);
297    }
298
299    #[test]
300    fn test_writer_chunking() {
301        let mut writer = BufferWriter::new(1);
302        let data = vec![0x42; 1000];
303        // Use small MDU to force multiple chunks
304        let msgs = writer.write(&data, 50, &NoopCompressor);
305        let max_data = StreamDataMessage::max_data_len(50);
306        assert!(msgs.len() > 1);
307
308        // Verify total data equals original
309        let total: Vec<u8> = msgs.iter().flat_map(|m| m.data.clone()).collect();
310        assert_eq!(total, data);
311
312        // Each chunk should be at most max_data
313        for msg in &msgs {
314            assert!(msg.data.len() <= max_data);
315        }
316    }
317
318    #[test]
319    fn test_writer_close() {
320        let mut writer = BufferWriter::new(5);
321        let msg = writer.close();
322        assert!(msg.eof);
323        assert!(msg.data.is_empty());
324        assert_eq!(msg.stream_id, 5);
325        assert!(writer.is_closed());
326    }
327
328    #[test]
329    fn test_writer_no_write_after_close() {
330        let mut writer = BufferWriter::new(1);
331        writer.close();
332        let msgs = writer.write(b"test", 431, &NoopCompressor);
333        assert!(msgs.is_empty());
334    }
335
336    #[test]
337    fn test_reader_reassembly() {
338        let mut reader = BufferReader::new(1);
339        let msg1 = StreamDataMessage::new(1, b"hello ".to_vec(), false, false);
340        let msg2 = StreamDataMessage::new(1, b"world".to_vec(), false, false);
341        let eof = StreamDataMessage::new(1, Vec::new(), true, false);
342
343        reader.receive(&msg1);
344        reader.receive(&msg2);
345        assert_eq!(reader.available(), 11);
346        assert!(!reader.is_eof());
347
348        reader.receive(&eof);
349        assert!(reader.is_eof());
350
351        let data = reader.read(100);
352        assert_eq!(data, b"hello world");
353        assert!(reader.is_done());
354    }
355
356    #[test]
357    fn test_reader_partial_read() {
358        let mut reader = BufferReader::new(1);
359        let msg = StreamDataMessage::new(1, b"abcdefgh".to_vec(), false, false);
360        reader.receive(&msg);
361
362        let first = reader.read(4);
363        assert_eq!(first, b"abcd");
364        assert_eq!(reader.available(), 4);
365
366        let rest = reader.read(100);
367        assert_eq!(rest, b"efgh");
368        assert_eq!(reader.available(), 0);
369    }
370
371    #[test]
372    fn test_reader_ignores_wrong_stream() {
373        let mut reader = BufferReader::new(1);
374        let msg = StreamDataMessage::new(2, b"wrong".to_vec(), false, false);
375        reader.receive(&msg);
376        assert_eq!(reader.available(), 0);
377    }
378
379    #[test]
380    fn test_writer_empty_data() {
381        let mut writer = BufferWriter::new(1);
382        let msgs = writer.write(&[], 431, &NoopCompressor);
383        assert!(msgs.is_empty());
384    }
385
386    // Test with a mock compressor
387    struct HalfCompressor;
388    impl Compressor for HalfCompressor {
389        fn compress(&self, data: &[u8]) -> Option<Vec<u8>> {
390            // "Compress" by taking first half
391            Some(data[..data.len() / 2].to_vec())
392        }
393        fn decompress(&self, data: &[u8]) -> Option<Vec<u8>> {
394            // "Decompress" by doubling
395            let mut out = data.to_vec();
396            out.extend_from_slice(data);
397            Some(out)
398        }
399        fn decompress_bounded(
400            &self,
401            data: &[u8],
402            max_output_size: usize,
403        ) -> Result<Vec<u8>, DecompressError> {
404            let out = self.decompress(data).ok_or(DecompressError::InvalidData)?;
405            if out.len() > max_output_size {
406                return Err(DecompressError::TooLarge);
407            }
408            Ok(out)
409        }
410    }
411
412    #[test]
413    fn test_compression_used_when_smaller() {
414        let mut writer = BufferWriter::new(1);
415        let data = vec![0x42; 100]; // > 32 bytes, compression will be tried
416        let msgs = writer.write(&data, 431, &HalfCompressor);
417        assert_eq!(msgs.len(), 1);
418        assert!(msgs[0].compressed);
419        assert_eq!(msgs[0].data.len(), 50); // half
420    }
421
422    #[test]
423    fn test_compressed_unpack() {
424        let msg = StreamDataMessage::new(1, b"compressed".to_vec(), false, true);
425        let packed = msg.pack();
426        let unpacked = StreamDataMessage::unpack(&packed, &HalfCompressor).unwrap();
427        // HalfCompressor doubles data on decompress
428        assert_eq!(unpacked.data, b"compressedcompressed");
429    }
430
431    #[test]
432    fn test_compressed_unpack_bounded_rejects_oversized_output() {
433        let msg = StreamDataMessage::new(1, b"compressed".to_vec(), false, true);
434        let packed = msg.pack();
435        assert_eq!(
436            StreamDataMessage::unpack_bounded(&packed, &HalfCompressor, 8),
437            Err(BufferError::DecompressionFailed)
438        );
439    }
440
441    #[test]
442    fn test_compressed_unpack_bounded_accepts_exact_limit() {
443        let msg = StreamDataMessage::new(1, b"compressed".to_vec(), false, true);
444        let packed = msg.pack();
445        let unpacked = StreamDataMessage::unpack_bounded(&packed, &HalfCompressor, 20).unwrap();
446        assert_eq!(unpacked.data, b"compressedcompressed");
447    }
448}