1use std::cmp::min;
2use std::fmt;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5use std::mem::size_of;
6use std::ops::{BitOrAssign, BitXor, Index, Range, RangeFrom};
7
8use num_traits::{Float, PrimInt, WrappingSub};
9
10use crate::endianness::Endianness;
11use crate::num_traits::{IsSigned, UncheckedPrimitiveFloat, UncheckedPrimitiveInt};
12use crate::{BitError, Result};
13use std::borrow::{Borrow, Cow};
14use std::convert::TryInto;
15use std::rc::Rc;
16
17const USIZE_SIZE: usize = size_of::<usize>();
18const USIZE_BIT_SIZE: usize = USIZE_SIZE * 8;
19
20pub(crate) enum Data<'a> {
22 Borrowed(&'a [u8]),
23 Owned(Rc<[u8]>),
24}
25
26impl Data<'_> {
27 pub fn as_slice(&self) -> &[u8] {
28 match self {
29 Data::Borrowed(bytes) => bytes,
30 Data::Owned(bytes) => bytes.borrow(),
31 }
32 }
33
34 pub fn len(&self) -> usize {
35 self.as_slice().len()
36 }
37
38 pub fn to_owned(&self) -> Data<'static> {
39 let bytes = match self {
40 Data::Borrowed(bytes) => Rc::from(bytes.to_vec()),
41 Data::Owned(bytes) => Rc::clone(bytes),
42 };
43 Data::Owned(bytes)
44 }
45}
46
47impl Index<Range<usize>> for Data<'_> {
48 type Output = [u8];
49
50 fn index(&self, index: Range<usize>) -> &Self::Output {
51 self.as_slice().index(index)
52 }
53}
54
55impl Index<RangeFrom<usize>> for Data<'_> {
56 type Output = [u8];
57
58 fn index(&self, index: RangeFrom<usize>) -> &Self::Output {
59 self.as_slice().index(index)
60 }
61}
62
63impl Index<usize> for Data<'_> {
64 type Output = u8;
65
66 fn index(&self, index: usize) -> &Self::Output {
67 self.as_slice().index(index)
68 }
69}
70
71impl Clone for Data<'_> {
72 fn clone(&self) -> Self {
73 match self {
74 Data::Borrowed(bytes) => Data::Borrowed(bytes),
75 Data::Owned(bytes) => Data::Owned(Rc::clone(bytes)),
76 }
77 }
78}
79
80pub struct BitReadBuffer<'a, E>
100where
101 E: Endianness,
102{
103 pub(crate) bytes: Data<'a>,
104 bit_len: usize,
105 endianness: PhantomData<E>,
106 slice: &'a [u8],
107}
108
109impl<'a, E> BitReadBuffer<'a, E>
110where
111 E: Endianness,
112{
113 pub fn new(bytes: &'a [u8], _endianness: E) -> Self {
127 let byte_len = bytes.len();
128
129 BitReadBuffer {
130 bytes: Data::Borrowed(bytes),
131 bit_len: byte_len * 8,
132 endianness: PhantomData,
133 slice: bytes,
134 }
135 }
136
137 pub fn to_owned(&self) -> BitReadBuffer<'static, E> {
141 let bytes = self.bytes.to_owned();
142 let byte_len = bytes.len();
143
144 let slice = unsafe { std::slice::from_raw_parts(bytes.as_slice().as_ptr(), bytes.len()) };
149
150 BitReadBuffer {
151 bytes,
152 bit_len: byte_len * 8,
153 endianness: PhantomData,
154 slice,
155 }
156 }
157}
158
159impl<E> BitReadBuffer<'static, E>
160where
161 E: Endianness,
162{
163 pub fn new_owned(bytes: Vec<u8>, _endianness: E) -> Self {
177 let byte_len = bytes.len();
178 let bytes = Data::Owned(Rc::from(bytes));
179
180 let slice = unsafe { std::slice::from_raw_parts(bytes.as_slice().as_ptr(), bytes.len()) };
185
186 BitReadBuffer {
187 bytes,
188 bit_len: byte_len * 8,
189 endianness: PhantomData,
190 slice,
191 }
192 }
193}
194
195pub(crate) fn get_bits_from_usize<E: Endianness>(
196 val: usize,
197 bit_offset: usize,
198 count: usize,
199) -> usize {
200 let shifted = if E::is_le() {
201 val >> bit_offset
202 } else {
203 val >> (usize::BITS as usize - bit_offset - count)
204 };
205 let mask = !(usize::MAX << count);
206 shifted & mask
207}
208
209impl<'a, E> BitReadBuffer<'a, E>
210where
211 E: Endianness,
212{
213 pub fn bit_len(&self) -> usize {
215 self.bit_len
216 }
217
218 pub fn byte_len(&self) -> usize {
220 self.slice.len()
221 }
222
223 unsafe fn read_usize_bytes(&self, byte_index: usize, end: bool) -> [u8; USIZE_SIZE] {
224 if end {
225 let mut bytes = [0; USIZE_SIZE];
226 let count = min(USIZE_SIZE, self.slice.len() - byte_index);
227 bytes[0..count]
228 .copy_from_slice(self.slice.get_unchecked(byte_index..byte_index + count));
229 bytes
230 } else {
231 debug_assert!(byte_index + USIZE_SIZE <= self.slice.len());
232 self.slice
236 .get_unchecked(byte_index..byte_index + USIZE_SIZE)
237 .try_into()
238 .unwrap()
239 }
240 }
241
242 unsafe fn read_shifted_usize(&self, byte_index: usize, shift: usize, end: bool) -> usize {
244 let raw_bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
245 let raw_usize: usize = usize::from_le_bytes(raw_bytes);
246 raw_usize >> shift
247 }
248
249 unsafe fn read_usize(&self, position: usize, count: usize, end: bool) -> usize {
250 let byte_index = position / 8;
251 let bit_offset = position & 7;
252
253 let bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
254
255 let container = if E::is_le() {
256 usize::from_le_bytes(bytes)
257 } else {
258 usize::from_be_bytes(bytes)
259 };
260
261 get_bits_from_usize::<E>(container, bit_offset, count)
262 }
263
264 #[inline]
290 pub fn read_bool(&self, position: usize) -> Result<bool> {
291 let byte_index = position / 8;
292 let bit_offset = position & 7;
293
294 if position >= self.bit_len() {
296 return Err(BitError::NotEnoughData {
297 requested: 1,
298 bits_left: 0,
299 });
300 }
301 if let Some(byte) = self.slice.get(byte_index) {
302 if E::is_le() {
303 let shifted = byte >> bit_offset as u8;
304 Ok(shifted & 1u8 == 1)
305 } else {
306 let shifted = byte << bit_offset as u8;
307 Ok(shifted & 0b1000_0000u8 == 0b1000_0000u8)
308 }
309 } else {
310 Err(BitError::NotEnoughData {
311 requested: 1,
312 bits_left: 0,
313 })
314 }
315 }
316
317 #[doc(hidden)]
318 #[inline]
319 pub unsafe fn read_bool_unchecked(&self, position: usize) -> bool {
320 let byte_index = position / 8;
321 let bit_offset = position & 7;
322
323 let byte = self.slice.get_unchecked(byte_index);
324 if E::is_le() {
325 let shifted = byte >> bit_offset;
326 shifted & 1u8 == 1
327 } else {
328 let shifted = byte << bit_offset;
329 shifted & 0b1000_0000u8 == 0b1000_0000u8
330 }
331 }
332
333 #[inline]
361 pub fn read_int<T>(&self, position: usize, count: usize) -> Result<T>
362 where
363 T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor + WrappingSub,
364 {
365 let type_bit_size = size_of::<T>() * 8;
366
367 if type_bit_size < count {
368 return Err(BitError::TooManyBits {
369 requested: count,
370 max: type_bit_size,
371 });
372 }
373
374 if position + count + USIZE_BIT_SIZE > self.bit_len() {
375 if position + count > self.bit_len() {
376 return if position > self.bit_len() {
377 Err(BitError::IndexOutOfBounds {
378 pos: position,
379 size: self.bit_len(),
380 })
381 } else {
382 Err(BitError::NotEnoughData {
383 requested: count,
384 bits_left: self.bit_len() - position,
385 })
386 };
387 }
388 Ok(unsafe { self.read_int_unchecked(position, count, true) })
389 } else {
390 Ok(unsafe { self.read_int_unchecked(position, count, false) })
391 }
392 }
393
394 #[doc(hidden)]
395 #[inline]
396 pub unsafe fn read_int_unchecked<T>(&self, position: usize, count: usize, end: bool) -> T
397 where
398 T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor + WrappingSub,
399 {
400 let type_bit_size = size_of::<T>() * 8;
401
402 let bit_offset = position & 7;
403
404 let fit_usize = count + bit_offset < usize::BITS as usize;
405 let value = if fit_usize {
406 self.read_fit_usize(position, count, end)
407 } else {
408 self.read_no_fit_usize(position, count, end)
409 };
410
411 if count == type_bit_size {
412 value
413 } else {
414 self.make_signed(value, count)
415 }
416 }
417
418 #[inline]
419 unsafe fn read_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
420 where
421 T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
422 {
423 let raw = self.read_usize(position, count, end);
424 T::from_unchecked(raw)
425 }
426
427 unsafe fn read_no_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
428 where
429 T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
430 {
431 let mut left_to_read = count;
432 let mut acc = T::zero();
433 let max_read = (size_of::<usize>() - 1) * 8;
434 let mut read_pos = position;
435 let mut bit_offset = 0;
436 while left_to_read > 0 {
437 let bits_left = self.bit_len() - read_pos;
438 let read = min(min(left_to_read, max_read), bits_left);
439 let data = T::from_unchecked(self.read_usize(read_pos, read, end));
440 if E::is_le() {
441 acc |= data << bit_offset;
442 } else {
443 acc = acc << read;
444 acc |= data;
445 }
446 bit_offset += read;
447 read_pos += read;
448 left_to_read -= read;
449 }
450
451 acc
452 }
453
454 fn make_signed<T>(&self, value: T, count: usize) -> T
455 where
456 T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor + WrappingSub,
457 {
458 if count == 0 {
459 T::zero()
460 } else if T::is_signed() {
461 let sign_bit = (value >> (count - 1)) & T::one();
462 if sign_bit == T::one() {
463 value | (T::zero() - T::one()) ^ (T::one() << count).wrapping_sub(&T::one())
464 } else {
465 value
466 }
467 } else {
468 value
469 }
470 }
471
472 #[inline]
501 pub fn read_bytes(&self, position: usize, byte_count: usize) -> Result<Cow<'a, [u8]>> {
502 if position + byte_count * 8 > self.bit_len() {
503 if position > self.bit_len() {
504 return Err(BitError::IndexOutOfBounds {
505 pos: position,
506 size: self.bit_len(),
507 });
508 } else {
509 return Err(BitError::NotEnoughData {
510 requested: byte_count * 8,
511 bits_left: self.bit_len() - position,
512 });
513 }
514 }
515
516 Ok(unsafe { self.read_bytes_unchecked(position, byte_count) })
517 }
518
519 #[doc(hidden)]
520 #[inline]
521 pub unsafe fn read_bytes_unchecked(&self, position: usize, byte_count: usize) -> Cow<'a, [u8]> {
522 let shift = position & 7;
523
524 if shift == 0 {
525 let byte_pos = position / 8;
526 return Cow::Borrowed(&self.slice[byte_pos..byte_pos + byte_count]);
527 }
528
529 let mut data = Vec::with_capacity(byte_count);
530 let mut byte_left = byte_count;
531 let mut read_pos = position / 8;
532
533 if E::is_le() {
534 while byte_left > USIZE_SIZE - 1 {
535 let raw = self.read_shifted_usize(read_pos, shift, false);
536 let bytes = if E::is_le() {
537 raw.to_le_bytes()
538 } else {
539 raw.to_be_bytes()
540 };
541 let read_bytes = USIZE_SIZE - 1;
542 let usable_bytes = &bytes[0..read_bytes];
543 data.extend_from_slice(usable_bytes);
544
545 read_pos += read_bytes;
546 byte_left -= read_bytes;
547 }
548
549 let bytes = self.read_shifted_usize(read_pos, shift, true).to_le_bytes();
550 let usable_bytes = &bytes[0..byte_left];
551 data.extend_from_slice(usable_bytes);
552 } else {
553 let mut pos = position;
554 while byte_left > 0 {
555 data.push(self.read_int_unchecked::<u8>(pos, 8, true));
556 byte_left -= 1;
557 pos += 8;
558 }
559 }
560
561 Cow::Owned(data)
562 }
563
564 #[inline]
600 pub fn read_string(&self, position: usize, byte_len: Option<usize>) -> Result<Cow<'a, str>> {
601 match byte_len {
602 Some(byte_len) => {
603 let bytes = self.read_bytes(position, byte_len)?;
604
605 let string = match bytes {
606 Cow::Owned(bytes) => Cow::Owned(
607 String::from_utf8(bytes)?
608 .trim_end_matches(char::from(0))
609 .to_string(),
610 ),
611 Cow::Borrowed(bytes) => Cow::Borrowed(
612 std::str::from_utf8(bytes)
613 .map_err(|err| BitError::Utf8Error(err, bytes.len()))?
614 .trim_end_matches(char::from(0)),
615 ),
616 };
617 Ok(string)
618 }
619 None => {
620 let bytes = self.read_string_bytes(position)?;
621 let string = match bytes {
622 Cow::Owned(bytes) => Cow::Owned(String::from_utf8(bytes)?),
623 Cow::Borrowed(bytes) => Cow::Borrowed(
624 std::str::from_utf8(bytes)
625 .map_err(|err| BitError::Utf8Error(err, bytes.len()))?,
626 ),
627 };
628 Ok(string)
629 }
630 }
631 }
632
633 #[inline]
634 fn find_null_byte(&self, byte_index: usize) -> usize {
635 memchr::memchr(0, &self.slice[byte_index..])
636 .map(|index| index + byte_index)
637 .unwrap_or(self.slice.len()) }
639
640 #[inline]
641 fn read_string_bytes(&self, position: usize) -> Result<Cow<'a, [u8]>> {
642 let shift = position & 7;
643 if shift == 0 {
644 let byte_index = position / 8;
645 Ok(Cow::Borrowed(
646 &self.slice[byte_index..self.find_null_byte(byte_index)],
647 ))
648 } else {
649 let mut acc = Vec::with_capacity(32);
650 if E::is_le() {
651 let mut byte_index = position / 8;
652 loop {
653 let shifted = unsafe { self.read_shifted_usize(byte_index, shift, true) };
660
661 let has_null = contains_zero_byte_non_top(shifted);
662 let bytes: [u8; USIZE_SIZE] = shifted.to_le_bytes();
663 let usable_bytes = &bytes[0..USIZE_SIZE - 1];
664
665 if has_null {
666 for i in 0..USIZE_SIZE - 1 {
667 if usable_bytes[i] == 0 {
668 acc.extend_from_slice(&usable_bytes[0..i]);
669 return Ok(Cow::Owned(acc));
670 }
671 }
672 }
673
674 acc.extend_from_slice(&usable_bytes[0..USIZE_SIZE - 1]);
675
676 byte_index += USIZE_SIZE - 1;
677 }
678 } else {
679 let mut pos = position;
680 loop {
681 let byte = self.read_int::<u8>(pos, 8)?;
682 pos += 8;
683 if byte == 0 {
684 return Ok(Cow::Owned(acc));
685 } else {
686 acc.push(byte);
687 }
688 }
689 }
690 }
691 }
692
693 #[inline]
718 pub fn read_float<T>(&self, position: usize) -> Result<T>
719 where
720 T: Float + UncheckedPrimitiveFloat,
721 {
722 let type_bit_size = size_of::<T>() * 8;
723 if position + type_bit_size + USIZE_BIT_SIZE > self.bit_len() {
724 if position + type_bit_size > self.bit_len() {
725 if position > self.bit_len() {
726 return Err(BitError::IndexOutOfBounds {
727 pos: position,
728 size: self.bit_len(),
729 });
730 } else {
731 return Err(BitError::NotEnoughData {
732 requested: size_of::<T>() * 8,
733 bits_left: self.bit_len() - position,
734 });
735 }
736 }
737 Ok(unsafe { self.read_float_unchecked(position, true) })
738 } else {
739 Ok(unsafe { self.read_float_unchecked(position, false) })
740 }
741 }
742
743 #[doc(hidden)]
744 #[inline]
745 pub unsafe fn read_float_unchecked<T>(&self, position: usize, end: bool) -> T
746 where
747 T: Float + UncheckedPrimitiveFloat,
748 {
749 if position & 7 == 0 {
750 let byte_pos = position / 8;
751 let bytes = self.slice[byte_pos..byte_pos + size_of::<T>()]
752 .try_into()
753 .unwrap();
754 T::from_bytes::<E>(bytes)
755 } else {
756 let int = self.read_int_unchecked(position, size_of::<T>() * 8, end);
757 T::from_int(int)
758 }
759 }
760
761 pub(crate) fn get_sub_buffer(&self, bit_len: usize) -> Result<Self> {
762 if bit_len > self.bit_len() {
763 return Err(BitError::NotEnoughData {
764 requested: bit_len,
765 bits_left: self.bit_len(),
766 });
767 }
768
769 Ok(BitReadBuffer {
770 bytes: self.bytes.clone(),
771 bit_len,
772 endianness: PhantomData,
773 slice: self.slice,
774 })
775 }
776
777 pub fn truncate(&mut self, bit_len: usize) -> Result<()> {
779 if bit_len > self.bit_len() {
780 return Err(BitError::NotEnoughData {
781 requested: bit_len,
782 bits_left: self.bit_len(),
783 });
784 }
785
786 self.bit_len = bit_len;
787 Ok(())
788 }
789}
790
791impl<'a, E: Endianness> From<&'a [u8]> for BitReadBuffer<'a, E> {
792 fn from(bytes: &'a [u8]) -> Self {
793 BitReadBuffer::new(bytes, E::endianness())
794 }
795}
796
797impl<E: Endianness> From<Vec<u8>> for BitReadBuffer<'_, E> {
798 fn from(bytes: Vec<u8>) -> Self {
799 BitReadBuffer::new_owned(bytes, E::endianness())
800 }
801}
802
803impl<E: Endianness> Clone for BitReadBuffer<'_, E> {
804 fn clone(&self) -> Self {
805 BitReadBuffer {
806 bytes: self.bytes.clone(),
807 bit_len: self.bit_len(),
808 endianness: PhantomData,
809 slice: self.slice,
810 }
811 }
812}
813
814impl<E: Endianness> Debug for BitReadBuffer<'_, E> {
815 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
816 write!(
817 f,
818 "BitBuffer {{ bit_len: {}, endianness: {} }}",
819 self.bit_len(),
820 E::as_string()
821 )
822 }
823}
824
825impl<E: Endianness> PartialEq for BitReadBuffer<'_, E> {
826 fn eq(&self, other: &Self) -> bool {
827 if self.bit_len != other.bit_len {
828 return false;
829 }
830 if self.bit_len % 8 == 0 {
831 self.slice == other.slice
832 } else {
833 let bytes = self.bit_len / 8;
834 let bits_left = self.bit_len % 8;
835 if self.slice[0..bytes] != other.slice[0..bytes] {
836 return false;
837 }
838 let rest_self = self.read_int::<u8>(bytes * 8, bits_left).unwrap();
839 let rest_other = other.read_int::<u8>(bytes * 8, bits_left).unwrap();
840 rest_self == rest_other
841 }
842 }
843}
844
845#[inline(always)]
853fn contains_zero_byte_non_top(x: usize) -> bool {
854 #[cfg(target_pointer_width = "64")]
855 const LO_USIZE: usize = 0x0001_0101_0101_0101;
856 #[cfg(target_pointer_width = "64")]
857 const HI_USIZE: usize = 0x0080_8080_8080_8080;
858
859 #[cfg(target_pointer_width = "32")]
860 const LO_USIZE: usize = 0x000_10101;
861 #[cfg(target_pointer_width = "32")]
862 const HI_USIZE: usize = 0x0080_8080;
863
864 x.wrapping_sub(LO_USIZE) & !x & HI_USIZE != 0
865}
866
867#[cfg(feature = "serde")]
868use serde::{de, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
869
870#[cfg(feature = "serde")]
871impl<'a, E: Endianness> Serialize for BitReadBuffer<'a, E> {
872 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
873 where
874 S: Serializer,
875 {
876 let mut data = self.read_bytes(0, self.bit_len() / 8).unwrap().to_vec();
877 let bits_left = self.bit_len() % 8;
878 if bits_left > 0 {
879 data.push(self.read_int((self.bit_len() / 8) * 8, bits_left).unwrap());
880 }
881
882 let mut s = serializer.serialize_struct("BitReadBuffer", 3)?;
883 s.serialize_field("data", &data)?;
884 s.serialize_field("bit_length", &self.bit_len())?;
885 s.end()
886 }
887}
888
889#[cfg(feature = "serde")]
890impl<'de, E: Endianness> Deserialize<'de> for BitReadBuffer<'static, E> {
891 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
892 where
893 D: Deserializer<'de>,
894 {
895 #[derive(Deserialize)]
896 struct BitData {
897 data: Vec<u8>,
898 bit_length: usize,
899 }
900
901 let data = BitData::deserialize(deserializer)?;
902 let mut buffer = BitReadBuffer::new_owned(data.data, E::endianness());
903 buffer
904 .truncate(data.bit_length)
905 .map_err(de::Error::custom)?;
906 Ok(buffer)
907 }
908}
909
910#[cfg(feature = "serde")]
911#[test]
912fn test_serde_roundtrip() {
913 use crate::LittleEndian;
914
915 let mut buffer = BitReadBuffer::new_owned(vec![55; 8], LittleEndian);
916 buffer.truncate(61).unwrap();
917
918 let json = serde_json::to_string(&buffer).unwrap();
919
920 let result: BitReadBuffer<LittleEndian> = serde_json::from_str(&json).unwrap();
921
922 assert_eq!(result, buffer);
923}