#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[inline]
pub fn relu_simd(data: &[f32]) -> Vec<f32> {
#[cfg(target_feature = "avx2")]
{
relu_avx2(data)
}
#[cfg(all(not(target_feature = "avx2"), target_feature = "sse2"))]
{
relu_sse2(data)
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sse2")))]
{
relu_scalar(data)
}
}
#[cfg(target_feature = "avx2")]
#[inline]
fn relu_avx2(data: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(data.len());
unsafe {
let zero = _mm256_setzero_ps();
let chunks = data.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let vec = _mm256_loadu_ps(chunk.as_ptr());
let max = _mm256_max_ps(vec, zero);
let mut out = [0.0f32; 8];
_mm256_storeu_ps(out.as_mut_ptr(), max);
result.extend_from_slice(&out);
}
result.extend(remainder.iter().map(|&x| x.max(0.0)));
}
result
}
#[cfg(target_feature = "sse2")]
#[inline]
fn relu_sse2(data: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(data.len());
unsafe {
let zero = _mm_setzero_ps();
let chunks = data.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let vec = _mm_loadu_ps(chunk.as_ptr());
let max = _mm_max_ps(vec, zero);
let mut out = [0.0f32; 4];
_mm_storeu_ps(out.as_mut_ptr(), max);
result.extend_from_slice(&out);
}
result.extend(remainder.iter().map(|&x| x.max(0.0)));
}
result
}
#[allow(dead_code)]
#[inline]
fn relu_scalar(data: &[f32]) -> Vec<f32> {
data.iter().map(|&x| x.max(0.0)).collect()
}
#[inline]
pub fn sigmoid_simd(data: &[f32]) -> Vec<f32> {
data.iter()
.map(|&x| {
1.0 / (1.0 + fast_exp(-x))
})
.collect()
}
#[inline]
fn fast_exp(x: f32) -> f32 {
let x = x.clamp(-88.0, 88.0);
if x < 0.0 {
let x = -x;
let x2 = x * x;
let x3 = x2 * x;
let x4 = x2 * x2;
1.0 / (1.0 + x + x2 * 0.5 + x3 * 0.16666667 + x4 * 0.041666667)
} else {
let x2 = x * x;
let x3 = x2 * x;
let x4 = x2 * x2;
1.0 + x + x2 * 0.5 + x3 * 0.16666667 + x4 * 0.041666667
}
}
#[inline]
pub fn gelu_simd(data: &[f32]) -> Vec<f32> {
const SQRT_2_OVER_PI: f32 = 0.797_884_6;
const COEFF: f32 = 0.044715;
data.iter()
.map(|&x| {
let inner = SQRT_2_OVER_PI * (x + COEFF * x.powi(3));
0.5 * x * (1.0 + fast_tanh(inner))
})
.collect()
}
#[inline]
fn fast_tanh(x: f32) -> f32 {
let x = x.clamp(-3.0, 3.0);
let x2 = x * x;
x * (27.0 + x2) / (27.0 + 9.0 * x2)
}
#[inline]
pub fn add_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
#[cfg(target_feature = "avx2")]
{
add_avx2(a, b)
}
#[cfg(all(not(target_feature = "avx2"), target_feature = "sse2"))]
{
add_sse2(a, b)
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sse2")))]
{
add_scalar(a, b)
}
}
#[cfg(target_feature = "avx2")]
#[inline]
fn add_avx2(a: &[f32], b: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(a.len());
unsafe {
let chunks_a = a.chunks_exact(8);
let chunks_b = b.chunks_exact(8);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let vec_a = _mm256_loadu_ps(chunk_a.as_ptr());
let vec_b = _mm256_loadu_ps(chunk_b.as_ptr());
let sum = _mm256_add_ps(vec_a, vec_b);
let mut out = [0.0f32; 8];
_mm256_storeu_ps(out.as_mut_ptr(), sum);
result.extend_from_slice(&out);
}
result.extend(remainder_a.iter().zip(remainder_b.iter()).map(|(&x, &y)| x + y));
}
result
}
#[cfg(target_feature = "sse2")]
#[inline]
fn add_sse2(a: &[f32], b: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(a.len());
unsafe {
let chunks_a = a.chunks_exact(4);
let chunks_b = b.chunks_exact(4);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let vec_a = _mm_loadu_ps(chunk_a.as_ptr());
let vec_b = _mm_loadu_ps(chunk_b.as_ptr());
let sum = _mm_add_ps(vec_a, vec_b);
let mut out = [0.0f32; 4];
_mm_storeu_ps(out.as_mut_ptr(), sum);
result.extend_from_slice(&out);
}
result.extend(remainder_a.iter().zip(remainder_b.iter()).map(|(&x, &y)| x + y));
}
result
}
#[allow(dead_code)]
#[inline]
fn add_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relu_simd() {
let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let result = relu_simd(&data);
assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_sigmoid_simd() {
let data = vec![0.0];
let result = sigmoid_simd(&data);
assert!((result[0] - 0.5).abs() < 0.01);
}
#[test]
fn test_add_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = add_simd(&a, &b);
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
}
}