Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    /// If crypto is set, it will return the unencrypted (cached) slice. This is different from other `slice` methods which always returns the raw buffer slice.
40    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41        let start = self.marker.unwrap_or(0);
42        let end = to.unwrap_or(self.byte_pos());
43
44        if let Some(crypto) = self.crypto.as_ref() {
45            return &crypto.get_cached(true)[start..end];
46        }
47
48        &self.buffer[start..end]
49    }
50
51    /// Set crypto stream
52    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
53        self.crypto = crypto;
54    }
55
56     /// Remove crypto stream
57    pub fn reset_crypto(&mut self) {
58        self.crypto = None;
59    }
60
61    /// Get byte position of writer
62    pub fn byte_pos(&self) -> usize {
63        self.bit_pos / 8
64    }
65
66    /// Write a single bit
67    pub fn write_bit(&mut self, val: bool) {
68        self.write_small(val as u8, 1);
69    }
70
71    /// Write 1-8 bits (LSB-first)
72    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
73        assert!(bits > 0 && bits < 8);
74
75        while bits > 0 {
76            self.ensure_byte();
77
78            // Determine the current bit position inside the current byte (0..7)
79            let bit_offset = self.bit_pos % 8;
80
81            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
82            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
83
84            // Create a mask for the bits we are going to modify.
85            // Example: bit_offset = 2, bits_in_current_byte = 3
86            // mask = 00011100 (we will only touch bits 2,3,4)
87            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
88
89            // Prepare the bits to write:
90            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
91            // 2️⃣ Shift them left by bit_offset to align inside the byte
92            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
93            // Step 1: val & 0b111 = 0b101
94            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
95            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
96
97            let byte_pos = self.byte_pos();
98
99            // Clear the bits in this byte where we will write
100            self.buffer[byte_pos] &= !mask;
101
102            // Write the new bits into the cleared positions
103            self.buffer[byte_pos] |= shifted_val & mask;
104
105            // Decrease the number of bits remaining to write
106            bits -= bits_in_current_byte;
107
108            // Shift val right to remove the bits we've already written
109            val >>= bits_in_current_byte;
110
111            self.bit_pos += bits_in_current_byte as usize;
112
113            // If full byte, encrypt it (if needed)
114            if self.bit_pos % 8 == 0 {
115                if let Some(crypto) = self.crypto.as_mut() {
116                    let b = self.buffer[byte_pos];
117                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
118                }
119            }
120        }
121    }
122
123    /// Write a full byte, starting at the next byte boundary
124    pub fn write_byte(&mut self, byte: u8) {
125        self.align_byte();
126        self.ensure_byte();
127
128        let byte_pos = self.byte_pos();
129        let byte = if let Some(crypto) = self.crypto.as_mut() {
130            crypto.apply_keystream_byte(byte)
131        } else {
132            byte
133        };
134
135        self.buffer[byte_pos] = byte;
136        self.bit_pos += 8;
137    }
138
139    /// Write a slice of bytes, starting at the next byte boundary
140    pub fn write_bytes(&mut self, data: &[u8]) {
141        self.align_byte();
142
143        if let Some(crypto) = self.crypto.as_mut() {
144            let encrypted = crypto.apply_keystream(data);
145            self.buffer.extend_from_slice(encrypted);
146        } else {
147            self.buffer.extend_from_slice(data);
148        }
149
150        self.bit_pos += 8 * data.len();
151    }
152
153    /// Write a dynamic int, starting at the next byte bounary
154    /// The last bit is used as a continuation flag for the next byte
155    pub fn write_dyn_int(&mut self, mut val: u128) {
156        while val > 0 {
157            let mut encoded = val % 128;
158            val /= 128;
159            if val > 0 {
160                encoded |= 128;
161            }
162            self.write_byte(encoded as u8);
163        }
164    }
165
166    /// Write a integer of fixed size to the buffer
167    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
168        self.write_bytes(&val.serialize());
169    }
170
171    /// Ensure the buffer has at least `byte_pos + 1` bytes
172    fn ensure_byte(&mut self) {
173        let byte_pos = self.byte_pos();
174        if byte_pos >= self.buffer.len() {
175            self.buffer.resize(byte_pos + 1, 0);
176        }
177    }
178
179    /// Align the stream to the next byte boundary
180    pub fn align_byte(&mut self) {
181        let rem = self.bit_pos % 8;
182        if rem != 0 {
183            let byte_pos = self.byte_pos();
184            self.bit_pos += 8 - rem;
185
186            // Encrypt byte
187            if let Some(crypto) = self.crypto.as_mut() {
188                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
189            }
190        }
191    }
192
193    /// Reset writing position
194    pub fn reset(&mut self) {
195        self.bit_pos = 0;
196    }
197
198    /// Get buffer length
199    pub fn len(&self) -> usize {
200        self.buffer.len()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use crate::CryptoStream;
207
208    use super::BitStreamWriter;
209
210    struct PlusOneEncrypter {
211        ciphertext: Vec<u8>
212    }
213
214    impl CryptoStream for PlusOneEncrypter {
215        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
216            self.ciphertext.push(b + 1);
217            *self.ciphertext.last().unwrap()
218        }
219    
220        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
221            let d = slice.iter().map(|s|s + 1);
222            self.ciphertext.extend(d);
223            &self.ciphertext[self.ciphertext.len() - slice.len()..]
224        }
225
226        fn get_cached(&self, original: bool) -> &[u8] {
227            &[]
228        }
229        
230        fn replace(&mut self, other: Box<dyn CryptoStream>) {
231            self.ciphertext = other.get_cached(true).to_vec();
232        }
233    }
234
235    #[test]
236    fn test_encrypt_bytes() {
237        let mut buf = Vec::new();
238        let mut writer = BitStreamWriter::new(&mut buf);
239        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
240
241        writer.write_byte(1);
242        writer.write_byte(2);
243        writer.write_byte(3);
244        writer.write_bit(false);
245        writer.write_bit(false);
246        writer.write_bit(true);
247        writer.write_bytes(&[5,6,7,8,9]);
248        writer.write_byte(10);
249
250        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
251    }
252
253
254    /// Helper to format buffer as binary strings
255    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
256        buffer.iter().map(|b| format!("{:08b}", b)).collect()
257    }
258
259    #[test]
260    fn test_write_bit() {
261        let mut buf = Vec::new();
262        let mut stream = BitStreamWriter::new(&mut buf);
263
264        stream.write_bit(true);
265        stream.write_bit(false);
266        stream.write_bit(true);
267        stream.write_bit(true); // 4 bits
268
269        assert_eq!(buf.len(), 1);
270        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
271    }
272
273    #[test]
274    fn test_write_small() {
275        let mut buf = Vec::new();
276        let mut stream = BitStreamWriter::new(&mut buf);
277
278        stream.write_small(0b101, 3); // write 3 bits
279        stream.write_small(0b11, 2); // write 2 bits
280        stream.write_small(0b111, 3); // write 3 bits
281
282        assert_eq!(buf.len(), 1);
283        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
284    }
285
286    #[test]
287    fn test_write_cross_byte() {
288        let mut buf = Vec::new();
289        let mut stream = BitStreamWriter::new(&mut buf);
290
291        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
292        stream.write_small(0b00101011, 7);
293        stream.write_small(0b1101, 4);
294
295        assert_eq!(buf.len(), 2);
296        assert_eq!(buf[0], 0b10101011);
297        assert_eq!(buf[1], 0b00000110);
298    }
299
300    #[test]
301    fn test_write_byte() {
302        let mut buf = Vec::new();
303        let mut stream = BitStreamWriter::new(&mut buf);
304
305        stream.write_bit(true); // 1 bit
306        stream.write_byte(0xAA); // should align and write full byte
307
308        assert_eq!(buf.len(), 2);
309        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
310        assert_eq!(buf[1], 0xAA); // full byte
311    }
312
313    #[test]
314    fn test_write_bytes() {
315        let mut buf = Vec::new();
316        let mut stream = BitStreamWriter::new(&mut buf);
317
318        stream.write_bit(true); // 1 bit
319        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
320
321        assert_eq!(buf.len(), 4);
322        assert_eq!(buf[0], 0b00000001); // padding of first bit
323        assert_eq!(buf[1], 0xAA);
324        assert_eq!(buf[2], 0xBB);
325        assert_eq!(buf[3], 0xCC);
326    }
327
328    #[test]
329    fn test_alignment() {
330        let mut buf = Vec::new();
331        let mut stream = BitStreamWriter::new(&mut buf);
332
333        stream.write_small(0b11, 2); // 2 bits
334        stream.align_byte();
335        stream.write_byte(0xFF);
336
337        assert_eq!(buf.len(), 2);
338        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
339        assert_eq!(buf[1], 0xFF);
340    }
341
342    #[test]
343    fn test_multiple_operations() {
344        let mut buf = Vec::new();
345        let mut stream = BitStreamWriter::new(&mut buf);
346
347        stream.write_bit(true);
348        stream.write_small(0b101, 3);
349        stream.write_byte(0xAA);
350        stream.write_bytes(&[0xBB, 0xCC]);
351        stream.write_small(0b11, 2);
352
353        let bin = buffer_to_bin(&buf);
354        println!("{:?}", bin);
355
356        assert_eq!(buf.len(), 5);
357        assert_eq!(buf[0], 0b00001011); // first 4 bits
358        assert_eq!(buf[1], 0xAA); // write_byte
359        assert_eq!(buf[2], 0xBB);
360        assert_eq!(buf[3], 0xCC);
361        assert_eq!(buf[4], 0b00000011); // last 2 bits
362    }
363
364    #[test]
365    fn test_write_dyn_int() {
366        let mut buf = Vec::new();
367        let mut stream = BitStreamWriter::new(&mut buf);
368
369        stream.write_dyn_int(127);
370        assert_eq!(1, stream.len());
371
372        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
373        assert_eq!(3, stream.len());
374
375        stream.write_dyn_int(268435455); // 4 bytes boundary
376        assert_eq!(7, stream.len());
377
378        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
379    }
380
381    #[test]
382    fn test_write_fixed_int() {
383        let mut buf = Vec::new();
384        let mut stream = BitStreamWriter::new(&mut buf);
385
386        stream.write_fixed_int(1u8);
387        stream.write_fixed_int(1i8);
388        stream.write_fixed_int(2u16);
389        stream.write_fixed_int(2i16);
390        stream.write_fixed_int(3u32);
391        stream.write_fixed_int(3i32);
392        stream.write_fixed_int(4u64);
393        stream.write_fixed_int(4i64);
394        stream.write_fixed_int(5u128);
395        stream.write_fixed_int(5i128);
396
397        assert_eq!(
398            vec![
399                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
400                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
401                0, 0, 0, 0, 0, 10
402            ],
403            buf
404        );
405    }
406
407    #[test]
408    fn test_slice_marker() {
409        let mut buf = Vec::new();
410        let mut stream = BitStreamWriter::new(&mut buf);
411
412        stream.write_bytes(&[10, 20, 30, 40, 50]);
413        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
414
415        stream.set_marker(Some(2));
416        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
417
418        stream.set_marker(None);
419        assert_eq!(stream.slice_marker(None), &[]);
420    }
421}