use core::{
fmt::{Binary, LowerHex},
mem::size_of,
};
use bitvec::{mem::bits_of, order::Msb0, vec::BitVec, view::AsBits};
use num_bigint::{BigInt, BigUint};
use num_traits::{PrimInt, ToBytes};
use crate::{
Error,
r#as::{Same, VarLen},
de::{BitReader, BitReaderExt, BitUnpackAs},
ser::{BitPackAs, BitWriter, BitWriterExt},
};
use super::NBits;
impl<const BITS: usize> BitPackAs<BigUint> for NBits<BITS> {
type Args = ();
#[inline]
fn pack_as<W>(source: &BigUint, writer: &mut W, _: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
let used_bits = source.bits() as usize;
if BITS < used_bits {
return Err(Error::custom(format!(
"{source:#b} cannot be packed into {BITS} bits"
)));
}
writer.repeat_bit(BITS - used_bits, false)?;
let bytes = source.to_bytes_be();
let mut bits = bytes.as_bits::<Msb0>();
bits = &bits[bits.len() - used_bits..];
writer.write_bitslice(bits)?;
Ok(())
}
}
impl<'de, const BITS: usize> BitUnpackAs<'de, BigUint> for NBits<BITS> {
type Args = ();
#[inline]
fn unpack_as<R>(reader: &mut R, _: Self::Args) -> Result<BigUint, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let total_bits = (BITS + 7) & !7;
let mut bits = BitVec::<u8, Msb0>::repeat(false, total_bits);
reader.read_bits_into(&mut bits[total_bits - BITS..])?;
Ok(BigUint::from_bytes_be(bits.as_raw_slice()))
}
}
impl<const BITS: usize> BitPackAs<BigInt> for NBits<BITS> {
type Args = ();
#[inline]
fn pack_as<W>(source: &BigInt, writer: &mut W, _: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
let used_bits = source.bits() as usize;
if BITS < used_bits {
return Err(Error::custom(format!(
"{source:#b} cannot be packed into {BITS} bits"
)));
}
writer.repeat_bit(BITS - used_bits, false)?;
let bytes = source.to_signed_bytes_be();
let mut bits = bytes.as_bits::<Msb0>();
bits = &bits[bits.len() - used_bits..];
writer.write_bitslice(bits)?;
Ok(())
}
}
impl<'de, const BITS: usize> BitUnpackAs<'de, BigInt> for NBits<BITS> {
type Args = ();
#[inline]
fn unpack_as<R>(reader: &mut R, _: Self::Args) -> Result<BigInt, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let total_bits = (BITS + 7) & !7;
let mut bits = BitVec::<u8, Msb0>::repeat(false, total_bits);
reader.read_bits_into(&mut bits[total_bits - BITS..])?;
Ok(BigInt::from_signed_bytes_be(bits.as_raw_slice()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VarInt<const BITS_FOR_BYTES_LEN: usize>;
impl<const BITS_FOR_BYTES_LEN: usize> BitPackAs<BigUint> for VarInt<BITS_FOR_BYTES_LEN> {
type Args = ();
#[inline]
fn pack_as<W>(source: &BigUint, writer: &mut W, _: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
let bytes = if source != &BigUint::ZERO {
source.to_bytes_be()
} else {
Vec::new()
};
writer.pack_as::<_, VarLen<Vec<Same>, BITS_FOR_BYTES_LEN>>(bytes, ())?;
Ok(())
}
}
impl<'de, const BITS_FOR_BYTES_LEN: usize> BitUnpackAs<'de, BigUint>
for VarInt<BITS_FOR_BYTES_LEN>
{
type Args = ();
#[inline]
fn unpack_as<R>(reader: &mut R, _: Self::Args) -> Result<BigUint, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let mut bits = BitVec::<u8, Msb0>::from_vec(
reader.unpack_as::<_, VarLen<Vec<Same>, BITS_FOR_BYTES_LEN>>(())?,
);
let total_bits = (bits.len() + 7) & !7;
let shift = total_bits - bits.len();
bits.resize(total_bits, false);
bits.shift_right(shift);
Ok(BigUint::from_bytes_be(bits.as_raw_slice()))
}
}
impl<const BITS_FOR_BYTES_LEN: usize> BitPackAs<BigInt> for VarInt<BITS_FOR_BYTES_LEN> {
type Args = ();
#[inline]
fn pack_as<W>(source: &BigInt, writer: &mut W, _: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
writer.pack_as::<_, VarLen<Same, BITS_FOR_BYTES_LEN>>(source.to_signed_bytes_be(), ())?;
Ok(())
}
}
impl<'de, const BITS_FOR_BYTES_LEN: usize> BitUnpackAs<'de, BigInt> for VarInt<BITS_FOR_BYTES_LEN> {
type Args = ();
#[inline]
fn unpack_as<R>(reader: &mut R, _: Self::Args) -> Result<BigInt, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let mut bits = BitVec::<u8, Msb0>::from_vec(
reader.unpack_as::<_, VarLen<Same, BITS_FOR_BYTES_LEN>>(())?,
);
let total_bits = (bits.len() + 7) & !7;
let shift = total_bits - bits.len();
bits.resize(total_bits, false);
bits.shift_right(shift);
Ok(BigInt::from_signed_bytes_be(bits.as_raw_slice()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VarNBits;
impl<T> BitPackAs<T> for VarNBits
where
T: PrimInt + Binary + ToBytes,
{
type Args = u32;
#[inline]
fn pack_as<W>(source: &T, writer: &mut W, num_bits: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
let size_bits: u32 = bits_of::<T>() as u32;
let leading_zeroes = source.leading_zeros();
let used_bits = size_bits - leading_zeroes;
if num_bits < used_bits {
return Err(Error::custom(format!(
"{source:0b} cannot be packed into {num_bits} bits",
)));
}
let arr = source.to_be_bytes();
let bits = arr.as_bits();
writer.write_bitslice(&bits[bits.len() - num_bits as usize..])?;
Ok(())
}
}
impl<'de, T> BitUnpackAs<'de, T> for VarNBits
where
T: PrimInt,
{
type Args = u32;
#[inline]
fn unpack_as<R>(reader: &mut R, num_bits: Self::Args) -> Result<T, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let size_bits: u32 = bits_of::<T>() as u32;
if num_bits > size_bits {
return Err(Error::custom("excessive bits for the type"));
}
let mut v: T = T::zero();
for bit in reader.unpack_iter::<bool>(()).take(num_bits as usize) {
v = v << 1;
v = v | if bit? { T::one() } else { T::zero() };
}
Ok(v)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VarNBytes;
impl<T> BitPackAs<T> for VarNBytes
where
T: PrimInt + LowerHex + ToBytes,
{
type Args = u32;
#[inline]
fn pack_as<W>(source: &T, writer: &mut W, num_bytes: Self::Args) -> Result<(), W::Error>
where
W: BitWriter + ?Sized,
{
let size_bytes: u32 = size_of::<T>() as u32;
let leading_zeroes = source.leading_zeros();
let used_bytes = size_bytes - leading_zeroes / 8;
if num_bytes < used_bytes {
return Err(Error::custom(format!(
"{source:0x} cannot be packed into {num_bytes} bytes",
)));
}
let arr = source.to_be_bytes();
let bytes = arr.as_ref();
writer.write_bitslice((&bytes[bytes.len() - num_bytes as usize..]).as_bits())?;
Ok(())
}
}
impl<'de, T> BitUnpackAs<'de, T> for VarNBytes
where
T: PrimInt,
{
type Args = u32;
#[inline]
fn unpack_as<R>(reader: &mut R, num_bytes: Self::Args) -> Result<T, R::Error>
where
R: BitReader<'de> + ?Sized,
{
let size_bytes: u32 = size_of::<T>() as u32;
if num_bytes > size_bytes {
return Err(Error::custom("excessive bits for type"));
}
let mut v: T = T::zero();
for byte in reader.unpack_iter::<u8>(()).take(num_bytes as usize) {
v = v << 8;
v = v | T::from(byte?).unwrap();
}
Ok(v)
}
}