#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[inline]
#[allow(dead_code)]
pub unsafe fn popcount_512_popcnt(ptr: *const u64) -> u32 {
unsafe {
let mut sum = 0i32;
sum += _popcnt64(*ptr as i64);
sum += _popcnt64(*ptr.add(1) as i64);
sum += _popcnt64(*ptr.add(2) as i64);
sum += _popcnt64(*ptr.add(3) as i64);
sum += _popcnt64(*ptr.add(4) as i64);
sum += _popcnt64(*ptr.add(5) as i64);
sum += _popcnt64(*ptr.add(6) as i64);
sum += _popcnt64(*ptr.add(7) as i64);
sum as u32
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[inline]
#[allow(dead_code)]
pub unsafe fn popcount_words_popcnt(ptr: *const u64, word_count: usize) -> u32 {
unsafe {
let mut total = 0i32;
for i in 0..word_count {
total += _popcnt64(*ptr.add(i) as i64);
}
total as u32
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
#[inline]
pub unsafe fn select_in_word_pdep(x: u64, k: u32) -> u32 {
if x == 0 {
return 64;
}
let pop = x.count_ones();
if k >= pop {
return 64;
}
let mask = if k >= 63 {
u64::MAX
} else {
(1u64 << (k + 1)) - 1
};
let scattered = _pdep_u64(mask, x);
if scattered == 0 {
return 64;
}
63 - scattered.leading_zeros()
}
#[cfg(all(target_arch = "x86_64", any(feature = "std", test)))]
pub fn has_fast_bmi2() -> bool {
use core::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => true,
2 => false,
_ => {
let result = detect_fast_bmi2();
CACHED.store(if result { 1 } else { 2 }, Ordering::Relaxed);
result
}
}
}
#[cfg(all(target_arch = "x86_64", any(feature = "std", test)))]
fn detect_fast_bmi2() -> bool {
if !is_x86_feature_detected!("bmi2") {
return false;
}
if is_x86_feature_detected!("avx512f") {
return true;
}
let cpuid0 = core::arch::x86_64::__cpuid(0);
let is_amd = cpuid0.ebx == 0x6874_7541 && cpuid0.edx == 0x6974_6E65 && cpuid0.ecx == 0x444D_4163;
if !is_amd {
return true;
}
let cpuid1 = core::arch::x86_64::__cpuid(1);
let family = ((cpuid1.eax >> 8) & 0xF) + ((cpuid1.eax >> 20) & 0xFF);
family >= 0x19
}
#[cfg(all(test, target_arch = "x86_64"))]
mod tests {
use super::*;
#[test]
fn test_popcount_512_popcnt() {
if !is_x86_feature_detected!("popcnt") {
return;
}
let data = [u64::MAX; 8];
let result = unsafe { popcount_512_popcnt(data.as_ptr()) };
assert_eq!(result, 512);
let data = [0u64; 8];
let result = unsafe { popcount_512_popcnt(data.as_ptr()) };
assert_eq!(result, 0);
let data = [0xAAAA_AAAA_AAAA_AAAAu64; 8];
let result = unsafe { popcount_512_popcnt(data.as_ptr()) };
assert_eq!(result, 256);
}
#[test]
fn test_popcount_words_popcnt() {
if !is_x86_feature_detected!("popcnt") {
return;
}
let data = [u64::MAX; 16];
unsafe {
assert_eq!(popcount_words_popcnt(data.as_ptr(), 0), 0);
assert_eq!(popcount_words_popcnt(data.as_ptr(), 1), 64);
assert_eq!(popcount_words_popcnt(data.as_ptr(), 8), 512);
assert_eq!(popcount_words_popcnt(data.as_ptr(), 16), 1024);
}
}
fn has_bmi2() -> bool {
is_x86_feature_detected!("bmi2")
}
#[test]
fn test_select_in_word_pdep_basic() {
if !has_bmi2() {
eprintln!("Skipping BMI2 test: CPU doesn't support BMI2");
return;
}
unsafe {
assert_eq!(select_in_word_pdep(0, 0), 64);
assert_eq!(select_in_word_pdep(1, 0), 0);
assert_eq!(select_in_word_pdep(1, 1), 64);
assert_eq!(select_in_word_pdep(1 << 63, 0), 63);
assert_eq!(select_in_word_pdep(1 << 63, 1), 64);
let word = 0b1010_1010u64;
assert_eq!(select_in_word_pdep(word, 0), 1); assert_eq!(select_in_word_pdep(word, 1), 3); assert_eq!(select_in_word_pdep(word, 2), 5); assert_eq!(select_in_word_pdep(word, 3), 7); assert_eq!(select_in_word_pdep(word, 4), 64); }
}
#[test]
fn test_select_in_word_pdep_all_ones() {
if !has_bmi2() {
return;
}
unsafe {
let word = u64::MAX;
for k in 0..64 {
assert_eq!(select_in_word_pdep(word, k), k, "k={}", k);
}
assert_eq!(select_in_word_pdep(word, 64), 64);
}
}
#[test]
fn test_select_in_word_pdep_sparse() {
if !has_bmi2() {
return;
}
unsafe {
let word = 1u64 | (1u64 << 31) | (1u64 << 63);
assert_eq!(select_in_word_pdep(word, 0), 0);
assert_eq!(select_in_word_pdep(word, 1), 31);
assert_eq!(select_in_word_pdep(word, 2), 63);
assert_eq!(select_in_word_pdep(word, 3), 64);
}
}
#[test]
fn test_select_in_word_pdep_dense() {
if !has_bmi2() {
return;
}
unsafe {
let word = 0xFFFF_FFFFu64;
for k in 0..32 {
assert_eq!(select_in_word_pdep(word, k), k, "k={}", k);
}
assert_eq!(select_in_word_pdep(word, 32), 64);
}
}
#[test]
fn test_select_in_word_pdep_matches_ctz() {
if !has_bmi2() {
return;
}
fn select_ctz(x: u64, k: u32) -> u32 {
let mut val = x;
let mut remaining = k;
loop {
if val == 0 {
return 64;
}
let t = val.trailing_zeros();
if remaining == 0 {
return t;
}
remaining -= 1;
val &= val - 1;
}
}
let patterns = [
0u64,
1,
0xFF,
0x8000_0000_0000_0000,
u64::MAX,
0xAAAA_AAAA_AAAA_AAAA,
0x5555_5555_5555_5555,
0x1234_5678_9ABC_DEF0,
0x00FF_00FF_00FF_00FF,
0xF0F0_F0F0_F0F0_F0F0,
];
for &word in &patterns {
let pop = word.count_ones();
for k in 0..=pop {
unsafe {
let pdep_result = select_in_word_pdep(word, k);
let ctz_result = select_ctz(word, k);
assert_eq!(
pdep_result, ctz_result,
"Mismatch for word={:#x}, k={}",
word, k
);
}
}
}
}
#[test]
fn test_select_in_word_pdep_exhaustive_small() {
if !has_bmi2() {
return;
}
fn select_ctz(x: u64, k: u32) -> u32 {
let mut val = x;
let mut remaining = k;
loop {
if val == 0 {
return 64;
}
let t = val.trailing_zeros();
if remaining == 0 {
return t;
}
remaining -= 1;
val &= val - 1;
}
}
for word in 0u64..=0xFFFF {
let pop = word.count_ones();
for k in 0..=pop {
unsafe {
let pdep_result = select_in_word_pdep(word, k);
let ctz_result = select_ctz(word, k);
assert_eq!(
pdep_result, ctz_result,
"Mismatch for word={:#x}, k={}",
word, k
);
}
}
}
}
#[test]
fn test_has_fast_bmi2_detection() {
let result = has_fast_bmi2();
eprintln!("has_fast_bmi2() = {}", result);
if !is_x86_feature_detected!("bmi2") {
assert!(
!result,
"has_fast_bmi2 should be false when BMI2 unsupported"
);
}
}
}