Skip to main content

rns_core/buffer/
mod.rs

1pub mod types;
2
3use alloc::vec::Vec;
4
5use crate::constants::STREAM_DATA_OVERHEAD;
6#[cfg(test)]
7use crate::constants::STREAM_ID_MAX;
8
9pub use types::{BufferError, Compressor, NoopCompressor, StreamId};
10
11/// Stream data message: 2-byte header + data.
12///
13/// Header format: `(stream_id & 0x3FFF) | (eof << 15) | (compressed << 14)`
14#[derive(Debug, Clone, PartialEq)]
15pub struct StreamDataMessage {
16    pub stream_id: StreamId,
17    pub compressed: bool,
18    pub eof: bool,
19    pub data: Vec<u8>,
20}
21
22impl StreamDataMessage {
23    /// Create a new stream data message.
24    pub fn new(stream_id: StreamId, data: Vec<u8>, eof: bool, compressed: bool) -> Self {
25        StreamDataMessage {
26            stream_id,
27            compressed,
28            eof,
29            data,
30        }
31    }
32
33    /// Pack the message: `[header:2 BE][data]`.
34    pub fn pack(&self) -> Vec<u8> {
35        let mut header_val: u16 = self.stream_id & 0x3FFF;
36        if self.eof {
37            header_val |= 0x8000;
38        }
39        if self.compressed {
40            header_val |= 0x4000;
41        }
42
43        let mut packed = Vec::with_capacity(2 + self.data.len());
44        packed.extend_from_slice(&header_val.to_be_bytes());
45        packed.extend_from_slice(&self.data);
46        packed
47    }
48
49    /// Unpack from raw bytes (decompresses if compressed flag is set).
50    pub fn unpack(raw: &[u8], compressor: &dyn Compressor) -> Result<Self, BufferError> {
51        if raw.len() < 2 {
52            return Err(BufferError::InvalidData);
53        }
54
55        let header = u16::from_be_bytes([raw[0], raw[1]]);
56        let eof = (header & 0x8000) != 0;
57        let compressed = (header & 0x4000) != 0;
58        let stream_id = header & 0x3FFF;
59
60        let mut data = raw[2..].to_vec();
61
62        if compressed {
63            data = compressor.decompress(&data)
64                .ok_or(BufferError::DecompressionFailed)?;
65        }
66
67        Ok(StreamDataMessage {
68            stream_id,
69            compressed,
70            eof,
71            data,
72        })
73    }
74
75    /// Maximum data length for a given link MDU.
76    pub fn max_data_len(link_mdu: usize) -> usize {
77        link_mdu.saturating_sub(STREAM_DATA_OVERHEAD)
78    }
79}
80
81/// Chunks data into StreamDataMessages.
82pub struct BufferWriter {
83    stream_id: StreamId,
84    closed: bool,
85}
86
87impl BufferWriter {
88    pub fn new(stream_id: StreamId) -> Self {
89        BufferWriter {
90            stream_id,
91            closed: false,
92        }
93    }
94
95    /// Write data → one or more StreamDataMessages.
96    ///
97    /// Tries compression if data > 32 bytes and compression reduces size.
98    pub fn write(
99        &mut self,
100        data: &[u8],
101        link_mdu: usize,
102        compressor: &dyn Compressor,
103    ) -> Vec<StreamDataMessage> {
104        if self.closed || data.is_empty() {
105            return Vec::new();
106        }
107
108        let max_data = StreamDataMessage::max_data_len(link_mdu);
109        if max_data == 0 {
110            return Vec::new();
111        }
112
113        let mut messages = Vec::new();
114        let mut offset = 0;
115
116        while offset < data.len() {
117            let end = (offset + max_data).min(data.len());
118            let chunk = &data[offset..end];
119
120            // Try compression for larger chunks
121            let (msg_data, compressed) = if chunk.len() > 32 {
122                if let Some(compressed_data) = compressor.compress(chunk) {
123                    if compressed_data.len() < chunk.len() && compressed_data.len() <= max_data {
124                        (compressed_data, true)
125                    } else {
126                        (chunk.to_vec(), false)
127                    }
128                } else {
129                    (chunk.to_vec(), false)
130                }
131            } else {
132                (chunk.to_vec(), false)
133            };
134
135            messages.push(StreamDataMessage::new(
136                self.stream_id,
137                msg_data,
138                false,
139                compressed,
140            ));
141
142            offset = end;
143        }
144
145        messages
146    }
147
148    /// Signal EOF → final StreamDataMessage with eof=true.
149    pub fn close(&mut self) -> StreamDataMessage {
150        self.closed = true;
151        StreamDataMessage::new(self.stream_id, Vec::new(), true, false)
152    }
153
154    pub fn is_closed(&self) -> bool {
155        self.closed
156    }
157}
158
159/// Reassembles a stream from messages.
160pub struct BufferReader {
161    stream_id: StreamId,
162    buffer: Vec<u8>,
163    eof: bool,
164}
165
166impl BufferReader {
167    pub fn new(stream_id: StreamId) -> Self {
168        BufferReader {
169            stream_id,
170            buffer: Vec::new(),
171            eof: false,
172        }
173    }
174
175    /// Receive a stream data message.
176    pub fn receive(&mut self, msg: &StreamDataMessage) {
177        if msg.stream_id != self.stream_id {
178            return;
179        }
180        if !msg.data.is_empty() {
181            self.buffer.extend_from_slice(&msg.data);
182        }
183        if msg.eof {
184            self.eof = true;
185        }
186    }
187
188    /// Read up to `max_bytes` from the buffer.
189    pub fn read(&mut self, max_bytes: usize) -> Vec<u8> {
190        let n = max_bytes.min(self.buffer.len());
191        let data: Vec<u8> = self.buffer.drain(..n).collect();
192        data
193    }
194
195    /// Number of bytes available to read.
196    pub fn available(&self) -> usize {
197        self.buffer.len()
198    }
199
200    /// Whether EOF has been received.
201    pub fn is_eof(&self) -> bool {
202        self.eof
203    }
204
205    /// Whether all data has been consumed (EOF received and buffer empty).
206    pub fn is_done(&self) -> bool {
207        self.eof && self.buffer.is_empty()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_pack_unpack_roundtrip() {
217        let msg = StreamDataMessage::new(42, b"hello".to_vec(), false, false);
218        let packed = msg.pack();
219        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
220        assert_eq!(unpacked.stream_id, 42);
221        assert_eq!(unpacked.data, b"hello");
222        assert!(!unpacked.eof);
223        assert!(!unpacked.compressed);
224    }
225
226    #[test]
227    fn test_pack_unpack_eof() {
228        let msg = StreamDataMessage::new(0, Vec::new(), true, false);
229        let packed = msg.pack();
230        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
231        assert_eq!(unpacked.stream_id, 0);
232        assert!(unpacked.eof);
233        assert!(unpacked.data.is_empty());
234    }
235
236    #[test]
237    fn test_header_bit_layout() {
238        // stream_id = 0x1234, eof = true, compressed = true
239        let msg = StreamDataMessage::new(0x1234, vec![0xFF], true, true);
240        let packed = msg.pack();
241        let header = u16::from_be_bytes([packed[0], packed[1]]);
242        assert_eq!(header & 0x3FFF, 0x1234);
243        assert!(header & 0x8000 != 0); // eof
244        assert!(header & 0x4000 != 0); // compressed
245    }
246
247    #[test]
248    fn test_max_stream_id() {
249        let msg = StreamDataMessage::new(STREAM_ID_MAX, vec![0x42], false, false);
250        let packed = msg.pack();
251        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
252        assert_eq!(unpacked.stream_id, STREAM_ID_MAX);
253    }
254
255    #[test]
256    fn test_stream_id_overflow() {
257        // If stream_id > STREAM_ID_MAX, only lower 14 bits are used
258        let msg = StreamDataMessage::new(0xFFFF, vec![], false, false);
259        let packed = msg.pack();
260        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
261        assert_eq!(unpacked.stream_id, 0x3FFF);
262    }
263
264    #[test]
265    fn test_unpack_too_short() {
266        assert_eq!(StreamDataMessage::unpack(&[0x00], &NoopCompressor), Err(BufferError::InvalidData));
267    }
268
269    #[test]
270    fn test_max_data_len() {
271        let mdl = StreamDataMessage::max_data_len(431);
272        assert_eq!(mdl, 431 - STREAM_DATA_OVERHEAD);
273    }
274
275    #[test]
276    fn test_writer_single_chunk() {
277        let mut writer = BufferWriter::new(1);
278        let data = vec![0x42; 100];
279        let msgs = writer.write(&data, 431, &NoopCompressor);
280        assert_eq!(msgs.len(), 1);
281        assert_eq!(msgs[0].data, data);
282        assert_eq!(msgs[0].stream_id, 1);
283        assert!(!msgs[0].eof);
284    }
285
286    #[test]
287    fn test_writer_chunking() {
288        let mut writer = BufferWriter::new(1);
289        let data = vec![0x42; 1000];
290        // Use small MDU to force multiple chunks
291        let msgs = writer.write(&data, 50, &NoopCompressor);
292        let max_data = StreamDataMessage::max_data_len(50);
293        assert!(msgs.len() > 1);
294
295        // Verify total data equals original
296        let total: Vec<u8> = msgs.iter().flat_map(|m| m.data.clone()).collect();
297        assert_eq!(total, data);
298
299        // Each chunk should be at most max_data
300        for msg in &msgs {
301            assert!(msg.data.len() <= max_data);
302        }
303    }
304
305    #[test]
306    fn test_writer_close() {
307        let mut writer = BufferWriter::new(5);
308        let msg = writer.close();
309        assert!(msg.eof);
310        assert!(msg.data.is_empty());
311        assert_eq!(msg.stream_id, 5);
312        assert!(writer.is_closed());
313    }
314
315    #[test]
316    fn test_writer_no_write_after_close() {
317        let mut writer = BufferWriter::new(1);
318        writer.close();
319        let msgs = writer.write(b"test", 431, &NoopCompressor);
320        assert!(msgs.is_empty());
321    }
322
323    #[test]
324    fn test_reader_reassembly() {
325        let mut reader = BufferReader::new(1);
326        let msg1 = StreamDataMessage::new(1, b"hello ".to_vec(), false, false);
327        let msg2 = StreamDataMessage::new(1, b"world".to_vec(), false, false);
328        let eof = StreamDataMessage::new(1, Vec::new(), true, false);
329
330        reader.receive(&msg1);
331        reader.receive(&msg2);
332        assert_eq!(reader.available(), 11);
333        assert!(!reader.is_eof());
334
335        reader.receive(&eof);
336        assert!(reader.is_eof());
337
338        let data = reader.read(100);
339        assert_eq!(data, b"hello world");
340        assert!(reader.is_done());
341    }
342
343    #[test]
344    fn test_reader_partial_read() {
345        let mut reader = BufferReader::new(1);
346        let msg = StreamDataMessage::new(1, b"abcdefgh".to_vec(), false, false);
347        reader.receive(&msg);
348
349        let first = reader.read(4);
350        assert_eq!(first, b"abcd");
351        assert_eq!(reader.available(), 4);
352
353        let rest = reader.read(100);
354        assert_eq!(rest, b"efgh");
355        assert_eq!(reader.available(), 0);
356    }
357
358    #[test]
359    fn test_reader_ignores_wrong_stream() {
360        let mut reader = BufferReader::new(1);
361        let msg = StreamDataMessage::new(2, b"wrong".to_vec(), false, false);
362        reader.receive(&msg);
363        assert_eq!(reader.available(), 0);
364    }
365
366    #[test]
367    fn test_writer_empty_data() {
368        let mut writer = BufferWriter::new(1);
369        let msgs = writer.write(&[], 431, &NoopCompressor);
370        assert!(msgs.is_empty());
371    }
372
373    // Test with a mock compressor
374    struct HalfCompressor;
375    impl Compressor for HalfCompressor {
376        fn compress(&self, data: &[u8]) -> Option<Vec<u8>> {
377            // "Compress" by taking first half
378            Some(data[..data.len() / 2].to_vec())
379        }
380        fn decompress(&self, data: &[u8]) -> Option<Vec<u8>> {
381            // "Decompress" by doubling
382            let mut out = data.to_vec();
383            out.extend_from_slice(data);
384            Some(out)
385        }
386    }
387
388    #[test]
389    fn test_compression_used_when_smaller() {
390        let mut writer = BufferWriter::new(1);
391        let data = vec![0x42; 100]; // > 32 bytes, compression will be tried
392        let msgs = writer.write(&data, 431, &HalfCompressor);
393        assert_eq!(msgs.len(), 1);
394        assert!(msgs[0].compressed);
395        assert_eq!(msgs[0].data.len(), 50); // half
396    }
397
398    #[test]
399    fn test_compressed_unpack() {
400        let msg = StreamDataMessage::new(1, b"compressed".to_vec(), false, true);
401        let packed = msg.pack();
402        let unpacked = StreamDataMessage::unpack(&packed, &HalfCompressor).unwrap();
403        // HalfCompressor doubles data on decompress
404        assert_eq!(unpacked.data, b"compressedcompressed");
405    }
406}