1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9}
10
11impl<'a> BitStreamWriter<'a> {
12 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
14 Self {
15 buffer,
16 bit_pos: 0,
17 crypto: None,
18 }
19 }
20
21 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
23 self.crypto = crypto;
24 }
25
26 pub fn byte_pos(&self) -> usize {
28 self.bit_pos / 8
29 }
30
31 pub fn write_bit(&mut self, val: bool) {
33 self.write_small(val as u8, 1);
34 }
35
36 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
38 assert!(bits > 0 && bits < 8);
39
40 while bits > 0 {
41 self.ensure_byte();
42
43 let bit_offset = self.bit_pos % 8;
45
46 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
48
49 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
53
54 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
61
62 let byte_pos = self.byte_pos();
63
64 self.buffer[byte_pos] &= !mask;
66
67 self.buffer[byte_pos] |= shifted_val & mask;
69
70 bits -= bits_in_current_byte;
72
73 val >>= bits_in_current_byte;
75
76 self.bit_pos += bits_in_current_byte as usize;
77
78 if self.bit_pos % 8 == 0 {
80 if let Some(crypto) = self.crypto.as_mut() {
81 let b = self.buffer[byte_pos];
82 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
83 }
84 }
85 }
86 }
87
88 pub fn write_byte(&mut self, byte: u8) {
90 self.align_byte();
91 self.ensure_byte();
92
93 let byte_pos = self.byte_pos();
94 let byte = if let Some(crypto) = self.crypto.as_mut() {
95 crypto.apply_keystream_byte(byte)
96 } else {
97 byte
98 };
99
100 self.buffer[byte_pos] = byte;
101 self.bit_pos += 8;
102 }
103
104 pub fn write_bytes(&mut self, data: &[u8]) {
106 self.align_byte();
107
108 if let Some(crypto) = self.crypto.as_mut() {
109 let encrypted = crypto.apply_keystream(data);
110 self.buffer.extend_from_slice(encrypted);
111 } else {
112 self.buffer.extend_from_slice(data);
113 }
114
115 self.bit_pos += 8 * data.len();
116 }
117
118 pub fn write_dyn_int(&mut self, mut val: u128) {
121 while val > 0 {
122 let mut encoded = val % 128;
123 val /= 128;
124 if val > 0 {
125 encoded |= 128;
126 }
127 self.write_byte(encoded as u8);
128 }
129 }
130
131 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
133 self.write_bytes(&val.serialize());
134 }
135
136 fn ensure_byte(&mut self) {
138 let byte_pos = self.byte_pos();
139 if byte_pos >= self.buffer.len() {
140 self.buffer.resize(byte_pos + 1, 0);
141 }
142 }
143
144 pub fn align_byte(&mut self) {
146 let rem = self.bit_pos % 8;
147 if rem != 0 {
148 let byte_pos = self.byte_pos();
149 self.bit_pos += 8 - rem;
150
151 if let Some(crypto) = self.crypto.as_mut() {
153 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
154 }
155 }
156 }
157
158 pub fn reset(&mut self) {
160 self.bit_pos = 0;
161 }
162
163 pub fn len(&self) -> usize {
165 self.buffer.len()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use crate::CryptoStream;
172
173 use super::BitStreamWriter;
174
175 struct PlusOneEncrypter {
176 ciphertext: Vec<u8>
177 }
178
179 impl CryptoStream for PlusOneEncrypter {
180 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
181 self.ciphertext.push(b + 1);
182 *self.ciphertext.last().unwrap()
183 }
184
185 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
186 let d = slice.iter().map(|s|s + 1);
187 self.ciphertext.extend(d);
188 &self.ciphertext[self.ciphertext.len() - slice.len()..]
189 }
190 }
191
192 #[test]
193 fn test_encrypt_bytes() {
194 let mut buf = Vec::new();
195 let mut writer = BitStreamWriter::new(&mut buf);
196 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
197
198 writer.write_byte(1);
199 writer.write_byte(2);
200 writer.write_byte(3);
201 writer.write_bit(false);
202 writer.write_bit(false);
203 writer.write_bit(true);
204 writer.write_bytes(&[5,6,7,8,9]);
205 writer.write_byte(10);
206
207 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
208 }
209
210
211 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
213 buffer.iter().map(|b| format!("{:08b}", b)).collect()
214 }
215
216 #[test]
217 fn test_write_bit() {
218 let mut buf = Vec::new();
219 let mut stream = BitStreamWriter::new(&mut buf);
220
221 stream.write_bit(true);
222 stream.write_bit(false);
223 stream.write_bit(true);
224 stream.write_bit(true); assert_eq!(buf.len(), 1);
227 assert_eq!(buf[0], 0b00001101); }
229
230 #[test]
231 fn test_write_small() {
232 let mut buf = Vec::new();
233 let mut stream = BitStreamWriter::new(&mut buf);
234
235 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
240 assert_eq!(buf[0], 0b11111101); }
242
243 #[test]
244 fn test_write_cross_byte() {
245 let mut buf = Vec::new();
246 let mut stream = BitStreamWriter::new(&mut buf);
247
248 stream.write_small(0b00101011, 7);
250 stream.write_small(0b1101, 4);
251
252 assert_eq!(buf.len(), 2);
253 assert_eq!(buf[0], 0b10101011);
254 assert_eq!(buf[1], 0b00000110);
255 }
256
257 #[test]
258 fn test_write_byte() {
259 let mut buf = Vec::new();
260 let mut stream = BitStreamWriter::new(&mut buf);
261
262 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
266 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
269
270 #[test]
271 fn test_write_bytes() {
272 let mut buf = Vec::new();
273 let mut stream = BitStreamWriter::new(&mut buf);
274
275 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
279 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
281 assert_eq!(buf[2], 0xBB);
282 assert_eq!(buf[3], 0xCC);
283 }
284
285 #[test]
286 fn test_alignment() {
287 let mut buf = Vec::new();
288 let mut stream = BitStreamWriter::new(&mut buf);
289
290 stream.write_small(0b11, 2); stream.align_byte();
292 stream.write_byte(0xFF);
293
294 assert_eq!(buf.len(), 2);
295 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
297 }
298
299 #[test]
300 fn test_multiple_operations() {
301 let mut buf = Vec::new();
302 let mut stream = BitStreamWriter::new(&mut buf);
303
304 stream.write_bit(true);
305 stream.write_small(0b101, 3);
306 stream.write_byte(0xAA);
307 stream.write_bytes(&[0xBB, 0xCC]);
308 stream.write_small(0b11, 2);
309
310 let bin = buffer_to_bin(&buf);
311 println!("{:?}", bin);
312
313 assert_eq!(buf.len(), 5);
314 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
317 assert_eq!(buf[3], 0xCC);
318 assert_eq!(buf[4], 0b00000011); }
320
321 #[test]
322 fn test_write_dyn_int() {
323 let mut buf = Vec::new();
324 let mut stream = BitStreamWriter::new(&mut buf);
325
326 stream.write_dyn_int(127);
327 assert_eq!(1, stream.len());
328
329 stream.write_dyn_int(128); assert_eq!(3, stream.len());
331
332 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
334
335 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
336 }
337
338 #[test]
339 fn test_write_fixed_int() {
340 let mut buf = Vec::new();
341 let mut stream = BitStreamWriter::new(&mut buf);
342
343 stream.write_fixed_int(1u8);
344 stream.write_fixed_int(1i8);
345 stream.write_fixed_int(2u16);
346 stream.write_fixed_int(2i16);
347 stream.write_fixed_int(3u32);
348 stream.write_fixed_int(3i32);
349 stream.write_fixed_int(4u64);
350 stream.write_fixed_int(4i64);
351 stream.write_fixed_int(5u128);
352 stream.write_fixed_int(5i128);
353
354 assert_eq!(
355 vec![
356 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
357 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
358 0, 0, 0, 0, 0, 10
359 ],
360 buf
361 );
362 }
363}