use std::io;
use better_io::BetterBufRead;
use crate::bits;
use crate::constants::Bitlen;
use crate::errors::{PcoError, PcoResult};
use crate::read_write_uint::ReadWriteUint;
#[inline]
pub unsafe fn u64_at(src: &[u8], byte_idx: usize) -> u64 {
let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 8]);
u64::from_le_bytes(raw_bytes)
}
#[inline]
pub unsafe fn u32_at(src: &[u8], byte_idx: usize) -> u32 {
let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 4]);
u32::from_le_bytes(raw_bytes)
}
#[inline]
pub unsafe fn read_uint_at<U: ReadWriteUint, const READ_BYTES: usize>(
src: &[u8],
byte_idx: usize,
bits_past_byte: Bitlen,
n: Bitlen,
) -> U {
match READ_BYTES {
4 => read_u32_at(src, byte_idx, bits_past_byte, n),
8 => read_u64_at(src, byte_idx, bits_past_byte, n),
15 => read_almost_u64x2_at(src, byte_idx, bits_past_byte, n),
_ => unreachable!("invalid read bytes: {}", READ_BYTES),
}
}
#[inline]
unsafe fn read_u32_at<U: ReadWriteUint>(
src: &[u8],
byte_idx: usize,
bits_past_byte: Bitlen,
n: Bitlen,
) -> U {
debug_assert!(n <= 25);
U::from_u32(bits::lowest_bits_fast(
u32_at(src, byte_idx) >> bits_past_byte,
n,
))
}
#[inline]
unsafe fn read_u64_at<U: ReadWriteUint>(
src: &[u8],
byte_idx: usize,
bits_past_byte: Bitlen,
n: Bitlen,
) -> U {
debug_assert!(n <= 57);
U::from_u64(bits::lowest_bits_fast(
u64_at(src, byte_idx) >> bits_past_byte,
n,
))
}
#[inline]
unsafe fn read_almost_u64x2_at<U: ReadWriteUint>(
src: &[u8],
byte_idx: usize,
bits_past_byte: Bitlen,
n: Bitlen,
) -> U {
debug_assert!(n <= 113);
let first_word = U::from_u64(u64_at(src, byte_idx) >> bits_past_byte);
let processed = 56 - bits_past_byte;
let second_word = U::from_u64(u64_at(src, byte_idx + 7)) << processed;
bits::lowest_bits(first_word | second_word, n)
}
pub struct BitReader<'a> {
pub src: &'a [u8],
unpadded_bit_size: usize,
pub stale_byte_idx: usize, pub bits_past_byte: Bitlen, }
impl<'a> BitReader<'a> {
pub fn new(src: &'a [u8], unpadded_byte_size: usize, bits_past_byte: Bitlen) -> Self {
Self {
src,
unpadded_bit_size: unpadded_byte_size * 8,
stale_byte_idx: 0,
bits_past_byte,
}
}
#[inline]
pub fn bit_idx(&self) -> usize {
self.stale_byte_idx * 8 + self.bits_past_byte as usize
}
fn byte_idx(&self) -> usize {
self.bit_idx() / 8
}
fn aligned_byte_idx(&self) -> PcoResult<usize> {
if self.bits_past_byte.is_multiple_of(8) {
Ok(self.byte_idx())
} else {
Err(PcoError::invalid_argument(format!(
"cannot get aligned byte index on misaligned bit reader (byte {} + {} bits)",
self.stale_byte_idx, self.bits_past_byte,
)))
}
}
#[inline]
fn refill(&mut self) {
self.stale_byte_idx += (self.bits_past_byte / 8) as usize;
self.bits_past_byte %= 8;
}
#[inline]
fn consume(&mut self, n: Bitlen) {
self.bits_past_byte += n;
}
pub fn read_aligned_bytes(&mut self, n: usize) -> PcoResult<&'a [u8]> {
let byte_idx = self.aligned_byte_idx()?;
let new_byte_idx = byte_idx + n;
self.stale_byte_idx = new_byte_idx;
self.bits_past_byte = 0;
Ok(&self.src[byte_idx..new_byte_idx])
}
pub unsafe fn read_uint<U: ReadWriteUint>(&mut self, n: Bitlen) -> U {
self.refill();
let res = match U::MAX_BYTES {
1..=4 => read_uint_at::<U, 4>(
self.src,
self.stale_byte_idx,
self.bits_past_byte,
n,
),
5..=8 => read_uint_at::<U, 8>(
self.src,
self.stale_byte_idx,
self.bits_past_byte,
n,
),
9..=15 => read_uint_at::<U, 15>(
self.src,
self.stale_byte_idx,
self.bits_past_byte,
n,
),
_ => unreachable!(
"[BitReader] unsupported max bytes: {}",
U::MAX_BYTES
),
};
self.consume(n);
res
}
pub unsafe fn read_usize(&mut self, n: Bitlen) -> usize {
self.read_uint(n)
}
pub unsafe fn read_bitlen(&mut self, n: Bitlen) -> Bitlen {
self.read_uint(n)
}
pub unsafe fn read_bool(&mut self) -> bool {
self.read_uint::<u32>(1) > 0
}
#[inline]
fn bit_idx_safe(&self) -> PcoResult<usize> {
let bit_idx = self.bit_idx();
if bit_idx > self.unpadded_bit_size {
return Err(PcoError::insufficient_data(format!(
"[BitReader] out of bounds at bit {} / {}",
bit_idx, self.unpadded_bit_size
)));
}
Ok(bit_idx)
}
pub fn check_in_bounds(&self) -> PcoResult<()> {
self.bit_idx_safe()?;
Ok(())
}
pub fn drain_empty_byte(&mut self, message: &str) -> PcoResult<()> {
self.check_in_bounds()?;
self.refill();
if self.bits_past_byte != 0 {
if (self.src[self.stale_byte_idx] >> self.bits_past_byte) > 0 {
return Err(PcoError::corruption(message));
}
self.consume(8 - self.bits_past_byte);
}
Ok(())
}
}
pub struct BitReaderBuilder<R: BetterBufRead> {
padding: usize,
inner: R,
eof_buffer: Vec<u8>,
reached_eof: bool,
bytes_into_eof_buffer: usize,
bits_past_byte: Bitlen,
}
impl<R: BetterBufRead> BitReaderBuilder<R> {
pub fn new(inner: R, padding: usize, bits_past_byte: Bitlen) -> Self {
Self {
padding,
inner,
eof_buffer: vec![],
reached_eof: false,
bytes_into_eof_buffer: 0,
bits_past_byte,
}
}
fn build<'a>(&'a mut self) -> io::Result<BitReader<'a>> {
let n_bytes_to_read = self.padding;
if !self.reached_eof {
self.inner.fill_or_eof(n_bytes_to_read)?;
let inner_bytes = self.inner.buffer();
if inner_bytes.len() < n_bytes_to_read {
self.reached_eof = true;
self.eof_buffer = vec![0; inner_bytes.len() + self.padding];
self.eof_buffer[..inner_bytes.len()].copy_from_slice(inner_bytes);
}
}
let src = if self.reached_eof {
&self.eof_buffer[self.bytes_into_eof_buffer..]
} else {
self.inner.buffer()
};
let unpadded_bytes = if self.reached_eof {
self.eof_buffer.len() - self.padding - self.bytes_into_eof_buffer
} else {
src.len()
};
let bits_past_byte = self.bits_past_byte;
Ok(BitReader::new(
src,
unpadded_bytes,
bits_past_byte,
))
}
pub fn into_inner(self) -> R {
self.inner
}
fn update(&mut self, final_bit_idx: usize) {
let bytes_consumed = final_bit_idx / 8;
self.inner.consume(bytes_consumed);
if self.reached_eof {
self.bytes_into_eof_buffer += bytes_consumed;
}
self.bits_past_byte = final_bit_idx as Bitlen % 8;
}
pub fn with_reader<Y, F: FnOnce(&mut BitReader) -> PcoResult<Y>>(
&mut self,
f: F,
) -> PcoResult<Y> {
let mut reader = self.build()?;
let res = f(&mut reader)?;
let final_bit_idx = reader.bit_idx_safe()?;
self.update(final_bit_idx);
Ok(res)
}
}
pub fn ensure_buf_read_capacity<R: BetterBufRead>(src: &mut R, required: usize) {
if let Some(current_capacity) = src.capacity() {
if current_capacity < required {
src.resize_capacity(required);
}
}
}
#[cfg(test)]
mod tests {
use crate::constants::OVERSHOOT_PADDING;
use crate::errors::{ErrorKind, PcoResult};
use super::*;
#[test]
fn test_bit_reader() -> PcoResult<()> {
let mut src = vec![137, 38, 255, 65];
src.resize(20, 0);
let mut reader = BitReader::new(&src, 5, 0);
unsafe {
assert_eq!(reader.read_bitlen(4), 9);
assert!(reader.read_aligned_bytes(1).is_err());
assert_eq!(reader.read_bitlen(4), 8);
assert_eq!(reader.read_aligned_bytes(1)?, vec![38]);
assert_eq!(reader.read_usize(15), 255 + 65 * 256);
reader.drain_empty_byte("should be empty")?;
assert_eq!(reader.aligned_byte_idx()?, 4);
}
Ok(())
}
#[test]
fn test_bit_reader_builder() -> PcoResult<()> {
let src = (0..7).collect::<Vec<_>>();
let mut reader_builder = BitReaderBuilder::new(src.as_slice(), 4 + OVERSHOOT_PADDING, 1);
reader_builder.with_reader(|reader| unsafe {
assert_eq!(&reader.src[0..4], &vec![0, 1, 2, 3]);
assert_eq!(reader.bit_idx(), 1);
assert_eq!(reader.read_usize(16), 1 << 7); Ok(())
})?;
reader_builder.with_reader(|reader| unsafe {
assert_eq!(&reader.src[0..4], &vec![2, 3, 4, 5]);
assert_eq!(reader.bit_idx(), 1);
assert_eq!(reader.read_usize(7), 1);
assert_eq!(reader.bit_idx(), 8);
assert_eq!(reader.read_aligned_bytes(3)?, &vec![3, 4, 5]);
Ok(())
})?;
let err = reader_builder
.with_reader(|reader| unsafe {
assert!(reader.src.len() >= 4); reader.read_usize(9); Ok(())
})
.unwrap_err();
assert!(matches!(
err.kind,
ErrorKind::InsufficientData
));
Ok(())
}
}