Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    /// If crypto is set, it will return the unencrypted (cached) slice. This is different from other `slice` methods which always returns the raw buffer slice.
40    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41        let start = self.marker.unwrap_or(0);
42        let end = to.unwrap_or(self.byte_pos());
43
44        if let Some(crypto) = self.crypto.as_ref() {
45            return &crypto.get_cached(true)[start..end];
46        }
47
48        &self.buffer[start..end]
49    }
50
51    /// Set crypto stream
52    pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53        if let Some(new) = crypto.as_mut() {
54            if let Some(existing) = self.crypto.as_ref() {
55                new.replace(existing);
56            } else {
57                // Initialize new crypto stream with the data written so far, so that it can cache full plaintext
58                new.set_cached(self.slice());
59            }
60        }
61
62        self.crypto = crypto;
63    }
64
65     /// Remove crypto stream
66    pub fn reset_crypto(&mut self) {
67        self.crypto = None;
68    }
69
70    /// Get byte position of writer
71    pub fn byte_pos(&self) -> usize {
72        self.bit_pos / 8
73    }
74
75    /// Write a single bit
76    pub fn write_bit(&mut self, val: bool) {
77        self.write_small(val as u8, 1);
78    }
79
80    /// Write 1-8 bits (LSB-first)
81    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
82        assert!(bits > 0 && bits < 8);
83
84        while bits > 0 {
85            self.ensure_byte();
86
87            // Determine the current bit position inside the current byte (0..7)
88            let bit_offset = self.bit_pos % 8;
89
90            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
91            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93            // Create a mask for the bits we are going to modify.
94            // Example: bit_offset = 2, bits_in_current_byte = 3
95            // mask = 00011100 (we will only touch bits 2,3,4)
96            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97
98            // Prepare the bits to write:
99            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
100            // 2️⃣ Shift them left by bit_offset to align inside the byte
101            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
102            // Step 1: val & 0b111 = 0b101
103            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
104            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
105
106            let byte_pos = self.byte_pos();
107
108            // Clear the bits in this byte where we will write
109            self.buffer[byte_pos] &= !mask;
110
111            // Write the new bits into the cleared positions
112            self.buffer[byte_pos] |= shifted_val & mask;
113
114            // Decrease the number of bits remaining to write
115            bits -= bits_in_current_byte;
116
117            // Shift val right to remove the bits we've already written
118            val >>= bits_in_current_byte;
119
120            self.bit_pos += bits_in_current_byte as usize;
121
122            // If full byte, encrypt it (if needed)
123            if self.bit_pos % 8 == 0 {
124                if let Some(crypto) = self.crypto.as_mut() {
125                    let b = self.buffer[byte_pos];
126                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
127                }
128            }
129        }
130    }
131
132    /// Write a full byte, starting at the next byte boundary
133    pub fn write_byte(&mut self, byte: u8) {
134        self.align_byte();
135        self.ensure_byte();
136
137        let byte_pos = self.byte_pos();
138        let byte = if let Some(crypto) = self.crypto.as_mut() {
139            crypto.apply_keystream_byte(byte)
140        } else {
141            byte
142        };
143
144        self.buffer[byte_pos] = byte;
145        self.bit_pos += 8;
146    }
147
148    /// Write a slice of bytes, starting at the next byte boundary
149    pub fn write_bytes(&mut self, data: &[u8]) {
150        self.align_byte();
151
152        if let Some(crypto) = self.crypto.as_mut() {
153            let encrypted = crypto.apply_keystream(data);
154            self.buffer.extend_from_slice(encrypted);
155        } else {
156            self.buffer.extend_from_slice(data);
157        }
158
159        self.bit_pos += 8 * data.len();
160    }
161
162    /// Write a dynamic int, starting at the next byte bounary
163    /// The last bit is used as a continuation flag for the next byte
164    pub fn write_dyn_int(&mut self, mut val: u128) {
165        if val == 0 {
166            self.write_byte(0);
167            return;
168        }
169        
170        while val > 0 {
171            let mut encoded = val % 128;
172            val /= 128;
173            if val > 0 {
174                encoded |= 128;
175            }
176            self.write_byte(encoded as u8);
177        }
178    }
179
180    /// Write a integer of fixed size to the buffer
181    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
182        self.write_bytes(&val.serialize());
183    }
184
185    /// Ensure the buffer has at least `byte_pos + 1` bytes
186    fn ensure_byte(&mut self) {
187        let byte_pos = self.byte_pos();
188        if byte_pos >= self.buffer.len() {
189            self.buffer.resize(byte_pos + 1, 0);
190        }
191    }
192
193    /// Align the stream to the next byte boundary
194    pub fn align_byte(&mut self) {
195        let rem = self.bit_pos % 8;
196        if rem != 0 {
197            let byte_pos = self.byte_pos();
198            self.bit_pos += 8 - rem;
199
200            // Encrypt byte
201            if let Some(crypto) = self.crypto.as_mut() {
202                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
203            }
204        }
205    }
206
207    /// Reset writing position
208    pub fn reset(&mut self) {
209        self.bit_pos = 0;
210    }
211
212    /// Get buffer length
213    pub fn len(&self) -> usize {
214        self.buffer.len()
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use crate::CryptoStream;
221
222    use super::BitStreamWriter;
223
224    struct PlusOneEncrypter {
225        ciphertext: Vec<u8>
226    }
227
228    impl CryptoStream for PlusOneEncrypter {
229        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
230            self.ciphertext.push(b + 1);
231            *self.ciphertext.last().unwrap()
232        }
233    
234        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
235            let d = slice.iter().map(|s|s + 1);
236            self.ciphertext.extend(d);
237            &self.ciphertext[self.ciphertext.len() - slice.len()..]
238        }
239
240        fn get_cached(&self, original: bool) -> &[u8] {
241            &[]
242        }
243        
244        fn replace(&mut self, other: &Box<dyn CryptoStream>) {
245            self.ciphertext = other.get_cached(true).to_vec();
246        }
247        
248        fn set_cached(&mut self, data: &[u8]) {
249            self.ciphertext = data.to_vec();
250        }
251    }
252
253    #[test]
254    fn test_encrypt_bytes() {
255        let mut buf = Vec::new();
256        let mut writer = BitStreamWriter::new(&mut buf);
257        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
258
259        writer.write_byte(1);
260        writer.write_byte(2);
261        writer.write_byte(3);
262        writer.write_bit(false);
263        writer.write_bit(false);
264        writer.write_bit(true);
265        writer.write_bytes(&[5,6,7,8,9]);
266        writer.write_byte(10);
267
268        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
269    }
270
271
272    /// Helper to format buffer as binary strings
273    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
274        buffer.iter().map(|b| format!("{:08b}", b)).collect()
275    }
276
277    #[test]
278    fn test_write_bit() {
279        let mut buf = Vec::new();
280        let mut stream = BitStreamWriter::new(&mut buf);
281
282        stream.write_bit(true);
283        stream.write_bit(false);
284        stream.write_bit(true);
285        stream.write_bit(true); // 4 bits
286
287        assert_eq!(buf.len(), 1);
288        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
289    }
290
291    #[test]
292    fn test_write_small() {
293        let mut buf = Vec::new();
294        let mut stream = BitStreamWriter::new(&mut buf);
295
296        stream.write_small(0b101, 3); // write 3 bits
297        stream.write_small(0b11, 2); // write 2 bits
298        stream.write_small(0b111, 3); // write 3 bits
299
300        assert_eq!(buf.len(), 1);
301        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
302    }
303
304    #[test]
305    fn test_write_cross_byte() {
306        let mut buf = Vec::new();
307        let mut stream = BitStreamWriter::new(&mut buf);
308
309        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
310        stream.write_small(0b00101011, 7);
311        stream.write_small(0b1101, 4);
312
313        assert_eq!(buf.len(), 2);
314        assert_eq!(buf[0], 0b10101011);
315        assert_eq!(buf[1], 0b00000110);
316    }
317
318    #[test]
319    fn test_write_byte() {
320        let mut buf = Vec::new();
321        let mut stream = BitStreamWriter::new(&mut buf);
322
323        stream.write_bit(true); // 1 bit
324        stream.write_byte(0xAA); // should align and write full byte
325
326        assert_eq!(buf.len(), 2);
327        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
328        assert_eq!(buf[1], 0xAA); // full byte
329    }
330
331    #[test]
332    fn test_write_bytes() {
333        let mut buf = Vec::new();
334        let mut stream = BitStreamWriter::new(&mut buf);
335
336        stream.write_bit(true); // 1 bit
337        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
338
339        assert_eq!(buf.len(), 4);
340        assert_eq!(buf[0], 0b00000001); // padding of first bit
341        assert_eq!(buf[1], 0xAA);
342        assert_eq!(buf[2], 0xBB);
343        assert_eq!(buf[3], 0xCC);
344    }
345
346    #[test]
347    fn test_alignment() {
348        let mut buf = Vec::new();
349        let mut stream = BitStreamWriter::new(&mut buf);
350
351        stream.write_small(0b11, 2); // 2 bits
352        stream.align_byte();
353        stream.write_byte(0xFF);
354
355        assert_eq!(buf.len(), 2);
356        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
357        assert_eq!(buf[1], 0xFF);
358    }
359
360    #[test]
361    fn test_multiple_operations() {
362        let mut buf = Vec::new();
363        let mut stream = BitStreamWriter::new(&mut buf);
364
365        stream.write_bit(true);
366        stream.write_small(0b101, 3);
367        stream.write_byte(0xAA);
368        stream.write_bytes(&[0xBB, 0xCC]);
369        stream.write_small(0b11, 2);
370
371        let bin = buffer_to_bin(&buf);
372        println!("{:?}", bin);
373
374        assert_eq!(buf.len(), 5);
375        assert_eq!(buf[0], 0b00001011); // first 4 bits
376        assert_eq!(buf[1], 0xAA); // write_byte
377        assert_eq!(buf[2], 0xBB);
378        assert_eq!(buf[3], 0xCC);
379        assert_eq!(buf[4], 0b00000011); // last 2 bits
380    }
381
382    #[test]
383    fn test_write_dyn_int() {
384        let mut buf = Vec::new();
385        let mut stream = BitStreamWriter::new(&mut buf);
386
387        stream.write_dyn_int(127);
388        assert_eq!(1, stream.len());
389
390        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
391        assert_eq!(3, stream.len());
392
393        stream.write_dyn_int(268435455); // 4 bytes boundary
394        assert_eq!(7, stream.len());
395
396        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
397    }
398
399    #[test]
400    fn test_write_fixed_int() {
401        let mut buf = Vec::new();
402        let mut stream = BitStreamWriter::new(&mut buf);
403
404        stream.write_fixed_int(1u8);
405        stream.write_fixed_int(1i8);
406        stream.write_fixed_int(2u16);
407        stream.write_fixed_int(2i16);
408        stream.write_fixed_int(3u32);
409        stream.write_fixed_int(3i32);
410        stream.write_fixed_int(4u64);
411        stream.write_fixed_int(4i64);
412        stream.write_fixed_int(5u128);
413        stream.write_fixed_int(5i128);
414
415        assert_eq!(
416            vec![
417                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
418                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
419                0, 0, 0, 0, 0, 10
420            ],
421            buf
422        );
423    }
424
425    #[test]
426    fn test_slice_marker() {
427        let mut buf = Vec::new();
428        let mut stream = BitStreamWriter::new(&mut buf);
429
430        stream.write_bytes(&[10, 20, 30, 40, 50]);
431        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
432
433        stream.set_marker(Some(2));
434        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
435
436        stream.set_marker(None);
437        assert_eq!(stream.slice_marker(None), &[]);
438    }
439
440    #[test]
441    fn test_write_0_dynint() {
442        let mut buf = Vec::new();
443        let mut stream = BitStreamWriter::new(&mut buf);
444
445        stream.write_dyn_int(0);
446        assert_eq!(1, stream.len());
447    }
448}