use std::convert::Infallible;
use std::fmt::Write;
pub trait Bits: seal_bits::Sealed {}
impl<T: seal_bits::Sealed> Bits for T {}
pub trait SignedBits: seal_signed_bits::Sealed {}
impl<T: seal_signed_bits::Sealed> SignedBits for T {}
pub trait BitSink: Sized {
type Error: std::error::Error;
fn align_to_byte(&mut self) -> Result<usize, Self::Error>;
#[inline]
fn write_bytes_aligned(&mut self, bytes: &[u8]) -> Result<usize, Self::Error> {
let ret = self.align_to_byte()?;
for b in bytes {
self.write(*b)?;
}
Ok(ret)
}
fn write_lsbs<T: Bits>(&mut self, val: T, n: usize) -> Result<(), Self::Error>;
fn write_msbs<T: Bits>(&mut self, val: T, n: usize) -> Result<(), Self::Error>;
fn write<T: Bits>(&mut self, val: T) -> Result<(), Self::Error>;
#[inline]
fn write_twoc<T: SignedBits>(
&mut self,
val: T,
bits_per_sample: usize,
) -> Result<(), Self::Error> {
let val: i64 = val.into();
let shifted = (val << (64 - bits_per_sample)) as u64;
self.write_msbs(shifted, bits_per_sample)
}
#[inline]
fn write_zeros(&mut self, n: usize) -> Result<(), Self::Error> {
let mut n = n;
while n > 64 {
self.write(0u64)?;
n -= 64;
}
self.write_msbs(0u64, n)?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct MemSink<S> {
storage: Vec<S>,
bitlength: usize,
}
pub type ByteSink = MemSink<u8>;
impl<S: Bits> Default for MemSink<S> {
fn default() -> Self {
Self::new()
}
}
impl<S: Bits> MemSink<S> {
pub fn new() -> Self {
Self {
storage: vec![],
bitlength: 0usize,
}
}
pub fn with_capacity(capacity_in_bits: usize) -> Self {
Self {
storage: Vec::with_capacity((capacity_in_bits >> S::BITS_LOG2) + 1),
bitlength: 0usize,
}
}
pub fn clear(&mut self) {
self.storage.clear();
self.bitlength = 0;
}
pub fn len(&self) -> usize {
self.bitlength
}
pub fn is_empty(&self) -> bool {
self.bitlength == 0
}
pub fn reserve(&mut self, additional_in_bits: usize) {
self.storage
.reserve((additional_in_bits >> S::BITS_LOG2) + 1);
}
#[inline]
const fn paddings(&self) -> usize {
((!self.bitlength).wrapping_add(1)) & (S::BITS - 1)
}
#[inline]
const fn paddings_to_byte(&self) -> usize {
((!self.bitlength).wrapping_add(1)) & 7
}
pub fn to_bitstring(&self) -> String {
let mut ret = String::new();
for v in &self.storage {
for b in v.to_be_bytes().as_ref() {
write!(ret, "{b:08b}").unwrap();
}
ret.push('_');
}
ret.pop();
for _t in 0..self.paddings() {
ret.pop();
}
for _t in 0..self.paddings() {
ret.push('*');
}
ret
}
#[inline]
pub fn into_inner(self) -> Vec<S> {
self.storage
}
#[inline]
pub fn as_slice(&self) -> &[S] {
&self.storage
}
pub fn write_to_byte_slice(&self, dest: &mut [u8]) {
let destlen = dest.len();
let bytes_per_elem = std::mem::size_of::<S>();
let mut head = 0;
for v in &self.storage {
if head + bytes_per_elem <= destlen {
dest[head..head + bytes_per_elem].copy_from_slice(v.to_be_bytes().as_ref());
} else {
let rem = destlen - head;
dest[head..].copy_from_slice(&v.to_be_bytes().as_ref()[..rem]);
}
head += bytes_per_elem;
}
}
}
impl BitSink for MemSink<u8> {
type Error = Infallible;
#[inline]
fn write<T: Bits>(&mut self, val: T) -> Result<(), Self::Error> {
let nbitlength = self.bitlength + 8 * std::mem::size_of::<T>();
let tail = self.paddings();
if tail > 0 {
self.write_msbs(val, tail)?;
}
let val = val << tail;
let bytes: T::Bytes = val.to_be_bytes();
self.storage.extend_from_slice(bytes.as_ref());
self.bitlength = nbitlength;
Ok(())
}
#[inline]
fn align_to_byte(&mut self) -> Result<usize, Self::Error> {
let r = self.paddings();
self.bitlength += r;
Ok(r)
}
#[inline]
fn write_bytes_aligned(&mut self, bytes: &[u8]) -> Result<usize, Self::Error> {
let ret = self.align_to_byte()?;
self.storage.extend_from_slice(bytes);
self.bitlength += 8 * bytes.len();
Ok(ret)
}
#[inline]
fn write_msbs<T: Bits>(&mut self, mut val: T, mut n: usize) -> Result<(), Self::Error> {
if n == 0 {
return Ok(());
}
let r = self.paddings();
self.bitlength += n;
val &= !((T::one() << (T::BITS - n)) - T::one());
if r != 0 {
let b = (val >> (T::BITS - r)).as_();
*self.storage.last_mut().unwrap() |= b;
val <<= r;
if r >= n {
return Ok(());
}
n -= r;
}
let bytes_to_write = n >> 3;
if bytes_to_write > 0 {
let bytes = val.to_ne_bytes();
let bytes = bytes.as_ref();
#[cfg(target_endian = "little")]
{
for i in 0..bytes_to_write {
self.storage.push(bytes[std::mem::size_of::<T>() - i - 1]);
}
}
#[cfg(target_endian = "big")]
{
for i in 0..bytes_to_write {
self.storage.push(bytes[i]);
}
}
n &= 7;
}
if n > 0 {
val <<= bytes_to_write << 3;
let tail_byte: u8 = (val >> (T::BITS - 8)).as_();
self.storage.push(tail_byte);
}
Ok(())
}
#[inline]
fn write_lsbs<T: Bits>(&mut self, val: T, n: usize) -> Result<(), Self::Error> {
if n == 0 {
return Ok(());
}
self.write_msbs(val << (T::BITS - n), n)
}
#[inline]
fn write_zeros(&mut self, n: usize) -> Result<(), Self::Error> {
let pad = self.paddings();
if n <= pad {
self.bitlength += n;
return Ok(());
}
self.bitlength += pad;
let n = n - pad;
let bytes = (n + 7) >> 3;
self.storage.resize(self.storage.len() + bytes, 0u8);
self.bitlength += n;
Ok(())
}
}
impl MemSink<u64> {
#[inline]
fn write_msbs_impl<T: Bits>(&mut self, val: T, n: usize) {
debug_assert!((val >> (T::BITS - n)) << (T::BITS - n) == val);
let r = self.paddings();
self.bitlength += n;
let mut val: u64 = val.into();
val <<= 64 - T::BITS;
let last_setter = val.wrapping_shr(64u32 - r as u32);
val = val.wrapping_shl(r as u32);
if r != 0 {
if let Some(p) = self.storage.last_mut() {
*p |= last_setter;
}
}
if r < n {
self.storage.push(val);
}
}
}
impl BitSink for MemSink<u64> {
type Error = Infallible;
#[inline]
fn write<T: Bits>(&mut self, val: T) -> Result<(), Self::Error> {
self.write_msbs(val, T::BITS)
}
#[inline]
fn align_to_byte(&mut self) -> Result<usize, Self::Error> {
let r = self.paddings_to_byte();
self.bitlength += r;
Ok(r)
}
#[inline]
fn write_bytes_aligned(&mut self, bytes: &[u8]) -> Result<usize, Self::Error> {
let r = self.align_to_byte()?;
for b in bytes {
self.write(*b)?;
}
Ok(r)
}
#[inline]
fn write_msbs<T: Bits>(&mut self, val: T, n: usize) -> Result<(), Self::Error> {
let mut val = val;
val &= !((T::one() << (T::BITS - n)) - T::one());
self.write_msbs_impl(val, n);
Ok(())
}
#[inline]
fn write_lsbs<T: Bits>(&mut self, val: T, n: usize) -> Result<(), Self::Error> {
self.write_msbs_impl(val << (T::BITS - n), n);
Ok(())
}
#[inline]
fn write_zeros(&mut self, n: usize) -> Result<(), Self::Error> {
let pad = self.paddings();
self.bitlength += n;
let n = n.saturating_sub(pad);
let elems: usize =
(n + <u64 as seal_bits::Sealed>::BITS - 1) >> <u64 as seal_bits::Sealed>::BITS_LOG2;
if elems > 0 {
self.storage.resize(self.storage.len() + elems, 0u64);
}
Ok(())
}
}
mod seal_bits {
use num_traits::AsPrimitive;
use num_traits::One;
use num_traits::PrimInt;
use num_traits::ToBytes;
use num_traits::WrappingShl;
pub trait Sealed:
std::ops::BitAndAssign
+ std::ops::ShlAssign<usize>
+ From<u8>
+ Into<u64>
+ AsPrimitive<u8>
+ One
+ PrimInt
+ ToBytes
+ WrappingShl
{
const BITS: usize = 1usize << Self::BITS_LOG2;
const BYTES: usize = Self::BITS / 8usize;
#[rustversion::since(1.67)]
const BITS_LOG2: usize = (std::mem::size_of::<Self>() * 8).ilog2() as usize;
#[rustversion::before(1.67)]
const BITS_LOG2: usize = 3 + std::mem::size_of::<Self>().trailing_zeros() as usize;
}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
}
mod seal_signed_bits {
pub trait Sealed: Into<i64> {}
impl Sealed for i8 {}
impl Sealed for i16 {}
impl Sealed for i32 {}
impl Sealed for i64 {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn align_to_byte_with_bitvec() -> Result<(), Infallible> {
let mut sink: MemSink<u8> = MemSink::new();
sink.write_lsbs(0x01u8, 1)?;
sink.align_to_byte()?;
assert_eq!(sink.len(), 8);
sink.align_to_byte()?;
assert_eq!(sink.len(), 8);
sink.write_lsbs(0x01u8, 2)?;
assert_eq!(sink.len(), 10);
sink.align_to_byte()?;
assert_eq!(sink.len(), 16);
Ok(())
}
#[test]
fn twoc_writing() -> Result<(), Infallible> {
let mut sink: MemSink<u8> = MemSink::new();
sink.write_twoc(-7, 4)?;
assert_eq!(sink.to_bitstring(), "1001****");
Ok(())
}
#[test]
fn bytevec_write_msb() -> Result<(), Infallible> {
let mut bv = ByteSink::new();
bv.write_msbs(0xFFu8, 3)?;
bv.write_msbs(0x0u64, 12)?;
bv.write_msbs(0xFFFF_FFFFu32, 9)?;
bv.write_msbs(0x0u16, 8)?;
assert_eq!(bv.to_bitstring(), "11100000_00000001_11111111_00000000");
let mut bv = ByteSink::new();
bv.write_msbs(0xA0u8, 3)?;
assert_eq!(bv.to_bitstring(), "101*****");
let mut bv = ByteSink::new();
bv.write_msbs(0x00u8, 2)?;
bv.write_msbs(0xFFu8, 3)?;
bv.write_msbs(0x00u8, 2)?;
assert_eq!(bv.to_bitstring(), "0011100*");
Ok(())
}
#[test]
fn u64vec_write_msb() -> Result<(), Infallible> {
let mut u64v = MemSink::<u64>::new();
u64v.write_msbs(0xFFu8, 3)?;
assert_eq!(
u64v.to_bitstring(),
"111*************************************************************"
);
u64v.write_msbs(0u16, 15)?;
assert_eq!(
u64v.to_bitstring(),
"111000000000000000**********************************************"
);
u64v.write_msbs(0u64.wrapping_sub(1u64), 45)?;
assert_eq!(
u64v.to_bitstring(),
"111000000000000000111111111111111111111111111111111111111111111*"
);
u64v.write_msbs(0xAAAA_AAAA_AAAA_AAAAu64, 60)?;
assert_eq!(
u64v.to_bitstring(),
concat!(
"1110000000000000001111111111111111111111111111111111111111111111_",
"01010101010101010101010101010101010101010101010101010101010*****"
)
);
u64v.align_to_byte()?;
assert_eq!(
u64v.to_bitstring(),
concat!(
"1110000000000000001111111111111111111111111111111111111111111111_",
"0101010101010101010101010101010101010101010101010101010101000000"
)
);
u64v.write_msbs(0xAAAA_AAAA_AAAA_AAAAu64, 60)?;
assert_eq!(
u64v.to_bitstring(),
concat!(
"1110000000000000001111111111111111111111111111111111111111111111_",
"0101010101010101010101010101010101010101010101010101010101000000_",
"101010101010101010101010101010101010101010101010101010101010****",
)
);
Ok(())
}
#[test]
fn bytevec_write_lsb() -> Result<(), Infallible> {
let mut bv = ByteSink::new();
bv.write_lsbs(0xFFu8, 3)?;
bv.write_lsbs(0x0u64, 12)?;
bv.write_lsbs(0xFFFF_FFFFu32, 9)?;
bv.write_lsbs(0x0u16, 8)?;
assert_eq!(bv.to_bitstring(), "11100000_00000001_11111111_00000000");
let mut bv = ByteSink::new();
bv.write_lsbs(0xFFu8, 3)?;
bv.write_lsbs(0x0u64, 12)?;
bv.write_lsbs(0xFFFF_FFFFu32, 9)?;
bv.write_lsbs(0x0u16, 5)?;
assert_eq!(bv.to_bitstring(), "11100000_00000001_11111111_00000***");
Ok(())
}
#[test]
fn u64vec_write_lsb() -> Result<(), Infallible> {
let mut u64v = MemSink::<u64>::new();
u64v.write_msbs(0xFFu8, 3)?;
assert_eq!(
u64v.to_bitstring(),
"111*************************************************************"
);
u64v.write_msbs(0u16, 15)?;
assert_eq!(
u64v.to_bitstring(),
"111000000000000000**********************************************"
);
Ok(())
}
#[test]
fn u64vec_write_zeros() -> Result<(), Infallible> {
let mut u64v = MemSink::<u64>::new();
u64v.write_lsbs(0xFFu8, 3)?;
assert_eq!(
u64v.to_bitstring(),
"111*************************************************************"
);
u64v.write_zeros(15)?;
assert_eq!(
u64v.to_bitstring(),
"111000000000000000**********************************************"
);
Ok(())
}
#[test]
fn u64vec() -> Result<(), Infallible> {
let mut sink = MemSink::<u64>::new();
sink.write_msbs(0xFFFF_FFFFu32, 17)?;
assert_eq!(
sink.to_bitstring(),
"11111111111111111***********************************************"
);
assert_eq!(sink.len(), 17);
sink.write_bytes_aligned(&[0xCA, 0xFE])?;
assert_eq!(
sink.to_bitstring(),
"1111111111111111100000001100101011111110************************"
);
assert_eq!(sink.len(), 40);
sink.write_lsbs(1u16, 2)?;
assert_eq!(
sink.to_bitstring(),
"111111111111111110000000110010101111111001**********************"
);
assert_eq!(sink.len(), 42);
sink.write_lsbs(0xAAAA_AAAAu32, 31)?;
assert_eq!(
sink.to_bitstring(),
concat!(
"1111111111111111100000001100101011111110010101010101010101010101_",
"010101010*******************************************************"
)
);
assert_eq!(sink.len(), 73);
Ok(())
}
}
#[cfg(all(test, feature = "simd-nightly"))]
mod bench {
use super::*;
extern crate test;
use test::bench::Bencher;
use test::black_box;
#[bench]
fn u64sink_to_byte(b: &mut Bencher) {
let mut bytes = [0u8; 8191];
let mut memsink = MemSink::<u64>::new();
for t in 0..8191 {
memsink
.write_bytes_aligned(&[(t % 256) as u8])
.expect("should never fail.");
}
b.iter(|| black_box(&mut memsink).write_to_byte_slice(&mut bytes));
}
}