#![warn(missing_docs)]
#[cfg(feature = "alloc")]
use alloc::vec;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use core2::io::{self, SeekFrom};
#[cfg(not(feature = "alloc"))]
use std::io::{self, SeekFrom};
use super::{
huffman::ReadHuffmanTree, BitQueue, Endianness, Numeric, PhantomData, Primitive, SignedNumeric,
};
pub trait BitRead {
fn read_bit(&mut self) -> io::Result<bool>;
fn read<U>(&mut self, bits: u32) -> io::Result<U>
where
U: Numeric;
fn read_in<const BITS: u32, U>(&mut self) -> io::Result<U>
where
U: Numeric,
{
self.read(BITS)
}
fn read_signed<S>(&mut self, bits: u32) -> io::Result<S>
where
S: SignedNumeric;
fn read_signed_in<const BITS: u32, S>(&mut self) -> io::Result<S>
where
S: SignedNumeric,
{
self.read_signed(BITS)
}
fn read_to<V>(&mut self) -> io::Result<V>
where
V: Primitive;
fn read_as_to<F, V>(&mut self) -> io::Result<V>
where
F: Endianness,
V: Primitive;
fn skip(&mut self, bits: u32) -> io::Result<()>;
fn read_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
for b in buf.iter_mut() {
*b = self.read_in::<8, _>()?;
}
Ok(())
}
#[inline(always)]
#[deprecated(since = "1.8.0", note = "use read_to() method instead")]
fn read_to_bytes<const SIZE: usize>(&mut self) -> io::Result<[u8; SIZE]> {
self.read_to()
}
fn read_to_vec(&mut self, bytes: usize) -> io::Result<Vec<u8>> {
read_to_vec(|buf| self.read_bytes(buf), bytes)
}
fn read_unary0(&mut self) -> io::Result<u32> {
let mut unary = 0;
while self.read_bit()? {
unary += 1;
}
Ok(unary)
}
fn read_unary1(&mut self) -> io::Result<u32> {
let mut unary = 0;
while !(self.read_bit()?) {
unary += 1;
}
Ok(unary)
}
fn parse<F: FromBitStream>(&mut self) -> Result<F, F::Error> {
F::from_reader(self)
}
fn parse_with<'a, F: FromBitStreamWith<'a>>(
&mut self,
context: &F::Context,
) -> Result<F, F::Error> {
F::from_reader(self, context)
}
fn byte_aligned(&self) -> bool;
fn byte_align(&mut self);
}
pub trait HuffmanRead<E: Endianness> {
fn read_huffman<T>(&mut self, tree: &[ReadHuffmanTree<E, T>]) -> io::Result<T>
where
T: Clone;
}
#[derive(Clone, Debug)]
pub struct BitReader<R: io::Read, E: Endianness> {
reader: R,
bitqueue: BitQueue<E, u8>,
}
impl<R: io::Read, E: Endianness> BitReader<R, E> {
pub fn new(reader: R) -> BitReader<R, E> {
BitReader {
reader,
bitqueue: BitQueue::new(),
}
}
pub fn endian(reader: R, _endian: E) -> BitReader<R, E> {
BitReader {
reader,
bitqueue: BitQueue::new(),
}
}
#[inline]
pub fn into_reader(self) -> R {
self.reader
}
#[inline]
pub fn reader(&mut self) -> Option<&mut R> {
if self.byte_aligned() {
Some(&mut self.reader)
} else {
None
}
}
#[inline]
pub fn into_bytereader(self) -> ByteReader<R, E> {
ByteReader::new(self.into_reader())
}
#[inline]
pub fn bytereader(&mut self) -> Option<ByteReader<&mut R, E>> {
self.reader().map(ByteReader::new)
}
#[inline]
pub fn into_unread(self) -> (u32, u8) {
(self.bitqueue.len(), self.bitqueue.value())
}
}
impl<R: io::Read, E: Endianness> BitRead for BitReader<R, E> {
#[inline(always)]
fn read_bit(&mut self) -> io::Result<bool> {
if self.bitqueue.is_empty() {
self.bitqueue.set(read_byte(&mut self.reader)?, 8);
}
Ok(self.bitqueue.pop(1) == 1)
}
fn read<U>(&mut self, mut bits: u32) -> io::Result<U>
where
U: Numeric,
{
if bits <= U::BITS_SIZE {
let bitqueue_len = self.bitqueue.len();
if bits <= bitqueue_len {
Ok(U::from_u8(self.bitqueue.pop(bits)))
} else {
let mut acc =
BitQueue::from_value(U::from_u8(self.bitqueue.pop_all()), bitqueue_len);
bits -= bitqueue_len;
read_aligned(&mut self.reader, bits / 8, &mut acc)?;
read_unaligned(&mut self.reader, bits % 8, &mut acc, &mut self.bitqueue)?;
Ok(acc.value())
}
} else {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type read",
))
}
}
#[inline]
fn read_in<const BITS: u32, U>(&mut self) -> io::Result<U>
where
U: Numeric,
{
const {
assert!(BITS <= U::BITS_SIZE, "excessive bits for type read");
}
let bitqueue_len = self.bitqueue.len();
if BITS <= bitqueue_len {
Ok(U::from_u8(self.bitqueue.pop_fixed::<BITS>()))
} else {
let mut bits = BITS;
let mut acc = BitQueue::from_value(U::from_u8(self.bitqueue.pop_all()), bitqueue_len);
bits -= bitqueue_len;
read_aligned(&mut self.reader, bits / 8, &mut acc)?;
read_unaligned(&mut self.reader, bits % 8, &mut acc, &mut self.bitqueue)?;
Ok(acc.value())
}
}
#[inline]
fn read_signed<S>(&mut self, bits: u32) -> io::Result<S>
where
S: SignedNumeric,
{
match bits {
0 => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"signed reads need at least 1 bit for sign",
)),
bits if bits <= S::BITS_SIZE => E::read_signed(self, bits),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type read",
)),
}
}
#[inline]
fn read_signed_in<const BITS: u32, S>(&mut self) -> io::Result<S>
where
S: SignedNumeric,
{
const {
assert!(BITS > 0, "signed reads need at least 1 bit for sign");
assert!(BITS <= S::BITS_SIZE, "excessive bits for type read");
}
E::read_signed_fixed::<_, BITS, S>(self)
}
#[inline]
fn read_to<V>(&mut self) -> io::Result<V>
where
V: Primitive,
{
E::read_primitive(self)
}
#[inline]
fn read_as_to<F, V>(&mut self) -> io::Result<V>
where
F: Endianness,
V: Primitive,
{
F::read_primitive(self)
}
fn skip(&mut self, mut bits: u32) -> io::Result<()> {
use core::cmp::min;
let to_drop = min(self.bitqueue.len(), bits);
if to_drop != 0 {
self.bitqueue.drop(to_drop);
bits -= to_drop;
}
skip_aligned(&mut self.reader, bits / 8)?;
skip_unaligned(&mut self.reader, bits % 8, &mut self.bitqueue)
}
fn read_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
if self.byte_aligned() {
self.reader.read_exact(buf)
} else {
for b in buf.iter_mut() {
*b = self.read_in::<8, _>()?;
}
Ok(())
}
}
fn read_unary0(&mut self) -> io::Result<u32> {
if self.bitqueue.is_empty() {
read_aligned_unary(&mut self.reader, 0b1111_1111, &mut self.bitqueue)
.map(|u| u + self.bitqueue.pop_1())
} else if self.bitqueue.all_1() {
let base = self.bitqueue.len();
self.bitqueue.clear();
read_aligned_unary(&mut self.reader, 0b1111_1111, &mut self.bitqueue)
.map(|u| base + u + self.bitqueue.pop_1())
} else {
Ok(self.bitqueue.pop_1())
}
}
fn read_unary1(&mut self) -> io::Result<u32> {
if self.bitqueue.is_empty() {
read_aligned_unary(&mut self.reader, 0b0000_0000, &mut self.bitqueue)
.map(|u| u + self.bitqueue.pop_0())
} else if self.bitqueue.all_0() {
let base = self.bitqueue.len();
self.bitqueue.clear();
read_aligned_unary(&mut self.reader, 0b0000_0000, &mut self.bitqueue)
.map(|u| base + u + self.bitqueue.pop_0())
} else {
Ok(self.bitqueue.pop_0())
}
}
#[inline]
fn byte_aligned(&self) -> bool {
self.bitqueue.is_empty()
}
#[inline]
fn byte_align(&mut self) {
self.bitqueue.clear()
}
}
impl<R, E> BitReader<R, E>
where
E: Endianness,
R: io::Read + io::Seek,
{
pub fn seek_bits(&mut self, from: io::SeekFrom) -> io::Result<u64> {
match from {
io::SeekFrom::Start(from_start_pos) => {
let (bytes, bits) = (from_start_pos / 8, (from_start_pos % 8) as u32);
self.byte_align();
self.reader.seek(io::SeekFrom::Start(bytes))?;
self.skip(bits)?;
Ok(from_start_pos)
}
io::SeekFrom::End(from_end_pos) => {
let reader_end = self.reader.seek(io::SeekFrom::End(0))?;
let new_pos = (reader_end * 8) as i64 - from_end_pos;
assert!(new_pos >= 0, "The final position should be greater than 0");
self.seek_bits(io::SeekFrom::Start(new_pos as u64))
}
io::SeekFrom::Current(offset) => {
let new_pos = self.position_in_bits()? as i64 + offset;
assert!(new_pos >= 0, "The final position should be greater than 0");
self.seek_bits(io::SeekFrom::Start(new_pos as u64))
}
}
}
#[inline]
pub fn position_in_bits(&mut self) -> io::Result<u64> {
let bytes = self.reader.seek(SeekFrom::Current(0))?;
Ok(bytes * 8 - (self.bitqueue.len() as u64))
}
}
impl<R: io::Read, E: Endianness> HuffmanRead<E> for BitReader<R, E> {
fn read_huffman<T>(&mut self, tree: &[ReadHuffmanTree<E, T>]) -> io::Result<T>
where
T: Clone,
{
let mut result: &ReadHuffmanTree<E, T> = &tree[self.bitqueue.to_state()];
loop {
match result {
ReadHuffmanTree::Done(ref value, ref queue_val, ref queue_bits, _) => {
self.bitqueue.set(*queue_val, *queue_bits);
return Ok(value.clone());
}
ReadHuffmanTree::Continue(ref tree) => {
result = &tree[read_byte(&mut self.reader)? as usize];
}
ReadHuffmanTree::InvalidState => {
panic!("invalid state");
}
}
}
}
}
#[inline]
fn read_byte<R>(mut reader: R) -> io::Result<u8>
where
R: io::Read,
{
let mut byte = 0;
reader
.read_exact(core::slice::from_mut(&mut byte))
.map(|()| byte)
}
fn read_aligned<R, E, N>(mut reader: R, bytes: u32, acc: &mut BitQueue<E, N>) -> io::Result<()>
where
R: io::Read,
E: Endianness,
N: Numeric,
{
if bytes > 0 {
let mut buf = N::buffer();
reader.read_exact(&mut buf.as_mut()[0..bytes as usize])?;
for b in &buf.as_ref()[0..bytes as usize] {
acc.push_fixed::<8>(N::from_u8(*b));
}
}
Ok(())
}
fn skip_aligned<R>(mut reader: R, mut bytes: u32) -> io::Result<()>
where
R: io::Read,
{
use core::cmp::min;
let mut buf = [0; 8];
while bytes > 0 {
let to_read = min(8, bytes);
reader.read_exact(&mut buf[0..to_read as usize])?;
bytes -= to_read;
}
Ok(())
}
#[inline]
fn read_unaligned<R, E, N>(
reader: R,
bits: u32,
acc: &mut BitQueue<E, N>,
rem: &mut BitQueue<E, u8>,
) -> io::Result<()>
where
R: io::Read,
E: Endianness,
N: Numeric,
{
debug_assert!(bits <= 8);
if bits > 0 {
rem.set(read_byte(reader)?, 8);
acc.push(bits, N::from_u8(rem.pop(bits)));
}
Ok(())
}
#[inline]
fn skip_unaligned<R, E>(reader: R, bits: u32, rem: &mut BitQueue<E, u8>) -> io::Result<()>
where
R: io::Read,
E: Endianness,
{
debug_assert!(bits <= 8);
if bits > 0 {
rem.set(read_byte(reader)?, 8);
rem.pop(bits);
}
Ok(())
}
#[inline]
fn read_aligned_unary<R, E>(
mut reader: R,
continue_val: u8,
rem: &mut BitQueue<E, u8>,
) -> io::Result<u32>
where
R: io::Read,
E: Endianness,
{
let mut acc = 0;
let mut byte = read_byte(reader.by_ref())?;
while byte == continue_val {
acc += 8;
byte = read_byte(reader.by_ref())?;
}
rem.set(byte, 8);
Ok(acc)
}
pub trait ByteRead {
fn read<V>(&mut self) -> Result<V, io::Error>
where
V: Primitive;
fn read_as<F, V>(&mut self) -> Result<V, io::Error>
where
F: Endianness,
V: Primitive;
fn read_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
for b in buf.iter_mut() {
*b = self.read()?;
}
Ok(())
}
#[inline(always)]
#[deprecated(since = "1.8.0", note = "use read() method instead")]
fn read_to_bytes<const SIZE: usize>(&mut self) -> io::Result<[u8; SIZE]> {
self.read()
}
fn read_to_vec(&mut self, bytes: usize) -> io::Result<Vec<u8>> {
read_to_vec(|buf| self.read_bytes(buf), bytes)
}
fn skip(&mut self, bytes: u32) -> io::Result<()>;
fn parse<F: FromByteStream>(&mut self) -> Result<F, F::Error> {
F::from_reader(self)
}
fn parse_with<'a, F: FromByteStreamWith<'a>>(
&mut self,
context: &F::Context,
) -> Result<F, F::Error> {
F::from_reader(self, context)
}
fn reader_ref(&mut self) -> &mut dyn io::Read;
}
#[derive(Debug)]
pub struct ByteReader<R: io::Read, E: Endianness> {
phantom: PhantomData<E>,
reader: R,
}
impl<R: io::Read, E: Endianness> ByteReader<R, E> {
pub fn new(reader: R) -> ByteReader<R, E> {
ByteReader {
phantom: PhantomData,
reader,
}
}
pub fn endian(reader: R, _endian: E) -> ByteReader<R, E> {
ByteReader {
phantom: PhantomData,
reader,
}
}
#[inline]
pub fn into_reader(self) -> R {
self.reader
}
#[inline]
pub fn reader(&mut self) -> &mut R {
&mut self.reader
}
#[inline]
pub fn into_bitreader(self) -> BitReader<R, E> {
BitReader::new(self.into_reader())
}
#[inline]
pub fn bitreader(&mut self) -> BitReader<&mut R, E> {
BitReader::new(self.reader())
}
}
impl<R: io::Read, E: Endianness> ByteRead for ByteReader<R, E> {
#[inline]
fn read<V>(&mut self) -> Result<V, io::Error>
where
V: Primitive,
{
E::read_numeric(&mut self.reader)
}
#[inline]
fn read_as<F, V>(&mut self) -> Result<V, io::Error>
where
F: Endianness,
V: Primitive,
{
F::read_numeric(&mut self.reader)
}
#[inline]
fn read_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.reader.read_exact(buf)
}
#[inline]
fn skip(&mut self, bytes: u32) -> io::Result<()> {
skip_aligned(&mut self.reader, bytes)
}
#[inline]
fn reader_ref(&mut self) -> &mut dyn io::Read {
&mut self.reader
}
}
pub trait FromBitStream {
type Error;
fn from_reader<R: BitRead + ?Sized>(r: &mut R) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait FromBitStreamWith<'a> {
type Context: 'a;
type Error;
fn from_reader<R: BitRead + ?Sized>(
r: &mut R,
context: &Self::Context,
) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait FromByteStream {
type Error;
fn from_reader<R: ByteRead + ?Sized>(r: &mut R) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait FromByteStreamWith<'a> {
type Context: 'a;
type Error;
fn from_reader<R: ByteRead + ?Sized>(
r: &mut R,
context: &Self::Context,
) -> Result<Self, Self::Error>
where
Self: Sized;
}
fn read_to_vec(
mut read: impl FnMut(&mut [u8]) -> io::Result<()>,
bytes: usize,
) -> io::Result<Vec<u8>> {
const MAX_CHUNK: usize = 4096;
match bytes {
0 => Ok(Vec::new()),
bytes if bytes <= MAX_CHUNK => {
let mut buf = vec![0; bytes];
read(&mut buf)?;
Ok(buf)
}
mut bytes => {
let mut whole = Vec::with_capacity(MAX_CHUNK);
let mut chunk: [u8; MAX_CHUNK] = [0; MAX_CHUNK];
while bytes > 0 {
let chunk_size = bytes.min(MAX_CHUNK);
let chunk = &mut chunk[0..chunk_size];
read(chunk)?;
whole.extend_from_slice(chunk);
bytes -= chunk_size;
}
Ok(whole)
}
}
}