use core::convert::Infallible;
#[cfg(feature = "mem_dbg")]
use mem_dbg::{MemDbg, MemSize};
use crate::codes::params::{DefaultReadParams, ReadParams};
use crate::traits::*;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
pub struct BitReader<E: Endianness, WR, RP: ReadParams = DefaultReadParams> {
backend: WR,
bit_index: u64,
_marker: core::marker::PhantomData<(E, RP)>,
}
impl<E: Endianness, WR, RP: ReadParams> BitReader<E, WR, RP> {
#[must_use]
pub const fn new(backend: WR) -> Self {
Self {
backend,
bit_index: 0,
_marker: core::marker::PhantomData,
}
}
}
impl<WR: WordRead<Word = u64> + WordSeek<Error = <WR as WordRead>::Error>, RP: ReadParams>
BitRead<BE> for BitReader<BE, WR, RP>
{
type Error = <WR as WordRead>::Error;
type PeekWord = u32;
const PEEK_BITS: usize = 32;
#[inline]
fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
self.bit_index += n_bits as u64;
Ok(())
}
#[inline]
fn read_bits(&mut self, num_bits: usize) -> Result<u64, Self::Error> {
#[cfg(feature = "checks")]
assert!(num_bits <= 64);
if num_bits == 0 {
return Ok(0);
}
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = (self.bit_index % 64) as usize;
let res = if (in_word_offset + num_bits) <= 64 {
let word = self.backend.read_word()?.to_be();
(word << in_word_offset) >> (64 - num_bits)
} else {
let high_word = self.backend.read_word()?.to_be();
let low_word = self.backend.read_word()?.to_be();
let shamt1 = 64 - num_bits;
let shamt2 = 128 - in_word_offset - num_bits;
((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
};
self.bit_index += num_bits as u64;
Ok(res)
}
#[inline]
fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
if n_bits == 0 {
return Ok(0);
}
#[cfg(feature = "checks")]
assert!(n_bits <= 32);
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = (self.bit_index % 64) as usize;
let res = if (in_word_offset + n_bits) <= 64 {
let word = self.backend.read_word()?.to_be();
(word << in_word_offset) >> (64 - n_bits)
} else {
let high_word = self.backend.read_word()?.to_be();
let low_word = self.backend.read_word()?.to_be();
let shamt1 = 64 - n_bits;
let shamt2 = 128 - in_word_offset - n_bits;
((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
};
Ok(res as u32)
}
#[inline]
fn read_unary(&mut self) -> Result<u64, Self::Error> {
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = self.bit_index % 64;
let mut bits_in_word = 64 - in_word_offset;
let mut total = 0;
let mut word = self.backend.read_word()?.to_be();
word <<= in_word_offset;
loop {
let zeros = word.leading_zeros() as u64;
if zeros < bits_in_word {
self.bit_index += total + zeros + 1;
return Ok(total + zeros);
}
total += bits_in_word;
bits_in_word = 64;
word = self.backend.read_word()?.to_be();
}
}
#[inline(always)]
fn skip_bits_after_peek(&mut self, n: usize) {
self.bit_index += n as u64;
}
}
impl<E: Endianness, WR: WordSeek, RP: ReadParams> BitSeek for BitReader<E, WR, RP> {
type Error = Infallible;
fn bit_pos(&mut self) -> Result<u64, Self::Error> {
Ok(self.bit_index)
}
fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
self.bit_index = bit_index;
Ok(())
}
}
impl<WR: WordRead<Word = u64> + WordSeek<Error = <WR as WordRead>::Error>, RP: ReadParams>
BitRead<LE> for BitReader<LE, WR, RP>
{
type Error = <WR as WordRead>::Error;
type PeekWord = u32;
const PEEK_BITS: usize = 32;
#[inline]
fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
self.bit_index += n_bits as u64;
Ok(())
}
#[inline]
fn read_bits(&mut self, num_bits: usize) -> Result<u64, Self::Error> {
#[cfg(feature = "checks")]
assert!(num_bits <= 64);
if num_bits == 0 {
return Ok(0);
}
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = (self.bit_index % 64) as usize;
let res = if (in_word_offset + num_bits) <= 64 {
let word = self.backend.read_word()?.to_le();
let shamt = 64 - num_bits;
(word << (shamt - in_word_offset)) >> shamt
} else {
let low_word = self.backend.read_word()?.to_le();
let high_word = self.backend.read_word()?.to_le();
let shamt1 = 128 - in_word_offset - num_bits;
let shamt2 = 64 - num_bits;
((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
};
self.bit_index += num_bits as u64;
Ok(res)
}
#[inline]
fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
if n_bits == 0 {
return Ok(0);
}
#[cfg(feature = "checks")]
assert!(n_bits <= 32);
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = (self.bit_index % 64) as usize;
let res = if (in_word_offset + n_bits) <= 64 {
let word = self.backend.read_word()?.to_le();
let shamt = 64 - n_bits;
(word << (shamt - in_word_offset)) >> shamt
} else {
let low_word = self.backend.read_word()?.to_le();
let high_word = self.backend.read_word()?.to_le();
let shamt1 = 128 - in_word_offset - n_bits;
let shamt2 = 64 - n_bits;
((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
};
Ok(res as u32)
}
#[inline]
fn read_unary(&mut self) -> Result<u64, Self::Error> {
self.backend.set_word_pos(self.bit_index / 64)?;
let in_word_offset = self.bit_index % 64;
let mut bits_in_word = 64 - in_word_offset;
let mut total = 0;
let mut word = self.backend.read_word()?.to_le();
word >>= in_word_offset;
loop {
let zeros = word.trailing_zeros() as u64;
if zeros < bits_in_word {
self.bit_index += total + zeros + 1;
return Ok(total + zeros);
}
total += bits_in_word;
bits_in_word = 64;
word = self.backend.read_word()?.to_le();
}
}
#[inline(always)]
fn skip_bits_after_peek(&mut self, n: usize) {
self.bit_index += n as u64;
}
}
#[cfg(feature = "std")]
impl<WR: WordRead<Word = u64> + WordSeek<Error = <WR as WordRead>::Error>, RP: ReadParams>
std::io::Read for BitReader<LE, WR, RP>
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut iter = buf.chunks_exact_mut(8);
for chunk in &mut iter {
let word = self
.read_bits(64)
.map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
chunk.copy_from_slice(&word.to_le_bytes());
}
let rem = iter.into_remainder();
if !rem.is_empty() {
let word = self
.read_bits(rem.len() * 8)
.map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
rem.copy_from_slice(&word.to_le_bytes()[..rem.len()]);
}
Ok(buf.len())
}
}
#[cfg(feature = "std")]
impl<WR: WordRead<Word = u64> + WordSeek<Error = <WR as WordRead>::Error>, RP: ReadParams>
std::io::Read for BitReader<BE, WR, RP>
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut iter = buf.chunks_exact_mut(8);
for chunk in &mut iter {
let word = self
.read_bits(64)
.map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
chunk.copy_from_slice(&word.to_be_bytes());
}
let rem = iter.into_remainder();
if !rem.is_empty() {
let word = self
.read_bits(rem.len() * 8)
.map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
rem.copy_from_slice(&word.to_be_bytes()[8 - rem.len()..]);
}
Ok(buf.len())
}
}