use crate::constants::*;
use crate::errors::*;
use crate::debug_test;
#[derive(Default)]
pub struct BitEstream {
pub(crate) bits_container: BitContainer,
pub(crate) stream: Vec<u8>,
pub(crate) bit_pos: u8,
}
pub struct BitDstream {
pub(crate) bits_container: BitContainer,
pub(crate) stream: Vec<u8>,
pub(crate) bits_consumed: u8,
}
pub trait BitWriter<T> {
fn unchecked_write(&mut self, value: T, nb_bits: u8);
}
pub trait BitReader<T> {
fn read(&mut self, nb_bits: u8) -> Result<T>;
}
impl BitWriter<usize> for BitEstream {
fn unchecked_write(&mut self, mut value: usize, nb_bits: u8) {
assert!(nb_bits <= usize::BITS as u8);
if self.bit_pos == CTNR_SIZE {
self.flush_bits();
}
if nb_bits + self.bit_pos <= CTNR_SIZE {
self.bits_container |= (value & BIT_MASK[nb_bits as usize]) << self.bit_pos;
self.bit_pos += nb_bits;
return;
}
let rest = CTNR_SIZE - self.bit_pos;
self.bits_container |= (value & BIT_MASK[rest as usize]) << self.bit_pos;
self.bit_pos += rest;
self.flush_bits();
value >>= rest;
self.unchecked_write(value, nb_bits - rest)
}
}
impl BitReader<usize> for BitDstream {
fn read(&mut self, nb_bits: u8) -> Result<usize> {
assert!(usize::BITS as u8 >= nb_bits);
if CTNR_SIZE == self.bits_consumed {
if self.reload_container()? {
throw!(ReadOverflow);
}
}
if (CTNR_SIZE - self.bits_consumed) >= nb_bits {
return self.read_bits(nb_bits);
}
let rest = CTNR_SIZE - self.bits_consumed;
let rem = nb_bits - rest;
let ret = self.read_bits(rest)? << rem;
if self.reload_container()? {
throw!(ReadOverflow);
}
Ok(self.read_bits(rem)? + ret)
}
}
impl BitEstream {
pub fn new() -> Self {
Default::default()
}
#[inline]
pub fn add_bits<T: Into<BitContainer>>(&mut self, value: T, nb_bits: u8) -> Result<()> {
if nb_bits > BIT_MASK_SIZE {
throw!(AddBits, format!("MASK size smaller than {}", nb_bits))
}
if nb_bits + self.bit_pos > CTNR_SIZE {
throw!(AddBits, "Container overflow")
}
self.bits_container |= (value.into() & BIT_MASK[nb_bits as usize]) << self.bit_pos;
self.bit_pos += nb_bits;
Ok(())
}
#[inline]
pub fn unchecked_add_bits(&mut self, value: BitContainer, nb_bits: u8) -> Result<()> {
if value >> nb_bits != 0 || nb_bits + self.bit_pos > CTNR_SIZE {
throw!(
AddBits,
format!(
"Error add bits fast (nb_bits {}) (bit_pos {})",
nb_bits, self.bit_pos
)
)
}
self.bits_container |= value << self.bit_pos;
self.bit_pos += nb_bits;
Ok(())
}
#[inline]
pub fn flush_bits(&mut self) {
let nb_bytes = (self.bit_pos >> 3) as usize;
self.stream.append(
&mut self
.bits_container
.to_le_bytes()
.to_vec()
.drain(0..nb_bytes)
.collect(),
);
self.bit_pos &= 7;
if nb_bytes == CTNR_BYTES_SIZE {
self.bits_container = 0
} else {
self.bits_container >>= nb_bytes << 3;
}
}
#[inline]
pub fn close_stream(&mut self) -> Result<()> {
if self.bit_pos == CTNR_SIZE {
self.flush_bits();
}
self.unchecked_add_bits(1, 1)?;
self.flush_bits();
if self.bits_container > 0 {
debug_test!("Push {:010b} into the stream", self.bits_container as u8);
self.stream.push(self.bits_container as u8);
self.bits_container = 0;
}
Ok(())
}
}
impl TryFrom<BitEstream> for BitDstream {
type Error = BitStreamError;
fn try_from(mut stream: BitEstream) -> Result<Self> {
if stream.stream.is_empty() {
throw!(EmptyStream)
}
let (bits_container, bits_consumed, _) = build_dstream_from_vec(&mut stream.stream)?;
Ok(BitDstream {
bits_container,
stream: stream.stream,
bits_consumed,
})
}
}
impl From<BitEstream> for Vec<u8> {
fn from(mut stream: BitEstream) -> Self {
stream.close_stream().unwrap();
stream.stream
}
}
impl From<&BitEstream> for Vec<u8> {
fn from(stream: &BitEstream) -> Self {
stream.stream.clone()
}
}
impl TryFrom<Vec<u8>> for BitDstream {
type Error = BitStreamError;
fn try_from(mut stream: Vec<u8>) -> Result<Self> {
if stream.is_empty() {
throw!(EmptyStream)
}
let (bits_container, bits_consumed, _) = build_dstream_from_vec(&mut stream)?;
Ok(BitDstream {
bits_container,
stream,
bits_consumed,
})
}
}
impl BitDstream {
#[inline]
pub fn look_bits(&mut self, nb_bits: u8) -> Result<BitContainer> {
debug_test!("Looking at bits {} {}", self.bits_consumed, nb_bits);
get_middle_bits(
self.bits_container,
CTNR_SIZE
.checked_sub(self.bits_consumed)
.unwrap_or_else(|| {
panic!(
"attempt to substract with overflow: \
substract the buffer container size {CTNR_SIZE} with current\
bits consumed by the reader {}",
self.bits_consumed
)
})
.checked_sub(nb_bits)
.unwrap_or_else(|| {
panic!(
"attempt to substract with overflow: \
substract the buffer container size {CTNR_SIZE} with current\
bits consumed by the reader {} and a given number of bits: {}",
self.bits_consumed, nb_bits
)
}),
nb_bits,
)
}
#[inline]
pub fn skip_bits(&mut self, nb_bits: u8) {
self.bits_consumed += nb_bits
}
#[inline]
pub fn read_bits(&mut self, nb_bits: u8) -> Result<BitContainer> {
let bits = self.look_bits(nb_bits)?;
self.skip_bits(nb_bits);
Ok(bits)
}
#[inline]
pub fn reload_container(&mut self) -> Result<bool> {
if self.bits_consumed > CTNR_SIZE {
throw!(BitsOverflow, "Stream consume overflow")
}
if self.stream.is_empty() {
return Ok(true);
}
let nb_bytes = (self.bits_consumed >> 3) as usize;
self.bits_consumed &= 7;
let mut ctnr = self.bits_container.to_le_bytes().to_vec();
let mut ctnr = ctnr.drain(..ctnr.len() - nb_bytes).collect::<Vec<u8>>();
self.stream.append(&mut ctnr);
let dstr = build_dstream_from_vec(&mut self.stream)?;
self.bits_consumed = if dstr.2 < CTNR_BYTES_SIZE {
CTNR_SIZE - (dstr.2 << 3) as u8
} else {
0
};
self.bits_container = dstr.0;
Ok(false)
}
}
#[inline]
fn read_le(v: &[u8]) -> BitContainer {
debug_test!("Read LITTLE-ENDIAN : {:?}", v);
assert_eq!(
v.len(),
CTNR_BYTES_SIZE,
"Unexpected size of container, cannot transmute value"
);
BitContainer::from_le_bytes(v.try_into().unwrap())
}
#[inline]
pub fn get_upper_bits(bits_container: BitContainer, start: u8) -> BitContainer {
bits_container >> start
}
#[inline]
pub fn get_middle_bits(
bit_container: BitContainer,
start: u8,
nb_bits: u8,
) -> Result<BitContainer> {
const REG_MASK: u8 = CTNR_SIZE - 1;
if nb_bits > BIT_MASK_SIZE {
throw!(BitsOverflow, format!("at get_middle_bits() : {}", nb_bits))
}
Ok((bit_container >> (start & REG_MASK)) & BIT_MASK[nb_bits as usize])
}
#[inline]
pub fn get_lower_bits(bit_container: BitContainer, nb_bits: u8) -> Result<BitContainer> {
if nb_bits > BIT_MASK_SIZE {
throw!(BitsOverflow, format!("at get_lower_bits() : {}", nb_bits))
}
Ok(bit_container & BIT_MASK[nb_bits as usize])
}
#[inline]
fn build_dstream_from_vec(stream: &mut Vec<u8>) -> Result<(BitContainer, u8, usize)> {
debug_test!(
"\nStart building a new build_dstream_from_vec of len : {}",
stream.len()
);
let bytes: usize;
let mut bits_container: BitContainer;
if stream.len() >= CTNR_BYTES_SIZE {
let bits: Vec<u8> = stream.drain(stream.len() - CTNR_BYTES_SIZE..).collect();
bits_container = read_le(&bits);
bytes = CTNR_BYTES_SIZE;
} else {
bytes = stream.len();
bits_container = stream[0] as BitContainer;
#[cfg(target_pointer_width = "64")]
{
if stream.len() == 7 {
bits_container += (stream.pop().unwrap() as BitContainer) << (CTNR_SIZE - 16);
}
if stream.len() == 6 {
bits_container += (stream.pop().unwrap() as BitContainer) << (CTNR_SIZE - 24);
}
if stream.len() == 5 {
bits_container += (stream.pop().unwrap() as BitContainer) << (CTNR_SIZE - 32);
}
if stream.len() == 4 {
bits_container += (stream.pop().unwrap() as BitContainer) << 24;
}
}
if stream.len() == 3 {
bits_container += (stream.pop().unwrap() as BitContainer) << 16;
}
if stream.len() == 2 {
bits_container += (stream.pop().unwrap() as BitContainer) << 8;
}
stream.clear();
}
Ok((bits_container, bits_container.leading_zeros() as u8, bytes))
}