use std::arch::x86_64::*;
#[derive(Debug, Clone, Copy)]
pub struct SimdFeatures {
pub sse2: bool,
pub avx: bool,
pub avx2: bool,
pub avx512f: bool,
pub fma: bool,
}
impl SimdFeatures {
#[cfg(target_arch = "x86_64")]
pub fn detect() -> Self {
Self {
sse2: is_x86_feature_detected!("sse2"),
avx: is_x86_feature_detected!("avx"),
avx2: is_x86_feature_detected!("avx2"),
avx512f: is_x86_feature_detected!("avx512f"),
fma: is_x86_feature_detected!("fma"),
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn detect() -> Self {
Self {
sse2: false,
avx: false,
avx2: false,
avx512f: false,
fma: false,
}
}
pub fn best_simd(&self) -> SimdLevel {
if self.avx512f {
SimdLevel::Avx512
} else if self.avx2 {
SimdLevel::Avx2
} else if self.avx {
SimdLevel::Avx
} else if self.sse2 {
SimdLevel::Sse2
} else {
SimdLevel::Scalar
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SimdLevel {
Scalar = 0,
Sse2 = 1,
Avx = 2,
Avx2 = 3,
Avx512 = 4,
}
impl SimdLevel {
pub fn vector_width(&self) -> usize {
match self {
SimdLevel::Scalar => 1,
SimdLevel::Sse2 => 16,
SimdLevel::Avx | SimdLevel::Avx2 => 32,
SimdLevel::Avx512 => 64,
}
}
pub fn f32_lanes(&self) -> usize {
self.vector_width() / 4
}
}
#[inline]
pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Arrays must have equal length");
let features = SimdFeatures::detect();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 && features.fma {
unsafe { dot_product_f32_avx2_fma(a, b) }
} else if features.avx {
unsafe { dot_product_f32_avx(a, b) }
} else {
dot_product_f32_scalar(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
dot_product_f32_scalar(a, b)
}
}
#[inline]
fn dot_product_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
#[inline]
unsafe fn dot_product_f32_avx(a: &[f32], b: &[f32]) -> f32 {
unsafe {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
sum = _mm256_add_ps(sum, _mm256_mul_ps(va, vb));
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
result += a[i] * b[i];
}
result
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_product_f32_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
unsafe {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
result += a[i] * b[i];
}
result
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
#[inline]
unsafe fn horizontal_sum_avx(v: __m256) -> f32 {
unsafe {
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(hi, lo);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
}
#[inline]
pub fn add_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
let features = SimdFeatures::detect();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { add_f32_avx2(a, b, result) }
} else {
add_f32_scalar(a, b, result)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
add_f32_scalar(a, b, result)
}
}
#[inline]
fn add_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) {
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn add_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
unsafe {
let len = a.len();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let sum = _mm256_add_ps(va, vb);
_mm256_storeu_ps(result.as_mut_ptr().add(idx), sum);
}
for i in (chunks * 8)..len {
result[i] = a[i] + b[i];
}
}
}
#[inline]
pub fn relu_f32(input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), output.len());
let features = SimdFeatures::detect();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { relu_f32_avx2(input, output) }
} else {
relu_f32_scalar(input, output)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
relu_f32_scalar(input, output)
}
}
#[inline]
fn relu_f32_scalar(input: &[f32], output: &mut [f32]) {
for i in 0..input.len() {
output[i] = input[i].max(0.0);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn relu_f32_avx2(input: &[f32], output: &mut [f32]) {
unsafe {
let len = input.len();
let chunks = len / 8;
let zero = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
let v = _mm256_loadu_ps(input.as_ptr().add(idx));
let relu = _mm256_max_ps(v, zero);
_mm256_storeu_ps(output.as_mut_ptr().add(idx), relu);
}
for i in (chunks * 8)..len {
output[i] = input[i].max(0.0);
}
}
}
pub fn simd_features() -> SimdFeatures {
static FEATURES: std::sync::OnceLock<SimdFeatures> = std::sync::OnceLock::new();
*FEATURES.get_or_init(SimdFeatures::detect)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_detection() {
let features = SimdFeatures::detect();
let level = features.best_simd();
println!("Detected SIMD level: {:?}", level);
println!("Features: {:?}", features);
#[cfg(target_arch = "x86_64")]
{
assert!(features.sse2, "x86_64 always has SSE2");
}
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let result = dot_product_f32(&a, &b);
let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
assert!((result - expected).abs() < 1e-5);
}
#[test]
fn test_add_vectorized() {
let a = vec![1.0; 100];
let b = vec![2.0; 100];
let mut result = vec![0.0; 100];
add_f32(&a, &b, &mut result);
for &r in &result {
assert!((r - 3.0).abs() < 1e-5);
}
}
#[test]
fn test_relu() {
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let mut output = vec![0.0; 5];
relu_f32(&input, &mut output);
let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
for (o, e) in output.iter().zip(&expected) {
assert!((o - e).abs() < 1e-5);
}
}
#[test]
fn test_large_dot_product() {
let size = 10_000;
let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
let result = dot_product_f32(&a, &b);
let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let relative_error = ((result - expected) / expected).abs();
assert!(relative_error < 1e-4);
}
#[test]
fn test_simd_level_comparison() {
assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
assert!(SimdLevel::Avx2 > SimdLevel::Avx);
assert!(SimdLevel::Avx > SimdLevel::Sse2);
assert!(SimdLevel::Sse2 > SimdLevel::Scalar);
}
#[test]
fn test_vector_widths() {
assert_eq!(SimdLevel::Scalar.vector_width(), 1);
assert_eq!(SimdLevel::Sse2.vector_width(), 16);
assert_eq!(SimdLevel::Avx.vector_width(), 32);
assert_eq!(SimdLevel::Avx2.vector_width(), 32);
assert_eq!(SimdLevel::Avx512.vector_width(), 64);
assert_eq!(SimdLevel::Avx2.f32_lanes(), 8);
assert_eq!(SimdLevel::Avx512.f32_lanes(), 16);
}
}