use crate::error::{Result, ZiporaError};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub struct Bmi2Capabilities {
pub has_bmi1: bool,
pub has_bmi2: bool,
pub has_popcnt: bool,
pub has_avx2: bool,
pub optimization_tier: u8,
pub chunk_size: usize,
}
impl Bmi2Capabilities {
pub fn detect() -> Self {
let has_bmi1 = Self::detect_bmi1();
let has_bmi2 = Self::detect_bmi2();
let has_popcnt = Self::detect_popcnt();
let has_avx2 = Self::detect_avx2();
let optimization_tier = match (has_popcnt, has_bmi1, has_bmi2, has_avx2) {
(true, true, true, true) => 4, (true, true, true, false) => 3, (true, true, false, _) => 2, (true, false, false, _) => 1, _ => 0, };
let chunk_size = match optimization_tier {
4 => 1024, 3 => 512, 2 => 256, 1 => 128, _ => 64, };
Self {
has_bmi1,
has_bmi2,
has_popcnt,
has_avx2,
optimization_tier,
chunk_size,
}
}
#[inline]
pub fn get() -> &'static Self {
static CAPABILITIES: std::sync::OnceLock<Bmi2Capabilities> = std::sync::OnceLock::new();
CAPABILITIES.get_or_init(Self::detect)
}
#[cfg(target_arch = "x86_64")]
fn detect_bmi1() -> bool {
is_x86_feature_detected!("bmi1")
}
#[cfg(not(target_arch = "x86_64"))]
fn detect_bmi1() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
fn detect_bmi2() -> bool {
is_x86_feature_detected!("bmi2")
}
#[cfg(not(target_arch = "x86_64"))]
fn detect_bmi2() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
fn detect_popcnt() -> bool {
is_x86_feature_detected!("popcnt")
}
#[cfg(not(target_arch = "x86_64"))]
fn detect_popcnt() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
fn detect_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
fn detect_avx2() -> bool {
false
}
}
pub struct Bmi2BitOps;
impl Bmi2BitOps {
#[cfg(target_arch = "x86_64")]
pub fn select1_ultra_fast(word: u64, rank: usize) -> Option<usize> {
if rank == 0 || word == 0 {
return None;
}
let caps = Bmi2Capabilities::get();
if !caps.has_bmi2 {
return Self::select1_fallback(word, rank);
}
unsafe {
let mask = 1u64 << (rank - 1);
let deposited = std::arch::x86_64::_pdep_u64(mask, word);
if deposited != 0 {
Some(std::arch::x86_64::_tzcnt_u64(deposited) as usize)
} else {
None
}
}
}
pub fn select1_fallback(word: u64, rank: usize) -> Option<usize> {
if rank == 0 || word == 0 {
return None;
}
let mut count = 0;
let mut current = word;
while current != 0 {
let pos = current.trailing_zeros() as usize;
current &= current - 1; count += 1;
if count == rank {
return Some(pos);
}
}
None
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn rank1_optimized(word: u64, pos: usize) -> usize {
if pos >= 64 {
return word.count_ones() as usize;
}
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
unsafe {
let masked = std::arch::x86_64::_bzhi_u64(word, pos as u32);
if caps.has_popcnt {
std::arch::x86_64::_popcnt64(masked as i64) as usize
} else {
masked.count_ones() as usize
}
}
} else if caps.has_popcnt {
let mask = if pos == 0 { 0 } else { (1u64 << pos) - 1 };
let masked = word & mask;
unsafe {
std::arch::x86_64::_popcnt64(masked as i64) as usize
}
} else {
let mask = if pos == 0 { 0 } else { (1u64 << pos) - 1 };
(word & mask).count_ones() as usize
}
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn rank1_optimized(word: u64, pos: usize) -> usize {
if pos >= 64 {
return word.count_ones() as usize;
}
let mask = if pos == 0 { 0 } else { (1u64 << pos) - 1 };
(word & mask).count_ones() as usize
}
#[cfg(target_arch = "x86_64")]
pub fn extract_bits_pext(data: u64, mask: u64) -> u64 {
let caps = Bmi2Capabilities::get();
if caps.has_bmi2 {
unsafe {
std::arch::x86_64::_pext_u64(data, mask)
}
} else {
let mut result = 0u64;
let mut result_pos = 0;
let mut mask_copy = mask;
let mut data_copy = data;
while mask_copy != 0 {
if mask_copy & 1 != 0 {
result |= (data_copy & 1) << result_pos;
result_pos += 1;
}
mask_copy >>= 1;
data_copy >>= 1;
}
result
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn extract_bits_pext(data: u64, mask: u64) -> u64 {
let mut result = 0u64;
let mut result_pos = 0;
let mut mask_copy = mask;
let mut data_copy = data;
while mask_copy != 0 {
if mask_copy & 1 != 0 {
result |= (data_copy & 1) << result_pos;
result_pos += 1;
}
mask_copy >>= 1;
data_copy >>= 1;
}
result
}
#[cfg(target_arch = "x86_64")]
pub fn trailing_zeros_optimized(word: u64) -> u32 {
if word == 0 {
return 64;
}
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
unsafe {
std::arch::x86_64::_tzcnt_u64(word) as u32
}
} else {
word.trailing_zeros()
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn trailing_zeros_optimized(word: u64) -> u32 {
word.trailing_zeros()
}
#[cfg(target_arch = "x86_64")]
pub fn leading_zeros_optimized(word: u64) -> u32 {
if word == 0 {
return 64;
}
let caps = Bmi2Capabilities::get();
if caps.has_bmi1 {
unsafe {
std::arch::x86_64::_lzcnt_u64(word) as u32
}
} else {
word.leading_zeros()
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn leading_zeros_optimized(word: u64) -> u32 {
word.leading_zeros()
}
}
pub struct Bmi2BlockOps;
impl Bmi2BlockOps {
pub fn bulk_rank1(words: &[u64], positions: &[usize]) -> Vec<usize> {
let caps = Bmi2Capabilities::get();
let chunk_size = caps.chunk_size.min(256);
let mut results = Vec::with_capacity(positions.len());
for chunk in positions.chunks(chunk_size) {
Self::prefetch_words(words, chunk);
for &pos in chunk {
let word_idx = pos / 64;
let bit_offset = pos % 64;
if word_idx < words.len() {
let rank = Bmi2BitOps::rank1_optimized(words[word_idx], bit_offset);
results.push(rank);
} else {
results.push(0);
}
}
}
results
}
pub fn bulk_select1(words: &[u64], ranks: &[usize]) -> Result<Vec<usize>> {
let caps = Bmi2Capabilities::get();
let chunk_size = caps.chunk_size.min(128);
let mut results = Vec::with_capacity(ranks.len());
for chunk in ranks.chunks(chunk_size) {
for &rank in chunk {
let mut total_ones = 0;
let mut found = false;
for (word_idx, &word) in words.iter().enumerate() {
let word_ones = word.count_ones() as usize;
if total_ones + word_ones >= rank {
let local_rank = rank - total_ones;
#[cfg(target_arch = "x86_64")]
{
if let Some(bit_pos) = Bmi2BitOps::select1_ultra_fast(word, local_rank) {
results.push(word_idx * 64 + bit_pos);
found = true;
break;
}
}
#[cfg(not(target_arch = "x86_64"))]
{
if let Some(bit_pos) = Bmi2BitOps::select1_fallback(word, local_rank) {
results.push(word_idx * 64 + bit_pos);
found = true;
break;
}
}
}
total_ones += word_ones;
}
if !found {
return Err(ZiporaError::invalid_data(format!("Select rank {} not found", rank)));
}
}
}
Ok(results)
}
#[cfg(target_arch = "x86_64")]
fn prefetch_words(words: &[u64], positions: &[usize]) {
for &pos in positions {
let word_idx = pos / 64;
if word_idx < words.len() {
unsafe {
let ptr = words.as_ptr().add(word_idx);
std::arch::x86_64::_mm_prefetch(
ptr as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
}
}
}
#[cfg(not(target_arch = "x86_64"))]
fn prefetch_words(_words: &[u64], _positions: &[usize]) {
}
}
#[derive(Debug)]
pub struct Bmi2Stats {
pub total_operations: AtomicU64,
pub hardware_accelerated: AtomicU64,
pub fallback_operations: AtomicU64,
pub cache_hit_rate: f64,
}
impl Clone for Bmi2Stats {
fn clone(&self) -> Self {
Self {
total_operations: AtomicU64::new(self.total_operations.load(Ordering::Relaxed)),
hardware_accelerated: AtomicU64::new(self.hardware_accelerated.load(Ordering::Relaxed)),
fallback_operations: AtomicU64::new(self.fallback_operations.load(Ordering::Relaxed)),
cache_hit_rate: self.cache_hit_rate,
}
}
}
impl Bmi2Stats {
pub fn new() -> Self {
Self {
total_operations: AtomicU64::new(0),
hardware_accelerated: AtomicU64::new(0),
fallback_operations: AtomicU64::new(0),
cache_hit_rate: 0.95, }
}
pub fn record_hardware_operation(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
self.hardware_accelerated.fetch_add(1, Ordering::Relaxed);
}
pub fn record_fallback_operation(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
self.fallback_operations.fetch_add(1, Ordering::Relaxed);
}
pub fn hardware_acceleration_ratio(&self) -> f64 {
let total = self.total_operations.load(Ordering::Relaxed) as f64;
let hw = self.hardware_accelerated.load(Ordering::Relaxed) as f64;
if total == 0.0 {
0.0
} else {
hw / total
}
}
pub fn reset(&self) {
self.total_operations.store(0, Ordering::Relaxed);
self.hardware_accelerated.store(0, Ordering::Relaxed);
self.fallback_operations.store(0, Ordering::Relaxed);
}
}
impl Default for Bmi2Stats {
fn default() -> Self {
Self::new()
}
}
pub struct Bmi2SequenceOps;
impl Bmi2SequenceOps {
pub fn analyze_bit_patterns(words: &[u64]) -> SequenceAnalysis {
if words.is_empty() {
return SequenceAnalysis::default();
}
let mut total_ones = 0;
let mut sparse_words = 0;
let mut dense_words = 0;
let mut consecutive_patterns = 0;
for (i, &word) in words.iter().enumerate() {
let ones = word.count_ones() as usize;
total_ones += ones;
if ones < 8 {
sparse_words += 1;
} else if ones > 56 {
dense_words += 1;
}
if i > 0 && Self::has_consecutive_pattern(words[i - 1], word) {
consecutive_patterns += 1;
}
}
let total_bits = words.len() * 64;
let density = total_ones as f64 / total_bits as f64;
let sparsity_ratio = sparse_words as f64 / words.len() as f64;
let density_ratio = dense_words as f64 / words.len() as f64;
let consecutive_ratio = consecutive_patterns as f64 / words.len().saturating_sub(1) as f64;
SequenceAnalysis {
total_words: words.len(),
total_ones,
density,
sparsity_ratio,
density_ratio,
consecutive_ratio,
recommended_strategy: Self::recommend_strategy(density, sparsity_ratio, consecutive_ratio),
optimal_chunk_size: Self::recommend_chunk_size(words.len(), density),
}
}
fn has_consecutive_pattern(word1: u64, word2: u64) -> bool {
let diff = word1 ^ word2;
diff.count_ones() <= 8 }
fn recommend_strategy(density: f64, sparsity_ratio: f64, consecutive_ratio: f64) -> OptimizationStrategy {
match (density, sparsity_ratio, consecutive_ratio) {
(d, s, _) if d < 0.1 && s > 0.7 => OptimizationStrategy::SparseLinear,
(d, _, c) if d > 0.8 && c > 0.5 => OptimizationStrategy::DenseSequential,
(d, _, _) if d < 0.3 => OptimizationStrategy::SparseBinary,
(d, _, _) if d > 0.6 => OptimizationStrategy::DenseBinary,
_ => OptimizationStrategy::Balanced,
}
}
fn recommend_chunk_size(total_words: usize, density: f64) -> usize {
let caps = Bmi2Capabilities::get();
let base_chunk = caps.chunk_size;
match (total_words, density) {
(n, d) if n < 100 && d < 0.1 => base_chunk / 4, (n, d) if n < 100 && d > 0.9 => base_chunk / 2, (_, d) if d < 0.1 => base_chunk, (_, d) if d > 0.9 => base_chunk * 2, _ => base_chunk, }
}
}
#[derive(Debug, Clone)]
pub struct SequenceAnalysis {
pub total_words: usize,
pub total_ones: usize,
pub density: f64,
pub sparsity_ratio: f64,
pub density_ratio: f64,
pub consecutive_ratio: f64,
pub recommended_strategy: OptimizationStrategy,
pub optimal_chunk_size: usize,
}
impl Default for SequenceAnalysis {
fn default() -> Self {
Self {
total_words: 0,
total_ones: 0,
density: 0.0,
sparsity_ratio: 0.0,
density_ratio: 0.0,
consecutive_ratio: 0.0,
recommended_strategy: OptimizationStrategy::Balanced,
optimal_chunk_size: 256,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationStrategy {
SparseLinear,
SparseBinary,
Balanced,
DenseBinary,
DenseSequential,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bmi2_capabilities_detection() {
let caps = Bmi2Capabilities::detect();
assert!(caps.optimization_tier <= 4);
assert!(caps.chunk_size >= 64);
assert!(caps.chunk_size <= 1024);
if caps.has_bmi2 {
assert!(caps.has_bmi1, "BMI2 should imply BMI1");
}
if caps.optimization_tier >= 3 {
assert!(caps.has_bmi2, "Tier 3+ should have BMI2");
}
println!("BMI2 Capabilities: {:?}", caps);
}
#[test]
fn test_bmi2_bit_operations() {
let word = 0b1010101010101010u64;
let select_result = Bmi2BitOps::select1_fallback(word, 1);
assert_eq!(select_result, Some(1));
let select_result = Bmi2BitOps::select1_fallback(word, 4);
assert_eq!(select_result, Some(7));
let rank_result = Bmi2BitOps::rank1_optimized(word, 8);
assert_eq!(rank_result, 4);
let rank_result = Bmi2BitOps::rank1_optimized(word, 16);
assert_eq!(rank_result, 8);
assert_eq!(Bmi2BitOps::select1_fallback(0, 1), None);
assert_eq!(Bmi2BitOps::select1_fallback(word, 0), None);
assert_eq!(Bmi2BitOps::rank1_optimized(0, 32), 0);
}
#[test]
fn test_bmi2_block_operations() {
let words = vec![
0b1111000011110000u64,
0b0000111100001111u64,
0b1010101010101010u64,
0b0101010101010101u64,
];
let positions = vec![4, 12, 20, 28];
let ranks = Bmi2BlockOps::bulk_rank1(&words, &positions);
assert_eq!(ranks.len(), positions.len());
let target_ranks = vec![1, 2, 4, 8];
let selects = Bmi2BlockOps::bulk_select1(&words, &target_ranks);
assert!(selects.is_ok());
let select_results = selects.unwrap();
assert_eq!(select_results.len(), target_ranks.len());
assert!(select_results[0] < 64 * words.len());
}
#[test]
fn test_bmi2_sequence_analysis() {
let sparse_words = vec![
0b0000000000000001u64,
0b0000000000000010u64,
0b0000000000000100u64,
0b0000000000001000u64,
];
let analysis = Bmi2SequenceOps::analyze_bit_patterns(&sparse_words);
assert!(analysis.density < 0.1);
assert!(analysis.sparsity_ratio > 0.5);
assert_eq!(analysis.recommended_strategy, OptimizationStrategy::SparseLinear);
let dense_words = vec![
0xFFFFFFFFFFFFFFFEu64,
0xFFFFFFFFFFFFFFFDu64,
0xFFFFFFFFFFFFFFFBu64,
0xFFFFFFFFFFFFFFF7u64,
];
let analysis = Bmi2SequenceOps::analyze_bit_patterns(&dense_words);
assert!(analysis.density > 0.9);
assert!(analysis.density_ratio > 0.5);
let empty_words = vec![];
let analysis = Bmi2SequenceOps::analyze_bit_patterns(&empty_words);
assert_eq!(analysis.total_words, 0);
assert_eq!(analysis.density, 0.0);
}
#[test]
fn test_bmi2_stats() {
let stats = Bmi2Stats::new();
assert_eq!(stats.hardware_acceleration_ratio(), 0.0);
stats.record_hardware_operation();
stats.record_hardware_operation();
stats.record_fallback_operation();
assert_eq!(stats.total_operations.load(Ordering::Relaxed), 3);
assert_eq!(stats.hardware_accelerated.load(Ordering::Relaxed), 2);
assert_eq!(stats.fallback_operations.load(Ordering::Relaxed), 1);
let ratio = stats.hardware_acceleration_ratio();
assert!((ratio - 2.0/3.0).abs() < 0.001);
stats.reset();
assert_eq!(stats.total_operations.load(Ordering::Relaxed), 0);
assert_eq!(stats.hardware_acceleration_ratio(), 0.0);
}
#[test]
fn test_bmi2_bit_manipulation() {
assert_eq!(Bmi2BitOps::trailing_zeros_optimized(0b1000), 3);
assert_eq!(Bmi2BitOps::trailing_zeros_optimized(0), 64);
assert_eq!(Bmi2BitOps::leading_zeros_optimized(0b1000), 60);
assert_eq!(Bmi2BitOps::leading_zeros_optimized(0), 64);
let data = 0b11010110u64;
let mask = 0b11001100u64;
let extracted = Bmi2BitOps::extract_bits_pext(data, mask);
assert_eq!(extracted, 0b1101);
}
}