1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41 let start = self.marker.unwrap_or(0);
42 let end = to.unwrap_or(self.byte_pos());
43
44 if let Some(crypto) = self.crypto.as_ref() {
45 return &crypto.get_cached(true)[start..end];
46 }
47
48 &self.buffer[start..end]
49 }
50
51 pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
53 if let Some(existing) = self.crypto.as_ref()
54 && let Some(new_crypto) = crypto.as_mut()
55 {
56 new_crypto.replace(existing);
57 self.crypto = crypto;
58 } else {
59 self.crypto = crypto;
60 }
61 }
62
63 pub fn reset_crypto(&mut self) {
65 self.crypto = None;
66 }
67
68 pub fn byte_pos(&self) -> usize {
70 self.bit_pos / 8
71 }
72
73 pub fn write_bit(&mut self, val: bool) {
75 self.write_small(val as u8, 1);
76 }
77
78 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
80 assert!(bits > 0 && bits < 8);
81
82 while bits > 0 {
83 self.ensure_byte();
84
85 let bit_offset = self.bit_pos % 8;
87
88 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
90
91 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
95
96 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
103
104 let byte_pos = self.byte_pos();
105
106 self.buffer[byte_pos] &= !mask;
108
109 self.buffer[byte_pos] |= shifted_val & mask;
111
112 bits -= bits_in_current_byte;
114
115 val >>= bits_in_current_byte;
117
118 self.bit_pos += bits_in_current_byte as usize;
119
120 if self.bit_pos % 8 == 0 {
122 if let Some(crypto) = self.crypto.as_mut() {
123 let b = self.buffer[byte_pos];
124 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
125 }
126 }
127 }
128 }
129
130 pub fn write_byte(&mut self, byte: u8) {
132 self.align_byte();
133 self.ensure_byte();
134
135 let byte_pos = self.byte_pos();
136 let byte = if let Some(crypto) = self.crypto.as_mut() {
137 crypto.apply_keystream_byte(byte)
138 } else {
139 byte
140 };
141
142 self.buffer[byte_pos] = byte;
143 self.bit_pos += 8;
144 }
145
146 pub fn write_bytes(&mut self, data: &[u8]) {
148 self.align_byte();
149
150 if let Some(crypto) = self.crypto.as_mut() {
151 let encrypted = crypto.apply_keystream(data);
152 self.buffer.extend_from_slice(encrypted);
153 } else {
154 self.buffer.extend_from_slice(data);
155 }
156
157 self.bit_pos += 8 * data.len();
158 }
159
160 pub fn write_dyn_int(&mut self, mut val: u128) {
163 while val > 0 {
164 let mut encoded = val % 128;
165 val /= 128;
166 if val > 0 {
167 encoded |= 128;
168 }
169 self.write_byte(encoded as u8);
170 }
171 }
172
173 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
175 self.write_bytes(&val.serialize());
176 }
177
178 fn ensure_byte(&mut self) {
180 let byte_pos = self.byte_pos();
181 if byte_pos >= self.buffer.len() {
182 self.buffer.resize(byte_pos + 1, 0);
183 }
184 }
185
186 pub fn align_byte(&mut self) {
188 let rem = self.bit_pos % 8;
189 if rem != 0 {
190 let byte_pos = self.byte_pos();
191 self.bit_pos += 8 - rem;
192
193 if let Some(crypto) = self.crypto.as_mut() {
195 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
196 }
197 }
198 }
199
200 pub fn reset(&mut self) {
202 self.bit_pos = 0;
203 }
204
205 pub fn len(&self) -> usize {
207 self.buffer.len()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use crate::CryptoStream;
214
215 use super::BitStreamWriter;
216
217 struct PlusOneEncrypter {
218 ciphertext: Vec<u8>
219 }
220
221 impl CryptoStream for PlusOneEncrypter {
222 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
223 self.ciphertext.push(b + 1);
224 *self.ciphertext.last().unwrap()
225 }
226
227 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
228 let d = slice.iter().map(|s|s + 1);
229 self.ciphertext.extend(d);
230 &self.ciphertext[self.ciphertext.len() - slice.len()..]
231 }
232
233 fn get_cached(&self, original: bool) -> &[u8] {
234 &[]
235 }
236
237 fn replace(&mut self, other: &Box<dyn CryptoStream>) {
238 self.ciphertext = other.get_cached(true).to_vec();
239 }
240 }
241
242 #[test]
243 fn test_encrypt_bytes() {
244 let mut buf = Vec::new();
245 let mut writer = BitStreamWriter::new(&mut buf);
246 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
247
248 writer.write_byte(1);
249 writer.write_byte(2);
250 writer.write_byte(3);
251 writer.write_bit(false);
252 writer.write_bit(false);
253 writer.write_bit(true);
254 writer.write_bytes(&[5,6,7,8,9]);
255 writer.write_byte(10);
256
257 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
258 }
259
260
261 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
263 buffer.iter().map(|b| format!("{:08b}", b)).collect()
264 }
265
266 #[test]
267 fn test_write_bit() {
268 let mut buf = Vec::new();
269 let mut stream = BitStreamWriter::new(&mut buf);
270
271 stream.write_bit(true);
272 stream.write_bit(false);
273 stream.write_bit(true);
274 stream.write_bit(true); assert_eq!(buf.len(), 1);
277 assert_eq!(buf[0], 0b00001101); }
279
280 #[test]
281 fn test_write_small() {
282 let mut buf = Vec::new();
283 let mut stream = BitStreamWriter::new(&mut buf);
284
285 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
290 assert_eq!(buf[0], 0b11111101); }
292
293 #[test]
294 fn test_write_cross_byte() {
295 let mut buf = Vec::new();
296 let mut stream = BitStreamWriter::new(&mut buf);
297
298 stream.write_small(0b00101011, 7);
300 stream.write_small(0b1101, 4);
301
302 assert_eq!(buf.len(), 2);
303 assert_eq!(buf[0], 0b10101011);
304 assert_eq!(buf[1], 0b00000110);
305 }
306
307 #[test]
308 fn test_write_byte() {
309 let mut buf = Vec::new();
310 let mut stream = BitStreamWriter::new(&mut buf);
311
312 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
316 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
319
320 #[test]
321 fn test_write_bytes() {
322 let mut buf = Vec::new();
323 let mut stream = BitStreamWriter::new(&mut buf);
324
325 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
329 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
331 assert_eq!(buf[2], 0xBB);
332 assert_eq!(buf[3], 0xCC);
333 }
334
335 #[test]
336 fn test_alignment() {
337 let mut buf = Vec::new();
338 let mut stream = BitStreamWriter::new(&mut buf);
339
340 stream.write_small(0b11, 2); stream.align_byte();
342 stream.write_byte(0xFF);
343
344 assert_eq!(buf.len(), 2);
345 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
347 }
348
349 #[test]
350 fn test_multiple_operations() {
351 let mut buf = Vec::new();
352 let mut stream = BitStreamWriter::new(&mut buf);
353
354 stream.write_bit(true);
355 stream.write_small(0b101, 3);
356 stream.write_byte(0xAA);
357 stream.write_bytes(&[0xBB, 0xCC]);
358 stream.write_small(0b11, 2);
359
360 let bin = buffer_to_bin(&buf);
361 println!("{:?}", bin);
362
363 assert_eq!(buf.len(), 5);
364 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
367 assert_eq!(buf[3], 0xCC);
368 assert_eq!(buf[4], 0b00000011); }
370
371 #[test]
372 fn test_write_dyn_int() {
373 let mut buf = Vec::new();
374 let mut stream = BitStreamWriter::new(&mut buf);
375
376 stream.write_dyn_int(127);
377 assert_eq!(1, stream.len());
378
379 stream.write_dyn_int(128); assert_eq!(3, stream.len());
381
382 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
384
385 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
386 }
387
388 #[test]
389 fn test_write_fixed_int() {
390 let mut buf = Vec::new();
391 let mut stream = BitStreamWriter::new(&mut buf);
392
393 stream.write_fixed_int(1u8);
394 stream.write_fixed_int(1i8);
395 stream.write_fixed_int(2u16);
396 stream.write_fixed_int(2i16);
397 stream.write_fixed_int(3u32);
398 stream.write_fixed_int(3i32);
399 stream.write_fixed_int(4u64);
400 stream.write_fixed_int(4i64);
401 stream.write_fixed_int(5u128);
402 stream.write_fixed_int(5i128);
403
404 assert_eq!(
405 vec![
406 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
407 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
408 0, 0, 0, 0, 0, 10
409 ],
410 buf
411 );
412 }
413
414 #[test]
415 fn test_slice_marker() {
416 let mut buf = Vec::new();
417 let mut stream = BitStreamWriter::new(&mut buf);
418
419 stream.write_bytes(&[10, 20, 30, 40, 50]);
420 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
421
422 stream.set_marker(Some(2));
423 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
424
425 stream.set_marker(None);
426 assert_eq!(stream.slice_marker(None), &[]);
427 }
428}