1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41 let start = self.marker.unwrap_or(0);
42 let end = to.unwrap_or(self.byte_pos());
43
44 if let Some(crypto) = self.crypto.as_ref() {
45 return &crypto.get_cached(true)[start..end];
46 }
47
48 &self.buffer[start..end]
49 }
50
51 pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53 if let Some(new) = crypto.as_mut() {
54 if let Some(existing) = self.crypto.as_ref() {
55 new.replace(existing);
56 } else {
57 new.set_cached(self.slice());
59 }
60 }
61
62 self.crypto = crypto;
63 }
64
65 pub fn reset_crypto(&mut self) {
67 self.crypto = None;
68 }
69
70 pub fn byte_pos(&self) -> usize {
72 self.bit_pos / 8
73 }
74
75 pub fn write_bit(&mut self, val: bool) {
77 self.write_small(val as u8, 1);
78 }
79
80 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
82 assert!(bits > 0 && bits < 8);
83
84 while bits > 0 {
85 self.ensure_byte();
86
87 let bit_offset = self.bit_pos % 8;
89
90 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97
98 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
105
106 let byte_pos = self.byte_pos();
107
108 self.buffer[byte_pos] &= !mask;
110
111 self.buffer[byte_pos] |= shifted_val & mask;
113
114 bits -= bits_in_current_byte;
116
117 val >>= bits_in_current_byte;
119
120 self.bit_pos += bits_in_current_byte as usize;
121
122 if self.bit_pos % 8 == 0 {
124 if let Some(crypto) = self.crypto.as_mut() {
125 let b = self.buffer[byte_pos];
126 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
127 }
128 }
129 }
130 }
131
132 pub fn write_byte(&mut self, byte: u8) {
134 self.align_byte();
135 self.ensure_byte();
136
137 let byte_pos = self.byte_pos();
138 let byte = if let Some(crypto) = self.crypto.as_mut() {
139 crypto.apply_keystream_byte(byte)
140 } else {
141 byte
142 };
143
144 self.buffer[byte_pos] = byte;
145 self.bit_pos += 8;
146 }
147
148 pub fn write_bytes(&mut self, data: &[u8]) {
150 self.align_byte();
151
152 if let Some(crypto) = self.crypto.as_mut() {
153 let encrypted = crypto.apply_keystream(data);
154 self.buffer.extend_from_slice(encrypted);
155 } else {
156 self.buffer.extend_from_slice(data);
157 }
158
159 self.bit_pos += 8 * data.len();
160 }
161
162 pub fn write_dyn_int(&mut self, mut val: u128) {
165 while val > 0 {
166 let mut encoded = val % 128;
167 val /= 128;
168 if val > 0 {
169 encoded |= 128;
170 }
171 self.write_byte(encoded as u8);
172 }
173 }
174
175 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
177 self.write_bytes(&val.serialize());
178 }
179
180 fn ensure_byte(&mut self) {
182 let byte_pos = self.byte_pos();
183 if byte_pos >= self.buffer.len() {
184 self.buffer.resize(byte_pos + 1, 0);
185 }
186 }
187
188 pub fn align_byte(&mut self) {
190 let rem = self.bit_pos % 8;
191 if rem != 0 {
192 let byte_pos = self.byte_pos();
193 self.bit_pos += 8 - rem;
194
195 if let Some(crypto) = self.crypto.as_mut() {
197 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
198 }
199 }
200 }
201
202 pub fn reset(&mut self) {
204 self.bit_pos = 0;
205 }
206
207 pub fn len(&self) -> usize {
209 self.buffer.len()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use crate::CryptoStream;
216
217 use super::BitStreamWriter;
218
219 struct PlusOneEncrypter {
220 ciphertext: Vec<u8>
221 }
222
223 impl CryptoStream for PlusOneEncrypter {
224 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
225 self.ciphertext.push(b + 1);
226 *self.ciphertext.last().unwrap()
227 }
228
229 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
230 let d = slice.iter().map(|s|s + 1);
231 self.ciphertext.extend(d);
232 &self.ciphertext[self.ciphertext.len() - slice.len()..]
233 }
234
235 fn get_cached(&self, original: bool) -> &[u8] {
236 &[]
237 }
238
239 fn replace(&mut self, other: &Box<dyn CryptoStream>) {
240 self.ciphertext = other.get_cached(true).to_vec();
241 }
242
243 fn set_cached(&mut self, data: &[u8]) {
244 self.ciphertext = data.to_vec();
245 }
246 }
247
248 #[test]
249 fn test_encrypt_bytes() {
250 let mut buf = Vec::new();
251 let mut writer = BitStreamWriter::new(&mut buf);
252 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
253
254 writer.write_byte(1);
255 writer.write_byte(2);
256 writer.write_byte(3);
257 writer.write_bit(false);
258 writer.write_bit(false);
259 writer.write_bit(true);
260 writer.write_bytes(&[5,6,7,8,9]);
261 writer.write_byte(10);
262
263 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
264 }
265
266
267 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
269 buffer.iter().map(|b| format!("{:08b}", b)).collect()
270 }
271
272 #[test]
273 fn test_write_bit() {
274 let mut buf = Vec::new();
275 let mut stream = BitStreamWriter::new(&mut buf);
276
277 stream.write_bit(true);
278 stream.write_bit(false);
279 stream.write_bit(true);
280 stream.write_bit(true); assert_eq!(buf.len(), 1);
283 assert_eq!(buf[0], 0b00001101); }
285
286 #[test]
287 fn test_write_small() {
288 let mut buf = Vec::new();
289 let mut stream = BitStreamWriter::new(&mut buf);
290
291 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
296 assert_eq!(buf[0], 0b11111101); }
298
299 #[test]
300 fn test_write_cross_byte() {
301 let mut buf = Vec::new();
302 let mut stream = BitStreamWriter::new(&mut buf);
303
304 stream.write_small(0b00101011, 7);
306 stream.write_small(0b1101, 4);
307
308 assert_eq!(buf.len(), 2);
309 assert_eq!(buf[0], 0b10101011);
310 assert_eq!(buf[1], 0b00000110);
311 }
312
313 #[test]
314 fn test_write_byte() {
315 let mut buf = Vec::new();
316 let mut stream = BitStreamWriter::new(&mut buf);
317
318 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
322 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
325
326 #[test]
327 fn test_write_bytes() {
328 let mut buf = Vec::new();
329 let mut stream = BitStreamWriter::new(&mut buf);
330
331 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
335 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
337 assert_eq!(buf[2], 0xBB);
338 assert_eq!(buf[3], 0xCC);
339 }
340
341 #[test]
342 fn test_alignment() {
343 let mut buf = Vec::new();
344 let mut stream = BitStreamWriter::new(&mut buf);
345
346 stream.write_small(0b11, 2); stream.align_byte();
348 stream.write_byte(0xFF);
349
350 assert_eq!(buf.len(), 2);
351 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
353 }
354
355 #[test]
356 fn test_multiple_operations() {
357 let mut buf = Vec::new();
358 let mut stream = BitStreamWriter::new(&mut buf);
359
360 stream.write_bit(true);
361 stream.write_small(0b101, 3);
362 stream.write_byte(0xAA);
363 stream.write_bytes(&[0xBB, 0xCC]);
364 stream.write_small(0b11, 2);
365
366 let bin = buffer_to_bin(&buf);
367 println!("{:?}", bin);
368
369 assert_eq!(buf.len(), 5);
370 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
373 assert_eq!(buf[3], 0xCC);
374 assert_eq!(buf[4], 0b00000011); }
376
377 #[test]
378 fn test_write_dyn_int() {
379 let mut buf = Vec::new();
380 let mut stream = BitStreamWriter::new(&mut buf);
381
382 stream.write_dyn_int(127);
383 assert_eq!(1, stream.len());
384
385 stream.write_dyn_int(128); assert_eq!(3, stream.len());
387
388 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
390
391 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
392 }
393
394 #[test]
395 fn test_write_fixed_int() {
396 let mut buf = Vec::new();
397 let mut stream = BitStreamWriter::new(&mut buf);
398
399 stream.write_fixed_int(1u8);
400 stream.write_fixed_int(1i8);
401 stream.write_fixed_int(2u16);
402 stream.write_fixed_int(2i16);
403 stream.write_fixed_int(3u32);
404 stream.write_fixed_int(3i32);
405 stream.write_fixed_int(4u64);
406 stream.write_fixed_int(4i64);
407 stream.write_fixed_int(5u128);
408 stream.write_fixed_int(5i128);
409
410 assert_eq!(
411 vec![
412 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
413 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
414 0, 0, 0, 0, 0, 10
415 ],
416 buf
417 );
418 }
419
420 #[test]
421 fn test_slice_marker() {
422 let mut buf = Vec::new();
423 let mut stream = BitStreamWriter::new(&mut buf);
424
425 stream.write_bytes(&[10, 20, 30, 40, 50]);
426 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
427
428 stream.set_marker(Some(2));
429 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
430
431 stream.set_marker(None);
432 assert_eq!(stream.slice_marker(None), &[]);
433 }
434}