use std::cmp::min;
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::mem::size_of;
use std::ops::{BitOrAssign, BitXor, Index, Range, RangeFrom};
use num_traits::{Float, PrimInt};
use crate::endianness::Endianness;
use crate::num_traits::{IsSigned, UncheckedPrimitiveFloat, UncheckedPrimitiveInt};
use crate::{BitError, Result};
use std::borrow::{Borrow, Cow};
use std::convert::TryInto;
use std::rc::Rc;
const USIZE_SIZE: usize = size_of::<usize>();
const USIZE_BIT_SIZE: usize = USIZE_SIZE * 8;
pub(crate) enum Data<'a> {
Borrowed(&'a [u8]),
Owned(Rc<[u8]>),
}
impl<'a> Data<'a> {
pub fn as_slice(&self) -> &[u8] {
match self {
Data::Borrowed(bytes) => *bytes,
Data::Owned(bytes) => bytes.borrow(),
}
}
pub fn len(&self) -> usize {
self.as_slice().len()
}
pub fn to_owned(&self) -> Data<'static> {
let bytes = match self {
Data::Borrowed(bytes) => Rc::from(bytes.to_vec()),
Data::Owned(bytes) => Rc::clone(bytes),
};
Data::Owned(bytes)
}
}
impl<'a> Index<Range<usize>> for Data<'a> {
type Output = [u8];
fn index(&self, index: Range<usize>) -> &Self::Output {
self.as_slice().index(index)
}
}
impl<'a> Index<RangeFrom<usize>> for Data<'a> {
type Output = [u8];
fn index(&self, index: RangeFrom<usize>) -> &Self::Output {
self.as_slice().index(index)
}
}
impl<'a> Index<usize> for Data<'a> {
type Output = u8;
fn index(&self, index: usize) -> &Self::Output {
self.as_slice().index(index)
}
}
impl<'a> Clone for Data<'a> {
fn clone(&self) -> Self {
match self {
Data::Borrowed(bytes) => Data::Borrowed(*bytes),
Data::Owned(bytes) => Data::Owned(Rc::clone(bytes)),
}
}
}
pub struct BitReadBuffer<'a, E>
where
E: Endianness,
{
pub(crate) bytes: Data<'a>,
bit_len: usize,
endianness: PhantomData<E>,
slice: &'a [u8],
}
impl<'a, E> BitReadBuffer<'a, E>
where
E: Endianness,
{
pub fn new(bytes: &'a [u8], _endianness: E) -> Self {
let byte_len = bytes.len();
BitReadBuffer {
bytes: Data::Borrowed(bytes),
bit_len: byte_len * 8,
endianness: PhantomData,
slice: bytes,
}
}
pub fn to_owned(&self) -> BitReadBuffer<'static, E> {
let bytes = self.bytes.to_owned();
let byte_len = bytes.len();
let slice = unsafe { std::slice::from_raw_parts(bytes.as_slice().as_ptr(), bytes.len()) };
BitReadBuffer {
bytes,
bit_len: byte_len * 8,
endianness: PhantomData,
slice,
}
}
}
impl<E> BitReadBuffer<'static, E>
where
E: Endianness,
{
pub fn new_owned(bytes: Vec<u8>, _endianness: E) -> Self {
let byte_len = bytes.len();
let bytes = Data::Owned(Rc::from(bytes));
let slice = unsafe { std::slice::from_raw_parts(bytes.as_slice().as_ptr(), bytes.len()) };
BitReadBuffer {
bytes,
bit_len: byte_len * 8,
endianness: PhantomData,
slice,
}
}
}
pub(crate) fn get_bits_from_usize<E: Endianness>(
val: usize,
bit_offset: usize,
count: usize,
) -> usize {
let shifted = if E::is_le() {
val >> bit_offset
} else {
val >> (usize::BITS as usize - bit_offset - count)
};
let mask = !(usize::MAX << count);
shifted & mask
}
impl<'a, E> BitReadBuffer<'a, E>
where
E: Endianness,
{
pub fn bit_len(&self) -> usize {
self.bit_len
}
pub fn byte_len(&self) -> usize {
self.slice.len()
}
unsafe fn read_usize_bytes(&self, byte_index: usize, end: bool) -> [u8; USIZE_SIZE] {
if end {
let mut bytes = [0; USIZE_SIZE];
let count = min(USIZE_SIZE, self.slice.len() - byte_index);
bytes[0..count]
.copy_from_slice(self.slice.get_unchecked(byte_index..byte_index + count));
bytes
} else {
debug_assert!(byte_index + USIZE_SIZE <= self.slice.len());
self.slice
.get_unchecked(byte_index..byte_index + USIZE_SIZE)
.try_into()
.unwrap()
}
}
unsafe fn read_shifted_usize(&self, byte_index: usize, shift: usize, end: bool) -> usize {
let raw_bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
let raw_usize: usize = usize::from_le_bytes(raw_bytes);
raw_usize >> shift
}
unsafe fn read_usize(&self, position: usize, count: usize, end: bool) -> usize {
let byte_index = position / 8;
let bit_offset = position & 7;
let bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
let container = if E::is_le() {
usize::from_le_bytes(bytes)
} else {
usize::from_be_bytes(bytes)
};
get_bits_from_usize::<E>(container, bit_offset, count)
}
#[inline]
pub fn read_bool(&self, position: usize) -> Result<bool> {
let byte_index = position / 8;
let bit_offset = position & 7;
if position >= self.bit_len() {
return Err(BitError::NotEnoughData {
requested: 1,
bits_left: 0,
});
}
if let Some(byte) = self.slice.get(byte_index) {
if E::is_le() {
let shifted = byte >> bit_offset as u8;
Ok(shifted & 1u8 == 1)
} else {
let shifted = byte << bit_offset as u8;
Ok(shifted & 0b1000_0000u8 == 0b1000_0000u8)
}
} else {
Err(BitError::NotEnoughData {
requested: 1,
bits_left: 0,
})
}
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_bool_unchecked(&self, position: usize) -> bool {
let byte_index = position / 8;
let bit_offset = position & 7;
let byte = self.slice.get_unchecked(byte_index);
if E::is_le() {
let shifted = byte >> bit_offset;
shifted & 1u8 == 1
} else {
let shifted = byte << bit_offset;
shifted & 0b1000_0000u8 == 0b1000_0000u8
}
}
#[inline]
pub fn read_int<T>(&self, position: usize, count: usize) -> Result<T>
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor,
{
let type_bit_size = size_of::<T>() * 8;
if type_bit_size < count {
return Err(BitError::TooManyBits {
requested: count,
max: type_bit_size,
});
}
if position + count + USIZE_BIT_SIZE > self.bit_len() {
if position + count > self.bit_len() {
return if position > self.bit_len() {
Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
})
} else {
Err(BitError::NotEnoughData {
requested: count,
bits_left: self.bit_len() - position,
})
};
}
Ok(unsafe { self.read_int_unchecked(position, count, true) })
} else {
Ok(unsafe { self.read_int_unchecked(position, count, false) })
}
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_int_unchecked<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor,
{
let type_bit_size = size_of::<T>() * 8;
let bit_offset = position & 7;
let fit_usize = count + bit_offset < usize::BITS as usize;
let value = if fit_usize {
self.read_fit_usize(position, count, end)
} else {
self.read_no_fit_usize(position, count, end)
};
if count == type_bit_size {
value
} else {
self.make_signed(value, count)
}
}
#[inline]
unsafe fn read_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{
let raw = self.read_usize(position, count, end);
T::from_unchecked(raw)
}
unsafe fn read_no_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{
let mut left_to_read = count;
let mut acc = T::zero();
let max_read = (size_of::<usize>() - 1) * 8;
let mut read_pos = position;
let mut bit_offset = 0;
while left_to_read > 0 {
let bits_left = self.bit_len() - read_pos;
let read = min(min(left_to_read, max_read), bits_left);
let data = T::from_unchecked(self.read_usize(read_pos, read, end));
if E::is_le() {
acc |= data << bit_offset;
} else {
acc = acc << read;
acc |= data;
}
bit_offset += read;
read_pos += read;
left_to_read -= read;
}
acc
}
fn make_signed<T>(&self, value: T, count: usize) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor,
{
if count == 0 {
T::zero()
} else if T::is_signed() {
let sign_bit = value >> (count - 1) & T::one();
if sign_bit == T::one() {
value | (T::zero() - T::one()) ^ ((T::one() << count) - T::one())
} else {
value
}
} else {
value
}
}
#[inline]
pub fn read_bytes(&self, position: usize, byte_count: usize) -> Result<Cow<'a, [u8]>> {
if position + byte_count * 8 > self.bit_len() {
if position > self.bit_len() {
return Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
});
} else {
return Err(BitError::NotEnoughData {
requested: byte_count * 8,
bits_left: self.bit_len() - position,
});
}
}
Ok(unsafe { self.read_bytes_unchecked(position, byte_count) })
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_bytes_unchecked(&self, position: usize, byte_count: usize) -> Cow<'a, [u8]> {
let shift = position & 7;
if shift == 0 {
let byte_pos = position / 8;
return Cow::Borrowed(&self.slice[byte_pos..byte_pos + byte_count]);
}
let mut data = Vec::with_capacity(byte_count);
let mut byte_left = byte_count;
let mut read_pos = position / 8;
if E::is_le() {
while byte_left > USIZE_SIZE - 1 {
let raw = self.read_shifted_usize(read_pos, shift, false);
let bytes = if E::is_le() {
raw.to_le_bytes()
} else {
raw.to_be_bytes()
};
let read_bytes = USIZE_SIZE - 1;
let usable_bytes = &bytes[0..read_bytes];
data.extend_from_slice(usable_bytes);
read_pos += read_bytes;
byte_left -= read_bytes;
}
let bytes = self.read_shifted_usize(read_pos, shift, true).to_le_bytes();
let usable_bytes = &bytes[0..byte_left];
data.extend_from_slice(usable_bytes);
} else {
let mut pos = position;
while byte_left > 0 {
data.push(self.read_int_unchecked::<u8>(pos, 8, true));
byte_left -= 1;
pos += 8;
}
}
Cow::Owned(data)
}
#[inline]
pub fn read_string(&self, position: usize, byte_len: Option<usize>) -> Result<Cow<'a, str>> {
match byte_len {
Some(byte_len) => {
let bytes = self.read_bytes(position, byte_len)?;
let string = match bytes {
Cow::Owned(bytes) => Cow::Owned(
String::from_utf8(bytes)?
.trim_end_matches(char::from(0))
.to_string(),
),
Cow::Borrowed(bytes) => Cow::Borrowed(
std::str::from_utf8(bytes)
.map_err(|err| BitError::Utf8Error(err, bytes.len()))?
.trim_end_matches(char::from(0)),
),
};
Ok(string)
}
None => {
let bytes = self.read_string_bytes(position)?;
let string = match bytes {
Cow::Owned(bytes) => Cow::Owned(String::from_utf8(bytes)?),
Cow::Borrowed(bytes) => Cow::Borrowed(
std::str::from_utf8(bytes)
.map_err(|err| BitError::Utf8Error(err, bytes.len()))?,
),
};
Ok(string)
}
}
}
#[inline]
fn find_null_byte(&self, byte_index: usize) -> usize {
memchr::memchr(0, &self.slice[byte_index..])
.map(|index| index + byte_index)
.unwrap_or(self.slice.len()) }
#[inline]
fn read_string_bytes(&self, position: usize) -> Result<Cow<'a, [u8]>> {
let shift = position & 7;
if shift == 0 {
let byte_index = position / 8;
Ok(Cow::Borrowed(
&self.slice[byte_index..self.find_null_byte(byte_index)],
))
} else {
let mut acc = Vec::with_capacity(32);
if E::is_le() {
let mut byte_index = position / 8;
loop {
let shifted = unsafe { self.read_shifted_usize(byte_index, shift, true) };
let has_null = contains_zero_byte_non_top(shifted);
let bytes: [u8; USIZE_SIZE] = shifted.to_le_bytes();
let usable_bytes = &bytes[0..USIZE_SIZE - 1];
if has_null {
for i in 0..USIZE_SIZE - 1 {
if usable_bytes[i] == 0 {
acc.extend_from_slice(&usable_bytes[0..i]);
return Ok(Cow::Owned(acc));
}
}
}
acc.extend_from_slice(&usable_bytes[0..USIZE_SIZE - 1]);
byte_index += USIZE_SIZE - 1;
}
} else {
let mut pos = position;
loop {
let byte = self.read_int::<u8>(pos, 8)?;
pos += 8;
if byte == 0 {
return Ok(Cow::Owned(acc));
} else {
acc.push(byte);
}
}
}
}
}
#[inline]
pub fn read_float<T>(&self, position: usize) -> Result<T>
where
T: Float + UncheckedPrimitiveFloat,
{
let type_bit_size = size_of::<T>() * 8;
if position + type_bit_size + USIZE_BIT_SIZE > self.bit_len() {
if position + type_bit_size > self.bit_len() {
if position > self.bit_len() {
return Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
});
} else {
return Err(BitError::NotEnoughData {
requested: size_of::<T>() * 8,
bits_left: self.bit_len() - position,
});
}
}
Ok(unsafe { self.read_float_unchecked(position, true) })
} else {
Ok(unsafe { self.read_float_unchecked(position, false) })
}
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_float_unchecked<T>(&self, position: usize, end: bool) -> T
where
T: Float + UncheckedPrimitiveFloat,
{
if position & 7 == 0 {
let byte_pos = position / 8;
let bytes = self.slice[byte_pos..byte_pos + size_of::<T>()]
.try_into()
.unwrap();
T::from_bytes::<E>(bytes)
} else {
let int = self.read_int_unchecked(position, size_of::<T>() * 8, end);
T::from_int(int)
}
}
pub(crate) fn get_sub_buffer(&self, bit_len: usize) -> Result<Self> {
if bit_len > self.bit_len() {
return Err(BitError::NotEnoughData {
requested: bit_len,
bits_left: self.bit_len(),
});
}
Ok(BitReadBuffer {
bytes: self.bytes.clone(),
bit_len,
endianness: PhantomData,
slice: self.slice,
})
}
pub fn truncate(&mut self, bit_len: usize) -> Result<()> {
if bit_len > self.bit_len() {
return Err(BitError::NotEnoughData {
requested: bit_len,
bits_left: self.bit_len(),
});
}
self.bit_len = bit_len;
Ok(())
}
}
impl<'a, E: Endianness> From<&'a [u8]> for BitReadBuffer<'a, E> {
fn from(bytes: &'a [u8]) -> Self {
BitReadBuffer::new(bytes, E::endianness())
}
}
impl<'a, E: Endianness> From<Vec<u8>> for BitReadBuffer<'a, E> {
fn from(bytes: Vec<u8>) -> Self {
BitReadBuffer::new_owned(bytes, E::endianness())
}
}
impl<'a, E: Endianness> Clone for BitReadBuffer<'a, E> {
fn clone(&self) -> Self {
BitReadBuffer {
bytes: self.bytes.clone(),
bit_len: self.bit_len(),
endianness: PhantomData,
slice: self.slice,
}
}
}
impl<E: Endianness> Debug for BitReadBuffer<'_, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"BitBuffer {{ bit_len: {}, endianness: {} }}",
self.bit_len(),
E::as_string()
)
}
}
impl<'a, E: Endianness> PartialEq for BitReadBuffer<'a, E> {
fn eq(&self, other: &Self) -> bool {
if self.bit_len != other.bit_len {
return false;
}
if self.bit_len % 8 == 0 {
self.slice == other.slice
} else {
let bytes = self.bit_len / 8;
let bits_left = self.bit_len % 8;
if self.slice[0..bytes] != other.slice[0..bytes] {
return false;
}
let rest_self = self.read_int::<u8>(bytes * 8, bits_left).unwrap();
let rest_other = other.read_int::<u8>(bytes * 8, bits_left).unwrap();
rest_self == rest_other
}
}
}
#[inline(always)]
fn contains_zero_byte_non_top(x: usize) -> bool {
#[cfg(target_pointer_width = "64")]
const LO_USIZE: usize = 0x0001_0101_0101_0101;
#[cfg(target_pointer_width = "64")]
const HI_USIZE: usize = 0x0080_8080_8080_8080;
#[cfg(target_pointer_width = "32")]
const LO_USIZE: usize = 0x000_10101;
#[cfg(target_pointer_width = "32")]
const HI_USIZE: usize = 0x0080_8080;
x.wrapping_sub(LO_USIZE) & !x & HI_USIZE != 0
}
#[cfg(feature = "serde")]
use serde::{de, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde")]
impl<'a, E: Endianness> Serialize for BitReadBuffer<'a, E> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut data = self.read_bytes(0, self.bit_len() / 8).unwrap().to_vec();
let bits_left = self.bit_len() % 8;
if bits_left > 0 {
data.push(self.read_int((self.bit_len() / 8) * 8, bits_left).unwrap());
}
let mut s = serializer.serialize_struct("BitReadBuffer", 3)?;
s.serialize_field("data", &data)?;
s.serialize_field("bit_length", &self.bit_len())?;
s.end()
}
}
#[cfg(feature = "serde")]
impl<'de, E: Endianness> Deserialize<'de> for BitReadBuffer<'static, E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct BitData {
data: Vec<u8>,
bit_length: usize,
}
let data = BitData::deserialize(deserializer)?;
let mut buffer = BitReadBuffer::new_owned(data.data, E::endianness());
buffer
.truncate(data.bit_length)
.map_err(de::Error::custom)?;
Ok(buffer)
}
}
#[cfg(feature = "serde")]
#[test]
fn test_serde_roundtrip() {
use crate::LittleEndian;
let mut buffer = BitReadBuffer::new_owned(vec![55; 8], LittleEndian);
buffer.truncate(61).unwrap();
let json = serde_json::to_string(&buffer).unwrap();
let result: BitReadBuffer<LittleEndian> = serde_json::from_str(&json).unwrap();
assert_eq!(result, buffer);
}