Skip to main content

rns_core/buffer/
mod.rs

1pub mod types;
2
3use alloc::vec::Vec;
4
5use crate::constants::STREAM_DATA_OVERHEAD;
6#[cfg(test)]
7use crate::constants::STREAM_ID_MAX;
8
9pub use types::{BufferError, Compressor, NoopCompressor, StreamId};
10
11/// Stream data message: 2-byte header + data.
12///
13/// Header format: `(stream_id & 0x3FFF) | (eof << 15) | (compressed << 14)`
14#[derive(Debug, Clone, PartialEq)]
15pub struct StreamDataMessage {
16    pub stream_id: StreamId,
17    pub compressed: bool,
18    pub eof: bool,
19    pub data: Vec<u8>,
20}
21
22impl StreamDataMessage {
23    /// Create a new stream data message.
24    pub fn new(stream_id: StreamId, data: Vec<u8>, eof: bool, compressed: bool) -> Self {
25        StreamDataMessage {
26            stream_id,
27            compressed,
28            eof,
29            data,
30        }
31    }
32
33    /// Pack the message: `[header:2 BE][data]`.
34    pub fn pack(&self) -> Vec<u8> {
35        let mut header_val: u16 = self.stream_id & 0x3FFF;
36        if self.eof {
37            header_val |= 0x8000;
38        }
39        if self.compressed {
40            header_val |= 0x4000;
41        }
42
43        let mut packed = Vec::with_capacity(2 + self.data.len());
44        packed.extend_from_slice(&header_val.to_be_bytes());
45        packed.extend_from_slice(&self.data);
46        packed
47    }
48
49    /// Unpack from raw bytes (decompresses if compressed flag is set).
50    pub fn unpack(raw: &[u8], compressor: &dyn Compressor) -> Result<Self, BufferError> {
51        if raw.len() < 2 {
52            return Err(BufferError::InvalidData);
53        }
54
55        let header = u16::from_be_bytes([raw[0], raw[1]]);
56        let eof = (header & 0x8000) != 0;
57        let compressed = (header & 0x4000) != 0;
58        let stream_id = header & 0x3FFF;
59
60        let mut data = raw[2..].to_vec();
61
62        if compressed {
63            data = compressor
64                .decompress(&data)
65                .ok_or(BufferError::DecompressionFailed)?;
66        }
67
68        Ok(StreamDataMessage {
69            stream_id,
70            compressed,
71            eof,
72            data,
73        })
74    }
75
76    /// Maximum data length for a given link MDU.
77    pub fn max_data_len(link_mdu: usize) -> usize {
78        link_mdu.saturating_sub(STREAM_DATA_OVERHEAD)
79    }
80}
81
82/// Chunks data into StreamDataMessages.
83pub struct BufferWriter {
84    stream_id: StreamId,
85    closed: bool,
86}
87
88impl BufferWriter {
89    pub fn new(stream_id: StreamId) -> Self {
90        BufferWriter {
91            stream_id,
92            closed: false,
93        }
94    }
95
96    /// Write data → one or more StreamDataMessages.
97    ///
98    /// Tries compression if data > 32 bytes and compression reduces size.
99    pub fn write(
100        &mut self,
101        data: &[u8],
102        link_mdu: usize,
103        compressor: &dyn Compressor,
104    ) -> Vec<StreamDataMessage> {
105        if self.closed || data.is_empty() {
106            return Vec::new();
107        }
108
109        let max_data = StreamDataMessage::max_data_len(link_mdu);
110        if max_data == 0 {
111            return Vec::new();
112        }
113
114        let mut messages = Vec::new();
115        let mut offset = 0;
116
117        while offset < data.len() {
118            let end = (offset + max_data).min(data.len());
119            let chunk = &data[offset..end];
120
121            // Try compression for larger chunks
122            let (msg_data, compressed) = if chunk.len() > 32 {
123                if let Some(compressed_data) = compressor.compress(chunk) {
124                    if compressed_data.len() < chunk.len() && compressed_data.len() <= max_data {
125                        (compressed_data, true)
126                    } else {
127                        (chunk.to_vec(), false)
128                    }
129                } else {
130                    (chunk.to_vec(), false)
131                }
132            } else {
133                (chunk.to_vec(), false)
134            };
135
136            messages.push(StreamDataMessage::new(
137                self.stream_id,
138                msg_data,
139                false,
140                compressed,
141            ));
142
143            offset = end;
144        }
145
146        messages
147    }
148
149    /// Signal EOF → final StreamDataMessage with eof=true.
150    pub fn close(&mut self) -> StreamDataMessage {
151        self.closed = true;
152        StreamDataMessage::new(self.stream_id, Vec::new(), true, false)
153    }
154
155    pub fn is_closed(&self) -> bool {
156        self.closed
157    }
158}
159
160/// Reassembles a stream from messages.
161pub struct BufferReader {
162    stream_id: StreamId,
163    buffer: Vec<u8>,
164    eof: bool,
165}
166
167impl BufferReader {
168    pub fn new(stream_id: StreamId) -> Self {
169        BufferReader {
170            stream_id,
171            buffer: Vec::new(),
172            eof: false,
173        }
174    }
175
176    /// Receive a stream data message.
177    pub fn receive(&mut self, msg: &StreamDataMessage) {
178        if msg.stream_id != self.stream_id {
179            return;
180        }
181        if !msg.data.is_empty() {
182            self.buffer.extend_from_slice(&msg.data);
183        }
184        if msg.eof {
185            self.eof = true;
186        }
187    }
188
189    /// Read up to `max_bytes` from the buffer.
190    pub fn read(&mut self, max_bytes: usize) -> Vec<u8> {
191        let n = max_bytes.min(self.buffer.len());
192        let data: Vec<u8> = self.buffer.drain(..n).collect();
193        data
194    }
195
196    /// Number of bytes available to read.
197    pub fn available(&self) -> usize {
198        self.buffer.len()
199    }
200
201    /// Whether EOF has been received.
202    pub fn is_eof(&self) -> bool {
203        self.eof
204    }
205
206    /// Whether all data has been consumed (EOF received and buffer empty).
207    pub fn is_done(&self) -> bool {
208        self.eof && self.buffer.is_empty()
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_pack_unpack_roundtrip() {
218        let msg = StreamDataMessage::new(42, b"hello".to_vec(), false, false);
219        let packed = msg.pack();
220        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
221        assert_eq!(unpacked.stream_id, 42);
222        assert_eq!(unpacked.data, b"hello");
223        assert!(!unpacked.eof);
224        assert!(!unpacked.compressed);
225    }
226
227    #[test]
228    fn test_pack_unpack_eof() {
229        let msg = StreamDataMessage::new(0, Vec::new(), true, false);
230        let packed = msg.pack();
231        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
232        assert_eq!(unpacked.stream_id, 0);
233        assert!(unpacked.eof);
234        assert!(unpacked.data.is_empty());
235    }
236
237    #[test]
238    fn test_header_bit_layout() {
239        // stream_id = 0x1234, eof = true, compressed = true
240        let msg = StreamDataMessage::new(0x1234, vec![0xFF], true, true);
241        let packed = msg.pack();
242        let header = u16::from_be_bytes([packed[0], packed[1]]);
243        assert_eq!(header & 0x3FFF, 0x1234);
244        assert!(header & 0x8000 != 0); // eof
245        assert!(header & 0x4000 != 0); // compressed
246    }
247
248    #[test]
249    fn test_max_stream_id() {
250        let msg = StreamDataMessage::new(STREAM_ID_MAX, vec![0x42], false, false);
251        let packed = msg.pack();
252        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
253        assert_eq!(unpacked.stream_id, STREAM_ID_MAX);
254    }
255
256    #[test]
257    fn test_stream_id_overflow() {
258        // If stream_id > STREAM_ID_MAX, only lower 14 bits are used
259        let msg = StreamDataMessage::new(0xFFFF, vec![], false, false);
260        let packed = msg.pack();
261        let unpacked = StreamDataMessage::unpack(&packed, &NoopCompressor).unwrap();
262        assert_eq!(unpacked.stream_id, 0x3FFF);
263    }
264
265    #[test]
266    fn test_unpack_too_short() {
267        assert_eq!(
268            StreamDataMessage::unpack(&[0x00], &NoopCompressor),
269            Err(BufferError::InvalidData)
270        );
271    }
272
273    #[test]
274    fn test_max_data_len() {
275        let mdl = StreamDataMessage::max_data_len(431);
276        assert_eq!(mdl, 431 - STREAM_DATA_OVERHEAD);
277    }
278
279    #[test]
280    fn test_writer_single_chunk() {
281        let mut writer = BufferWriter::new(1);
282        let data = vec![0x42; 100];
283        let msgs = writer.write(&data, 431, &NoopCompressor);
284        assert_eq!(msgs.len(), 1);
285        assert_eq!(msgs[0].data, data);
286        assert_eq!(msgs[0].stream_id, 1);
287        assert!(!msgs[0].eof);
288    }
289
290    #[test]
291    fn test_writer_chunking() {
292        let mut writer = BufferWriter::new(1);
293        let data = vec![0x42; 1000];
294        // Use small MDU to force multiple chunks
295        let msgs = writer.write(&data, 50, &NoopCompressor);
296        let max_data = StreamDataMessage::max_data_len(50);
297        assert!(msgs.len() > 1);
298
299        // Verify total data equals original
300        let total: Vec<u8> = msgs.iter().flat_map(|m| m.data.clone()).collect();
301        assert_eq!(total, data);
302
303        // Each chunk should be at most max_data
304        for msg in &msgs {
305            assert!(msg.data.len() <= max_data);
306        }
307    }
308
309    #[test]
310    fn test_writer_close() {
311        let mut writer = BufferWriter::new(5);
312        let msg = writer.close();
313        assert!(msg.eof);
314        assert!(msg.data.is_empty());
315        assert_eq!(msg.stream_id, 5);
316        assert!(writer.is_closed());
317    }
318
319    #[test]
320    fn test_writer_no_write_after_close() {
321        let mut writer = BufferWriter::new(1);
322        writer.close();
323        let msgs = writer.write(b"test", 431, &NoopCompressor);
324        assert!(msgs.is_empty());
325    }
326
327    #[test]
328    fn test_reader_reassembly() {
329        let mut reader = BufferReader::new(1);
330        let msg1 = StreamDataMessage::new(1, b"hello ".to_vec(), false, false);
331        let msg2 = StreamDataMessage::new(1, b"world".to_vec(), false, false);
332        let eof = StreamDataMessage::new(1, Vec::new(), true, false);
333
334        reader.receive(&msg1);
335        reader.receive(&msg2);
336        assert_eq!(reader.available(), 11);
337        assert!(!reader.is_eof());
338
339        reader.receive(&eof);
340        assert!(reader.is_eof());
341
342        let data = reader.read(100);
343        assert_eq!(data, b"hello world");
344        assert!(reader.is_done());
345    }
346
347    #[test]
348    fn test_reader_partial_read() {
349        let mut reader = BufferReader::new(1);
350        let msg = StreamDataMessage::new(1, b"abcdefgh".to_vec(), false, false);
351        reader.receive(&msg);
352
353        let first = reader.read(4);
354        assert_eq!(first, b"abcd");
355        assert_eq!(reader.available(), 4);
356
357        let rest = reader.read(100);
358        assert_eq!(rest, b"efgh");
359        assert_eq!(reader.available(), 0);
360    }
361
362    #[test]
363    fn test_reader_ignores_wrong_stream() {
364        let mut reader = BufferReader::new(1);
365        let msg = StreamDataMessage::new(2, b"wrong".to_vec(), false, false);
366        reader.receive(&msg);
367        assert_eq!(reader.available(), 0);
368    }
369
370    #[test]
371    fn test_writer_empty_data() {
372        let mut writer = BufferWriter::new(1);
373        let msgs = writer.write(&[], 431, &NoopCompressor);
374        assert!(msgs.is_empty());
375    }
376
377    // Test with a mock compressor
378    struct HalfCompressor;
379    impl Compressor for HalfCompressor {
380        fn compress(&self, data: &[u8]) -> Option<Vec<u8>> {
381            // "Compress" by taking first half
382            Some(data[..data.len() / 2].to_vec())
383        }
384        fn decompress(&self, data: &[u8]) -> Option<Vec<u8>> {
385            // "Decompress" by doubling
386            let mut out = data.to_vec();
387            out.extend_from_slice(data);
388            Some(out)
389        }
390    }
391
392    #[test]
393    fn test_compression_used_when_smaller() {
394        let mut writer = BufferWriter::new(1);
395        let data = vec![0x42; 100]; // > 32 bytes, compression will be tried
396        let msgs = writer.write(&data, 431, &HalfCompressor);
397        assert_eq!(msgs.len(), 1);
398        assert!(msgs[0].compressed);
399        assert_eq!(msgs[0].data.len(), 50); // half
400    }
401
402    #[test]
403    fn test_compressed_unpack() {
404        let msg = StreamDataMessage::new(1, b"compressed".to_vec(), false, true);
405        let packed = msg.pack();
406        let unpacked = StreamDataMessage::unpack(&packed, &HalfCompressor).unwrap();
407        // HalfCompressor doubles data on decompress
408        assert_eq!(unpacked.data, b"compressedcompressed");
409    }
410}