Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    /// If crypto is set, it will return the unencrypted (cached) slice. This is different from other `slice` methods which always returns the raw buffer slice.
40    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41        let start = self.marker.unwrap_or(0);
42        let end = to.unwrap_or(self.byte_pos());
43
44        if let Some(crypto) = self.crypto.as_ref() {
45            return &crypto.get_cached(true)[start..end];
46        }
47
48        &self.buffer[start..end]
49    }
50
51    /// Set crypto stream
52    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
53        self.crypto = crypto;
54    }
55
56    /// Get byte position of writer
57    pub fn byte_pos(&self) -> usize {
58        self.bit_pos / 8
59    }
60
61    /// Write a single bit
62    pub fn write_bit(&mut self, val: bool) {
63        self.write_small(val as u8, 1);
64    }
65
66    /// Write 1-8 bits (LSB-first)
67    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
68        assert!(bits > 0 && bits < 8);
69
70        while bits > 0 {
71            self.ensure_byte();
72
73            // Determine the current bit position inside the current byte (0..7)
74            let bit_offset = self.bit_pos % 8;
75
76            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
77            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
78
79            // Create a mask for the bits we are going to modify.
80            // Example: bit_offset = 2, bits_in_current_byte = 3
81            // mask = 00011100 (we will only touch bits 2,3,4)
82            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
83
84            // Prepare the bits to write:
85            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
86            // 2️⃣ Shift them left by bit_offset to align inside the byte
87            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
88            // Step 1: val & 0b111 = 0b101
89            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
90            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
91
92            let byte_pos = self.byte_pos();
93
94            // Clear the bits in this byte where we will write
95            self.buffer[byte_pos] &= !mask;
96
97            // Write the new bits into the cleared positions
98            self.buffer[byte_pos] |= shifted_val & mask;
99
100            // Decrease the number of bits remaining to write
101            bits -= bits_in_current_byte;
102
103            // Shift val right to remove the bits we've already written
104            val >>= bits_in_current_byte;
105
106            self.bit_pos += bits_in_current_byte as usize;
107
108            // If full byte, encrypt it (if needed)
109            if self.bit_pos % 8 == 0 {
110                if let Some(crypto) = self.crypto.as_mut() {
111                    let b = self.buffer[byte_pos];
112                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
113                }
114            }
115        }
116    }
117
118    /// Write a full byte, starting at the next byte boundary
119    pub fn write_byte(&mut self, byte: u8) {
120        self.align_byte();
121        self.ensure_byte();
122
123        let byte_pos = self.byte_pos();
124        let byte = if let Some(crypto) = self.crypto.as_mut() {
125            crypto.apply_keystream_byte(byte)
126        } else {
127            byte
128        };
129
130        self.buffer[byte_pos] = byte;
131        self.bit_pos += 8;
132    }
133
134    /// Write a slice of bytes, starting at the next byte boundary
135    pub fn write_bytes(&mut self, data: &[u8]) {
136        self.align_byte();
137
138        if let Some(crypto) = self.crypto.as_mut() {
139            let encrypted = crypto.apply_keystream(data);
140            self.buffer.extend_from_slice(encrypted);
141        } else {
142            self.buffer.extend_from_slice(data);
143        }
144
145        self.bit_pos += 8 * data.len();
146    }
147
148    /// Write a dynamic int, starting at the next byte bounary
149    /// The last bit is used as a continuation flag for the next byte
150    pub fn write_dyn_int(&mut self, mut val: u128) {
151        while val > 0 {
152            let mut encoded = val % 128;
153            val /= 128;
154            if val > 0 {
155                encoded |= 128;
156            }
157            self.write_byte(encoded as u8);
158        }
159    }
160
161    /// Write a integer of fixed size to the buffer
162    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
163        self.write_bytes(&val.serialize());
164    }
165
166    /// Ensure the buffer has at least `byte_pos + 1` bytes
167    fn ensure_byte(&mut self) {
168        let byte_pos = self.byte_pos();
169        if byte_pos >= self.buffer.len() {
170            self.buffer.resize(byte_pos + 1, 0);
171        }
172    }
173
174    /// Align the stream to the next byte boundary
175    pub fn align_byte(&mut self) {
176        let rem = self.bit_pos % 8;
177        if rem != 0 {
178            let byte_pos = self.byte_pos();
179            self.bit_pos += 8 - rem;
180
181            // Encrypt byte
182            if let Some(crypto) = self.crypto.as_mut() {
183                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
184            }
185        }
186    }
187
188    /// Reset writing position
189    pub fn reset(&mut self) {
190        self.bit_pos = 0;
191    }
192
193    /// Get buffer length
194    pub fn len(&self) -> usize {
195        self.buffer.len()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use crate::CryptoStream;
202
203    use super::BitStreamWriter;
204
205    struct PlusOneEncrypter {
206        ciphertext: Vec<u8>
207    }
208
209    impl CryptoStream for PlusOneEncrypter {
210        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
211            self.ciphertext.push(b + 1);
212            *self.ciphertext.last().unwrap()
213        }
214    
215        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
216            let d = slice.iter().map(|s|s + 1);
217            self.ciphertext.extend(d);
218            &self.ciphertext[self.ciphertext.len() - slice.len()..]
219        }
220
221        fn get_cached(&self, original: bool) -> &[u8] {
222            &[]
223        }
224    }
225
226    #[test]
227    fn test_encrypt_bytes() {
228        let mut buf = Vec::new();
229        let mut writer = BitStreamWriter::new(&mut buf);
230        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
231
232        writer.write_byte(1);
233        writer.write_byte(2);
234        writer.write_byte(3);
235        writer.write_bit(false);
236        writer.write_bit(false);
237        writer.write_bit(true);
238        writer.write_bytes(&[5,6,7,8,9]);
239        writer.write_byte(10);
240
241        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
242    }
243
244
245    /// Helper to format buffer as binary strings
246    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
247        buffer.iter().map(|b| format!("{:08b}", b)).collect()
248    }
249
250    #[test]
251    fn test_write_bit() {
252        let mut buf = Vec::new();
253        let mut stream = BitStreamWriter::new(&mut buf);
254
255        stream.write_bit(true);
256        stream.write_bit(false);
257        stream.write_bit(true);
258        stream.write_bit(true); // 4 bits
259
260        assert_eq!(buf.len(), 1);
261        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
262    }
263
264    #[test]
265    fn test_write_small() {
266        let mut buf = Vec::new();
267        let mut stream = BitStreamWriter::new(&mut buf);
268
269        stream.write_small(0b101, 3); // write 3 bits
270        stream.write_small(0b11, 2); // write 2 bits
271        stream.write_small(0b111, 3); // write 3 bits
272
273        assert_eq!(buf.len(), 1);
274        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
275    }
276
277    #[test]
278    fn test_write_cross_byte() {
279        let mut buf = Vec::new();
280        let mut stream = BitStreamWriter::new(&mut buf);
281
282        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
283        stream.write_small(0b00101011, 7);
284        stream.write_small(0b1101, 4);
285
286        assert_eq!(buf.len(), 2);
287        assert_eq!(buf[0], 0b10101011);
288        assert_eq!(buf[1], 0b00000110);
289    }
290
291    #[test]
292    fn test_write_byte() {
293        let mut buf = Vec::new();
294        let mut stream = BitStreamWriter::new(&mut buf);
295
296        stream.write_bit(true); // 1 bit
297        stream.write_byte(0xAA); // should align and write full byte
298
299        assert_eq!(buf.len(), 2);
300        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
301        assert_eq!(buf[1], 0xAA); // full byte
302    }
303
304    #[test]
305    fn test_write_bytes() {
306        let mut buf = Vec::new();
307        let mut stream = BitStreamWriter::new(&mut buf);
308
309        stream.write_bit(true); // 1 bit
310        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
311
312        assert_eq!(buf.len(), 4);
313        assert_eq!(buf[0], 0b00000001); // padding of first bit
314        assert_eq!(buf[1], 0xAA);
315        assert_eq!(buf[2], 0xBB);
316        assert_eq!(buf[3], 0xCC);
317    }
318
319    #[test]
320    fn test_alignment() {
321        let mut buf = Vec::new();
322        let mut stream = BitStreamWriter::new(&mut buf);
323
324        stream.write_small(0b11, 2); // 2 bits
325        stream.align_byte();
326        stream.write_byte(0xFF);
327
328        assert_eq!(buf.len(), 2);
329        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
330        assert_eq!(buf[1], 0xFF);
331    }
332
333    #[test]
334    fn test_multiple_operations() {
335        let mut buf = Vec::new();
336        let mut stream = BitStreamWriter::new(&mut buf);
337
338        stream.write_bit(true);
339        stream.write_small(0b101, 3);
340        stream.write_byte(0xAA);
341        stream.write_bytes(&[0xBB, 0xCC]);
342        stream.write_small(0b11, 2);
343
344        let bin = buffer_to_bin(&buf);
345        println!("{:?}", bin);
346
347        assert_eq!(buf.len(), 5);
348        assert_eq!(buf[0], 0b00001011); // first 4 bits
349        assert_eq!(buf[1], 0xAA); // write_byte
350        assert_eq!(buf[2], 0xBB);
351        assert_eq!(buf[3], 0xCC);
352        assert_eq!(buf[4], 0b00000011); // last 2 bits
353    }
354
355    #[test]
356    fn test_write_dyn_int() {
357        let mut buf = Vec::new();
358        let mut stream = BitStreamWriter::new(&mut buf);
359
360        stream.write_dyn_int(127);
361        assert_eq!(1, stream.len());
362
363        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
364        assert_eq!(3, stream.len());
365
366        stream.write_dyn_int(268435455); // 4 bytes boundary
367        assert_eq!(7, stream.len());
368
369        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
370    }
371
372    #[test]
373    fn test_write_fixed_int() {
374        let mut buf = Vec::new();
375        let mut stream = BitStreamWriter::new(&mut buf);
376
377        stream.write_fixed_int(1u8);
378        stream.write_fixed_int(1i8);
379        stream.write_fixed_int(2u16);
380        stream.write_fixed_int(2i16);
381        stream.write_fixed_int(3u32);
382        stream.write_fixed_int(3i32);
383        stream.write_fixed_int(4u64);
384        stream.write_fixed_int(4i64);
385        stream.write_fixed_int(5u128);
386        stream.write_fixed_int(5i128);
387
388        assert_eq!(
389            vec![
390                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
391                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
392                0, 0, 0, 0, 0, 10
393            ],
394            buf
395        );
396    }
397
398    #[test]
399    fn test_slice_marker() {
400        let mut buf = Vec::new();
401        let mut stream = BitStreamWriter::new(&mut buf);
402
403        stream.write_bytes(&[10, 20, 30, 40, 50]);
404        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
405
406        stream.set_marker(Some(2));
407        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
408
409        stream.set_marker(None);
410        assert_eq!(stream.slice_marker(None), &[]);
411    }
412}