use crate::{Result, TensorError};
pub mod simd_f32_ops {
use super::*;
pub fn simd_add_f32(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != output.len() {
return Err(TensorError::invalid_argument(
"SIMD slice length mismatch".to_string(),
));
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return simd_add_f32_avx2(a, b, output);
}
}
#[cfg(target_arch = "aarch64")]
{
return simd_add_f32_neon(a, b, output);
}
#[allow(unreachable_code)]
for i in 0..a.len() {
output[i] = a[i] + b[i];
}
Ok(())
}
pub fn simd_mul_f32(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != output.len() {
return Err(TensorError::invalid_argument(
"SIMD slice length mismatch".to_string(),
));
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return simd_mul_f32_avx2(a, b, output);
}
}
#[cfg(target_arch = "aarch64")]
{
return simd_mul_f32_neon(a, b, output);
}
#[allow(unreachable_code)]
for i in 0..a.len() {
output[i] = a[i] * b[i];
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
fn simd_add_f32_avx2(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
use std::arch::x86_64::*;
let len = a.len();
let simd_end = len & !7;
unsafe {
for i in (0..simd_end).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vr = _mm256_add_ps(va, vb);
_mm256_storeu_ps(output.as_mut_ptr().add(i), vr);
}
}
for i in simd_end..len {
output[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
fn simd_mul_f32_avx2(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
use std::arch::x86_64::*;
let len = a.len();
let simd_end = len & !7;
unsafe {
for i in (0..simd_end).step_by(8) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vr = _mm256_mul_ps(va, vb);
_mm256_storeu_ps(output.as_mut_ptr().add(i), vr);
}
}
for i in simd_end..len {
output[i] = a[i] * b[i];
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
fn simd_add_f32_neon(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
use std::arch::aarch64::*;
let len = a.len();
let simd_end = len & !3;
unsafe {
for i in (0..simd_end).step_by(4) {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
let vr = vaddq_f32(va, vb);
vst1q_f32(output.as_mut_ptr().add(i), vr);
}
}
for i in simd_end..len {
output[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
fn simd_mul_f32_neon(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<()> {
use std::arch::aarch64::*;
let len = a.len();
let simd_end = len & !3;
unsafe {
for i in (0..simd_end).step_by(4) {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
let vr = vmulq_f32(va, vb);
vst1q_f32(output.as_mut_ptr().add(i), vr);
}
}
for i in simd_end..len {
output[i] = a[i] * b[i];
}
Ok(())
}
}