#![warn(missing_docs)]
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use core2::io;
#[cfg(not(feature = "alloc"))]
use std::io;
use core::convert::From;
use core::ops::{AddAssign, Rem};
use super::{
huffman::WriteHuffmanTree, BitQueue, Endianness, Numeric, PhantomData, Primitive, SignedNumeric,
};
pub struct BitWriter<W: io::Write, E: Endianness> {
writer: W,
bitqueue: BitQueue<E, u8>,
}
impl<W: io::Write, E: Endianness> BitWriter<W, E> {
pub fn new(writer: W) -> BitWriter<W, E> {
BitWriter {
writer,
bitqueue: BitQueue::new(),
}
}
pub fn endian(writer: W, _endian: E) -> BitWriter<W, E> {
BitWriter {
writer,
bitqueue: BitQueue::new(),
}
}
#[inline]
pub fn into_writer(self) -> W {
self.writer
}
#[inline]
pub fn writer(&mut self) -> Option<&mut W> {
if self.byte_aligned() {
Some(&mut self.writer)
} else {
None
}
}
#[inline]
pub fn into_bytewriter(self) -> ByteWriter<W, E> {
ByteWriter::new(self.into_writer())
}
#[inline]
pub fn bytewriter(&mut self) -> Option<ByteWriter<&mut W, E>> {
self.writer().map(ByteWriter::new)
}
#[inline(always)]
pub fn into_unwritten(self) -> (u32, u8) {
(self.bitqueue.len(), self.bitqueue.value())
}
#[inline(always)]
pub fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
}
pub trait BitWrite {
fn write_bit(&mut self, bit: bool) -> io::Result<()>;
fn write<U>(&mut self, bits: u32, value: U) -> io::Result<()>
where
U: Numeric;
fn write_out<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
where
U: Numeric,
{
self.write(BITS, value)
}
fn write_signed<S>(&mut self, bits: u32, value: S) -> io::Result<()>
where
S: SignedNumeric;
fn write_signed_out<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
self.write_signed(BITS, value)
}
fn write_from<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive;
fn write_as_from<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive;
#[inline]
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
buf.iter().try_for_each(|b| self.write_out::<8, _>(*b))
}
fn write_unary0(&mut self, value: u32) -> io::Result<()> {
match value {
0 => self.write_bit(false),
bits @ 1..=31 => self
.write(value, (1u32 << bits) - 1)
.and_then(|()| self.write_bit(false)),
32 => self
.write(value, 0xFFFF_FFFFu32)
.and_then(|()| self.write_bit(false)),
bits @ 33..=63 => self
.write(value, (1u64 << bits) - 1)
.and_then(|()| self.write_bit(false)),
64 => self
.write(value, 0xFFFF_FFFF_FFFF_FFFFu64)
.and_then(|()| self.write_bit(false)),
mut bits => {
while bits > 64 {
self.write(64, 0xFFFF_FFFF_FFFF_FFFFu64)?;
bits -= 64;
}
self.write_unary0(bits)
}
}
}
fn write_unary1(&mut self, value: u32) -> io::Result<()> {
match value {
0 => self.write_bit(true),
1..=32 => self.write(value, 0u32).and_then(|()| self.write_bit(true)),
33..=64 => self.write(value, 0u64).and_then(|()| self.write_bit(true)),
mut bits => {
while bits > 64 {
self.write(64, 0u64)?;
bits -= 64;
}
self.write_unary1(bits)
}
}
}
fn build<T: ToBitStream>(&mut self, build: &T) -> Result<(), T::Error> {
build.to_writer(self)
}
fn build_with<'a, T: ToBitStreamWith<'a>>(
&mut self,
build: &T,
context: &T::Context,
) -> Result<(), T::Error> {
build.to_writer(self, context)
}
fn byte_aligned(&self) -> bool;
fn byte_align(&mut self) -> io::Result<()> {
while !self.byte_aligned() {
self.write_bit(false)?;
}
Ok(())
}
}
pub trait HuffmanWrite<E: Endianness> {
fn write_huffman<T>(&mut self, tree: &WriteHuffmanTree<E, T>, symbol: T) -> io::Result<()>
where
T: Ord + Copy;
}
impl<W: io::Write, E: Endianness> BitWrite for BitWriter<W, E> {
fn write_bit(&mut self, bit: bool) -> io::Result<()> {
self.bitqueue.push(1, u8::from(bit));
if self.bitqueue.is_full() {
write_byte(&mut self.writer, self.bitqueue.pop(8))
} else {
Ok(())
}
}
fn write<U>(&mut self, bits: u32, value: U) -> io::Result<()>
where
U: Numeric,
{
if bits > U::BITS_SIZE {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type written",
))
} else if (bits < U::BITS_SIZE) && (value >= (U::ONE << bits)) {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive value for bits written",
))
} else if bits < self.bitqueue.remaining_len() {
self.bitqueue.push(bits, value.to_u8());
Ok(())
} else {
let mut acc = BitQueue::from_value(value, bits);
write_unaligned(&mut self.writer, &mut acc, &mut self.bitqueue)?;
write_aligned(&mut self.writer, &mut acc)?;
self.bitqueue.push(acc.len(), acc.value().to_u8());
Ok(())
}
}
fn write_out<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
where
U: Numeric,
{
const {
assert!(BITS <= U::BITS_SIZE, "excessive bits for type written");
}
if (BITS < U::BITS_SIZE) && (value >= (U::ONE << BITS)) {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive value for bits written",
))
} else if BITS < self.bitqueue.remaining_len() {
self.bitqueue.push_fixed::<BITS>(value.to_u8());
Ok(())
} else {
let mut acc = BitQueue::from_value(value, BITS);
write_unaligned(&mut self.writer, &mut acc, &mut self.bitqueue)?;
write_aligned(&mut self.writer, &mut acc)?;
self.bitqueue.push(acc.len(), acc.value().to_u8());
Ok(())
}
}
#[inline]
fn write_signed<S>(&mut self, bits: u32, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
match bits {
0 => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"signed writes need at least 1 bit for sign",
)),
bits if bits <= S::BITS_SIZE => E::write_signed(self, bits, value),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type written",
)),
}
}
#[inline]
fn write_signed_out<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
const {
assert!(BITS > 0, "signed writes need at least 1 bit for sign");
assert!(BITS <= S::BITS_SIZE, "excessive bits for type written");
}
E::write_signed_fixed::<_, BITS, S>(self, value)
}
#[inline]
fn write_from<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive,
{
E::write_primitive(self, value)
}
#[inline]
fn write_as_from<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive,
{
F::write_primitive(self, value)
}
#[inline]
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
if self.byte_aligned() {
self.writer.write_all(buf)
} else {
buf.iter().try_for_each(|b| self.write_out::<8, _>(*b))
}
}
#[inline(always)]
fn byte_aligned(&self) -> bool {
self.bitqueue.is_empty()
}
}
impl<W: io::Write, E: Endianness> HuffmanWrite<E> for BitWriter<W, E> {
#[inline]
fn write_huffman<T>(&mut self, tree: &WriteHuffmanTree<E, T>, symbol: T) -> io::Result<()>
where
T: Ord + Copy,
{
tree.get(&symbol)
.try_for_each(|(bits, value)| self.write(*bits, *value))
}
}
#[derive(Default)]
pub struct BitCounter<N, E: Endianness> {
bits: N,
phantom: PhantomData<E>,
}
impl<N: Default + Copy, E: Endianness> BitCounter<N, E> {
#[inline]
pub fn new() -> Self {
BitCounter {
bits: N::default(),
phantom: PhantomData,
}
}
#[inline]
pub fn written(&self) -> N {
self.bits
}
}
impl<N, E> BitWrite for BitCounter<N, E>
where
E: Endianness,
N: Copy + AddAssign + From<u32> + Rem<Output = N> + PartialEq,
{
#[inline]
fn write_bit(&mut self, _bit: bool) -> io::Result<()> {
self.bits += 1.into();
Ok(())
}
#[inline]
fn write<U>(&mut self, bits: u32, value: U) -> io::Result<()>
where
U: Numeric,
{
if bits > U::BITS_SIZE {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type written",
))
} else if (bits < U::BITS_SIZE) && (value >= (U::ONE << bits)) {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive value for bits written",
))
} else {
self.bits += bits.into();
Ok(())
}
}
fn write_out<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
where
U: Numeric,
{
const {
assert!(BITS <= U::BITS_SIZE, "excessive bits for type written");
}
if (BITS < U::BITS_SIZE) && (value >= (U::ONE << BITS)) {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive value for bits written",
))
} else {
self.bits += BITS.into();
Ok(())
}
}
#[inline]
fn write_signed<S>(&mut self, bits: u32, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
match bits {
0 => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"signed writes need at least 1 bit for sign",
)),
bits if bits <= S::BITS_SIZE => E::write_signed(self, bits, value),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"excessive bits for type written",
)),
}
}
#[inline]
fn write_signed_out<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
const {
assert!(BITS > 0, "signed writes need at least 1 bit for sign");
assert!(BITS <= S::BITS_SIZE, "excessive bits for type written");
}
E::write_signed_fixed::<_, BITS, S>(self, value)
}
#[inline]
fn write_from<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive,
{
E::write_primitive(self, value)
}
#[inline]
fn write_as_from<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive,
{
F::write_primitive(self, value)
}
#[inline]
fn write_unary1(&mut self, value: u32) -> io::Result<()> {
self.bits += (value + 1).into();
Ok(())
}
#[inline]
fn write_unary0(&mut self, value: u32) -> io::Result<()> {
self.bits += (value + 1).into();
Ok(())
}
#[inline]
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
self.bits += (buf.len() as u32 * 8).into();
Ok(())
}
#[inline]
fn byte_aligned(&self) -> bool {
self.bits % 8.into() == 0.into()
}
}
impl<N, E> HuffmanWrite<E> for BitCounter<N, E>
where
E: Endianness,
N: AddAssign + From<u32>,
{
fn write_huffman<T>(&mut self, tree: &WriteHuffmanTree<E, T>, symbol: T) -> io::Result<()>
where
T: Ord + Copy,
{
for &(bits, _) in tree.get(&symbol) {
let bits: N = bits.into();
self.bits += bits;
}
Ok(())
}
}
pub struct UnsignedValue(InnerUnsignedValue);
enum InnerUnsignedValue {
U8(u8),
U16(u16),
U32(u32),
U64(u64),
U128(u128),
I8(i8),
I16(i16),
I32(i32),
I64(i64),
I128(i128),
}
macro_rules! define_unsigned_value {
($t:ty, $n:ident) => {
impl From<$t> for UnsignedValue {
#[inline]
fn from(v: $t) -> Self {
UnsignedValue(InnerUnsignedValue::$n(v))
}
}
};
}
define_unsigned_value!(u8, U8);
define_unsigned_value!(u16, U16);
define_unsigned_value!(u32, U32);
define_unsigned_value!(u64, U64);
define_unsigned_value!(u128, U128);
define_unsigned_value!(i8, I8);
define_unsigned_value!(i16, I16);
define_unsigned_value!(i32, I32);
define_unsigned_value!(i64, I64);
define_unsigned_value!(i128, I128);
pub struct SignedValue(InnerSignedValue);
enum InnerSignedValue {
I8(i8),
I16(i16),
I32(i32),
I64(i64),
I128(i128),
}
macro_rules! define_signed_value {
($t:ty, $n:ident) => {
impl From<$t> for SignedValue {
#[inline]
fn from(v: $t) -> Self {
SignedValue(InnerSignedValue::$n(v))
}
}
};
}
define_signed_value!(i8, I8);
define_signed_value!(i16, I16);
define_signed_value!(i32, I32);
define_signed_value!(i64, I64);
define_signed_value!(i128, I128);
enum WriteRecord {
Bit(bool),
Unsigned { bits: u32, value: UnsignedValue },
Signed { bits: u32, value: SignedValue },
Unary0(u32),
Unary1(u32),
Bytes(Box<[u8]>),
}
impl WriteRecord {
fn playback<W: BitWrite>(&self, writer: &mut W) -> io::Result<()> {
match self {
WriteRecord::Bit(v) => writer.write_bit(*v),
WriteRecord::Unsigned {
bits,
value: UnsignedValue(value),
} => match value {
InnerUnsignedValue::U8(v) => writer.write(*bits, *v),
InnerUnsignedValue::U16(v) => writer.write(*bits, *v),
InnerUnsignedValue::U32(v) => writer.write(*bits, *v),
InnerUnsignedValue::U64(v) => writer.write(*bits, *v),
InnerUnsignedValue::U128(v) => writer.write(*bits, *v),
InnerUnsignedValue::I8(v) => writer.write(*bits, *v),
InnerUnsignedValue::I16(v) => writer.write(*bits, *v),
InnerUnsignedValue::I32(v) => writer.write(*bits, *v),
InnerUnsignedValue::I64(v) => writer.write(*bits, *v),
InnerUnsignedValue::I128(v) => writer.write(*bits, *v),
},
WriteRecord::Signed {
bits,
value: SignedValue(value),
} => match value {
InnerSignedValue::I8(v) => writer.write_signed(*bits, *v),
InnerSignedValue::I16(v) => writer.write_signed(*bits, *v),
InnerSignedValue::I32(v) => writer.write_signed(*bits, *v),
InnerSignedValue::I64(v) => writer.write_signed(*bits, *v),
InnerSignedValue::I128(v) => writer.write_signed(*bits, *v),
},
WriteRecord::Unary0(v) => writer.write_unary0(*v),
WriteRecord::Unary1(v) => writer.write_unary1(*v),
WriteRecord::Bytes(bytes) => writer.write_bytes(bytes),
}
}
}
#[derive(Default)]
pub struct BitRecorder<N, E: Endianness> {
counter: BitCounter<N, E>,
records: Vec<WriteRecord>,
}
impl<N: Default + Copy, E: Endianness> BitRecorder<N, E> {
#[inline]
pub fn new() -> Self {
BitRecorder {
counter: BitCounter::new(),
records: Vec::new(),
}
}
#[inline]
pub fn with_capacity(writes: usize) -> Self {
BitRecorder {
counter: BitCounter::new(),
records: Vec::with_capacity(writes),
}
}
#[inline]
pub fn endian(_endian: E) -> Self {
BitRecorder {
counter: BitCounter::new(),
records: Vec::new(),
}
}
#[inline]
pub fn written(&self) -> N {
self.counter.written()
}
#[inline]
pub fn playback<W: BitWrite>(&self, writer: &mut W) -> io::Result<()> {
self.records
.iter()
.try_for_each(|record| record.playback(writer))
}
}
impl<N, E> BitWrite for BitRecorder<N, E>
where
E: Endianness,
N: Copy + From<u32> + AddAssign + Rem<Output = N> + Eq,
{
#[inline]
fn write_bit(&mut self, bit: bool) -> io::Result<()> {
self.records.push(WriteRecord::Bit(bit));
self.counter.write_bit(bit)
}
#[inline]
fn write<U>(&mut self, bits: u32, value: U) -> io::Result<()>
where
U: Numeric,
{
self.counter.write(bits, value)?;
self.records.push(WriteRecord::Unsigned {
bits,
value: value.unsigned_value(),
});
Ok(())
}
#[inline]
fn write_out<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
where
U: Numeric,
{
self.counter.write_out::<BITS, U>(value)?;
self.records.push(WriteRecord::Unsigned {
bits: BITS,
value: value.unsigned_value(),
});
Ok(())
}
#[inline]
fn write_signed<S>(&mut self, bits: u32, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
self.counter.write_signed(bits, value)?;
self.records.push(WriteRecord::Signed {
bits,
value: value.signed_value(),
});
Ok(())
}
#[inline]
fn write_signed_out<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
where
S: SignedNumeric,
{
self.counter.write_signed_out::<BITS, S>(value)?;
self.records.push(WriteRecord::Signed {
bits: BITS,
value: value.signed_value(),
});
Ok(())
}
#[inline]
fn write_from<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive,
{
E::write_primitive(self, value)
}
#[inline]
fn write_as_from<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive,
{
F::write_primitive(self, value)
}
#[inline]
fn write_unary0(&mut self, value: u32) -> io::Result<()> {
self.records.push(WriteRecord::Unary0(value));
self.counter.write_unary0(value)
}
#[inline]
fn write_unary1(&mut self, value: u32) -> io::Result<()> {
self.records.push(WriteRecord::Unary1(value));
self.counter.write_unary1(value)
}
#[inline]
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
self.records.push(WriteRecord::Bytes(buf.into()));
self.counter.write_bytes(buf)
}
#[inline]
fn byte_aligned(&self) -> bool {
self.counter.byte_aligned()
}
}
impl<N, E> HuffmanWrite<E> for BitRecorder<N, E>
where
E: Endianness,
N: Copy + From<u32> + AddAssign + Rem<Output = N> + Eq,
{
#[inline]
fn write_huffman<T>(&mut self, tree: &WriteHuffmanTree<E, T>, symbol: T) -> io::Result<()>
where
T: Ord + Copy,
{
tree.get(&symbol)
.try_for_each(|(bits, value)| self.write(*bits, *value))
}
}
#[inline]
fn write_byte<W>(mut writer: W, byte: u8) -> io::Result<()>
where
W: io::Write,
{
writer.write_all(core::slice::from_ref(&byte))
}
fn write_unaligned<W, E, N>(
writer: W,
acc: &mut BitQueue<E, N>,
rem: &mut BitQueue<E, u8>,
) -> io::Result<()>
where
W: io::Write,
E: Endianness,
N: Numeric,
{
if rem.is_empty() {
Ok(())
} else {
use core::cmp::min;
let bits_to_transfer = min(8 - rem.len(), acc.len());
rem.push(bits_to_transfer, acc.pop(bits_to_transfer).to_u8());
if rem.len() == 8 {
write_byte(writer, rem.pop(8))
} else {
Ok(())
}
}
}
fn write_aligned<W, E, N>(mut writer: W, acc: &mut BitQueue<E, N>) -> io::Result<()>
where
W: io::Write,
E: Endianness,
N: Numeric,
{
let to_write = (acc.len() / 8) as usize;
if to_write > 0 {
let mut buf = N::buffer();
let buf_ref: &mut [u8] = buf.as_mut();
for b in buf_ref[0..to_write].iter_mut() {
*b = acc.pop_fixed::<8>().to_u8();
}
writer.write_all(&buf_ref[0..to_write])
} else {
Ok(())
}
}
pub struct ByteWriter<W: io::Write, E: Endianness> {
phantom: PhantomData<E>,
writer: W,
}
impl<W: io::Write, E: Endianness> ByteWriter<W, E> {
pub fn new(writer: W) -> ByteWriter<W, E> {
ByteWriter {
phantom: PhantomData,
writer,
}
}
pub fn endian(writer: W, _endian: E) -> ByteWriter<W, E> {
ByteWriter {
phantom: PhantomData,
writer,
}
}
#[inline]
pub fn into_writer(self) -> W {
self.writer
}
#[inline]
pub fn writer(&mut self) -> &mut W {
&mut self.writer
}
#[inline]
pub fn into_bitwriter(self) -> BitWriter<W, E> {
BitWriter::new(self.into_writer())
}
#[inline]
pub fn bitwriter(&mut self) -> BitWriter<&mut W, E> {
BitWriter::new(self.writer())
}
}
pub trait ByteWrite {
fn write<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive;
fn write_as<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive;
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()>;
fn build<T: ToByteStream>(&mut self, build: &T) -> Result<(), T::Error> {
build.to_writer(self)
}
fn build_with<'a, T: ToByteStreamWith<'a>>(
&mut self,
build: &T,
context: &T::Context,
) -> Result<(), T::Error> {
build.to_writer(self, context)
}
fn writer_ref(&mut self) -> &mut dyn io::Write;
}
impl<W: io::Write, E: Endianness> ByteWrite for ByteWriter<W, E> {
#[inline]
fn write<V>(&mut self, value: V) -> io::Result<()>
where
V: Primitive,
{
E::write_numeric(&mut self.writer, value)
}
#[inline]
fn write_as<F, V>(&mut self, value: V) -> io::Result<()>
where
F: Endianness,
V: Primitive,
{
F::write_numeric(&mut self.writer, value)
}
#[inline]
fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
self.writer.write_all(buf)
}
#[inline]
fn writer_ref(&mut self) -> &mut dyn io::Write {
&mut self.writer
}
}
pub trait ToBitStream {
type Error;
fn to_writer<W: BitWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error>
where
Self: Sized;
}
pub trait ToBitStreamWith<'a> {
type Context: 'a;
type Error;
fn to_writer<W: BitWrite + ?Sized>(
&self,
w: &mut W,
context: &Self::Context,
) -> Result<(), Self::Error>
where
Self: Sized;
}
pub trait ToByteStream {
type Error;
fn to_writer<W: ByteWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error>
where
Self: Sized;
}
pub trait ToByteStreamWith<'a> {
type Context: 'a;
type Error;
fn to_writer<W: ByteWrite + ?Sized>(
&self,
w: &mut W,
context: &Self::Context,
) -> Result<(), Self::Error>
where
Self: Sized;
}