#[cfg(feature = "std")]
use std::sync::OnceLock;
pub trait CpuKernel: Copy + 'static {
fn mask_lower_bits(value: u64, n: u8) -> u64;
}
#[derive(Copy, Clone, Default)]
pub struct ScalarKernel;
impl CpuKernel for ScalarKernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
let mask = u64::MAX
.checked_shr(64u32.wrapping_sub(n as u32))
.unwrap_or(0);
value & mask
}
}
#[cfg(target_arch = "x86_64")]
#[derive(Copy, Clone, Default)]
pub(crate) struct Bmi2Kernel;
#[cfg(target_arch = "x86_64")]
impl CpuKernel for Bmi2Kernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
unsafe { mask_lower_bits_bmi2_impl(value, n) }
}
}
#[cfg(target_arch = "x86_64")]
#[derive(Copy, Clone, Default)]
pub(crate) struct Avx2Kernel;
#[cfg(target_arch = "x86_64")]
impl CpuKernel for Avx2Kernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
unsafe { mask_lower_bits_bmi2_impl(value, n) }
}
}
#[cfg(target_arch = "x86_64")]
#[derive(Copy, Clone, Default)]
pub(crate) struct Vbmi2Kernel;
#[cfg(target_arch = "x86_64")]
impl CpuKernel for Vbmi2Kernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
unsafe { mask_lower_bits_bmi2_impl(value, n) }
}
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
#[derive(Copy, Clone, Default)]
pub(crate) struct NeonKernel;
#[cfg(target_arch = "aarch64")]
impl CpuKernel for NeonKernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
ScalarKernel::mask_lower_bits(value, n)
}
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
#[derive(Copy, Clone, Default)]
pub(crate) struct SveKernel;
#[cfg(target_arch = "aarch64")]
impl CpuKernel for SveKernel {
#[inline(always)]
fn mask_lower_bits(value: u64, n: u8) -> u64 {
ScalarKernel::mask_lower_bits(value, n)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
#[inline]
unsafe fn mask_lower_bits_bmi2_impl(value: u64, n: u8) -> u64 {
core::arch::x86_64::_bzhi_u64(value, n as u32)
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
const fn select_x86_kernel(
has_avx512vbmi2: bool,
has_avx512f: bool,
has_avx512vl: bool,
has_avx512bw: bool,
has_bmi2: bool,
has_avx2: bool,
) -> CpuKernelTag {
if has_avx512vbmi2 && has_avx512f && has_avx512vl && has_avx512bw && has_bmi2 && has_avx2 {
return CpuKernelTag::Vbmi2;
}
if has_avx2 && has_bmi2 {
return CpuKernelTag::Avx2;
}
if has_bmi2 {
return CpuKernelTag::Bmi2;
}
CpuKernelTag::Scalar
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum CpuKernelTag {
Scalar,
#[cfg(target_arch = "x86_64")]
Bmi2,
#[cfg(target_arch = "x86_64")]
Avx2,
#[cfg(target_arch = "x86_64")]
Vbmi2,
#[cfg(target_arch = "aarch64")]
Neon,
#[cfg(all(target_arch = "aarch64", any(feature = "std", target_feature = "sve"),))]
Sve,
}
#[cfg(feature = "std")]
pub(crate) fn detect_cpu_kernel() -> CpuKernelTag {
static CACHED: OnceLock<CpuKernelTag> = OnceLock::new();
*CACHED.get_or_init(detect_cpu_kernel_uncached)
}
#[cfg(feature = "std")]
fn detect_cpu_kernel_uncached() -> CpuKernelTag {
#[cfg(target_arch = "x86_64")]
{
use std::arch::is_x86_feature_detected;
return select_x86_kernel(
is_x86_feature_detected!("avx512vbmi2"),
is_x86_feature_detected!("avx512f"),
is_x86_feature_detected!("avx512vl"),
is_x86_feature_detected!("avx512bw"),
is_x86_feature_detected!("bmi2"),
is_x86_feature_detected!("avx2"),
);
}
#[cfg(target_arch = "aarch64")]
{
use std::arch::is_aarch64_feature_detected;
if is_aarch64_feature_detected!("sve") {
return CpuKernelTag::Sve;
}
if is_aarch64_feature_detected!("neon") {
return CpuKernelTag::Neon;
}
return CpuKernelTag::Scalar;
}
#[allow(unreachable_code)]
CpuKernelTag::Scalar
}
#[cfg(not(feature = "std"))]
pub(crate) fn detect_cpu_kernel() -> CpuKernelTag {
#[cfg(target_arch = "x86_64")]
{
return select_x86_kernel(
cfg!(target_feature = "avx512vbmi2"),
cfg!(target_feature = "avx512f"),
cfg!(target_feature = "avx512vl"),
cfg!(target_feature = "avx512bw"),
cfg!(target_feature = "bmi2"),
cfg!(target_feature = "avx2"),
);
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return CpuKernelTag::Sve;
}
#[cfg(target_feature = "neon")]
{
return CpuKernelTag::Neon;
}
}
#[allow(unreachable_code)]
CpuKernelTag::Scalar
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scalar_mask_lower_bits_zero_n_returns_zero() {
assert_eq!(ScalarKernel::mask_lower_bits(0xDEADBEEF, 0), 0);
}
#[test]
fn scalar_mask_lower_bits_full_64_returns_full_value() {
assert_eq!(
ScalarKernel::mask_lower_bits(0xFFFF_FFFF_FFFF_FFFF, 64),
0xFFFF_FFFF_FFFF_FFFF
);
}
#[test]
fn scalar_mask_lower_bits_mid_keeps_low_n_bits() {
assert_eq!(ScalarKernel::mask_lower_bits(0xDEAD_BEEF, 8), 0xEF);
assert_eq!(
ScalarKernel::mask_lower_bits(0x0102_0304_0506_0708, 16),
0x0708
);
}
#[cfg(all(target_arch = "x86_64", feature = "std"))]
#[test]
fn avx2_mask_lower_bits_matches_scalar_on_bmi2_hw() {
if !std::arch::is_x86_feature_detected!("bmi2") {
return;
}
for n in 0..=64u8 {
let v = 0x1234_5678_9ABC_DEF0u64;
assert_eq!(
Avx2Kernel::mask_lower_bits(v, n),
ScalarKernel::mask_lower_bits(v, n),
"mismatch at n={}",
n
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn select_x86_kernel_vbmi2_without_avx2_does_not_pick_vbmi2() {
let tag = select_x86_kernel(
true, true, true,
true, true, false,
);
assert_ne!(
tag,
CpuKernelTag::Vbmi2,
"selecting Vbmi2 without AVX2 would call AVX2 instructions and SIGILL"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn select_x86_kernel_full_x86_v4_picks_vbmi2() {
let tag = select_x86_kernel(true, true, true, true, true, true);
assert_eq!(tag, CpuKernelTag::Vbmi2);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn select_x86_kernel_avx2_baseline_picks_avx2() {
let tag = select_x86_kernel(false, false, false, false, true, true);
assert_eq!(tag, CpuKernelTag::Avx2);
}
#[test]
fn detect_returns_consistent_tag() {
let first = detect_cpu_kernel();
let second = detect_cpu_kernel();
assert_eq!(
first, second,
"cached detect must return same tag on repeated calls"
);
}
}