#![allow(dead_code)]
#[inline(always)]
pub fn extract_bits_fallback(val: u64, n: u8) -> u64 {
if n >= 64 {
val
} else {
let mask = (1u64 << n).wrapping_sub(1);
val & mask
}
}
#[cfg(target_arch = "x86_64")]
pub fn has_bmi2() -> bool {
#[cfg(target_feature = "bmi2")]
{
true
}
#[cfg(not(target_feature = "bmi2"))]
{
std::is_x86_feature_detected!("bmi2")
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn has_bmi2() -> bool {
false
}
#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
#[inline(always)]
pub unsafe fn extract_bits_bmi2(val: u64, n: u32) -> u64 {
use std::arch::x86_64::_bzhi_u64;
_bzhi_u64(val, n)
}
#[inline(always)]
pub fn extract_bits(val: u64, n: u8) -> u64 {
#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
{
unsafe { std::arch::x86_64::_bzhi_u64(val, n as u32) }
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
{
extract_bits_fallback(val, n)
}
}
#[inline(always)]
pub fn decode_extra_bits(saved_bitbuf: u64, codeword_bits: u8, extra_bits: u8) -> u64 {
let shifted = saved_bitbuf >> codeword_bits;
let mask = (1u64 << extra_bits).wrapping_sub(1);
shifted & mask
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_bits_fallback() {
assert_eq!(extract_bits_fallback(0xFFFF, 8), 0xFF);
assert_eq!(extract_bits_fallback(0xFFFF, 4), 0xF);
assert_eq!(extract_bits_fallback(0b11111111, 3), 0b111);
assert_eq!(extract_bits_fallback(0b10101010, 4), 0b1010);
assert_eq!(extract_bits_fallback(u64::MAX, 0), 0);
assert_eq!(extract_bits_fallback(u64::MAX, 64), u64::MAX);
}
#[test]
fn test_decode_extra_bits() {
let saved = 0b1_0000000u64; assert_eq!(decode_extra_bits(saved, 7, 1), 1);
let saved = 0b101_00000u64; assert_eq!(decode_extra_bits(saved, 5, 3), 5);
assert_eq!(decode_extra_bits(0xFFFF, 8, 0), 0);
}
#[test]
fn test_bmi2_detection() {
let has = has_bmi2();
eprintln!("BMI2 available: {}", has);
}
#[test]
fn bench_extract_bits() {
use std::time::Instant;
let iterations = 10_000_000u64;
let mut sum = 0u64;
let start = Instant::now();
for i in 0..iterations {
sum = sum.wrapping_add(extract_bits_fallback(i, (i & 0x3F) as u8));
}
let fallback_time = start.elapsed();
let start = Instant::now();
for i in 0..iterations {
sum = sum.wrapping_add(extract_bits(i, (i & 0x3F) as u8));
}
let extract_time = start.elapsed();
eprintln!("\nExtract bits benchmark ({} iterations):", iterations);
eprintln!(" Fallback: {:.2}ms", fallback_time.as_secs_f64() * 1000.0);
eprintln!(
" extract_bits: {:.2}ms",
extract_time.as_secs_f64() * 1000.0
);
eprintln!(
" Speedup: {:.1}x",
fallback_time.as_secs_f64() / extract_time.as_secs_f64()
);
eprintln!(" BMI2 available: {}", has_bmi2());
eprintln!(" (sum to prevent optimization: {})", sum);
}
}