#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use std::sync::OnceLock;
use super::simd_config;
pub type DotKernel = fn(&[f32], &[f32]) -> f32;
static DOT_PRODUCT_KERNEL: OnceLock<DotKernel> = OnceLock::new();
#[inline]
pub fn resolved_dot_product_kernel() -> DotKernel {
*DOT_PRODUCT_KERNEL.get_or_init(resolve_dot_product_kernel)
}
fn resolve_dot_product_kernel() -> DotKernel {
let config = simd_config();
#[cfg(target_arch = "x86_64")]
{
if config.avx512f_enabled {
return dot_product_avx512_kernel;
}
if config.avx2_enabled && config.fma_enabled {
return dot_product_avx2_kernel;
}
}
#[cfg(target_arch = "aarch64")]
{
if config.neon_enabled {
return dot_product_neon_kernel;
}
}
dot_product_scalar
}
pub type DotBatch4Kernel = fn(&[f32], &[f32], &[f32], &[f32], &[f32]) -> [f32; 4];
static DOT_PRODUCT_BATCH4_KERNEL: OnceLock<DotBatch4Kernel> = OnceLock::new();
#[inline]
pub fn resolved_dot_product_batch4_kernel() -> DotBatch4Kernel {
*DOT_PRODUCT_BATCH4_KERNEL.get_or_init(resolve_dot_product_batch4_kernel)
}
#[inline]
pub fn dot_product_batch4(
query: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
if query.len() != c0.len()
|| query.len() != c1.len()
|| query.len() != c2.len()
|| query.len() != c3.len()
{
debug_assert!(
false,
"dot_product_batch4: dimension mismatch (query={}, c0={}, c1={}, c2={}, c3={})",
query.len(),
c0.len(),
c1.len(),
c2.len(),
c3.len()
);
return [0.0; 4];
}
resolved_dot_product_batch4_kernel()(query, c0, c1, c2, c3)
}
fn resolve_dot_product_batch4_kernel() -> DotBatch4Kernel {
let config = simd_config();
#[cfg(target_arch = "x86_64")]
{
if config.avx2_enabled && config.fma_enabled {
return dot_product_batch4_avx2_kernel;
}
}
#[cfg(target_arch = "aarch64")]
{
if config.neon_enabled {
return dot_product_batch4_neon_kernel;
}
}
dot_product_batch4_scalar
}
fn dot_product_batch4_scalar(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
let mut out = [0.0f32; 4];
for i in 0..q.len() {
let qi = q[i];
out[0] += qi * c0[i];
out[1] += qi * c1[i];
out[2] += qi * c2[i];
out[3] += qi * c3[i];
}
out
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn dot_product_avx512_kernel(a: &[f32], b: &[f32]) -> f32 {
unsafe { dot_product_avx512_unrolled(a, b) }
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn dot_product_avx2_kernel(a: &[f32], b: &[f32]) -> f32 {
if a.len() == 384 {
unsafe { dot_product_384_avx2(a, b) }
} else {
unsafe { dot_product_avx2_8acc(a, b) }
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn dot_product_neon_kernel(a: &[f32], b: &[f32]) -> f32 {
unsafe { dot_product_neon_unrolled(a, b) }
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
debug_assert_eq!(a.len(), b.len());
resolved_dot_product_kernel()(a, b)
}
#[inline]
pub(crate) fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_product_avx512_unrolled(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 16;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm512_setzero_ps();
let mut sum1 = _mm512_setzero_ps();
let mut sum2 = _mm512_setzero_ps();
let mut sum3 = _mm512_setzero_ps();
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm512_loadu_ps(a.as_ptr().add(base));
let b0 = _mm512_loadu_ps(b.as_ptr().add(base));
sum0 = _mm512_fmadd_ps(a0, b0, sum0);
let a1 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
sum1 = _mm512_fmadd_ps(a1, b1, sum1);
let a2 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
sum2 = _mm512_fmadd_ps(a2, b2, sum2);
let a3 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
sum3 = _mm512_fmadd_ps(a3, b3, sum3);
}
let sum01 = _mm512_add_ps(sum0, sum1);
let sum23 = _mm512_add_ps(sum2, sum3);
let sum_vec = _mm512_add_ps(sum01, sum23);
let main_sum = horizontal_sum_avx512(sum_vec);
let main_processed = chunks * CHUNK_SIZE;
let remaining = n - main_processed;
let remaining_chunks = remaining / SIMD_WIDTH;
let mut remainder_sum = _mm512_setzero_ps();
for i in 0..remaining_chunks {
let offset = main_processed + i * SIMD_WIDTH;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(offset));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(offset));
remainder_sum = _mm512_fmadd_ps(a_vec, b_vec, remainder_sum);
}
let mut total = main_sum + horizontal_sum_avx512(remainder_sum);
let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
for i in scalar_start..n {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
_mm512_reduce_add_ps(v)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_product_avx2_8acc(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 8;
const UNROLL: usize = 8;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
let mut sum4 = _mm256_setzero_ps();
let mut sum5 = _mm256_setzero_ps();
let mut sum6 = _mm256_setzero_ps();
let mut sum7 = _mm256_setzero_ps();
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
sum0 = _mm256_fmadd_ps(a0, b0, sum0);
let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
sum1 = _mm256_fmadd_ps(a1, b1, sum1);
let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
sum2 = _mm256_fmadd_ps(a2, b2, sum2);
let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
sum3 = _mm256_fmadd_ps(a3, b3, sum3);
let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
sum4 = _mm256_fmadd_ps(a4, b4, sum4);
let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
sum5 = _mm256_fmadd_ps(a5, b5, sum5);
let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
sum6 = _mm256_fmadd_ps(a6, b6, sum6);
let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
sum7 = _mm256_fmadd_ps(a7, b7, sum7);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum45 = _mm256_add_ps(sum4, sum5);
let sum67 = _mm256_add_ps(sum6, sum7);
let sum0123 = _mm256_add_ps(sum01, sum23);
let sum4567 = _mm256_add_ps(sum45, sum67);
let sum_vec = _mm256_add_ps(sum0123, sum4567);
let sum = horizontal_sum_avx2(sum_vec);
let main_processed = chunks * CHUNK_SIZE;
let remaining = n - main_processed;
let remaining_chunks = remaining / SIMD_WIDTH;
let mut remainder_sum = _mm256_setzero_ps();
for i in 0..remaining_chunks {
let offset = main_processed + i * SIMD_WIDTH;
let a_vec = _mm256_loadu_ps(a.as_ptr().add(offset));
let b_vec = _mm256_loadu_ps(b.as_ptr().add(offset));
remainder_sum = _mm256_fmadd_ps(a_vec, b_vec, remainder_sum);
}
let mut total = sum + horizontal_sum_avx2(remainder_sum);
let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
for i in scalar_start..n {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_product_384_avx2(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 8;
const UNROLL: usize = 8;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; const CHUNKS: usize = 384 / CHUNK_SIZE; const TAIL_ITERS: usize = (384 - CHUNKS * CHUNK_SIZE) / SIMD_WIDTH;
debug_assert_eq!(a.len(), 384);
debug_assert_eq!(b.len(), 384);
debug_assert_eq!(CHUNKS * CHUNK_SIZE + TAIL_ITERS * SIMD_WIDTH, 384);
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
let mut sum4 = _mm256_setzero_ps();
let mut sum5 = _mm256_setzero_ps();
let mut sum6 = _mm256_setzero_ps();
let mut sum7 = _mm256_setzero_ps();
for i in 0..CHUNKS {
let base = i * CHUNK_SIZE;
let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
sum0 = _mm256_fmadd_ps(a0, b0, sum0);
let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
sum1 = _mm256_fmadd_ps(a1, b1, sum1);
let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
sum2 = _mm256_fmadd_ps(a2, b2, sum2);
let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
sum3 = _mm256_fmadd_ps(a3, b3, sum3);
let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
sum4 = _mm256_fmadd_ps(a4, b4, sum4);
let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
sum5 = _mm256_fmadd_ps(a5, b5, sum5);
let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
sum6 = _mm256_fmadd_ps(a6, b6, sum6);
let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
sum7 = _mm256_fmadd_ps(a7, b7, sum7);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum45 = _mm256_add_ps(sum4, sum5);
let sum67 = _mm256_add_ps(sum6, sum7);
let sum0123 = _mm256_add_ps(sum01, sum23);
let sum4567 = _mm256_add_ps(sum45, sum67);
let sum_vec = _mm256_add_ps(sum0123, sum4567);
horizontal_sum_avx2(sum_vec)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
pub(crate) unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let sums2 = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(sums2)
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn dot_product_batch4_avx2_kernel(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
if q.len() == 384 {
unsafe { dot_product_384_batch4_avx2(q, c0, c1, c2, c3) }
} else {
unsafe { dot_product_batch4_avx2(q, c0, c1, c2, c3) }
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_product_384_batch4_avx2(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
const W: usize = 8; const CHUNK: usize = W * 2; const CHUNKS: usize = 384 / CHUNK;
debug_assert_eq!(q.len(), 384);
let mut acc00 = _mm256_setzero_ps();
let mut acc01 = _mm256_setzero_ps();
let mut acc10 = _mm256_setzero_ps();
let mut acc11 = _mm256_setzero_ps();
let mut acc20 = _mm256_setzero_ps();
let mut acc21 = _mm256_setzero_ps();
let mut acc30 = _mm256_setzero_ps();
let mut acc31 = _mm256_setzero_ps();
for i in 0..CHUNKS {
let base = i * CHUNK;
let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
}
[
horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
]
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_product_batch4_avx2(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
const W: usize = 8;
const CHUNK: usize = W * 2;
let n = q.len();
let chunks = n / CHUNK;
let mut acc00 = _mm256_setzero_ps();
let mut acc01 = _mm256_setzero_ps();
let mut acc10 = _mm256_setzero_ps();
let mut acc11 = _mm256_setzero_ps();
let mut acc20 = _mm256_setzero_ps();
let mut acc21 = _mm256_setzero_ps();
let mut acc30 = _mm256_setzero_ps();
let mut acc31 = _mm256_setzero_ps();
for i in 0..chunks {
let base = i * CHUNK;
let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
}
let mut out = [
horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
];
let scalar_start = chunks * CHUNK;
for i in scalar_start..n {
let qi = q[i];
out[0] += qi * c0[i];
out[1] += qi * c1[i];
out[2] += qi * c2[i];
out[3] += qi * c3[i];
}
out
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn dot_product_neon_unrolled(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 4;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = vld1q_f32(a.as_ptr().add(base));
let b0 = vld1q_f32(b.as_ptr().add(base));
sum0 = vfmaq_f32(sum0, a0, b0);
let a1 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH));
sum1 = vfmaq_f32(sum1, a1, b1);
let a2 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 2));
sum2 = vfmaq_f32(sum2, a2, b2);
let a3 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 3));
sum3 = vfmaq_f32(sum3, a3, b3);
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum_vec = vaddq_f32(sum01, sum23);
let mut sum = horizontal_sum_neon(sum_vec);
let main_processed = chunks * CHUNK_SIZE;
let remaining = n - main_processed;
let remaining_chunks = remaining / SIMD_WIDTH;
let mut remainder_sum = vdupq_n_f32(0.0);
for i in 0..remaining_chunks {
let offset = main_processed + i * SIMD_WIDTH;
let a_vec = vld1q_f32(a.as_ptr().add(offset));
let b_vec = vld1q_f32(b.as_ptr().add(offset));
remainder_sum = vfmaq_f32(remainder_sum, a_vec, b_vec);
}
sum += horizontal_sum_neon(remainder_sum);
let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
for i in scalar_start..n {
sum += a[i] * b[i];
}
sum
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
vaddvq_f32(v)
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn dot_product_batch4_neon_kernel(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
unsafe { dot_product_batch4_neon(q, c0, c1, c2, c3) }
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn dot_product_batch4_neon(
q: &[f32],
c0: &[f32],
c1: &[f32],
c2: &[f32],
c3: &[f32],
) -> [f32; 4] {
const W: usize = 4; const CHUNK: usize = W * 2;
let n = q.len();
let chunks = n / CHUNK;
let mut acc00 = vdupq_n_f32(0.0);
let mut acc01 = vdupq_n_f32(0.0);
let mut acc10 = vdupq_n_f32(0.0);
let mut acc11 = vdupq_n_f32(0.0);
let mut acc20 = vdupq_n_f32(0.0);
let mut acc21 = vdupq_n_f32(0.0);
let mut acc30 = vdupq_n_f32(0.0);
let mut acc31 = vdupq_n_f32(0.0);
for i in 0..chunks {
let base = i * CHUNK;
let q0 = vld1q_f32(q.as_ptr().add(base));
let q1 = vld1q_f32(q.as_ptr().add(base + W));
acc00 = vfmaq_f32(acc00, q0, vld1q_f32(c0.as_ptr().add(base)));
acc01 = vfmaq_f32(acc01, q1, vld1q_f32(c0.as_ptr().add(base + W)));
acc10 = vfmaq_f32(acc10, q0, vld1q_f32(c1.as_ptr().add(base)));
acc11 = vfmaq_f32(acc11, q1, vld1q_f32(c1.as_ptr().add(base + W)));
acc20 = vfmaq_f32(acc20, q0, vld1q_f32(c2.as_ptr().add(base)));
acc21 = vfmaq_f32(acc21, q1, vld1q_f32(c2.as_ptr().add(base + W)));
acc30 = vfmaq_f32(acc30, q0, vld1q_f32(c3.as_ptr().add(base)));
acc31 = vfmaq_f32(acc31, q1, vld1q_f32(c3.as_ptr().add(base + W)));
}
let mut out = [
vaddvq_f32(vaddq_f32(acc00, acc01)),
vaddvq_f32(vaddq_f32(acc10, acc11)),
vaddvq_f32(vaddq_f32(acc20, acc21)),
vaddvq_f32(vaddq_f32(acc30, acc31)),
];
let scalar_start = chunks * CHUNK;
for i in scalar_start..n {
let qi = q[i];
out[0] += qi * c0[i];
out[1] += qi * c1[i];
out[2] += qi * c2[i];
out[3] += qi * c3[i];
}
out
}
#[inline]
fn same_query_batch4(chunk: &[(&[f32], &[f32])]) -> bool {
debug_assert_eq!(chunk.len(), 4);
let q_ptr = chunk[0].0.as_ptr();
let q_len = chunk[0].0.len();
q_len == chunk[0].1.len()
&& chunk
.iter()
.all(|(q, c)| q.as_ptr() == q_ptr && q.len() == q_len && c.len() == q_len)
}
pub fn batch_dot_product(pairs: &[(&[f32], &[f32])]) -> Vec<f32> {
let pair_kernel = resolved_dot_product_kernel();
let batch4_kernel = resolved_dot_product_batch4_kernel();
let mut out = Vec::with_capacity(pairs.len());
let mut chunks = pairs.chunks_exact(4);
for chunk in &mut chunks {
if same_query_batch4(chunk) {
let q = chunk[0].0;
let dots = batch4_kernel(q, chunk[0].1, chunk[1].1, chunk[2].1, chunk[3].1);
out.extend_from_slice(&dots);
} else {
for &(a, b) in chunk {
out.push(if a.len() == b.len() {
pair_kernel(a, b)
} else {
0.0
});
}
}
}
for &(a, b) in chunks.remainder() {
out.push(if a.len() == b.len() {
pair_kernel(a, b)
} else {
0.0
});
}
out
}