Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    /// If crypto is set, it will return the unencrypted (cached) slice. This is different from other `slice` methods which always returns the raw buffer slice.
40    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41        let start = self.marker.unwrap_or(0);
42        let end = to.unwrap_or(self.byte_pos());
43
44        if let Some(crypto) = self.crypto.as_ref() {
45            return &crypto.get_cached(true)[start..end];
46        }
47
48        &self.buffer[start..end]
49    }
50
51    /// Set crypto stream
52    pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53        if let Some(new) = crypto.as_mut() {
54            if let Some(existing) = self.crypto.as_ref() {
55                new.replace(existing);
56            } else {
57                // Initialize new crypto stream with the data written so far, so that it can cache full plaintext
58                new.set_cached(self.slice());
59            }
60        }
61
62        self.crypto = crypto;
63    }
64
65     /// Remove crypto stream
66    pub fn reset_crypto(&mut self) {
67        self.crypto = None;
68    }
69
70    /// Get byte position of writer
71    pub fn byte_pos(&self) -> usize {
72        self.bit_pos / 8
73    }
74
75    /// Write a single bit
76    pub fn write_bit(&mut self, val: bool) {
77        self.write_small(val as u8, 1);
78    }
79
80    /// Write 1-8 bits (LSB-first)
81    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
82        assert!(bits > 0 && bits < 8);
83
84        while bits > 0 {
85            self.ensure_byte();
86
87            // Determine the current bit position inside the current byte (0..7)
88            let bit_offset = self.bit_pos % 8;
89
90            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
91            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93            // Create a mask for the bits we are going to modify.
94            // Example: bit_offset = 2, bits_in_current_byte = 3
95            // mask = 00011100 (we will only touch bits 2,3,4)
96            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97
98            // Prepare the bits to write:
99            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
100            // 2️⃣ Shift them left by bit_offset to align inside the byte
101            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
102            // Step 1: val & 0b111 = 0b101
103            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
104            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
105
106            let byte_pos = self.byte_pos();
107
108            // Clear the bits in this byte where we will write
109            self.buffer[byte_pos] &= !mask;
110
111            // Write the new bits into the cleared positions
112            self.buffer[byte_pos] |= shifted_val & mask;
113
114            // Decrease the number of bits remaining to write
115            bits -= bits_in_current_byte;
116
117            // Shift val right to remove the bits we've already written
118            val >>= bits_in_current_byte;
119
120            self.bit_pos += bits_in_current_byte as usize;
121
122            // If full byte, encrypt it (if needed)
123            if self.bit_pos % 8 == 0 {
124                if let Some(crypto) = self.crypto.as_mut() {
125                    let b = self.buffer[byte_pos];
126                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
127                }
128            }
129        }
130    }
131
132    /// Write a full byte, starting at the next byte boundary
133    pub fn write_byte(&mut self, byte: u8) {
134        self.align_byte();
135        self.ensure_byte();
136
137        let byte_pos = self.byte_pos();
138        let byte = if let Some(crypto) = self.crypto.as_mut() {
139            crypto.apply_keystream_byte(byte)
140        } else {
141            byte
142        };
143
144        self.buffer[byte_pos] = byte;
145        self.bit_pos += 8;
146    }
147
148    /// Write a slice of bytes, starting at the next byte boundary
149    pub fn write_bytes(&mut self, data: &[u8]) {
150        self.align_byte();
151
152        if let Some(crypto) = self.crypto.as_mut() {
153            let encrypted = crypto.apply_keystream(data);
154            self.buffer.extend_from_slice(encrypted);
155        } else {
156            self.buffer.extend_from_slice(data);
157        }
158
159        self.bit_pos += 8 * data.len();
160    }
161
162    /// Write a dynamic int, starting at the next byte bounary
163    /// The last bit is used as a continuation flag for the next byte
164    pub fn write_dyn_int(&mut self, mut val: u128) {
165        while val > 0 {
166            let mut encoded = val % 128;
167            val /= 128;
168            if val > 0 {
169                encoded |= 128;
170            }
171            self.write_byte(encoded as u8);
172        }
173    }
174
175    /// Write a integer of fixed size to the buffer
176    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
177        self.write_bytes(&val.serialize());
178    }
179
180    /// Ensure the buffer has at least `byte_pos + 1` bytes
181    fn ensure_byte(&mut self) {
182        let byte_pos = self.byte_pos();
183        if byte_pos >= self.buffer.len() {
184            self.buffer.resize(byte_pos + 1, 0);
185        }
186    }
187
188    /// Align the stream to the next byte boundary
189    pub fn align_byte(&mut self) {
190        let rem = self.bit_pos % 8;
191        if rem != 0 {
192            let byte_pos = self.byte_pos();
193            self.bit_pos += 8 - rem;
194
195            // Encrypt byte
196            if let Some(crypto) = self.crypto.as_mut() {
197                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
198            }
199        }
200    }
201
202    /// Reset writing position
203    pub fn reset(&mut self) {
204        self.bit_pos = 0;
205    }
206
207    /// Get buffer length
208    pub fn len(&self) -> usize {
209        self.buffer.len()
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use crate::CryptoStream;
216
217    use super::BitStreamWriter;
218
219    struct PlusOneEncrypter {
220        ciphertext: Vec<u8>
221    }
222
223    impl CryptoStream for PlusOneEncrypter {
224        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
225            self.ciphertext.push(b + 1);
226            *self.ciphertext.last().unwrap()
227        }
228    
229        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
230            let d = slice.iter().map(|s|s + 1);
231            self.ciphertext.extend(d);
232            &self.ciphertext[self.ciphertext.len() - slice.len()..]
233        }
234
235        fn get_cached(&self, original: bool) -> &[u8] {
236            &[]
237        }
238        
239        fn replace(&mut self, other: &Box<dyn CryptoStream>) {
240            self.ciphertext = other.get_cached(true).to_vec();
241        }
242        
243        fn set_cached(&mut self, data: &[u8]) {
244            self.ciphertext = data.to_vec();
245        }
246    }
247
248    #[test]
249    fn test_encrypt_bytes() {
250        let mut buf = Vec::new();
251        let mut writer = BitStreamWriter::new(&mut buf);
252        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
253
254        writer.write_byte(1);
255        writer.write_byte(2);
256        writer.write_byte(3);
257        writer.write_bit(false);
258        writer.write_bit(false);
259        writer.write_bit(true);
260        writer.write_bytes(&[5,6,7,8,9]);
261        writer.write_byte(10);
262
263        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
264    }
265
266
267    /// Helper to format buffer as binary strings
268    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
269        buffer.iter().map(|b| format!("{:08b}", b)).collect()
270    }
271
272    #[test]
273    fn test_write_bit() {
274        let mut buf = Vec::new();
275        let mut stream = BitStreamWriter::new(&mut buf);
276
277        stream.write_bit(true);
278        stream.write_bit(false);
279        stream.write_bit(true);
280        stream.write_bit(true); // 4 bits
281
282        assert_eq!(buf.len(), 1);
283        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
284    }
285
286    #[test]
287    fn test_write_small() {
288        let mut buf = Vec::new();
289        let mut stream = BitStreamWriter::new(&mut buf);
290
291        stream.write_small(0b101, 3); // write 3 bits
292        stream.write_small(0b11, 2); // write 2 bits
293        stream.write_small(0b111, 3); // write 3 bits
294
295        assert_eq!(buf.len(), 1);
296        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
297    }
298
299    #[test]
300    fn test_write_cross_byte() {
301        let mut buf = Vec::new();
302        let mut stream = BitStreamWriter::new(&mut buf);
303
304        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
305        stream.write_small(0b00101011, 7);
306        stream.write_small(0b1101, 4);
307
308        assert_eq!(buf.len(), 2);
309        assert_eq!(buf[0], 0b10101011);
310        assert_eq!(buf[1], 0b00000110);
311    }
312
313    #[test]
314    fn test_write_byte() {
315        let mut buf = Vec::new();
316        let mut stream = BitStreamWriter::new(&mut buf);
317
318        stream.write_bit(true); // 1 bit
319        stream.write_byte(0xAA); // should align and write full byte
320
321        assert_eq!(buf.len(), 2);
322        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
323        assert_eq!(buf[1], 0xAA); // full byte
324    }
325
326    #[test]
327    fn test_write_bytes() {
328        let mut buf = Vec::new();
329        let mut stream = BitStreamWriter::new(&mut buf);
330
331        stream.write_bit(true); // 1 bit
332        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
333
334        assert_eq!(buf.len(), 4);
335        assert_eq!(buf[0], 0b00000001); // padding of first bit
336        assert_eq!(buf[1], 0xAA);
337        assert_eq!(buf[2], 0xBB);
338        assert_eq!(buf[3], 0xCC);
339    }
340
341    #[test]
342    fn test_alignment() {
343        let mut buf = Vec::new();
344        let mut stream = BitStreamWriter::new(&mut buf);
345
346        stream.write_small(0b11, 2); // 2 bits
347        stream.align_byte();
348        stream.write_byte(0xFF);
349
350        assert_eq!(buf.len(), 2);
351        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
352        assert_eq!(buf[1], 0xFF);
353    }
354
355    #[test]
356    fn test_multiple_operations() {
357        let mut buf = Vec::new();
358        let mut stream = BitStreamWriter::new(&mut buf);
359
360        stream.write_bit(true);
361        stream.write_small(0b101, 3);
362        stream.write_byte(0xAA);
363        stream.write_bytes(&[0xBB, 0xCC]);
364        stream.write_small(0b11, 2);
365
366        let bin = buffer_to_bin(&buf);
367        println!("{:?}", bin);
368
369        assert_eq!(buf.len(), 5);
370        assert_eq!(buf[0], 0b00001011); // first 4 bits
371        assert_eq!(buf[1], 0xAA); // write_byte
372        assert_eq!(buf[2], 0xBB);
373        assert_eq!(buf[3], 0xCC);
374        assert_eq!(buf[4], 0b00000011); // last 2 bits
375    }
376
377    #[test]
378    fn test_write_dyn_int() {
379        let mut buf = Vec::new();
380        let mut stream = BitStreamWriter::new(&mut buf);
381
382        stream.write_dyn_int(127);
383        assert_eq!(1, stream.len());
384
385        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
386        assert_eq!(3, stream.len());
387
388        stream.write_dyn_int(268435455); // 4 bytes boundary
389        assert_eq!(7, stream.len());
390
391        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
392    }
393
394    #[test]
395    fn test_write_fixed_int() {
396        let mut buf = Vec::new();
397        let mut stream = BitStreamWriter::new(&mut buf);
398
399        stream.write_fixed_int(1u8);
400        stream.write_fixed_int(1i8);
401        stream.write_fixed_int(2u16);
402        stream.write_fixed_int(2i16);
403        stream.write_fixed_int(3u32);
404        stream.write_fixed_int(3i32);
405        stream.write_fixed_int(4u64);
406        stream.write_fixed_int(4i64);
407        stream.write_fixed_int(5u128);
408        stream.write_fixed_int(5i128);
409
410        assert_eq!(
411            vec![
412                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
413                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
414                0, 0, 0, 0, 0, 10
415            ],
416            buf
417        );
418    }
419
420    #[test]
421    fn test_slice_marker() {
422        let mut buf = Vec::new();
423        let mut stream = BitStreamWriter::new(&mut buf);
424
425        stream.write_bytes(&[10, 20, 30, 40, 50]);
426        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
427
428        stream.set_marker(Some(2));
429        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
430
431        stream.set_marker(None);
432        assert_eq!(stream.slice_marker(None), &[]);
433    }
434}