use crate::simd_ops::SimdUnifiedOps;
use ::ndarray::{Array, Array1, ArrayView1, ArrayView2, Dimension, Zip};
use num_traits::Float;
use std::ops::{Add, Div, Mul, Sub};
use crate::simd::detect::get_cpu_features;
use crate::simd::dot::{simd_div_f32, simd_mul_f32, simd_mul_f32_fast};
use crate::simd::traits::SimdOps;
#[allow(dead_code)]
pub fn simd_binary_op<F, S1, S2, D>(
a: &crate::ndarray::ArrayBase<S1, D>,
b: &crate::ndarray::ArrayBase<S2, D>,
op: fn(F, F) -> F,
) -> Array<F, D>
where
F: SimdOps + Float + SimdUnifiedOps,
S1: crate::ndarray::Data<Elem = F>,
S2: crate::ndarray::Data<Elem = F>,
D: Dimension,
{
let mut result = Array::zeros(a.raw_dim());
Zip::from(&mut result)
.and(a)
.and(b)
.for_each(|r, &a, &b| *r = op(a, b));
result
}
#[allow(dead_code)]
pub fn simd_mul_f32_ultra(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
#[cfg(target_arch = "x86_64")]
{
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
let features = get_cpu_features();
use std::arch::x86_64::*;
if features.has_avx2 {
unsafe {
let mut i = 0;
const PREFETCH_DISTANCE: usize = 256;
let a_aligned = (a_ptr as usize) % 32 == 0;
let b_aligned = (b_ptr as usize) % 32 == 0;
let result_aligned = (result_ptr as usize) % 32 == 0;
if a_aligned && b_aligned && result_aligned && len >= 64 {
while i + 64 <= len {
if i + PREFETCH_DISTANCE < len {
_mm_prefetch(
a_ptr.add(i + PREFETCH_DISTANCE) as *const i8,
_MM_HINT_T0,
);
_mm_prefetch(
b_ptr.add(i + PREFETCH_DISTANCE) as *const i8,
_MM_HINT_T0,
);
}
let a_vec1 = _mm256_load_ps(a_ptr.add(i));
let b_vec1 = _mm256_load_ps(b_ptr.add(i));
let result_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
let a_vec2 = _mm256_load_ps(a_ptr.add(i + 8));
let b_vec2 = _mm256_load_ps(b_ptr.add(i + 8));
let result_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
let a_vec3 = _mm256_load_ps(a_ptr.add(i + 16));
let b_vec3 = _mm256_load_ps(b_ptr.add(i + 16));
let result_vec3 = _mm256_mul_ps(a_vec3, b_vec3);
let a_vec4 = _mm256_load_ps(a_ptr.add(i + 24));
let b_vec4 = _mm256_load_ps(b_ptr.add(i + 24));
let result_vec4 = _mm256_mul_ps(a_vec4, b_vec4);
let a_vec5 = _mm256_load_ps(a_ptr.add(i + 32));
let b_vec5 = _mm256_load_ps(b_ptr.add(i + 32));
let result_vec5 = _mm256_mul_ps(a_vec5, b_vec5);
let a_vec6 = _mm256_load_ps(a_ptr.add(i + 40));
let b_vec6 = _mm256_load_ps(b_ptr.add(i + 40));
let result_vec6 = _mm256_mul_ps(a_vec6, b_vec6);
let a_vec7 = _mm256_load_ps(a_ptr.add(i + 48));
let b_vec7 = _mm256_load_ps(b_ptr.add(i + 48));
let result_vec7 = _mm256_mul_ps(a_vec7, b_vec7);
let a_vec8 = _mm256_load_ps(a_ptr.add(i + 56));
let b_vec8 = _mm256_load_ps(b_ptr.add(i + 56));
let result_vec8 = _mm256_mul_ps(a_vec8, b_vec8);
_mm256_store_ps(result_ptr.add(i), result_vec1);
_mm256_store_ps(result_ptr.add(i + 8), result_vec2);
_mm256_store_ps(result_ptr.add(i + 16), result_vec3);
_mm256_store_ps(result_ptr.add(i + 24), result_vec4);
_mm256_store_ps(result_ptr.add(i + 32), result_vec5);
_mm256_store_ps(result_ptr.add(i + 40), result_vec6);
_mm256_store_ps(result_ptr.add(i + 48), result_vec7);
_mm256_store_ps(result_ptr.add(i + 56), result_vec8);
i += 64;
}
}
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
}
} else {
for i in 0..len {
result[i] = a[i] * b[i];
}
}
return Array1::from_vec(result);
}
#[cfg(not(target_arch = "x86_64"))]
{
(a * b).to_owned()
}
}
#[allow(dead_code)]
pub fn simd_fused_multiply_add_f32(
a: &ArrayView1<f32>,
b: &ArrayView1<f32>,
c: &ArrayView1<f32>,
) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays a and b must have the same length");
assert_eq!(a.len(), c.len(), "Arrays a and c must have the same length");
let len = a.len();
let mut result = Vec::with_capacity(len);
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if features.has_fma {
unsafe {
let mut i = 0;
while i + 8 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 8];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 8];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 8];
let a_vec = _mm256_loadu_ps(a_slice.as_ptr());
let b_vec = _mm256_loadu_ps(b_slice.as_ptr());
let c_vec = _mm256_loadu_ps(c_slice.as_ptr());
let result_vec = _mm256_fmadd_ps(a_vec, b_vec, c_vec);
let mut temp = [0.0f32; 8];
_mm256_storeu_ps(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 8;
}
for j in i..len {
result.push(a[j].mul_add(b[j], c[j]));
}
}
} else if is_x86_feature_detected!("avx2") {
unsafe {
let mut i = 0;
while i + 8 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 8];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 8];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 8];
let a_vec = _mm256_loadu_ps(a_slice.as_ptr());
let b_vec = _mm256_loadu_ps(b_slice.as_ptr());
let c_vec = _mm256_loadu_ps(c_slice.as_ptr());
let mul_result = _mm256_mul_ps(a_vec, b_vec);
let result_vec = _mm256_add_ps(mul_result, c_vec);
let mut temp = [0.0f32; 8];
_mm256_storeu_ps(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 8;
}
for j in i..len {
result.push(a[j] * b[j] + c[j]);
}
}
} else {
for i in 0..len {
result.push(a[i] * b[i] + c[i]);
}
}
}
#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::*;
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe {
let mut i = 0;
while i + 4 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
let a_vec = vld1q_f32(a_slice.as_ptr());
let b_vec = vld1q_f32(b_slice.as_ptr());
let c_vec = vld1q_f32(c_slice.as_ptr());
let result_vec = vfmaq_f32(c_vec, a_vec, b_vec);
let mut temp = [0.0f32; 4];
vst1q_f32(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 4;
}
for j in i..len {
result.push(a[j].mul_add(b[j], c[j]));
}
}
} else {
for i in 0..len {
result.push(a[i] * b[i] + c[i]);
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
for i in 0..len {
result.push(a[i] * b[i] + c[i]);
}
}
Array1::from_vec(result)
}
#[allow(dead_code)]
pub fn simd_fused_multiply_add_f64(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
c: &ArrayView1<f64>,
) -> Array1<f64> {
assert_eq!(a.len(), b.len(), "Arrays a and b must have the same length");
assert_eq!(a.len(), c.len(), "Arrays a and c must have the same length");
let len = a.len();
let mut result = Vec::with_capacity(len);
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if features.has_fma && features.has_avx2 {
unsafe {
let mut i = 0;
while i + 4 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
let a_vec = _mm256_loadu_pd(a_slice.as_ptr());
let b_vec = _mm256_loadu_pd(b_slice.as_ptr());
let c_vec = _mm256_loadu_pd(c_slice.as_ptr());
let result_vec = _mm256_fmadd_pd(a_vec, b_vec, c_vec);
let mut temp = [0.0f64; 4];
_mm256_storeu_pd(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 4;
}
for j in i..len {
result.push(a[j].mul_add(b[j], c[j]));
}
}
} else if is_x86_feature_detected!("avx2") {
unsafe {
let mut i = 0;
while i + 4 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 4];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 4];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 4];
let a_vec = _mm256_loadu_pd(a_slice.as_ptr());
let b_vec = _mm256_loadu_pd(b_slice.as_ptr());
let c_vec = _mm256_loadu_pd(c_slice.as_ptr());
let mul_result = _mm256_mul_pd(a_vec, b_vec);
let result_vec = _mm256_add_pd(mul_result, c_vec);
let mut temp = [0.0f64; 4];
_mm256_storeu_pd(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 4;
}
for j in i..len {
result.push(a[j] * b[j] + c[j]);
}
}
} else {
for i in 0..len {
result.push(a[i].mul_add(b[i], c[i]));
}
}
}
#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::*;
unsafe {
let mut i = 0;
while i + 2 <= len {
let a_slice = &a.as_slice().expect("Operation failed")[i..i + 2];
let b_slice = &b.as_slice().expect("Operation failed")[i..i + 2];
let c_slice = &c.as_slice().expect("Operation failed")[i..i + 2];
let a_vec = vld1q_f64(a_slice.as_ptr());
let b_vec = vld1q_f64(b_slice.as_ptr());
let c_vec = vld1q_f64(c_slice.as_ptr());
let result_vec = vfmaq_f64(c_vec, a_vec, b_vec);
let mut temp = [0.0f64; 2];
vst1q_f64(temp.as_mut_ptr(), result_vec);
result.extend_from_slice(&temp);
i += 2;
}
for j in i..len {
result.push(a[j].mul_add(b[j], c[j]));
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
for i in 0..len {
result.push(a[i].mul_add(b[i], c[i]));
}
}
Array1::from(result)
}
#[allow(dead_code)]
pub fn simd_gemv_cache_optimized_f32(
alpha: f32,
a: &ArrayView2<f32>,
x: &ArrayView1<f32>,
beta: f32,
y: &mut Array1<f32>,
) {
f32::simd_gemv(a, x, beta, y);
if alpha != 1.0 {
for elem in y.iter_mut() {
*elem *= alpha;
}
}
}
#[allow(dead_code)]
pub fn simd_add_cache_optimized_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
f32::simd_add_cache_optimized(a, b)
}
#[allow(dead_code)]
pub fn simd_fma_advanced_optimized_f32(
a: &ArrayView1<f32>,
b: &ArrayView1<f32>,
c: &ArrayView1<f32>,
) -> Array1<f32> {
f32::simd_fma_advanced_optimized(a, b, c)
}
#[allow(dead_code)]
pub fn simd_adaptive_add_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
f32::simd_add_adaptive(a, b)
}
#[allow(dead_code)]
pub fn simd_add_cache_optimized_f64(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
f64::simd_add_cache_optimized(a, b)
}
#[allow(dead_code)]
pub fn simd_fma_advanced_optimized_f64(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
c: &ArrayView1<f64>,
) -> Array1<f64> {
f64::simd_fma_advanced_optimized(a, b, c)
}
#[allow(dead_code)]
pub fn simd_adaptive_add_f64(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
f64::simd_add_adaptive(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_blazing(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if features.has_avx2 {
unsafe {
let mut i = 0;
while i + 64 <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
let r5 = _mm256_mul_ps(a5, b5);
let r6 = _mm256_mul_ps(a6, b6);
let r7 = _mm256_mul_ps(a7, b7);
let r8 = _mm256_mul_ps(a8, b8);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
_mm256_storeu_ps(result_ptr.add(i + 32), r5);
_mm256_storeu_ps(result_ptr.add(i + 40), r6);
_mm256_storeu_ps(result_ptr.add(i + 48), r7);
_mm256_storeu_ps(result_ptr.add(i + 56), r8);
i += 64;
}
while i + 32 <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
i += 32;
}
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
return Array1::from(result);
}
}
}
simd_mul_f32_fast(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_cache_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if features.has_avx2 {
unsafe {
let mut i = 0;
const CACHE_LINE_ELEMENTS: usize = 16;
while i + CACHE_LINE_ELEMENTS <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
i += CACHE_LINE_ELEMENTS;
}
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
return Array1::from(result);
}
}
}
simd_mul_f32_fast(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_lightweight(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if std::arch::is_x86_feature_detected!("avx2") {
unsafe {
let mut i = 0;
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
return Array1::from(result);
}
}
}
for i in 0..len {
result[i] = a[i] * b[i];
}
Array1::from_vec(result)
}
#[allow(dead_code)]
pub fn simd_mul_f32_avx512(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if features.has_avx512f {
unsafe {
let mut i = 0;
while i + 64 <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
let r5 = _mm256_mul_ps(a5, b5);
let r6 = _mm256_mul_ps(a6, b6);
let r7 = _mm256_mul_ps(a7, b7);
let r8 = _mm256_mul_ps(a8, b8);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
_mm256_storeu_ps(result_ptr.add(i + 32), r5);
_mm256_storeu_ps(result_ptr.add(i + 40), r6);
_mm256_storeu_ps(result_ptr.add(i + 48), r7);
_mm256_storeu_ps(result_ptr.add(i + 56), r8);
i += 64;
}
while i + 16 <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
i += 16;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
return Array1::from_vec(result);
}
}
}
simd_mul_f32_lightweight(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_branchfree(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if std::arch::is_x86_feature_detected!("avx2") {
unsafe {
let mut i = 0;
let vector_len = len & !7;
while i < vector_len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
if i < len {
let remaining = len - i;
let mask_data = [
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
0xFFFFFFFFu32,
];
let mut masked_data = mask_data;
for j in remaining..8 {
masked_data[j] = 0;
}
let mut a_temp = [0.0f32; 8];
let mut b_temp = [0.0f32; 8];
for j in 0..remaining {
a_temp[j] = *a_ptr.add(i + j);
b_temp[j] = *b_ptr.add(i + j);
}
let a_vec = _mm256_loadu_ps(a_temp.as_ptr());
let b_vec = _mm256_loadu_ps(b_temp.as_ptr());
let result_vec = _mm256_mul_ps(a_vec, b_vec);
let mut result_temp = [0.0f32; 8];
_mm256_storeu_ps(result_temp.as_mut_ptr(), result_vec);
for j in 0..remaining {
*result_ptr.add(i + j) = result_temp[j];
}
}
return Array1::from_vec(result);
}
}
}
for i in 0..len {
result[i] = a[i] * b[i];
}
Array1::from_vec(result)
}
#[allow(dead_code)]
pub fn simd_mul_f32_bandwidth_saturated(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
assert_eq!(a.len(), b.len(), "Arrays must have the same length");
let len = a.len();
let mut result = vec![0.0f32; len];
let a_ptr = a.as_slice().expect("Operation failed").as_ptr();
let b_ptr = b.as_slice().expect("Operation failed").as_ptr();
let result_ptr = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
if std::arch::is_x86_feature_detected!("avx2") {
unsafe {
let mut i = 0;
const CACHE_LINES_PER_ITERATION: usize = 4; const ELEMENTS_PER_ITERATION: usize = CACHE_LINES_PER_ITERATION * 16;
while i + ELEMENTS_PER_ITERATION <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
let r5 = _mm256_mul_ps(a5, b5);
let r6 = _mm256_mul_ps(a6, b6);
let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
let r7 = _mm256_mul_ps(a7, b7);
let r8 = _mm256_mul_ps(a8, b8);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
_mm256_storeu_ps(result_ptr.add(i + 32), r5);
_mm256_storeu_ps(result_ptr.add(i + 40), r6);
_mm256_storeu_ps(result_ptr.add(i + 48), r7);
_mm256_storeu_ps(result_ptr.add(i + 56), r8);
i += ELEMENTS_PER_ITERATION;
}
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
return Array1::from_vec(result);
}
}
}
simd_mul_f32_lightweight(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_ultimate(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
let len = a.len();
if len >= 10_000_000 {
simd_mul_f32_lightweight(a, b)
} else if len >= 100_000 {
simd_mul_f32_branchfree(a, b)
} else if len >= 10_000 {
simd_mul_f32_branchfree(a, b)
} else {
simd_mul_f32_lightweight(a, b)
}
}
pub fn simd_mul_f32_cacheline(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
let len = a.len();
let mut result = unsafe { Array1::uninit(len).assume_init() };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
let cache_line_size = 16;
let vector_end = len - (len % cache_line_size);
let mut i = 0;
while i < vector_end {
_mm_prefetch(a_ptr.add(i + cache_line_size) as *const i8, _MM_HINT_T0);
_mm_prefetch(b_ptr.add(i + cache_line_size) as *const i8, _MM_HINT_T0);
let a_vec1 = _mm256_loadu_ps(a_ptr.add(i));
let a_vec2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let b_vec1 = _mm256_loadu_ps(b_ptr.add(i));
let b_vec2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let result_vec1 = _mm256_mul_ps(a_vec1, b_vec1);
let result_vec2 = _mm256_mul_ps(a_vec2, b_vec2);
_mm256_stream_ps(result_ptr.add(i), result_vec1);
_mm256_stream_ps(result_ptr.add(i + 8), result_vec2);
i += cache_line_size;
}
_mm_sfence();
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
for i in 0..len {
unsafe {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
}
}
}
result
}
pub fn simd_mul_f32_pipelined(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
let len = a.len();
let mut result = unsafe { Array1::uninit(len).assume_init() };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
let mut i = 0;
let block_size = 32; let block_end = len - (len % block_size);
while i < block_end {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
i += block_size;
}
while i + 16 <= len {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
i += 16;
}
while i + 8 <= len {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < len {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
for i in 0..len {
unsafe {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
}
}
}
result
}
pub fn simd_mul_f32_tlb_optimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
let len = a.len();
let mut result = unsafe { Array1::uninit(len).assume_init() };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
const CHUNK_SIZE: usize = 512 * 1024 / 4;
let mut pos = 0;
while pos < len {
let chunk_end = std::cmp::min(pos + CHUNK_SIZE, len);
let prefetch_distance = 64;
for j in (pos..chunk_end).step_by(prefetch_distance) {
_mm_prefetch(a_ptr.add(j) as *const i8, _MM_HINT_T0);
_mm_prefetch(b_ptr.add(j) as *const i8, _MM_HINT_T0);
}
let mut i = pos;
while i + 64 <= chunk_end {
let a1 = _mm256_loadu_ps(a_ptr.add(i));
let a2 = _mm256_loadu_ps(a_ptr.add(i + 8));
let a3 = _mm256_loadu_ps(a_ptr.add(i + 16));
let a4 = _mm256_loadu_ps(a_ptr.add(i + 24));
let a5 = _mm256_loadu_ps(a_ptr.add(i + 32));
let a6 = _mm256_loadu_ps(a_ptr.add(i + 40));
let a7 = _mm256_loadu_ps(a_ptr.add(i + 48));
let a8 = _mm256_loadu_ps(a_ptr.add(i + 56));
let b1 = _mm256_loadu_ps(b_ptr.add(i));
let b2 = _mm256_loadu_ps(b_ptr.add(i + 8));
let b3 = _mm256_loadu_ps(b_ptr.add(i + 16));
let b4 = _mm256_loadu_ps(b_ptr.add(i + 24));
let b5 = _mm256_loadu_ps(b_ptr.add(i + 32));
let b6 = _mm256_loadu_ps(b_ptr.add(i + 40));
let b7 = _mm256_loadu_ps(b_ptr.add(i + 48));
let b8 = _mm256_loadu_ps(b_ptr.add(i + 56));
let r1 = _mm256_mul_ps(a1, b1);
let r2 = _mm256_mul_ps(a2, b2);
let r3 = _mm256_mul_ps(a3, b3);
let r4 = _mm256_mul_ps(a4, b4);
let r5 = _mm256_mul_ps(a5, b5);
let r6 = _mm256_mul_ps(a6, b6);
let r7 = _mm256_mul_ps(a7, b7);
let r8 = _mm256_mul_ps(a8, b8);
_mm256_storeu_ps(result_ptr.add(i), r1);
_mm256_storeu_ps(result_ptr.add(i + 8), r2);
_mm256_storeu_ps(result_ptr.add(i + 16), r3);
_mm256_storeu_ps(result_ptr.add(i + 24), r4);
_mm256_storeu_ps(result_ptr.add(i + 32), r5);
_mm256_storeu_ps(result_ptr.add(i + 40), r6);
_mm256_storeu_ps(result_ptr.add(i + 48), r7);
_mm256_storeu_ps(result_ptr.add(i + 56), r8);
i += 64;
}
while i + 8 <= chunk_end {
let a_vec = _mm256_loadu_ps(a_ptr.add(i));
let b_vec = _mm256_loadu_ps(b_ptr.add(i));
let result_vec = _mm256_mul_ps(a_vec, b_vec);
_mm256_storeu_ps(result_ptr.add(i), result_vec);
i += 8;
}
while i < chunk_end {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
pos = chunk_end;
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
for i in 0..len {
unsafe {
*result_ptr.add(i) = *a_ptr.add(i) * *b_ptr.add(i);
}
}
}
result
}
#[allow(dead_code)]
pub fn simd_mul_f32_adaptive(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
let len = a.len();
if len >= 100_000 {
simd_mul_f32_blazing(a, b)
} else if len >= 10_000 {
simd_mul_f32_fast(a, b)
} else {
simd_mul_f32(a, b)
}
}
#[allow(dead_code)]
pub fn simd_add_auto(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
simd_adaptive_add_f32(a, b)
}
#[allow(dead_code)]
pub fn simd_mul_f32_hyperoptimized(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
simd_mul_f32_ultra(a, b)
}
#[allow(dead_code)]
pub fn simd_add_f32_adaptive(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
f32::simd_add_adaptive(a, b)
}