use std::mem::MaybeUninit;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) mod x86_ops {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use std::mem::MaybeUninit;
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn dot_and_magnitudes_avx2(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
assert_eq!(a.len(), b.len());
unsafe {
let a_chunks = a.chunks_exact(8);
let b_chunks = b.chunks_exact(8);
let a_rem = a_chunks.remainder();
let b_rem = b_chunks.remainder();
let mut dot_acc = _mm256_setzero_ps();
let mut mag_a_acc = _mm256_setzero_ps();
let mut mag_b_acc = _mm256_setzero_ps();
for (va_chunk, vb_chunk) in a_chunks.zip(b_chunks) {
let va = _mm256_loadu_ps(va_chunk.as_ptr());
let vb = _mm256_loadu_ps(vb_chunk.as_ptr());
dot_acc = _mm256_fmadd_ps(va, vb, dot_acc);
mag_a_acc = _mm256_fmadd_ps(va, va, mag_a_acc);
mag_b_acc = _mm256_fmadd_ps(vb, vb, mag_b_acc);
}
let dot = horizontal_sum_avx(dot_acc);
let mag_a = horizontal_sum_avx(mag_a_acc);
let mag_b = horizontal_sum_avx(mag_b_acc);
let mut dot_rem = 0.0f32;
let mut mag_a_rem = 0.0f32;
let mut mag_b_rem = 0.0f32;
for (&ai, &bi) in a_rem.iter().zip(b_rem) {
dot_rem += ai * bi;
mag_a_rem += ai * ai;
mag_b_rem += bi * bi;
}
(dot + dot_rem, mag_a + mag_a_rem, mag_b + mag_b_rem)
}
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn horizontal_sum_avx(v: __m256) -> f32 {
unsafe {
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
horizontal_sum_sse(sum128)
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn dot_and_magnitudes_sse2(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
assert_eq!(a.len(), b.len());
unsafe {
let a_chunks = a.chunks_exact(4);
let b_chunks = b.chunks_exact(4);
let a_rem = a_chunks.remainder();
let b_rem = b_chunks.remainder();
let mut dot_acc = _mm_setzero_ps();
let mut mag_a_acc = _mm_setzero_ps();
let mut mag_b_acc = _mm_setzero_ps();
for (va_chunk, vb_chunk) in a_chunks.zip(b_chunks) {
let va = _mm_loadu_ps(va_chunk.as_ptr());
let vb = _mm_loadu_ps(vb_chunk.as_ptr());
dot_acc = _mm_add_ps(dot_acc, _mm_mul_ps(va, vb));
mag_a_acc = _mm_add_ps(mag_a_acc, _mm_mul_ps(va, va));
mag_b_acc = _mm_add_ps(mag_b_acc, _mm_mul_ps(vb, vb));
}
let dot = horizontal_sum_sse(dot_acc);
let mag_a = horizontal_sum_sse(mag_a_acc);
let mag_b = horizontal_sum_sse(mag_b_acc);
let mut dot_rem = 0.0f32;
let mut mag_a_rem = 0.0f32;
let mut mag_b_rem = 0.0f32;
for (&ai, &bi) in a_rem.iter().zip(b_rem) {
dot_rem += ai * bi;
mag_a_rem += ai * ai;
mag_b_rem += bi * bi;
}
(dot + dot_rem, mag_a + mag_a_rem, mag_b + mag_b_rem)
}
}
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn horizontal_sum_sse(v: __m128) -> f32 {
let shuf = _mm_shuffle_ps(v, v, 0b10_11_00_01);
let sum1 = _mm_add_ps(v, shuf);
let shuf2 = _mm_shuffle_ps(sum1, sum1, 0b00_00_11_10);
let sum2 = _mm_add_ps(sum1, shuf2);
_mm_cvtss_f32(sum2)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
unsafe {
let a_chunks = a.chunks_exact(8);
let b_chunks = b.chunks_exact(8);
let a_rem = a_chunks.remainder();
let b_rem = b_chunks.remainder();
let mut acc = _mm256_setzero_ps();
for (va_chunk, vb_chunk) in a_chunks.zip(b_chunks) {
let va = _mm256_loadu_ps(va_chunk.as_ptr());
let vb = _mm256_loadu_ps(vb_chunk.as_ptr());
acc = _mm256_fmadd_ps(va, vb, acc);
}
let mut sum = horizontal_sum_avx(acc);
for (&va, &vb) in a_rem.iter().zip(b_rem) {
sum += va * vb;
}
sum
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn dot_product_sse2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
unsafe {
let a_chunks = a.chunks_exact(4);
let b_chunks = b.chunks_exact(4);
let a_rem = a_chunks.remainder();
let b_rem = b_chunks.remainder();
let mut acc = _mm_setzero_ps();
for (va_chunk, vb_chunk) in a_chunks.zip(b_chunks) {
let va = _mm_loadu_ps(va_chunk.as_ptr());
let vb = _mm_loadu_ps(vb_chunk.as_ptr());
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let mut sum = horizontal_sum_sse(acc);
for (&va, &vb) in a_rem.iter().zip(b_rem) {
sum += va * vb;
}
sum
}
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn squared_diff_sum_avx2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
unsafe {
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut acc = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let diff = _mm256_sub_ps(va, vb);
acc = _mm256_fmadd_ps(diff, diff, acc);
}
let mut sum = horizontal_sum_avx(acc);
let start = chunks * 8;
for i in 0..remainder {
let diff = a[start + i] - b[start + i];
sum += diff * diff;
}
sum
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn squared_diff_sum_sse2(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
unsafe {
let len = a.len();
let chunks = len / 4;
let remainder = len % 4;
let mut acc = _mm_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = _mm_loadu_ps(a_ptr.add(offset));
let vb = _mm_loadu_ps(b_ptr.add(offset));
let diff = _mm_sub_ps(va, vb);
acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
}
let mut sum = horizontal_sum_sse(acc);
let start = chunks * 4;
for i in 0..remainder {
let diff = a[start + i] - b[start + i];
sum += diff * diff;
}
sum
}
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn squared_magnitude_avx2(v: &[f32]) -> f32 {
unsafe {
let chunks = v.chunks_exact(8);
let rem = chunks.remainder();
let mut acc = _mm256_setzero_ps();
for chunk in chunks {
let va = _mm256_loadu_ps(chunk.as_ptr());
acc = _mm256_fmadd_ps(va, va, acc);
}
let mut sum = horizontal_sum_avx(acc);
for &x in rem {
sum += x * x;
}
sum
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn squared_magnitude_sse2(v: &[f32]) -> f32 {
unsafe {
let chunks = v.chunks_exact(4);
let rem = chunks.remainder();
let mut acc = _mm_setzero_ps();
for chunk in chunks {
let va = _mm_loadu_ps(chunk.as_ptr());
acc = _mm_add_ps(acc, _mm_mul_ps(va, va));
}
let mut sum = horizontal_sum_sse(acc);
for &x in rem {
sum += x * x;
}
sum
}
}
#[target_feature(enable = "avx2")]
#[inline]
pub unsafe fn scale_in_place_avx2(v: &mut [f32], scalar: f32) {
unsafe {
let scalar_vec = _mm256_set1_ps(scalar);
let mut chunks = v.chunks_exact_mut(8);
for chunk in chunks.by_ref() {
let va = _mm256_loadu_ps(chunk.as_ptr());
let result = _mm256_mul_ps(va, scalar_vec);
_mm256_storeu_ps(chunk.as_mut_ptr(), result);
}
for x in chunks.into_remainder() {
*x *= scalar;
}
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn scale_in_place_sse2(v: &mut [f32], scalar: f32) {
unsafe {
let scalar_vec = _mm_set1_ps(scalar);
let mut chunks = v.chunks_exact_mut(4);
for chunk in chunks.by_ref() {
let va = _mm_loadu_ps(chunk.as_ptr());
let result = _mm_mul_ps(va, scalar_vec);
_mm_storeu_ps(chunk.as_mut_ptr(), result);
}
for x in chunks.into_remainder() {
*x *= scalar;
}
}
}
#[target_feature(enable = "avx2")]
#[inline]
pub unsafe fn scale_and_copy_avx2(src: &[f32], dst: &mut [MaybeUninit<f32>], scalar: f32) {
assert_eq!(src.len(), dst.len());
unsafe {
let scalar_vec = _mm256_set1_ps(scalar);
let mut src_chunks = src.chunks_exact(8);
let mut dst_chunks = dst.chunks_exact_mut(8);
for (s_chunk, d_chunk) in src_chunks.by_ref().zip(dst_chunks.by_ref()) {
let va = _mm256_loadu_ps(s_chunk.as_ptr());
let result = _mm256_mul_ps(va, scalar_vec);
_mm256_storeu_ps(d_chunk.as_mut_ptr() as *mut f32, result);
}
let src_rem = src_chunks.remainder();
let dst_rem = dst_chunks.into_remainder();
for (s, d) in src_rem.iter().zip(dst_rem.iter_mut()) {
d.write(*s * scalar);
}
}
}
#[target_feature(enable = "sse2")]
#[inline]
pub unsafe fn scale_and_copy_sse2(src: &[f32], dst: &mut [MaybeUninit<f32>], scalar: f32) {
assert_eq!(src.len(), dst.len());
unsafe {
let scalar_vec = _mm_set1_ps(scalar);
let mut src_chunks = src.chunks_exact(4);
let mut dst_chunks = dst.chunks_exact_mut(4);
for (s_chunk, d_chunk) in src_chunks.by_ref().zip(dst_chunks.by_ref()) {
let va = _mm_loadu_ps(s_chunk.as_ptr());
let result = _mm_mul_ps(va, scalar_vec);
_mm_storeu_ps(d_chunk.as_mut_ptr() as *mut f32, result);
}
let src_rem = src_chunks.remainder();
let dst_rem = dst_chunks.into_remainder();
for (s, d) in src_rem.iter().zip(dst_rem.iter_mut()) {
d.write(*s * scalar);
}
}
}
#[test]
fn test_squared_magnitude_implementation_coverage() {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use crate::core::vector::simd::x86_ops;
use crate::core::vector::simd::{squared_magnitude, squared_magnitude_scalar};
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let expected = 55.0;
let res_scalar = squared_magnitude_scalar(&v);
assert_eq!(res_scalar, expected, "Scalar implementation failed");
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse2") {
let res_sse2 = unsafe { x86_ops::squared_magnitude_sse2(&v) };
assert_eq!(res_sse2, expected, "SSE2 implementation failed");
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
let res_avx2 = unsafe { x86_ops::squared_magnitude_avx2(&v) };
assert_eq!(res_avx2, expected, "AVX2 implementation failed");
}
}
let res_dispatch = squared_magnitude(&v);
assert_eq!(res_dispatch, expected, "Dispatcher failed");
}
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn dot_and_magnitudes_scalar(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
a.iter().zip(b.iter()).fold(
(0.0f32, 0.0f32, 0.0f32),
|(dot, mag_a, mag_b), (&ai, &bi)| (dot + ai * bi, mag_a + ai * ai, mag_b + bi * bi),
)
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn squared_diff_sum_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| {
let diff = ai - bi;
diff * diff
})
.sum()
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn squared_magnitude_scalar(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum()
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn scale_in_place_scalar(v: &mut [f32], scalar: f32) {
for x in v.iter_mut() {
*x *= scalar;
}
}
#[inline]
#[cfg_attr(
all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)),
allow(dead_code)
)]
pub(crate) fn scale_and_copy_scalar(src: &[f32], dst: &mut [MaybeUninit<f32>], scalar: f32) {
assert_eq!(src.len(), dst.len());
for (s, d) in src.iter().zip(dst.iter_mut()) {
d.write(*s * scalar);
}
}
#[inline(always)]
pub(crate) fn scale_in_place(v: &mut [f32], scalar: f32) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
x86_ops::scale_in_place_avx2(v, scalar);
}
return;
}
if is_x86_feature_detected!("sse2") {
unsafe {
x86_ops::scale_in_place_sse2(v, scalar);
}
return;
}
}
scale_in_place_scalar(v, scalar);
}
#[inline(always)]
pub(crate) fn scale_and_copy(src: &[f32], dst: &mut [MaybeUninit<f32>], scalar: f32) {
assert_eq!(src.len(), dst.len());
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
x86_ops::scale_and_copy_avx2(src, dst, scalar);
}
return;
}
if is_x86_feature_detected!("sse2") {
unsafe {
x86_ops::scale_and_copy_sse2(src, dst, scalar);
}
return;
}
}
scale_and_copy_scalar(src, dst, scalar);
}
#[inline]
pub(crate) fn dot_and_magnitudes(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { x86_ops::dot_and_magnitudes_avx2(a, b) };
}
unsafe { x86_ops::dot_and_magnitudes_sse2(a, b) }
}
#[cfg(target_arch = "x86")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { x86_ops::dot_and_magnitudes_avx2(a, b) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { x86_ops::dot_and_magnitudes_sse2(a, b) };
}
return dot_and_magnitudes_scalar(a, b);
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
dot_and_magnitudes_scalar(a, b)
}
#[inline]
pub(crate) fn squared_diff_sum(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { x86_ops::squared_diff_sum_avx2(a, b) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { x86_ops::squared_diff_sum_sse2(a, b) };
}
}
squared_diff_sum_scalar(a, b)
}
#[inline]
pub(crate) fn dot_product_sum(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { x86_ops::dot_product_avx2(a, b) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { x86_ops::dot_product_sse2(a, b) };
}
}
dot_product_scalar(a, b)
}
#[inline]
pub(crate) fn squared_magnitude(v: &[f32]) -> f32 {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { x86_ops::squared_magnitude_avx2(v) };
}
if is_x86_feature_detected!("sse2") {
return unsafe { x86_ops::squared_magnitude_sse2(v) };
}
}
squared_magnitude_scalar(v)
}
#[cfg(test)]
mod tests {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use super::x86_ops;
use super::*;
fn as_uninit_mut(slice: &mut [f32]) -> &mut [MaybeUninit<f32>] {
unsafe {
std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<f32>, slice.len())
}
}
#[test]
fn test_scale_and_copy_implementation_coverage() {
let src = vec![1.0f32; 17]; let mut dst = vec![0.0f32; 17];
let scalar = 2.0;
let expected = vec![2.0f32; 17];
scale_and_copy_scalar(&src, as_uninit_mut(&mut dst), scalar);
assert_eq!(dst, expected, "Scalar implementation failed");
dst.fill(0.0);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse2") {
unsafe { x86_ops::scale_and_copy_sse2(&src, as_uninit_mut(&mut dst), scalar) };
assert_eq!(dst, expected, "SSE2 implementation failed");
dst.fill(0.0);
}
if is_x86_feature_detected!("avx2") {
unsafe { x86_ops::scale_and_copy_avx2(&src, as_uninit_mut(&mut dst), scalar) };
assert_eq!(dst, expected, "AVX2 implementation failed");
dst.fill(0.0);
}
}
}
#[test]
fn test_scale_in_place_implementation_coverage() {
let v = vec![1.0f32; 17];
let scalar = 2.0;
let expected = vec![2.0f32; 17];
let mut v_scalar = v.clone();
scale_in_place_scalar(&mut v_scalar, scalar);
assert_eq!(v_scalar, expected, "Scalar implementation failed");
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse2") {
let mut v_sse2 = v.clone();
unsafe { x86_ops::scale_in_place_sse2(&mut v_sse2, scalar) };
assert_eq!(v_sse2, expected, "SSE2 implementation failed");
}
if is_x86_feature_detected!("avx2") {
let mut v_avx2 = v.clone();
unsafe { x86_ops::scale_in_place_avx2(&mut v_avx2, scalar) };
assert_eq!(v_avx2, expected, "AVX2 implementation failed");
}
}
}
}