use crate::error::BuffError;
pub const MAX_BITS: usize = 32;
const BYTE_BITS: usize = 8;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BitPack<B> {
buff: B,
cursor: usize,
bits: usize,
}
impl<B> BitPack<B> {
#[inline]
pub fn new(buff: B) -> Self {
BitPack {
buff,
cursor: 0,
bits: 0,
}
}
#[inline]
pub fn sum_bits(&self) -> usize {
self.cursor * BYTE_BITS + self.bits
}
#[inline]
pub fn with_cursor(&mut self, cursor: usize) -> &mut Self {
self.cursor = cursor;
self
}
#[inline]
pub fn with_bits(&mut self, bits: usize) -> &mut Self {
self.bits = bits;
self
}
}
impl<B: AsRef<[u8]>> BitPack<B> {
#[inline]
pub fn as_slice(&self) -> &[u8] {
self.buff.as_ref()
}
}
impl BitPack<&[u8]> {
pub fn read(&mut self, mut bits: usize) -> Result<u32, BuffError> {
if bits > MAX_BITS {
return Err(BuffError::BitWidthExceeded(bits));
}
if self.buff.len() * BYTE_BITS < self.sum_bits() + bits {
return Err(BuffError::BufferOverflow {
attempted: bits,
available: self.buff.len() * BYTE_BITS - self.sum_bits(),
});
}
let mut bits_left = 0u32;
let mut output = 0u32;
loop {
let byte_left = BYTE_BITS - self.bits;
if bits <= byte_left {
let mut bb = self.buff[self.cursor] as u32;
bb >>= self.bits as u32;
bb &= ((1 << bits) - 1) as u32;
output |= bb << bits_left;
self.bits += bits;
break;
}
let mut bb = self.buff[self.cursor] as u32;
bb >>= self.bits as u32;
bb &= ((1 << byte_left) - 1) as u32;
output |= bb << bits_left;
self.bits += byte_left;
bits_left += byte_left as u32;
bits -= byte_left;
if self.bits >= BYTE_BITS {
self.cursor += 1;
self.bits -= BYTE_BITS;
}
}
Ok(output)
}
#[inline]
pub fn read_byte(&mut self) -> Result<u8, BuffError> {
self.cursor += 1;
if self.cursor >= self.buff.len() {
return Err(BuffError::InvalidData("unexpected end of buffer".into()));
}
Ok(self.buff[self.cursor])
}
#[inline]
pub fn read_n_byte(&mut self, n: usize) -> Result<&[u8], BuffError> {
self.cursor += 1;
let end = self.cursor + n;
if end > self.buff.len() {
return Err(BuffError::BufferOverflow {
attempted: n,
available: self.buff.len() - self.cursor,
});
}
let output = &self.buff[self.cursor..end];
self.cursor += n - 1;
Ok(output)
}
#[inline]
pub fn read_n_byte_unmut(&self, start: usize, n: usize) -> Result<&[u8], BuffError> {
let s = start + self.cursor + 1;
let end = s + n;
if end > self.buff.len() {
return Err(BuffError::BufferOverflow {
attempted: n,
available: self.buff.len().saturating_sub(s),
});
}
Ok(&self.buff[s..end])
}
#[inline]
pub fn skip_n_byte(&mut self, n: usize) {
self.cursor += n;
}
#[inline]
pub fn skip(&mut self, bits: usize) -> Result<(), BuffError> {
if self.buff.len() * BYTE_BITS < self.sum_bits() + bits {
return Err(BuffError::BufferOverflow {
attempted: bits,
available: self.buff.len() * BYTE_BITS - self.sum_bits(),
});
}
let bytes = bits / BYTE_BITS;
let left = bits % BYTE_BITS;
let cur_bits = self.bits + left;
self.cursor = self.cursor + bytes + cur_bits / BYTE_BITS;
self.bits = cur_bits % BYTE_BITS;
Ok(())
}
#[inline]
pub fn finish_read_byte(&mut self) {
self.cursor += 1;
self.bits = 0;
}
}
impl BitPack<&mut [u8]> {
pub fn write(&mut self, mut value: u32, mut bits: usize) -> Result<(), BuffError> {
if bits > MAX_BITS {
return Err(BuffError::BitWidthExceeded(bits));
}
if self.buff.len() * BYTE_BITS < self.sum_bits() + bits {
return Err(BuffError::BufferOverflow {
attempted: bits,
available: self.buff.len() * BYTE_BITS - self.sum_bits(),
});
}
if bits < MAX_BITS {
value &= ((1 << bits) - 1) as u32;
}
loop {
let bits_left = BYTE_BITS - self.bits;
if bits <= bits_left {
self.buff[self.cursor] |= (value as u8) << self.bits as u8;
self.bits += bits;
if self.bits >= BYTE_BITS {
self.cursor += 1;
self.bits = 0;
}
break;
}
let bb = value & ((1 << bits_left) - 1) as u32;
self.buff[self.cursor] |= (bb as u8) << self.bits as u8;
self.cursor += 1;
self.bits = 0;
value >>= bits_left as u32;
bits -= bits_left;
}
Ok(())
}
}
impl Default for BitPack<Vec<u8>> {
fn default() -> Self {
Self::new(Vec::new())
}
}
impl BitPack<Vec<u8>> {
pub fn with_capacity(capacity: usize) -> Self {
Self::new(Vec::with_capacity(capacity))
}
#[inline]
pub fn write(&mut self, value: u32, bits: usize) -> Result<(), BuffError> {
if bits > MAX_BITS {
return Err(BuffError::BitWidthExceeded(bits));
}
let len = self.buff.len();
if let Some(bits_needed) = (self.sum_bits() + bits).checked_sub(len * BYTE_BITS) {
self.buff.resize(len + bits_needed.div_ceil(BYTE_BITS), 0x0);
}
let mut bitpack = BitPack {
buff: self.buff.as_mut_slice(),
cursor: self.cursor,
bits: self.bits,
};
bitpack.write(value, bits)?;
self.bits = bitpack.bits;
self.cursor = bitpack.cursor;
Ok(())
}
#[inline]
pub fn write_byte(&mut self, value: u8) -> Result<(), BuffError> {
self.buff.push(value);
Ok(())
}
#[inline]
pub fn write_bytes(&mut self, values: &[u8]) {
self.buff.extend_from_slice(values);
}
#[inline]
pub fn finish_write_byte(&mut self) {
let len = self.buff.len();
self.buff.resize(len + 1, 0x0);
self.bits = 0;
self.cursor = len;
}
#[inline]
pub fn into_vec(self) -> Vec<u8> {
self.buff
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_read_roundtrip() {
let mut bitpack_vec = BitPack::<Vec<u8>>::with_capacity(8);
bitpack_vec.write(10, 4).unwrap();
bitpack_vec.write(1021, 10).unwrap();
bitpack_vec.write(3, 2).unwrap();
let mut bitpack = BitPack::<&[u8]>::new(bitpack_vec.as_slice());
assert_eq!(bitpack.read(4).unwrap(), 10);
assert_eq!(bitpack.read(10).unwrap(), 1021);
assert_eq!(bitpack.read(2).unwrap(), 3);
}
#[test]
fn test_single_bits() {
let mut bitpack_vec = BitPack::<Vec<u8>>::with_capacity(1);
bitpack_vec.write(1, 1).unwrap();
bitpack_vec.write(0, 1).unwrap();
bitpack_vec.write(0, 1).unwrap();
bitpack_vec.write(1, 1).unwrap();
let mut bitpack = BitPack::<&[u8]>::new(bitpack_vec.as_slice());
assert_eq!(bitpack.read(1).unwrap(), 1);
assert_eq!(bitpack.read(1).unwrap(), 0);
assert_eq!(bitpack.read(1).unwrap(), 0);
assert_eq!(bitpack.read(1).unwrap(), 1);
}
#[test]
fn test_full_bytes() {
let mut bitpack_vec = BitPack::<Vec<u8>>::with_capacity(8);
bitpack_vec.write(255, 8).unwrap();
bitpack_vec.write(65535, 16).unwrap();
bitpack_vec.write(255, 8).unwrap();
let mut bitpack = BitPack::<&[u8]>::new(bitpack_vec.as_slice());
assert_eq!(bitpack.read(8).unwrap(), 255);
assert_eq!(bitpack.read(16).unwrap(), 65535);
assert_eq!(bitpack.read(8).unwrap(), 255);
}
#[test]
fn test_bit_width_exceeded() {
let mut bitpack_vec = BitPack::<Vec<u8>>::with_capacity(8);
let result = bitpack_vec.write(0, 33);
assert!(matches!(result, Err(BuffError::BitWidthExceeded(33))));
}
#[test]
fn test_read_bit_width_exceeded() {
let data = vec![0u8; 10];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let result = bitpack.read(33);
assert!(matches!(result, Err(BuffError::BitWidthExceeded(33))));
}
#[test]
fn test_read_buffer_overflow() {
let data = vec![0u8; 2];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let _ = bitpack.read(16);
let result = bitpack.read(8);
assert!(matches!(result, Err(BuffError::BufferOverflow { .. })));
}
#[test]
fn test_sum_bits() {
let mut bitpack_vec = BitPack::<Vec<u8>>::with_capacity(8);
assert_eq!(bitpack_vec.sum_bits(), 0);
bitpack_vec.write(15, 4).unwrap();
assert_eq!(bitpack_vec.sum_bits(), 4);
bitpack_vec.write(255, 8).unwrap();
assert_eq!(bitpack_vec.sum_bits(), 12);
}
#[test]
fn test_with_cursor_and_bits() {
let data = vec![0xAB, 0xCD, 0xEF];
let mut bitpack = BitPack::<&[u8]>::new(&data);
bitpack.with_cursor(1);
assert_eq!(bitpack.cursor, 1);
bitpack.with_bits(4);
assert_eq!(bitpack.bits, 4);
}
#[test]
fn test_as_slice() {
let data = vec![1u8, 2, 3, 4];
let bitpack = BitPack::<&[u8]>::new(&data);
assert_eq!(bitpack.as_slice(), &[1, 2, 3, 4]);
}
#[test]
fn test_read_n_byte() {
let data = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let bytes = bitpack.read_n_byte(3).unwrap();
assert_eq!(bytes, &[1, 2, 3]);
}
#[test]
fn test_read_n_byte_overflow() {
let data = vec![0u8, 1, 2];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let result = bitpack.read_n_byte(10);
assert!(matches!(result, Err(BuffError::BufferOverflow { .. })));
}
#[test]
fn test_read_n_byte_unmut() {
let data = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
let bitpack = BitPack::<&[u8]>::new(&data);
let bytes = bitpack.read_n_byte_unmut(0, 3).unwrap();
assert_eq!(bytes, &[1, 2, 3]);
}
#[test]
fn test_read_n_byte_unmut_overflow() {
let data = vec![0u8, 1, 2];
let bitpack = BitPack::<&[u8]>::new(&data);
let result = bitpack.read_n_byte_unmut(0, 10);
assert!(matches!(result, Err(BuffError::BufferOverflow { .. })));
}
#[test]
fn test_skip_n_byte() {
let data = vec![0u8; 10];
let mut bitpack = BitPack::<&[u8]>::new(&data);
assert_eq!(bitpack.cursor, 0);
bitpack.skip_n_byte(5);
assert_eq!(bitpack.cursor, 5);
}
#[test]
fn test_skip_bits() {
let data = vec![0u8; 10];
let mut bitpack = BitPack::<&[u8]>::new(&data);
bitpack.skip(20).unwrap();
assert_eq!(bitpack.cursor, 2);
assert_eq!(bitpack.bits, 4);
}
#[test]
fn test_skip_overflow() {
let data = vec![0u8; 2];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let result = bitpack.skip(100);
assert!(matches!(result, Err(BuffError::BufferOverflow { .. })));
}
#[test]
fn test_finish_read_byte() {
let data = vec![0u8; 10];
let mut bitpack = BitPack::<&[u8]>::new(&data);
bitpack.bits = 5; bitpack.finish_read_byte();
assert_eq!(bitpack.cursor, 1);
assert_eq!(bitpack.bits, 0);
}
#[test]
fn test_finish_write_byte() {
let mut bitpack = BitPack::<Vec<u8>>::with_capacity(10);
bitpack.write(0xAB, 8).unwrap();
bitpack.finish_write_byte();
assert_eq!(bitpack.bits, 0);
assert_eq!(bitpack.buff.len(), 2);
}
#[test]
fn test_write_bytes() {
let mut bitpack = BitPack::<Vec<u8>>::with_capacity(10);
bitpack.write_bytes(&[1, 2, 3, 4]);
assert_eq!(bitpack.buff, vec![1, 2, 3, 4]);
}
#[test]
fn test_write_byte() {
let mut bitpack = BitPack::<Vec<u8>>::with_capacity(10);
bitpack.write_byte(0xAB).unwrap();
bitpack.write_byte(0xCD).unwrap();
assert_eq!(bitpack.buff, vec![0xAB, 0xCD]);
}
#[test]
fn test_into_vec() {
let mut bitpack = BitPack::<Vec<u8>>::with_capacity(10);
bitpack.write(0xABCD, 16).unwrap();
let vec = bitpack.into_vec();
assert!(!vec.is_empty());
}
#[test]
fn test_default() {
let bitpack = BitPack::<Vec<u8>>::default();
assert!(bitpack.buff.is_empty());
assert_eq!(bitpack.cursor, 0);
assert_eq!(bitpack.bits, 0);
}
#[test]
fn test_write_to_mut_slice() {
let mut buffer = [0u8; 4];
let mut bitpack = BitPack::new(&mut buffer[..]);
bitpack.write(0xAB, 8).unwrap();
bitpack.write(0xCD, 8).unwrap();
assert_eq!(buffer[0], 0xAB);
assert_eq!(buffer[1], 0xCD);
}
#[test]
fn test_write_to_mut_slice_overflow() {
let mut buffer = [0u8; 2];
let mut bitpack = BitPack::new(&mut buffer[..]);
bitpack.write(0xFFFF, 16).unwrap();
let result = bitpack.write(0xFF, 8);
assert!(matches!(result, Err(BuffError::BufferOverflow { .. })));
}
#[test]
fn test_write_max_bits() {
let mut bitpack = BitPack::<Vec<u8>>::with_capacity(8);
bitpack.write(0xDEADBEEF, 32).unwrap();
assert_eq!(bitpack.sum_bits(), 32);
}
#[test]
fn test_read_byte() {
let data = vec![0xAA, 0xBB, 0xCC];
let mut bitpack = BitPack::<&[u8]>::new(&data);
assert_eq!(bitpack.read_byte().unwrap(), 0xBB);
assert_eq!(bitpack.read_byte().unwrap(), 0xCC);
}
#[test]
fn test_read_byte_eof() {
let data = vec![0xAA];
let mut bitpack = BitPack::<&[u8]>::new(&data);
let result = bitpack.read_byte();
assert!(matches!(result, Err(BuffError::InvalidData(_))));
}
#[test]
fn test_write_mut_slice_bit_width_exceeded() {
let mut buffer = [0u8; 10];
let mut bitpack = BitPack::new(&mut buffer[..]);
let result = bitpack.write(0, 33);
assert!(matches!(result, Err(BuffError::BitWidthExceeded(33))));
}
#[test]
fn test_write_mut_slice_full_32_bits() {
let mut buffer = [0u8; 8];
let mut bitpack = BitPack::new(&mut buffer[..]);
bitpack.write(0xDEADBEEF, 32).unwrap();
assert_eq!(bitpack.sum_bits(), 32);
}
#[test]
fn test_write_mut_slice_value_masking() {
let mut buffer = [0u8; 4];
let mut bitpack = BitPack::new(&mut buffer[..]);
bitpack.write(0xFF, 4).unwrap();
assert_eq!(buffer[0] & 0x0F, 0x0F);
}
}