#![warn(missing_docs)]
use std::io;
use super::{huffman::WriteHuffmanTree, BitQueue, Endianness, Numeric, SignedNumeric};
pub struct BitWriter<W: io::Write, E: Endianness> {
writer: W,
bitqueue: BitQueue<E, u8>,
}
impl<W: io::Write, E: Endianness> BitWriter<W, E> {
pub fn new(writer: W) -> BitWriter<W, E> {
BitWriter {
writer,
bitqueue: BitQueue::new(),
}
}
pub fn endian(writer: W, _endian: E) -> BitWriter<W, E> {
BitWriter {
writer,
bitqueue: BitQueue::new(),
}
}
#[inline]
pub fn into_writer(self) -> W {
self.writer
}
pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
self.bitqueue.push(1, if bit { 1 } else { 0 });
if self.bitqueue.is_full() {
write_byte(&mut self.writer, self.bitqueue.pop(8))
} else {
Ok(())
}
}
pub fn write<U>(&mut self, bits: u32, value: U) -> io::Result<()>
where
U: Numeric,
{
if bits > U::bits_size() {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type written",
))
} else if (bits < U::bits_size()) && (value >= (U::one() << bits)) {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive value for bits written",
))
} else if bits < self.bitqueue.remaining_len() {
self.bitqueue.push(bits, value.to_u8());
Ok(())
} else {
let mut acc = BitQueue::from_value(value, bits);
write_unaligned(&mut self.writer, &mut acc, &mut self.bitqueue)?;
write_aligned(&mut self.writer, &mut acc)?;
self.bitqueue.push(acc.len(), acc.value().to_u8());
Ok(())
}
}
#[inline]
pub fn write_signed<S>(&mut self, bits: u32, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
E::write_signed(self, bits, value)
}
pub fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
if self.byte_aligned() {
self.writer.write_all(buf)
} else {
for b in buf {
self.write(8, *b)?;
}
Ok(())
}
}
pub fn write_huffman<T>(&mut self, tree: &WriteHuffmanTree<E, T>, symbol: T) -> io::Result<()>
where
T: Ord + Copy,
{
for &(bits, value) in tree.get(&symbol) {
self.write(bits, value)?;
}
Ok(())
}
pub fn write_unary0(&mut self, value: u32) -> io::Result<()> {
match value {
0 => self.write_bit(false),
bits @ 1..=31 => self
.write(value, (1u32 << bits) - 1)
.and_then(|()| self.write_bit(false)),
32 => self
.write(value, 0xFFFF_FFFFu32)
.and_then(|()| self.write_bit(false)),
bits @ 32..=63 => self
.write(value, (1u64 << bits) - 1)
.and_then(|()| self.write_bit(false)),
64 => self
.write(value, 0xFFFF_FFFF_FFFF_FFFFu64)
.and_then(|()| self.write_bit(false)),
mut bits => {
while bits > 64 {
self.write(64, 0xFFFF_FFFF_FFFF_FFFFu64)?;
bits -= 64;
}
self.write_unary0(bits)
}
}
}
pub fn write_unary1(&mut self, value: u32) -> io::Result<()> {
match value {
0 => self.write_bit(true),
1..=32 => self.write(value, 0u32).and_then(|()| self.write_bit(true)),
33..=64 => self.write(value, 0u64).and_then(|()| self.write_bit(true)),
mut bits => {
while bits > 64 {
self.write(64, 0u64)?;
bits -= 64;
}
self.write_unary1(bits)
}
}
}
#[inline(always)]
pub fn byte_aligned(&self) -> bool {
self.bitqueue.is_empty()
}
pub fn byte_align(&mut self) -> io::Result<()> {
while !self.byte_aligned() {
self.write_bit(false)?;
}
Ok(())
}
#[inline(always)]
pub fn into_unwritten(self) -> (u32, u8) {
(self.bitqueue.len(), self.bitqueue.value())
}
}
#[inline]
fn write_byte<W>(mut writer: W, byte: u8) -> io::Result<()>
where
W: io::Write,
{
let buf = [byte];
writer.write_all(&buf)
}
fn write_unaligned<W, E, N>(
writer: W,
acc: &mut BitQueue<E, N>,
rem: &mut BitQueue<E, u8>,
) -> io::Result<()>
where
W: io::Write,
E: Endianness,
N: Numeric,
{
if rem.is_empty() {
Ok(())
} else {
use std::cmp::min;
let bits_to_transfer = min(8 - rem.len(), acc.len());
rem.push(bits_to_transfer, acc.pop(bits_to_transfer).to_u8());
if rem.len() == 8 {
write_byte(writer, rem.pop(8))
} else {
Ok(())
}
}
}
fn write_aligned<W, E, N>(mut writer: W, acc: &mut BitQueue<E, N>) -> io::Result<()>
where
W: io::Write,
E: Endianness,
N: Numeric,
{
let to_write = (acc.len() / 8) as usize;
if to_write > 0 {
debug_assert!(to_write <= 16);
let mut buf = [0; 16];
for b in buf[0..to_write].iter_mut() {
*b = acc.pop(8).to_u8();
}
writer.write_all(&buf[0..to_write])
} else {
Ok(())
}
}