1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41 let start = self.marker.unwrap_or(0);
42 let end = to.unwrap_or(self.byte_pos());
43
44 if let Some(crypto) = self.crypto.as_ref() {
45 return &crypto.get_cached(true)[start..end];
46 }
47
48 &self.buffer[start..end]
49 }
50
51 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
53 self.crypto = crypto;
54 }
55
56 pub fn reset_crypto(&mut self) {
58 self.crypto = None;
59 }
60
61 pub fn byte_pos(&self) -> usize {
63 self.bit_pos / 8
64 }
65
66 pub fn write_bit(&mut self, val: bool) {
68 self.write_small(val as u8, 1);
69 }
70
71 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
73 assert!(bits > 0 && bits < 8);
74
75 while bits > 0 {
76 self.ensure_byte();
77
78 let bit_offset = self.bit_pos % 8;
80
81 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
83
84 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
88
89 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
96
97 let byte_pos = self.byte_pos();
98
99 self.buffer[byte_pos] &= !mask;
101
102 self.buffer[byte_pos] |= shifted_val & mask;
104
105 bits -= bits_in_current_byte;
107
108 val >>= bits_in_current_byte;
110
111 self.bit_pos += bits_in_current_byte as usize;
112
113 if self.bit_pos % 8 == 0 {
115 if let Some(crypto) = self.crypto.as_mut() {
116 let b = self.buffer[byte_pos];
117 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
118 }
119 }
120 }
121 }
122
123 pub fn write_byte(&mut self, byte: u8) {
125 self.align_byte();
126 self.ensure_byte();
127
128 let byte_pos = self.byte_pos();
129 let byte = if let Some(crypto) = self.crypto.as_mut() {
130 crypto.apply_keystream_byte(byte)
131 } else {
132 byte
133 };
134
135 self.buffer[byte_pos] = byte;
136 self.bit_pos += 8;
137 }
138
139 pub fn write_bytes(&mut self, data: &[u8]) {
141 self.align_byte();
142
143 if let Some(crypto) = self.crypto.as_mut() {
144 let encrypted = crypto.apply_keystream(data);
145 self.buffer.extend_from_slice(encrypted);
146 } else {
147 self.buffer.extend_from_slice(data);
148 }
149
150 self.bit_pos += 8 * data.len();
151 }
152
153 pub fn write_dyn_int(&mut self, mut val: u128) {
156 while val > 0 {
157 let mut encoded = val % 128;
158 val /= 128;
159 if val > 0 {
160 encoded |= 128;
161 }
162 self.write_byte(encoded as u8);
163 }
164 }
165
166 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
168 self.write_bytes(&val.serialize());
169 }
170
171 fn ensure_byte(&mut self) {
173 let byte_pos = self.byte_pos();
174 if byte_pos >= self.buffer.len() {
175 self.buffer.resize(byte_pos + 1, 0);
176 }
177 }
178
179 pub fn align_byte(&mut self) {
181 let rem = self.bit_pos % 8;
182 if rem != 0 {
183 let byte_pos = self.byte_pos();
184 self.bit_pos += 8 - rem;
185
186 if let Some(crypto) = self.crypto.as_mut() {
188 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
189 }
190 }
191 }
192
193 pub fn reset(&mut self) {
195 self.bit_pos = 0;
196 }
197
198 pub fn len(&self) -> usize {
200 self.buffer.len()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use crate::CryptoStream;
207
208 use super::BitStreamWriter;
209
210 struct PlusOneEncrypter {
211 ciphertext: Vec<u8>
212 }
213
214 impl CryptoStream for PlusOneEncrypter {
215 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
216 self.ciphertext.push(b + 1);
217 *self.ciphertext.last().unwrap()
218 }
219
220 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
221 let d = slice.iter().map(|s|s + 1);
222 self.ciphertext.extend(d);
223 &self.ciphertext[self.ciphertext.len() - slice.len()..]
224 }
225
226 fn get_cached(&self, original: bool) -> &[u8] {
227 &[]
228 }
229
230 fn replace(&mut self, other: Box<dyn CryptoStream>) {
231 self.ciphertext = other.get_cached(true).to_vec();
232 }
233 }
234
235 #[test]
236 fn test_encrypt_bytes() {
237 let mut buf = Vec::new();
238 let mut writer = BitStreamWriter::new(&mut buf);
239 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
240
241 writer.write_byte(1);
242 writer.write_byte(2);
243 writer.write_byte(3);
244 writer.write_bit(false);
245 writer.write_bit(false);
246 writer.write_bit(true);
247 writer.write_bytes(&[5,6,7,8,9]);
248 writer.write_byte(10);
249
250 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
251 }
252
253
254 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
256 buffer.iter().map(|b| format!("{:08b}", b)).collect()
257 }
258
259 #[test]
260 fn test_write_bit() {
261 let mut buf = Vec::new();
262 let mut stream = BitStreamWriter::new(&mut buf);
263
264 stream.write_bit(true);
265 stream.write_bit(false);
266 stream.write_bit(true);
267 stream.write_bit(true); assert_eq!(buf.len(), 1);
270 assert_eq!(buf[0], 0b00001101); }
272
273 #[test]
274 fn test_write_small() {
275 let mut buf = Vec::new();
276 let mut stream = BitStreamWriter::new(&mut buf);
277
278 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
283 assert_eq!(buf[0], 0b11111101); }
285
286 #[test]
287 fn test_write_cross_byte() {
288 let mut buf = Vec::new();
289 let mut stream = BitStreamWriter::new(&mut buf);
290
291 stream.write_small(0b00101011, 7);
293 stream.write_small(0b1101, 4);
294
295 assert_eq!(buf.len(), 2);
296 assert_eq!(buf[0], 0b10101011);
297 assert_eq!(buf[1], 0b00000110);
298 }
299
300 #[test]
301 fn test_write_byte() {
302 let mut buf = Vec::new();
303 let mut stream = BitStreamWriter::new(&mut buf);
304
305 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
309 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
312
313 #[test]
314 fn test_write_bytes() {
315 let mut buf = Vec::new();
316 let mut stream = BitStreamWriter::new(&mut buf);
317
318 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
322 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
324 assert_eq!(buf[2], 0xBB);
325 assert_eq!(buf[3], 0xCC);
326 }
327
328 #[test]
329 fn test_alignment() {
330 let mut buf = Vec::new();
331 let mut stream = BitStreamWriter::new(&mut buf);
332
333 stream.write_small(0b11, 2); stream.align_byte();
335 stream.write_byte(0xFF);
336
337 assert_eq!(buf.len(), 2);
338 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
340 }
341
342 #[test]
343 fn test_multiple_operations() {
344 let mut buf = Vec::new();
345 let mut stream = BitStreamWriter::new(&mut buf);
346
347 stream.write_bit(true);
348 stream.write_small(0b101, 3);
349 stream.write_byte(0xAA);
350 stream.write_bytes(&[0xBB, 0xCC]);
351 stream.write_small(0b11, 2);
352
353 let bin = buffer_to_bin(&buf);
354 println!("{:?}", bin);
355
356 assert_eq!(buf.len(), 5);
357 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
360 assert_eq!(buf[3], 0xCC);
361 assert_eq!(buf[4], 0b00000011); }
363
364 #[test]
365 fn test_write_dyn_int() {
366 let mut buf = Vec::new();
367 let mut stream = BitStreamWriter::new(&mut buf);
368
369 stream.write_dyn_int(127);
370 assert_eq!(1, stream.len());
371
372 stream.write_dyn_int(128); assert_eq!(3, stream.len());
374
375 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
377
378 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
379 }
380
381 #[test]
382 fn test_write_fixed_int() {
383 let mut buf = Vec::new();
384 let mut stream = BitStreamWriter::new(&mut buf);
385
386 stream.write_fixed_int(1u8);
387 stream.write_fixed_int(1i8);
388 stream.write_fixed_int(2u16);
389 stream.write_fixed_int(2i16);
390 stream.write_fixed_int(3u32);
391 stream.write_fixed_int(3i32);
392 stream.write_fixed_int(4u64);
393 stream.write_fixed_int(4i64);
394 stream.write_fixed_int(5u128);
395 stream.write_fixed_int(5i128);
396
397 assert_eq!(
398 vec![
399 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
400 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
401 0, 0, 0, 0, 0, 10
402 ],
403 buf
404 );
405 }
406
407 #[test]
408 fn test_slice_marker() {
409 let mut buf = Vec::new();
410 let mut stream = BitStreamWriter::new(&mut buf);
411
412 stream.write_bytes(&[10, 20, 30, 40, 50]);
413 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
414
415 stream.set_marker(Some(2));
416 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
417
418 stream.set_marker(None);
419 assert_eq!(stream.slice_marker(None), &[]);
420 }
421}