const SIMD_THRESHOLD: usize = 16;
#[inline]
pub fn popcount_slice(words: &[u64]) -> usize {
if words.len() < SIMD_THRESHOLD {
return popcount_scalar(words);
}
#[cfg(all(feature = "avx512", target_arch = "x86_64"))]
{
if std::arch::is_x86_feature_detected!("avx512vpopcntdq")
&& std::arch::is_x86_feature_detected!("avx512f")
{
return unsafe { popcount_avx512(words) };
}
}
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("popcnt") {
return unsafe { popcount_hw(words) };
}
if std::arch::is_x86_feature_detected!("avx2") {
return unsafe { popcount_avx2(words) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { popcount_neon(words) };
}
popcount_scalar(words)
}
#[inline]
fn popcount_scalar(words: &[u64]) -> usize {
words.iter().map(|w| w.count_ones() as usize).sum()
}
#[cfg(all(feature = "avx512", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
unsafe fn popcount_avx512(words: &[u64]) -> usize {
use std::arch::x86_64::*;
let chunks = words.len() / 8;
unsafe {
let ptr = words.as_ptr() as *const __m512i;
let mut acc = _mm512_setzero_si512();
for i in 0..chunks {
let v = _mm512_loadu_si512(ptr.add(i));
acc = _mm512_add_epi64(acc, _mm512_popcnt_epi64(v));
}
let mut buf = [0u64; 8];
_mm512_storeu_si512(buf.as_mut_ptr() as *mut _, acc);
let mut sum: usize = buf.iter().sum::<u64>() as usize;
for &w in &words[chunks * 8..] {
sum += w.count_ones() as usize;
}
sum
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn popcount_avx2(words: &[u64]) -> usize {
use std::arch::x86_64::*;
let bytes = words.as_ptr() as *const u8;
let total_bytes = words.len() * 8;
let chunks = total_bytes / 32;
unsafe {
let lut = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
);
let lo_mask = _mm256_set1_epi8(0x0F);
let mut total_acc = _mm256_setzero_si256(); let mut local_acc = _mm256_setzero_si256(); let mut since_reduce = 0u32;
for i in 0..chunks {
since_reduce += 1;
if since_reduce == 31 {
total_acc = _mm256_add_epi64(
total_acc,
_mm256_sad_epu8(local_acc, _mm256_setzero_si256()),
);
local_acc = _mm256_setzero_si256();
since_reduce = 0;
}
let v = _mm256_loadu_si256(bytes.add(i * 32) as *const __m256i);
let lo = _mm256_and_si256(v, lo_mask);
let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), lo_mask);
let popcnt_lo = _mm256_shuffle_epi8(lut, lo);
let popcnt_hi = _mm256_shuffle_epi8(lut, hi);
local_acc = _mm256_add_epi8(local_acc, _mm256_add_epi8(popcnt_lo, popcnt_hi));
}
total_acc = _mm256_add_epi64(
total_acc,
_mm256_sad_epu8(local_acc, _mm256_setzero_si256()),
);
let lo128 = _mm256_castsi256_si128(total_acc);
let hi128 = _mm256_extracti128_si256(total_acc, 1);
let sum128 = _mm_add_epi64(lo128, hi128);
let hi64 = _mm_unpackhi_epi64(sum128, sum128);
let total = _mm_add_epi64(sum128, hi64);
let mut sum = _mm_cvtsi128_si64(total) as usize;
let processed_bytes = chunks * 32;
let remaining_words = (total_bytes - processed_bytes) / 8;
let tail_start = processed_bytes / 8;
for &w in &words[tail_start..tail_start + remaining_words] {
sum += w.count_ones() as usize;
}
sum
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
unsafe fn popcount_hw(words: &[u64]) -> usize {
use std::arch::x86_64::_popcnt64;
let mut sum: usize = 0;
for &w in words {
sum += unsafe { _popcnt64(w as i64) } as usize;
}
sum
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn popcount_neon(words: &[u64]) -> usize {
use std::arch::aarch64::*;
let chunks = words.len() / 2;
unsafe {
let mut acc = vdupq_n_u64(0);
for i in 0..chunks {
let base = i * 2;
let v = vld1q_u64(words.as_ptr().add(base));
let byte_counts = vcntq_u8(vreinterpretq_u8_u64(v));
let pair_sums = vpaddlq_u8(byte_counts); let quad_sums = vpaddlq_u16(pair_sums); let oct_sums = vpaddlq_u32(quad_sums); acc = vaddq_u64(acc, oct_sums);
}
let sum = vgetq_lane_u64(acc, 0) + vgetq_lane_u64(acc, 1);
let mut total = sum as usize;
for &w in &words[chunks * 2..] {
total += w.count_ones() as usize;
}
total
}
}
#[inline]
pub fn has_fast_bmi2() -> bool {
static CACHE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*CACHE.get_or_init(has_fast_bmi2_detect)
}
fn has_fast_bmi2_detect() -> bool {
#[cfg(target_arch = "x86_64")]
{
if !std::arch::is_x86_feature_detected!("bmi2") {
return false;
}
let cpuid0 = unsafe { std::arch::x86_64::__cpuid(0) };
let is_amd = cpuid0.ebx == 0x6874_7541
&& cpuid0.edx == 0x6974_6E65
&& cpuid0.ecx == 0x444D_4163;
if !is_amd {
return true;
}
let cpuid1 = unsafe { std::arch::x86_64::__cpuid(1) };
let base_family = (cpuid1.eax >> 8) & 0xF;
let ext_family = (cpuid1.eax >> 20) & 0xFF;
let effective_family = if base_family == 0xF {
base_family + ext_family
} else {
base_family
};
effective_family >= 0x19
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
#[inline(always)]
pub fn select_in_word(word: u64, rank: usize) -> usize {
#[cfg(target_arch = "x86_64")]
{
if has_fast_bmi2() {
return unsafe {
let mask = std::arch::x86_64::_pdep_u64(1u64 << rank, word);
mask.trailing_zeros() as usize
};
}
if std::arch::is_x86_feature_detected!("popcnt") {
return select_in_word_popcnt(word, rank);
}
}
select_in_word_scalar(word, rank)
}
#[inline(always)]
fn select_in_word_popcnt(word: u64, rank: usize) -> usize {
let mut r = rank;
let mut pos = 0usize;
let count = (word & 0xFFFF_FFFF).count_ones() as usize;
if count <= r { r -= count; pos += 32; }
let count = ((word >> pos) & 0xFFFF).count_ones() as usize;
if count <= r { r -= count; pos += 16; }
let count = ((word >> pos) & 0xFF).count_ones() as usize;
if count <= r { r -= count; pos += 8; }
let count = ((word >> pos) & 0xF).count_ones() as usize;
if count <= r { r -= count; pos += 4; }
let count = ((word >> pos) & 0x3).count_ones() as usize;
if count <= r { r -= count; pos += 2; }
let count = ((word >> pos) & 0x1) as usize;
if count <= r { pos += 1; }
pos
}
#[inline(always)]
fn select_in_word_scalar(word: u64, rank: usize) -> usize {
let mut w = word;
for _ in 0..rank {
w &= w - 1; }
w.trailing_zeros() as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_slice() {
assert_eq!(popcount_slice(&[]), 0);
}
#[test]
fn test_single_word() {
assert_eq!(popcount_slice(&[0]), 0);
assert_eq!(popcount_slice(&[1]), 1);
assert_eq!(popcount_slice(&[u64::MAX]), 64);
assert_eq!(popcount_slice(&[0xFF]), 8);
assert_eq!(popcount_slice(&[0xAAAA_AAAA_AAAA_AAAA]), 32);
}
#[test]
fn test_all_zeros() {
let words = vec![0u64; 100];
assert_eq!(popcount_slice(&words), 0);
}
#[test]
fn test_all_ones() {
let words = vec![u64::MAX; 100];
assert_eq!(popcount_slice(&words), 6400);
}
#[test]
fn test_matches_scalar_small() {
for len in 0..SIMD_THRESHOLD {
let words: Vec<u64> = (0..len as u64).map(|i| i.wrapping_mul(0x1234_5678_9ABC_DEF0)).collect();
let expected: usize = words.iter().map(|w| w.count_ones() as usize).sum();
assert_eq!(popcount_slice(&words), expected, "mismatch at len={len}");
}
}
#[test]
fn test_matches_scalar_simd_range() {
for len in [16, 17, 31, 32, 33, 63, 64, 100, 127, 128, 255, 256, 500, 1000] {
let words: Vec<u64> = (0..len as u64)
.map(|i| i.wrapping_mul(0xDEAD_BEEF_CAFE_BABE).wrapping_add(i))
.collect();
let expected: usize = words.iter().map(|w| w.count_ones() as usize).sum();
assert_eq!(popcount_slice(&words), expected, "mismatch at len={len}");
}
}
#[test]
fn test_alternating_bits() {
let words = vec![0x5555_5555_5555_5555u64; 64]; assert_eq!(popcount_slice(&words), 64 * 32);
}
#[test]
fn test_single_bit_per_word() {
let words: Vec<u64> = (0..64).map(|i| 1u64 << i).collect();
assert_eq!(popcount_slice(&words), 64);
}
#[test]
fn test_boundary_at_31_iterations() {
let words = vec![u64::MAX; 124];
assert_eq!(popcount_slice(&words), 124 * 64);
let words = vec![u64::MAX; 125];
assert_eq!(popcount_slice(&words), 125 * 64);
}
#[test]
fn test_avx2_reduction_overflow_boundary() {
for n in [124, 248, 372, 496] {
let words = vec![u64::MAX; n];
assert_eq!(popcount_slice(&words), n * 64, "mismatch at n={n}");
}
}
#[test]
fn test_large_slice() {
let words: Vec<u64> = (0..10_000u64)
.map(|i| i.wrapping_mul(0x0123_4567_89AB_CDEF))
.collect();
let expected: usize = words.iter().map(|w| w.count_ones() as usize).sum();
assert_eq!(popcount_slice(&words), expected);
}
#[test]
fn test_tier_consistency() {
let words: Vec<u64> = (0..256u64)
.map(|i| i.wrapping_mul(0xFEDC_BA98_7654_3210).wrapping_add(i * 17))
.collect();
let scalar = popcount_scalar(&words);
let dispatch = popcount_slice(&words);
assert_eq!(dispatch, scalar, "dispatch vs scalar mismatch");
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("avx2") {
let avx2 = unsafe { popcount_avx2(&words) };
assert_eq!(avx2, scalar, "AVX2 vs scalar mismatch");
}
if std::arch::is_x86_feature_detected!("popcnt") {
let hw = unsafe { popcount_hw(&words) };
assert_eq!(hw, scalar, "POPCNT vs scalar mismatch");
}
}
}
#[test]
fn test_union_counting_workload() {
let num_words = (50_000 >> 6) + 1; let mut bits = vec![0u64; num_words];
let doc_ids: Vec<u32> = (0..1000).map(|i| (i * 47) % 50_000).collect();
for &doc_id in &doc_ids {
let w = doc_id as usize >> 6;
let b = doc_id as usize & 63;
bits[w] |= 1u64 << b;
}
let expected: usize = bits.iter().map(|w| w.count_ones() as usize).sum();
assert_eq!(popcount_slice(&bits), expected);
}
#[test]
fn test_has_fast_bmi2_no_panic() {
let result = has_fast_bmi2();
eprintln!("has_fast_bmi2() = {result}");
}
#[test]
fn test_select_in_word_basic() {
let word = 0xAAu64;
assert_eq!(select_in_word(word, 0), 1); assert_eq!(select_in_word(word, 1), 3); assert_eq!(select_in_word(word, 2), 5); assert_eq!(select_in_word(word, 3), 7); }
#[test]
fn test_select_in_word_rank_zero() {
assert_eq!(select_in_word(1, 0), 0);
assert_eq!(select_in_word(0x80, 0), 7);
assert_eq!(select_in_word(u64::MAX, 0), 0);
}
#[test]
fn test_select_in_word_high_rank() {
for k in 0..64 {
assert_eq!(select_in_word(u64::MAX, k), k, "MAX rank={k}");
}
}
#[test]
fn test_select_in_word_single_bit() {
for pos in 0..64 {
assert_eq!(select_in_word(1u64 << pos, 0), pos, "1<<{pos} rank=0");
}
}
#[test]
fn test_select_in_word_sparse() {
let word = 1u64 | (1u64 << 16) | (1u64 << 32) | (1u64 << 48);
assert_eq!(select_in_word(word, 0), 0);
assert_eq!(select_in_word(word, 1), 16);
assert_eq!(select_in_word(word, 2), 32);
assert_eq!(select_in_word(word, 3), 48);
}
#[test]
fn test_select_in_word_consecutive() {
let word = 0xFFu64;
for k in 0..8 {
assert_eq!(select_in_word(word, k), k, "0xFF rank={k}");
}
}
#[test]
fn test_select_in_word_popcnt_matches_scalar() {
let test_words: Vec<u64> = vec![
0, 1, 0xFF, 0xAAAA_AAAA_AAAA_AAAA, 0x5555_5555_5555_5555,
u64::MAX, 0x8000_0000_0000_0001, 0x0123_4567_89AB_CDEF,
0xFEDC_BA98_7654_3210, 0x0F0F_0F0F_0F0F_0F0F,
];
for &word in &test_words {
let ones = word.count_ones() as usize;
for rank in 0..ones {
let popcnt_result = select_in_word_popcnt(word, rank);
let scalar_result = select_in_word_scalar(word, rank);
assert_eq!(
popcnt_result, scalar_result,
"mismatch word=0x{word:016X} rank={rank}: popcnt={popcnt_result} scalar={scalar_result}"
);
}
}
}
#[test]
fn test_select_in_word_random_patterns() {
let mut rng = 0x1234_5678_9ABC_DEF0u64;
for _ in 0..100 {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let word = rng;
let ones = word.count_ones() as usize;
for rank in 0..ones {
let result = select_in_word(word, rank);
let scalar = select_in_word_scalar(word, rank);
assert_eq!(
result, scalar,
"mismatch word=0x{word:016X} rank={rank}"
);
}
}
}
}
#[cfg(test)]
#[cfg(feature = "simd")]
mod benchmarks {
use super::*;
use std::time::Instant;
#[test]
fn bench_popcount_slice_throughput() {
if cfg!(debug_assertions) {
eprintln!("Skipping benchmark in debug mode");
return;
}
let sizes = [16, 100, 781, 1_000, 10_000];
let iterations = 100_000;
for &size in &sizes {
let words: Vec<u64> = (0..size as u64)
.map(|i| i.wrapping_mul(0xDEAD_BEEF_CAFE_BABE))
.collect();
let mut sink = 0usize;
for _ in 0..1000 {
sink += popcount_slice(&words);
}
let start = Instant::now();
for _ in 0..iterations {
sink += popcount_slice(&words);
}
let elapsed = start.elapsed();
let ns_per_call = elapsed.as_nanos() as f64 / iterations as f64;
let words_per_ns = size as f64 / ns_per_call;
eprintln!(
"popcount_slice({size:>6} words = {:>6} bytes): {ns_per_call:>8.1} ns/call, \
{words_per_ns:.2} words/ns ({:.0} Mwords/s) [sink={sink}]",
size * 8,
words_per_ns * 1000.0,
);
}
}
#[test]
fn bench_popcount_simd_vs_scalar() {
if cfg!(debug_assertions) {
eprintln!("Skipping benchmark in debug mode");
return;
}
let size = 1000;
let iterations = 200_000;
let words: Vec<u64> = (0..size as u64)
.map(|i| i.wrapping_mul(0xCAFE_BABE_DEAD_BEEF))
.collect();
let mut sink = 0usize;
for _ in 0..1000 {
sink += popcount_scalar(&words);
sink += popcount_slice(&words);
}
let start = Instant::now();
for _ in 0..iterations {
sink += popcount_scalar(&words);
}
let scalar_ns = start.elapsed().as_nanos() as f64 / iterations as f64;
let start = Instant::now();
for _ in 0..iterations {
sink += popcount_slice(&words);
}
let simd_ns = start.elapsed().as_nanos() as f64 / iterations as f64;
let speedup = scalar_ns / simd_ns;
eprintln!(
"popcount {size} words: scalar={scalar_ns:.1}ns, simd={simd_ns:.1}ns, \
speedup={speedup:.1}× [sink={sink}]"
);
}
#[test]
fn bench_select_in_word() {
if cfg!(debug_assertions) {
eprintln!("Skipping benchmark in debug mode");
return;
}
let words: Vec<u64> = (0..1000u64)
.map(|i| i.wrapping_mul(0xDEAD_BEEF_CAFE_BABE) | 0x8000_0000_0000_0001)
.collect();
let iterations = 100_000;
let mut sink = 0usize;
for &w in &words {
let ones = w.count_ones() as usize;
for r in 0..ones.min(4) {
sink += select_in_word(w, r);
}
}
let start = Instant::now();
for _ in 0..iterations {
for &w in &words {
sink += select_in_word(w, 0);
sink += select_in_word(w, 1);
}
}
let ns = start.elapsed().as_nanos() as f64 / (iterations as f64 * 2000.0);
eprintln!("select_in_word: {ns:.1} ns/call [sink={sink}]");
}
}