#[cfg(feature = "std")]
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SimdCapabilities {
pub has_sse42: bool,
pub has_avx: bool,
pub has_avx2: bool,
pub has_fma: bool,
pub has_avx512f: bool,
pub has_avx512bw: bool,
pub has_avx512vl: bool,
pub has_neon: bool,
pub has_sve: bool,
pub cache_line_bytes: usize,
pub vector_width_bytes: usize,
}
impl SimdCapabilities {
#[cfg(feature = "std")]
pub fn detect() -> &'static Self {
simd_caps()
}
#[cfg(not(feature = "std"))]
pub fn detect() -> Self {
simd_caps()
}
#[cfg(all(target_arch = "x86_64", feature = "std"))]
fn compute() -> Self {
let has_avx512f = is_x86_feature_detected!("avx512f");
let has_avx512bw = is_x86_feature_detected!("avx512bw");
let has_avx512vl = is_x86_feature_detected!("avx512vl");
let has_avx2 = is_x86_feature_detected!("avx2");
let has_fma = is_x86_feature_detected!("fma");
let has_avx = is_x86_feature_detected!("avx");
let has_sse42 = is_x86_feature_detected!("sse4.2");
let vector_width_bytes = if has_avx512f {
64
} else if has_avx2 {
32
} else if has_sse42 {
16
} else {
8
};
Self {
has_sse42,
has_avx,
has_avx2,
has_fma,
has_avx512f,
has_avx512bw,
has_avx512vl,
has_neon: false,
has_sve: false,
cache_line_bytes: 64,
vector_width_bytes,
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "std")))]
fn compute() -> Self {
let has_avx512f = cfg!(target_feature = "avx512f");
let has_avx512bw = cfg!(target_feature = "avx512bw");
let has_avx512vl = cfg!(target_feature = "avx512vl");
let has_avx2 = cfg!(target_feature = "avx2");
let has_fma = cfg!(target_feature = "fma");
let has_avx = cfg!(target_feature = "avx");
let has_sse42 = cfg!(target_feature = "sse4.2");
let vector_width_bytes: usize = if has_avx512f {
64
} else if has_avx2 {
32
} else if has_sse42 {
16
} else {
8
};
Self {
has_sse42,
has_avx,
has_avx2,
has_fma,
has_avx512f,
has_avx512bw,
has_avx512vl,
has_neon: false,
has_sve: false,
cache_line_bytes: 64,
vector_width_bytes,
}
}
#[cfg(target_arch = "aarch64")]
fn compute() -> Self {
let has_sve = cfg!(target_feature = "sve");
Self {
has_sse42: false,
has_avx: false,
has_avx2: false,
has_fma: false,
has_avx512f: false,
has_avx512bw: false,
has_avx512vl: false,
has_neon: true,
has_sve,
cache_line_bytes: 64,
vector_width_bytes: 16,
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
fn compute() -> Self {
Self {
has_sse42: false,
has_avx: false,
has_avx2: false,
has_fma: false,
has_avx512f: false,
has_avx512bw: false,
has_avx512vl: false,
has_neon: false,
has_sve: false,
cache_line_bytes: 64,
vector_width_bytes: 8,
}
}
#[inline]
pub fn has_avx512_full(&self) -> bool {
self.has_avx512f && self.has_avx512bw && self.has_avx512vl
}
#[inline]
pub fn has_avx2_fma(&self) -> bool {
self.has_avx2 && self.has_fma
}
#[inline]
pub fn f64_simd_width(&self) -> usize {
self.vector_width_bytes / core::mem::size_of::<f64>()
}
#[inline]
pub fn f32_simd_width(&self) -> usize {
self.vector_width_bytes / core::mem::size_of::<f32>()
}
#[inline]
pub fn optimal_level(&self) -> SimdLevel {
if self.has_avx512_full() {
SimdLevel::Avx512
} else if self.has_avx2_fma() {
SimdLevel::Avx2
} else if self.has_avx {
SimdLevel::Avx
} else if self.has_sse42 {
SimdLevel::Sse42
} else if self.has_neon {
SimdLevel::Neon
} else if self.has_sve {
SimdLevel::Sve
} else {
SimdLevel::Scalar
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SimdLevel {
Scalar = 0,
Sse42 = 1,
Avx = 2,
Avx2 = 3,
Avx512 = 4,
Neon = 10,
Sve = 11,
}
impl SimdLevel {
#[inline]
pub const fn name(self) -> &'static str {
match self {
SimdLevel::Scalar => "scalar",
SimdLevel::Sse42 => "SSE4.2",
SimdLevel::Avx => "AVX",
SimdLevel::Avx2 => "AVX2+FMA",
SimdLevel::Avx512 => "AVX-512",
SimdLevel::Neon => "NEON",
SimdLevel::Sve => "SVE",
}
}
#[inline]
pub const fn f64_width(self) -> usize {
match self {
SimdLevel::Scalar => 1,
SimdLevel::Sse42 => 2,
SimdLevel::Avx | SimdLevel::Avx2 => 4,
SimdLevel::Avx512 => 8,
SimdLevel::Neon => 2,
SimdLevel::Sve => 2, }
}
#[inline]
pub const fn f32_width(self) -> usize {
match self {
SimdLevel::Scalar => 1,
SimdLevel::Sse42 => 4,
SimdLevel::Avx | SimdLevel::Avx2 => 8,
SimdLevel::Avx512 => 16,
SimdLevel::Neon => 4,
SimdLevel::Sve => 4, }
}
}
#[cfg(feature = "std")]
static SIMD_CAPS: OnceLock<SimdCapabilities> = OnceLock::new();
#[cfg(feature = "std")]
#[inline]
pub fn simd_caps() -> &'static SimdCapabilities {
SIMD_CAPS.get_or_init(SimdCapabilities::compute)
}
#[cfg(not(feature = "std"))]
#[inline]
pub fn simd_caps() -> SimdCapabilities {
SimdCapabilities::compute()
}
#[inline]
pub fn optimal_simd_level() -> SimdLevel {
#[cfg(feature = "std")]
{
simd_caps().optimal_level()
}
#[cfg(not(feature = "std"))]
{
simd_caps().optimal_level()
}
}
#[inline]
pub fn has_avx512() -> bool {
#[cfg(feature = "std")]
{
simd_caps().has_avx512_full()
}
#[cfg(not(feature = "std"))]
{
simd_caps().has_avx512_full()
}
}
#[inline]
pub fn has_avx2_fma() -> bool {
#[cfg(feature = "std")]
{
simd_caps().has_avx2_fma()
}
#[cfg(not(feature = "std"))]
{
simd_caps().has_avx2_fma()
}
}
#[inline]
pub fn has_neon() -> bool {
#[cfg(feature = "std")]
{
simd_caps().has_neon
}
#[cfg(not(feature = "std"))]
{
simd_caps().has_neon
}
}
#[macro_export]
macro_rules! simd_dispatch {
(
$caps:expr,
avx512 => $avx512:expr,
avx2 => $avx2:expr,
sse42 => $sse42:expr,
neon => $neon:expr,
scalar => $scalar:expr $(,)?
) => {{
let caps = $caps;
if caps.has_avx512_full() {
$avx512
} else if caps.has_avx2_fma() {
$avx2
} else if caps.has_sse42 {
$sse42
} else if caps.has_neon {
$neon
} else {
$scalar
}
}};
}
pub use simd_dispatch;
pub trait SimdDispatcher {
type Output;
fn dispatch_avx512(&self) -> Self::Output;
fn dispatch_avx2(&self) -> Self::Output;
fn dispatch_neon(&self) -> Self::Output;
fn dispatch_scalar(&self) -> Self::Output;
fn dispatch_sse42(&self) -> Self::Output {
self.dispatch_scalar()
}
fn dispatch(&self) -> Self::Output {
#[cfg(feature = "std")]
let caps = SimdCapabilities::detect();
#[cfg(not(feature = "std"))]
let caps = SimdCapabilities::detect();
if caps.has_avx512_full() {
self.dispatch_avx512()
} else if caps.has_avx2_fma() {
self.dispatch_avx2()
} else if caps.has_sse42 {
self.dispatch_sse42()
} else if caps.has_neon {
self.dispatch_neon()
} else {
self.dispatch_scalar()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmKernelKind {
Avx512,
Avx2,
Sse42,
Neon,
Scalar,
}
impl GemmKernelKind {
#[inline]
pub const fn name(self) -> &'static str {
match self {
GemmKernelKind::Avx512 => "AVX-512",
GemmKernelKind::Avx2 => "AVX2+FMA",
GemmKernelKind::Sse42 => "SSE4.2",
GemmKernelKind::Neon => "NEON",
GemmKernelKind::Scalar => "scalar",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KernelSelector {
pub gemm_f64_kernel: GemmKernelKind,
pub gemm_f32_kernel: GemmKernelKind,
}
impl KernelSelector {
fn from_caps(caps: &SimdCapabilities) -> Self {
let kind = if caps.has_avx512_full() {
GemmKernelKind::Avx512
} else if caps.has_avx2_fma() {
GemmKernelKind::Avx2
} else if caps.has_sse42 {
GemmKernelKind::Sse42
} else if caps.has_neon {
GemmKernelKind::Neon
} else {
GemmKernelKind::Scalar
};
Self {
gemm_f64_kernel: kind,
gemm_f32_kernel: kind,
}
}
#[cfg(feature = "std")]
pub fn select() -> &'static Self {
static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
KERNEL_SEL.get_or_init(|| Self::from_caps(simd_caps()))
}
#[cfg(not(feature = "std"))]
pub fn select() -> Self {
Self::from_caps(&simd_caps())
}
}
#[cfg(feature = "std")]
pub fn print_capabilities() {
let caps = simd_caps();
let level = caps.optimal_level();
println!("=== OxiBLAS SIMD Capabilities ===");
println!("Optimal level : {}", level.name());
println!("Cache line : {} bytes", caps.cache_line_bytes);
println!("Vector width : {} bytes", caps.vector_width_bytes);
println!("f64 SIMD width : {} elements", caps.f64_simd_width());
println!("f32 SIMD width : {} elements", caps.f32_simd_width());
#[cfg(target_arch = "x86_64")]
{
println!("x86-64 Features:");
println!(" SSE4.2 : {}", caps.has_sse42);
println!(" AVX : {}", caps.has_avx);
println!(" AVX2 : {}", caps.has_avx2);
println!(" FMA : {}", caps.has_fma);
println!(" AVX-512F : {}", caps.has_avx512f);
println!(" AVX-512BW : {}", caps.has_avx512bw);
println!(" AVX-512VL : {}", caps.has_avx512vl);
}
#[cfg(target_arch = "aarch64")]
{
println!("AArch64 Features:");
println!(" NEON : {}", caps.has_neon);
println!(" SVE : {}", caps.has_sve);
}
let sel = KernelSelector::select();
println!("Kernel Selection:");
println!(" GEMM f64 : {}", sel.gemm_f64_kernel.name());
println!(" GEMM f32 : {}", sel.gemm_f32_kernel.name());
println!("==================================");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capabilities_detection_does_not_panic() {
let caps = SimdCapabilities::detect();
assert!(caps.cache_line_bytes >= 8);
assert!(caps.cache_line_bytes.is_power_of_two());
}
#[cfg(target_arch = "aarch64")]
#[test]
fn test_aarch64_neon_always_present() {
let caps = SimdCapabilities::detect();
assert!(caps.has_neon, "NEON is mandatory on AArch64");
assert!(!caps.has_avx2);
assert!(!caps.has_avx512f);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_x86_64_flag_consistency() {
let caps = SimdCapabilities::detect();
assert!(!caps.has_neon);
if caps.has_avx2 {
assert!(caps.has_avx, "AVX2 implies AVX");
}
if caps.has_avx512f {
assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
}
}
#[test]
fn test_vector_width_consistent_with_capabilities() {
let caps = SimdCapabilities::detect();
if caps.has_avx512f {
assert_eq!(caps.vector_width_bytes, 64);
} else if caps.has_avx2 {
assert_eq!(caps.vector_width_bytes, 32);
}
}
#[test]
fn test_simd_widths_derived_correctly() {
let caps = SimdCapabilities::detect();
assert_eq!(
caps.f64_simd_width(),
caps.vector_width_bytes / core::mem::size_of::<f64>()
);
assert_eq!(
caps.f32_simd_width(),
caps.vector_width_bytes / core::mem::size_of::<f32>()
);
assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
}
#[test]
fn test_simd_level_ordering() {
assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
assert!(SimdLevel::Avx2 > SimdLevel::Avx);
assert!(SimdLevel::Avx > SimdLevel::Sse42);
assert!(SimdLevel::Sse42 > SimdLevel::Scalar);
}
#[test]
fn test_simd_level_widths() {
assert_eq!(SimdLevel::Scalar.f64_width(), 1);
assert_eq!(SimdLevel::Scalar.f32_width(), 1);
assert_eq!(SimdLevel::Sse42.f64_width(), 2);
assert_eq!(SimdLevel::Sse42.f32_width(), 4);
assert_eq!(SimdLevel::Avx2.f64_width(), 4);
assert_eq!(SimdLevel::Avx2.f32_width(), 8);
assert_eq!(SimdLevel::Avx512.f64_width(), 8);
assert_eq!(SimdLevel::Avx512.f32_width(), 16);
assert_eq!(SimdLevel::Neon.f64_width(), 2);
assert_eq!(SimdLevel::Neon.f32_width(), 4);
}
#[cfg(feature = "std")]
#[test]
fn test_simd_caps_cached_identity() {
let a = simd_caps();
let b = simd_caps();
assert!(
core::ptr::eq(a, b),
"simd_caps() must return a stable &'static"
);
}
#[test]
fn test_optimal_level_consistent_with_flags() {
let caps = SimdCapabilities::detect();
let level = caps.optimal_level();
match level {
SimdLevel::Avx512 => assert!(caps.has_avx512_full()),
SimdLevel::Avx2 => {
assert!(!caps.has_avx512_full());
assert!(caps.has_avx2_fma());
}
SimdLevel::Avx => {
assert!(!caps.has_avx512_full());
assert!(!caps.has_avx2_fma());
assert!(caps.has_avx);
}
SimdLevel::Sse42 => {
assert!(!caps.has_avx);
assert!(caps.has_sse42);
}
SimdLevel::Neon => {
assert!(caps.has_neon);
assert!(!caps.has_avx);
}
SimdLevel::Sve => {
assert!(caps.has_sve);
assert!(!caps.has_neon);
}
SimdLevel::Scalar => {
assert!(!caps.has_sse42);
assert!(!caps.has_avx);
assert!(!caps.has_neon);
assert!(!caps.has_sve);
}
}
}
#[test]
fn test_kernel_selector_valid_kinds() {
#[cfg(feature = "std")]
let sel = *KernelSelector::select();
#[cfg(not(feature = "std"))]
let sel = KernelSelector::select();
assert!(matches!(
sel.gemm_f64_kernel,
GemmKernelKind::Avx512
| GemmKernelKind::Avx2
| GemmKernelKind::Sse42
| GemmKernelKind::Neon
| GemmKernelKind::Scalar
));
assert!(matches!(
sel.gemm_f32_kernel,
GemmKernelKind::Avx512
| GemmKernelKind::Avx2
| GemmKernelKind::Sse42
| GemmKernelKind::Neon
| GemmKernelKind::Scalar
));
}
#[test]
fn test_kernel_selector_matches_optimal_level() {
let caps = SimdCapabilities::detect();
#[cfg(feature = "std")]
let sel = *KernelSelector::select();
#[cfg(not(feature = "std"))]
let sel = KernelSelector::select();
if caps.has_avx512_full() {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
} else if caps.has_avx2_fma() {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
} else if caps.has_sse42 {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Sse42);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Sse42);
} else if caps.has_neon {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
} else {
assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
}
}
#[test]
fn test_simd_dispatch_macro_selects_branch() {
#[cfg(feature = "std")]
let caps = simd_caps();
#[cfg(not(feature = "std"))]
let caps = &simd_caps();
let result: u32 = simd_dispatch!(
caps,
avx512 => 512u32,
avx2 => 256u32,
sse42 => 128u32,
neon => 1000u32,
scalar => 1u32,
);
let expected: u32 = match caps.optimal_level() {
SimdLevel::Avx512 => 512,
SimdLevel::Avx2 => 256,
SimdLevel::Avx | SimdLevel::Sse42 => 128,
SimdLevel::Neon | SimdLevel::Sve => 1000,
SimdLevel::Scalar => 1,
};
assert_eq!(result, expected);
}
struct ScalarDot<'a> {
x: &'a [f64],
y: &'a [f64],
}
impl SimdDispatcher for ScalarDot<'_> {
type Output = f64;
fn dispatch_avx512(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_avx2(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_neon(&self) -> f64 {
self.dispatch_scalar()
}
fn dispatch_scalar(&self) -> f64 {
self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
}
}
#[test]
fn test_simd_dispatcher_trait_correctness() {
let x = [1.0_f64, 2.0, 3.0, 4.0];
let y = [5.0_f64, 6.0, 7.0, 8.0];
let dot = ScalarDot { x: &x, y: &y };
let result = dot.dispatch();
assert!((result - 70.0).abs() < f64::EPSILON);
}
#[cfg(feature = "std")]
#[test]
fn test_print_capabilities_does_not_panic() {
print_capabilities();
}
#[test]
fn test_gemm_kernel_kind_names_non_empty() {
for kind in [
GemmKernelKind::Avx512,
GemmKernelKind::Avx2,
GemmKernelKind::Sse42,
GemmKernelKind::Neon,
GemmKernelKind::Scalar,
] {
assert!(!kind.name().is_empty());
}
}
#[test]
fn test_helper_functions_agree_with_caps() {
let caps = SimdCapabilities::detect();
assert_eq!(has_avx512(), caps.has_avx512_full());
assert_eq!(has_avx2_fma(), caps.has_avx2_fma());
assert_eq!(has_neon(), caps.has_neon);
}
}