use common_traits::*;
use crate::codes::params::{DefaultReadParams, ReadParams};
use crate::traits::*;
use core::convert::Infallible;
use core::{mem, ptr};
use std::error::Error;
type BB<WR> = <<WR as WordRead>::Word as DoubleType>::DoubleType;
#[derive(Debug)]
pub struct BufBitReader<E: Endianness, WR: WordRead, RP: ReadParams = DefaultReadParams>
where
WR::Word: DoubleType,
{
backend: WR,
buffer: BB<WR>,
bits_in_buffer: usize,
_marker: core::marker::PhantomData<(E, RP)>,
}
impl<E: Endianness, WR: WordRead + Clone, RP: ReadParams> core::clone::Clone
for BufBitReader<E, WR, RP>
where
WR::Word: DoubleType,
{
fn clone(&self) -> Self {
Self {
backend: self.backend.clone(),
buffer: self.buffer,
bits_in_buffer: self.bits_in_buffer,
_marker: core::marker::PhantomData,
}
}
}
impl<E: Endianness, WR: WordRead, RP: ReadParams> BufBitReader<E, WR, RP>
where
WR::Word: DoubleType,
{
#[must_use]
pub fn new(backend: WR) -> Self {
check_tables(WR::Word::BITS + 1);
Self {
backend,
buffer: BB::<WR>::ZERO,
bits_in_buffer: 0,
_marker: core::marker::PhantomData,
}
}
pub fn into_inner(self) -> Result<WR, Infallible> {
let backend = unsafe { ptr::read(&self.backend) };
mem::forget(self);
Ok(backend)
}
}
impl<WR: WordRead, RP: ReadParams> BufBitReader<BE, WR, RP>
where
WR::Word: DoubleType,
{
#[inline(always)]
fn refill(&mut self) -> Result<(), <WR as WordRead>::Error> {
debug_assert!(BB::<WR>::BITS - self.bits_in_buffer >= WR::Word::BITS);
let new_word: BB<WR> = self.backend.read_word()?.to_be().upcast();
self.bits_in_buffer += WR::Word::BITS;
self.buffer |= new_word << (BB::<WR>::BITS - self.bits_in_buffer);
Ok(())
}
}
impl<WR: WordRead, RP: ReadParams> BitRead<BE> for BufBitReader<BE, WR, RP>
where
WR::Word: DoubleType + UpcastableInto<u64>,
BB<WR>: CastableInto<u64>,
{
type Error = <WR as WordRead>::Error;
type PeekWord = BB<WR>;
#[inline(always)]
fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
debug_assert!(n_bits > 0);
debug_assert!(n_bits <= Self::PeekWord::BITS);
if n_bits > self.bits_in_buffer {
self.refill()?;
}
debug_assert!(n_bits <= self.bits_in_buffer);
Ok(self.buffer >> (BB::<WR>::BITS - n_bits))
}
#[inline(always)]
fn skip_bits_after_table_lookup(&mut self, n_bits: usize) {
self.bits_in_buffer -= n_bits;
self.buffer <<= n_bits;
}
#[inline]
fn read_bits(&mut self, mut n_bits: usize) -> Result<u64, Self::Error> {
debug_assert!(n_bits <= 64);
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
if n_bits <= self.bits_in_buffer {
let result: u64 = (self.buffer >> (BB::<WR>::BITS - n_bits - 1) >> 1_u32).cast();
self.bits_in_buffer -= n_bits;
self.buffer <<= n_bits;
return Ok(result);
}
let mut result: u64 =
(self.buffer >> (BB::<WR>::BITS - 1 - self.bits_in_buffer) >> 1_u8).cast();
n_bits -= self.bits_in_buffer;
while n_bits > WR::Word::BITS {
let new_word: u64 = self.backend.read_word()?.to_be().upcast();
result = (result << WR::Word::BITS) | new_word;
n_bits -= WR::Word::BITS;
}
debug_assert!(n_bits > 0);
debug_assert!(n_bits <= WR::Word::BITS);
let new_word = self.backend.read_word()?.to_be();
self.bits_in_buffer = WR::Word::BITS - n_bits;
let upcasted: u64 = new_word.upcast();
let final_bits: u64 = (upcasted >> self.bits_in_buffer).downcast();
result = (result << (n_bits - 1) << 1) | final_bits;
self.buffer = (UpcastableInto::<BB<WR>>::upcast(new_word)
<< (BB::<WR>::BITS - self.bits_in_buffer - 1))
<< 1;
Ok(result)
}
#[inline]
fn read_unary(&mut self) -> Result<u64, Self::Error> {
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
let zeros: usize = self.buffer.leading_zeros() as _;
if zeros < self.bits_in_buffer {
self.buffer = self.buffer << zeros << 1;
self.bits_in_buffer -= zeros + 1;
return Ok(zeros as u64);
}
let mut result: u64 = self.bits_in_buffer as _;
loop {
let new_word = self.backend.read_word()?.to_be();
if new_word != WR::Word::ZERO {
let zeros: usize = new_word.leading_zeros() as _;
self.buffer =
UpcastableInto::<BB<WR>>::upcast(new_word) << (WR::Word::BITS + zeros) << 1;
self.bits_in_buffer = WR::Word::BITS - zeros - 1;
return Ok(result + zeros as u64);
}
result += WR::Word::BITS as u64;
}
}
#[inline]
fn skip_bits(&mut self, mut n_bits: usize) -> Result<(), Self::Error> {
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
if n_bits <= self.bits_in_buffer {
self.bits_in_buffer -= n_bits;
self.buffer <<= n_bits;
return Ok(());
}
n_bits -= self.bits_in_buffer;
while n_bits > WR::Word::BITS {
let _ = self.backend.read_word()?;
n_bits -= WR::Word::BITS;
}
let new_word = self.backend.read_word()?.to_be();
self.bits_in_buffer = WR::Word::BITS - n_bits;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word)
<< (BB::<WR>::BITS - 1 - self.bits_in_buffer)
<< 1;
Ok(())
}
#[cfg(not(feature = "no_copy_impls"))]
fn copy_to<F: Endianness>(
&mut self,
bit_write: &mut impl BitWrite<F>,
mut n: u64,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let from_buffer = Ord::min(n, self.bits_in_buffer as _);
self.buffer = self.buffer.rotate_left(from_buffer as _);
bit_write.write_bits(self.buffer.cast(), from_buffer as usize)?;
n -= from_buffer;
if n == 0 {
self.bits_in_buffer -= from_buffer as usize;
return Ok(());
}
while n > WR::Word::BITS as u64 {
bit_write.write_bits(self.backend.read_word()?.to_be().upcast(), WR::Word::BITS)?;
n -= WR::Word::BITS as u64;
}
assert!(n > 0);
let new_word = self.backend.read_word()?.to_be();
self.bits_in_buffer = WR::Word::BITS - n as usize;
bit_write.write_bits((new_word >> self.bits_in_buffer).upcast(), n as usize)?;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word)
.rotate_right(WR::Word::BITS as u32 - n as u32);
Ok(())
}
}
impl<
E: Error + Send + Sync + 'static,
WR: WordRead<Error = E> + WordSeek<Error = E>,
RP: ReadParams,
> BitSeek for BufBitReader<BE, WR, RP>
where
WR::Word: DoubleType,
{
type Error = <WR as WordSeek>::Error;
#[inline]
fn get_bit_pos(&mut self) -> Result<u64, Self::Error> {
Ok(self.backend.get_word_pos()? * WR::Word::BITS as u64 - self.bits_in_buffer as u64)
}
#[inline]
fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
self.backend
.set_word_pos(bit_index / WR::Word::BITS as u64)?;
let bit_offset = (bit_index % WR::Word::BITS as u64) as usize;
self.buffer = BB::<WR>::ZERO;
self.bits_in_buffer = 0;
if bit_offset != 0 {
let new_word: BB<WR> = self.backend.read_word()?.to_be().upcast();
self.bits_in_buffer = WR::Word::BITS - bit_offset;
self.buffer = new_word << (BB::<WR>::BITS - self.bits_in_buffer);
}
Ok(())
}
}
impl<WR: WordRead, RP: ReadParams> BufBitReader<LE, WR, RP>
where
WR::Word: DoubleType,
{
#[inline(always)]
fn refill(&mut self) -> Result<(), <WR as WordRead>::Error> {
debug_assert!(BB::<WR>::BITS - self.bits_in_buffer >= WR::Word::BITS);
let new_word: BB<WR> = self.backend.read_word()?.to_le().upcast();
self.buffer |= new_word << self.bits_in_buffer;
self.bits_in_buffer += WR::Word::BITS;
Ok(())
}
}
impl<WR: WordRead, RP: ReadParams> BitRead<LE> for BufBitReader<LE, WR, RP>
where
WR::Word: DoubleType + UpcastableInto<u64>,
BB<WR>: CastableInto<u64>,
{
type Error = <WR as WordRead>::Error;
type PeekWord = BB<WR>;
#[inline(always)]
fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
debug_assert!(n_bits > 0);
debug_assert!(n_bits <= Self::PeekWord::BITS);
if n_bits > self.bits_in_buffer {
self.refill()?;
}
debug_assert!(n_bits <= self.bits_in_buffer);
let shamt = BB::<WR>::BITS - n_bits;
Ok((self.buffer << shamt) >> shamt)
}
#[inline(always)]
fn skip_bits_after_table_lookup(&mut self, n_bits: usize) {
self.bits_in_buffer -= n_bits;
self.buffer >>= n_bits;
}
#[inline]
fn read_bits(&mut self, mut n_bits: usize) -> Result<u64, Self::Error> {
debug_assert!(n_bits <= 64);
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
if n_bits <= self.bits_in_buffer {
let result: u64 = (self.buffer & ((BB::<WR>::ONE << n_bits) - BB::<WR>::ONE)).cast();
self.bits_in_buffer -= n_bits;
self.buffer >>= n_bits;
return Ok(result);
}
let mut result: u64 = self.buffer.cast();
let mut bits_in_res = self.bits_in_buffer;
while n_bits > WR::Word::BITS + bits_in_res {
let new_word: u64 = self.backend.read_word()?.to_le().upcast();
result |= new_word << bits_in_res;
bits_in_res += WR::Word::BITS;
}
n_bits -= bits_in_res;
debug_assert!(n_bits > 0);
debug_assert!(n_bits <= WR::Word::BITS);
let new_word = self.backend.read_word()?.to_le();
self.bits_in_buffer = WR::Word::BITS - n_bits;
let shamt = 64 - n_bits;
let upcasted: u64 = new_word.upcast();
let final_bits: u64 = ((upcasted << shamt) >> shamt).downcast();
result |= final_bits << bits_in_res;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n_bits;
Ok(result)
}
#[inline]
fn read_unary(&mut self) -> Result<u64, Self::Error> {
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
let zeros: usize = self.buffer.trailing_zeros() as usize;
if zeros < self.bits_in_buffer {
self.buffer = self.buffer >> zeros >> 1;
self.bits_in_buffer -= zeros + 1;
return Ok(zeros as u64);
}
let mut result: u64 = self.bits_in_buffer as _;
loop {
let new_word = self.backend.read_word()?.to_le();
if new_word != WR::Word::ZERO {
let zeros: usize = new_word.trailing_zeros() as _;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> zeros >> 1;
self.bits_in_buffer = WR::Word::BITS - zeros - 1;
return Ok(result + zeros as u64);
}
result += WR::Word::BITS as u64;
}
}
#[inline]
fn skip_bits(&mut self, mut n_bits: usize) -> Result<(), Self::Error> {
debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
if n_bits <= self.bits_in_buffer {
self.bits_in_buffer -= n_bits;
self.buffer >>= n_bits;
return Ok(());
}
n_bits -= self.bits_in_buffer;
while n_bits > WR::Word::BITS {
let _ = self.backend.read_word()?;
n_bits -= WR::Word::BITS;
}
let new_word = self.backend.read_word()?.to_le();
self.bits_in_buffer = WR::Word::BITS - n_bits;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n_bits;
Ok(())
}
#[cfg(not(feature = "no_copy_impls"))]
fn copy_to<F: Endianness>(
&mut self,
bit_write: &mut impl BitWrite<F>,
mut n: u64,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let from_buffer = Ord::min(n, self.bits_in_buffer as _);
bit_write.write_bits(self.buffer.cast(), from_buffer as usize)?;
self.buffer >>= from_buffer;
n -= from_buffer;
if n == 0 {
self.bits_in_buffer -= from_buffer as usize;
return Ok(());
}
while n > WR::Word::BITS as u64 {
bit_write.write_bits(self.backend.read_word()?.to_le().upcast(), WR::Word::BITS)?;
n -= WR::Word::BITS as u64;
}
assert!(n > 0);
let new_word = self.backend.read_word()?.to_le();
self.bits_in_buffer = WR::Word::BITS - n as usize;
bit_write.write_bits(new_word.upcast(), n as usize)?;
self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n;
Ok(())
}
}
impl<
E: Error + Send + Sync + 'static,
WR: WordRead<Error = E> + WordSeek<Error = E>,
RP: ReadParams,
> BitSeek for BufBitReader<LE, WR, RP>
where
WR::Word: DoubleType,
{
type Error = <WR as WordSeek>::Error;
#[inline]
fn get_bit_pos(&mut self) -> Result<u64, Self::Error> {
Ok(self.backend.get_word_pos()? * WR::Word::BITS as u64 - self.bits_in_buffer as u64)
}
#[inline]
fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
self.backend
.set_word_pos(bit_index / WR::Word::BITS as u64)?;
let bit_offset = (bit_index % WR::Word::BITS as u64) as usize;
self.buffer = BB::<WR>::ZERO;
self.bits_in_buffer = 0;
if bit_offset != 0 {
let new_word: BB<WR> = self.backend.read_word()?.to_le().upcast();
self.bits_in_buffer = WR::Word::BITS - bit_offset;
self.buffer = new_word >> bit_offset;
}
Ok(())
}
}
macro_rules! test_buf_bit_reader {
($f: ident, $word:ty) => {
#[test]
fn $f() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
use super::MemWordWriterVec;
use crate::{
codes::{GammaRead, GammaWrite},
prelude::{
len_delta, len_gamma, BufBitWriter, DeltaRead, DeltaWrite, MemWordReader,
},
};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
let mut buffer_be: Vec<$word> = vec![];
let mut buffer_le: Vec<$word> = vec![];
let mut big = super::BufBitWriter::<BE, _>::new(MemWordWriterVec::new(&mut buffer_be));
let mut little = BufBitWriter::<LE, _>::new(MemWordWriterVec::new(&mut buffer_le));
let mut r = SmallRng::seed_from_u64(0);
const ITER: usize = 1_000_000;
for _ in 0..ITER {
let value = r.gen_range(0..128);
assert_eq!(big.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_delta(value)?, len_delta(value));
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
big.write_bits(0, 0)?;
} else {
big.write_bits(1, n_bits)?;
}
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
little.write_bits(0, 0)?;
} else {
little.write_bits(1, n_bits)?;
}
let value = r.gen_range(0..128);
assert_eq!(big.write_unary(value)?, value as usize + 1);
let value = r.gen_range(0..128);
assert_eq!(little.write_unary(value)?, value as usize + 1);
}
drop(big);
drop(little);
type ReadWord = $word;
let be_trans: &[ReadWord] = unsafe {
core::slice::from_raw_parts(
buffer_be.as_ptr() as *const ReadWord,
buffer_be.len()
* (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
)
};
let le_trans: &[ReadWord] = unsafe {
core::slice::from_raw_parts(
buffer_le.as_ptr() as *const ReadWord,
buffer_le.len()
* (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
)
};
let mut big_buff = BufBitReader::<BE, _>::new(MemWordReader::new(be_trans));
let mut little_buff = BufBitReader::<LE, _>::new(MemWordReader::new(le_trans));
let mut r = SmallRng::seed_from_u64(0);
for _ in 0..ITER {
assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
assert_eq!(big_buff.read_bits(0)?, 0);
} else {
assert_eq!(big_buff.read_bits(n_bits)?, 1);
}
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
assert_eq!(little_buff.read_bits(0)?, 0);
} else {
assert_eq!(little_buff.read_bits(n_bits)?, 1);
}
assert_eq!(big_buff.read_unary()?, r.gen_range(0..128));
assert_eq!(little_buff.read_unary()?, r.gen_range(0..128));
}
Ok(())
}
};
}
test_buf_bit_reader!(test_u64, u64);
test_buf_bit_reader!(test_u32, u32);
test_buf_bit_reader!(test_u16, u16);