use core::convert::TryInto;
#[cfg(all(feature = "std", target_arch = "x86_64"))]
use std::sync::OnceLock;
#[cfg_attr(all(target_arch = "x86_64", target_feature = "bmi2"), allow(dead_code))]
const BIT_MASK: [u64; 65] = {
let mut table = [0u64; 65];
let mut i: u32 = 1;
while i < 64 {
table[i as usize] = (1u64 << i) - 1;
i += 1;
}
table[64] = u64::MAX;
table
};
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
debug_assert!(n <= 64, "mask_lower_bits: n must be <= 64, got {}", n);
#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
{
unsafe { core::arch::x86_64::_bzhi_u64(value, n as u32) }
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
{
value & BIT_MASK[n as usize]
}
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[derive(Copy, Clone)]
struct TripleExtractDispatch {
use_pext: bool,
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
static TRIPLE_EXTRACT_DISPATCH: OnceLock<TripleExtractDispatch> = OnceLock::new();
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[inline(always)]
fn should_use_pext(vendor: [u8; 12], family: u32) -> bool {
vendor != *b"AuthenticAMD" || family != 0x17
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[inline(always)]
fn triple_extract_dispatch() -> &'static TripleExtractDispatch {
TRIPLE_EXTRACT_DISPATCH.get_or_init(detect_triple_extract_dispatch)
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
fn detect_triple_extract_dispatch() -> TripleExtractDispatch {
use core::arch::x86_64::__cpuid;
use std::arch::is_x86_feature_detected;
if !is_x86_feature_detected!("bmi2") {
return TripleExtractDispatch { use_pext: false };
}
let leaf0 = __cpuid(0);
let mut vendor = [0u8; 12];
vendor[0..4].copy_from_slice(&leaf0.ebx.to_le_bytes());
vendor[4..8].copy_from_slice(&leaf0.edx.to_le_bytes());
vendor[8..12].copy_from_slice(&leaf0.ecx.to_le_bytes());
let eax = __cpuid(1).eax;
let base_family = (eax >> 8) & 0xF;
let ext_family = (eax >> 20) & 0xFF;
let family = if base_family == 0xF {
base_family + ext_family
} else {
base_family
};
TripleExtractDispatch {
use_pext: should_use_pext(vendor, family),
}
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[inline(always)]
fn try_extract_triple_with_pext(all_three: u64, n1: u8, n2: u8, n3: u8) -> Option<(u64, u64, u64)> {
if !triple_extract_dispatch().use_pext {
return None;
}
Some(unsafe { extract_triple_pext(all_three, n1, n2, n3) })
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[target_feature(enable = "bmi2")]
unsafe fn extract_triple_pext(all_three: u64, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
use core::arch::x86_64::_pext_u64;
let mask3 = BIT_MASK[n3 as usize];
let mask2 = BIT_MASK[n2 as usize].wrapping_shl(u32::from(n3));
let mask1 = BIT_MASK[n1 as usize].wrapping_shl(u32::from(n2) + u32::from(n3));
let val1 = _pext_u64(all_three, mask1);
let val2 = _pext_u64(all_three, mask2);
let val3 = _pext_u64(all_three, mask3);
(val1, val2, val3)
}
pub struct BitReaderReversed<'s> {
index: usize,
bits_consumed: u8,
extra_bits: usize,
source: &'s [u8],
bit_container: u64,
}
impl<'s> BitReaderReversed<'s> {
pub fn bits_remaining(&self) -> isize {
self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
}
pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
BitReaderReversed {
index: source.len(),
bits_consumed: 64,
source,
bit_container: 0,
extra_bits: 0,
}
}
#[cold]
fn refill(&mut self) {
let bytes_consumed = self.bits_consumed as usize / 8;
if bytes_consumed == 0 {
return;
}
if self.index >= bytes_consumed {
self.index -= bytes_consumed;
self.bits_consumed &= 7;
self.bit_container =
u64::from_le_bytes((&self.source[self.index..][..8]).try_into().unwrap());
} else if self.index > 0 {
if self.source.len() >= 8 {
self.bit_container = u64::from_le_bytes((&self.source[..8]).try_into().unwrap());
} else {
let mut value = [0; 8];
value[..self.source.len()].copy_from_slice(self.source);
self.bit_container = u64::from_le_bytes(value);
}
self.bits_consumed -= 8 * self.index as u8;
self.index = 0;
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else if self.bits_consumed < 64 {
self.bit_container <<= self.bits_consumed;
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
} else {
self.extra_bits += self.bits_consumed as usize;
self.bits_consumed = 0;
self.bit_container = 0;
}
debug_assert!(self.bits_consumed < 8);
}
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> u64 {
if self.bits_consumed + n > 64 {
self.refill();
}
let value = self.peek_bits(n);
self.consume(n);
value
}
#[inline(always)]
pub fn ensure_bits(&mut self, n: u8) {
debug_assert!(n <= 56);
if self.bits_consumed + n > 64 {
self.refill();
}
}
#[inline(always)]
pub fn get_bits_unchecked(&mut self, n: u8) -> u64 {
debug_assert!(n <= 56);
debug_assert!(
self.bits_consumed + n <= 64,
"get_bits_unchecked: not enough bits (consumed={}, requested={})",
self.bits_consumed,
n
);
let value = self.peek_bits(n);
self.consume(n);
value
}
#[inline(always)]
pub fn peek_bits(&mut self, n: u8) -> u64 {
debug_assert!(
n == 0 || self.bits_consumed + n <= 64,
"peek_bits: not enough bits (consumed={}, requested={})",
self.bits_consumed,
n
);
let shift_by = (64u8 - self.bits_consumed).wrapping_sub(n);
mask_lower_bits(self.bit_container.wrapping_shr(shift_by as u32), n)
}
#[inline(always)]
pub fn peek_bits_triple(&mut self, sum: u8, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
debug_assert_eq!(
u16::from(sum),
u16::from(n1) + u16::from(n2) + u16::from(n3),
"peek_bits_triple: sum ({}) must equal n1+n2+n3 ({}+{}+{})",
sum,
n1,
n2,
n3
);
debug_assert!(
sum == 0 || self.bits_consumed + sum <= 64,
"peek_bits_triple: not enough bits (consumed={}, requested={})",
self.bits_consumed,
sum
);
let shift_by = (64u8 - self.bits_consumed).wrapping_sub(sum);
let all_three = self.bit_container.wrapping_shr(shift_by as u32);
#[cfg(all(feature = "std", target_arch = "x86_64"))]
if let Some(values) = try_extract_triple_with_pext(all_three, n1, n2, n3) {
return values;
}
let val1 = mask_lower_bits(all_three.wrapping_shr(u32::from(n3) + u32::from(n2)), n1);
let val2 = mask_lower_bits(all_three.wrapping_shr(u32::from(n3)), n2);
let val3 = mask_lower_bits(all_three, n3);
(val1, val2, val3)
}
#[inline(always)]
pub fn consume(&mut self, n: u8) {
self.bits_consumed += n;
debug_assert!(self.bits_consumed <= 64);
}
#[inline(always)]
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum_wide = u16::from(n1) + u16::from(n2) + u16::from(n3);
if sum_wide <= 56 {
let sum = sum_wide as u8;
self.ensure_bits(sum);
let triple = self.peek_bits_triple(sum, n1, n2, n3);
self.consume(sum);
return triple;
}
(self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
}
}
#[cfg(test)]
mod test {
#[cfg(all(feature = "std", target_arch = "x86_64"))]
use std::arch::is_x86_feature_detected;
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[inline]
fn scalar_extract_triple(all_three: u64, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let val3 = all_three & super::BIT_MASK[n3 as usize];
let val2 = all_three.wrapping_shr(u32::from(n3)) & super::BIT_MASK[n2 as usize];
let val1 =
all_three.wrapping_shr(u32::from(n2) + u32::from(n3)) & super::BIT_MASK[n1 as usize];
(val1, val2, val3)
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[inline]
fn next_test_value(state: &mut u64) -> u64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
x
}
#[test]
fn it_works() {
let data = [0b10101010, 0b01010101];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(1), 1);
assert_eq!(br.get_bits(1), 0);
assert_eq!(br.get_bits(4), 0b1010);
assert_eq!(br.get_bits(4), 0b1101);
assert_eq!(br.get_bits(4), 0b0101);
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.bits_remaining(), -7);
}
#[test]
fn ensure_and_unchecked_match_get_bits() {
let data: [u8; 10] = [0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01];
let mut ref_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(0);
let r2 = ref_br.get_bits(7);
let r3 = ref_br.get_bits(13);
let r4 = ref_br.get_bits(9);
let r5 = ref_br.get_bits(8);
let r5b = ref_br.get_bits(2);
let r6 = ref_br.get_bits(9);
let r7 = ref_br.get_bits(9);
let r8 = ref_br.get_bits(8);
let mut fast_br = super::BitReaderReversed::new(&data);
fast_br.ensure_bits(0);
assert_eq!(fast_br.get_bits_unchecked(0), r1);
fast_br.ensure_bits(7);
assert_eq!(fast_br.get_bits_unchecked(7), r2);
fast_br.ensure_bits(13);
assert_eq!(fast_br.get_bits_unchecked(13), r3);
fast_br.ensure_bits(9);
assert_eq!(fast_br.get_bits_unchecked(9), r4);
fast_br.ensure_bits(8);
assert_eq!(fast_br.get_bits_unchecked(8), r5);
fast_br.ensure_bits(2);
assert_eq!(fast_br.get_bits_unchecked(2), r5b);
fast_br.ensure_bits(26);
assert_eq!(fast_br.get_bits_unchecked(9), r6);
assert_eq!(fast_br.get_bits_unchecked(9), r7);
assert_eq!(fast_br.get_bits_unchecked(8), r8);
assert_eq!(ref_br.bits_remaining(), fast_br.bits_remaining());
}
#[test]
fn mask_table_correctness() {
assert_eq!(super::BIT_MASK[0], 0);
assert_eq!(super::BIT_MASK[1], 1);
assert_eq!(super::BIT_MASK[8], 0xFF);
assert_eq!(super::BIT_MASK[16], 0xFFFF);
assert_eq!(super::BIT_MASK[32], 0xFFFF_FFFF);
assert_eq!(super::BIT_MASK[63], (1u64 << 63) - 1);
assert_eq!(super::BIT_MASK[64], u64::MAX);
for n in 0..64u32 {
assert_eq!(
super::BIT_MASK[n as usize],
(1u64 << n) - 1,
"BIT_MASK[{n}] mismatch"
);
}
}
#[test]
fn mask_lower_bits_edge_cases() {
assert_eq!(super::mask_lower_bits(u64::MAX, 0), 0);
assert_eq!(super::mask_lower_bits(u64::MAX, 1), 1);
assert_eq!(
super::mask_lower_bits(0xABCD_1234_5678_9ABC, 64),
0xABCD_1234_5678_9ABC
);
assert_eq!(super::mask_lower_bits(0xABCD_1234_5678_9ABC, 8), 0xBC);
assert_eq!(super::mask_lower_bits(0xABCD_1234_5678_9ABC, 16), 0x9ABC);
}
#[test]
fn peek_bits_zero_is_always_zero() {
let data = [0xFF; 8];
let mut br = super::BitReaderReversed::new(&data);
assert_eq!(br.peek_bits(0), 0);
br.get_bits(7);
assert_eq!(br.peek_bits(0), 0);
br.bits_consumed = 0;
assert_eq!(br.peek_bits(0), 0);
}
#[test]
fn get_bits_triple_matches_individual() {
let data: [u8; 16] = [
0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01, 0x99, 0x88, 0x77, 0x66,
0x55, 0x44,
];
let mut ref_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(8);
let r2 = ref_br.get_bits(9);
let r3 = ref_br.get_bits(9);
let mut triple_br = super::BitReaderReversed::new(&data);
let (t1, t2, t3) = triple_br.get_bits_triple(8, 9, 9);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
let mut ref_br = super::BitReaderReversed::new(&data);
let mut triple_br = super::BitReaderReversed::new(&data);
let _ = ref_br.get_bits(8);
let _ = triple_br.get_bits(8);
let r1 = ref_br.get_bits(8);
let r2 = ref_br.get_bits(9);
let r3 = ref_br.get_bits(9);
let (t1, t2, t3) = triple_br.get_bits_triple(8, 9, 9);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
let mut ref_br = super::BitReaderReversed::new(&data);
let mut triple_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(5);
let r2 = ref_br.get_bits(0);
let r3 = ref_br.get_bits(4);
let (t1, t2, t3) = triple_br.get_bits_triple(5, 0, 4);
assert_eq!((r1, r2, r3), (t1, t2, t3));
assert_eq!(ref_br.bits_remaining(), triple_br.bits_remaining());
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[test]
fn should_use_pext_policy_table() {
let cases = [
(*b"AuthenticAMD", 0x17, false),
(*b"AuthenticAMD", 0x19, true),
(*b"GenuineIntel", 0x06, true),
];
for (vendor, family, expected) in cases {
assert_eq!(super::should_use_pext(vendor, family), expected);
}
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[test]
fn bmi2_triple_extract_matches_scalar_reference() {
if !is_x86_feature_detected!("bmi2") {
return;
}
let widths = [
(0, 0, 0),
(1, 1, 1),
(3, 5, 7),
(8, 8, 8),
(15, 16, 17),
(21, 21, 21),
(0, 13, 27),
(31, 0, 1),
(1, 31, 0),
(20, 20, 24),
];
let fixed_values = [
0,
1,
u64::MAX,
0x0123_4567_89AB_CDEF,
0xFEDC_BA98_7654_3210,
0xAAAA_AAAA_AAAA_AAAA,
0x5555_5555_5555_5555,
1u64 << 63,
(1u64 << 32) - 1,
];
for &(n1, n2, n3) in &widths {
for &all_three in &fixed_values {
let expected = scalar_extract_triple(all_three, n1, n2, n3);
let pext = unsafe { super::extract_triple_pext(all_three, n1, n2, n3) };
assert_eq!(pext, expected);
if let Some(dispatched) = super::try_extract_triple_with_pext(all_three, n1, n2, n3)
{
assert_eq!(dispatched, expected);
}
}
}
let mut state = 0xD6E8_FD9D_5A2C_19B7u64;
for &(n1, n2, n3) in &widths {
for _ in 0..64 {
let all_three = next_test_value(&mut state);
let expected = scalar_extract_triple(all_three, n1, n2, n3);
let pext = unsafe { super::extract_triple_pext(all_three, n1, n2, n3) };
assert_eq!(pext, expected);
if let Some(dispatched) = super::try_extract_triple_with_pext(all_three, n1, n2, n3)
{
assert_eq!(dispatched, expected);
}
}
}
}
}