1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
41 let start = self.marker.unwrap_or(0);
42 let end = to.unwrap_or(self.byte_pos());
43
44 if let Some(crypto) = self.crypto.as_ref() {
45 return &crypto.get_cached(true)[start..end];
46 }
47
48 &self.buffer[start..end]
49 }
50
51 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
53 self.crypto = crypto;
54 }
55
56 pub fn byte_pos(&self) -> usize {
58 self.bit_pos / 8
59 }
60
61 pub fn write_bit(&mut self, val: bool) {
63 self.write_small(val as u8, 1);
64 }
65
66 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
68 assert!(bits > 0 && bits < 8);
69
70 while bits > 0 {
71 self.ensure_byte();
72
73 let bit_offset = self.bit_pos % 8;
75
76 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
78
79 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
83
84 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
91
92 let byte_pos = self.byte_pos();
93
94 self.buffer[byte_pos] &= !mask;
96
97 self.buffer[byte_pos] |= shifted_val & mask;
99
100 bits -= bits_in_current_byte;
102
103 val >>= bits_in_current_byte;
105
106 self.bit_pos += bits_in_current_byte as usize;
107
108 if self.bit_pos % 8 == 0 {
110 if let Some(crypto) = self.crypto.as_mut() {
111 let b = self.buffer[byte_pos];
112 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
113 }
114 }
115 }
116 }
117
118 pub fn write_byte(&mut self, byte: u8) {
120 self.align_byte();
121 self.ensure_byte();
122
123 let byte_pos = self.byte_pos();
124 let byte = if let Some(crypto) = self.crypto.as_mut() {
125 crypto.apply_keystream_byte(byte)
126 } else {
127 byte
128 };
129
130 self.buffer[byte_pos] = byte;
131 self.bit_pos += 8;
132 }
133
134 pub fn write_bytes(&mut self, data: &[u8]) {
136 self.align_byte();
137
138 if let Some(crypto) = self.crypto.as_mut() {
139 let encrypted = crypto.apply_keystream(data);
140 self.buffer.extend_from_slice(encrypted);
141 } else {
142 self.buffer.extend_from_slice(data);
143 }
144
145 self.bit_pos += 8 * data.len();
146 }
147
148 pub fn write_dyn_int(&mut self, mut val: u128) {
151 while val > 0 {
152 let mut encoded = val % 128;
153 val /= 128;
154 if val > 0 {
155 encoded |= 128;
156 }
157 self.write_byte(encoded as u8);
158 }
159 }
160
161 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
163 self.write_bytes(&val.serialize());
164 }
165
166 fn ensure_byte(&mut self) {
168 let byte_pos = self.byte_pos();
169 if byte_pos >= self.buffer.len() {
170 self.buffer.resize(byte_pos + 1, 0);
171 }
172 }
173
174 pub fn align_byte(&mut self) {
176 let rem = self.bit_pos % 8;
177 if rem != 0 {
178 let byte_pos = self.byte_pos();
179 self.bit_pos += 8 - rem;
180
181 if let Some(crypto) = self.crypto.as_mut() {
183 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
184 }
185 }
186 }
187
188 pub fn reset(&mut self) {
190 self.bit_pos = 0;
191 }
192
193 pub fn len(&self) -> usize {
195 self.buffer.len()
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use crate::CryptoStream;
202
203 use super::BitStreamWriter;
204
205 struct PlusOneEncrypter {
206 ciphertext: Vec<u8>
207 }
208
209 impl CryptoStream for PlusOneEncrypter {
210 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
211 self.ciphertext.push(b + 1);
212 *self.ciphertext.last().unwrap()
213 }
214
215 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
216 let d = slice.iter().map(|s|s + 1);
217 self.ciphertext.extend(d);
218 &self.ciphertext[self.ciphertext.len() - slice.len()..]
219 }
220
221 fn get_cached(&self, original: bool) -> &[u8] {
222 &[]
223 }
224 }
225
226 #[test]
227 fn test_encrypt_bytes() {
228 let mut buf = Vec::new();
229 let mut writer = BitStreamWriter::new(&mut buf);
230 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
231
232 writer.write_byte(1);
233 writer.write_byte(2);
234 writer.write_byte(3);
235 writer.write_bit(false);
236 writer.write_bit(false);
237 writer.write_bit(true);
238 writer.write_bytes(&[5,6,7,8,9]);
239 writer.write_byte(10);
240
241 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
242 }
243
244
245 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
247 buffer.iter().map(|b| format!("{:08b}", b)).collect()
248 }
249
250 #[test]
251 fn test_write_bit() {
252 let mut buf = Vec::new();
253 let mut stream = BitStreamWriter::new(&mut buf);
254
255 stream.write_bit(true);
256 stream.write_bit(false);
257 stream.write_bit(true);
258 stream.write_bit(true); assert_eq!(buf.len(), 1);
261 assert_eq!(buf[0], 0b00001101); }
263
264 #[test]
265 fn test_write_small() {
266 let mut buf = Vec::new();
267 let mut stream = BitStreamWriter::new(&mut buf);
268
269 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
274 assert_eq!(buf[0], 0b11111101); }
276
277 #[test]
278 fn test_write_cross_byte() {
279 let mut buf = Vec::new();
280 let mut stream = BitStreamWriter::new(&mut buf);
281
282 stream.write_small(0b00101011, 7);
284 stream.write_small(0b1101, 4);
285
286 assert_eq!(buf.len(), 2);
287 assert_eq!(buf[0], 0b10101011);
288 assert_eq!(buf[1], 0b00000110);
289 }
290
291 #[test]
292 fn test_write_byte() {
293 let mut buf = Vec::new();
294 let mut stream = BitStreamWriter::new(&mut buf);
295
296 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
300 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
303
304 #[test]
305 fn test_write_bytes() {
306 let mut buf = Vec::new();
307 let mut stream = BitStreamWriter::new(&mut buf);
308
309 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
313 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
315 assert_eq!(buf[2], 0xBB);
316 assert_eq!(buf[3], 0xCC);
317 }
318
319 #[test]
320 fn test_alignment() {
321 let mut buf = Vec::new();
322 let mut stream = BitStreamWriter::new(&mut buf);
323
324 stream.write_small(0b11, 2); stream.align_byte();
326 stream.write_byte(0xFF);
327
328 assert_eq!(buf.len(), 2);
329 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
331 }
332
333 #[test]
334 fn test_multiple_operations() {
335 let mut buf = Vec::new();
336 let mut stream = BitStreamWriter::new(&mut buf);
337
338 stream.write_bit(true);
339 stream.write_small(0b101, 3);
340 stream.write_byte(0xAA);
341 stream.write_bytes(&[0xBB, 0xCC]);
342 stream.write_small(0b11, 2);
343
344 let bin = buffer_to_bin(&buf);
345 println!("{:?}", bin);
346
347 assert_eq!(buf.len(), 5);
348 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
351 assert_eq!(buf[3], 0xCC);
352 assert_eq!(buf[4], 0b00000011); }
354
355 #[test]
356 fn test_write_dyn_int() {
357 let mut buf = Vec::new();
358 let mut stream = BitStreamWriter::new(&mut buf);
359
360 stream.write_dyn_int(127);
361 assert_eq!(1, stream.len());
362
363 stream.write_dyn_int(128); assert_eq!(3, stream.len());
365
366 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
368
369 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
370 }
371
372 #[test]
373 fn test_write_fixed_int() {
374 let mut buf = Vec::new();
375 let mut stream = BitStreamWriter::new(&mut buf);
376
377 stream.write_fixed_int(1u8);
378 stream.write_fixed_int(1i8);
379 stream.write_fixed_int(2u16);
380 stream.write_fixed_int(2i16);
381 stream.write_fixed_int(3u32);
382 stream.write_fixed_int(3i32);
383 stream.write_fixed_int(4u64);
384 stream.write_fixed_int(4i64);
385 stream.write_fixed_int(5u128);
386 stream.write_fixed_int(5i128);
387
388 assert_eq!(
389 vec![
390 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
391 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
392 0, 0, 0, 0, 0, 10
393 ],
394 buf
395 );
396 }
397
398 #[test]
399 fn test_slice_marker() {
400 let mut buf = Vec::new();
401 let mut stream = BitStreamWriter::new(&mut buf);
402
403 stream.write_bytes(&[10, 20, 30, 40, 50]);
404 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
405
406 stream.set_marker(Some(2));
407 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
408
409 stream.set_marker(None);
410 assert_eq!(stream.slice_marker(None), &[]);
411 }
412}