use core::any::TypeId;
use core::{mem, ptr};
use crate::codes::params::{DefaultWriteParams, WriteParams};
use crate::traits::*;
use common_traits::{AsBytes, CastableInto, FiniteRangeNumber, Integer, Number};
#[cfg(feature = "mem_dbg")]
use mem_dbg::{MemDbg, MemSize};
#[derive(Debug)]
#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
pub struct BufBitWriter<E: Endianness, WW: WordWrite, WP: WriteParams = DefaultWriteParams> {
backend: WW,
buffer: WW::Word,
space_left_in_buffer: usize,
_marker_endianness: core::marker::PhantomData<(E, WP)>,
}
impl<E: Endianness, WW: WordWrite, WP: WriteParams> BufBitWriter<E, WW, WP>
where
BufBitWriter<E, WW, WP>: BitWrite<E>,
{
pub fn new(backend: WW) -> Self {
Self {
backend,
buffer: WW::Word::ZERO,
space_left_in_buffer: WW::Word::BITS,
_marker_endianness: core::marker::PhantomData,
}
}
pub fn into_inner(mut self) -> Result<WW, <Self as BitWrite<E>>::Error> {
self.flush()?;
let backend = unsafe { ptr::read(&self.backend) };
mem::forget(self);
Ok(backend)
}
}
impl<E: Endianness, WW: WordWrite, WP: WriteParams> core::ops::Drop for BufBitWriter<E, WW, WP> {
fn drop(&mut self) {
if TypeId::of::<E>() == TypeId::of::<LE>() {
flush_le(self).unwrap();
} else {
flush_be(self).unwrap();
}
}
}
fn flush_be<E: Endianness, WW: WordWrite, WP: WriteParams>(
buf_bit_writer: &mut BufBitWriter<E, WW, WP>,
) -> Result<usize, WW::Error> {
let to_flush = WW::Word::BITS - buf_bit_writer.space_left_in_buffer;
if to_flush != 0 {
buf_bit_writer.buffer <<= buf_bit_writer.space_left_in_buffer;
buf_bit_writer
.backend
.write_word(buf_bit_writer.buffer.to_be())?;
buf_bit_writer.space_left_in_buffer = WW::Word::BITS;
}
buf_bit_writer.backend.flush()?;
Ok(to_flush)
}
impl<WW: WordWrite, WP: WriteParams> BitWrite<BE> for BufBitWriter<BE, WW, WP>
where
u64: CastableInto<WW::Word>,
{
type Error = <WW as WordWrite>::Error;
fn flush(&mut self) -> Result<usize, Self::Error> {
flush_be(self)
}
#[allow(unused_mut)]
#[inline]
fn write_bits(&mut self, mut value: u64, n_bits: usize) -> Result<usize, Self::Error> {
debug_assert!(n_bits <= 64);
#[cfg(feature = "checks")]
assert!(
value & (1_u128 << n_bits).wrapping_sub(1) as u64 == value,
"Error: value {} does not fit in {} bits",
value,
n_bits
);
debug_assert!(self.space_left_in_buffer > 0);
#[cfg(test)]
if n_bits < 64 {
value |= u64::MAX << n_bits;
}
if n_bits < self.space_left_in_buffer {
self.buffer <<= n_bits;
self.buffer |= value.cast() & !(WW::Word::MAX << n_bits as u32);
self.space_left_in_buffer -= n_bits;
return Ok(n_bits);
}
self.buffer = self.buffer << (self.space_left_in_buffer - 1) << 1;
self.buffer |= (value << (64 - n_bits) >> (64 - self.space_left_in_buffer)).cast();
self.backend.write_word(self.buffer.to_be())?;
let mut to_write = n_bits - self.space_left_in_buffer;
for _ in 0..to_write / WW::Word::BITS {
to_write -= WW::Word::BITS;
self.backend
.write_word((value >> to_write).cast().to_be())?;
}
self.space_left_in_buffer = WW::Word::BITS - to_write;
self.buffer = value.cast();
Ok(n_bits)
}
#[inline]
#[allow(clippy::collapsible_if)]
fn write_unary(&mut self, mut value: u64) -> Result<usize, Self::Error> {
debug_assert_ne!(value, u64::MAX);
debug_assert!(self.space_left_in_buffer > 0);
let code_length = value + 1;
if code_length <= self.space_left_in_buffer as u64 {
self.space_left_in_buffer -= code_length as usize;
self.buffer = self.buffer << value << 1;
self.buffer |= WW::Word::ONE;
if self.space_left_in_buffer == 0 {
self.backend.write_word(self.buffer.to_be())?;
self.space_left_in_buffer = WW::Word::BITS;
}
return Ok(code_length as usize);
}
self.buffer = self.buffer << (self.space_left_in_buffer - 1) << 1;
self.backend.write_word(self.buffer.to_be())?;
value -= self.space_left_in_buffer as u64;
for _ in 0..value / WW::Word::BITS as u64 {
self.backend.write_word(WW::Word::ZERO)?;
}
value %= WW::Word::BITS as u64;
if value == WW::Word::BITS as u64 - 1 {
self.backend.write_word(WW::Word::ONE.to_be())?;
self.space_left_in_buffer = WW::Word::BITS;
} else {
self.buffer = WW::Word::ONE;
self.space_left_in_buffer = WW::Word::BITS - (value as usize + 1);
}
Ok(code_length as usize)
}
#[cfg(not(feature = "no_copy_impls"))]
fn copy_from<F: Endianness, R: BitRead<F>>(
&mut self,
bit_read: &mut R,
mut n: u64,
) -> Result<(), CopyError<R::Error, Self::Error>> {
if n < self.space_left_in_buffer as u64 {
self.buffer = self.buffer << n
| bit_read
.read_bits(n as usize)
.map_err(CopyError::ReadError)?
.cast();
self.space_left_in_buffer -= n as usize;
return Ok(());
}
self.buffer = self.buffer << (self.space_left_in_buffer - 1) << 1
| bit_read
.read_bits(self.space_left_in_buffer)
.map_err(CopyError::ReadError)?
.cast();
n -= self.space_left_in_buffer as u64;
self.backend
.write_word(self.buffer.to_be())
.map_err(CopyError::WriteError)?;
for _ in 0..n / WW::Word::BITS as u64 {
self.backend
.write_word(
bit_read
.read_bits(WW::Word::BITS)
.map_err(CopyError::ReadError)?
.cast()
.to_be(),
)
.map_err(CopyError::WriteError)?;
}
n %= WW::Word::BITS as u64;
self.buffer = bit_read
.read_bits(n as usize)
.map_err(CopyError::ReadError)?
.cast();
self.space_left_in_buffer = WW::Word::BITS - n as usize;
Ok(())
}
}
fn flush_le<E: Endianness, WW: WordWrite, WP: WriteParams>(
buf_bit_writer: &mut BufBitWriter<E, WW, WP>,
) -> Result<usize, WW::Error> {
let to_flush = WW::Word::BITS - buf_bit_writer.space_left_in_buffer;
if to_flush != 0 {
buf_bit_writer.buffer >>= buf_bit_writer.space_left_in_buffer;
buf_bit_writer
.backend
.write_word(buf_bit_writer.buffer.to_le())?;
buf_bit_writer.space_left_in_buffer = WW::Word::BITS;
}
buf_bit_writer.backend.flush()?;
Ok(to_flush)
}
impl<WW: WordWrite, WP: WriteParams> BitWrite<LE> for BufBitWriter<LE, WW, WP>
where
u64: CastableInto<WW::Word>,
{
type Error = <WW as WordWrite>::Error;
fn flush(&mut self) -> Result<usize, Self::Error> {
flush_le(self)
}
#[inline]
fn write_bits(&mut self, mut value: u64, n_bits: usize) -> Result<usize, Self::Error> {
debug_assert!(n_bits <= 64);
#[cfg(feature = "checks")]
assert!(
value & (1_u128 << n_bits).wrapping_sub(1) as u64 == value,
"Error: value {} does not fit in {} bits",
value,
n_bits
);
debug_assert!(self.space_left_in_buffer > 0);
#[cfg(test)]
if n_bits < 64 {
value |= u64::MAX << n_bits;
}
if n_bits < self.space_left_in_buffer {
self.buffer >>= n_bits;
self.buffer |=
(value.cast() & !(WW::Word::MAX << n_bits as u32)).rotate_right(n_bits as u32);
self.space_left_in_buffer -= n_bits;
return Ok(n_bits);
}
self.buffer = self.buffer >> (self.space_left_in_buffer - 1) >> 1;
self.buffer |= value.cast() << (WW::Word::BITS - self.space_left_in_buffer);
self.backend.write_word(self.buffer.to_le())?;
let to_write = n_bits - self.space_left_in_buffer;
value = value >> (self.space_left_in_buffer - 1) >> 1;
for _ in 0..to_write / WW::Word::BITS {
self.backend.write_word(value.cast().to_le())?;
value >>= WW::Word::BITS;
}
self.space_left_in_buffer = WW::Word::BITS - to_write % WW::Word::BITS;
self.buffer = value.cast().rotate_right(to_write as u32);
Ok(n_bits)
}
#[inline]
#[allow(clippy::collapsible_if)]
fn write_unary(&mut self, mut value: u64) -> Result<usize, Self::Error> {
debug_assert_ne!(value, u64::MAX);
debug_assert!(self.space_left_in_buffer > 0);
let code_length = value + 1;
if code_length <= self.space_left_in_buffer as u64 {
self.space_left_in_buffer -= code_length as usize;
self.buffer = self.buffer >> value >> 1;
self.buffer |= WW::Word::ONE << (WW::Word::BITS - 1);
if self.space_left_in_buffer == 0 {
self.backend.write_word(self.buffer.to_le())?;
self.space_left_in_buffer = WW::Word::BITS;
}
return Ok(code_length as usize);
}
self.buffer = self.buffer >> (self.space_left_in_buffer - 1) >> 1;
self.backend.write_word(self.buffer.to_le())?;
value -= self.space_left_in_buffer as u64;
for _ in 0..value / WW::Word::BITS as u64 {
self.backend.write_word(WW::Word::ZERO)?;
}
value %= WW::Word::BITS as u64;
if value == WW::Word::BITS as u64 - 1 {
self.backend
.write_word((WW::Word::ONE << (WW::Word::BITS - 1)).to_le())?;
self.space_left_in_buffer = WW::Word::BITS;
} else {
self.buffer = WW::Word::ONE << (WW::Word::BITS - 1);
self.space_left_in_buffer = WW::Word::BITS - (value as usize + 1);
}
Ok(code_length as usize)
}
#[cfg(not(feature = "no_copy_impls"))]
fn copy_from<F: Endianness, R: BitRead<F>>(
&mut self,
bit_read: &mut R,
mut n: u64,
) -> Result<(), CopyError<R::Error, Self::Error>> {
if n < self.space_left_in_buffer as u64 {
self.buffer = self.buffer >> n
| (bit_read
.read_bits(n as usize)
.map_err(CopyError::ReadError)?)
.cast()
.rotate_right(n as u32);
self.space_left_in_buffer -= n as usize;
return Ok(());
}
self.buffer = self.buffer >> (self.space_left_in_buffer - 1) >> 1
| (bit_read
.read_bits(self.space_left_in_buffer)
.map_err(CopyError::ReadError)?
.cast())
.rotate_right(self.space_left_in_buffer as u32);
n -= self.space_left_in_buffer as u64;
self.backend
.write_word(self.buffer.to_le())
.map_err(CopyError::WriteError)?;
for _ in 0..n / WW::Word::BITS as u64 {
self.backend
.write_word(
bit_read
.read_bits(WW::Word::BITS)
.map_err(CopyError::ReadError)?
.cast()
.to_le(),
)
.map_err(CopyError::WriteError)?;
}
n %= WW::Word::BITS as u64;
self.buffer = bit_read
.read_bits(n as usize)
.map_err(CopyError::ReadError)?
.cast()
.rotate_right(n as u32);
self.space_left_in_buffer = WW::Word::BITS - n as usize;
Ok(())
}
}
macro_rules! test_buf_bit_writer {
($f: ident, $word:ty) => {
#[test]
fn $f() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
use super::MemWordWriterVec;
use crate::{
codes::{GammaRead, GammaWrite},
prelude::{
len_delta, len_gamma, BufBitReader, DeltaRead, DeltaWrite, MemWordReader,
},
};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
let mut buffer_be: Vec<$word> = vec![];
let mut buffer_le: Vec<$word> = vec![];
let mut big = BufBitWriter::<BE, _>::new(MemWordWriterVec::new(&mut buffer_be));
let mut little = BufBitWriter::<LE, _>::new(MemWordWriterVec::new(&mut buffer_le));
let mut r = SmallRng::seed_from_u64(0);
const ITER: usize = 1_000_000;
for _ in 0..ITER {
let value = r.gen_range(0..128);
assert_eq!(big.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_gamma(value)?, len_gamma(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(big.write_delta(value)?, len_delta(value));
let value = r.gen_range(0..128);
assert_eq!(little.write_delta(value)?, len_delta(value));
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
big.write_bits(0, 0)?;
} else {
big.write_bits(r.gen::<u64>() & u64::MAX >> 64 - n_bits, n_bits)?;
}
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
little.write_bits(0, 0)?;
} else {
little.write_bits(r.gen::<u64>() & u64::MAX >> 64 - n_bits, n_bits)?;
}
let value = r.gen_range(0..128);
assert_eq!(big.write_unary(value)?, value as usize + 1);
let value = r.gen_range(0..128);
assert_eq!(little.write_unary(value)?, value as usize + 1);
}
drop(big);
drop(little);
type ReadWord = u16;
let be_trans: &[ReadWord] = unsafe {
core::slice::from_raw_parts(
buffer_be.as_ptr() as *const ReadWord,
buffer_be.len()
* (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
)
};
let le_trans: &[ReadWord] = unsafe {
core::slice::from_raw_parts(
buffer_le.as_ptr() as *const ReadWord,
buffer_le.len()
* (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
)
};
let mut big_buff = BufBitReader::<BE, _>::new(MemWordReader::new(be_trans));
let mut little_buff = BufBitReader::<LE, _>::new(MemWordReader::new(le_trans));
let mut r = SmallRng::seed_from_u64(0);
for _ in 0..ITER {
assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
assert_eq!(big_buff.read_bits(0)?, 0);
} else {
assert_eq!(
big_buff.read_bits(n_bits)?,
r.gen::<u64>() & u64::MAX >> 64 - n_bits
);
}
let n_bits = r.gen_range(0..=64);
if n_bits == 0 {
assert_eq!(little_buff.read_bits(0)?, 0);
} else {
assert_eq!(
little_buff.read_bits(n_bits)?,
r.gen::<u64>() & u64::MAX >> 64 - n_bits
);
}
assert_eq!(big_buff.read_unary()?, r.gen_range(0..128));
assert_eq!(little_buff.read_unary()?, r.gen_range(0..128));
}
Ok(())
}
};
}
test_buf_bit_writer!(test_u128, u128);
test_buf_bit_writer!(test_u64, u64);
test_buf_bit_writer!(test_u32, u32);
test_buf_bit_writer!(test_u16, u16);
test_buf_bit_writer!(test_usize, usize);