#[cfg(feature = "std")]
use std::sync::OnceLock;
use crate::simd::dispatch::{SimdCapabilities as LegacyCaps, SimdLevel as LegacyLevel};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SimdCapabilityInfo {
pub has_sse42: bool,
pub has_avx: bool,
pub has_avx2: bool,
pub has_fma: bool,
pub has_avx512f: bool,
pub has_avx512bw: bool,
pub has_avx512vl: bool,
pub has_neon: bool,
pub has_sve: bool,
pub cache_line_bytes: usize,
pub vector_width_bytes: usize,
}
impl SimdCapabilityInfo {
#[cfg(feature = "std")]
#[inline]
pub fn detect() -> &'static Self {
static INFO: OnceLock<SimdCapabilityInfo> = OnceLock::new();
INFO.get_or_init(Self::compute)
}
#[cfg(not(feature = "std"))]
#[inline]
pub fn detect() -> Self {
Self::compute()
}
fn compute() -> Self {
let legacy = simd_caps();
Self::from_legacy(legacy)
}
#[cfg(feature = "std")]
fn from_legacy(legacy: &LegacyCaps) -> Self {
#[cfg(all(target_arch = "x86_64", feature = "std"))]
let has_sse42 = is_x86_feature_detected!("sse4.2");
#[cfg(not(target_arch = "x86_64"))]
let has_sse42 = false;
let has_avx512f = legacy.has_avx512f;
let has_avx2 = legacy.has_avx2;
let vector_width_bytes = if has_avx512f {
64
} else if has_avx2 {
32
} else if has_sse42 || legacy.has_neon {
16
} else {
8
};
Self {
has_sse42,
has_avx: legacy.has_avx,
has_avx2,
has_fma: legacy.has_fma,
has_avx512f,
has_avx512bw: legacy.has_avx512bw,
has_avx512vl: legacy.has_avx512vl,
has_neon: legacy.has_neon,
has_sve: legacy.has_sve,
cache_line_bytes: 64,
vector_width_bytes,
}
}
#[cfg(not(feature = "std"))]
fn from_legacy(legacy: LegacyCaps) -> Self {
let has_sse42 = cfg!(target_feature = "sse4.2");
let has_avx512f = legacy.has_avx512f;
let has_avx2 = legacy.has_avx2;
let vector_width_bytes: usize = if has_avx512f {
64
} else if has_avx2 {
32
} else if has_sse42 || legacy.has_neon {
16
} else {
8
};
Self {
has_sse42,
has_avx: legacy.has_avx,
has_avx2,
has_fma: legacy.has_fma,
has_avx512f,
has_avx512bw: legacy.has_avx512bw,
has_avx512vl: legacy.has_avx512vl,
has_neon: legacy.has_neon,
has_sve: legacy.has_sve,
cache_line_bytes: 64,
vector_width_bytes,
}
}
#[inline]
pub fn has_avx512_full(&self) -> bool {
self.has_avx512f && self.has_avx512bw && self.has_avx512vl
}
#[inline]
pub fn has_avx2_fma(&self) -> bool {
self.has_avx2 && self.has_fma
}
#[inline]
pub fn f64_simd_width(&self) -> usize {
self.vector_width_bytes / core::mem::size_of::<f64>()
}
#[inline]
pub fn f32_simd_width(&self) -> usize {
self.vector_width_bytes / core::mem::size_of::<f32>()
}
#[inline]
pub fn optimal_level(&self) -> LegacyLevel {
if self.has_avx512_full() {
LegacyLevel::Avx512
} else if self.has_avx2_fma() {
LegacyLevel::Avx2
} else if self.has_avx {
LegacyLevel::Avx
} else if self.has_sse42 {
LegacyLevel::Sse42
} else if self.has_neon {
LegacyLevel::Neon
} else if self.has_sve {
LegacyLevel::Sve
} else {
LegacyLevel::Scalar
}
}
}
#[macro_export]
macro_rules! simd_dispatch_caps {
(
$caps:expr,
avx512 => $avx512:expr,
avx2 => $avx2:expr,
sse42 => $sse42:expr,
neon => $neon:expr,
scalar => $scalar:expr $(,)?
) => {{
let _caps = $caps;
if _caps.has_avx512_full() {
$avx512
} else if _caps.has_avx2_fma() {
$avx2
} else if _caps.has_sse42 {
$sse42
} else if _caps.has_neon {
$neon
} else {
$scalar
}
}};
}
pub use simd_dispatch_caps;
pub trait SimdDispatcher {
type Output;
fn dispatch_avx512(&self) -> Self::Output;
fn dispatch_avx2(&self) -> Self::Output;
fn dispatch_neon(&self) -> Self::Output;
fn dispatch_scalar(&self) -> Self::Output;
fn dispatch(&self) -> Self::Output {
#[cfg(feature = "std")]
let caps = SimdCapabilityInfo::detect();
#[cfg(not(feature = "std"))]
let caps = SimdCapabilityInfo::detect();
if caps.has_avx512_full() {
self.dispatch_avx512()
} else if caps.has_avx2_fma() {
self.dispatch_avx2()
} else if caps.has_neon {
self.dispatch_neon()
} else {
self.dispatch_scalar()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmKernelKind {
Avx512,
Avx2,
Neon,
Scalar,
}
impl GemmKernelKind {
#[inline]
pub const fn name(self) -> &'static str {
match self {
GemmKernelKind::Avx512 => "AVX-512",
GemmKernelKind::Avx2 => "AVX2+FMA",
GemmKernelKind::Neon => "NEON",
GemmKernelKind::Scalar => "scalar",
}
}
#[inline]
pub const fn is_simd(self) -> bool {
!matches!(self, GemmKernelKind::Scalar)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KernelSelector {
pub gemm_f64_kernel: GemmKernelKind,
pub gemm_f32_kernel: GemmKernelKind,
}
impl KernelSelector {
fn from_caps(caps: &SimdCapabilityInfo) -> Self {
let kind = if caps.has_avx512_full() {
GemmKernelKind::Avx512
} else if caps.has_avx2_fma() {
GemmKernelKind::Avx2
} else if caps.has_neon {
GemmKernelKind::Neon
} else {
GemmKernelKind::Scalar
};
Self {
gemm_f64_kernel: kind,
gemm_f32_kernel: kind,
}
}
#[cfg(feature = "std")]
pub fn select() -> &'static Self {
static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
KERNEL_SEL.get_or_init(|| Self::from_caps(SimdCapabilityInfo::detect()))
}
#[cfg(not(feature = "std"))]
pub fn select() -> Self {
Self::from_caps(&SimdCapabilityInfo::detect())
}
}
pub use crate::simd::dispatch::{
SimdCapabilities, SimdLevel, has_avx2_fma, has_avx512, has_neon, optimal_simd_level, simd_caps,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_does_not_panic_and_cache_line_sane() {
let caps = SimdCapabilityInfo::detect();
assert!(caps.cache_line_bytes >= 8);
assert!(caps.cache_line_bytes.is_power_of_two());
assert!(caps.vector_width_bytes.is_power_of_two());
}
#[cfg(target_arch = "aarch64")]
#[test]
fn test_aarch64_neon_always_true() {
let caps = SimdCapabilityInfo::detect();
assert!(caps.has_neon, "NEON is mandatory on AArch64");
assert!(!caps.has_avx2, "AVX2 must not appear on AArch64");
assert!(!caps.has_avx512f, "AVX-512 must not appear on AArch64");
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_x86_64_flag_hierarchy() {
let caps = SimdCapabilityInfo::detect();
assert!(!caps.has_neon, "NEON must not appear on x86-64");
if caps.has_avx2 {
assert!(caps.has_avx, "AVX2 requires AVX");
}
if caps.has_avx512f {
assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
}
}
#[test]
fn test_simd_width_derivation() {
let caps = SimdCapabilityInfo::detect();
assert_eq!(
caps.f64_simd_width(),
caps.vector_width_bytes / core::mem::size_of::<f64>()
);
assert_eq!(
caps.f32_simd_width(),
caps.vector_width_bytes / core::mem::size_of::<f32>()
);
assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
}
#[test]
fn test_vector_width_matches_capability_tier() {
let caps = SimdCapabilityInfo::detect();
if caps.has_avx512f {
assert_eq!(caps.vector_width_bytes, 64);
} else if caps.has_avx2 {
assert_eq!(caps.vector_width_bytes, 32);
}
}
#[cfg(feature = "std")]
#[test]
fn test_simd_caps_stable_pointer() {
let a = simd_caps();
let b = simd_caps();
assert!(
core::ptr::eq(a, b),
"simd_caps() must return a stable &'static"
);
}
#[cfg(feature = "std")]
#[test]
fn test_capability_info_stable_pointer() {
let a = SimdCapabilityInfo::detect();
let b = SimdCapabilityInfo::detect();
assert!(
core::ptr::eq(a, b),
"detect() must return a stable &'static"
);
}
#[test]
fn test_optimal_level_consistent_with_flags() {
let caps = SimdCapabilityInfo::detect();
let level = caps.optimal_level();
match level {
LegacyLevel::Avx512 => assert!(caps.has_avx512_full()),
LegacyLevel::Avx2 => {
assert!(!caps.has_avx512_full());
assert!(caps.has_avx2_fma());
}
LegacyLevel::Avx => {
assert!(!caps.has_avx512_full());
assert!(!caps.has_avx2_fma());
assert!(caps.has_avx);
}
LegacyLevel::Sse42 => {
assert!(!caps.has_avx);
assert!(caps.has_sse42);
}
LegacyLevel::Neon => {
assert!(caps.has_neon);
assert!(!caps.has_avx);
}
LegacyLevel::Sve => {
assert!(caps.has_sve);
assert!(!caps.has_neon);
}
LegacyLevel::Scalar => {
assert!(!caps.has_avx);
assert!(!caps.has_neon);
assert!(!caps.has_sve);
}
}
}
#[test]
fn test_kernel_selector_valid_kinds() {
#[cfg(feature = "std")]
let sel = *KernelSelector::select();
#[cfg(not(feature = "std"))]
let sel = KernelSelector::select();
assert!(matches!(
sel.gemm_f64_kernel,
GemmKernelKind::Avx512
| GemmKernelKind::Avx2
| GemmKernelKind::Neon
| GemmKernelKind::Scalar
));
assert!(matches!(
sel.gemm_f32_kernel,
GemmKernelKind::Avx512
| GemmKernelKind::Avx2
| GemmKernelKind::Neon
| GemmKernelKind::Scalar
));
}
#[test]
fn test_kernel_selector_agrees_with_capability_info() {
let caps = SimdCapabilityInfo::detect();
#[cfg(feature = "std")]
let sel = *KernelSelector::select();
#[cfg(not(feature = "std"))]
let sel = KernelSelector::select();
if caps.has_avx512_full() {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
} else if caps.has_avx2_fma() {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
} else if caps.has_neon {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
} else {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
}
}
#[test]
fn test_simd_dispatch_caps_macro_branch_selection() {
#[cfg(feature = "std")]
let caps = SimdCapabilityInfo::detect();
#[cfg(not(feature = "std"))]
let caps = SimdCapabilityInfo::detect();
let chosen: u32 = simd_dispatch_caps!(
&caps,
avx512 => 512u32,
avx2 => 256u32,
sse42 => 128u32,
neon => 1000u32,
scalar => 1u32,
);
if caps.has_avx512_full() {
assert_eq!(chosen, 512);
} else if caps.has_avx2_fma() {
assert_eq!(chosen, 256);
} else if caps.has_sse42 {
assert_eq!(chosen, 128);
} else if caps.has_neon {
assert_eq!(chosen, 1000);
} else {
assert_eq!(chosen, 1);
}
}
struct DotProduct<'a> {
x: &'a [f64],
y: &'a [f64],
}
impl SimdDispatcher for DotProduct<'_> {
type Output = f64;
fn dispatch_avx512(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_avx2(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_neon(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_scalar(&self) -> f64 {
self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
}
}
#[test]
fn test_simd_dispatcher_trait_correctness() {
let x = [1.0_f64, 2.0, 3.0, 4.0];
let y = [5.0_f64, 6.0, 7.0, 8.0];
let result = DotProduct { x: &x, y: &y }.dispatch();
assert!((result - 70.0).abs() < f64::EPSILON);
}
#[test]
fn test_gemm_kernel_kind_names_non_empty() {
for kind in [
GemmKernelKind::Avx512,
GemmKernelKind::Avx2,
GemmKernelKind::Neon,
GemmKernelKind::Scalar,
] {
assert!(!kind.name().is_empty());
}
}
#[test]
fn test_gemm_kernel_kind_is_simd() {
assert!(GemmKernelKind::Avx512.is_simd());
assert!(GemmKernelKind::Avx2.is_simd());
assert!(GemmKernelKind::Neon.is_simd());
assert!(!GemmKernelKind::Scalar.is_simd());
}
#[test]
fn test_free_helper_fns_agree_with_detect() {
let caps = SimdCapabilityInfo::detect();
if caps.has_avx512_full() {
assert!(has_avx512());
}
if caps.has_avx2_fma() {
assert!(has_avx2_fma());
}
if caps.has_neon {
assert!(has_neon());
}
}
#[test]
fn test_dispatcher_scalar_ground_truth() {
let x = [0.0_f64; 0];
let y = [0.0_f64; 0];
let result = DotProduct { x: &x, y: &y }.dispatch();
assert_eq!(result, 0.0, "empty dot product must be zero");
}
}