use crate::error::{Result, ZiporaError};
use crate::succinct::rank_select::SimdCapabilities;
#[derive(Debug, Clone)]
pub struct Bmi2Capabilities {
pub has_bmi1: bool,
pub has_bmi2: bool,
pub simd_caps: SimdCapabilities,
}
impl Bmi2Capabilities {
pub fn detect() -> Self {
let simd_caps = SimdCapabilities::detect();
Self {
has_bmi1: simd_caps.cpu_features.has_popcnt, has_bmi2: simd_caps.cpu_features.has_bmi2,
simd_caps,
}
}
#[inline]
pub fn get() -> &'static Self {
use std::sync::OnceLock;
static BMI2_CAPS: OnceLock<Bmi2Capabilities> = OnceLock::new();
BMI2_CAPS.get_or_init(Self::detect)
}
}
pub struct Bmi2RankOps;
pub struct Bmi2PrefetchOps;
impl Bmi2PrefetchOps {
#[inline]
pub fn prefetch_bit_data(bit_data: &[u64], word_index: usize) {
#[cfg(target_arch = "x86_64")]
{
if word_index < bit_data.len() {
unsafe {
let ptr = bit_data.as_ptr().add(word_index) as *const i8;
std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr);
}
}
}
}
#[inline]
pub fn prefetch_rank_cache(rank_cache: &[u32], block_index: usize) {
#[cfg(target_arch = "x86_64")]
{
if block_index < rank_cache.len() {
unsafe {
let ptr = rank_cache.as_ptr().add(block_index) as *const i8;
std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr);
}
}
}
}
#[inline]
pub fn prefetch_sequential(bit_data: &[u64], start_word: usize, prefetch_distance: usize) {
#[cfg(target_arch = "x86_64")]
{
let end_word = (start_word + prefetch_distance).min(bit_data.len());
for word_idx in (start_word..end_word).step_by(8) {
unsafe {
let ptr = bit_data.as_ptr().add(word_idx) as *const i8;
std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr);
}
}
}
}
}
impl Bmi2RankOps {
#[inline]
pub fn popcount_u64(x: u64) -> u32 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
return unsafe { Self::popcount_hardware(x) };
}
}
x.count_ones()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[inline]
unsafe fn popcount_hardware(x: u64) -> u32 {
std::arch::x86_64::_popcnt64(x as i64) as u32
}
#[inline]
pub fn popcount_trail(x: u64, n: u32) -> u32 {
if n == 0 {
return 0;
}
if n >= 64 {
return Self::popcount_u64(x);
}
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::popcount_trail_bzhi(x, n) };
}
}
let mask = (1u64 << n) - 1;
Self::popcount_u64(x & mask)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2,popcnt")]
#[inline]
unsafe fn popcount_trail_bzhi(x: u64, n: u32) -> u32 {
let masked = std::arch::x86_64::_bzhi_u64(x, n);
unsafe { Self::popcount_hardware(masked) }
}
#[inline]
pub fn leading_zeros(x: u64) -> u32 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 && x != 0 {
return unsafe { std::arch::x86_64::_lzcnt_u64(x) as u32 };
}
}
x.leading_zeros()
}
#[inline]
pub fn trailing_zeros(x: u64) -> u32 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 && x != 0 {
return unsafe { std::arch::x86_64::_tzcnt_u64(x) as u32 };
}
}
x.trailing_zeros()
}
pub fn popcount_bulk(words: &[u64]) -> Vec<u32> {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.simd_caps.cpu_features.has_avx2 && words.len() >= 4 {
return unsafe { Self::popcount_bulk_avx2(words) };
}
}
words.iter().map(|&w| Self::popcount_u64(w)).collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,popcnt")]
unsafe fn popcount_bulk_avx2(words: &[u64]) -> Vec<u32> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(words.len());
let chunks = words.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let vec = unsafe { _mm256_loadu_si256(chunk.as_ptr() as *const __m256i) };
let mut chunk_results = [0u32; 4];
for i in 0..4 {
let word = unsafe {
match i {
0 => _mm256_extract_epi64::<0>(vec) as u64,
1 => _mm256_extract_epi64::<1>(vec) as u64,
2 => _mm256_extract_epi64::<2>(vec) as u64,
3 => _mm256_extract_epi64::<3>(vec) as u64,
_ => unreachable!(),
}
};
chunk_results[i] = unsafe { Self::popcount_hardware(word) };
}
result.extend_from_slice(&chunk_results);
}
for &word in remainder {
result.push(unsafe { Self::popcount_hardware(word) });
}
result
}
}
pub struct Bmi2SelectOps;
impl Bmi2SelectOps {
#[inline]
pub fn select1_u64(word: u64, k: u32) -> Option<u32> {
if word == 0 || k >= word.count_ones() {
return None;
}
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return Some(unsafe { Self::select1_pdep_optimized(word, k) });
}
}
Self::select1_binary_search(word, k)
}
#[inline]
pub fn select1_u64_enhanced(word: u64, k: u32) -> Option<u32> {
if word == 0 || k >= word.count_ones() {
return None;
}
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return Some(unsafe { Self::select1_pdep_optimized(word, k) });
}
}
Self::select1_binary_search_optimized(word, k)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1,bmi2")]
#[inline]
unsafe fn select1_pdep(word: u64, k: u32) -> u32 {
let deposited = std::arch::x86_64::_pdep_u64(1u64 << k, word);
std::arch::x86_64::_tzcnt_u64(deposited) as u32
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1,bmi2")]
#[inline]
unsafe fn select1_pdep_optimized(word: u64, k: u32) -> u32 {
debug_assert!(k < word.count_ones());
debug_assert!(word != 0);
#[cfg(any(target_env = "gnu", target_env = ""))]
{
let deposited = std::arch::x86_64::_pdep_u64(1u64 << k, word);
deposited.trailing_zeros()
}
#[cfg(not(any(target_env = "gnu", target_env = "")))]
{
let deposited = std::arch::x86_64::_pdep_u64(1u64 << k, word);
std::arch::x86_64::_tzcnt_u64(deposited) as u32
}
}
fn select1_binary_search(word: u64, k: u32) -> Option<u32> {
let mut ones_seen = 0;
for bit_pos in 0..64 {
if (word >> bit_pos) & 1 == 1 {
if ones_seen == k {
return Some(bit_pos);
}
ones_seen += 1;
}
}
None
}
fn select1_binary_search_optimized(word: u64, k: u32) -> Option<u32> {
let ones_count = word.count_ones();
if k >= ones_count {
return None;
}
if ones_count <= 8 {
let mut ones_seen = 0;
for bit_pos in 0..64 {
if (word >> bit_pos) & 1 == 1 {
if ones_seen == k {
return Some(bit_pos);
}
ones_seen += 1;
}
}
return None;
}
let mut low = 0u32;
let mut high = 64u32;
while low < high {
let mid = (low + high) / 2;
let rank_at_mid = Self::popcount_range(word, 0, mid);
if rank_at_mid <= k {
low = mid + 1;
} else {
high = mid;
}
}
if low > 0 && (word >> (low - 1)) & 1 == 1 {
let rank_before = if low == 1 { 0 } else { Self::popcount_range(word, 0, low - 1) };
if rank_before == k {
return Some(low - 1);
}
}
None
}
#[inline]
fn popcount_range(word: u64, start: u32, end: u32) -> u32 {
if start >= end || start >= 64 {
return 0;
}
let end = end.min(64);
let mask = if end == 64 {
u64::MAX
} else {
(1u64 << end) - 1
};
let start_mask = if start == 0 {
0
} else {
(1u64 << start) - 1
};
let range_mask = mask & !start_mask;
(word & range_mask).count_ones()
}
#[inline]
pub fn select0_u64(word: u64, k: u32) -> Option<u32> {
let inverted = !word;
Self::select1_u64(inverted, k)
}
pub fn select1_hybrid_cache(
rank_cache: &[u32],
select_cache: &[u32],
bit_data: &[u64],
target_rank: u32,
cache_density: usize,
) -> Result<u32> {
if select_cache.is_empty() {
return Err(ZiporaError::invalid_data("Empty select cache"));
}
let cache_idx = (target_rank / cache_density as u32) as usize;
if cache_idx >= select_cache.len() {
return Err(ZiporaError::invalid_data("Cache index out of bounds"));
}
let start_block = if cache_idx == 0 { 0 } else { select_cache[cache_idx - 1] };
let end_block = if cache_idx + 1 < select_cache.len() {
select_cache[cache_idx + 1]
} else {
rank_cache.len() as u32
};
let search_range = end_block - start_block;
if search_range <= 32 {
for block_idx in start_block..end_block.min(rank_cache.len() as u32) {
if rank_cache[block_idx as usize] > target_rank {
return Self::find_select_in_block(bit_data, block_idx as usize, target_rank, rank_cache);
}
}
} else {
let mut low = start_block;
let mut high = end_block.min(rank_cache.len() as u32);
while low < high {
let mid = (low + high) / 2;
if rank_cache[mid as usize] <= target_rank {
low = mid + 1;
} else {
high = mid;
}
}
if low < rank_cache.len() as u32 {
return Self::find_select_in_block(bit_data, low as usize, target_rank, rank_cache);
}
}
Err(ZiporaError::invalid_data("Select target not found"))
}
fn find_select_in_block(
bit_data: &[u64],
block_idx: usize,
target_rank: u32,
rank_cache: &[u32],
) -> Result<u32> {
let block_start_bit = block_idx * 256; let block_start_word = block_start_bit / 64;
#[cfg(target_arch = "x86_64")]
{
if block_start_word < bit_data.len() {
unsafe {
let ptr = bit_data.as_ptr().add(block_start_word) as *const i8;
std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr);
}
}
}
let rank_before_block = if block_idx == 0 { 0 } else { rank_cache[block_idx - 1] };
let target_in_block = target_rank - rank_before_block;
let mut current_rank = 0;
for word_in_block in 0..4 {
let word_idx = block_start_word + word_in_block;
if word_idx >= bit_data.len() {
break;
}
let word = bit_data[word_idx];
let word_popcount = word.count_ones();
if current_rank + word_popcount > target_in_block {
let target_in_word = target_in_block - current_rank;
if let Some(bit_pos) = Self::select1_u64_enhanced(word, target_in_word) {
return Ok((word_idx * 64) as u32 + bit_pos);
}
}
current_rank += word_popcount;
}
Err(ZiporaError::invalid_data("Select position not found in block"))
}
pub fn select1_bulk(words: &[u64], indices: &[u32]) -> Result<Vec<u32>> {
let mut results = Vec::with_capacity(indices.len());
let _global_ones_count = 0u32;
for &target_k in indices {
let mut remaining_k = target_k;
let mut found = false;
for (word_idx, &word) in words.iter().enumerate() {
let word_ones = Bmi2RankOps::popcount_u64(word);
if remaining_k < word_ones {
if let Some(bit_pos) = Self::select1_u64(word, remaining_k) {
results.push((word_idx * 64) as u32 + bit_pos);
found = true;
break;
}
}
remaining_k -= word_ones;
}
if !found {
return Err(ZiporaError::invalid_data(format!(
"Select index {} not found",
target_k
)));
}
}
Ok(results)
}
}
pub struct Bmi2BitOps;
impl Bmi2BitOps {
#[inline]
pub fn extract_bits(src: u64, mask: u64) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { std::arch::x86_64::_pext_u64(src, mask) };
}
}
Self::extract_bits_manual(src, mask)
}
fn extract_bits_manual(src: u64, mask: u64) -> u64 {
let mut result = 0u64;
let mut result_pos = 0;
for bit_pos in 0..64 {
if (mask >> bit_pos) & 1 == 1 {
if (src >> bit_pos) & 1 == 1 {
result |= 1u64 << result_pos;
}
result_pos += 1;
}
}
result
}
#[inline]
pub fn deposit_bits(src: u64, mask: u64) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { std::arch::x86_64::_pdep_u64(src, mask) };
}
}
Self::deposit_bits_manual(src, mask)
}
fn deposit_bits_manual(src: u64, mask: u64) -> u64 {
let mut result = 0u64;
let mut src_pos = 0;
for bit_pos in 0..64 {
if (mask >> bit_pos) & 1 == 1 {
if (src >> src_pos) & 1 == 1 {
result |= 1u64 << bit_pos;
}
src_pos += 1;
}
}
result
}
#[inline]
pub fn reset_lowest_bit(x: u64) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
return unsafe { std::arch::x86_64::_blsr_u64(x) };
}
}
x & (x.wrapping_sub(1))
}
#[inline]
pub fn isolate_lowest_bit(x: u64) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
return unsafe { std::arch::x86_64::_blsi_u64(x) };
}
}
x & x.wrapping_neg()
}
#[inline]
pub fn mask_up_to_lowest_bit(x: u64) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
return unsafe { std::arch::x86_64::_blsmsk_u64(x) };
}
}
x ^ (x.wrapping_sub(1))
}
}
pub struct Bmi2BextrOps;
impl Bmi2BextrOps {
#[inline]
pub fn extract_bits_bextr(src: u64, start: u32, length: u32) -> u64 {
if length == 0 || start >= 64 {
return 0;
}
let effective_length = length.min(64 - start);
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::extract_bits_bextr_hardware(src, start, effective_length) };
}
}
let shifted = src >> start;
if effective_length >= 64 {
shifted
} else {
shifted & ((1u64 << effective_length) - 1)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1")]
#[inline]
unsafe fn extract_bits_bextr_hardware(src: u64, start: u32, length: u32) -> u64 {
std::arch::x86_64::_bextr_u64(src, start, length)
}
pub fn extract_bits_multi(sources: &[u64], start: u32, length: u32) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 && sources.len() >= 4 {
return unsafe { Self::extract_bits_multi_vectorized(sources, start, length) };
}
}
sources.iter()
.map(|&src| Self::extract_bits_bextr(src, start, length))
.collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1")]
unsafe fn extract_bits_multi_vectorized(sources: &[u64], start: u32, length: u32) -> Vec<u64> {
let mut results = Vec::with_capacity(sources.len());
for &src in sources {
results.push(std::arch::x86_64::_bextr_u64(src, start, length));
}
results
}
#[inline]
pub fn extract_hash_bucket(hash: u64, bucket_bits: u32) -> u32 {
Self::extract_bits_bextr(hash, 0, bucket_bits) as u32
}
pub fn extract_hash_buckets_bulk(hashes: &[u64], bucket_bits: u32) -> Vec<u32> {
Self::extract_bits_multi(hashes, 0, bucket_bits)
.into_iter()
.map(|x| x as u32)
.collect()
}
}
pub struct Bmi2BzhiOps;
impl Bmi2BzhiOps {
#[inline]
pub fn popcount_bzhi_enhanced(word: u64, bit_count: u32) -> u32 {
if bit_count == 0 {
return 0;
}
if bit_count >= 64 {
return Bmi2RankOps::popcount_u64(word);
}
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::popcount_bzhi_hardware(word, bit_count) };
}
}
let mask = (1u64 << bit_count) - 1;
Bmi2RankOps::popcount_u64(word & mask)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2,popcnt")]
#[inline]
unsafe fn popcount_bzhi_hardware(word: u64, bit_count: u32) -> u32 {
let masked = std::arch::x86_64::_bzhi_u64(word, bit_count);
unsafe { Bmi2RankOps::popcount_hardware(masked) }
}
pub fn bzhi_mask_operations(word: u64, boundaries: &[u32]) -> Vec<u64> {
boundaries.iter()
.map(|&boundary| {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { std::arch::x86_64::_bzhi_u64(word, boundary) };
}
}
if boundary >= 64 {
word
} else {
word & ((1u64 << boundary) - 1)
}
})
.collect()
}
pub fn bzhi_bulk_process(words: &[u64], bit_count: u32) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 && words.len() >= 4 {
return unsafe { Self::bzhi_bulk_vectorized(words, bit_count) };
}
}
let mask = if bit_count >= 64 {
u64::MAX
} else {
(1u64 << bit_count) - 1
};
words.iter().map(|&word| word & mask).collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
unsafe fn bzhi_bulk_vectorized(words: &[u64], bit_count: u32) -> Vec<u64> {
words.iter()
.map(|&word| std::arch::x86_64::_bzhi_u64(word, bit_count))
.collect()
}
}
pub struct Bmi2AdvancedPatterns;
impl Bmi2AdvancedPatterns {
#[inline]
pub fn pdep_ctz_select(word: u64, k: u32) -> Option<u32> {
if word == 0 || k >= word.count_ones() {
return None;
}
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return Some(unsafe { Self::pdep_ctz_select_hardware(word, k) });
}
}
Bmi2SelectOps::select1_u64(word, k)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1,bmi2")]
#[inline]
unsafe fn pdep_ctz_select_hardware(word: u64, k: u32) -> u32 {
let deposited = std::arch::x86_64::_pdep_u64(1u64 << k, word);
std::arch::x86_64::_tzcnt_u64(deposited) as u32
}
pub fn pext_parallel_extract(sources: &[u64], mask: u64) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::pext_parallel_hardware(sources, mask) };
}
}
sources.iter()
.map(|&src| Bmi2BitOps::extract_bits(src, mask))
.collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
unsafe fn pext_parallel_hardware(sources: &[u64], mask: u64) -> Vec<u64> {
sources.iter()
.map(|&src| std::arch::x86_64::_pext_u64(src, mask))
.collect()
}
pub fn pdep_bzhi_composite(src: u64, mask: u64, bit_limit: u32) -> u64 {
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::pdep_bzhi_composite_hardware(src, mask, bit_limit) };
}
}
let deposited = Bmi2BitOps::deposit_bits(src, mask);
if bit_limit >= 64 {
deposited
} else {
deposited & ((1u64 << bit_limit) - 1)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
#[inline]
unsafe fn pdep_bzhi_composite_hardware(src: u64, mask: u64, bit_limit: u32) -> u64 {
let deposited = std::arch::x86_64::_pdep_u64(src, mask);
std::arch::x86_64::_bzhi_u64(deposited, bit_limit)
}
pub fn pdep_ctz_select_bulk(word: u64, indices: &[u32]) -> Result<Vec<u32>> {
if word == 0 {
return Err(ZiporaError::invalid_data("Cannot select from zero word"));
}
let word_popcount = word.count_ones();
let mut results = Vec::with_capacity(indices.len());
for &k in indices {
if k >= word_popcount {
return Err(ZiporaError::invalid_data(format!(
"Select index {} out of bounds (word has {} ones)", k, word_popcount
)));
}
if let Some(pos) = Self::pdep_ctz_select(word, k) {
results.push(pos);
} else {
return Err(ZiporaError::invalid_data(format!(
"Select failed for index {}", k
)));
}
}
Ok(results)
}
}
pub struct Bmi2RangeOps;
impl Bmi2RangeOps {
pub fn count_ones_range(word: u64, start: u32, len: u32) -> u32 {
if len == 0 {
return 0;
}
if start >= 64 {
return 0;
}
let effective_len = len.min(64 - start);
#[cfg(target_arch = "x86_64")]
{
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
return unsafe { Self::count_ones_range_bmi2(word, start, effective_len) };
}
}
let shifted = word >> start;
let mask = if effective_len >= 64 {
u64::MAX
} else {
(1u64 << effective_len) - 1
};
Bmi2RankOps::popcount_u64(shifted & mask)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2,popcnt")]
unsafe fn count_ones_range_bmi2(word: u64, start: u32, len: u32) -> u32 {
use std::arch::x86_64::*;
let shifted = word >> start;
let masked = _bzhi_u64(shifted, len);
unsafe { Bmi2RankOps::popcount_hardware(masked) }
}
pub fn count_ones_multi_range(word: u64, ranges: &[(u32, u32)]) -> Vec<u32> {
ranges
.iter()
.map(|&(start, len)| Self::count_ones_range(word, start, len))
.collect()
}
pub fn extract_rank_fields(words: &[u64], field_mask: u64) -> Vec<u64> {
words
.iter()
.map(|&word| Bmi2BitOps::extract_bits(word, field_mask))
.collect()
}
}
pub struct Bmi2SequenceOps;
impl Bmi2SequenceOps {
#[inline]
pub fn one_seq_len(bit_data: &[u64], bit_pos: usize) -> usize {
if bit_pos >= bit_data.len() * 64 {
return 0;
}
let word_idx = bit_pos / 64;
let bit_offset = bit_pos % 64;
if word_idx >= bit_data.len() {
return 0;
}
let word = bit_data[word_idx];
if (word >> bit_offset) & 1 == 0 {
return 0;
}
let remaining_word = word >> bit_offset;
let trailing_ones = Self::count_trailing_ones(remaining_word);
if trailing_ones < 64 - bit_offset {
return trailing_ones;
}
let mut total_ones = 64 - bit_offset;
for next_word_idx in (word_idx + 1)..bit_data.len() {
let next_word = bit_data[next_word_idx];
let ones_in_word = Self::count_trailing_ones(next_word);
total_ones += ones_in_word;
if ones_in_word < 64 {
break;
}
}
total_ones
}
#[inline]
pub fn zero_seq_len(bit_data: &[u64], bit_pos: usize) -> usize {
if bit_pos >= bit_data.len() * 64 {
return 0;
}
let word_idx = bit_pos / 64;
let bit_offset = bit_pos % 64;
if word_idx >= bit_data.len() {
return 0;
}
let word = bit_data[word_idx];
if (word >> bit_offset) & 1 == 1 {
return 0;
}
let remaining_word = (!word) >> bit_offset;
let trailing_zeros = Self::count_trailing_ones(remaining_word);
if trailing_zeros < 64 - bit_offset {
return trailing_zeros;
}
let mut total_zeros = 64 - bit_offset;
for next_word_idx in (word_idx + 1)..bit_data.len() {
let next_word = bit_data[next_word_idx];
let zeros_in_word = Self::count_trailing_ones(!next_word);
total_zeros += zeros_in_word;
if zeros_in_word < 64 {
break;
}
}
total_zeros
}
#[inline]
pub fn one_seq_revlen(bit_data: &[u64], end_pos: usize) -> usize {
if end_pos == 0 || end_pos > bit_data.len() * 64 {
return 0;
}
let bit_pos = end_pos - 1;
let word_idx = bit_pos / 64;
let bit_offset = bit_pos % 64;
if word_idx >= bit_data.len() {
return 0;
}
let word = bit_data[word_idx];
if (word >> bit_offset) & 1 == 0 {
return 0;
}
let mask = (1u64 << (bit_offset + 1)) - 1;
let masked_word = word & mask;
let leading_ones = Self::count_leading_ones_in_mask(masked_word, bit_offset + 1);
if leading_ones < bit_offset + 1 {
return leading_ones;
}
let mut total_ones = bit_offset + 1;
if word_idx > 0 {
for prev_word_idx in (0..word_idx).rev() {
let prev_word = bit_data[prev_word_idx];
let ones_in_word = Self::count_leading_ones_in_mask(prev_word, 64);
total_ones += ones_in_word;
if ones_in_word < 64 {
break;
}
}
}
total_ones
}
#[inline]
fn count_trailing_ones(word: u64) -> usize {
if word == u64::MAX {
return 64;
}
let inverted = !word;
if inverted == 0 {
64
} else {
inverted.trailing_zeros() as usize
}
}
#[inline]
fn count_leading_ones_in_mask(word: u64, bit_count: usize) -> usize {
if bit_count == 0 {
return 0;
}
if word == (1u64 << bit_count) - 1 {
return bit_count;
}
let inverted = !word;
let shift_amount = 64 - bit_count;
let shifted = inverted << shift_amount;
if shifted == 0 {
bit_count
} else {
(shifted.leading_zeros() as usize).min(bit_count)
}
}
}
pub struct Bmi2BlockOps;
impl Bmi2BlockOps {
pub fn rank_bulk(blocks: &[u64], positions: &[usize]) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let mut rank = 0;
let block_idx = pos / 64;
let bit_offset = pos % 64;
for i in 0..block_idx.min(blocks.len()) {
rank += Bmi2RankOps::popcount_u64(blocks[i]) as usize;
}
if block_idx < blocks.len() {
rank += Bmi2RankOps::popcount_trail(blocks[block_idx], bit_offset as u32) as usize;
}
results.push(rank);
}
results
}
pub fn select_bulk(blocks: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
let mut results = Vec::with_capacity(indices.len());
let mut cumulative_counts = Vec::with_capacity(blocks.len() + 1);
cumulative_counts.push(0);
for &block in blocks {
let last_count = cumulative_counts.last().expect("cumulative_counts non-empty");
cumulative_counts.push(last_count + Bmi2RankOps::popcount_u64(block) as usize);
}
for &k in indices {
let block_idx = match cumulative_counts.binary_search(&(k + 1)) {
Ok(idx) => {
if idx == 0 {
return Err(ZiporaError::invalid_data(format!(
"Select index {} out of bounds - no ones before first block",
k
)));
}
idx - 1
}
Err(idx) => {
if idx == 0 {
return Err(ZiporaError::invalid_data(format!(
"Select index {} out of bounds - before first block",
k
)));
}
idx - 1
}
};
if block_idx >= blocks.len() {
return Err(ZiporaError::invalid_data(format!(
"Select index {} out of bounds",
k
)));
}
let ones_before_block = cumulative_counts[block_idx];
let k_in_block = k - ones_before_block;
if let Some(bit_pos) = Bmi2SelectOps::select1_u64(blocks[block_idx], k_in_block as u32)
{
results.push(block_idx * 64 + bit_pos as usize);
} else {
return Err(ZiporaError::invalid_data(format!(
"Select failed for index {}",
k
)));
}
}
Ok(results)
}
#[cfg(target_arch = "x86_64")]
pub fn process_blocks_simd(blocks: &[u64]) -> Vec<(u32, u32)> {
let caps = Bmi2Capabilities::get();
if caps.simd_caps.cpu_features.has_avx2 && blocks.len() >= 4 {
unsafe { Self::process_blocks_avx2_bmi2(blocks) }
} else {
blocks
.iter()
.map(|&block| {
(
Bmi2RankOps::popcount_u64(block),
Bmi2RankOps::leading_zeros(block),
)
})
.collect()
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,popcnt,lzcnt")]
unsafe fn process_blocks_avx2_bmi2(blocks: &[u64]) -> Vec<(u32, u32)> {
use std::arch::x86_64::*;
let mut results = Vec::with_capacity(blocks.len());
for chunk in blocks.chunks(4) {
for &block in chunk {
let popcount = unsafe { Bmi2RankOps::popcount_hardware(block) };
let lzcnt = if block != 0 {
_lzcnt_u64(block) as u32
} else {
64
};
results.push((popcount, lzcnt));
}
}
results
}
}
pub trait Bmi2OperationTrait {
fn execute_bmi2(&self, input: &[u64]) -> Vec<u64>;
fn operation_name(&self) -> &'static str;
fn uses_bmi2_acceleration(&self) -> bool {
Bmi2Capabilities::get().has_bmi2
}
fn estimated_speedup(&self) -> f64;
}
pub struct Bmi2Dispatcher {
capabilities: Bmi2Capabilities,
}
impl Bmi2Dispatcher {
pub fn new() -> Self {
Self {
capabilities: Bmi2Capabilities::detect(),
}
}
pub fn dispatch_popcount(&self, word: u64) -> u32 {
if self.capabilities.has_bmi1 {
Bmi2RankOps::popcount_u64(word)
} else {
word.count_ones()
}
}
pub fn dispatch_select(&self, word: u64, k: u32) -> Option<u32> {
if self.capabilities.has_bmi2 {
Bmi2AdvancedPatterns::pdep_ctz_select(word, k)
} else {
Bmi2SelectOps::select1_u64(word, k)
}
}
pub fn dispatch_extract_bits(&self, src: u64, start: u32, length: u32) -> u64 {
if self.capabilities.has_bmi2 {
Bmi2BextrOps::extract_bits_bextr(src, start, length)
} else {
let shifted = src >> start;
if length >= 64 {
shifted
} else {
shifted & ((1u64 << length) - 1)
}
}
}
pub fn dispatch_mask_operation(&self, word: u64, bit_count: u32) -> u64 {
if self.capabilities.has_bmi2 {
#[cfg(target_arch = "x86_64")]
{
if bit_count >= 64 {
return word;
}
return unsafe { std::arch::x86_64::_bzhi_u64(word, bit_count) };
}
}
if bit_count >= 64 {
word
} else {
word & ((1u64 << bit_count) - 1)
}
}
pub fn optimization_report(&self) -> Bmi2OptimizationReport {
Bmi2OptimizationReport {
has_bmi1: self.capabilities.has_bmi1,
has_bmi2: self.capabilities.has_bmi2,
has_popcnt: self.capabilities.simd_caps.cpu_features.has_popcnt,
optimization_tier: if self.capabilities.has_bmi2 {
"Tier 3: Full BMI2 Acceleration"
} else if self.capabilities.has_bmi1 {
"Tier 2: BMI1 Acceleration"
} else {
"Tier 1: Software Fallback"
},
available_operations: self.get_available_operations(),
estimated_speedups: self.get_estimated_speedups(),
}
}
fn get_available_operations(&self) -> Vec<&'static str> {
let mut ops = vec!["Basic bit operations"];
if self.capabilities.has_bmi1 {
ops.extend_from_slice(&[
"Hardware POPCNT",
"LZCNT/TZCNT",
"BLSR/BLSI/BLSMSK",
"BEXTR",
]);
}
if self.capabilities.has_bmi2 {
ops.extend_from_slice(&[
"PDEP/PEXT",
"BZHI",
"PDEP+CTZ select",
"Advanced patterns",
]);
}
ops
}
fn get_estimated_speedups(&self) -> std::collections::HashMap<&'static str, f64> {
let mut speedups = std::collections::HashMap::new();
if self.capabilities.has_bmi1 {
speedups.insert("POPCNT", 2.0);
speedups.insert("LZCNT/TZCNT", 1.5);
speedups.insert("BEXTR", 2.5);
}
if self.capabilities.has_bmi2 {
speedups.insert("Select (PDEP+CTZ)", 8.0);
speedups.insert("Bit extraction (PEXT)", 3.0);
speedups.insert("Masking (BZHI)", 2.0);
speedups.insert("Complex patterns", 4.0);
}
speedups
}
}
impl Default for Bmi2Dispatcher {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Bmi2OptimizationReport {
pub has_bmi1: bool,
pub has_bmi2: bool,
pub has_popcnt: bool,
pub optimization_tier: &'static str,
pub available_operations: Vec<&'static str>,
pub estimated_speedups: std::collections::HashMap<&'static str, f64>,
}
pub struct Bmi2HashOps;
impl Bmi2HashOps {
#[inline]
pub fn hash_bucket_extract(hash: u64, bucket_bits: u32) -> u32 {
Bmi2BextrOps::extract_hash_bucket(hash, bucket_bits)
}
pub fn collision_resolve_pattern(_hash: u64, occupied_mask: u64) -> Option<u32> {
if occupied_mask == u64::MAX {
return None; }
let free_mask = !occupied_mask;
if let Some(first_free) = Bmi2SelectOps::select1_u64(free_mask, 0) {
Some(first_free)
} else {
None
}
}
pub fn hash_buckets_bulk(hashes: &[u64], bucket_bits: u32) -> Vec<u32> {
Bmi2BextrOps::extract_hash_buckets_bulk(hashes, bucket_bits)
}
}
pub struct Bmi2CompressionOps;
impl Bmi2CompressionOps {
pub fn extract_entropy_field(bit_stream: u64, start_bit: u32, field_length: u32) -> u32 {
Bmi2BextrOps::extract_bits_bextr(bit_stream, start_bit, field_length) as u32
}
pub fn decode_variable_length(packed_data: u64, symbol_masks: &[u64]) -> Vec<u32> {
symbol_masks.iter()
.map(|&mask| Bmi2BitOps::extract_bits(packed_data, mask) as u32)
.collect()
}
pub fn pack_compression_fields(values: &[u32], field_masks: &[u64]) -> u64 {
let mut result = 0u64;
for (i, (&value, &mask)) in values.iter().zip(field_masks.iter()).enumerate() {
if i >= field_masks.len() {
break;
}
result |= Bmi2BitOps::deposit_bits(value as u64, mask);
}
result
}
}
pub struct Bmi2StringOps;
impl Bmi2StringOps {
pub fn extract_utf8_char(byte_data: u64, char_start: u32, char_len: u32) -> u32 {
Bmi2BextrOps::extract_bits_bextr(byte_data, char_start * 8, char_len * 8) as u32
}
pub fn pattern_match_extract(text_data: u64, pattern_mask: u64) -> u64 {
Bmi2BitOps::extract_bits(text_data, pattern_mask)
}
pub fn extract_chars_bulk(byte_streams: &[u64], char_positions: &[(u32, u32)]) -> Vec<u32> {
let mut results = Vec::with_capacity(char_positions.len());
for (i, &(start, len)) in char_positions.iter().enumerate() {
if i < byte_streams.len() {
results.push(Self::extract_utf8_char(byte_streams[i], start, len));
}
}
results
}
}
pub struct Bmi2MemoryOps;
impl Bmi2MemoryOps {
pub fn extract_cache_aligned_field(cache_line: u64, field_offset: u32, field_size: u32) -> u64 {
Bmi2BextrOps::extract_bits_bextr(cache_line, field_offset, field_size)
}
pub fn calculate_address_pattern(base_addr: u64, offset_pattern: u64, layout_mask: u64) -> u64 {
base_addr + Bmi2BitOps::deposit_bits(offset_pattern, layout_mask)
}
pub fn extract_memory_fields_bulk(
memory_words: &[u64],
field_offset: u32,
field_size: u32
) -> Vec<u64> {
Bmi2BextrOps::extract_bits_multi(memory_words, field_offset, field_size)
}
}
pub struct Bmi2Accelerator {
capabilities: Bmi2Capabilities,
dispatcher: Bmi2Dispatcher,
}
impl Bmi2Accelerator {
pub fn new() -> Self {
Self {
capabilities: Bmi2Capabilities::detect(),
dispatcher: Bmi2Dispatcher::new(),
}
}
pub fn capabilities(&self) -> &Bmi2Capabilities {
&self.capabilities
}
pub fn is_available(&self) -> bool {
self.capabilities.has_bmi2
}
#[inline]
pub fn rank1(&self, word: u64, pos: u32) -> u32 {
Bmi2RankOps::popcount_trail(word, pos)
}
#[inline]
pub fn select1(&self, word: u64, k: u32) -> Option<u32> {
Bmi2SelectOps::select1_u64(word, k)
}
pub fn rank_bulk(&self, blocks: &[u64], positions: &[usize]) -> Vec<usize> {
Bmi2BlockOps::rank_bulk(blocks, positions)
}
pub fn select_bulk(&self, blocks: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
Bmi2BlockOps::select_bulk(blocks, indices)
}
pub fn extract_bits(&self, src: u64, start: u32, length: u32) -> u64 {
self.dispatcher.dispatch_extract_bits(src, start, length)
}
pub fn select1_enhanced(&self, word: u64, k: u32) -> Option<u32> {
self.dispatcher.dispatch_select(word, k)
}
pub fn hash_bucket_extract(&self, hash: u64, bucket_bits: u32) -> u32 {
Bmi2HashOps::hash_bucket_extract(hash, bucket_bits)
}
pub fn extract_entropy_field(&self, bit_stream: u64, start_bit: u32, field_length: u32) -> u32 {
Bmi2CompressionOps::extract_entropy_field(bit_stream, start_bit, field_length)
}
pub fn extract_utf8_char(&self, byte_data: u64, char_start: u32, char_len: u32) -> u32 {
Bmi2StringOps::extract_utf8_char(byte_data, char_start, char_len)
}
pub fn extract_cache_aligned_field(&self, cache_line: u64, field_offset: u32, field_size: u32) -> u64 {
Bmi2MemoryOps::extract_cache_aligned_field(cache_line, field_offset, field_size)
}
pub fn optimization_report(&self) -> Bmi2OptimizationReport {
self.dispatcher.optimization_report()
}
pub fn stats(&self) -> Bmi2Stats {
Bmi2Stats {
has_bmi1: self.capabilities.has_bmi1,
has_bmi2: self.capabilities.has_bmi2,
has_popcnt: self.capabilities.simd_caps.cpu_features.has_popcnt,
has_lzcnt: self.capabilities.has_bmi1, optimization_tier: if self.capabilities.has_bmi2 {
3
} else if self.capabilities.has_bmi1 {
2
} else {
1
},
estimated_speedup_rank: if self.capabilities.has_bmi2 {
3.5
} else if self.capabilities.has_bmi1 {
2.0
} else {
1.0
},
estimated_speedup_select: if self.capabilities.has_bmi2 { 8.0 } else { 1.0 },
}
}
}
impl Default for Bmi2Accelerator {
fn default() -> Self {
Self::new()
}
}
pub struct PopcountOperation;
pub struct SelectOperation;
pub struct BitExtractionOperation;
impl Bmi2OperationTrait for PopcountOperation {
fn execute_bmi2(&self, input: &[u64]) -> Vec<u64> {
input.iter()
.map(|&word| Bmi2RankOps::popcount_u64(word) as u64)
.collect()
}
fn operation_name(&self) -> &'static str {
"Population Count"
}
fn estimated_speedup(&self) -> f64 {
if self.uses_bmi2_acceleration() { 2.0 } else { 1.0 }
}
}
impl Bmi2OperationTrait for SelectOperation {
fn execute_bmi2(&self, input: &[u64]) -> Vec<u64> {
input.iter()
.map(|&word| {
Bmi2AdvancedPatterns::pdep_ctz_select(word, 0)
.unwrap_or(u32::MAX) as u64
})
.collect()
}
fn operation_name(&self) -> &'static str {
"Select Operation"
}
fn estimated_speedup(&self) -> f64 {
if self.uses_bmi2_acceleration() { 8.0 } else { 1.0 }
}
}
impl Bmi2OperationTrait for BitExtractionOperation {
fn execute_bmi2(&self, input: &[u64]) -> Vec<u64> {
input.iter()
.map(|&word| Bmi2BextrOps::extract_bits_bextr(word, 0, 32))
.collect()
}
fn operation_name(&self) -> &'static str {
"Bit Extraction"
}
fn estimated_speedup(&self) -> f64 {
if self.uses_bmi2_acceleration() { 2.5 } else { 1.0 }
}
}
#[derive(Debug, Clone)]
pub struct Bmi2Stats {
pub has_bmi1: bool,
pub has_bmi2: bool,
pub has_popcnt: bool,
pub has_lzcnt: bool,
pub optimization_tier: u8,
pub estimated_speedup_rank: f64,
pub estimated_speedup_select: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_bmi2_capabilities() {
let caps = Bmi2Capabilities::detect();
println!("BMI1: {}, BMI2: {}", caps.has_bmi1, caps.has_bmi2);
let caps2 = Bmi2Capabilities::get();
assert_eq!(caps.has_bmi1, caps2.has_bmi1);
assert_eq!(caps.has_bmi2, caps2.has_bmi2);
}
#[test]
fn test_popcount_operations() {
let test_words = vec![
0x0000000000000000u64, 0xFFFFFFFFFFFFFFFFu64, 0xAAAAAAAAAAAAAAAAu64, 0x5555555555555555u64, 0x0000000000000001u64, 0x8000000000000000u64, ];
for &word in &test_words {
let expected = word.count_ones();
let actual = Bmi2RankOps::popcount_u64(word);
assert_eq!(
actual, expected,
"Popcount mismatch for word {:#018x}",
word
);
}
}
#[test]
fn test_popcount_trail() {
let word = 0xAAAAAAAAAAAAAAAAu64;
for n in [0, 1, 4, 8, 16, 32, 63, 64] {
let expected = if n >= 64 {
word.count_ones()
} else {
let mask = (1u64 << n) - 1;
(word & mask).count_ones()
};
let actual = Bmi2RankOps::popcount_trail(word, n);
assert_eq!(actual, expected, "Popcount trail mismatch for n={}", n);
}
}
#[test]
fn test_leading_trailing_zeros() {
let test_cases = vec![
(0x0000000000000001u64, 63, 0),
(0x8000000000000000u64, 0, 63),
(0x0000000000000100u64, 55, 8),
(0x0000000000000000u64, 64, 64), (0xFFFFFFFFFFFFFFFFu64, 0, 0),
];
for (word, expected_lz, expected_tz) in test_cases {
let actual_lz = if word == 0 {
64
} else {
Bmi2RankOps::leading_zeros(word)
};
let actual_tz = if word == 0 {
64
} else {
Bmi2RankOps::trailing_zeros(word)
};
assert_eq!(
actual_lz, expected_lz,
"Leading zeros mismatch for {:#018x}",
word
);
assert_eq!(
actual_tz, expected_tz,
"Trailing zeros mismatch for {:#018x}",
word
);
}
}
#[test]
fn test_select_operations() {
let word = 0x0000000000000015u64;
assert_eq!(Bmi2SelectOps::select1_u64(word, 0), Some(0)); assert_eq!(Bmi2SelectOps::select1_u64(word, 1), Some(2)); assert_eq!(Bmi2SelectOps::select1_u64(word, 2), Some(4));
assert_eq!(Bmi2SelectOps::select1_u64(word, 3), None); assert_eq!(Bmi2SelectOps::select1_u64(0, 0), None); }
#[test]
fn test_select0_operations() {
let word = 0xFFFFFFFFFFFFFFEAu64;
assert_eq!(Bmi2SelectOps::select0_u64(word, 0), Some(0)); assert_eq!(Bmi2SelectOps::select0_u64(word, 1), Some(2)); assert_eq!(Bmi2SelectOps::select0_u64(word, 2), Some(4)); }
#[test]
fn test_bit_manipulation() {
let src = 0b11110000u64;
let mask = 0b10101010u64;
let extracted = Bmi2BitOps::extract_bits(src, mask);
let deposited = Bmi2BitOps::deposit_bits(0b1111u64, mask);
println!("PEXT: {:#010x} -> {:#010x}", src, extracted);
println!("PDEP: 0b1111 -> {:#010x}", deposited);
let word = 0b10101000u64;
let reset = Bmi2BitOps::reset_lowest_bit(word);
let isolate = Bmi2BitOps::isolate_lowest_bit(word);
let mask_to_lowest = Bmi2BitOps::mask_up_to_lowest_bit(word);
println!("Original: {:#010x}", word);
println!("Reset lowest: {:#010x}", reset);
println!("Isolate lowest: {:#010x}", isolate);
println!("Mask to lowest: {:#010x}", mask_to_lowest);
}
#[test]
fn test_range_operations() {
let word = 0xAAAAAAAAAAAAAAAAu64;
let count_0_8 = Bmi2RangeOps::count_ones_range(word, 0, 8);
let count_8_8 = Bmi2RangeOps::count_ones_range(word, 8, 8);
let count_60_8 = Bmi2RangeOps::count_ones_range(word, 60, 8);
assert_eq!(count_0_8, 4);
assert_eq!(count_8_8, 4);
assert_eq!(count_60_8, 2);
let ranges = vec![(0, 8), (8, 8), (16, 8), (24, 8)];
let counts = Bmi2RangeOps::count_ones_multi_range(word, &ranges);
assert_eq!(counts, vec![4, 4, 4, 4]);
}
#[test]
fn test_bulk_operations() {
let blocks = vec![
0x0000000000000001u64, 0x0000000000000003u64, 0x0000000000000007u64, 0x000000000000000Fu64, ];
let positions = vec![0, 64, 128, 192, 256];
let ranks = Bmi2BlockOps::rank_bulk(&blocks, &positions);
assert_eq!(ranks, vec![0, 1, 3, 6, 10]);
let indices = vec![0, 2, 5, 9];
let selects = Bmi2BlockOps::select_bulk(&blocks, &indices).unwrap();
println!("Bulk select results: {:?}", selects);
for (i, &select_result) in selects.iter().enumerate() {
let rank_at_result = if i < ranks.len() - 1 {
Bmi2BlockOps::rank_bulk(&blocks, &[select_result])[0]
} else {
0
};
assert!(
rank_at_result <= indices[i] + 1,
"Select verification failed"
);
}
}
#[test]
fn test_accelerator_interface() {
let accel = Bmi2Accelerator::new();
let stats = accel.stats();
println!("BMI2 Accelerator stats: {:?}", stats);
let word = 0x0000000000000055u64;
let rank = accel.rank1(word, 8);
let expected_rank = (word & ((1u64 << 8) - 1)).count_ones();
assert_eq!(rank, expected_rank);
if word.count_ones() > 0 {
let select = accel.select1(word, 0);
assert!(select.is_some());
let pos = select.unwrap();
assert!(
(word >> pos) & 1 == 1,
"Selected position should have a 1 bit"
);
}
}
#[test]
fn test_popcount_bulk() {
let words = vec![
0x0000000000000000u64,
0xFFFFFFFFFFFFFFFFu64,
0xAAAAAAAAAAAAAAAAu64,
0x5555555555555555u64,
0x0F0F0F0F0F0F0F0Fu64,
];
let bulk_result = Bmi2RankOps::popcount_bulk(&words);
let individual_result: Vec<u32> = words
.iter()
.map(|&w| Bmi2RankOps::popcount_u64(w))
.collect();
assert_eq!(bulk_result, individual_result);
assert_eq!(bulk_result, vec![0, 64, 32, 32, 32]);
}
#[test]
fn test_bextr_operations() {
let word = 0xABCDEF1234567890u64;
let lower_8 = Bmi2BextrOps::extract_bits_bextr(word, 0, 8);
assert_eq!(lower_8, 0x90);
let bits_8_15 = Bmi2BextrOps::extract_bits_bextr(word, 8, 8);
assert_eq!(bits_8_15, 0x78);
let upper_16 = Bmi2BextrOps::extract_bits_bextr(word, 48, 16);
assert_eq!(upper_16, 0xABCD);
assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 0, 0), 0); assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 64, 8), 0); assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 60, 8), word >> 60); }
#[test]
fn test_bextr_multi_extraction() {
let sources = vec![
0x1111111111111111u64,
0x2222222222222222u64,
0x3333333333333333u64,
0x4444444444444444u64,
];
let results = Bmi2BextrOps::extract_bits_multi(&sources, 4, 4);
assert_eq!(results, vec![0x1, 0x2, 0x3, 0x4]);
}
#[test]
fn test_hash_operations() {
let hash = 0xABCDEF1234567890u64;
let bucket_8 = Bmi2HashOps::hash_bucket_extract(hash, 8);
assert_eq!(bucket_8, 0x90);
let bucket_4 = Bmi2HashOps::hash_bucket_extract(hash, 4);
assert_eq!(bucket_4, 0x0);
let hashes = vec![0x1, 0x2, 0x3, 0x4];
let buckets = Bmi2HashOps::hash_buckets_bulk(&hashes, 2);
assert_eq!(buckets, vec![1, 2, 3, 0]); }
#[test]
fn test_collision_resolution() {
let occupied_mask = 0b11110000u64;
let first_free = Bmi2HashOps::collision_resolve_pattern(0, occupied_mask);
assert_eq!(first_free, Some(0));
let all_occupied = Bmi2HashOps::collision_resolve_pattern(0, u64::MAX);
assert_eq!(all_occupied, None);
}
#[test]
fn test_bzhi_enhanced_operations() {
let word = 0xFFFFFFFFFFFFFFFFu64;
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 0), 0);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 32), 32);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 64), 64);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 100), 64);
let alt_word = 0xAAAAAAAAAAAAAAAAu64;
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(alt_word, 8), 4);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(alt_word, 16), 8);
}
#[test]
fn test_bzhi_bulk_operations() {
let words = vec![
0xFFFFFFFFFFFFFFFFu64,
0xAAAAAAAAAAAAAAAAu64,
0x5555555555555555u64,
0x0F0F0F0F0F0F0F0Fu64,
];
let results = Bmi2BzhiOps::bzhi_bulk_process(&words, 32);
assert_eq!(results[0], 0xFFFFFFFFu64);
assert_eq!(results[1], 0xAAAAAAAAu64);
assert_eq!(results[2], 0x55555555u64);
assert_eq!(results[3], 0x0F0F0F0Fu64);
}
#[test]
fn test_pdep_ctz_select() {
let word = 0x0000000000000015u64;
assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(word, 0), Some(0)); assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(word, 1), Some(2)); assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(word, 2), Some(4));
assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(word, 3), None); assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(0, 0), None); }
#[test]
fn test_pdep_ctz_bulk_select() {
let word = 0x000000000000001Fu64; let indices = vec![0, 1, 2, 3, 4];
let results = Bmi2AdvancedPatterns::pdep_ctz_select_bulk(word, &indices).unwrap();
assert_eq!(results, vec![0, 1, 2, 3, 4]);
let invalid_indices = vec![0, 5]; let error_result = Bmi2AdvancedPatterns::pdep_ctz_select_bulk(word, &invalid_indices);
assert!(error_result.is_err());
}
#[test]
fn test_pext_parallel_extract() {
let sources = vec![
0b11110000u64,
0b10101010u64,
0b11001100u64,
0b11111111u64,
];
let mask = 0b10101010u64;
let results = Bmi2AdvancedPatterns::pext_parallel_extract(&sources, mask);
for (i, result) in results.iter().enumerate() {
let expected = Bmi2BitOps::extract_bits(sources[i], mask);
assert_eq!(*result, expected);
}
}
#[test]
fn test_pdep_bzhi_composite() {
let src = 0b1111u64;
let mask = 0b10101010u64; let bit_limit = 8;
let result = Bmi2AdvancedPatterns::pdep_bzhi_composite(src, mask, bit_limit);
let expected_deposited = Bmi2BitOps::deposit_bits(src, mask);
let expected = expected_deposited & ((1u64 << bit_limit) - 1);
assert_eq!(result, expected);
}
#[test]
fn test_compression_operations() {
let bit_stream = 0xABCDEF12u64;
let field1 = Bmi2CompressionOps::extract_entropy_field(bit_stream, 0, 8);
assert_eq!(field1, 0x12);
let field2 = Bmi2CompressionOps::extract_entropy_field(bit_stream, 8, 8);
assert_eq!(field2, 0xEF);
let symbol_masks = vec![0x0000000Fu64, 0x000000F0u64, 0x00000F00u64];
let symbols = Bmi2CompressionOps::decode_variable_length(0x123u64, &symbol_masks);
assert_eq!(symbols[0], 0x3); assert_eq!(symbols[1], 0x2); assert_eq!(symbols[2], 0x1); }
#[test]
fn test_string_operations() {
let byte_data = 0x41424344u64;
let char1 = Bmi2StringOps::extract_utf8_char(byte_data, 0, 1);
assert_eq!(char1, 0x44);
let char2 = Bmi2StringOps::extract_utf8_char(byte_data, 1, 1);
assert_eq!(char2, 0x43);
let pattern_mask = 0xFF000000u64; let pattern_result = Bmi2StringOps::pattern_match_extract(byte_data, pattern_mask);
assert_eq!(pattern_result, 0x41); }
#[test]
fn test_memory_operations() {
let cache_line = 0xABCDEF1234567890u64;
let field1 = Bmi2MemoryOps::extract_cache_aligned_field(cache_line, 0, 16);
assert_eq!(field1, 0x7890);
let field2 = Bmi2MemoryOps::extract_cache_aligned_field(cache_line, 16, 16);
assert_eq!(field2, 0x3456);
let base_addr = 0x10000000u64;
let offset_pattern = 0x123u64;
let layout_mask = 0xFFFu64;
let addr = Bmi2MemoryOps::calculate_address_pattern(base_addr, offset_pattern, layout_mask);
assert_eq!(addr, base_addr + Bmi2BitOps::deposit_bits(offset_pattern, layout_mask));
}
#[test]
fn test_dispatcher_functionality() {
let dispatcher = Bmi2Dispatcher::new();
let word = 0xAAAAAAAAAAAAAAAAu64;
let popcount = dispatcher.dispatch_popcount(word);
assert_eq!(popcount, 32);
let select = dispatcher.dispatch_select(word, 0);
assert_eq!(select, Some(1));
let extracted = dispatcher.dispatch_extract_bits(word, 4, 4);
assert_eq!(extracted, 0xA);
let masked = dispatcher.dispatch_mask_operation(word, 32);
assert_eq!(masked, 0xAAAAAAAAu64);
}
#[test]
fn test_optimization_report() {
let dispatcher = Bmi2Dispatcher::new();
let report = dispatcher.optimization_report();
assert!(report.available_operations.contains(&"Basic bit operations"));
assert!(report.optimization_tier.starts_with("Tier"));
println!("BMI2 Optimization Report: {:?}", report);
}
#[test]
fn test_operation_trait_implementations() {
let popcount_op = PopcountOperation;
let select_op = SelectOperation;
let extract_op = BitExtractionOperation;
let test_data = vec![0xAAAAAAAAAAAAAAAAu64, 0x5555555555555555u64];
let popcount_results = popcount_op.execute_bmi2(&test_data);
assert_eq!(popcount_results, vec![32, 32]);
assert_eq!(popcount_op.operation_name(), "Population Count");
let select_results = select_op.execute_bmi2(&test_data);
assert_eq!(select_results[0], 1); assert_eq!(select_results[1], 0); assert_eq!(select_op.operation_name(), "Select Operation");
let extract_results = extract_op.execute_bmi2(&test_data);
assert_eq!(extract_results, vec![0xAAAAAAAAu64, 0x55555555u64]); assert_eq!(extract_op.operation_name(), "Bit Extraction");
}
#[test]
fn test_accelerator_enhanced_interface() {
let accel = Bmi2Accelerator::new();
let word = 0xABCDEF1234567890u64;
let extracted = accel.extract_bits(word, 16, 16);
assert_eq!(extracted, 0x3456);
let select_word = 0x0000000000000015u64; let select_result = accel.select1_enhanced(select_word, 1);
assert_eq!(select_result, Some(2));
let hash_bucket = accel.hash_bucket_extract(0xABCDEF12u64, 8);
assert_eq!(hash_bucket, 0x12);
let entropy_field = accel.extract_entropy_field(0xABCDu64, 4, 8);
assert_eq!(entropy_field, 0xBC);
let utf8_char = accel.extract_utf8_char(0x41424344u64, 0, 1);
assert_eq!(utf8_char, 0x44);
let cache_field = accel.extract_cache_aligned_field(0xABCDEF12u64, 8, 8);
assert_eq!(cache_field, 0xEF);
let report = accel.optimization_report();
println!("Enhanced Accelerator Report: {:?}", report);
}
#[test]
fn test_edge_cases() {
assert_eq!(Bmi2RankOps::popcount_u64(0), 0);
assert_eq!(Bmi2RankOps::popcount_trail(0, 32), 0);
assert_eq!(Bmi2SelectOps::select1_u64(0, 0), None);
assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(0, 0), None);
let all_ones = u64::MAX;
assert_eq!(Bmi2RankOps::popcount_u64(all_ones), 64);
assert_eq!(Bmi2RankOps::popcount_trail(all_ones, 32), 32);
assert_eq!(Bmi2SelectOps::select1_u64(all_ones, 0), Some(0));
assert_eq!(Bmi2SelectOps::select1_u64(all_ones, 63), Some(63));
assert_eq!(Bmi2SelectOps::select1_u64(all_ones, 64), None);
assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(all_ones, 0), Some(0));
assert_eq!(Bmi2AdvancedPatterns::pdep_ctz_select(all_ones, 63), Some(63));
for bit_pos in [0, 1, 31, 32, 63] {
let single_bit = 1u64 << bit_pos;
assert_eq!(Bmi2RankOps::popcount_u64(single_bit), 1);
assert_eq!(
Bmi2SelectOps::select1_u64(single_bit, 0),
Some(bit_pos as u32)
);
assert_eq!(Bmi2SelectOps::select1_u64(single_bit, 1), None);
assert_eq!(
Bmi2AdvancedPatterns::pdep_ctz_select(single_bit, 0),
Some(bit_pos as u32)
);
}
let word = 0xABCDEF1234567890u64;
assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 0, 0), 0); assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 64, 8), 0); assert_eq!(Bmi2BextrOps::extract_bits_bextr(word, 60, 8), word >> 60);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 0), 0);
assert_eq!(Bmi2BzhiOps::popcount_bzhi_enhanced(word, 100), word.count_ones()); }
#[test]
fn test_performance_comparison() {
let test_words = (0..1000).map(|i| (i as u64).wrapping_mul(0x123456789ABCDEFu64)).collect::<Vec<_>>();
let start = std::time::Instant::now();
let _popcount_results: Vec<u32> = test_words.iter()
.map(|&w| Bmi2RankOps::popcount_u64(w))
.collect();
let popcount_time = start.elapsed();
let start = std::time::Instant::now();
let _select_results: Vec<Option<u32>> = test_words.iter()
.filter(|&&w| w != 0)
.map(|&w| Bmi2AdvancedPatterns::pdep_ctz_select(w, 0))
.collect();
let select_time = start.elapsed();
let start = std::time::Instant::now();
let _extract_results: Vec<u64> = test_words.iter()
.map(|&w| Bmi2BextrOps::extract_bits_bextr(w, 8, 16))
.collect();
let extract_time = start.elapsed();
println!("Performance test results:");
println!(" Popcount: {:?}", popcount_time);
println!(" Select: {:?}", select_time);
println!(" Extract: {:?}", extract_time);
}
#[test]
fn test_range_edge_cases() {
let word = 0xFFFFFFFFFFFFFFFFu64;
assert_eq!(Bmi2RangeOps::count_ones_range(word, 10, 0), 0);
assert_eq!(Bmi2RangeOps::count_ones_range(word, 64, 10), 0);
assert_eq!(Bmi2RangeOps::count_ones_range(word, 100, 10), 0);
assert_eq!(Bmi2RangeOps::count_ones_range(word, 60, 10), 4);
assert_eq!(Bmi2RangeOps::count_ones_range(word, 0, 64), 64);
assert_eq!(Bmi2RangeOps::count_ones_range(word, 0, 100), 64); }
}