use std::cmp::min;
use crate::encoding::fixed_int::FixedInt;
pub struct BitStreamWriter<'a> {
buffer: &'a mut Vec<u8>,
bit_pos: usize,
}
impl<'a> BitStreamWriter<'a> {
pub fn new(buffer: &'a mut Vec<u8>) -> Self {
Self { buffer, bit_pos: 0 }
}
pub fn byte_pos(&self) -> usize {
self.bit_pos / 8
}
pub fn write_bit(&mut self, val: bool) {
self.write_small(val as u8, 1);
}
pub fn write_small(&mut self, mut val: u8, mut bits: u8) {
assert!(bits > 0 && bits < 8);
while bits > 0 {
self.ensure_byte();
let bit_offset = self.bit_pos % 8;
let bits_in_current_byte = min(8 - bit_offset as u8, bits);
let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
let shifted_val = (val & ((1 << bits_in_current_byte) - 1)) << bit_offset;
let byte_pos = self.byte_pos();
self.buffer[byte_pos] &= !mask;
self.buffer[byte_pos] |= shifted_val & mask;
bits -= bits_in_current_byte;
val >>= bits_in_current_byte;
self.bit_pos += bits_in_current_byte as usize;
}
}
pub fn write_byte(&mut self, byte: u8) {
self.align_byte();
self.ensure_byte();
let byte_pos = self.byte_pos();
self.buffer[byte_pos] = byte;
self.bit_pos += 8;
}
pub fn write_bytes(&mut self, data: &[u8]) {
self.align_byte();
self.buffer.extend_from_slice(data);
self.bit_pos += 8 * data.len();
}
pub fn write_dyn_int(&mut self, mut val: u128) {
while val > 0 {
let mut encoded = val % 128;
val /= 128;
if val > 0 {
encoded |= 128;
}
self.write_byte(encoded as u8);
}
}
pub fn write_fixed_int<const S: usize, T: FixedInt<S>>(&mut self, val: T) {
self.write_bytes(&val.serialize());
}
fn ensure_byte(&mut self) {
let byte_pos = self.byte_pos();
if byte_pos >= self.buffer.len() {
self.buffer.resize(byte_pos + 1, 0);
}
}
pub fn align_byte(&mut self) {
let rem = self.bit_pos % 8;
if rem != 0 {
self.bit_pos += 8 - rem;
}
}
pub fn reset(&mut self) {
self.bit_pos = 0;
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
#[cfg(test)]
mod tests {
use super::BitStreamWriter;
fn buffer_to_bin(buffer: &[u8]) -> Vec<String> {
buffer.iter().map(|b| format!("{:08b}", b)).collect()
}
#[test]
fn test_write_bit() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true);
stream.write_bit(false);
stream.write_bit(true);
stream.write_bit(true);
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 0b00001101); }
#[test]
fn test_write_small() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b101, 3); stream.write_small(0b11, 2); stream.write_small(0b111, 3);
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 0b11111101); }
#[test]
fn test_write_cross_byte() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b00101011, 7);
stream.write_small(0b1101, 4);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b10101011);
assert_eq!(buf[1], 0b00000110);
}
#[test]
fn test_write_byte() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true); stream.write_byte(0xAA);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA); }
#[test]
fn test_write_bytes() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true); stream.write_bytes(&[0xAA, 0xBB, 0xCC]);
assert_eq!(buf.len(), 4);
assert_eq!(buf[0], 0b00000001); assert_eq!(buf[1], 0xAA);
assert_eq!(buf[2], 0xBB);
assert_eq!(buf[3], 0xCC);
}
#[test]
fn test_alignment() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_small(0b11, 2); stream.align_byte();
stream.write_byte(0xFF);
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0b00000011); assert_eq!(buf[1], 0xFF);
}
#[test]
fn test_multiple_operations() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_bit(true);
stream.write_small(0b101, 3);
stream.write_byte(0xAA);
stream.write_bytes(&[0xBB, 0xCC]);
stream.write_small(0b11, 2);
let bin = buffer_to_bin(&buf);
println!("{:?}", bin);
assert_eq!(buf.len(), 5);
assert_eq!(buf[0], 0b00001011); assert_eq!(buf[1], 0xAA); assert_eq!(buf[2], 0xBB);
assert_eq!(buf[3], 0xCC);
assert_eq!(buf[4], 0b00000011); }
#[test]
fn test_write_dyn_int() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_dyn_int(127);
assert_eq!(1, stream.len());
stream.write_dyn_int(128); assert_eq!(3, stream.len());
stream.write_dyn_int(268435455); assert_eq!(7, stream.len());
assert_eq!(vec![127, 128, 1, 255, 255, 255, 127], buf);
}
#[test]
fn test_write_fixed_int() {
let mut buf = Vec::new();
let mut stream = BitStreamWriter::new(&mut buf);
stream.write_fixed_int(1u8);
stream.write_fixed_int(1i8);
stream.write_fixed_int(2u16);
stream.write_fixed_int(2i16);
stream.write_fixed_int(3u32);
stream.write_fixed_int(3i32);
stream.write_fixed_int(4u64);
stream.write_fixed_int(4i64);
stream.write_fixed_int(5u128);
stream.write_fixed_int(5i128);
assert_eq!(
vec![
1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 10
],
buf
);
}
}