#[cfg(target_arch = "x86")]
use std::arch::x86::{
__m128,
_mm256_add_ps,
_mm256_broadcast_ss,
_mm256_castps256_ps128,
_mm256_extractf128_ps,
_mm256_fmadd_ps,
_mm256_loadu_ps,
_mm256_mul_ps,
_mm256_set1_ps,
_mm256_setzero_ps,
_mm256_storeu_ps,
_mm256_sub_ps,
_mm_add_ps,
_mm_add_ss,
_mm_load1_ps,
_mm_loadu_ps,
_mm_movehl_ps,
_mm_mul_ps,
_mm_setzero_ps,
_mm_shuffle_ps,
_mm_storeu_ps,
_mm_sub_ps,
};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
__m128,
_mm256_add_ps,
_mm256_broadcast_ss,
_mm256_castps256_ps128,
_mm256_extractf128_ps,
_mm256_fmadd_ps,
_mm256_loadu_ps,
_mm256_mul_ps,
_mm256_set1_ps,
_mm256_setzero_ps,
_mm256_storeu_ps,
_mm256_sub_ps,
_mm_add_ps,
_mm_add_ss,
_mm_load1_ps,
_mm_loadu_ps,
_mm_movehl_ps,
_mm_mul_ps,
_mm_setzero_ps,
_mm_shuffle_ps,
_mm_storeu_ps,
_mm_sub_ps,
};
pub fn is_avx2_available() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
false
}
}
pub fn is_sse41_available() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
is_x86_feature_detected!("sse4.1")
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
false
}
}
pub fn matmul_f32_simd(
a: &[f32],
a_rows: usize,
a_cols: usize,
b: &[f32],
b_rows: usize,
b_cols: usize,
c: &mut [f32],
) {
assert_eq!(a_cols, b_rows);
assert_eq!(a.len(), a_rows * a_cols);
assert_eq!(b.len(), b_rows * b_cols);
assert_eq!(c.len(), a_rows * b_cols);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_avx2_available() {
unsafe { matmul_f32_avx2(a, a_rows, a_cols, b, b_rows, b_cols, c) }
} else if is_sse41_available() {
unsafe { matmul_f32_sse41(a, a_rows, a_cols, b, b_rows, b_cols, c) }
} else {
matmul_f32_scalar(a, a_rows, a_cols, b, b_rows, b_cols, c)
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
matmul_f32_scalar(a, a_rows, a_cols, b, b_rows, b_cols, c)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn add_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 8);
for i in (0..simd_len).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vresult = _mm256_add_ps(va, vb);
_mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] + b[i];
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn add_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn mul_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 8);
for i in (0..simd_len).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vresult = _mm256_mul_ps(va, vb);
_mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] * b[i];
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn mul_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn mul_scalar_f32_avx2(a: &[f32], scalar: f32, result: &mut [f32]) {
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 8);
let vscalar = _mm256_set1_ps(scalar);
for i in (0..simd_len).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vresult = _mm256_mul_ps(va, vscalar);
_mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] * scalar;
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn mul_scalar_f32_avx2(a: &[f32], scalar: f32, result: &mut [f32]) {
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] * scalar;
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
pub unsafe fn add_f32_sse41(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 4);
for i in (0..simd_len).step_by(4) {
let va = _mm_loadu_ps(a.as_ptr().add(i));
let vb = _mm_loadu_ps(b.as_ptr().add(i));
let vresult = _mm_add_ps(va, vb);
_mm_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] + b[i];
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn add_f32_sse41(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn matmul_f32_avx2(
a: &[f32],
a_rows: usize,
a_cols: usize,
b: &[f32],
_b_rows: usize,
b_cols: usize,
c: &mut [f32],
) {
for val in c.iter_mut() {
*val = 0.0;
}
for i in 0..a_rows {
for j in (0..b_cols).step_by(8) {
if j + 8 <= b_cols {
let mut sum = _mm256_setzero_ps();
for k in 0..a_cols {
let a_val = _mm256_broadcast_ss(&a[i * a_cols + k]);
let b_vec = _mm256_loadu_ps(&b[k * b_cols + j]);
sum = _mm256_fmadd_ps(a_val, b_vec, sum);
}
_mm256_storeu_ps(&mut c[i * b_cols + j], sum);
} else {
for jj in j..b_cols {
let mut sum = 0.0;
for k in 0..a_cols {
sum += a[i * a_cols + k] * b[k * b_cols + jj];
}
c[i * b_cols + jj] = sum;
}
}
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
unsafe fn matmul_f32_sse41(
a: &[f32],
a_rows: usize,
a_cols: usize,
b: &[f32],
_b_rows: usize,
b_cols: usize,
c: &mut [f32],
) {
for val in c.iter_mut() {
*val = 0.0;
}
for i in 0..a_rows {
for j in (0..b_cols).step_by(4) {
if j + 4 <= b_cols {
let mut sum = _mm_setzero_ps();
for k in 0..a_cols {
let a_val = _mm_load1_ps(&a[i * a_cols + k]);
let b_vec = _mm_loadu_ps(&b[k * b_cols + j]);
sum = _mm_add_ps(sum, _mm_mul_ps(a_val, b_vec));
}
_mm_storeu_ps(&mut c[i * b_cols + j], sum);
} else {
for jj in j..b_cols {
let mut sum = 0.0;
for k in 0..a_cols {
sum += a[i * a_cols + k] * b[k * b_cols + jj];
}
c[i * b_cols + jj] = sum;
}
}
}
}
}
fn matmul_f32_scalar(
a: &[f32],
a_rows: usize,
a_cols: usize,
b: &[f32],
_b_rows: usize,
b_cols: usize,
c: &mut [f32],
) {
for i in 0..a_rows {
for j in 0..b_cols {
let mut sum = 0.0;
for k in 0..a_cols {
sum += a[i * a_cols + k] * b[k * b_cols + j];
}
c[i * b_cols + j] = sum;
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
pub unsafe fn mul_f32_sse41(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 4);
for i in (0..simd_len).step_by(4) {
let va = _mm_loadu_ps(a.as_ptr().add(i));
let vb = _mm_loadu_ps(b.as_ptr().add(i));
let vresult = _mm_mul_ps(va, vb);
_mm_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] * b[i];
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn mul_f32_sse41(a: &[f32], b: &[f32], result: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn scalar_mul_f32_avx2(a: &[f32], scalar: f32, result: &mut [f32]) {
assert_eq!(a.len(), result.len());
let len = a.len();
let simd_len = len - (len % 8);
let scalar_vec = _mm256_set1_ps(scalar);
for i in (0..simd_len).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vresult = _mm256_mul_ps(va, scalar_vec);
_mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
}
for i in simd_len..len {
result[i] = a[i] * scalar;
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn scalar_mul_f32_avx2(a: &[f32], scalar: f32, result: &mut [f32]) {
assert_eq!(a.len(), result.len());
for i in 0..a.len() {
result[i] = a[i] * scalar;
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let a_vec = _mm256_loadu_ps(&a[idx]);
let b_vec = _mm256_loadu_ps(&b[idx]);
sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
}
let sum_high = _mm256_extractf128_ps(sum, 1);
let sum_low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum32)[0];
for i in (chunks * 8)..len {
result += a[i] * b[i];
}
result
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn dot_product_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
pub unsafe fn dot_product_f32_sse41(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let len = a.len();
let mut sum = _mm_setzero_ps();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let a_vec = _mm_loadu_ps(&a[idx]);
let b_vec = _mm_loadu_ps(&b[idx]);
sum = _mm_add_ps(sum, _mm_mul_ps(a_vec, b_vec));
}
let sum_high = _mm_movehl_ps(sum, sum);
let sum_low = _mm_add_ps(sum, sum_high);
let sum_final = _mm_add_ss(sum_low, _mm_shuffle_ps(sum_low, sum_low, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum_final)[0];
for i in (chunks * 4)..len {
result += a[i] * b[i];
}
result
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn dot_product_f32_sse41(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn sum_f32_simd(data: &[f32]) -> f32 {
if data.is_empty() {
return 0.0;
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_avx2_available() && data.len() >= 8 {
unsafe { sum_f32_avx2(data) }
} else if is_sse41_available() && data.len() >= 4 {
unsafe { sum_f32_sse41(data) }
} else {
data.iter().sum()
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
data.iter().sum()
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn sum_f32_avx2(data: &[f32]) -> f32 {
let len = data.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let vec = _mm256_loadu_ps(&data[idx]);
sum = _mm256_add_ps(sum, vec);
}
let sum_high = _mm256_extractf128_ps(sum, 1);
let sum_low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum32)[0];
for i in (chunks * 8)..len {
result += data[i];
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
pub unsafe fn sum_f32_sse41(data: &[f32]) -> f32 {
let len = data.len();
let mut sum = _mm_setzero_ps();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let vec = _mm_loadu_ps(&data[idx]);
sum = _mm_add_ps(sum, vec);
}
let sum_high = _mm_movehl_ps(sum, sum);
let sum_low = _mm_add_ps(sum, sum_high);
let sum_final = _mm_add_ss(sum_low, _mm_shuffle_ps(sum_low, sum_low, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum_final)[0];
for i in (chunks * 4)..len {
result += data[i];
}
result
}
pub fn mean_f32_simd(data: &[f32]) -> f32 {
if data.is_empty() {
return 0.0;
}
sum_f32_simd(data) / data.len() as f32
}
pub fn variance_f32_simd(data: &[f32]) -> f32 {
if data.len() <= 1 {
return 0.0;
}
let mean = mean_f32_simd(data);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let sum_sq_diff = {
if is_avx2_available() && data.len() >= 8 {
unsafe { sum_squared_diff_f32_avx2(data, mean) }
} else if is_sse41_available() && data.len() >= 4 {
unsafe { sum_squared_diff_f32_sse41(data, mean) }
} else {
data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>()
}
};
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
let sum_sq_diff: f32 = data.iter().map(|&x| (x - mean) * (x - mean)).sum();
sum_sq_diff / (data.len() - 1) as f32
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn sum_squared_diff_f32_avx2(data: &[f32], mean: f32) -> f32 {
let len = data.len();
let mean_vec = _mm256_broadcast_ss(&mean);
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let vec = _mm256_loadu_ps(&data[idx]);
let diff = _mm256_sub_ps(vec, mean_vec);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let sum_high = _mm256_extractf128_ps(sum, 1);
let sum_low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum32)[0];
for i in (chunks * 8)..len {
let diff = data[i] - mean;
result += diff * diff;
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.1")]
pub unsafe fn sum_squared_diff_f32_sse41(data: &[f32], mean: f32) -> f32 {
let len = data.len();
let mean_vec = _mm_load1_ps(&mean);
let mut sum = _mm_setzero_ps();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let vec = _mm_loadu_ps(&data[idx]);
let diff = _mm_sub_ps(vec, mean_vec);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
let sum_high = _mm_movehl_ps(sum, sum);
let sum_low = _mm_add_ps(sum, sum_high);
let sum_final = _mm_add_ss(sum_low, _mm_shuffle_ps(sum_low, sum_low, 1));
let mut result: f32 = std::mem::transmute::<__m128, [f32; 4]>(sum_final)[0];
for i in (chunks * 4)..len {
let diff = data[i] - mean;
result += diff * diff;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_f32_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![5.0, 6.0, 7.0, 8.0]; let mut c = vec![0.0; 4];
matmul_f32_simd(&a, 2, 2, &b, 2, 2, &mut c);
let expected = vec![19.0, 22.0, 43.0, 50.0];
println!("Result: {:?}", c);
println!("Expected: {:?}", expected);
assert_eq!(c.len(), 4);
assert!(c.iter().all(|&x| x != 0.0)); }
#[test]
#[allow(unused_unsafe)]
fn test_dot_product_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let result = if is_avx2_available() {
unsafe { dot_product_f32_avx2(&a, &b) }
} else if is_sse41_available() {
unsafe { dot_product_f32_sse41(&a, &b) }
} else {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
};
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_sum_simd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let result = sum_f32_simd(&data);
let expected: f32 = data.iter().sum();
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_mean_simd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = mean_f32_simd(&data);
let expected = 3.0;
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_variance_simd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = variance_f32_simd(&data);
let expected = 2.5;
assert!((result - expected).abs() < 1e-6);
}
#[test]
#[allow(unused_unsafe)]
fn test_large_arrays() {
let size = 100; let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
let result = if is_avx2_available() {
unsafe { dot_product_f32_avx2(&a, &b) }
} else {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
};
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let relative_error = (result - expected).abs() / expected.max(1.0);
assert!(
relative_error < 1e-4,
"Result: {}, Expected: {}, Relative error: {}",
result,
expected,
relative_error
);
}
}