#[cfg(feature = "f16")]
pub mod half_convert_utils;
#[macro_use]
mod half_macros;
pub mod activations;
pub mod binary;
pub mod clamp;
pub mod compare;
pub mod conv;
pub mod cumulative;
pub mod dot;
pub mod fused_activation_mul;
pub mod fused_elementwise;
pub mod index;
pub mod logsumexp;
pub mod math;
pub mod matmul;
pub mod norm;
pub mod reduce;
pub mod scalar;
pub mod softmax;
pub mod softmax_bwd;
pub mod special;
pub mod unary;
pub mod where_select;
#[cfg(target_arch = "x86_64")]
pub mod streaming;
use std::sync::OnceLock;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[allow(dead_code)] pub enum SimdLevel {
Avx512 = 4,
Avx2Fma = 3,
NeonFp16 = 2,
Neon = 1,
Scalar = 0,
}
#[allow(dead_code)]
impl SimdLevel {
#[inline]
pub const fn is_x86(self) -> bool {
matches!(self, Self::Avx512 | Self::Avx2Fma)
}
#[inline]
pub const fn is_arm64(self) -> bool {
matches!(self, Self::Neon | Self::NeonFp16)
}
#[inline]
pub const fn has_avx512(self) -> bool {
matches!(self, Self::Avx512)
}
#[inline]
pub const fn has_avx2(self) -> bool {
matches!(self, Self::Avx512 | Self::Avx2Fma)
}
#[inline]
pub const fn has_neon(self) -> bool {
matches!(self, Self::Neon | Self::NeonFp16)
}
#[inline]
pub const fn f32_lanes(self) -> usize {
match self {
Self::Avx512 => 16,
Self::Avx2Fma => 8,
Self::Neon | Self::NeonFp16 => 4,
Self::Scalar => 1,
}
}
#[inline]
pub const fn f64_lanes(self) -> usize {
match self {
Self::Avx512 => 8,
Self::Avx2Fma => 4,
Self::Neon | Self::NeonFp16 => 2,
Self::Scalar => 1,
}
}
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Avx512 => "AVX-512",
Self::Avx2Fma => "AVX2+FMA",
Self::NeonFp16 => "NEON+FP16",
Self::Neon => "NEON",
Self::Scalar => "Scalar",
}
}
}
impl std::fmt::Display for SimdLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
static SIMD_LEVEL: OnceLock<SimdLevel> = OnceLock::new();
#[inline]
pub fn detect_simd() -> SimdLevel {
*SIMD_LEVEL.get_or_init(detect_simd_uncached)
}
#[cold]
fn detect_simd_uncached() -> SimdLevel {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512vl")
&& is_x86_feature_detected!("fma")
{
return SimdLevel::Avx512;
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return SimdLevel::Avx2Fma;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("fp16") {
return SimdLevel::NeonFp16;
}
return SimdLevel::Neon;
}
#[allow(unreachable_code)]
SimdLevel::Scalar
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_detection_is_cached() {
let level1 = detect_simd();
let level2 = detect_simd();
assert_eq!(level1, level2);
}
#[test]
fn test_simd_level_ordering() {
assert!(SimdLevel::Avx512 > SimdLevel::Avx2Fma);
assert!(SimdLevel::Avx2Fma > SimdLevel::NeonFp16);
assert!(SimdLevel::NeonFp16 > SimdLevel::Neon);
assert!(SimdLevel::Neon > SimdLevel::Scalar);
}
#[test]
fn test_simd_level_capabilities() {
assert!(SimdLevel::Avx512.has_avx512());
assert!(SimdLevel::Avx512.has_avx2());
assert!(!SimdLevel::Avx2Fma.has_avx512());
assert!(SimdLevel::Avx2Fma.has_avx2());
assert!(!SimdLevel::Scalar.has_avx512());
assert!(!SimdLevel::Scalar.has_avx2());
assert!(SimdLevel::Neon.has_neon());
assert!(SimdLevel::NeonFp16.has_neon());
assert!(!SimdLevel::Avx512.has_neon());
assert!(!SimdLevel::Scalar.has_neon());
}
#[test]
fn test_architecture_detection() {
assert!(SimdLevel::Avx512.is_x86());
assert!(SimdLevel::Avx2Fma.is_x86());
assert!(!SimdLevel::Neon.is_x86());
assert!(!SimdLevel::Scalar.is_x86());
assert!(SimdLevel::Neon.is_arm64());
assert!(SimdLevel::NeonFp16.is_arm64());
assert!(!SimdLevel::Avx512.is_arm64());
assert!(!SimdLevel::Scalar.is_arm64());
}
#[test]
fn test_lane_counts() {
assert_eq!(SimdLevel::Avx512.f32_lanes(), 16);
assert_eq!(SimdLevel::Avx2Fma.f32_lanes(), 8);
assert_eq!(SimdLevel::Avx512.f64_lanes(), 8);
assert_eq!(SimdLevel::Avx2Fma.f64_lanes(), 4);
assert_eq!(SimdLevel::Neon.f32_lanes(), 4);
assert_eq!(SimdLevel::NeonFp16.f32_lanes(), 4);
assert_eq!(SimdLevel::Neon.f64_lanes(), 2);
assert_eq!(SimdLevel::NeonFp16.f64_lanes(), 2);
assert_eq!(SimdLevel::Scalar.f32_lanes(), 1);
assert_eq!(SimdLevel::Scalar.f64_lanes(), 1);
}
}