use core::convert::TryInto;
#[cfg_attr(all(target_arch = "x86_64", target_feature = "bmi2"), allow(dead_code))]
const BIT_MASK: [u64; 65] = {
let mut table = [0u64; 65];
let mut i: u32 = 1;
while i < 64 {
table[i as usize] = (1u64 << i) - 1;
i += 1;
}
table[64] = u64::MAX;
table
};
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
debug_assert!(n <= 64, "mask_lower_bits: n must be <= 64, got {}", n);
#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
{
unsafe { core::arch::x86_64::_bzhi_u64(value, n as u32) }
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
{
value & BIT_MASK[n as usize]
}
}
pub struct BitReaderReversed<'s> {
index: usize,
bits_consumed: u8,
extra_bits: usize,
source: &'s [u8],
bit_container: u64,
}
impl<'s> BitReaderReversed<'s> {
pub fn bits_remaining(&self) -> isize {
self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
}
pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
BitReaderReversed {
index: source.len(),
bits_consumed: 64,
source,
bit_container: 0,
extra_bits: 0,
}
}
#[cold]
fn refill(&mut self) {
let bytes_consumed = self.bits_consumed as usize / 8;
if bytes_consumed == 0 {
return;
}
if self.index >= bytes_consumed {
self.index -= bytes_consumed;
self.bits_consumed &= 7;
self.bit_container =
u64::from_le_bytes((&self.source[self.index..][..8]).try_into().unwrap());
} else if self.index > 0 {
if self.source.len() >= 8 {
self.bit_container = u64::from_le_bytes((&self.source[..8]).try_into().unwrap());
} else {
let mut value = [0; 8];
value[..self.source.len()].copy_from_slice(self.source);
self.bit_container = u64::from_le_bytes(value);
}
self.bits_consumed -= 8 * self.index as u8;
self.index = 0;
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else if self.bits_consumed < 64 {
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else {
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
self.bit_container = 0;
}
debug_assert!(self.bits_consumed < 8);
}
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> u64 {
if self.bits_consumed + n > 64 {
self.refill();
}
let value = self.peek_bits(n);
self.consume(n);
value
}
#[inline(always)]
pub fn ensure_bits(&mut self, n: u8) {
debug_assert!(n <= 56);
if self.bits_consumed + n > 64 {
self.refill();
}
}
#[inline(always)]
pub fn get_bits_unchecked(&mut self, n: u8) -> u64 {
debug_assert!(n <= 56);
debug_assert!(
self.bits_consumed + n <= 64,
"get_bits_unchecked: not enough bits (consumed={}, requested={})",
self.bits_consumed,
n
);
let value = self.peek_bits(n);
self.consume(n);
value
}
#[inline(always)]
pub fn peek_bits(&mut self, n: u8) -> u64 {
debug_assert!(
n == 0 || self.bits_consumed + n <= 64,
"peek_bits: not enough bits (consumed={}, requested={})",
self.bits_consumed,
n
);
let shift_by = (64u8 - self.bits_consumed).wrapping_sub(n);
mask_lower_bits(self.bit_container.wrapping_shr(shift_by as u32), n)
}
#[inline(always)]
pub fn peek_bits_triple(&mut self, sum: u8, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
debug_assert_eq!(
u16::from(sum),
u16::from(n1) + u16::from(n2) + u16::from(n3),
"peek_bits_triple: sum ({}) must equal n1+n2+n3 ({}+{}+{})",
sum,
n1,
n2,
n3
);
debug_assert!(
sum == 0 || self.bits_consumed + sum <= 64,
"peek_bits_triple: not enough bits (consumed={}, requested={})",
self.bits_consumed,
sum
);
let shift_by = (64u8 - self.bits_consumed).wrapping_sub(sum);
let all_three = self.bit_container.wrapping_shr(shift_by as u32);
let val1 = mask_lower_bits(all_three.wrapping_shr(u32::from(n3) + u32::from(n2)), n1);
let val2 = mask_lower_bits(all_three.wrapping_shr(u32::from(n3)), n2);
let val3 = mask_lower_bits(all_three, n3);
(val1, val2, val3)
}
#[inline(always)]
pub fn consume(&mut self, n: u8) {
self.bits_consumed += n;
debug_assert!(self.bits_consumed <= 64);
}
#[inline(always)]
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum_wide = u16::from(n1) + u16::from(n2) + u16::from(n3);
if sum_wide <= 56 {
let sum = sum_wide as u8;
self.ensure_bits(sum);
let triple = self.peek_bits_triple(sum, n1, n2, n3);
self.consume(sum);
return triple;
}
(self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
}
}
#[cfg(test)]
mod test {
#[test]
fn it_works() {
let data = [0b10101010, 0b01010101];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(1), 1);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(4), 0b1010);
assert_eq!(br.get_bits(4), 0b1101);
assert_eq!(br.get_bits(4), 0b0101);
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.bits_remaining(), -7);
}
#[test]
fn ensure_and_unchecked_match_get_bits() {
let data: [u8; 10] = [0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01];
let mut ref_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(0);
let r2 = ref_br.get_bits(7);
let r3 = ref_br.get_bits(13);
let r4 = ref_br.get_bits(9);
let r5 = ref_br.get_bits(8);
let r5b = ref_br.get_bits(2);
let r6 = ref_br.get_bits(9);
let r7 = ref_br.get_bits(9);
let r8 = ref_br.get_bits(8);
let mut fast_br = super::BitReaderReversed::new(&data);
fast_br.ensure_bits(0);
assert_eq!(fast_br.get_bits_unchecked(0), r1);
fast_br.ensure_bits(7);
assert_eq!(fast_br.get_bits_unchecked(7), r2);
fast_br.ensure_bits(13);
assert_eq!(fast_br.get_bits_unchecked(13), r3);
fast_br.ensure_bits(9);
assert_eq!(fast_br.get_bits_unchecked(9), r4);
fast_br.ensure_bits(8);
assert_eq!(fast_br.get_bits_unchecked(8), r5);
fast_br.ensure_bits(2);
assert_eq!(fast_br.get_bits_unchecked(2), r5b);
fast_br.ensure_bits(26);
assert_eq!(fast_br.get_bits_unchecked(9), r6);
assert_eq!(fast_br.get_bits_unchecked(9), r7);
assert_eq!(fast_br.get_bits_unchecked(8), r8);
assert_eq!(ref_br.bits_remaining(), fast_br.bits_remaining());
}
#[test]
fn mask_table_correctness() {
assert_eq!(super::BIT_MASK[0], 0);
assert_eq!(super::BIT_MASK[1], 1);
assert_eq!(super::BIT_MASK[8], 0xFF);
assert_eq!(super::BIT_MASK[16], 0xFFFF);
assert_eq!(super::BIT_MASK[32], 0xFFFF_FFFF);
assert_eq!(super::BIT_MASK[63], (1u64 << 63) - 1);
assert_eq!(super::BIT_MASK[64], u64::MAX);
for n in 0..64u32 {
assert_eq!(
super::BIT_MASK[n as usize],
(1u64 << n) - 1,
"BIT_MASK[{n}] mismatch"
);
}
}
#[test]
fn mask_lower_bits_edge_cases() {
assert_eq!(super::mask_lower_bits(u64::MAX, 0), 0);
assert_eq!(super::mask_lower_bits(u64::MAX, 1), 1);
assert_eq!(
super::mask_lower_bits(0xABCD_1234_5678_9ABC, 64),
0xABCD_1234_5678_9ABC
);
assert_eq!(super::mask_lower_bits(0xABCD_1234_5678_9ABC, 8), 0xBC);
assert_eq!(super::mask_lower_bits(0xABCD_1234_5678_9ABC, 16), 0x9ABC);
}
#[test]
fn peek_bits_zero_is_always_zero() {
let data = [0xFF; 8];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.peek_bits(0), 0);
br.get_bits(7);
assert_eq!(br.peek_bits(0), 0);
br.bits_consumed = 0;
assert_eq!(br.peek_bits(0), 0);
}
#[test]
fn get_bits_triple_matches_individual() {
let data: [u8; 16] = [
0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01, 0x99, 0x88, 0x77, 0x66,
0x55, 0x44,
];
let mut ref_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(8);
let r2 = ref_br.get_bits(9);
let r3 = ref_br.get_bits(9);
let mut triple_br = super::BitReaderReversed::new(&data);
let (t1, t2, t3) = triple_br.get_bits_triple(8, 9, 9);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
let mut ref_br = super::BitReaderReversed::new(&data);
let mut triple_br = super::BitReaderReversed::new(&data);
let _ = ref_br.get_bits(8);
let _ = triple_br.get_bits(8);
let r1 = ref_br.get_bits(8);
let r2 = ref_br.get_bits(9);
let r3 = ref_br.get_bits(9);
let (t1, t2, t3) = triple_br.get_bits_triple(8, 9, 9);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
let mut ref_br = super::BitReaderReversed::new(&data);
let mut triple_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(5);
let r2 = ref_br.get_bits(0);
let r3 = ref_br.get_bits(4);
let (t1, t2, t3) = triple_br.get_bits_triple(5, 0, 4);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
}
}