use core::mem;
use bitvec::{
mem::bits_of,
order::Msb0,
view::{AsBits, AsMutBits},
};
use crate::{
Error,
r#as::{AsBytes, NBits},
de::{BitReader, BitReaderExt, BitUnpack, r#as::BitUnpackAs},
ser::{BitPack, BitWriter, BitWriterExt, r#as::BitPackAs},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ConstBit<const VALUE: bool>;
impl<const VALUE: bool> BitPack for ConstBit<VALUE> {
#[inline]
fn pack<W>(&self, writer: W) -> Result<(), W::Error>
where
W: BitWriter,
{
VALUE.pack(writer)
}
}
impl<'de, const VALUE: bool> BitUnpack<'de> for ConstBit<VALUE> {
#[inline]
fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
where
R: BitReader<'de>,
{
if VALUE != reader.unpack::<bool>()? {
Err(Error::custom(format!(
"expected {:#b}, got {:#b}",
VALUE as u8, !VALUE as u8
)))
} else {
Ok(Self)
}
}
}
macro_rules! impl_bit_serde_for_integers {
($($t:tt)+) => {$(
impl BitPack for $t {
#[inline]
fn pack<W>(&self, mut writer: W) -> Result<(), W::Error>
where
W: BitWriter,
{
writer.pack_as::<_, AsBytes>(self.to_be_bytes())?;
Ok(())
}
}
impl<'de> BitUnpack<'de> for $t {
#[inline]
fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
where
R: BitReader<'de>,
{
reader.read_bytes_array().map(Self::from_be_bytes)
}
}
impl<const BITS: usize> BitPackAs<$t> for NBits<BITS> {
#[inline]
fn pack_as<W>(source: &$t, mut writer: W) -> Result<(), W::Error>
where
W: BitWriter,
{
const BITS_SIZE: usize = bits_of::<$t>();
assert!(BITS <= BITS_SIZE, "excessive bits for type");
if BITS < BITS_SIZE - source.leading_zeros() as usize {
return Err(Error::custom(
format!("{source:#b} cannot be packed into {BITS} bits"),
));
}
let bytes = source.to_be_bytes();
let bits = bytes.as_bits::<Msb0>();
writer.write_bitslice(&bits[bits.len() - BITS..])?;
Ok(())
}
}
impl<'de, const BITS: usize> BitUnpackAs<'de, $t> for NBits<BITS> {
#[inline]
fn unpack_as<R>(mut reader: R) -> Result<$t, R::Error>
where
R: BitReader<'de>,
{
const BITS_SIZE: usize = bits_of::<$t>();
assert!(BITS <= BITS_SIZE, "excessive bits for type");
let mut arr = [0u8; mem::size_of::<$t>()];
let arr_bits = &mut arr.as_mut_bits()[BITS_SIZE - BITS..];
if reader.read_bits_into(arr_bits)? != arr_bits.len() {
return Err(Error::custom("EOF"));
}
Ok($t::from_be_bytes(arr))
}
}
)+};
}
impl_bit_serde_for_integers! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}
macro_rules! const_uint {
($($vis:vis $name:ident<$typ:tt>)+) => {$(
#[doc = concat!("Constant version of `", stringify!($typ), "`")]
#[doc = concat!(
"Reads `", stringify!($typ), "` and returns an error
if it didn't match the type parameter.",
)]
#[doc = concat!("# integer::", stringify!($name), ",")]
#[doc = concat!("# buff.pack::<[", stringify!($typ), "; 2]>([123; 2])?;")]
#[doc = concat!("reader.unpack::<", stringify!($name), "<123>>()?;")]
#[doc = concat!("let got: ", stringify!($typ), " = reader.unpack()?;")]
#[doc = concat!(
"Writes `", stringify!($typ), "` as specified in type parameter."
)]
#[doc = concat!("# integer::", stringify!($name), ",")]
#[doc = concat!("writer.pack(", stringify!($name), "::<123>)?;")]
#[doc = concat!("writer.pack::<", stringify!($typ), ">(123)?;")]
#[doc = concat!(
"# assert_eq!(reader.unpack::<[", stringify!($typ), "; 2]>()?, [123; 2]);"
)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
$vis struct $name<const VALUE: $typ, const BITS: usize = { bits_of::<$typ>() }>;
impl<const VALUE: $typ, const BITS: usize> BitPack for $name<VALUE, BITS> {
#[inline]
fn pack<W>(&self, mut writer: W) -> Result<(), W::Error>
where
W: BitWriter,
{
writer.pack_as::<_, NBits<BITS>>(VALUE)?;
Ok(())
}
}
impl<'de, const VALUE: $typ, const BITS: usize> BitUnpack<'de> for $name<VALUE, BITS> {
#[inline]
fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
where
R: BitReader<'de>,
{
let v = reader.unpack_as::<$typ, NBits<BITS>>()?;
if v != VALUE {
return Err(Error::custom(format!(
"expected {VALUE:#b}, got: {v:#b}"
)));
}
Ok(Self)
}
}
)+};
}
const_uint! {
pub ConstU8 <u8>
pub ConstI8 <i8>
pub ConstU16 <u16>
pub ConstI16 <i16>
pub ConstU32 <u32>
pub ConstI32 <i32>
pub ConstU64 <u64>
pub ConstI64 <i64>
pub ConstU128<u128>
pub ConstI128<i128>
}
#[cfg(test)]
mod tests {
use bitvec::{bits, order::Msb0};
use num_bigint::BigUint;
use crate::{
ser::{r#as::pack_as, pack},
tests::{assert_pack_unpack_as_eq, assert_pack_unpack_eq},
};
use super::*;
#[test]
fn store_uint() {
assert_eq!(
pack(0xFD_FE_u16).unwrap(),
bits![u8, Msb0; 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
)
}
#[test]
fn serde_uint() {
assert_pack_unpack_eq(12345_u32);
}
#[test]
fn store_nbits_uint() {
assert_eq!(
pack_as::<_, NBits<7>>(0x7E).unwrap(),
bits![u8, Msb0; 1, 1, 1, 1, 1, 1, 0],
)
}
#[test]
fn nbits_one_bit() {
assert_eq!(pack_as::<_, NBits<1>>(0b1).unwrap(), pack(true).unwrap())
}
#[test]
fn store_nbits_same_uint() {
const N: u8 = 231;
assert_eq!(pack(N).unwrap(), pack_as::<_, NBits<8>>(N).unwrap())
}
#[test]
fn serde_nbits_uint() {
assert_pack_unpack_as_eq::<u8, NBits<7>>(0x7E);
}
#[test]
fn serde_big_nbits() {
assert_pack_unpack_as_eq::<BigUint, NBits<100>>(12345_u64.into());
}
}