use crate::features::Feature;
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
const INIT_BIT: u64 = 1 << 63;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Features {
pub(in crate::cache) lo: u64,
pub(in crate::cache) hi: u64,
}
impl Features {
pub const EMPTY: Self = Self { lo: 0, hi: 0 };
#[inline]
pub fn current() -> Self {
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
{
windows_cache::snapshot_fast()
}
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
{
snapshot()
}
}
#[inline]
pub fn current_full() -> Self {
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
{
windows_cache::snapshot_full()
}
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
{
snapshot()
}
}
#[inline]
pub const fn has(&self, feature: Feature) -> bool {
let bit = feature as u8;
if bit < 64 {
(self.lo >> bit) & 1 != 0
} else {
(self.hi >> (bit - 64)) & 1 != 0
}
}
#[inline]
pub(crate) const fn with(mut self, feature: Feature) -> Self {
let bit = feature as u8;
if bit < 64 {
self.lo |= 1u64 << bit;
} else {
self.hi |= 1u64 << (bit - 64);
}
self
}
pub fn iter(&self) -> impl Iterator<Item = Feature> + '_ {
Feature::all().filter(move |f| self.has(*f))
}
}
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
#[inline]
fn snapshot() -> Features {
let mut f = Features::EMPTY;
for feat in Feature::all() {
if is_detected(feat) {
f = f.with(feat);
}
}
f
}
#[doc(hidden)]
#[inline]
pub fn is_detected(feature: Feature) -> bool {
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
{
windows_cache::query_fast(feature)
}
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
{
stdarch_dispatch(feature)
}
#[cfg(not(target_arch = "aarch64"))]
{
let _ = feature;
false
}
}
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
#[inline]
fn stdarch_dispatch(feature: Feature) -> bool {
match feature {
Feature::Asimd => std::arch::is_aarch64_feature_detected!("asimd"),
Feature::Fp => std::arch::is_aarch64_feature_detected!("fp"),
Feature::Fp16 => std::arch::is_aarch64_feature_detected!("fp16"),
Feature::Fhm => std::arch::is_aarch64_feature_detected!("fhm"),
Feature::Fcma => std::arch::is_aarch64_feature_detected!("fcma"),
Feature::Bf16 => std::arch::is_aarch64_feature_detected!("bf16"),
Feature::I8mm => std::arch::is_aarch64_feature_detected!("i8mm"),
Feature::JsConv => std::arch::is_aarch64_feature_detected!("jsconv"),
Feature::FrintTs => std::arch::is_aarch64_feature_detected!("frintts"),
Feature::Rdm => std::arch::is_aarch64_feature_detected!("rdm"),
Feature::Dotprod => std::arch::is_aarch64_feature_detected!("dotprod"),
Feature::Aes => std::arch::is_aarch64_feature_detected!("aes"),
Feature::Pmull => std::arch::is_aarch64_feature_detected!("pmull"),
Feature::Sha2 => std::arch::is_aarch64_feature_detected!("sha2"),
Feature::Sha3 => std::arch::is_aarch64_feature_detected!("sha3"),
Feature::Sm4 => std::arch::is_aarch64_feature_detected!("sm4"),
Feature::Crc => std::arch::is_aarch64_feature_detected!("crc"),
Feature::Lse => std::arch::is_aarch64_feature_detected!("lse"),
Feature::Lse2 => std::arch::is_aarch64_feature_detected!("lse2"),
Feature::Rcpc => std::arch::is_aarch64_feature_detected!("rcpc"),
Feature::Rcpc2 => std::arch::is_aarch64_feature_detected!("rcpc2"),
Feature::Paca => std::arch::is_aarch64_feature_detected!("paca"),
Feature::Pacg => std::arch::is_aarch64_feature_detected!("pacg"),
Feature::Bti => std::arch::is_aarch64_feature_detected!("bti"),
Feature::Dpb => std::arch::is_aarch64_feature_detected!("dpb"),
Feature::Dpb2 => std::arch::is_aarch64_feature_detected!("dpb2"),
Feature::Mte => std::arch::is_aarch64_feature_detected!("mte"),
Feature::Dit => std::arch::is_aarch64_feature_detected!("dit"),
Feature::Sb => std::arch::is_aarch64_feature_detected!("sb"),
Feature::Ssbs => std::arch::is_aarch64_feature_detected!("ssbs"),
Feature::FlagM => std::arch::is_aarch64_feature_detected!("flagm"),
Feature::Rand => std::arch::is_aarch64_feature_detected!("rand"),
Feature::Tme => std::arch::is_aarch64_feature_detected!("tme"),
Feature::Sve => std::arch::is_aarch64_feature_detected!("sve"),
Feature::Sve2 => std::arch::is_aarch64_feature_detected!("sve2"),
Feature::Sve2Aes => std::arch::is_aarch64_feature_detected!("sve2-aes"),
Feature::Sve2Bitperm => std::arch::is_aarch64_feature_detected!("sve2-bitperm"),
Feature::Sve2Sha3 => std::arch::is_aarch64_feature_detected!("sve2-sha3"),
Feature::Sve2Sm4 => std::arch::is_aarch64_feature_detected!("sve2-sm4"),
Feature::F32mm => std::arch::is_aarch64_feature_detected!("f32mm"),
Feature::F64mm => std::arch::is_aarch64_feature_detected!("f64mm"),
_ => false,
}
}
#[inline]
pub fn set_registry_enabled(enabled: bool) {
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
windows_cache::set_registry_enabled(enabled);
#[cfg(not(all(target_os = "windows", target_arch = "aarch64")))]
{
let _ = enabled;
}
}
#[cfg(all(target_os = "windows", target_arch = "aarch64"))]
mod windows_cache {
use super::{Features, INIT_BIT};
use crate::features::Feature;
use core::sync::atomic::{AtomicBool, AtomicU64, Ordering};
static FAST_LO: AtomicU64 = AtomicU64::new(0);
static FAST_HI: AtomicU64 = AtomicU64::new(0);
static FULL_LO: AtomicU64 = AtomicU64::new(0);
static FULL_HI: AtomicU64 = AtomicU64::new(0);
static REGISTRY_RUNTIME_ENABLED: AtomicBool = AtomicBool::new(true);
#[inline]
pub(super) fn query_fast(feature: Feature) -> bool {
let bit = feature as u8;
let (atomic, pos) = if bit < 64 {
(&FAST_LO, bit)
} else {
(&FAST_HI, bit - 64)
};
loop {
let word = atomic.load(Ordering::Acquire);
if word & INIT_BIT != 0 {
return (word >> pos) & 1 != 0;
}
populate_fast();
}
}
#[inline]
pub(super) fn snapshot_fast() -> Features {
loop {
let lo = FAST_LO.load(Ordering::Acquire);
if lo & INIT_BIT == 0 {
populate_fast();
continue;
}
let hi = FAST_HI.load(Ordering::Relaxed);
return Features {
lo: lo & !INIT_BIT,
hi: hi & !INIT_BIT,
};
}
}
#[inline]
pub(super) fn snapshot_full() -> Features {
loop {
let lo = FULL_LO.load(Ordering::Acquire);
if lo & INIT_BIT == 0 {
populate_full();
continue;
}
let hi = FULL_HI.load(Ordering::Relaxed);
return Features {
lo: lo & !INIT_BIT,
hi: hi & !INIT_BIT,
};
}
}
fn populate_fast() {
let mut f = Features::EMPTY;
crate::windows::fill_ipfp(&mut f);
FAST_HI.store(f.hi | INIT_BIT, Ordering::Release);
FAST_LO.store(f.lo | INIT_BIT, Ordering::Release);
}
fn populate_full() {
let mut f = Features::EMPTY;
crate::windows::fill_ipfp(&mut f);
#[cfg(feature = "registry")]
if REGISTRY_RUNTIME_ENABLED.load(Ordering::Acquire) {
crate::windows::fill_registry(&mut f);
}
FULL_HI.store(f.hi | INIT_BIT, Ordering::Release);
FULL_LO.store(f.lo | INIT_BIT, Ordering::Release);
}
#[inline]
pub(super) fn set_registry_enabled(enabled: bool) {
REGISTRY_RUNTIME_ENABLED.store(enabled, Ordering::Release);
FULL_LO.store(0, Ordering::Release);
FULL_HI.store(0, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_has_nothing() {
let f = Features::EMPTY;
for feat in Feature::all() {
assert!(!f.has(feat), "{}", feat.name());
}
}
#[test]
fn with_sets_only_target_bit() {
let f = Features::EMPTY.with(Feature::Rdm);
assert!(f.has(Feature::Rdm));
assert!(!f.has(Feature::Sve));
}
#[test]
fn high_bit_features_round_trip() {
let f = Features::EMPTY.with(Feature::SmeF64f64);
assert!(f.has(Feature::SmeF64f64));
for feat in Feature::all() {
if (feat as u8) < 64 {
assert!(!f.has(feat), "{} unexpectedly set", feat.name());
}
}
}
#[test]
fn full_snapshot_implies_fast_for_ipfp_features() {
use crate::features::DetectionMethod;
let fast = Features::current();
let full = Features::current_full();
for f in Feature::all() {
if f.detection_method() == DetectionMethod::Ipfp {
assert_eq!(
fast.has(f),
full.has(f),
"fast/full disagree for IPFP feature {}",
f.name()
);
}
}
}
#[test]
fn current_snapshot_matches_individual_calls() {
use crate::features::DetectionMethod;
let snap = Features::current();
for f in Feature::all() {
if f.detection_method() != DetectionMethod::Registry {
assert_eq!(snap.has(f), is_detected(f), "fast {} disagrees", f.name());
}
}
}
}