use core::convert::TryInto;
use std::io;
use super::{
skip_aligned, BitCount, BitRead, ByteReader, Endianness, PhantomData, Primitive,
SignedBitCount, SignedInteger, UnsignedInteger,
};
#[derive(Clone, Debug)]
pub struct BitReader<R, E: Endianness> {
reader: R,
value: u8,
bits: u32,
phantom: PhantomData<E>,
}
impl<R, E: Endianness> BitReader<R, E> {
pub fn new(reader: R) -> BitReader<R, E> {
BitReader {
reader,
value: 0,
bits: 0,
phantom: PhantomData,
}
}
pub fn endian(reader: R, _endian: E) -> BitReader<R, E> {
BitReader {
reader,
value: 0,
bits: 0,
phantom: PhantomData,
}
}
#[inline]
pub fn into_reader(self) -> R {
self.reader
}
}
impl<R: io::Read, E: Endianness> BitReader<R, E> {
#[inline]
pub fn reader(&mut self) -> Option<&mut R> {
if BitRead::byte_aligned(self) {
Some(&mut self.reader)
} else {
None
}
}
#[inline]
pub fn aligned_reader(&mut self) -> &mut R {
BitRead::byte_align(self);
&mut self.reader
}
#[inline]
pub fn into_bytereader(self) -> ByteReader<R, E> {
ByteReader::new(self.into_reader())
}
#[inline]
pub fn bytereader(&mut self) -> Option<ByteReader<&mut R, E>> {
self.reader().map(ByteReader::new)
}
}
impl<R: io::Read, E: Endianness> BitRead for BitReader<R, E> {
#[inline(always)]
fn read_bit(&mut self) -> io::Result<bool> {
let Self {
value,
bits,
reader,
..
} = self;
E::pop_bit_refill(reader, value, bits)
}
#[inline(always)]
fn read_unsigned_counted<const BITS: u32, U>(&mut self, bits: BitCount<BITS>) -> io::Result<U>
where
U: UnsignedInteger,
{
let Self {
value: queue_value,
bits: queue_bits,
reader,
..
} = self;
E::read_bits(reader, queue_value, queue_bits, bits)
}
#[inline]
fn read_unsigned<const BITS: u32, U>(&mut self) -> io::Result<U>
where
U: UnsignedInteger,
{
let Self {
value,
bits,
reader,
..
} = self;
E::read_bits_fixed::<BITS, R, U>(reader, value, bits)
}
#[inline(always)]
fn read_signed_counted<const MAX: u32, S>(
&mut self,
bits: impl TryInto<SignedBitCount<MAX>>,
) -> io::Result<S>
where
S: SignedInteger,
{
E::read_signed_counted(
self,
bits.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"signed reads need at least 1 bit for sign",
)
})?,
)
}
#[inline]
fn read_signed<const BITS: u32, S>(&mut self) -> io::Result<S>
where
S: SignedInteger,
{
let count = const {
assert!(BITS <= S::BITS_SIZE, "excessive bits for type read");
let count = BitCount::<BITS>::new::<BITS>().signed_count();
match count {
Some(c) => c,
None => panic!("signed reads need at least 1 bit for sign"),
}
};
E::read_signed_counted(self, count)
}
#[inline]
fn read_to<V>(&mut self) -> io::Result<V>
where
V: Primitive,
{
let mut buffer = V::buffer();
E::read_bytes::<8, _>(
&mut self.reader,
&mut self.value,
self.bits,
buffer.as_mut(),
)?;
Ok(E::bytes_to_primitive(buffer))
}
#[inline]
fn read_as_to<F, V>(&mut self) -> io::Result<V>
where
F: Endianness,
V: Primitive,
{
let mut buffer = V::buffer();
F::read_bytes::<8, _>(
&mut self.reader,
&mut self.value,
self.bits,
buffer.as_mut(),
)?;
Ok(F::bytes_to_primitive(buffer))
}
fn skip(&mut self, mut bits: u32) -> io::Result<()> {
if BitRead::byte_aligned(self) && bits % 8 == 0 {
skip_aligned(self.reader.by_ref(), bits / 8)
} else {
loop {
match bits {
0 => break Ok(()),
bits @ 1..64 => break self.read_var(bits).map(|_: u64| ()),
_ => {
let _ = BitRead::read::<64, u64>(self)?;
bits -= 64;
}
}
}
}
}
#[inline]
fn read_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
E::read_bytes::<1024, _>(&mut self.reader, &mut self.value, self.bits, buf)
}
fn read_unary<const STOP_BIT: u8>(&mut self) -> io::Result<u32> {
let Self {
value,
bits,
reader,
..
} = self;
E::pop_unary::<STOP_BIT, R>(reader, value, bits)
}
#[inline]
fn byte_aligned(&self) -> bool {
self.bits == 0
}
#[inline]
fn byte_align(&mut self) {
self.value = 0;
self.bits = 0;
}
}
impl<R, E> BitReader<R, E>
where
E: Endianness,
R: io::Read + io::Seek,
{
pub fn seek_bits(&mut self, from: io::SeekFrom) -> io::Result<u64> {
match from {
io::SeekFrom::Start(from_start_pos) => {
let (bytes, bits) = (from_start_pos / 8, (from_start_pos % 8) as u32);
BitRead::byte_align(self);
self.reader.seek(io::SeekFrom::Start(bytes))?;
BitRead::skip(self, bits)?;
Ok(from_start_pos)
}
io::SeekFrom::End(from_end_pos) => {
let reader_end = self.reader.seek(io::SeekFrom::End(0))?;
let new_pos = (reader_end * 8) as i64 - from_end_pos;
assert!(new_pos >= 0, "The final position should be greater than 0");
self.seek_bits(io::SeekFrom::Start(new_pos as u64))
}
io::SeekFrom::Current(offset) => {
let new_pos = self.position_in_bits()? as i64 + offset;
assert!(new_pos >= 0, "The final position should be greater than 0");
self.seek_bits(io::SeekFrom::Start(new_pos as u64))
}
}
}
#[inline]
#[allow(clippy::seek_from_current)]
pub fn position_in_bits(&mut self) -> io::Result<u64> {
let bytes = self.reader.seek(io::SeekFrom::Current(0))?;
Ok(bytes * 8 - (self.bits as u64))
}
}