use std::cmp::min;
use crate::{CryptoStream, encoding::fixed_int::FixedInt};
pub struct BitStreamWriter<'a> {
buffer: &'a mut Vec<u8>,
bit_pos: usize,
crypto: Option<Box<dyn CryptoStream>>,
marker: Option<usize>,
}
impl<'a> BitStreamWriter<'a> {
pub fn new(buffer: &'a mut Vec<u8>) -> Self {
Self {
buffer,
bit_pos: 0,
crypto: None,
marker: None,
}
}
pub fn slice(&self) -> &[u8] {
&self.buffer
}
pub fn set_marker(&mut self, pos: Option<usize>) {
self.marker = Some(pos.unwrap_or(self.byte_pos()));
}
pub fn reset_marker(&mut self) {
self.marker = None;
}
pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
let start = self.marker.unwrap_or(0);
let end = to.unwrap_or(self.byte_pos());
if let Some(crypto) = self.crypto.as_ref() {
return &crypto.get_cached(true)[start..end];
}
&self.buffer[start..end]
}
pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
if let Some(new) = crypto.as_mut() {
if let Some(existing) = self.crypto.as_ref() {
new.replace(existing);
} else {
new.set_cached(self.slice());
}
}
self.crypto = crypto;
}
pub fn reset_crypto(&mut self) {
self.crypto = None;
}
pub fn byte_pos(&self) -> usize {
self.bit_pos / 8
}
pub fn write_bit(&mut self, val: bool) {
self.write_small(val as u8, 1);
}
pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
assert!(bits > 0 && bits < 8);
while bits > 0 {
self.ensure_byte();
let bit_offset = self.bit_pos % 8;
let bits_in_current_byte = min(8 - bit_offset as u8, bits);
let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
let byte_pos = self.byte_pos();
self.buffer[byte_pos] &= !mask;
self.buffer[byte_pos] |= shifted_val & mask;
bits -= bits_in_current_byte;
val >>= bits_in_current_byte;
self.bit_pos += bits_in_current_byte as usize;
if self.bit_pos % 8 == 0 {
if let Some(crypto) = self.crypto.as_mut() {
let b = self.buffer[byte_pos];
self.buffer[byte_pos] = crypto.apply_keystream_byte(b);
}
}
}
}
pub fn write_byte(&mut self, byte: u8) {
self.align_byte();
self.ensure_byte();
let byte_pos = self.byte_pos();
let byte = if let Some(crypto) = self.crypto.as_mut() {
crypto.apply_keystream_byte(byte)
} else {
byte
};
self.buffer[byte_pos] = byte;
self.bit_pos += 8;
}
pub fn write_bytes(&mut self, data: &[u8]) {
self.align_byte();
if let Some(crypto) = self.crypto.as_mut() {
let encrypted = crypto.apply_keystream(data);
self.buffer.extend_from_slice(encrypted);
} else {
self.buffer.extend_from_slice(data);
}
self.bit_pos += 8 * data.len();
}
pub fn write_dyn_int(&mut self, mut val: u128) {
if val == 0 {
self.write_byte(0);
return;
}
while val > 0 {
let mut encoded = val % 128;
val /= 128;
if val > 0 {
encoded |= 128;
}
self.write_byte(encoded as u8);
}
}
pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
self.write_bytes(&val.serialize());
}
fn ensure_byte(&mut self) {
let byte_pos = self.byte_pos();
if byte_pos >= self.buffer.len() {
self.buffer.resize(byte_pos + 1, 0);
}
}
pub fn align_byte(&mut self) {
let rem = self.bit_pos % 8;
if rem != 0 {
let byte_pos = self.byte_pos();
self.bit_pos += 8 - rem;
if let Some(crypto) = self.crypto.as_mut() {
self.buffer[byte_pos] = crypto.apply_keystream_byte(self.buffer[byte_pos]);
}
}
}
pub fn reset(&mut self) {
self.bit_pos = 0;
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
#[cfg(test)]
mod tests {
use crate::CryptoStream;
use super::BitStreamWriter;
struct PlusOneEncrypter {
ciphertext: Vec<u8>
}
impl CryptoStream for PlusOneEncrypter {
fn apply_keystream_byte(&mut self, b: u8) -> u8 {
self.ciphertext.push(b + 1);
*self.ciphertext.last().unwrap()
}
fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
let d = slice.iter().map(|s|s + 1);
self.ciphertext.extend(d);
&self.ciphertext[self.ciphertext.len() - slice.len()..]
}
fn get_cached(&self, original: bool) -> &[u8] {
&[]
}
fn replace(&mut self, other: &Box<dyn CryptoStream>) {
self.ciphertext = other.get_cached(true).to_vec();
}
fn set_cached(&mut self, data: &[u8]) {
self.ciphertext = data.to_vec();
}
}
#[test]
fn test_encrypt_bytes() {
let mut buf = Vec::new();
let mut writer = BitStreamWriter::new(&mut buf);
writer.crypto = Some(Box::new(PlusOneEncrypter { ciphertext: Vec::new() }));
writer.write_byte(1);
writer.write_byte(2);
writer.write_byte(3);
writer.write_bit(false);
writer.write_bit(false);
writer.write_bit(true);
writer.write_bytes(&[5,6,7,8,9]);
writer.write_byte(10);
assert_eq!(buf, vec![2,3,4,5,6,7,8,9,10,11]);
}
fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
buffer.iter().map(|b| format!("{:08b}", b)).collect()
}
#[test]
fn test_write_bit() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true);
stream.write_bit(false);
stream.write_bit(true);
stream.write_bit(true);
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 0b00001101); }
#[test]
fn test_write_small() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3);
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 0b11111101); }
#[test]
fn test_write_cross_byte() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b00101011, 7);
stream.write_small(0b1101, 4);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b10101011);
assert_eq!(buf[1], 0b00000110);
}
#[test]
fn test_write_byte() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true); stream.write_byte(0xAA);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
#[test]
fn test_write_bytes() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]);
assert_eq!(buf.len(), 4);
assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
assert_eq!(buf[2], 0xBB);
assert_eq!(buf[3], 0xCC);
}
#[test]
fn test_alignment() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b11, 2); stream.align_byte();
stream.write_byte(0xFF);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
}
#[test]
fn test_multiple_operations() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true);
stream.write_small(0b101, 3);
stream.write_byte(0xAA);
stream.write_bytes(&[0xBB, 0xCC]);
stream.write_small(0b11, 2);
let bin = buffer_to_bin(&buf);
println!("{:?}", bin);
assert_eq!(buf.len(), 5);
assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
assert_eq!(buf[3], 0xCC);
assert_eq!(buf[4], 0b00000011); }
#[test]
fn test_write_dyn_int() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_dyn_int(127);
assert_eq!(1, stream.len());
stream.write_dyn_int(128); assert_eq!(3, stream.len());
stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
}
#[test]
fn test_write_fixed_int() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_fixed_int(1u8);
stream.write_fixed_int(1i8);
stream.write_fixed_int(2u16);
stream.write_fixed_int(2i16);
stream.write_fixed_int(3u32);
stream.write_fixed_int(3i32);
stream.write_fixed_int(4u64);
stream.write_fixed_int(4i64);
stream.write_fixed_int(5u128);
stream.write_fixed_int(5i128);
assert_eq!(
vec![
1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 10
],
buf
);
}
#[test]
fn test_slice_marker() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bytes(&[10, 20, 30, 40, 50]);
assert_eq!(stream.slice_marker(Some(4)), &[10,20,30,40]);
stream.set_marker(Some(2));
assert_eq!(stream.slice_marker(None), &[30, 40, 50]);
stream.set_marker(None);
assert_eq!(stream.slice_marker(None), &[]);
}
#[test]
fn test_write_0_dynint() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_dyn_int(0);
assert_eq!(1, stream.len());
}
}