binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9}
10
11impl<'a> BitStreamWriter<'a> {
12    /// Create a new LSB-first writer
13    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
14        Self {
15            buffer,
16            bit_pos: 0,
17            crypto: None,
18        }
19    }
20
21    /// Return slice of buffer
22    pub fn slice(&self) -> &[u8] {
23        &self.buffer
24    }
25
26    /// Set crypto stream
27    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
28        self.crypto = crypto;
29    }
30
31    /// Get byte position of writer
32    pub fn byte_pos(&self) -> usize {
33        self.bit_pos / 8
34    }
35
36    /// Write a single bit
37    pub fn write_bit(&mut self, val: bool) {
38        self.write_small(val as u8, 1);
39    }
40
41    /// Write 1-8 bits (LSB-first)
42    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
43        assert!(bits > 0 && bits < 8);
44
45        while bits > 0 {
46            self.ensure_byte();
47
48            // Determine the current bit position inside the current byte (0..7)
49            let bit_offset = self.bit_pos % 8;
50
51            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
52            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
53
54            // Create a mask for the bits we are going to modify.
55            // Example: bit_offset = 2, bits_in_current_byte = 3
56            // mask = 00011100 (we will only touch bits 2,3,4)
57            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
58
59            // Prepare the bits to write:
60            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
61            // 2️⃣ Shift them left by bit_offset to align inside the byte
62            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
63            // Step 1: val & 0b111 = 0b101
64            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
65            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
66
67            let byte_pos = self.byte_pos();
68
69            // Clear the bits in this byte where we will write
70            self.buffer[byte_pos] &= !mask;
71
72            // Write the new bits into the cleared positions
73            self.buffer[byte_pos] |= shifted_val & mask;
74
75            // Decrease the number of bits remaining to write
76            bits -= bits_in_current_byte;
77
78            // Shift val right to remove the bits we've already written
79            val >>= bits_in_current_byte;
80
81            self.bit_pos += bits_in_current_byte as usize;
82
83            // If full byte, encrypt it (if needed)
84            if self.bit_pos % 8 == 0 {
85                if let Some(crypto) = self.crypto.as_mut() {
86                    let b = self.buffer[byte_pos];
87                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
88                }
89            }
90        }
91    }
92
93    /// Write a full byte, starting at the next byte boundary
94    pub fn write_byte(&mut self, byte: u8) {
95        self.align_byte();
96        self.ensure_byte();
97
98        let byte_pos = self.byte_pos();
99        let byte = if let Some(crypto) = self.crypto.as_mut() {
100            crypto.apply_keystream_byte(byte)
101        } else {
102            byte
103        };
104
105        self.buffer[byte_pos] = byte;
106        self.bit_pos += 8;
107    }
108
109    /// Write a slice of bytes, starting at the next byte boundary
110    pub fn write_bytes(&mut self, data: &[u8]) {
111        self.align_byte();
112
113        if let Some(crypto) = self.crypto.as_mut() {
114            let encrypted = crypto.apply_keystream(data);
115            self.buffer.extend_from_slice(encrypted);
116        } else {
117            self.buffer.extend_from_slice(data);
118        }
119
120        self.bit_pos += 8 * data.len();
121    }
122
123    /// Write a dynamic int, starting at the next byte bounary
124    /// The last bit is used as a continuation flag for the next byte
125    pub fn write_dyn_int(&mut self, mut val: u128) {
126        while val > 0 {
127            let mut encoded = val % 128;
128            val /= 128;
129            if val > 0 {
130                encoded |= 128;
131            }
132            self.write_byte(encoded as u8);
133        }
134    }
135
136    /// Write a integer of fixed size to the buffer
137    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
138        self.write_bytes(&val.serialize());
139    }
140
141    /// Ensure the buffer has at least `byte_pos + 1` bytes
142    fn ensure_byte(&mut self) {
143        let byte_pos = self.byte_pos();
144        if byte_pos >= self.buffer.len() {
145            self.buffer.resize(byte_pos + 1, 0);
146        }
147    }
148
149    /// Align the stream to the next byte boundary
150    pub fn align_byte(&mut self) {
151        let rem = self.bit_pos % 8;
152        if rem != 0 {
153            let byte_pos = self.byte_pos();
154            self.bit_pos += 8 - rem;
155
156            // Encrypt byte
157            if let Some(crypto) = self.crypto.as_mut() {
158                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
159            }
160        }
161    }
162
163    /// Reset writing position
164    pub fn reset(&mut self) {
165        self.bit_pos = 0;
166    }
167
168    /// Get buffer length
169    pub fn len(&self) -> usize {
170        self.buffer.len()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use crate::CryptoStream;
177
178    use super::BitStreamWriter;
179
180    struct PlusOneEncrypter {
181        ciphertext: Vec<u8>
182    }
183
184    impl CryptoStream for PlusOneEncrypter {
185        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
186            self.ciphertext.push(b + 1);
187            *self.ciphertext.last().unwrap()
188        }
189    
190        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
191            let d = slice.iter().map(|s|s + 1);
192            self.ciphertext.extend(d);
193            &self.ciphertext[self.ciphertext.len() - slice.len()..]
194        }
195    }
196
197    #[test]
198    fn test_encrypt_bytes() {
199        let mut buf = Vec::new();
200        let mut writer = BitStreamWriter::new(&mut buf);
201        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
202
203        writer.write_byte(1);
204        writer.write_byte(2);
205        writer.write_byte(3);
206        writer.write_bit(false);
207        writer.write_bit(false);
208        writer.write_bit(true);
209        writer.write_bytes(&[5,6,7,8,9]);
210        writer.write_byte(10);
211
212        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
213    }
214
215
216    /// Helper to format buffer as binary strings
217    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
218        buffer.iter().map(|b| format!("{:08b}", b)).collect()
219    }
220
221    #[test]
222    fn test_write_bit() {
223        let mut buf = Vec::new();
224        let mut stream = BitStreamWriter::new(&mut buf);
225
226        stream.write_bit(true);
227        stream.write_bit(false);
228        stream.write_bit(true);
229        stream.write_bit(true); // 4 bits
230
231        assert_eq!(buf.len(), 1);
232        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
233    }
234
235    #[test]
236    fn test_write_small() {
237        let mut buf = Vec::new();
238        let mut stream = BitStreamWriter::new(&mut buf);
239
240        stream.write_small(0b101, 3); // write 3 bits
241        stream.write_small(0b11, 2); // write 2 bits
242        stream.write_small(0b111, 3); // write 3 bits
243
244        assert_eq!(buf.len(), 1);
245        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
246    }
247
248    #[test]
249    fn test_write_cross_byte() {
250        let mut buf = Vec::new();
251        let mut stream = BitStreamWriter::new(&mut buf);
252
253        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
254        stream.write_small(0b00101011, 7);
255        stream.write_small(0b1101, 4);
256
257        assert_eq!(buf.len(), 2);
258        assert_eq!(buf[0], 0b10101011);
259        assert_eq!(buf[1], 0b00000110);
260    }
261
262    #[test]
263    fn test_write_byte() {
264        let mut buf = Vec::new();
265        let mut stream = BitStreamWriter::new(&mut buf);
266
267        stream.write_bit(true); // 1 bit
268        stream.write_byte(0xAA); // should align and write full byte
269
270        assert_eq!(buf.len(), 2);
271        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
272        assert_eq!(buf[1], 0xAA); // full byte
273    }
274
275    #[test]
276    fn test_write_bytes() {
277        let mut buf = Vec::new();
278        let mut stream = BitStreamWriter::new(&mut buf);
279
280        stream.write_bit(true); // 1 bit
281        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
282
283        assert_eq!(buf.len(), 4);
284        assert_eq!(buf[0], 0b00000001); // padding of first bit
285        assert_eq!(buf[1], 0xAA);
286        assert_eq!(buf[2], 0xBB);
287        assert_eq!(buf[3], 0xCC);
288    }
289
290    #[test]
291    fn test_alignment() {
292        let mut buf = Vec::new();
293        let mut stream = BitStreamWriter::new(&mut buf);
294
295        stream.write_small(0b11, 2); // 2 bits
296        stream.align_byte();
297        stream.write_byte(0xFF);
298
299        assert_eq!(buf.len(), 2);
300        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
301        assert_eq!(buf[1], 0xFF);
302    }
303
304    #[test]
305    fn test_multiple_operations() {
306        let mut buf = Vec::new();
307        let mut stream = BitStreamWriter::new(&mut buf);
308
309        stream.write_bit(true);
310        stream.write_small(0b101, 3);
311        stream.write_byte(0xAA);
312        stream.write_bytes(&[0xBB, 0xCC]);
313        stream.write_small(0b11, 2);
314
315        let bin = buffer_to_bin(&buf);
316        println!("{:?}", bin);
317
318        assert_eq!(buf.len(), 5);
319        assert_eq!(buf[0], 0b00001011); // first 4 bits
320        assert_eq!(buf[1], 0xAA); // write_byte
321        assert_eq!(buf[2], 0xBB);
322        assert_eq!(buf[3], 0xCC);
323        assert_eq!(buf[4], 0b00000011); // last 2 bits
324    }
325
326    #[test]
327    fn test_write_dyn_int() {
328        let mut buf = Vec::new();
329        let mut stream = BitStreamWriter::new(&mut buf);
330
331        stream.write_dyn_int(127);
332        assert_eq!(1, stream.len());
333
334        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
335        assert_eq!(3, stream.len());
336
337        stream.write_dyn_int(268435455); // 4 bytes boundary
338        assert_eq!(7, stream.len());
339
340        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
341    }
342
343    #[test]
344    fn test_write_fixed_int() {
345        let mut buf = Vec::new();
346        let mut stream = BitStreamWriter::new(&mut buf);
347
348        stream.write_fixed_int(1u8);
349        stream.write_fixed_int(1i8);
350        stream.write_fixed_int(2u16);
351        stream.write_fixed_int(2i16);
352        stream.write_fixed_int(3u32);
353        stream.write_fixed_int(3i32);
354        stream.write_fixed_int(4u64);
355        stream.write_fixed_int(4i64);
356        stream.write_fixed_int(5u128);
357        stream.write_fixed_int(5i128);
358
359        assert_eq!(
360            vec![
361                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
362                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
363                0, 0, 0, 0, 0, 10
364            ],
365            buf
366        );
367    }
368}