use std::io::{self, ErrorKind, Read, Seek, SeekFrom, Write};
#[cfg(feature = "async")]
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
#[cfg(feature = "async")]
use crate::async_io::{AsyncRead, AsyncSeek, AsyncWrite};
pub const INVALID_ALIGNMENT_MESSAGE: &str = "invalid alignment";
pub const INVALID_BIT_WIDTH_MESSAGE: &str = "bit width exceeds input buffer";
#[derive(Debug)]
pub struct BitReader<R> {
inner: R,
octet: u8,
remaining_bits: u8,
}
impl<R> BitReader<R> {
pub const fn new(inner: R) -> Self {
Self {
inner,
octet: 0,
remaining_bits: 0,
}
}
pub const fn is_aligned(&self) -> bool {
self.remaining_bits == 0
}
}
impl<R: Read> BitReader<R> {
pub fn read_bits(&mut self, width: usize) -> io::Result<Vec<u8>> {
let byte_len = width.div_ceil(8);
let bit_offset = (byte_len * 8) - width;
let mut data = vec![0_u8; byte_len];
for index in 0..width {
if self.read_bit()? {
let bit_index = bit_offset + index;
let byte_index = bit_index / 8;
let within_byte = 7 - (bit_index % 8);
data[byte_index] |= 1 << within_byte;
}
}
Ok(data)
}
pub fn read_bit(&mut self) -> io::Result<bool> {
if self.remaining_bits == 0 {
let mut buf = [0_u8; 1];
let read = self.inner.read(&mut buf)?;
if read == 0 {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
));
}
self.octet = buf[0];
self.remaining_bits = 8;
}
self.remaining_bits -= 1;
Ok((self.octet >> self.remaining_bits) & 0x01 != 0)
}
}
impl<R: Read> Read for BitReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
self.inner.read(buf)
}
}
impl<R: Read + Seek> Seek for BitReader<R> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
if matches!(pos, SeekFrom::Current(_)) && !self.is_aligned() {
return Err(invalid_alignment());
}
let next = self.inner.seek(pos)?;
self.remaining_bits = 0;
Ok(next)
}
}
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
#[derive(Debug)]
pub struct AsyncBitReader<R> {
inner: R,
octet: u8,
remaining_bits: u8,
}
#[cfg(feature = "async")]
impl<R> AsyncBitReader<R> {
pub const fn new(inner: R) -> Self {
Self {
inner,
octet: 0,
remaining_bits: 0,
}
}
pub const fn is_aligned(&self) -> bool {
self.remaining_bits == 0
}
}
#[cfg(feature = "async")]
impl<R: AsyncRead + Unpin> AsyncBitReader<R> {
pub async fn read_bits(&mut self, width: usize) -> io::Result<Vec<u8>> {
let byte_len = width.div_ceil(8);
let bit_offset = (byte_len * 8) - width;
let mut data = vec![0_u8; byte_len];
for index in 0..width {
if self.read_bit().await? {
let bit_index = bit_offset + index;
let byte_index = bit_index / 8;
let within_byte = 7 - (bit_index % 8);
data[byte_index] |= 1 << within_byte;
}
}
Ok(data)
}
pub async fn read_bit(&mut self) -> io::Result<bool> {
if self.remaining_bits == 0 {
let mut buf = [0_u8; 1];
self.inner.read_exact(&mut buf).await?;
self.octet = buf[0];
self.remaining_bits = 8;
}
self.remaining_bits -= 1;
Ok((self.octet >> self.remaining_bits) & 0x01 != 0)
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
self.inner.read_exact(buf).await.map(|_| ())
}
}
#[cfg(feature = "async")]
impl<R: AsyncRead + AsyncSeek + Unpin> AsyncBitReader<R> {
pub async fn stream_position(&mut self) -> io::Result<u64> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
self.inner.stream_position().await
}
pub async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
if matches!(pos, SeekFrom::Current(_)) && !self.is_aligned() {
return Err(invalid_alignment());
}
let next = self.inner.seek(pos).await?;
self.remaining_bits = 0;
Ok(next)
}
}
#[derive(Debug)]
pub struct BitWriter<W> {
inner: W,
octet: u8,
written_bits: u8,
}
impl<W> BitWriter<W> {
pub const fn new(inner: W) -> Self {
Self {
inner,
octet: 0,
written_bits: 0,
}
}
pub const fn is_aligned(&self) -> bool {
self.written_bits == 0
}
pub fn into_inner(self) -> io::Result<W> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
Ok(self.inner)
}
}
impl<W: Write> BitWriter<W> {
pub fn write_bits(&mut self, data: &[u8], width: usize) -> io::Result<()> {
let total_bits = data.len() * 8;
if width > total_bits {
return Err(io::Error::new(
ErrorKind::InvalidInput,
INVALID_BIT_WIDTH_MESSAGE,
));
}
for index in (total_bits - width)..total_bits {
let byte_index = index / 8;
let within_byte = 7 - (index % 8);
self.write_bit((data[byte_index] >> within_byte) & 0x01 != 0)?;
}
Ok(())
}
pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
if bit {
self.octet |= 1 << (7 - self.written_bits);
}
self.written_bits += 1;
if self.written_bits == 8 {
self.inner.write_all(&[self.octet])?;
self.octet = 0;
self.written_bits = 0;
}
Ok(())
}
}
impl<W: Write> Write for BitWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
#[derive(Debug)]
pub struct AsyncBitWriter<W> {
inner: W,
octet: u8,
written_bits: u8,
}
#[cfg(feature = "async")]
impl<W> AsyncBitWriter<W> {
pub const fn new(inner: W) -> Self {
Self {
inner,
octet: 0,
written_bits: 0,
}
}
pub const fn is_aligned(&self) -> bool {
self.written_bits == 0
}
}
#[cfg(feature = "async")]
impl<W: AsyncWrite + Unpin> AsyncBitWriter<W> {
pub async fn write_bits(&mut self, data: &[u8], width: usize) -> io::Result<()> {
let total_bits = data.len() * 8;
if width > total_bits {
return Err(io::Error::new(
ErrorKind::InvalidInput,
INVALID_BIT_WIDTH_MESSAGE,
));
}
for index in (total_bits - width)..total_bits {
let byte_index = index / 8;
let within_byte = 7 - (index % 8);
self.write_bit((data[byte_index] >> within_byte) & 0x01 != 0)
.await?;
}
Ok(())
}
pub async fn write_bit(&mut self, bit: bool) -> io::Result<()> {
if bit {
self.octet |= 1 << (7 - self.written_bits);
}
self.written_bits += 1;
if self.written_bits == 8 {
self.inner.write_all(&[self.octet]).await?;
self.octet = 0;
self.written_bits = 0;
}
Ok(())
}
pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
self.inner.write_all(buf).await
}
pub async fn flush(&mut self) -> io::Result<()> {
self.inner.flush().await
}
pub fn into_inner(self) -> io::Result<W> {
if !self.is_aligned() {
return Err(invalid_alignment());
}
Ok(self.inner)
}
}
fn invalid_alignment() -> io::Error {
io::Error::new(ErrorKind::InvalidInput, INVALID_ALIGNMENT_MESSAGE)
}