use std::cmp::min;
use crate::bit_words::BitWords;
use crate::bits;
use crate::constants::{BITS_TO_ENCODE_N_ENTRIES, BYTES_PER_WORD, WORD_SIZE};
use crate::data_types::UnsignedLike;
use crate::errors::{QCompressError, QCompressResult};
#[derive(Clone)]
pub struct BitReader<'a> {
word: usize,
words: &'a [usize],
i: usize,
j: usize,
total_bits: usize,
}
impl<'a> From<&'a BitWords> for BitReader<'a> {
fn from(bit_words: &'a BitWords) -> Self {
let word = bit_words.words.first().copied().unwrap_or_default();
BitReader {
word,
words: &bit_words.words,
i: 0,
j: 0,
total_bits: bit_words.total_bits,
}
}
}
impl<'a> BitReader<'a> {
pub fn aligned_byte_idx(&self) -> QCompressResult<usize> {
if self.j % 8 == 0 {
Ok(self.i * BYTES_PER_WORD + self.j / 8)
} else {
Err(QCompressError::invalid_argument(format!(
"cannot get aligned byte index on misaligned bit reader at word {} bit {}",
self.i,
self.j,
)))
}
}
pub fn bit_idx(&self) -> usize {
WORD_SIZE * self.i + self.j
}
pub fn bits_remaining(&self) -> usize {
self.total_bits - self.bit_idx()
}
pub fn byte_size(&self) -> usize {
bits::ceil_div(self.total_bits, 8)
}
fn increment_i(&mut self) {
self.i += 1;
self.update_unsafe_word();
}
fn update_unsafe_word(&mut self) {
self.word = self.words[self.i];
}
#[inline]
fn refresh_if_needed(&mut self) {
if self.j == WORD_SIZE {
self.increment_i();
self.j = 0;
}
}
fn insufficient_data_check(&self, name: &str, n: usize) -> QCompressResult<()> {
let bit_idx = self.bit_idx();
if bit_idx + n > self.total_bits {
Err(QCompressError::insufficient_data_recipe(
name,
n,
bit_idx,
self.total_bits,
))
} else {
Ok(())
}
}
pub fn read_aligned_bytes(&mut self, n: usize) -> QCompressResult<Vec<u8>> {
let byte_idx = self.aligned_byte_idx()?;
let new_byte_idx = byte_idx + n;
let byte_size = self.byte_size();
if new_byte_idx > byte_size {
Err(QCompressError::insufficient_data(format!(
"cannot read {} aligned bytes at byte idx {} out of {}",
n,
byte_idx,
byte_size,
)))
} else {
self.refresh_if_needed();
let end_word_idx = bits::ceil_div(new_byte_idx, BYTES_PER_WORD);
let padded_bytes = bits::words_to_bytes(&self.words[byte_idx / BYTES_PER_WORD..end_word_idx]);
self.seek(n * 8);
let padded_start_idx = byte_idx % BYTES_PER_WORD;
Ok(padded_bytes[padded_start_idx..padded_start_idx + n].to_vec())
}
}
pub fn read_one(&mut self) -> QCompressResult<bool> {
self.insufficient_data_check("read_one", 1)?;
Ok(self.unchecked_read_one())
}
pub fn read(&mut self, n: usize) -> QCompressResult<Vec<bool>> {
self.insufficient_data_check("read", n)?;
let mut res = Vec::with_capacity(n);
for _ in 0..n {
if self.j == WORD_SIZE {
self.increment_i();
self.j = 0;
}
res.push(bits::bit_from_word(self.word, self.j));
self.j += 1;
}
Ok(res)
}
pub fn read_diff<U: UnsignedLike>(&mut self, n: usize) -> QCompressResult<U> {
self.insufficient_data_check("read_diff", n)?;
Ok(self.unchecked_read_diff::<U>(n))
}
pub fn read_usize(&mut self, n: usize) -> QCompressResult<usize> {
self.insufficient_data_check("read_usize", n)?;
Ok(self.unchecked_read_usize(n))
}
pub fn read_prefix_table_idx(
&mut self,
table_size_log: usize,
) -> QCompressResult<(usize, usize)> {
let bit_idx = self.bit_idx();
if bit_idx >= self.total_bits {
return Err(QCompressError::insufficient_data_recipe(
"read_prefix_table_idx",
1,
bit_idx,
self.total_bits,
));
}
self.refresh_if_needed();
let n_plus_j = table_size_log + self.j;
if n_plus_j <= WORD_SIZE {
let rshift = WORD_SIZE - n_plus_j;
let res = (self.word & (usize::MAX >> self.j)) >> rshift;
let bits_read = min(table_size_log, self.total_bits - bit_idx);
self.j += bits_read;
Ok((bits_read, res))
} else {
let remaining = n_plus_j - WORD_SIZE;
let mut res = (self.word & (usize::MAX >> self.j)) << remaining;
if self.i + 1 < self.words.len() {
self.increment_i();
let shift = WORD_SIZE - remaining;
res |= self.word >> shift;
self.j = remaining;
Ok((table_size_log, res))
} else {
self.j = WORD_SIZE;
Ok((table_size_log - remaining, res))
}
}
}
pub fn read_varint(&mut self, jumpstart: usize) -> QCompressResult<usize> {
let mut res = self.read_usize(jumpstart)?;
for i in jumpstart..BITS_TO_ENCODE_N_ENTRIES {
if self.read_one()? {
if self.read_one()? {
res |= 1 << i
}
} else {
break;
}
}
Ok(res)
}
pub fn unchecked_read_one(&mut self) -> bool {
self.refresh_if_needed();
let res = bits::bit_from_word(self.word, self.j);
self.j += 1;
res
}
pub fn unchecked_read_diff<U: UnsignedLike>(&mut self, n: usize) -> U {
if n == 0 {
return U::ZERO;
}
self.refresh_if_needed();
let n_plus_j = n + self.j;
let first_masked_word = self.word & (usize::MAX >> self.j);
if n_plus_j <= WORD_SIZE {
let shift = WORD_SIZE - n_plus_j;
self.j = n_plus_j;
U::from_word(first_masked_word >> shift)
} else {
let mut remaining = n_plus_j - WORD_SIZE;
let mut res = U::from_word(first_masked_word << remaining);
self.increment_i();
for _ in 0..(U::BITS - 1) / WORD_SIZE {
if remaining <= WORD_SIZE {
break;
}
remaining -= WORD_SIZE;
res |= U::from_word(self.word) << remaining;
self.increment_i();
}
self.j = remaining;
let shift = WORD_SIZE - remaining;
res | U::from_word(self.word >> shift)
}
}
pub fn unchecked_read_usize(
&mut self,
n: usize,
) -> usize {
self.refresh_if_needed();
let n_plus_j = n + self.j;
let first_word = self.word & (usize::MAX >> self.j);
if n_plus_j <= WORD_SIZE {
let shift = WORD_SIZE - n_plus_j;
self.j = n_plus_j;
first_word >> shift
} else {
let remaining = n_plus_j - WORD_SIZE;
let shift = WORD_SIZE - remaining;
self.increment_i();
self.j = remaining;
(first_word << remaining) | (self.word >> shift)
}
}
pub fn unchecked_read_varint(&mut self, jumpstart: usize) -> usize {
let mut res = self.unchecked_read_usize(jumpstart);
for i in jumpstart..BITS_TO_ENCODE_N_ENTRIES {
if self.unchecked_read_one() {
if self.unchecked_read_one() {
res |= 1 << i
}
} else {
break;
}
}
res
}
pub fn drain_empty_byte(&mut self, message: &str) -> QCompressResult<()> {
if self.j % 8 != 0 {
let end_j = 8 * bits::ceil_div(self.j, 8);
if self.word & (usize::MAX >> self.j) & (usize::MAX << (WORD_SIZE - end_j)) > 0 {
return Err(QCompressError::corruption(message));
}
self.j = end_j;
}
Ok(())
}
pub fn seek_to(&mut self, bit_idx: usize) {
self.i = bit_idx.div_euclid(WORD_SIZE);
self.j = bit_idx.rem_euclid(WORD_SIZE);
self.word = self.words.get(self.i).copied().unwrap_or(0);
}
pub fn seek(&mut self, n: usize) {
self.seek_to(self.bit_idx() + n);
}
pub fn rewind_prefix_overshoot(&mut self, n: usize) {
if n <= self.j {
self.j -= n;
} else {
self.i -= 1;
self.j = self.j + WORD_SIZE - n;
self.update_unsafe_word();
}
}
}
#[cfg(test)]
mod tests {
use crate::bit_words::BitWords;
use super::BitReader;
use crate::errors::QCompressResult;
#[test]
fn test_bit_reader() -> QCompressResult<()>{
let bytes = vec![0x9a, 0x6b, 0x2d];
let words = BitWords::from(&bytes);
let mut bit_reader = BitReader::from(&words);
assert_eq!(
bit_reader.read_aligned_bytes(1)?,
vec![0x9a],
);
assert!(!bit_reader.unchecked_read_one());
assert!(bit_reader.read_one()?);
assert_eq!(
bit_reader.read(3)?,
vec![true, false, true],
);
assert_eq!(
bit_reader.unchecked_read_diff::<u64>(2),
1_u64
);
assert_eq!(
bit_reader.unchecked_read_diff::<u32>(3),
4_u32
);
assert_eq!(
bit_reader.unchecked_read_varint(2),
6
);
Ok(())
}
#[test]
fn test_seek_rewind() {
let bytes = vec![0; 6];
let words = BitWords::from(&bytes);
let mut reader = BitReader::from(&words);
reader.seek(43);
reader.rewind_prefix_overshoot(2);
assert_eq!(reader.bit_idx(), 41);
reader.rewind_prefix_overshoot(2);
assert_eq!(reader.bit_idx(), 39);
reader.rewind_prefix_overshoot(7);
assert_eq!(reader.bit_idx(), 32);
reader.rewind_prefix_overshoot(8);
assert_eq!(reader.bit_idx(), 24);
reader.rewind_prefix_overshoot(17);
assert_eq!(reader.bit_idx(), 7);
}
}