1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41 let start = self.marker.unwrap_or(0);
42 let end = to.unwrap_or(self.byte_pos());
43
44 if let Some(crypto) = self.crypto.as_ref() {
45 return &crypto.get_cached(true)[start..end];
46 }
47
48 &self.buffer[start..end]
49 }
50
51 pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53 if let Some(new) = crypto.as_mut() {
54 if let Some(existing) = self.crypto.as_ref() {
55 new.replace(existing);
56 } else {
57 new.set_cached(self.slice());
59 }
60 }
61
62 self.crypto = crypto;
63 }
64
65 pub fn reset_crypto(&mut self) {
67 self.crypto = None;
68 }
69
70 pub fn byte_pos(&self) -> usize {
72 self.bit_pos / 8
73 }
74
75 pub fn write_bit(&mut self, val: bool) {
77 self.write_small(val as u8, 1);
78 }
79
80 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
82 assert!(bits > 0 && bits < 8);
83
84 while bits > 0 {
85 self.ensure_byte();
86
87 let bit_offset = self.bit_pos % 8;
89
90 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97
98 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
105
106 let byte_pos = self.byte_pos();
107
108 self.buffer[byte_pos] &= !mask;
110
111 self.buffer[byte_pos] |= shifted_val & mask;
113
114 bits -= bits_in_current_byte;
116
117 val >>= bits_in_current_byte;
119
120 self.bit_pos += bits_in_current_byte as usize;
121
122 if self.bit_pos % 8 == 0 {
124 if let Some(crypto) = self.crypto.as_mut() {
125 let b = self.buffer[byte_pos];
126 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
127 }
128 }
129 }
130 }
131
132 pub fn write_byte(&mut self, byte: u8) {
134 self.align_byte();
135 self.ensure_byte();
136
137 let byte_pos = self.byte_pos();
138 let byte = if let Some(crypto) = self.crypto.as_mut() {
139 crypto.apply_keystream_byte(byte)
140 } else {
141 byte
142 };
143
144 self.buffer[byte_pos] = byte;
145 self.bit_pos += 8;
146 }
147
148 pub fn write_bytes(&mut self, data: &[u8]) {
150 self.align_byte();
151
152 if let Some(crypto) = self.crypto.as_mut() {
153 let encrypted = crypto.apply_keystream(data);
154 self.buffer.extend_from_slice(encrypted);
155 } else {
156 self.buffer.extend_from_slice(data);
157 }
158
159 self.bit_pos += 8 * data.len();
160 }
161
162 pub fn write_dyn_int(&mut self, mut val: u128) {
165 if val == 0 {
166 self.write_byte(0);
167 return;
168 }
169
170 while val > 0 {
171 let mut encoded = val % 128;
172 val /= 128;
173 if val > 0 {
174 encoded |= 128;
175 }
176 self.write_byte(encoded as u8);
177 }
178 }
179
180 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
182 self.write_bytes(&val.serialize());
183 }
184
185 fn ensure_byte(&mut self) {
187 let byte_pos = self.byte_pos();
188 if byte_pos >= self.buffer.len() {
189 self.buffer.resize(byte_pos + 1, 0);
190 }
191 }
192
193 pub fn align_byte(&mut self) {
195 let rem = self.bit_pos % 8;
196 if rem != 0 {
197 let byte_pos = self.byte_pos();
198 self.bit_pos += 8 - rem;
199
200 if let Some(crypto) = self.crypto.as_mut() {
202 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
203 }
204 }
205 }
206
207 pub fn reset(&mut self) {
209 self.bit_pos = 0;
210 }
211
212 pub fn len(&self) -> usize {
214 self.buffer.len()
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use crate::CryptoStream;
221
222 use super::BitStreamWriter;
223
224 struct PlusOneEncrypter {
225 ciphertext: Vec<u8>
226 }
227
228 impl CryptoStream for PlusOneEncrypter {
229 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
230 self.ciphertext.push(b + 1);
231 *self.ciphertext.last().unwrap()
232 }
233
234 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
235 let d = slice.iter().map(|s|s + 1);
236 self.ciphertext.extend(d);
237 &self.ciphertext[self.ciphertext.len() - slice.len()..]
238 }
239
240 fn get_cached(&self, original: bool) -> &[u8] {
241 &[]
242 }
243
244 fn replace(&mut self, other: &Box<dyn CryptoStream>) {
245 self.ciphertext = other.get_cached(true).to_vec();
246 }
247
248 fn set_cached(&mut self, data: &[u8]) {
249 self.ciphertext = data.to_vec();
250 }
251 }
252
253 #[test]
254 fn test_encrypt_bytes() {
255 let mut buf = Vec::new();
256 let mut writer = BitStreamWriter::new(&mut buf);
257 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
258
259 writer.write_byte(1);
260 writer.write_byte(2);
261 writer.write_byte(3);
262 writer.write_bit(false);
263 writer.write_bit(false);
264 writer.write_bit(true);
265 writer.write_bytes(&[5,6,7,8,9]);
266 writer.write_byte(10);
267
268 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
269 }
270
271
272 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
274 buffer.iter().map(|b| format!("{:08b}", b)).collect()
275 }
276
277 #[test]
278 fn test_write_bit() {
279 let mut buf = Vec::new();
280 let mut stream = BitStreamWriter::new(&mut buf);
281
282 stream.write_bit(true);
283 stream.write_bit(false);
284 stream.write_bit(true);
285 stream.write_bit(true); assert_eq!(buf.len(), 1);
288 assert_eq!(buf[0], 0b00001101); }
290
291 #[test]
292 fn test_write_small() {
293 let mut buf = Vec::new();
294 let mut stream = BitStreamWriter::new(&mut buf);
295
296 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
301 assert_eq!(buf[0], 0b11111101); }
303
304 #[test]
305 fn test_write_cross_byte() {
306 let mut buf = Vec::new();
307 let mut stream = BitStreamWriter::new(&mut buf);
308
309 stream.write_small(0b00101011, 7);
311 stream.write_small(0b1101, 4);
312
313 assert_eq!(buf.len(), 2);
314 assert_eq!(buf[0], 0b10101011);
315 assert_eq!(buf[1], 0b00000110);
316 }
317
318 #[test]
319 fn test_write_byte() {
320 let mut buf = Vec::new();
321 let mut stream = BitStreamWriter::new(&mut buf);
322
323 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
327 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
330
331 #[test]
332 fn test_write_bytes() {
333 let mut buf = Vec::new();
334 let mut stream = BitStreamWriter::new(&mut buf);
335
336 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
340 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
342 assert_eq!(buf[2], 0xBB);
343 assert_eq!(buf[3], 0xCC);
344 }
345
346 #[test]
347 fn test_alignment() {
348 let mut buf = Vec::new();
349 let mut stream = BitStreamWriter::new(&mut buf);
350
351 stream.write_small(0b11, 2); stream.align_byte();
353 stream.write_byte(0xFF);
354
355 assert_eq!(buf.len(), 2);
356 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
358 }
359
360 #[test]
361 fn test_multiple_operations() {
362 let mut buf = Vec::new();
363 let mut stream = BitStreamWriter::new(&mut buf);
364
365 stream.write_bit(true);
366 stream.write_small(0b101, 3);
367 stream.write_byte(0xAA);
368 stream.write_bytes(&[0xBB, 0xCC]);
369 stream.write_small(0b11, 2);
370
371 let bin = buffer_to_bin(&buf);
372 println!("{:?}", bin);
373
374 assert_eq!(buf.len(), 5);
375 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
378 assert_eq!(buf[3], 0xCC);
379 assert_eq!(buf[4], 0b00000011); }
381
382 #[test]
383 fn test_write_dyn_int() {
384 let mut buf = Vec::new();
385 let mut stream = BitStreamWriter::new(&mut buf);
386
387 stream.write_dyn_int(127);
388 assert_eq!(1, stream.len());
389
390 stream.write_dyn_int(128); assert_eq!(3, stream.len());
392
393 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
395
396 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
397 }
398
399 #[test]
400 fn test_write_fixed_int() {
401 let mut buf = Vec::new();
402 let mut stream = BitStreamWriter::new(&mut buf);
403
404 stream.write_fixed_int(1u8);
405 stream.write_fixed_int(1i8);
406 stream.write_fixed_int(2u16);
407 stream.write_fixed_int(2i16);
408 stream.write_fixed_int(3u32);
409 stream.write_fixed_int(3i32);
410 stream.write_fixed_int(4u64);
411 stream.write_fixed_int(4i64);
412 stream.write_fixed_int(5u128);
413 stream.write_fixed_int(5i128);
414
415 assert_eq!(
416 vec![
417 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
418 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
419 0, 0, 0, 0, 0, 10
420 ],
421 buf
422 );
423 }
424
425 #[test]
426 fn test_slice_marker() {
427 let mut buf = Vec::new();
428 let mut stream = BitStreamWriter::new(&mut buf);
429
430 stream.write_bytes(&[10, 20, 30, 40, 50]);
431 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
432
433 stream.set_marker(Some(2));
434 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
435
436 stream.set_marker(None);
437 assert_eq!(stream.slice_marker(None), &[]);
438 }
439
440 #[test]
441 fn test_write_0_dynint() {
442 let mut buf = Vec::new();
443 let mut stream = BitStreamWriter::new(&mut buf);
444
445 stream.write_dyn_int(0);
446 assert_eq!(1, stream.len());
447 }
448}