use std::io;
use super::{Numeric, SignedNumeric, BitQueueBE, BitQueueLE, BitQueue};
use huffman::ReadHuffmanTree;
pub trait BitRead {
fn read_bit(&mut self) -> Result<bool, io::Error>;
fn read<U>(&mut self, bits: u32) -> Result<U, io::Error>
where U: Numeric;
fn read_signed<S>(&mut self, bits: u32) -> Result<S, io::Error>
where S: SignedNumeric;
fn skip(&mut self, bits: u32) -> Result<(), io::Error>;
fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), io::Error>;
fn read_unary0(&mut self) -> Result<u32, io::Error>;
fn read_unary1(&mut self) -> Result<u32, io::Error>;
fn byte_aligned(&self) -> bool;
fn byte_align(&mut self);
fn read_huffman<T>(&mut self, mut tree: &ReadHuffmanTree<T>) ->
Result<T,io::Error> where T: Clone {
loop {
match tree {
&ReadHuffmanTree::Leaf(ref v) => {return Ok(v.clone());}
&ReadHuffmanTree::Tree(ref zero, ref one) => {
tree = match self.read_bit() {
Ok(false) => {zero}
Ok(true) => {one}
Err(err) => {return Err(err);}
};
}
}
}
}
}
macro_rules! define_read_bit {
() => {
#[inline(always)]
fn read_bit(&mut self) -> Result<bool, io::Error> {
if self.bitqueue.is_empty() {
self.bitqueue.set(read_byte(self.reader)?, 8);
}
Ok(self.bitqueue.pop(1) == 1)
}
}
}
macro_rules! define_read {
($bitqueue:ident) => {
fn read<U>(&mut self, mut bits: u32) -> Result<U, io::Error>
where U: Numeric {
use std::cmp::min;
let mut acc: $bitqueue<U> = $bitqueue::new();
let queue_len = self.bitqueue.len();
if queue_len > 0 {
let to_transfer = min(queue_len, bits);
acc.push(to_transfer,
U::from_u8(self.bitqueue.pop(to_transfer)));
bits -= to_transfer;
}
read_aligned(&mut self.reader, bits / 8, &mut acc)
.and_then(|()| read_unaligned(&mut self.reader,
bits % 8,
&mut acc,
&mut self.bitqueue))
.map(|()| acc.value())
}
}
}
macro_rules! define_skip {
() => {
fn skip(&mut self, mut bits: u32) -> Result<(), io::Error> {
use std::cmp::min;
let queue_len = self.bitqueue.len();
if queue_len > 0 {
let to_drop = min(queue_len, bits);
self.bitqueue.drop(to_drop);
bits -= to_drop;
}
skip_aligned(&mut self.reader, bits / 8)
.and_then(|()| skip_unaligned(&mut self.reader,
bits % 8,
&mut self.bitqueue))
}
}
}
macro_rules! define_read_bytes {
() => {
fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
if self.byte_aligned() {
self.reader.read_exact(buf)
} else {
for b in buf.iter_mut() {
*b = self.read::<u8>(8)?;
}
Ok(())
}
}
}
}
macro_rules! define_read_unary {
($method_name:ident,
$aligned_cont_val: expr,
$bitqueue_check: ident,
$bitqueue_pop: ident) => {
fn $method_name(&mut self) -> Result<u32, io::Error> {
if self.bitqueue.is_empty() {
read_aligned_unary(&mut self.reader,
$aligned_cont_val,
&mut self.bitqueue).map(
|u| u + self.bitqueue.$bitqueue_pop())
} else if self.bitqueue.$bitqueue_check() {
let base = self.bitqueue.len();
self.bitqueue.clear();
read_aligned_unary(&mut self.reader,
$aligned_cont_val,
&mut self.bitqueue).map(
|u| base + u + self.bitqueue.$bitqueue_pop())
} else {
Ok(self.bitqueue.$bitqueue_pop())
}
}
}
}
pub struct BitReaderBE<'a> {
reader: &'a mut io::Read,
bitqueue: BitQueueBE<u8>
}
impl<'a> BitReaderBE<'a> {
pub fn new(reader: &mut io::Read) -> BitReaderBE {
BitReaderBE{reader: reader, bitqueue: BitQueueBE::new()}
}
}
impl<'a> BitRead for BitReaderBE<'a> {
define_read_bit!();
define_read!(BitQueueBE);
define_skip!();
define_read_bytes!();
define_read_unary!(read_unary0, 0xFF, all_1, pop_1);
define_read_unary!(read_unary1, 0x00, all_0, pop_0);
fn read_signed<S>(&mut self, bits: u32) -> Result<S, io::Error>
where S: SignedNumeric {
debug_assert!(bits >= 1);
let is_negative = self.read_bit()?;
let unsigned = self.read::<S>(bits - 1)?;
Ok(if is_negative {unsigned.as_negative(bits)} else {unsigned})
}
#[inline]
fn byte_aligned(&self) -> bool {self.bitqueue.is_empty()}
#[inline]
fn byte_align(&mut self) {self.bitqueue.clear()}
}
pub struct BitReaderLE<'a> {
reader: &'a mut io::Read,
bitqueue: BitQueueLE<u8>
}
impl<'a> BitReaderLE<'a> {
pub fn new(reader: &mut io::Read) -> BitReaderLE {
BitReaderLE{reader: reader, bitqueue: BitQueueLE::new()}
}
}
impl<'a> BitRead for BitReaderLE<'a> {
define_read_bit!();
define_read!(BitQueueLE);
define_skip!();
define_read_bytes!();
define_read_unary!(read_unary0, 0xFF, all_1, pop_1);
define_read_unary!(read_unary1, 0x00, all_0, pop_0);
fn read_signed<S>(&mut self, bits: u32) -> Result<S, io::Error>
where S: SignedNumeric {
debug_assert!(bits >= 1);
let unsigned = self.read::<S>(bits - 1)?;
let is_negative = self.read_bit()?;
Ok(if is_negative {unsigned.as_negative(bits)} else {unsigned})
}
#[inline]
fn byte_aligned(&self) -> bool {self.bitqueue.is_empty()}
#[inline]
fn byte_align(&mut self) {self.bitqueue.clear()}
}
#[inline]
fn read_byte(reader: &mut io::Read) -> Result<u8,io::Error> {
let mut buf = [0; 1];
reader.read_exact(&mut buf).map(|()| buf[0])
}
fn read_aligned<N>(reader: &mut io::Read,
mut bytes: u32,
acc: &mut BitQueue<N>) -> Result<(), io::Error>
where N: Numeric {
use std::cmp::min;
let mut buf = [0; 8];
while bytes > 0 {
let to_read = min(8, bytes);
reader.read_exact(&mut buf[0..to_read as usize])?;
for b in buf.iter().take(to_read as usize) {
acc.push(8, N::from_u8(*b));
}
bytes -= to_read;
}
Ok(())
}
fn skip_aligned(reader: &mut io::Read,
mut bytes: u32) -> Result<(), io::Error> {
use std::cmp::min;
let mut buf = [0; 8];
while bytes > 0 {
let to_read = min(8, bytes);
reader.read_exact(&mut buf[0..to_read as usize])?;
bytes -= to_read;
}
Ok(())
}
#[inline]
fn read_unaligned<N>(reader: &mut io::Read,
bits: u32,
acc: &mut BitQueue<N>,
rem: &mut BitQueue<u8>) -> Result<(), io::Error>
where N: Numeric {
debug_assert!(bits <= 8);
if bits > 0 {
rem.set(read_byte(reader)?, 8);
acc.push(bits, N::from_u8(rem.pop(bits)));
}
Ok(())
}
#[inline]
fn skip_unaligned(reader: &mut io::Read,
bits: u32,
rem: &mut BitQueue<u8>) -> Result<(), io::Error> {
debug_assert!(bits <= 8);
if bits > 0 {
rem.set(read_byte(reader)?, 8);
rem.pop(bits);
}
Ok(())
}
#[inline]
fn read_aligned_unary(reader: &mut io::Read,
continue_val: u8,
rem: &mut BitQueue<u8>) -> Result<u32,io::Error> {
let mut acc = 0;
let mut byte = read_byte(reader)?;
while byte == continue_val {
acc += 8;
byte = read_byte(reader)?;
}
rem.set(byte, 8);
Ok(acc)
}