use core::sync::atomic::{AtomicU8, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum SimdCapability {
Scalar = 0,
Sse42 = 1,
Avx2 = 2,
Avx512 = 3,
Neon = 4,
}
impl SimdCapability {
#[inline]
pub const fn vector_width(self) -> usize {
match self {
SimdCapability::Scalar => 1,
SimdCapability::Sse42 => 16,
SimdCapability::Avx2 => 32,
SimdCapability::Avx512 => 64,
SimdCapability::Neon => 16,
}
}
#[inline]
pub const fn is_simd(self) -> bool {
!matches!(self, SimdCapability::Scalar)
}
#[inline]
pub const fn name(self) -> &'static str {
match self {
SimdCapability::Scalar => "Scalar",
SimdCapability::Sse42 => "SSE4.2",
SimdCapability::Avx2 => "AVX2",
SimdCapability::Avx512 => "AVX-512",
SimdCapability::Neon => "NEON",
}
}
#[inline]
pub const fn f32_lanes(self) -> usize {
self.vector_width() / 4
}
#[inline]
pub const fn f64_lanes(self) -> usize {
self.vector_width() / 8
}
#[inline]
pub const fn i32_lanes(self) -> usize {
self.vector_width() / 4
}
}
impl Default for SimdCapability {
fn default() -> Self {
detect_capability()
}
}
static CACHED_CAPABILITY: AtomicU8 = AtomicU8::new(0xFF);
#[inline]
pub fn detect_capability() -> SimdCapability {
let cached = CACHED_CAPABILITY.load(Ordering::Relaxed);
if cached != 0xFF {
return match cached {
0 => SimdCapability::Scalar,
1 => SimdCapability::Sse42,
2 => SimdCapability::Avx2,
3 => SimdCapability::Avx512,
4 => SimdCapability::Neon,
_ => SimdCapability::Scalar,
};
}
let detected = detect_capability_impl();
CACHED_CAPABILITY.store(detected as u8, Ordering::Relaxed);
detected
}
#[inline]
pub fn is_simd_available() -> bool {
detect_capability().is_simd()
}
#[inline]
pub fn optimal_alignment() -> usize {
detect_capability().vector_width()
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
fn detect_capability_impl() -> SimdCapability {
#[cfg(target_feature = "avx512f")]
{
return SimdCapability::Avx512;
}
#[cfg(not(target_feature = "avx512f"))]
{
if is_x86_feature_detected!("avx512f") {
return SimdCapability::Avx512;
}
if is_x86_feature_detected!("avx2") {
return SimdCapability::Avx2;
}
if is_x86_feature_detected!("sse4.2") {
return SimdCapability::Sse42;
}
SimdCapability::Scalar
}
}
#[cfg(all(target_arch = "x86", target_feature = "sse2"))]
fn detect_capability_impl() -> SimdCapability {
if is_x86_feature_detected!("avx2") {
return SimdCapability::Avx2;
}
if is_x86_feature_detected!("sse4.2") {
return SimdCapability::Sse42;
}
SimdCapability::Scalar
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn detect_capability_impl() -> SimdCapability {
SimdCapability::Neon
}
#[cfg(all(target_arch = "arm", target_feature = "neon"))]
fn detect_capability_impl() -> SimdCapability {
#[cfg(target_os = "linux")]
{
SimdCapability::Neon
}
#[cfg(not(target_os = "linux"))]
{
SimdCapability::Neon
}
}
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "x86", target_feature = "sse2"),
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "arm", target_feature = "neon"),
)))]
fn detect_capability_impl() -> SimdCapability {
SimdCapability::Scalar
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_capability() {
let cap = detect_capability();
println!("Detected SIMD capability: {:?}", cap);
println!("Vector width: {} bytes", cap.vector_width());
println!("f32 lanes: {}", cap.f32_lanes());
println!("f64 lanes: {}", cap.f64_lanes());
}
#[test]
fn test_cached_detection() {
let cap1 = detect_capability();
let cap2 = detect_capability();
assert_eq!(cap1, cap2);
}
#[test]
fn test_simd_capability_ordering() {
assert!(SimdCapability::Scalar < SimdCapability::Sse42);
assert!(SimdCapability::Sse42 < SimdCapability::Avx2);
assert!(SimdCapability::Avx2 < SimdCapability::Avx512);
}
#[test]
fn test_vector_widths() {
assert_eq!(SimdCapability::Scalar.vector_width(), 1);
assert_eq!(SimdCapability::Sse42.vector_width(), 16);
assert_eq!(SimdCapability::Avx2.vector_width(), 32);
assert_eq!(SimdCapability::Avx512.vector_width(), 64);
assert_eq!(SimdCapability::Neon.vector_width(), 16);
}
#[test]
fn test_is_simd() {
assert!(!SimdCapability::Scalar.is_simd());
assert!(SimdCapability::Sse42.is_simd());
assert!(SimdCapability::Avx2.is_simd());
assert!(SimdCapability::Avx512.is_simd());
assert!(SimdCapability::Neon.is_simd());
}
}