#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[inline]
pub fn simd_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), out.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { simd_add_f32_avx2(a, b, out) }
} else if is_x86_feature_detected!("sse4.1") {
unsafe { simd_add_f32_sse(a, b, out) }
} else {
scalar_add_f32(a, b, out)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_add_f32(a, b, out)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_add_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len();
let mut i = 0;
while i + 8 <= len {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vout = _mm256_add_ps(va, vb);
_mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
i += 8;
}
while i < len {
out[i] = a[i] + b[i];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn simd_add_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len();
let mut i = 0;
while i + 4 <= len {
let va = _mm_loadu_ps(a.as_ptr().add(i));
let vb = _mm_loadu_ps(b.as_ptr().add(i));
let vout = _mm_add_ps(va, vb);
_mm_storeu_ps(out.as_mut_ptr().add(i), vout);
i += 4;
}
while i < len {
out[i] = a[i] + b[i];
i += 1;
}
}
#[inline]
fn scalar_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
}
#[inline]
pub fn simd_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), out.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { simd_mul_f32_avx2(a, b, out) }
} else if is_x86_feature_detected!("sse4.1") {
unsafe { simd_mul_f32_sse(a, b, out) }
} else {
scalar_mul_f32(a, b, out)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_mul_f32(a, b, out)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_mul_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len();
let mut i = 0;
while i + 8 <= len {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vout = _mm256_mul_ps(va, vb);
_mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
i += 8;
}
while i < len {
out[i] = a[i] * b[i];
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn simd_mul_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len();
let mut i = 0;
while i + 4 <= len {
let va = _mm_loadu_ps(a.as_ptr().add(i));
let vb = _mm_loadu_ps(b.as_ptr().add(i));
let vout = _mm_mul_ps(va, vb);
_mm_storeu_ps(out.as_mut_ptr().add(i), vout);
i += 4;
}
while i < len {
out[i] = a[i] * b[i];
i += 1;
}
}
#[inline]
fn scalar_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] * b[i];
}
}
#[inline]
pub fn simd_dot_f32(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { simd_dot_f32_avx2(a, b) }
} else if is_x86_feature_detected!("sse4.1") {
unsafe { simd_dot_f32_sse(a, b) }
} else {
scalar_dot_f32(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_dot_f32(a, b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut i = 0;
let mut sum = _mm256_setzero_ps();
while i + 8 <= len {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vprod = _mm256_mul_ps(va, vb);
sum = _mm256_add_ps(sum, vprod);
i += 8;
}
let mut result = 0.0f32;
let sum_array: [f32; 8] = std::mem::transmute(sum);
for &val in &sum_array {
result += val;
}
while i < len {
result += a[i] * b[i];
i += 1;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn simd_dot_f32_sse(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut i = 0;
let mut sum = _mm_setzero_ps();
while i + 4 <= len {
let va = _mm_loadu_ps(a.as_ptr().add(i));
let vb = _mm_loadu_ps(b.as_ptr().add(i));
let vprod = _mm_mul_ps(va, vb);
sum = _mm_add_ps(sum, vprod);
i += 4;
}
let mut result = 0.0f32;
let sum_array: [f32; 4] = std::mem::transmute(sum);
for &val in &sum_array {
result += val;
}
while i < len {
result += a[i] * b[i];
i += 1;
}
result
}
#[inline]
fn scalar_dot_f32(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), output.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { simd_relu_f32_avx2(input, output) }
} else if is_x86_feature_detected!("sse4.1") {
unsafe { simd_relu_f32_sse(input, output) }
} else {
scalar_relu_f32(input, output)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_relu_f32(input, output)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_relu_f32_avx2(input: &[f32], output: &mut [f32]) {
let len = input.len();
let mut i = 0;
let zero = _mm256_setzero_ps();
while i + 8 <= len {
let v = _mm256_loadu_ps(input.as_ptr().add(i));
let vout = _mm256_max_ps(v, zero);
_mm256_storeu_ps(output.as_mut_ptr().add(i), vout);
i += 8;
}
while i < len {
output[i] = input[i].max(0.0);
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn simd_relu_f32_sse(input: &[f32], output: &mut [f32]) {
let len = input.len();
let mut i = 0;
let zero = _mm_setzero_ps();
while i + 4 <= len {
let v = _mm_loadu_ps(input.as_ptr().add(i));
let vout = _mm_max_ps(v, zero);
_mm_storeu_ps(output.as_mut_ptr().add(i), vout);
i += 4;
}
while i < len {
output[i] = input[i].max(0.0);
i += 1;
}
}
#[inline]
fn scalar_relu_f32(input: &[f32], output: &mut [f32]) {
for i in 0..input.len() {
output[i] = input[i].max(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_add() {
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let mut out = vec![0.0f32; 8];
simd_add_f32(&a, &b, &mut out);
for i in 0..8 {
assert_eq!(out[i], 9.0);
}
}
#[test]
fn test_simd_mul() {
let a = vec![1.0f32, 2.0, 3.0, 4.0];
let b = vec![2.0f32, 3.0, 4.0, 5.0];
let mut out = vec![0.0f32; 4];
simd_mul_f32(&a, &b, &mut out);
assert_eq!(out, vec![2.0, 6.0, 12.0, 20.0]);
}
#[test]
fn test_simd_dot() {
let a = vec![1.0f32, 2.0, 3.0, 4.0];
let b = vec![5.0f32, 6.0, 7.0, 8.0];
let result = simd_dot_f32(&a, &b);
assert_eq!(result, 70.0); }
#[test]
fn test_simd_relu() {
let input = vec![-2.0f32, -1.0, 0.0, 1.0, 2.0];
let mut output = vec![0.0f32; 5];
simd_relu_f32(&input, &mut output);
assert_eq!(output, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
}