1use std::cmp::min;
2
3use crate::{CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamWriter<'a> {
6 buffer: &'a mut Vec<u8>,
7 bit_pos: usize,
8 crypto: Option<Box<dyn CryptoStream>>,
9}
10
11impl<'a> BitStreamWriter<'a> {
12 pub fn new(buffer: &'a mut Vec<u8>) -> Self {
14 Self {
15 buffer,
16 bit_pos: 0,
17 crypto: None,
18 }
19 }
20
21 pub fn slice(&self) -> &[u8] {
23 &self.buffer
24 }
25
26 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
28 self.crypto = crypto;
29 }
30
31 pub fn byte_pos(&self) -> usize {
33 self.bit_pos / 8
34 }
35
36 pub fn write_bit(&mut self, val: bool) {
38 self.write_small(val as u8, 1);
39 }
40
41 pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
43 assert!(bits > 0 && bits < 8);
44
45 while bits > 0 {
46 self.ensure_byte();
47
48 let bit_offset = self.bit_pos % 8;
50
51 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
53
54 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
58
59 let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
66
67 let byte_pos = self.byte_pos();
68
69 self.buffer[byte_pos] &= !mask;
71
72 self.buffer[byte_pos] |= shifted_val & mask;
74
75 bits -= bits_in_current_byte;
77
78 val >>= bits_in_current_byte;
80
81 self.bit_pos += bits_in_current_byte as usize;
82
83 if self.bit_pos % 8 == 0 {
85 if let Some(crypto) = self.crypto.as_mut() {
86 let b = self.buffer[byte_pos];
87 self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
88 }
89 }
90 }
91 }
92
93 pub fn write_byte(&mut self, byte: u8) {
95 self.align_byte();
96 self.ensure_byte();
97
98 let byte_pos = self.byte_pos();
99 let byte = if let Some(crypto) = self.crypto.as_mut() {
100 crypto.apply_keystream_byte(byte)
101 } else {
102 byte
103 };
104
105 self.buffer[byte_pos] = byte;
106 self.bit_pos += 8;
107 }
108
109 pub fn write_bytes(&mut self, data: &[u8]) {
111 self.align_byte();
112
113 if let Some(crypto) = self.crypto.as_mut() {
114 let encrypted = crypto.apply_keystream(data);
115 self.buffer.extend_from_slice(encrypted);
116 } else {
117 self.buffer.extend_from_slice(data);
118 }
119
120 self.bit_pos += 8 * data.len();
121 }
122
123 pub fn write_dyn_int(&mut self, mut val: u128) {
126 while val > 0 {
127 let mut encoded = val % 128;
128 val /= 128;
129 if val > 0 {
130 encoded |= 128;
131 }
132 self.write_byte(encoded as u8);
133 }
134 }
135
136 pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
138 self.write_bytes(&val.serialize());
139 }
140
141 fn ensure_byte(&mut self) {
143 let byte_pos = self.byte_pos();
144 if byte_pos >= self.buffer.len() {
145 self.buffer.resize(byte_pos + 1, 0);
146 }
147 }
148
149 pub fn align_byte(&mut self) {
151 let rem = self.bit_pos % 8;
152 if rem != 0 {
153 let byte_pos = self.byte_pos();
154 self.bit_pos += 8 - rem;
155
156 if let Some(crypto) = self.crypto.as_mut() {
158 self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
159 }
160 }
161 }
162
163 pub fn reset(&mut self) {
165 self.bit_pos = 0;
166 }
167
168 pub fn len(&self) -> usize {
170 self.buffer.len()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use crate::CryptoStream;
177
178 use super::BitStreamWriter;
179
180 struct PlusOneEncrypter {
181 ciphertext: Vec<u8>
182 }
183
184 impl CryptoStream for PlusOneEncrypter {
185 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
186 self.ciphertext.push(b + 1);
187 *self.ciphertext.last().unwrap()
188 }
189
190 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
191 let d = slice.iter().map(|s|s + 1);
192 self.ciphertext.extend(d);
193 &self.ciphertext[self.ciphertext.len() - slice.len()..]
194 }
195 }
196
197 #[test]
198 fn test_encrypt_bytes() {
199 let mut buf = Vec::new();
200 let mut writer = BitStreamWriter::new(&mut buf);
201 writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
202
203 writer.write_byte(1);
204 writer.write_byte(2);
205 writer.write_byte(3);
206 writer.write_bit(false);
207 writer.write_bit(false);
208 writer.write_bit(true);
209 writer.write_bytes(&[5,6,7,8,9]);
210 writer.write_byte(10);
211
212 assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
213 }
214
215
216 fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
218 buffer.iter().map(|b| format!("{:08b}", b)).collect()
219 }
220
221 #[test]
222 fn test_write_bit() {
223 let mut buf = Vec::new();
224 let mut stream = BitStreamWriter::new(&mut buf);
225
226 stream.write_bit(true);
227 stream.write_bit(false);
228 stream.write_bit(true);
229 stream.write_bit(true); assert_eq!(buf.len(), 1);
232 assert_eq!(buf[0], 0b00001101); }
234
235 #[test]
236 fn test_write_small() {
237 let mut buf = Vec::new();
238 let mut stream = BitStreamWriter::new(&mut buf);
239
240 stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3); assert_eq!(buf.len(), 1);
245 assert_eq!(buf[0], 0b11111101); }
247
248 #[test]
249 fn test_write_cross_byte() {
250 let mut buf = Vec::new();
251 let mut stream = BitStreamWriter::new(&mut buf);
252
253 stream.write_small(0b00101011, 7);
255 stream.write_small(0b1101, 4);
256
257 assert_eq!(buf.len(), 2);
258 assert_eq!(buf[0], 0b10101011);
259 assert_eq!(buf[1], 0b00000110);
260 }
261
262 #[test]
263 fn test_write_byte() {
264 let mut buf = Vec::new();
265 let mut stream = BitStreamWriter::new(&mut buf);
266
267 stream.write_bit(true); stream.write_byte(0xAA); assert_eq!(buf.len(), 2);
271 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
274
275 #[test]
276 fn test_write_bytes() {
277 let mut buf = Vec::new();
278 let mut stream = BitStreamWriter::new(&mut buf);
279
280 stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]); assert_eq!(buf.len(), 4);
284 assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
286 assert_eq!(buf[2], 0xBB);
287 assert_eq!(buf[3], 0xCC);
288 }
289
290 #[test]
291 fn test_alignment() {
292 let mut buf = Vec::new();
293 let mut stream = BitStreamWriter::new(&mut buf);
294
295 stream.write_small(0b11, 2); stream.align_byte();
297 stream.write_byte(0xFF);
298
299 assert_eq!(buf.len(), 2);
300 assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
302 }
303
304 #[test]
305 fn test_multiple_operations() {
306 let mut buf = Vec::new();
307 let mut stream = BitStreamWriter::new(&mut buf);
308
309 stream.write_bit(true);
310 stream.write_small(0b101, 3);
311 stream.write_byte(0xAA);
312 stream.write_bytes(&[0xBB, 0xCC]);
313 stream.write_small(0b11, 2);
314
315 let bin = buffer_to_bin(&buf);
316 println!("{:?}", bin);
317
318 assert_eq!(buf.len(), 5);
319 assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
322 assert_eq!(buf[3], 0xCC);
323 assert_eq!(buf[4], 0b00000011); }
325
326 #[test]
327 fn test_write_dyn_int() {
328 let mut buf = Vec::new();
329 let mut stream = BitStreamWriter::new(&mut buf);
330
331 stream.write_dyn_int(127);
332 assert_eq!(1, stream.len());
333
334 stream.write_dyn_int(128); assert_eq!(3, stream.len());
336
337 stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
339
340 assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
341 }
342
343 #[test]
344 fn test_write_fixed_int() {
345 let mut buf = Vec::new();
346 let mut stream = BitStreamWriter::new(&mut buf);
347
348 stream.write_fixed_int(1u8);
349 stream.write_fixed_int(1i8);
350 stream.write_fixed_int(2u16);
351 stream.write_fixed_int(2i16);
352 stream.write_fixed_int(3u32);
353 stream.write_fixed_int(3i32);
354 stream.write_fixed_int(4u64);
355 stream.write_fixed_int(4i64);
356 stream.write_fixed_int(5u128);
357 stream.write_fixed_int(5i128);
358
359 assert_eq!(
360 vec![
361 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
362 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
363 0, 0, 0, 0, 0, 10
364 ],
365 buf
366 );
367 }
368}