use scirs2_core::ndarray::{Array1, Array2};
pub const AVX512_WIDTH: usize = 16;
pub const AVX2_WIDTH: usize = 8;
#[inline]
pub fn is_avx512_available() -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
#[inline]
pub fn is_avx2_available() -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
#[inline]
pub fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { dot_product_avx512_impl(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { dot_product_avx2_impl(a, b) }
} else {
dot_product_scalar(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
dot_product_scalar(a, b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / AVX512_WIDTH;
let remainder = len % AVX512_WIDTH;
let mut acc0 = _mm512_setzero_ps();
let mut acc1 = _mm512_setzero_ps();
let mut acc2 = _mm512_setzero_ps();
let mut acc3 = _mm512_setzero_ps();
let mut i = 0;
let chunks_unrolled = chunks / 4;
for _ in 0..chunks_unrolled {
let a0 = _mm512_loadu_ps(a.as_ptr().add(i));
let b0 = _mm512_loadu_ps(b.as_ptr().add(i));
acc0 = _mm512_fmadd_ps(a0, b0, acc0);
let a1 = _mm512_loadu_ps(a.as_ptr().add(i + AVX512_WIDTH));
let b1 = _mm512_loadu_ps(b.as_ptr().add(i + AVX512_WIDTH));
acc1 = _mm512_fmadd_ps(a1, b1, acc1);
let a2 = _mm512_loadu_ps(a.as_ptr().add(i + AVX512_WIDTH * 2));
let b2 = _mm512_loadu_ps(b.as_ptr().add(i + AVX512_WIDTH * 2));
acc2 = _mm512_fmadd_ps(a2, b2, acc2);
let a3 = _mm512_loadu_ps(a.as_ptr().add(i + AVX512_WIDTH * 3));
let b3 = _mm512_loadu_ps(b.as_ptr().add(i + AVX512_WIDTH * 3));
acc3 = _mm512_fmadd_ps(a3, b3, acc3);
i += AVX512_WIDTH * 4;
}
for _ in 0..(chunks % 4) {
let a_vec = _mm512_loadu_ps(a.as_ptr().add(i));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(i));
acc0 = _mm512_fmadd_ps(a_vec, b_vec, acc0);
i += AVX512_WIDTH;
}
let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
let sum = _mm512_reduce_add_ps(acc);
let mut remainder_sum = 0.0f32;
for j in 0..remainder {
remainder_sum += a[i + j] * b[i + j];
}
sum + remainder_sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / AVX2_WIDTH;
let remainder = len % AVX2_WIDTH;
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut i = 0;
for _ in 0..chunks / 2 {
let a0 = _mm256_loadu_ps(a.as_ptr().add(i));
let b0 = _mm256_loadu_ps(b.as_ptr().add(i));
acc0 = _mm256_fmadd_ps(a0, b0, acc0);
let a1 = _mm256_loadu_ps(a.as_ptr().add(i + AVX2_WIDTH));
let b1 = _mm256_loadu_ps(b.as_ptr().add(i + AVX2_WIDTH));
acc1 = _mm256_fmadd_ps(a1, b1, acc1);
i += AVX2_WIDTH * 2;
}
if chunks % 2 == 1 {
let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
acc0 = _mm256_fmadd_ps(a_vec, b_vec, acc0);
i += AVX2_WIDTH;
}
let acc = _mm256_add_ps(acc0, acc1);
let mut temp = [0.0f32; AVX2_WIDTH];
_mm256_storeu_ps(temp.as_mut_ptr(), acc);
let sum = temp.iter().sum::<f32>();
let mut remainder_sum = 0.0f32;
for j in 0..remainder {
remainder_sum += a[i + j] * b[i + j];
}
sum + remainder_sum
}
#[inline]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
let mut i = 0;
for _ in 0..chunks {
sum0 += a[i] * b[i];
sum1 += a[i + 1] * b[i + 1];
sum2 += a[i + 2] * b[i + 2];
sum3 += a[i + 3] * b[i + 3];
i += 4;
}
for j in 0..remainder {
sum0 += a[i + j] * b[i + j];
}
sum0 + sum1 + sum2 + sum3
}
#[inline]
pub fn matvec_avx512(mat: &Array2<f32>, vec: &Array1<f32>, out: &mut Array1<f32>) {
debug_assert_eq!(mat.ncols(), vec.len());
debug_assert_eq!(mat.nrows(), out.len());
for (i, out_row) in out.iter_mut().enumerate() {
let row = mat.row(i);
let row_slice = row.as_slice().unwrap();
let vec_slice = vec.as_slice().unwrap();
*out_row = dot_product_avx512(row_slice, vec_slice);
}
}
pub mod elementwise {
#[cfg(target_arch = "x86_64")]
use super::AVX512_WIDTH;
#[inline]
pub fn add_avx512(a: &[f32], b: &[f32], out: &mut [f32]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), out.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { add_avx512_impl(a, b, out) }
} else {
add_scalar(a, b, out);
}
}
#[cfg(not(target_arch = "x86_64"))]
{
add_scalar(a, b, out);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn add_avx512_impl(a: &[f32], b: &[f32], out: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / AVX512_WIDTH;
let remainder = len % AVX512_WIDTH;
for i in 0..chunks {
let idx = i * AVX512_WIDTH;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(idx));
let sum = _mm512_add_ps(a_vec, b_vec);
_mm512_storeu_ps(out.as_mut_ptr().add(idx), sum);
}
let idx = chunks * AVX512_WIDTH;
for j in 0..remainder {
out[idx + j] = a[idx + j] + b[idx + j];
}
}
fn add_scalar(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
}
#[inline]
pub fn mul_avx512(a: &[f32], b: &[f32], out: &mut [f32]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), out.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { mul_avx512_impl(a, b, out) }
} else {
mul_scalar(a, b, out);
}
}
#[cfg(not(target_arch = "x86_64"))]
{
mul_scalar(a, b, out);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn mul_avx512_impl(a: &[f32], b: &[f32], out: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / AVX512_WIDTH;
let remainder = len % AVX512_WIDTH;
for i in 0..chunks {
let idx = i * AVX512_WIDTH;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(idx));
let prod = _mm512_mul_ps(a_vec, b_vec);
_mm512_storeu_ps(out.as_mut_ptr().add(idx), prod);
}
let idx = chunks * AVX512_WIDTH;
for j in 0..remainder {
out[idx + j] = a[idx + j] * b[idx + j];
}
}
fn mul_scalar(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] * b[i];
}
}
}
pub mod activations {
#[inline]
pub fn fast_exp_avx512(x: &[f32], out: &mut [f32]) {
debug_assert_eq!(x.len(), out.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { fast_exp_avx512_impl(x, out) }
} else {
fast_exp_scalar(x, out);
}
}
#[cfg(not(target_arch = "x86_64"))]
{
fast_exp_scalar(x, out);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn fast_exp_avx512_impl(x: &[f32], out: &mut [f32]) {
for i in 0..x.len() {
out[i] = x[i].exp(); }
}
fn fast_exp_scalar(x: &[f32], out: &mut [f32]) {
for i in 0..x.len() {
out[i] = x[i].exp();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_avx512_detection() {
let has_avx512 = is_avx512_available();
let has_avx2 = is_avx2_available();
println!("AVX-512 available: {}", has_avx512);
println!("AVX2 available: {}", has_avx2);
}
#[test]
fn test_dot_product_avx512() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = dot_product_avx512(&a, &b);
let expected: f32 = a.iter().sum();
assert!((result - expected).abs() < 1e-5);
}
#[test]
fn test_dot_product_avx512_large() {
let n = 1024;
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b = vec![2.0; n];
let result = dot_product_avx512(&a, &b);
let expected: f32 = a.iter().map(|&x| x * 2.0).sum();
assert!((result - expected).abs() < 1.0); }
#[test]
fn test_matvec_avx512() {
let mat = Array2::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap();
let vec = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let mut out = Array1::zeros(3);
matvec_avx512(&mat, &vec, &mut out);
assert_eq!(out[0], 10.0);
assert_eq!(out[1], 26.0);
assert_eq!(out[2], 42.0);
}
#[test]
fn test_elementwise_add_avx512() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let mut out = vec![0.0; 8];
elementwise::add_avx512(&a, &b, &mut out);
for &val in out.iter().take(8) {
assert_eq!(val, 9.0);
}
}
#[test]
fn test_elementwise_mul_avx512() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0];
let mut out = vec![0.0; 8];
elementwise::mul_avx512(&a, &b, &mut out);
for i in 0..8 {
assert_eq!(out[i], a[i] * 2.0);
}
}
}