use crate::error::{Result, ZiporaError};
use std::sync::OnceLock;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::{
uint8x8_t, uint8x16_t, vaddvq_u8, vcntq_u8, vcombine_u8, vdup_n_u8, vget_high_u8, vget_low_u8,
vld1q_u8,
};
pub use crate::system::{CpuFeatures, get_cpu_features};
static SIMD_FEATURES: OnceLock<SimdCapabilities> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct SimdCapabilities {
pub cpu_features: &'static CpuFeatures,
pub optimization_tier: u8,
pub chunk_size: usize,
pub use_prefetch: bool,
}
impl SimdCapabilities {
pub fn detect() -> Self {
let cpu_features = get_cpu_features();
let (optimization_tier, chunk_size, use_prefetch) =
Self::determine_optimization_strategy(cpu_features);
Self {
cpu_features,
optimization_tier,
chunk_size,
use_prefetch,
}
}
fn determine_optimization_strategy(features: &CpuFeatures) -> (u8, usize, bool) {
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
if features.has_avx512vpopcntdq && features.has_avx512bw {
return (5, 64 * 1024, true); }
#[cfg(target_arch = "x86_64")]
if features.has_avx2 {
return (4, 32 * 1024, true); }
#[cfg(target_arch = "x86_64")]
if features.has_bmi2 {
return (3, 16 * 1024, true); }
#[cfg(target_arch = "x86_64")]
if features.has_popcnt {
return (2, 8 * 1024, true); }
#[cfg(target_arch = "aarch64")]
if cfg!(feature = "simd") && std::arch::is_aarch64_feature_detected!("neon") {
return (1, 16 * 1024, false); }
(0, 4 * 1024, false) }
#[inline]
pub fn get() -> &'static SimdCapabilities {
SIMD_FEATURES.get_or_init(Self::detect)
}
}
pub trait SimdOps {
fn rank1_bulk_simd(&self, positions: &[usize]) -> Vec<usize> {
let bit_data = self.get_bit_data();
bulk_rank1_simd(bit_data, positions)
}
fn select1_bulk_simd(&self, indices: &[usize]) -> Result<Vec<usize>> {
let bit_data = self.get_bit_data();
bulk_select1_simd(bit_data, indices)
}
fn get_bit_data(&self) -> &[u64];
}
pub fn bulk_rank1_simd(bit_data: &[u64], positions: &[usize]) -> Vec<usize> {
if positions.is_empty() {
return Vec::new();
}
let capabilities = SimdCapabilities::get();
match capabilities.optimization_tier {
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
5 => bulk_rank1_avx512(bit_data, positions, capabilities.chunk_size),
#[cfg(target_arch = "x86_64")]
4 => bulk_rank1_avx2(bit_data, positions, capabilities.chunk_size),
#[cfg(target_arch = "x86_64")]
3 | 2 => bulk_rank1_popcnt(bit_data, positions, capabilities.use_prefetch),
#[cfg(target_arch = "aarch64")]
1 => bulk_rank1_neon(bit_data, positions),
_ => bulk_rank1_scalar(bit_data, positions),
}
}
pub fn bulk_select1_simd(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
if indices.is_empty() {
return Ok(Vec::new());
}
let capabilities = SimdCapabilities::get();
match capabilities.optimization_tier {
#[cfg(target_arch = "x86_64")]
5 | 4 | 3 if capabilities.cpu_features.has_bmi2 => bulk_select1_bmi2(bit_data, indices),
#[cfg(target_arch = "x86_64")]
4 => bulk_select1_avx2(bit_data, indices),
#[cfg(target_arch = "x86_64")]
2 => bulk_select1_popcnt(bit_data, indices),
#[cfg(target_arch = "aarch64")]
1 => bulk_select1_neon(bit_data, indices),
_ => bulk_select1_scalar(bit_data, indices),
}
}
pub fn bulk_popcount_simd(bit_data: &[u64]) -> Vec<usize> {
if bit_data.is_empty() {
return Vec::new();
}
let capabilities = SimdCapabilities::get();
match capabilities.optimization_tier {
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
5 => bulk_popcount_avx512(bit_data),
#[cfg(target_arch = "x86_64")]
4 => bulk_popcount_avx2(bit_data),
#[cfg(target_arch = "x86_64")]
3 | 2 => bulk_popcount_popcnt(bit_data),
#[cfg(target_arch = "aarch64")]
1 => bulk_popcount_neon(bit_data),
_ => bulk_popcount_scalar(bit_data),
}
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
fn bulk_rank1_avx512(bit_data: &[u64], positions: &[usize], chunk_size: usize) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let word_index = pos / 64;
let bit_offset = pos % 64;
if word_index >= bit_data.len() {
results.push(0);
continue;
}
let mut rank = 0usize;
let complete_words = word_index;
if complete_words >= 8 {
let chunks = complete_words / 8;
let remainder = complete_words % 8;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 8;
let data_ptr = bit_data.as_ptr().add(base_idx);
let vec = _mm512_loadu_si512(data_ptr as *const __m512i);
let popcounts = _mm512_popcnt_epi64(vec);
let mut sum_array = [0u64; 8];
_mm512_storeu_si512(sum_array.as_mut_ptr() as *mut __m512i, popcounts);
rank += sum_array.iter().sum::<u64>() as usize;
}
}
for i in chunks * 8..chunks * 8 + remainder {
rank += bit_data[i].count_ones() as usize;
}
} else {
for i in 0..complete_words {
rank += bit_data[i].count_ones() as usize;
}
}
if bit_offset > 0 && word_index < bit_data.len() {
let mask = (1u64 << bit_offset) - 1;
let masked_word = bit_data[word_index] & mask;
rank += masked_word.count_ones() as usize;
}
results.push(rank);
}
results
}
#[cfg(target_arch = "x86_64")]
fn bulk_rank1_avx2(bit_data: &[u64], positions: &[usize], _chunk_size: usize) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let word_index = pos / 64;
let bit_offset = pos % 64;
if word_index >= bit_data.len() {
results.push(0);
continue;
}
let mut rank = 0usize;
let complete_words = word_index;
if complete_words >= 4 {
let chunks = complete_words / 4;
let remainder = complete_words % 4;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 4;
if chunk + 1 < chunks {
let next_ptr = bit_data.as_ptr().add(base_idx + 4);
_mm_prefetch(next_ptr as *const i8, _MM_HINT_T0);
}
let ptr = bit_data.as_ptr().add(base_idx);
let vec = _mm256_loadu_si256(ptr as *const __m256i);
let vals = std::mem::transmute::<__m256i, [u64; 4]>(vec);
for val in vals {
rank += _popcnt64(val as i64) as usize;
}
}
}
for i in chunks * 4..chunks * 4 + remainder {
unsafe {
rank += _popcnt64(bit_data[i] as i64) as usize;
}
}
} else {
for i in 0..complete_words {
unsafe {
rank += _popcnt64(bit_data[i] as i64) as usize;
}
}
}
if bit_offset > 0 && word_index < bit_data.len() {
let mask = (1u64 << bit_offset) - 1;
let masked_word = bit_data[word_index] & mask;
unsafe {
rank += _popcnt64(masked_word as i64) as usize;
}
}
results.push(rank);
}
results
}
#[cfg(target_arch = "x86_64")]
fn bulk_rank1_popcnt(bit_data: &[u64], positions: &[usize], use_prefetch: bool) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let word_index = pos / 64;
let bit_offset = pos % 64;
if word_index >= bit_data.len() {
results.push(0);
continue;
}
let mut rank = 0usize;
for i in 0..word_index {
if use_prefetch && i % 8 == 0 && i + 8 < word_index {
unsafe {
let prefetch_ptr = bit_data.as_ptr().add(i + 8);
_mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0);
}
}
unsafe {
rank += _popcnt64(bit_data[i] as i64) as usize;
}
}
if bit_offset > 0 && word_index < bit_data.len() {
let mask = (1u64 << bit_offset) - 1;
let masked_word = bit_data[word_index] & mask;
unsafe {
rank += _popcnt64(masked_word as i64) as usize;
}
}
results.push(rank);
}
results
}
#[cfg(target_arch = "x86_64")]
fn bulk_select1_bmi2(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
let mut results = Vec::with_capacity(indices.len());
let total_ones = bulk_popcount_simd(bit_data).iter().sum::<usize>();
for &index in indices {
if index >= total_ones {
return Err(ZiporaError::invalid_data(format!(
"Select index {} exceeds available set bits {}",
index, total_ones
)));
}
let target_rank = index + 1; let mut word_idx = 0;
let mut cumulative_rank = 0;
let mut left = 0;
let mut right = bit_data.len();
while left < right {
let mid = (left + right) / 2;
let rank_at_mid = bulk_rank1_simd(bit_data, &[mid * 64])
.get(0)
.copied()
.unwrap_or(0);
if rank_at_mid < target_rank {
left = mid + 1;
} else {
right = mid;
}
}
word_idx = if left > 0 { left - 1 } else { 0 };
if word_idx > 0 {
cumulative_rank = bulk_rank1_simd(bit_data, &[word_idx * 64])
.get(0)
.copied()
.unwrap_or(0);
}
let remaining_rank = target_rank - cumulative_rank;
if remaining_rank > 0 && word_idx < bit_data.len() {
let word = bit_data[word_idx];
unsafe {
let mask = (1u64 << remaining_rank) - 1;
let selected_bits = _pdep_u64(mask, word);
if selected_bits != 0 {
let bit_pos = selected_bits.trailing_zeros() as usize;
results.push(word_idx * 64 + bit_pos);
} else {
let mut count = 0;
for bit_pos in 0..64 {
if (word >> bit_pos) & 1 == 1 {
count += 1;
if count == remaining_rank {
results.push(word_idx * 64 + bit_pos);
break;
}
}
}
}
}
}
}
Ok(results)
}
#[cfg(target_arch = "x86_64")]
fn bulk_select1_avx2(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
bulk_select1_popcnt(bit_data, indices)
}
#[cfg(target_arch = "x86_64")]
fn bulk_select1_popcnt(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
let mut results = Vec::with_capacity(indices.len());
let total_ones = bulk_popcount_simd(bit_data).iter().sum::<usize>();
for &index in indices {
if index >= total_ones {
return Err(ZiporaError::invalid_data(format!(
"Select index {} exceeds available set bits {}",
index, total_ones
)));
}
let mut target_rank = index + 1; let mut current_rank = 0;
for (word_idx, &word) in bit_data.iter().enumerate() {
unsafe {
let word_popcount = _popcnt64(word as i64) as usize;
if current_rank + word_popcount >= target_rank {
let remaining = target_rank - current_rank;
let mut bit_count = 0;
for bit_pos in 0..64 {
if (word >> bit_pos) & 1 == 1 {
bit_count += 1;
if bit_count == remaining {
results.push(word_idx * 64 + bit_pos);
break;
}
}
}
break;
}
current_rank += word_popcount;
}
}
}
Ok(results)
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
fn bulk_popcount_avx512(bit_data: &[u64]) -> Vec<usize> {
let mut results = Vec::with_capacity(bit_data.len());
let chunks = bit_data.len() / 8;
let remainder = bit_data.len() % 8;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 8;
let data_ptr = bit_data.as_ptr().add(base_idx);
let vec = _mm512_loadu_si512(data_ptr as *const __m512i);
let popcounts = _mm512_popcnt_epi64(vec);
let mut result_array = [0u64; 8];
_mm512_storeu_si512(result_array.as_mut_ptr() as *mut __m512i, popcounts);
for count in result_array {
results.push(count as usize);
}
}
for i in chunks * 8..chunks * 8 + remainder {
results.push(_popcnt64(bit_data[i] as i64) as usize);
}
}
results
}
#[cfg(target_arch = "x86_64")]
fn bulk_popcount_avx2(bit_data: &[u64]) -> Vec<usize> {
let mut results = Vec::with_capacity(bit_data.len());
let chunks = bit_data.len() / 4;
let remainder = bit_data.len() % 4;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 4;
let ptr = bit_data.as_ptr().add(base_idx);
let vec = _mm256_loadu_si256(ptr as *const __m256i);
let vals = std::mem::transmute::<__m256i, [u64; 4]>(vec);
for val in vals {
results.push(_popcnt64(val as i64) as usize);
}
}
for i in chunks * 4..chunks * 4 + remainder {
results.push(_popcnt64(bit_data[i] as i64) as usize);
}
}
results
}
#[cfg(target_arch = "x86_64")]
fn bulk_popcount_popcnt(bit_data: &[u64]) -> Vec<usize> {
bit_data
.iter()
.map(|&word| unsafe { _popcnt64(word as i64) as usize })
.collect()
}
#[cfg(target_arch = "aarch64")]
fn bulk_rank1_neon(bit_data: &[u64], positions: &[usize]) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let word_index = pos / 64;
let bit_offset = pos % 64;
if word_index >= bit_data.len() {
results.push(0);
continue;
}
let mut rank = 0usize;
let chunks = word_index / 2;
let remainder = word_index % 2;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 2;
let ptr = bit_data.as_ptr().add(base_idx) as *const u8;
let vec = unsafe { vld1q_u8(ptr) };
let popcount_vec = vcntq_u8(vec);
let total_bits = vaddvq_u8(popcount_vec) as usize;
rank += total_bits;
}
if remainder > 0 {
let word_idx = chunks * 2;
rank += bit_data[word_idx].count_ones() as usize;
}
}
if bit_offset > 0 && word_index < bit_data.len() {
let mask = (1u64 << bit_offset) - 1;
let masked_word = bit_data[word_index] & mask;
rank += masked_word.count_ones() as usize;
}
results.push(rank);
}
results
}
#[cfg(target_arch = "aarch64")]
fn bulk_select1_neon(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
bulk_select1_scalar(bit_data, indices)
}
#[cfg(target_arch = "aarch64")]
fn bulk_popcount_neon(bit_data: &[u64]) -> Vec<usize> {
let mut results = Vec::with_capacity(bit_data.len());
let chunks = bit_data.len() / 2;
let remainder = bit_data.len() % 2;
unsafe {
for chunk in 0..chunks {
let base_idx = chunk * 2;
let ptr = bit_data.as_ptr().add(base_idx) as *const u8;
let vec = unsafe { vld1q_u8(ptr) };
let popcount_vec = vcntq_u8(vec);
let low_half = vget_low_u8(popcount_vec);
let high_half = vget_high_u8(popcount_vec);
let low_sum = vaddvq_u8(vcombine_u8(low_half, vdup_n_u8(0))) as usize;
let high_sum = vaddvq_u8(vcombine_u8(high_half, vdup_n_u8(0))) as usize;
results.push(low_sum);
results.push(high_sum);
}
if remainder > 0 {
let word_idx = chunks * 2;
results.push(bit_data[word_idx].count_ones() as usize);
}
}
results
}
fn bulk_rank1_scalar(bit_data: &[u64], positions: &[usize]) -> Vec<usize> {
let mut results = Vec::with_capacity(positions.len());
for &pos in positions {
let word_index = pos / 64;
let bit_offset = pos % 64;
if word_index >= bit_data.len() {
results.push(0);
continue;
}
let mut rank = 0usize;
for i in 0..word_index {
rank += bit_data[i].count_ones() as usize;
}
if bit_offset > 0 && word_index < bit_data.len() {
let mask = (1u64 << bit_offset) - 1;
let masked_word = bit_data[word_index] & mask;
rank += masked_word.count_ones() as usize;
}
results.push(rank);
}
results
}
fn bulk_select1_scalar(bit_data: &[u64], indices: &[usize]) -> Result<Vec<usize>> {
let mut results = Vec::with_capacity(indices.len());
let total_ones: usize = bit_data.iter().map(|w| w.count_ones() as usize).sum();
for &index in indices {
if index >= total_ones {
return Err(ZiporaError::invalid_data(format!(
"Select index {} exceeds available set bits {}",
index, total_ones
)));
}
let mut target_rank = index + 1; let mut current_rank = 0;
for (word_idx, &word) in bit_data.iter().enumerate() {
let word_popcount = word.count_ones() as usize;
if current_rank + word_popcount >= target_rank {
let remaining = target_rank - current_rank;
let mut bit_count = 0;
for bit_pos in 0..64 {
if (word >> bit_pos) & 1 == 1 {
bit_count += 1;
if bit_count == remaining {
results.push(word_idx * 64 + bit_pos);
break;
}
}
}
break;
}
current_rank += word_popcount;
}
}
Ok(results)
}
fn bulk_popcount_scalar(bit_data: &[u64]) -> Vec<usize> {
bit_data
.iter()
.map(|&word| word.count_ones() as usize)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_data() -> Vec<u64> {
vec![
0xAAAAAAAAAAAAAAAAu64, 0x5555555555555555u64, 0xFFFFFFFFFFFFFFFFu64, 0x0000000000000000u64, 0x8000000000000001u64, ]
}
#[test]
fn test_simd_capabilities_detection() {
let caps = SimdCapabilities::detect();
assert!(caps.optimization_tier <= 5);
assert!(caps.chunk_size > 0);
assert!(caps.chunk_size <= 64 * 1024);
let caps2 = SimdCapabilities::get();
assert_eq!(caps.optimization_tier, caps2.optimization_tier);
assert_eq!(caps.chunk_size, caps2.chunk_size);
}
#[test]
fn test_bulk_rank1_simd() {
let bit_data = create_test_data();
let positions = vec![0, 1, 63, 64, 65, 127, 128, 191, 192, 255, 256];
let ranks = bulk_rank1_simd(&bit_data, &positions);
assert_eq!(ranks.len(), positions.len());
assert_eq!(ranks[0], 0); assert_eq!(ranks[1], 0); assert_eq!(ranks[3], 32); assert_eq!(ranks[7], 127); assert_eq!(ranks[8], 128); }
#[test]
fn test_bulk_select1_simd() {
let bit_data = create_test_data();
let indices = vec![0, 1, 31, 32, 63];
let result = bulk_select1_simd(&bit_data, &indices);
assert!(result.is_ok());
let selects = result.unwrap();
assert_eq!(selects.len(), indices.len());
assert_eq!(selects[0], 1);
}
#[test]
fn test_bulk_select1_simd_invalid_index() {
let bit_data = create_test_data();
let total_ones: usize = bit_data.iter().map(|w| w.count_ones() as usize).sum();
let indices = vec![total_ones + 1];
let result = bulk_select1_simd(&bit_data, &indices);
assert!(result.is_err());
}
#[test]
fn test_bulk_popcount_simd() {
let bit_data = create_test_data();
let popcounts = bulk_popcount_simd(&bit_data);
assert_eq!(popcounts.len(), bit_data.len());
assert_eq!(popcounts[0], 32); assert_eq!(popcounts[1], 32); assert_eq!(popcounts[2], 64); assert_eq!(popcounts[3], 0); assert_eq!(popcounts[4], 2); }
#[test]
fn test_empty_inputs() {
let bit_data = vec![];
let positions = vec![];
let indices = vec![];
assert_eq!(bulk_rank1_simd(&bit_data, &positions), Vec::<usize>::new());
assert_eq!(bulk_select1_simd(&bit_data, &indices).unwrap(), Vec::<usize>::new());
assert_eq!(bulk_popcount_simd(&bit_data), Vec::<usize>::new());
}
#[test]
fn test_scalar_vs_simd_consistency() {
let bit_data = create_test_data();
let positions = vec![0, 1, 32, 64, 100, 200];
let scalar_ranks = bulk_rank1_scalar(&bit_data, &positions);
let simd_ranks = bulk_rank1_simd(&bit_data, &positions);
assert_eq!(scalar_ranks, simd_ranks);
let scalar_popcounts = bulk_popcount_scalar(&bit_data);
let simd_popcounts = bulk_popcount_simd(&bit_data);
assert_eq!(scalar_popcounts, simd_popcounts);
}
#[test]
fn test_large_dataset_performance() {
let bit_data: Vec<u64> = (0..1000)
.map(|i| match i % 4 {
0 => 0xAAAAAAAAAAAAAAAAu64,
1 => 0x5555555555555555u64,
2 => 0xFFFFFFFFFFFFFFFFu64,
_ => 0x0000000000000000u64,
})
.collect();
let positions: Vec<usize> = (0..500).map(|i| i * 100).collect();
let indices: Vec<usize> = (0..100).map(|i| i * 100).collect();
let _ranks = bulk_rank1_simd(&bit_data, &positions);
let _selects = bulk_select1_simd(&bit_data, &indices);
let _popcounts = bulk_popcount_simd(&bit_data);
}
}