#[cfg(test)]
mod tests;
use std::io::Read;
use std::io::Write;
use std::mem::size_of;
use super::Result;
use super::Error::{Eof, IoError, InvalidInput};
pub trait ByteCount {
fn get_count(&self) -> u64;
}
pub trait BitRead {
fn read_bits(&mut self, bits: usize) -> Result<usize>;
}
pub trait BitWrite {
fn write_bits(&mut self, symbol: usize, bits: usize) -> Result<()>;
fn flush_bits(&mut self) -> Result<()>;
}
struct BitBuffer {
bytes: [u8; 1usize],
bits: usize,
count: u64,
}
impl BitBuffer {
fn new() -> BitBuffer {
BitBuffer {
bytes: [0u8; 1usize],
bits: 0usize,
count: 0u64,
}
}
}
pub struct BitReader<'a> {
buffer: BitBuffer,
input: &'a mut Read,
}
impl<'a> BitReader<'a> {
pub fn new(reader: &'a mut Read) -> BitReader<'a> {
BitReader {
buffer: BitBuffer::new(),
input: reader,
}
}
}
impl<'a> ByteCount for BitReader<'a> {
fn get_count(&self) -> u64 {
self.buffer.count
}
}
impl<'a> BitRead for BitReader<'a> {
fn read_bits(&mut self, mut bits: usize) -> Result<usize> {
if bits > size_of::<usize>() * 8 {
return Err(InvalidInput);
}
let mut result = 0usize;
while bits > 0 {
if self.buffer.bits >= bits {
result <<= bits;
result |= self.buffer.bytes[0] as usize >> (self.buffer.bits - bits);
self.buffer.bits -= bits;
self.buffer.bytes[0] &= (1 << self.buffer.bits) - 1;
bits = 0
} else if self.buffer.bits > 0 {
result <<= self.buffer.bits;
result |= self.buffer.bytes[0] as usize;
bits -= self.buffer.bits;
self.buffer.bytes[0] = 0;
self.buffer.bits = 0;
} else {
match self.input.read(&mut self.buffer.bytes) {
Ok(0) => {
return Err(Eof);
},
Ok(_) => {
self.buffer.count += 1;
self.buffer.bits = 8;
},
Err(e) => {
return Err(IoError(e));
}
}
}
}
return Ok(result);
}
}
pub struct BitWriter<'a> {
buffer: BitBuffer,
output: &'a mut Write,
}
impl<'a> BitWriter<'a> {
pub fn new(writer: &'a mut Write) -> BitWriter<'a> {
BitWriter {
buffer: BitBuffer::new(),
output: writer
}
}
}
impl<'a> ByteCount for BitWriter<'a> {
fn get_count(&self) -> u64 {
self.buffer.count
}
}
impl<'a> BitWrite for BitWriter<'a> {
fn write_bits(&mut self, mut symbol: usize, mut bits: usize) -> Result<()> {
if (bits > size_of::<usize>() * 8) || (symbol >> bits > 0){
return Err(InvalidInput);
}
while bits > 0 {
if self.buffer.bits + bits <= 8 {
if self.buffer.bits > 0 {
self.buffer.bytes[0] <<= bits;
}
self.buffer.bytes[0] |= symbol as u8;
self.buffer.bits += bits;
bits = 0;
symbol = 0;
} else if self.buffer.bits < 8 {
let num = 8 - self.buffer.bits;
if self.buffer.bits > 0 {
self.buffer.bytes[0] <<= num;
}
self.buffer.bytes[0] |= (symbol >> (bits - num)) as u8;
self.buffer.bits += num;
bits -= num;
symbol &= (1 << bits) - 1;
}
if self.buffer.bits == 8 {
try!(self.flush_bits())
}
}
return Ok(())
}
fn flush_bits(&mut self) -> Result<()> {
if self.buffer.bits > 0 {
self.buffer.bytes[0] <<= 8 - self.buffer.bits;
match self.output.write_all(&self.buffer.bytes) {
Ok(_) => {
self.buffer.count += 1;
self.buffer.bytes[0] = 0;
self.buffer.bits = 0;
},
Err(e) => {
return Err(IoError(e));
}
}
}
return Ok(())
}
}