binary_codec/bitstream/
writer.rs

1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6    buffer: &'a mut Vec<u8>,
7    bit_pos: usize,
8    crypto: Option<Box<dyn CryptoStream>>,
9}
10
11impl<'a> BitStreamWriter<'a> {
12    /// Create a new LSB-first writer
13    pub fn new(buffer: &'a mut Vec<u8>) -> Self {
14        Self {
15            buffer,
16            bit_pos: 0,
17            crypto: None,
18        }
19    }
20
21    /// Set crypto stream
22    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
23        self.crypto = crypto;
24    }
25
26    /// Get byte position of writer
27    pub fn byte_pos(&self) -> usize {
28        self.bit_pos / 8
29    }
30
31    /// Write a single bit
32    pub fn write_bit(&mut self, val: bool) {
33        self.write_small(val as u8, 1);
34    }
35
36    /// Write 1-8 bits (LSB-first)
37    pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
38        assert!(bits > 0 && bits < 8);
39
40        while bits > 0 {
41            self.ensure_byte();
42
43            // Determine the current bit position inside the current byte (0..7)
44            let bit_offset = self.bit_pos % 8;
45
46            // Determine how many bytes left to read. Min(bits left in this byte, bits to write)
47            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
48
49            // Create a mask for the bits we are going to modify.
50            // Example: bit_offset = 2, bits_in_current_byte = 3
51            // mask = 00011100 (we will only touch bits 2,3,4)
52            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
53
54            // Prepare the bits to write:
55            // 1️⃣ Take only the lowest 'bits_in_current_byte' bits from val
56            // 2️⃣ Shift them left by bit_offset to align inside the byte
57            // Example: val = 0b110101, bits_in_current_byte = 3, bit_offset = 2
58            // Step 1: val & 0b111 = 0b101
59            // Step 2: 0b101 << 2 = 0b10100 → shifted_val
60            let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
61
62            let byte_pos = self.byte_pos();
63
64            // Clear the bits in this byte where we will write
65            self.buffer[byte_pos] &= !mask;
66
67            // Write the new bits into the cleared positions
68            self.buffer[byte_pos] |= shifted_val & mask;
69
70            // Decrease the number of bits remaining to write
71            bits -= bits_in_current_byte;
72
73            // Shift val right to remove the bits we've already written
74            val >>= bits_in_current_byte;
75
76            self.bit_pos += bits_in_current_byte as usize;
77
78            // If full byte, encrypt it (if needed)
79            if self.bit_pos % 8 == 0 {
80                if let Some(crypto) = self.crypto.as_mut() {
81                    let b = self.buffer[byte_pos];
82                    self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
83                }
84            }
85        }
86    }
87
88    /// Write a full byte, starting at the next byte boundary
89    pub fn write_byte(&mut self, byte: u8) {
90        self.align_byte();
91        self.ensure_byte();
92
93        let byte_pos = self.byte_pos();
94        let byte = if let Some(crypto) = self.crypto.as_mut() {
95            crypto.apply_keystream_byte(byte)
96        } else {
97            byte
98        };
99
100        self.buffer[byte_pos] = byte;
101        self.bit_pos += 8;
102    }
103
104    /// Write a slice of bytes, starting at the next byte boundary
105    pub fn write_bytes(&mut self, data: &[u8]) {
106        self.align_byte();
107
108        if let Some(crypto) = self.crypto.as_mut() {
109            let encrypted = crypto.apply_keystream(data);
110            self.buffer.extend_from_slice(encrypted);
111        } else {
112            self.buffer.extend_from_slice(data);
113        }
114
115        self.bit_pos += 8 * data.len();
116    }
117
118    /// Write a dynamic int, starting at the next byte bounary
119    /// The last bit is used as a continuation flag for the next byte
120    pub fn write_dyn_int(&mut self, mut val: u128) {
121        while val > 0 {
122            let mut encoded = val % 128;
123            val /= 128;
124            if val > 0 {
125                encoded |= 128;
126            }
127            self.write_byte(encoded as u8);
128        }
129    }
130
131    /// Write a integer of fixed size to the buffer
132    pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
133        self.write_bytes(&val.serialize());
134    }
135
136    /// Ensure the buffer has at least `byte_pos + 1` bytes
137    fn ensure_byte(&mut self) {
138        let byte_pos = self.byte_pos();
139        if byte_pos >= self.buffer.len() {
140            self.buffer.resize(byte_pos + 1, 0);
141        }
142    }
143
144    /// Align the stream to the next byte boundary
145    pub fn align_byte(&mut self) {
146        let rem = self.bit_pos % 8;
147        if rem != 0 {
148            let byte_pos = self.byte_pos();
149            self.bit_pos += 8 - rem;
150
151            // Encrypt byte
152            if let Some(crypto) = self.crypto.as_mut() {
153                self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
154            }
155        }
156    }
157
158    /// Reset writing position
159    pub fn reset(&mut self) {
160        self.bit_pos = 0;
161    }
162
163    /// Get buffer length
164    pub fn len(&self) -> usize {
165        self.buffer.len()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::CryptoStream;
172
173    use super::BitStreamWriter;
174
175    struct PlusOneEncrypter {
176        ciphertext: Vec<u8>
177    }
178
179    impl CryptoStream for PlusOneEncrypter {
180        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
181            self.ciphertext.push(b + 1);
182            *self.ciphertext.last().unwrap()
183        }
184    
185        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
186            let d = slice.iter().map(|s|s + 1);
187            self.ciphertext.extend(d);
188            &self.ciphertext[self.ciphertext.len() - slice.len()..]
189        }
190    }
191
192    #[test]
193    fn test_encrypt_bytes() {
194        let mut buf = Vec::new();
195        let mut writer = BitStreamWriter::new(&mut buf);
196        writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
197
198        writer.write_byte(1);
199        writer.write_byte(2);
200        writer.write_byte(3);
201        writer.write_bit(false);
202        writer.write_bit(false);
203        writer.write_bit(true);
204        writer.write_bytes(&[5,6,7,8,9]);
205        writer.write_byte(10);
206
207        assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
208    }
209
210
211    /// Helper to format buffer as binary strings
212    fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
213        buffer.iter().map(|b| format!("{:08b}", b)).collect()
214    }
215
216    #[test]
217    fn test_write_bit() {
218        let mut buf = Vec::new();
219        let mut stream = BitStreamWriter::new(&mut buf);
220
221        stream.write_bit(true);
222        stream.write_bit(false);
223        stream.write_bit(true);
224        stream.write_bit(true); // 4 bits
225
226        assert_eq!(buf.len(), 1);
227        assert_eq!(buf[0], 0b00001101); // LSB-first: first bit is lowest
228    }
229
230    #[test]
231    fn test_write_small() {
232        let mut buf = Vec::new();
233        let mut stream = BitStreamWriter::new(&mut buf);
234
235        stream.write_small(0b101, 3); // write 3 bits
236        stream.write_small(0b11, 2); // write 2 bits
237        stream.write_small(0b111, 3); // write 3 bits
238
239        assert_eq!(buf.len(), 1);
240        assert_eq!(buf[0], 0b11111101); // bits packed LSB-first
241    }
242
243    #[test]
244    fn test_write_cross_byte() {
245        let mut buf = Vec::new();
246        let mut stream = BitStreamWriter::new(&mut buf);
247
248        // write 11 bits: first 7 bits fill almost one byte, next 3 bits spill into second byte
249        stream.write_small(0b00101011, 7);
250        stream.write_small(0b1101, 4);
251
252        assert_eq!(buf.len(), 2);
253        assert_eq!(buf[0], 0b10101011);
254        assert_eq!(buf[1], 0b00000110);
255    }
256
257    #[test]
258    fn test_write_byte() {
259        let mut buf = Vec::new();
260        let mut stream = BitStreamWriter::new(&mut buf);
261
262        stream.write_bit(true); // 1 bit
263        stream.write_byte(0xAA); // should align and write full byte
264
265        assert_eq!(buf.len(), 2);
266        assert_eq!(buf[0], 0b00000001); // first bit written, rest padded
267        assert_eq!(buf[1], 0xAA); // full byte
268    }
269
270    #[test]
271    fn test_write_bytes() {
272        let mut buf = Vec::new();
273        let mut stream = BitStreamWriter::new(&mut buf);
274
275        stream.write_bit(true); // 1 bit
276        stream.write_bytes(&[0xAA, 0xBB, 0xCC]); // align and write slice
277
278        assert_eq!(buf.len(), 4);
279        assert_eq!(buf[0], 0b00000001); // padding of first bit
280        assert_eq!(buf[1], 0xAA);
281        assert_eq!(buf[2], 0xBB);
282        assert_eq!(buf[3], 0xCC);
283    }
284
285    #[test]
286    fn test_alignment() {
287        let mut buf = Vec::new();
288        let mut stream = BitStreamWriter::new(&mut buf);
289
290        stream.write_small(0b11, 2); // 2 bits
291        stream.align_byte();
292        stream.write_byte(0xFF);
293
294        assert_eq!(buf.len(), 2);
295        assert_eq!(buf[0], 0b00000011); // 2 bits written, rest padded
296        assert_eq!(buf[1], 0xFF);
297    }
298
299    #[test]
300    fn test_multiple_operations() {
301        let mut buf = Vec::new();
302        let mut stream = BitStreamWriter::new(&mut buf);
303
304        stream.write_bit(true);
305        stream.write_small(0b101, 3);
306        stream.write_byte(0xAA);
307        stream.write_bytes(&[0xBB, 0xCC]);
308        stream.write_small(0b11, 2);
309
310        let bin = buffer_to_bin(&buf);
311        println!("{:?}", bin);
312
313        assert_eq!(buf.len(), 5);
314        assert_eq!(buf[0], 0b00001011); // first 4 bits
315        assert_eq!(buf[1], 0xAA); // write_byte
316        assert_eq!(buf[2], 0xBB);
317        assert_eq!(buf[3], 0xCC);
318        assert_eq!(buf[4], 0b00000011); // last 2 bits
319    }
320
321    #[test]
322    fn test_write_dyn_int() {
323        let mut buf = Vec::new();
324        let mut stream = BitStreamWriter::new(&mut buf);
325
326        stream.write_dyn_int(127);
327        assert_eq!(1, stream.len());
328
329        stream.write_dyn_int(128); // Crossed 127 = boundary of first byte
330        assert_eq!(3, stream.len());
331
332        stream.write_dyn_int(268435455); // 4 bytes boundary
333        assert_eq!(7, stream.len());
334
335        assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
336    }
337
338    #[test]
339    fn test_write_fixed_int() {
340        let mut buf = Vec::new();
341        let mut stream = BitStreamWriter::new(&mut buf);
342
343        stream.write_fixed_int(1u8);
344        stream.write_fixed_int(1i8);
345        stream.write_fixed_int(2u16);
346        stream.write_fixed_int(2i16);
347        stream.write_fixed_int(3u32);
348        stream.write_fixed_int(3i32);
349        stream.write_fixed_int(4u64);
350        stream.write_fixed_int(4i64);
351        stream.write_fixed_int(5u128);
352        stream.write_fixed_int(5i128);
353
354        assert_eq!(
355            vec![
356                1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
357                0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
358                0, 0, 0, 0, 0, 10
359            ],
360            buf
361        );
362    }
363}