1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9 marker: Option<usize>,
10}
11
12impl<'a> BitStreamWriter<'a> {
13 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 marker: None,
20 }
21 }
22
23 pub fn slice(&self) -> &[u8] {
25 &self.buffer
26 }
27
28 pub fn set_marker(&mut self, pos: Option<usize>) {
30 self.marker = Some(pos.unwrap_or(self.byte_pos()));
31 }
32
33 pub fn reset_marker(&mut self) {
35 self.marker = None;
36 }
37
38 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
40 &self.buffer[self.marker.unwrap_or(0)..to.unwrap_or(self.byte_pos())]
41 }
42
43 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
45 self.crypto = crypto;
46 }
47
48 pub fn byte_pos(&self) -> usize {
50 self.bit_pos / 8
51 }
52
53 pub fn write_bit(&mut self, val: bool) {
55 self.write_small(val as u8, 1);
56 }
57
58 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
60 assert!(bits > 0 && bits < 8);
61
62 while bits > 0 {
63 self.ensure_byte();
64
65 let bit_offset = self.bit_pos % 8;
67
68 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
70
71 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
75
76 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
83
84 let byte_pos = self.byte_pos();
85
86 self.buffer[byte_pos] &= !mask;
88
89 self.buffer[byte_pos] |= shifted_val & mask;
91
92 bits -= bits_in_current_byte;
94
95 val >>= bits_in_current_byte;
97
98 self.bit_pos += bits_in_current_byte as usize;
99
100 if self.bit_pos % 8 == 0 {
102 if let Some(crypto) = self.crypto.as_mut() {
103 let b = self.buffer[byte_pos];
104 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
105 }
106 }
107 }
108 }
109
110 pub fn write_byte(&mut self, byte: u8) {
112 self.align_byte();
113 self.ensure_byte();
114
115 let byte_pos = self.byte_pos();
116 let byte = if let Some(crypto) = self.crypto.as_mut() {
117 crypto.apply_keystream_byte(byte)
118 } else {
119 byte
120 };
121
122 self.buffer[byte_pos] = byte;
123 self.bit_pos += 8;
124 }
125
126 pub fn write_bytes(&mut self, data: &[u8]) {
128 self.align_byte();
129
130 if let Some(crypto) = self.crypto.as_mut() {
131 let encrypted = crypto.apply_keystream(data);
132 self.buffer.extend_from_slice(encrypted);
133 } else {
134 self.buffer.extend_from_slice(data);
135 }
136
137 self.bit_pos += 8 * data.len();
138 }
139
140 pub fn write_dyn_int(&mut self, mut val: u128) {
143 while val > 0 {
144 let mut encoded = val % 128;
145 val /= 128;
146 if val > 0 {
147 encoded |= 128;
148 }
149 self.write_byte(encoded as u8);
150 }
151 }
152
153 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
155 self.write_bytes(&val.serialize());
156 }
157
158 fn ensure_byte(&mut self) {
160 let byte_pos = self.byte_pos();
161 if byte_pos >= self.buffer.len() {
162 self.buffer.resize(byte_pos + 1, 0);
163 }
164 }
165
166 pub fn align_byte(&mut self) {
168 let rem = self.bit_pos % 8;
169 if rem != 0 {
170 let byte_pos = self.byte_pos();
171 self.bit_pos += 8 - rem;
172
173 if let Some(crypto) = self.crypto.as_mut() {
175 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
176 }
177 }
178 }
179
180 pub fn reset(&mut self) {
182 self.bit_pos = 0;
183 }
184
185 pub fn len(&self) -> usize {
187 self.buffer.len()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use crate::CryptoStream;
194
195 use super::BitStreamWriter;
196
197 struct PlusOneEncrypter {
198 ciphertext: Vec<u8>
199 }
200
201 impl CryptoStream for PlusOneEncrypter {
202 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
203 self.ciphertext.push(b + 1);
204 *self.ciphertext.last().unwrap()
205 }
206
207 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
208 let d = slice.iter().map(|s|s + 1);
209 self.ciphertext.extend(d);
210 &self.ciphertext[self.ciphertext.len() - slice.len()..]
211 }
212 }
213
214 #[test]
215 fn test_encrypt_bytes() {
216 let mut buf = Vec::new();
217 let mut writer = BitStreamWriter::new(&mut buf);
218 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
219
220 writer.write_byte(1);
221 writer.write_byte(2);
222 writer.write_byte(3);
223 writer.write_bit(false);
224 writer.write_bit(false);
225 writer.write_bit(true);
226 writer.write_bytes(&[5,6,7,8,9]);
227 writer.write_byte(10);
228
229 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
230 }
231
232
233 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
235 buffer.iter().map(|b| format!("{:08b}", b)).collect()
236 }
237
238 #[test]
239 fn test_write_bit() {
240 let mut buf = Vec::new();
241 let mut stream = BitStreamWriter::new(&mut buf);
242
243 stream.write_bit(true);
244 stream.write_bit(false);
245 stream.write_bit(true);
246 stream.write_bit(true); assert_eq!(buf.len(), 1);
249 assert_eq!(buf[0], 0b00001101); }
251
252 #[test]
253 fn test_write_small() {
254 let mut buf = Vec::new();
255 let mut stream = BitStreamWriter::new(&mut buf);
256
257 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
262 assert_eq!(buf[0], 0b11111101); }
264
265 #[test]
266 fn test_write_cross_byte() {
267 let mut buf = Vec::new();
268 let mut stream = BitStreamWriter::new(&mut buf);
269
270 stream.write_small(0b00101011, 7);
272 stream.write_small(0b1101, 4);
273
274 assert_eq!(buf.len(), 2);
275 assert_eq!(buf[0], 0b10101011);
276 assert_eq!(buf[1], 0b00000110);
277 }
278
279 #[test]
280 fn test_write_byte() {
281 let mut buf = Vec::new();
282 let mut stream = BitStreamWriter::new(&mut buf);
283
284 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
288 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
291
292 #[test]
293 fn test_write_bytes() {
294 let mut buf = Vec::new();
295 let mut stream = BitStreamWriter::new(&mut buf);
296
297 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
301 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
303 assert_eq!(buf[2], 0xBB);
304 assert_eq!(buf[3], 0xCC);
305 }
306
307 #[test]
308 fn test_alignment() {
309 let mut buf = Vec::new();
310 let mut stream = BitStreamWriter::new(&mut buf);
311
312 stream.write_small(0b11, 2); stream.align_byte();
314 stream.write_byte(0xFF);
315
316 assert_eq!(buf.len(), 2);
317 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
319 }
320
321 #[test]
322 fn test_multiple_operations() {
323 let mut buf = Vec::new();
324 let mut stream = BitStreamWriter::new(&mut buf);
325
326 stream.write_bit(true);
327 stream.write_small(0b101, 3);
328 stream.write_byte(0xAA);
329 stream.write_bytes(&[0xBB, 0xCC]);
330 stream.write_small(0b11, 2);
331
332 let bin = buffer_to_bin(&buf);
333 println!("{:?}", bin);
334
335 assert_eq!(buf.len(), 5);
336 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
339 assert_eq!(buf[3], 0xCC);
340 assert_eq!(buf[4], 0b00000011); }
342
343 #[test]
344 fn test_write_dyn_int() {
345 let mut buf = Vec::new();
346 let mut stream = BitStreamWriter::new(&mut buf);
347
348 stream.write_dyn_int(127);
349 assert_eq!(1, stream.len());
350
351 stream.write_dyn_int(128); assert_eq!(3, stream.len());
353
354 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
356
357 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
358 }
359
360 #[test]
361 fn test_write_fixed_int() {
362 let mut buf = Vec::new();
363 let mut stream = BitStreamWriter::new(&mut buf);
364
365 stream.write_fixed_int(1u8);
366 stream.write_fixed_int(1i8);
367 stream.write_fixed_int(2u16);
368 stream.write_fixed_int(2i16);
369 stream.write_fixed_int(3u32);
370 stream.write_fixed_int(3i32);
371 stream.write_fixed_int(4u64);
372 stream.write_fixed_int(4i64);
373 stream.write_fixed_int(5u128);
374 stream.write_fixed_int(5i128);
375
376 assert_eq!(
377 vec![
378 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
379 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
380 0, 0, 0, 0, 0, 10
381 ],
382 buf
383 );
384 }
385
386 #[test]
387 fn test_slice_marker() {
388 let mut buf = Vec::new();
389 let mut stream = BitStreamWriter::new(&mut buf);
390
391 stream.write_bytes(&[10, 20, 30, 40, 50]);
392 assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
393
394 stream.set_marker(Some(2));
395 assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
396
397 stream.set_marker(None);
398 assert_eq!(stream.slice_marker(None), &[]);
399 }
400}