#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use crate::optimizer::scalar;
pub fn add_avx512_impl(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len().min(b.len()).min(out.len());
let unroll_factor = 32;
let simd_len = len / unroll_factor;
let main_loop_len = simd_len * unroll_factor;
#[cfg(target_arch = "x86_64")]
{
unsafe {
for i in (0..main_loop_len).step_by(unroll_factor) {
let a_ptr = a.as_ptr().add(i);
let b_ptr = b.as_ptr().add(i);
let out_ptr = out.as_mut_ptr().add(i);
let va1 = _mm512_loadu_ps(a_ptr);
let va2 = _mm512_loadu_ps(a_ptr.add(16));
let vb1 = _mm512_loadu_ps(b_ptr);
let vb2 = _mm512_loadu_ps(b_ptr.add(16));
let vres1 = _mm512_add_ps(va1, vb1);
let vres2 = _mm512_add_ps(va2, vb2);
_mm512_storeu_ps(out_ptr, vres1);
_mm512_storeu_ps(out_ptr.add(16), vres2);
}
}
}
if main_loop_len < len {
scalar::add_impl(
&a[main_loop_len..len],
&b[main_loop_len..len],
&mut out[main_loop_len..len],
);
}
}