Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    /// If crypto is set, it will return the unencrypted (cached) slice. This is different from other `slice` methods which always returns the raw buffer slice.
40    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41        let start = self.marker.unwrap_or(0);
42        let end = to.unwrap_or(self.byte_pos());
43
44        if let Some(crypto) = self.crypto.as_ref() {
45            return &crypto.get_cached(true)[start..end];
46        }
47
48        &self.buffer[start..end]
49    }
50
51    /// Set crypto stream
52    pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53        if let Some(existing) = self.crypto.as_ref()
54            && let Some(new_crypto) = crypto.as_mut()
55        {
56            new_crypto.replace(existing);
57            self.crypto = crypto;
58        } else {
59            self.crypto = crypto;
60        }
61    }
62
63     /// Remove crypto stream
64    pub fn reset_crypto(&mut self) {
65        self.crypto = None;
66    }
67
68    /// Get byte position of writer
69    pub fn byte_pos(&self) -> usize {
70        self.bit_pos / 8
71    }
72
73    /// Write a single bit
74    pub fn write_bit(&mut self, val: bool) {
75        self.write_small(val as u8, 1);
76    }
77
78    /// Write 1-8 bits (LSB-first)
79    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
80        assert!(bits > 0 && bits < 8);
81
82        while bits > 0 {
83            self.ensure_byte();
84
85            // Determine the current bit position inside the current byte (0..7)
86            let bit_offset = self.bit_pos % 8;
87
88            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
89            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
90
91            // Create a mask for the bits we are going to modify.
92            // Example: bit_offset = 2, bits_in_current_byte = 3
93            // mask = 00011100 (we will only touch bits 2,3,4)
94            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
95
96            // Prepare the bits to write:
97            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
98            // 2️⃣ Shift them left by bit_offset to align inside the byte
99            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
100            // Step 1: val & 0b111 = 0b101
101            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
102            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
103
104            let byte_pos = self.byte_pos();
105
106            // Clear the bits in this byte where we will write
107            self.buffer[byte_pos] &= !mask;
108
109            // Write the new bits into the cleared positions
110            self.buffer[byte_pos] |= shifted_val & mask;
111
112            // Decrease the number of bits remaining to write
113            bits -= bits_in_current_byte;
114
115            // Shift val right to remove the bits we've already written
116            val >>= bits_in_current_byte;
117
118            self.bit_pos += bits_in_current_byte as usize;
119
120            // If full byte, encrypt it (if needed)
121            if self.bit_pos % 8 == 0 {
122                if let Some(crypto) = self.crypto.as_mut() {
123                    let b = self.buffer[byte_pos];
124                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
125                }
126            }
127        }
128    }
129
130    /// Write a full byte, starting at the next byte boundary
131    pub fn write_byte(&mut self, byte: u8) {
132        self.align_byte();
133        self.ensure_byte();
134
135        let byte_pos = self.byte_pos();
136        let byte = if let Some(crypto) = self.crypto.as_mut() {
137            crypto.apply_keystream_byte(byte)
138        } else {
139            byte
140        };
141
142        self.buffer[byte_pos] = byte;
143        self.bit_pos += 8;
144    }
145
146    /// Write a slice of bytes, starting at the next byte boundary
147    pub fn write_bytes(&mut self, data: &[u8]) {
148        self.align_byte();
149
150        if let Some(crypto) = self.crypto.as_mut() {
151            let encrypted = crypto.apply_keystream(data);
152            self.buffer.extend_from_slice(encrypted);
153        } else {
154            self.buffer.extend_from_slice(data);
155        }
156
157        self.bit_pos += 8 * data.len();
158    }
159
160    /// Write a dynamic int, starting at the next byte bounary
161    /// The last bit is used as a continuation flag for the next byte
162    pub fn write_dyn_int(&mut self, mut val: u128) {
163        while val > 0 {
164            let mut encoded = val % 128;
165            val /= 128;
166            if val > 0 {
167                encoded |= 128;
168            }
169            self.write_byte(encoded as u8);
170        }
171    }
172
173    /// Write a integer of fixed size to the buffer
174    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
175        self.write_bytes(&val.serialize());
176    }
177
178    /// Ensure the buffer has at least `byte_pos + 1` bytes
179    fn ensure_byte(&mut self) {
180        let byte_pos = self.byte_pos();
181        if byte_pos >= self.buffer.len() {
182            self.buffer.resize(byte_pos + 1, 0);
183        }
184    }
185
186    /// Align the stream to the next byte boundary
187    pub fn align_byte(&mut self) {
188        let rem = self.bit_pos % 8;
189        if rem != 0 {
190            let byte_pos = self.byte_pos();
191            self.bit_pos += 8 - rem;
192
193            // Encrypt byte
194            if let Some(crypto) = self.crypto.as_mut() {
195                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
196            }
197        }
198    }
199
200    /// Reset writing position
201    pub fn reset(&mut self) {
202        self.bit_pos = 0;
203    }
204
205    /// Get buffer length
206    pub fn len(&self) -> usize {
207        self.buffer.len()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use crate::CryptoStream;
214
215    use super::BitStreamWriter;
216
217    struct PlusOneEncrypter {
218        ciphertext: Vec<u8>
219    }
220
221    impl CryptoStream for PlusOneEncrypter {
222        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
223            self.ciphertext.push(b + 1);
224            *self.ciphertext.last().unwrap()
225        }
226    
227        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
228            let d = slice.iter().map(|s|s + 1);
229            self.ciphertext.extend(d);
230            &self.ciphertext[self.ciphertext.len() - slice.len()..]
231        }
232
233        fn get_cached(&self, original: bool) -> &[u8] {
234            &[]
235        }
236        
237        fn replace(&mut self, other: &Box<dyn CryptoStream>) {
238            self.ciphertext = other.get_cached(true).to_vec();
239        }
240    }
241
242    #[test]
243    fn test_encrypt_bytes() {
244        let mut buf = Vec::new();
245        let mut writer = BitStreamWriter::new(&mut buf);
246        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
247
248        writer.write_byte(1);
249        writer.write_byte(2);
250        writer.write_byte(3);
251        writer.write_bit(false);
252        writer.write_bit(false);
253        writer.write_bit(true);
254        writer.write_bytes(&[5,6,7,8,9]);
255        writer.write_byte(10);
256
257        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
258    }
259
260
261    /// Helper to format buffer as binary strings
262    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
263        buffer.iter().map(|b| format!("{:08b}", b)).collect()
264    }
265
266    #[test]
267    fn test_write_bit() {
268        let mut buf = Vec::new();
269        let mut stream = BitStreamWriter::new(&mut buf);
270
271        stream.write_bit(true);
272        stream.write_bit(false);
273        stream.write_bit(true);
274        stream.write_bit(true); // 4 bits
275
276        assert_eq!(buf.len(), 1);
277        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
278    }
279
280    #[test]
281    fn test_write_small() {
282        let mut buf = Vec::new();
283        let mut stream = BitStreamWriter::new(&mut buf);
284
285        stream.write_small(0b101, 3); // write 3 bits
286        stream.write_small(0b11, 2); // write 2 bits
287        stream.write_small(0b111, 3); // write 3 bits
288
289        assert_eq!(buf.len(), 1);
290        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
291    }
292
293    #[test]
294    fn test_write_cross_byte() {
295        let mut buf = Vec::new();
296        let mut stream = BitStreamWriter::new(&mut buf);
297
298        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
299        stream.write_small(0b00101011, 7);
300        stream.write_small(0b1101, 4);
301
302        assert_eq!(buf.len(), 2);
303        assert_eq!(buf[0], 0b10101011);
304        assert_eq!(buf[1], 0b00000110);
305    }
306
307    #[test]
308    fn test_write_byte() {
309        let mut buf = Vec::new();
310        let mut stream = BitStreamWriter::new(&mut buf);
311
312        stream.write_bit(true); // 1 bit
313        stream.write_byte(0xAA); // should align and write full byte
314
315        assert_eq!(buf.len(), 2);
316        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
317        assert_eq!(buf[1], 0xAA); // full byte
318    }
319
320    #[test]
321    fn test_write_bytes() {
322        let mut buf = Vec::new();
323        let mut stream = BitStreamWriter::new(&mut buf);
324
325        stream.write_bit(true); // 1 bit
326        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
327
328        assert_eq!(buf.len(), 4);
329        assert_eq!(buf[0], 0b00000001); // padding of first bit
330        assert_eq!(buf[1], 0xAA);
331        assert_eq!(buf[2], 0xBB);
332        assert_eq!(buf[3], 0xCC);
333    }
334
335    #[test]
336    fn test_alignment() {
337        let mut buf = Vec::new();
338        let mut stream = BitStreamWriter::new(&mut buf);
339
340        stream.write_small(0b11, 2); // 2 bits
341        stream.align_byte();
342        stream.write_byte(0xFF);
343
344        assert_eq!(buf.len(), 2);
345        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
346        assert_eq!(buf[1], 0xFF);
347    }
348
349    #[test]
350    fn test_multiple_operations() {
351        let mut buf = Vec::new();
352        let mut stream = BitStreamWriter::new(&mut buf);
353
354        stream.write_bit(true);
355        stream.write_small(0b101, 3);
356        stream.write_byte(0xAA);
357        stream.write_bytes(&[0xBB, 0xCC]);
358        stream.write_small(0b11, 2);
359
360        let bin = buffer_to_bin(&buf);
361        println!("{:?}", bin);
362
363        assert_eq!(buf.len(), 5);
364        assert_eq!(buf[0], 0b00001011); // first 4 bits
365        assert_eq!(buf[1], 0xAA); // write_byte
366        assert_eq!(buf[2], 0xBB);
367        assert_eq!(buf[3], 0xCC);
368        assert_eq!(buf[4], 0b00000011); // last 2 bits
369    }
370
371    #[test]
372    fn test_write_dyn_int() {
373        let mut buf = Vec::new();
374        let mut stream = BitStreamWriter::new(&mut buf);
375
376        stream.write_dyn_int(127);
377        assert_eq!(1, stream.len());
378
379        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
380        assert_eq!(3, stream.len());
381
382        stream.write_dyn_int(268435455); // 4 bytes boundary
383        assert_eq!(7, stream.len());
384
385        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
386    }
387
388    #[test]
389    fn test_write_fixed_int() {
390        let mut buf = Vec::new();
391        let mut stream = BitStreamWriter::new(&mut buf);
392
393        stream.write_fixed_int(1u8);
394        stream.write_fixed_int(1i8);
395        stream.write_fixed_int(2u16);
396        stream.write_fixed_int(2i16);
397        stream.write_fixed_int(3u32);
398        stream.write_fixed_int(3i32);
399        stream.write_fixed_int(4u64);
400        stream.write_fixed_int(4i64);
401        stream.write_fixed_int(5u128);
402        stream.write_fixed_int(5i128);
403
404        assert_eq!(
405            vec![
406                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
407                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
408                0, 0, 0, 0, 0, 10
409            ],
410            buf
411        );
412    }
413
414    #[test]
415    fn test_slice_marker() {
416        let mut buf = Vec::new();
417        let mut stream = BitStreamWriter::new(&mut buf);
418
419        stream.write_bytes(&[10, 20, 30, 40, 50]);
420        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
421
422        stream.set_marker(Some(2));
423        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
424
425        stream.set_marker(None);
426        assert_eq!(stream.slice_marker(None), &[]);
427    }
428}