use core::iter;
use std::borrow::Cow;
use bitvec::{mem::bits_of, order::Msb0, slice::BitSlice, vec::BitVec, view::AsMutBits};
use impl_tools::autoimpl;
use crate::{
Context, Error, StringError,
adapters::{Checkpoint, Join, MapErr, Tee},
ser::BitWriter,
};
use super::{BitUnpack, r#as::BitUnpackAs};
#[autoimpl(for <R: trait + ?Sized> &mut R, Box<R>)]
pub trait BitReader<'de> {
type Error: Error;
fn bits_left(&self) -> usize;
fn read_bit(&mut self) -> Result<Option<bool>, Self::Error>;
#[inline]
fn read_bits_into(&mut self, dst: &mut BitSlice<u8, Msb0>) -> Result<usize, Self::Error> {
for (i, mut bit) in dst.iter_mut().enumerate() {
let Some(read) = self.read_bit()? else {
return Ok(i);
};
*bit = read;
}
Ok(dst.len())
}
#[inline]
fn read_bits(&mut self, mut n: usize) -> Result<Cow<'de, BitSlice<u8, Msb0>>, Self::Error> {
let mut buf = BitVec::repeat(false, n);
n = self.read_bits_into(&mut buf)?;
buf.truncate(n);
Ok(Cow::Owned(buf))
}
#[inline]
fn skip(&mut self, n: usize) -> Result<usize, Self::Error> {
for i in 1..=n {
if self.read_bit()?.is_none() {
return Ok(i);
}
}
Ok(n)
}
}
pub trait BitReaderExt<'de>: BitReader<'de> {
#[inline]
fn is_empty(&self) -> bool {
self.bits_left() == 0
}
#[inline]
fn read_bytes_into(&mut self, mut dst: impl AsMut<[u8]>) -> Result<usize, Self::Error> {
self.read_bits_into(dst.as_mut_bits())
}
#[inline]
fn read_bytes_array<const N: usize>(&mut self) -> Result<[u8; N], Self::Error> {
let mut arr = [0; N];
let n = self.read_bits_into(arr.as_mut_bits())?;
if n != N * bits_of::<u8>() {
return Err(Error::custom("EOF"));
}
Ok(arr)
}
#[inline]
fn unpack<T>(&mut self, args: T::Args) -> Result<T, Self::Error>
where
T: BitUnpack<'de>,
{
T::unpack(self, args)
}
#[inline]
fn unpack_iter<'a, T>(
&'a mut self,
args: T::Args,
) -> impl Iterator<Item = Result<T, Self::Error>> + 'a
where
T: BitUnpack<'de>,
T::Args: Clone + 'a,
{
iter::repeat_with(move || self.unpack::<T>(args.clone()))
.enumerate()
.map(|(i, v)| v.with_context(|| format!("[{i}]")))
}
#[inline]
fn unpack_as<T, As>(&mut self, args: As::Args) -> Result<T, Self::Error>
where
As: BitUnpackAs<'de, T> + ?Sized,
{
As::unpack_as(self, args)
}
#[inline]
fn unpack_iter_as<'a, T, As>(
&'a mut self,
args: As::Args,
) -> impl Iterator<Item = Result<T, Self::Error>> + 'a
where
As: BitUnpackAs<'de, T> + ?Sized,
As::Args: Clone + 'a,
{
iter::repeat_with(move || self.unpack_as::<_, As>(args.clone()))
.enumerate()
.map(|(i, v)| v.with_context(|| format!("[{i}]")))
}
#[inline]
fn as_mut(&mut self) -> &mut Self {
self
}
#[inline]
fn map_err<F>(self, f: F) -> MapErr<Self, F>
where
Self: Sized,
{
MapErr { inner: self, f }
}
#[inline]
fn tee<W>(self, writer: W) -> Tee<Self, W>
where
Self: Sized,
W: BitWriter,
{
Tee::new(self, writer)
}
#[inline]
fn checkpoint(self) -> Checkpoint<Self>
where
Self: Sized,
{
Checkpoint::new(self)
}
#[inline]
fn join<R>(self, next: R) -> Join<Self, R>
where
Self: Sized,
R: BitReader<'de>,
{
Join::new(self, next)
}
}
impl<'de, T> BitReaderExt<'de> for T where T: BitReader<'de> + ?Sized {}
impl<'de> BitReader<'de> for &'de BitSlice<u8, Msb0> {
type Error = StringError;
#[inline]
fn bits_left(&self) -> usize {
self.len()
}
#[inline]
fn read_bit(&mut self) -> Result<Option<bool>, Self::Error> {
let Some((bit, rest)) = self.split_first() else {
return Ok(None);
};
*self = rest;
Ok(Some(*bit))
}
#[inline]
fn read_bits_into(&mut self, dst: &mut BitSlice<u8, Msb0>) -> Result<usize, Self::Error> {
let n = dst.len().min(self.bits_left());
let (v, rest) = self.split_at(n);
dst[..n].copy_from_bitslice(v);
*self = rest;
Ok(n)
}
#[inline]
fn read_bits(&mut self, n: usize) -> Result<Cow<'de, BitSlice<u8, Msb0>>, Self::Error> {
let (v, rest) = self.split_at(n.min(self.bits_left()));
*self = rest;
Ok(Cow::Borrowed(v))
}
#[inline]
fn skip(&mut self, mut n: usize) -> Result<usize, Self::Error> {
n = n.min(self.bits_left());
let (_, rest) = self.split_at(n);
*self = rest;
Ok(n)
}
}
impl<'de> BitReader<'de> for &[bool] {
type Error = StringError;
#[inline]
fn bits_left(&self) -> usize {
self.len()
}
#[inline]
fn read_bit(&mut self) -> Result<Option<bool>, Self::Error> {
let Some((bit, rest)) = self.split_first() else {
return Ok(None);
};
*self = rest;
Ok(Some(*bit))
}
#[inline]
fn skip(&mut self, mut n: usize) -> Result<usize, Self::Error> {
n = n.min(self.bits_left());
let (_, rest) = self.split_at(n);
*self = rest;
Ok(n)
}
}
impl<'de> BitReader<'de> for &str {
type Error = StringError;
#[inline]
fn bits_left(&self) -> usize {
self.len()
}
#[inline]
fn read_bit(&mut self) -> Result<Option<bool>, Self::Error> {
let Some((char, rest)) = self.split_at_checked(1) else {
return Ok(None);
};
let bit = match char {
"0" => false,
"1" => true,
_ => {
return Err(Error::custom(format!(
"invalid character: expected '0' or '1', got: {char}",
)));
}
};
*self = rest;
Ok(Some(bit))
}
}