Skip to main content

binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9    marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13    /// Create a new LSB-first writer
14    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            marker: None,
20        }
21    }
22
23    /// Return slice of buffer
24    pub fn slice(&self) -> &[u8] {
25        &self.buffer
26    }
27
28    /// Set marker at specific position or current byte
29    pub fn set_marker(&mut self, pos: Option<usize>) {
30        self.marker = Some(pos.unwrap_or(self.byte_pos()));
31    }
32
33    /// Unset marker
34    pub fn reset_marker(&mut self) {
35        self.marker = None;
36    }
37
38    /// Return slice from marker (or start if marker is unset) to specific position or current byte
39    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
40        &self.buffer[self.marker.unwrap_or(0)..to.unwrap_or(self.byte_pos())]
41    }
42
43    /// Set crypto stream
44    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
45        self.crypto = crypto;
46    }
47
48    /// Get byte position of writer
49    pub fn byte_pos(&self) -> usize {
50        self.bit_pos / 8
51    }
52
53    /// Write a single bit
54    pub fn write_bit(&mut self, val: bool) {
55        self.write_small(val as u8, 1);
56    }
57
58    /// Write 1-8 bits (LSB-first)
59    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
60        assert!(bits > 0 && bits < 8);
61
62        while bits > 0 {
63            self.ensure_byte();
64
65            // Determine the current bit position inside the current byte (0..7)
66            let bit_offset = self.bit_pos % 8;
67
68            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
69            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
70
71            // Create a mask for the bits we are going to modify.
72            // Example: bit_offset = 2, bits_in_current_byte = 3
73            // mask = 00011100 (we will only touch bits 2,3,4)
74            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
75
76            // Prepare the bits to write:
77            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
78            // 2️⃣ Shift them left by bit_offset to align inside the byte
79            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
80            // Step 1: val & 0b111 = 0b101
81            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
82            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
83
84            let byte_pos = self.byte_pos();
85
86            // Clear the bits in this byte where we will write
87            self.buffer[byte_pos] &= !mask;
88
89            // Write the new bits into the cleared positions
90            self.buffer[byte_pos] |= shifted_val & mask;
91
92            // Decrease the number of bits remaining to write
93            bits -= bits_in_current_byte;
94
95            // Shift val right to remove the bits we've already written
96            val >>= bits_in_current_byte;
97
98            self.bit_pos += bits_in_current_byte as usize;
99
100            // If full byte, encrypt it (if needed)
101            if self.bit_pos % 8 == 0 {
102                if let Some(crypto) = self.crypto.as_mut() {
103                    let b = self.buffer[byte_pos];
104                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
105                }
106            }
107        }
108    }
109
110    /// Write a full byte, starting at the next byte boundary
111    pub fn write_byte(&mut self, byte: u8) {
112        self.align_byte();
113        self.ensure_byte();
114
115        let byte_pos = self.byte_pos();
116        let byte = if let Some(crypto) = self.crypto.as_mut() {
117            crypto.apply_keystream_byte(byte)
118        } else {
119            byte
120        };
121
122        self.buffer[byte_pos] = byte;
123        self.bit_pos += 8;
124    }
125
126    /// Write a slice of bytes, starting at the next byte boundary
127    pub fn write_bytes(&mut self, data: &[u8]) {
128        self.align_byte();
129
130        if let Some(crypto) = self.crypto.as_mut() {
131            let encrypted = crypto.apply_keystream(data);
132            self.buffer.extend_from_slice(encrypted);
133        } else {
134            self.buffer.extend_from_slice(data);
135        }
136
137        self.bit_pos += 8 * data.len();
138    }
139
140    /// Write a dynamic int, starting at the next byte bounary
141    /// The last bit is used as a continuation flag for the next byte
142    pub fn write_dyn_int(&mut self, mut val: u128) {
143        while val > 0 {
144            let mut encoded = val % 128;
145            val /= 128;
146            if val > 0 {
147                encoded |= 128;
148            }
149            self.write_byte(encoded as u8);
150        }
151    }
152
153    /// Write a integer of fixed size to the buffer
154    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
155        self.write_bytes(&val.serialize());
156    }
157
158    /// Ensure the buffer has at least `byte_pos + 1` bytes
159    fn ensure_byte(&mut self) {
160        let byte_pos = self.byte_pos();
161        if byte_pos >= self.buffer.len() {
162            self.buffer.resize(byte_pos + 1, 0);
163        }
164    }
165
166    /// Align the stream to the next byte boundary
167    pub fn align_byte(&mut self) {
168        let rem = self.bit_pos % 8;
169        if rem != 0 {
170            let byte_pos = self.byte_pos();
171            self.bit_pos += 8 - rem;
172
173            // Encrypt byte
174            if let Some(crypto) = self.crypto.as_mut() {
175                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
176            }
177        }
178    }
179
180    /// Reset writing position
181    pub fn reset(&mut self) {
182        self.bit_pos = 0;
183    }
184
185    /// Get buffer length
186    pub fn len(&self) -> usize {
187        self.buffer.len()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use crate::CryptoStream;
194
195    use super::BitStreamWriter;
196
197    struct PlusOneEncrypter {
198        ciphertext: Vec<u8>
199    }
200
201    impl CryptoStream for PlusOneEncrypter {
202        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
203            self.ciphertext.push(b + 1);
204            *self.ciphertext.last().unwrap()
205        }
206    
207        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
208            let d = slice.iter().map(|s|s + 1);
209            self.ciphertext.extend(d);
210            &self.ciphertext[self.ciphertext.len() - slice.len()..]
211        }
212    }
213
214    #[test]
215    fn test_encrypt_bytes() {
216        let mut buf = Vec::new();
217        let mut writer = BitStreamWriter::new(&mut buf);
218        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
219
220        writer.write_byte(1);
221        writer.write_byte(2);
222        writer.write_byte(3);
223        writer.write_bit(false);
224        writer.write_bit(false);
225        writer.write_bit(true);
226        writer.write_bytes(&[5,6,7,8,9]);
227        writer.write_byte(10);
228
229        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
230    }
231
232
233    /// Helper to format buffer as binary strings
234    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
235        buffer.iter().map(|b| format!("{:08b}", b)).collect()
236    }
237
238    #[test]
239    fn test_write_bit() {
240        let mut buf = Vec::new();
241        let mut stream = BitStreamWriter::new(&mut buf);
242
243        stream.write_bit(true);
244        stream.write_bit(false);
245        stream.write_bit(true);
246        stream.write_bit(true); // 4 bits
247
248        assert_eq!(buf.len(), 1);
249        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
250    }
251
252    #[test]
253    fn test_write_small() {
254        let mut buf = Vec::new();
255        let mut stream = BitStreamWriter::new(&mut buf);
256
257        stream.write_small(0b101, 3); // write 3 bits
258        stream.write_small(0b11, 2); // write 2 bits
259        stream.write_small(0b111, 3); // write 3 bits
260
261        assert_eq!(buf.len(), 1);
262        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
263    }
264
265    #[test]
266    fn test_write_cross_byte() {
267        let mut buf = Vec::new();
268        let mut stream = BitStreamWriter::new(&mut buf);
269
270        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
271        stream.write_small(0b00101011, 7);
272        stream.write_small(0b1101, 4);
273
274        assert_eq!(buf.len(), 2);
275        assert_eq!(buf[0], 0b10101011);
276        assert_eq!(buf[1], 0b00000110);
277    }
278
279    #[test]
280    fn test_write_byte() {
281        let mut buf = Vec::new();
282        let mut stream = BitStreamWriter::new(&mut buf);
283
284        stream.write_bit(true); // 1 bit
285        stream.write_byte(0xAA); // should align and write full byte
286
287        assert_eq!(buf.len(), 2);
288        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
289        assert_eq!(buf[1], 0xAA); // full byte
290    }
291
292    #[test]
293    fn test_write_bytes() {
294        let mut buf = Vec::new();
295        let mut stream = BitStreamWriter::new(&mut buf);
296
297        stream.write_bit(true); // 1 bit
298        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
299
300        assert_eq!(buf.len(), 4);
301        assert_eq!(buf[0], 0b00000001); // padding of first bit
302        assert_eq!(buf[1], 0xAA);
303        assert_eq!(buf[2], 0xBB);
304        assert_eq!(buf[3], 0xCC);
305    }
306
307    #[test]
308    fn test_alignment() {
309        let mut buf = Vec::new();
310        let mut stream = BitStreamWriter::new(&mut buf);
311
312        stream.write_small(0b11, 2); // 2 bits
313        stream.align_byte();
314        stream.write_byte(0xFF);
315
316        assert_eq!(buf.len(), 2);
317        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
318        assert_eq!(buf[1], 0xFF);
319    }
320
321    #[test]
322    fn test_multiple_operations() {
323        let mut buf = Vec::new();
324        let mut stream = BitStreamWriter::new(&mut buf);
325
326        stream.write_bit(true);
327        stream.write_small(0b101, 3);
328        stream.write_byte(0xAA);
329        stream.write_bytes(&[0xBB, 0xCC]);
330        stream.write_small(0b11, 2);
331
332        let bin = buffer_to_bin(&buf);
333        println!("{:?}", bin);
334
335        assert_eq!(buf.len(), 5);
336        assert_eq!(buf[0], 0b00001011); // first 4 bits
337        assert_eq!(buf[1], 0xAA); // write_byte
338        assert_eq!(buf[2], 0xBB);
339        assert_eq!(buf[3], 0xCC);
340        assert_eq!(buf[4], 0b00000011); // last 2 bits
341    }
342
343    #[test]
344    fn test_write_dyn_int() {
345        let mut buf = Vec::new();
346        let mut stream = BitStreamWriter::new(&mut buf);
347
348        stream.write_dyn_int(127);
349        assert_eq!(1, stream.len());
350
351        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
352        assert_eq!(3, stream.len());
353
354        stream.write_dyn_int(268435455); // 4 bytes boundary
355        assert_eq!(7, stream.len());
356
357        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
358    }
359
360    #[test]
361    fn test_write_fixed_int() {
362        let mut buf = Vec::new();
363        let mut stream = BitStreamWriter::new(&mut buf);
364
365        stream.write_fixed_int(1u8);
366        stream.write_fixed_int(1i8);
367        stream.write_fixed_int(2u16);
368        stream.write_fixed_int(2i16);
369        stream.write_fixed_int(3u32);
370        stream.write_fixed_int(3i32);
371        stream.write_fixed_int(4u64);
372        stream.write_fixed_int(4i64);
373        stream.write_fixed_int(5u128);
374        stream.write_fixed_int(5i128);
375
376        assert_eq!(
377            vec![
378                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
379                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
380                0, 0, 0, 0, 0, 10
381            ],
382            buf
383        );
384    }
385
386    #[test]
387    fn test_slice_marker() {
388        let mut buf = Vec::new();
389        let mut stream = BitStreamWriter::new(&mut buf);
390
391        stream.write_bytes(&[10, 20, 30, 40, 50]);
392        assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
393
394        stream.set_marker(Some(2));
395        assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
396
397        stream.set_marker(None);
398        assert_eq!(stream.slice_marker(None), &[]);
399    }
400}