Skip to main content

lattice_embed/simd/
dot_product.rs

1//! SIMD-accelerated dot product operations.
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use std::sync::OnceLock;
10
11use super::simd_config;
12
13/// SIMD kernel function pointer type for f32 dot product.
14pub type DotKernel = fn(&[f32], &[f32]) -> f32;
15
16static DOT_PRODUCT_KERNEL: OnceLock<DotKernel> = OnceLock::new();
17
18/// Resolve the best available f32 dot-product kernel once and return it.
19///
20/// Used by `batch_dot_product` to hoist SIMD dispatch out of batch loops.
21#[inline]
22pub fn resolved_dot_product_kernel() -> DotKernel {
23    *DOT_PRODUCT_KERNEL.get_or_init(resolve_dot_product_kernel)
24}
25
26fn resolve_dot_product_kernel() -> DotKernel {
27    let config = simd_config();
28
29    #[cfg(target_arch = "x86_64")]
30    {
31        if config.avx512f_enabled {
32            return dot_product_avx512_kernel;
33        }
34        if config.avx2_enabled && config.fma_enabled {
35            return dot_product_avx2_kernel;
36        }
37    }
38
39    #[cfg(target_arch = "aarch64")]
40    {
41        if config.neon_enabled {
42            return dot_product_neon_kernel;
43        }
44    }
45
46    dot_product_scalar
47}
48
49// ---------------------------------------------------------------------------
50// Batch-4 dot product kernel (query vs. 4 candidates simultaneously)
51// ---------------------------------------------------------------------------
52
53/// SIMD kernel type for batch-4 f32 dot product.
54///
55/// Signature: (query, c0, c1, c2, c3) → [dot(q,c0), dot(q,c1), dot(q,c2), dot(q,c3)].
56/// All slices must have equal length (enforced by `dot_product_batch4`).
57pub type DotBatch4Kernel = fn(&[f32], &[f32], &[f32], &[f32], &[f32]) -> [f32; 4];
58
59static DOT_PRODUCT_BATCH4_KERNEL: OnceLock<DotBatch4Kernel> = OnceLock::new();
60
61/// Resolve the best available batch-4 f32 dot-product kernel once and return it.
62///
63/// Used by HNSW expansion loop and `batch_dot_product` for same-query chunks.
64#[inline]
65pub fn resolved_dot_product_batch4_kernel() -> DotBatch4Kernel {
66    *DOT_PRODUCT_BATCH4_KERNEL.get_or_init(resolve_dot_product_batch4_kernel)
67}
68
69/// Compute dot product of one query against 4 candidates simultaneously.
70///
71/// Returns `[0.0; 4]` if any candidate length differs from the query length.
72#[inline]
73pub fn dot_product_batch4(
74    query: &[f32],
75    c0: &[f32],
76    c1: &[f32],
77    c2: &[f32],
78    c3: &[f32],
79) -> [f32; 4] {
80    if query.len() != c0.len()
81        || query.len() != c1.len()
82        || query.len() != c2.len()
83        || query.len() != c3.len()
84    {
85        debug_assert!(
86            false,
87            "dot_product_batch4: dimension mismatch (query={}, c0={}, c1={}, c2={}, c3={})",
88            query.len(),
89            c0.len(),
90            c1.len(),
91            c2.len(),
92            c3.len()
93        );
94        return [0.0; 4];
95    }
96    resolved_dot_product_batch4_kernel()(query, c0, c1, c2, c3)
97}
98
99fn resolve_dot_product_batch4_kernel() -> DotBatch4Kernel {
100    let config = simd_config();
101
102    #[cfg(target_arch = "x86_64")]
103    {
104        if config.avx2_enabled && config.fma_enabled {
105            return dot_product_batch4_avx2_kernel;
106        }
107    }
108
109    #[cfg(target_arch = "aarch64")]
110    {
111        if config.neon_enabled {
112            return dot_product_batch4_neon_kernel;
113        }
114    }
115
116    dot_product_batch4_scalar
117}
118
119/// Scalar batch-4 dot product fallback. Used when no SIMD is available.
120fn dot_product_batch4_scalar(
121    q: &[f32],
122    c0: &[f32],
123    c1: &[f32],
124    c2: &[f32],
125    c3: &[f32],
126) -> [f32; 4] {
127    let mut out = [0.0f32; 4];
128    for i in 0..q.len() {
129        let qi = q[i];
130        out[0] += qi * c0[i];
131        out[1] += qi * c1[i];
132        out[2] += qi * c2[i];
133        out[3] += qi * c3[i];
134    }
135    out
136}
137
138#[cfg(target_arch = "x86_64")]
139#[inline]
140fn dot_product_avx512_kernel(a: &[f32], b: &[f32]) -> f32 {
141    // SAFETY: only stored in DOT_PRODUCT_KERNEL when avx512f was detected at init time.
142    unsafe { dot_product_avx512_unrolled(a, b) }
143}
144
145#[cfg(target_arch = "x86_64")]
146#[inline]
147fn dot_product_avx2_kernel(a: &[f32], b: &[f32]) -> f32 {
148    // SAFETY: only stored in DOT_PRODUCT_KERNEL when avx2+fma were detected at init time.
149    if a.len() == 384 {
150        unsafe { dot_product_384_avx2(a, b) }
151    } else {
152        unsafe { dot_product_avx2_8acc(a, b) }
153    }
154}
155
156#[cfg(target_arch = "aarch64")]
157#[inline]
158fn dot_product_neon_kernel(a: &[f32], b: &[f32]) -> f32 {
159    // SAFETY: only stored in DOT_PRODUCT_KERNEL when neon was detected at init time (always true on aarch64).
160    unsafe { dot_product_neon_unrolled(a, b) }
161}
162
163/// **Unstable**: SIMD dispatch layer; use `lattice_embed::utils::dot_product` for the stable wrapper.
164///
165/// For normalized vectors, this equals cosine similarity.
166/// Returns 0.0 if vectors have different lengths.
167#[inline]
168pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
169    // Runtime length check to prevent UB in release builds
170    if a.len() != b.len() {
171        return 0.0;
172    }
173    debug_assert_eq!(a.len(), b.len());
174    resolved_dot_product_kernel()(a, b)
175}
176
177/// Scalar dot product implementation.
178#[inline]
179pub(crate) fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
180    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
181}
182
183/// AVX-512F-accelerated dot product using FMA with 4x unrolling and multiple accumulators.
184///
185/// Processes 64 floats per iteration (4 x 16 floats) with 4 independent accumulators
186/// to break dependency chains and maximize throughput.
187///
188/// # Safety
189///
190/// Caller must ensure:
191/// - CPU supports AVX-512F instructions (verified via `simd_config()`)
192/// - `a` and `b` have equal length (checked by caller)
193///
194/// Memory safety:
195/// - Uses `_mm512_loadu_ps` for unaligned loads (safe for any alignment)
196/// - Pointer arithmetic stays within slice bounds via chunk/remainder calculation:
197///   `chunks = n / CHUNK_SIZE` (floor), so `chunks * CHUNK_SIZE <= n`.
198///   `remaining_chunks = remaining / SIMD_WIDTH` (floor), so all SIMD loads stay in bounds.
199/// - Final scalar loop iterates `scalar_start..n` using safe `a[i]` / `b[i]` indexing
200///   and never reads past the end of the slice.
201#[cfg(target_arch = "x86_64")]
202#[target_feature(enable = "avx512f")]
203unsafe fn dot_product_avx512_unrolled(a: &[f32], b: &[f32]) -> f32 {
204    const SIMD_WIDTH: usize = 16;
205    const UNROLL: usize = 4;
206    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; // 64 floats per iteration
207
208    let n = a.len();
209    debug_assert_eq!(n, b.len());
210    let chunks = n / CHUNK_SIZE;
211
212    // 4 independent accumulators to break dependency chains
213    let mut sum0 = _mm512_setzero_ps();
214    let mut sum1 = _mm512_setzero_ps();
215    let mut sum2 = _mm512_setzero_ps();
216    let mut sum3 = _mm512_setzero_ps();
217
218    for i in 0..chunks {
219        let base = i * CHUNK_SIZE;
220
221        let a0 = _mm512_loadu_ps(a.as_ptr().add(base));
222        let b0 = _mm512_loadu_ps(b.as_ptr().add(base));
223        sum0 = _mm512_fmadd_ps(a0, b0, sum0);
224
225        let a1 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
226        let b1 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
227        sum1 = _mm512_fmadd_ps(a1, b1, sum1);
228
229        let a2 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
230        let b2 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
231        sum2 = _mm512_fmadd_ps(a2, b2, sum2);
232
233        let a3 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
234        let b3 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
235        sum3 = _mm512_fmadd_ps(a3, b3, sum3);
236    }
237
238    // Combine accumulators (dependencies are introduced only once at the end)
239    let sum01 = _mm512_add_ps(sum0, sum1);
240    let sum23 = _mm512_add_ps(sum2, sum3);
241    let sum_vec = _mm512_add_ps(sum01, sum23);
242
243    let main_sum = horizontal_sum_avx512(sum_vec);
244
245    // Handle remainder with single-register loop
246    let main_processed = chunks * CHUNK_SIZE;
247    let remaining = n - main_processed;
248    let remaining_chunks = remaining / SIMD_WIDTH;
249
250    let mut remainder_sum = _mm512_setzero_ps();
251    for i in 0..remaining_chunks {
252        let offset = main_processed + i * SIMD_WIDTH;
253        let a_vec = _mm512_loadu_ps(a.as_ptr().add(offset));
254        let b_vec = _mm512_loadu_ps(b.as_ptr().add(offset));
255        remainder_sum = _mm512_fmadd_ps(a_vec, b_vec, remainder_sum);
256    }
257
258    let mut total = main_sum + horizontal_sum_avx512(remainder_sum);
259
260    // Final scalar remainder
261    let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
262    for i in scalar_start..n {
263        total += a[i] * b[i];
264    }
265
266    total
267}
268
269/// Horizontal sum of AVX-512 register (16 floats -> 1 float).
270///
271/// # Safety
272///
273/// Caller must ensure CPU supports AVX-512F (verified via `target_feature` gate).
274#[cfg(target_arch = "x86_64")]
275#[target_feature(enable = "avx512f")]
276#[inline]
277pub(crate) unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
278    _mm512_reduce_add_ps(v)
279}
280
281/// AVX2-accelerated dot product using 8 independent accumulators.
282///
283/// Processes 64 floats per iteration (8 x 8 floats) with 8 independent accumulators
284/// to better hide FMA latency on modern x86 CPUs. AVX2 provides 16 YMM registers;
285/// 8 accumulators + 1 A load + 1 B load = 10 registers, well within budget.
286///
287/// # Safety
288///
289/// Caller must ensure:
290/// - CPU supports AVX2 and FMA instructions (verified via `simd_config()`)
291/// - `a` and `b` have equal length (checked by caller)
292///
293/// Memory safety:
294/// - Uses `_mm256_loadu_ps` for unaligned loads (safe for any alignment)
295/// - Pointer arithmetic stays within slice bounds via chunk/remainder calculation:
296///   `chunks = n / CHUNK_SIZE` (floor), so `chunks * CHUNK_SIZE <= n`.
297///   `remaining_chunks = remaining / SIMD_WIDTH` (floor), so all SIMD loads stay in bounds.
298/// - Final scalar loop uses safe `a[i]` / `b[i]` indexing and never reads past slice end.
299#[cfg(target_arch = "x86_64")]
300#[target_feature(enable = "avx2", enable = "fma")]
301unsafe fn dot_product_avx2_8acc(a: &[f32], b: &[f32]) -> f32 {
302    const SIMD_WIDTH: usize = 8;
303    const UNROLL: usize = 8;
304    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; // 64 floats per iteration
305    let n = a.len();
306    debug_assert_eq!(n, b.len());
307    let chunks = n / CHUNK_SIZE;
308
309    // 8 independent accumulators to break dependency chains
310    let mut sum0 = _mm256_setzero_ps();
311    let mut sum1 = _mm256_setzero_ps();
312    let mut sum2 = _mm256_setzero_ps();
313    let mut sum3 = _mm256_setzero_ps();
314    let mut sum4 = _mm256_setzero_ps();
315    let mut sum5 = _mm256_setzero_ps();
316    let mut sum6 = _mm256_setzero_ps();
317    let mut sum7 = _mm256_setzero_ps();
318
319    for i in 0..chunks {
320        let base = i * CHUNK_SIZE;
321
322        let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
323        let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
324        sum0 = _mm256_fmadd_ps(a0, b0, sum0);
325
326        let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
327        let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
328        sum1 = _mm256_fmadd_ps(a1, b1, sum1);
329
330        let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
331        let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
332        sum2 = _mm256_fmadd_ps(a2, b2, sum2);
333
334        let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
335        let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
336        sum3 = _mm256_fmadd_ps(a3, b3, sum3);
337
338        let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
339        let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
340        sum4 = _mm256_fmadd_ps(a4, b4, sum4);
341
342        let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
343        let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
344        sum5 = _mm256_fmadd_ps(a5, b5, sum5);
345
346        let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
347        let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
348        sum6 = _mm256_fmadd_ps(a6, b6, sum6);
349
350        let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
351        let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
352        sum7 = _mm256_fmadd_ps(a7, b7, sum7);
353    }
354
355    // Combine accumulators pairwise to reduce dependency chain depth
356    let sum01 = _mm256_add_ps(sum0, sum1);
357    let sum23 = _mm256_add_ps(sum2, sum3);
358    let sum45 = _mm256_add_ps(sum4, sum5);
359    let sum67 = _mm256_add_ps(sum6, sum7);
360    let sum0123 = _mm256_add_ps(sum01, sum23);
361    let sum4567 = _mm256_add_ps(sum45, sum67);
362    let sum_vec = _mm256_add_ps(sum0123, sum4567);
363
364    let sum = horizontal_sum_avx2(sum_vec);
365
366    // Handle remainder with single-vector loop
367    let main_processed = chunks * CHUNK_SIZE;
368    let remaining = n - main_processed;
369    let remaining_chunks = remaining / SIMD_WIDTH;
370
371    let mut remainder_sum = _mm256_setzero_ps();
372    for i in 0..remaining_chunks {
373        let offset = main_processed + i * SIMD_WIDTH;
374        let a_vec = _mm256_loadu_ps(a.as_ptr().add(offset));
375        let b_vec = _mm256_loadu_ps(b.as_ptr().add(offset));
376        remainder_sum = _mm256_fmadd_ps(a_vec, b_vec, remainder_sum);
377    }
378
379    let mut total = sum + horizontal_sum_avx2(remainder_sum);
380
381    // Final scalar remainder
382    let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
383    for i in scalar_start..n {
384        total += a[i] * b[i];
385    }
386
387    total
388}
389
390/// AVX2-accelerated dot product specialized for 384-dimension vectors.
391///
392/// 384 = 48 x 8, so 384d vectors divide evenly into 48 AVX2 iterations
393/// with zero remainder. This eliminates all remainder handling branches.
394/// Uses 8 accumulators across 6 iterations of 8 FMAs each (48 total).
395///
396/// # Safety
397///
398/// Caller must ensure:
399/// - CPU supports AVX2 and FMA instructions (verified via `simd_config()`)
400/// - `a` and `b` have equal length == 384 (checked by caller)
401///
402/// Memory safety:
403/// - Uses `_mm256_loadu_ps` for unaligned loads (safe for any alignment)
404/// - Fixed iteration count (48) covers exactly 384 elements, no out-of-bounds
405#[cfg(target_arch = "x86_64")]
406#[target_feature(enable = "avx2", enable = "fma")]
407unsafe fn dot_product_384_avx2(a: &[f32], b: &[f32]) -> f32 {
408    const SIMD_WIDTH: usize = 8;
409    // 384 / 8 = 48 iterations, processed as 6 groups of 8 for accumulator reuse
410    const UNROLL: usize = 8;
411    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; // 64 floats per iteration
412    const CHUNKS: usize = 384 / CHUNK_SIZE; // 6 full chunks
413    const TAIL_ITERS: usize = (384 - CHUNKS * CHUNK_SIZE) / SIMD_WIDTH; // 0 remainder
414
415    debug_assert_eq!(a.len(), 384);
416    debug_assert_eq!(b.len(), 384);
417    debug_assert_eq!(CHUNKS * CHUNK_SIZE + TAIL_ITERS * SIMD_WIDTH, 384);
418
419    // 8 independent accumulators
420    let mut sum0 = _mm256_setzero_ps();
421    let mut sum1 = _mm256_setzero_ps();
422    let mut sum2 = _mm256_setzero_ps();
423    let mut sum3 = _mm256_setzero_ps();
424    let mut sum4 = _mm256_setzero_ps();
425    let mut sum5 = _mm256_setzero_ps();
426    let mut sum6 = _mm256_setzero_ps();
427    let mut sum7 = _mm256_setzero_ps();
428
429    // 6 full chunks of 64 elements = 384 elements total
430    for i in 0..CHUNKS {
431        let base = i * CHUNK_SIZE;
432
433        let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
434        let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
435        sum0 = _mm256_fmadd_ps(a0, b0, sum0);
436
437        let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
438        let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
439        sum1 = _mm256_fmadd_ps(a1, b1, sum1);
440
441        let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
442        let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
443        sum2 = _mm256_fmadd_ps(a2, b2, sum2);
444
445        let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
446        let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
447        sum3 = _mm256_fmadd_ps(a3, b3, sum3);
448
449        let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
450        let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
451        sum4 = _mm256_fmadd_ps(a4, b4, sum4);
452
453        let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
454        let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
455        sum5 = _mm256_fmadd_ps(a5, b5, sum5);
456
457        let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
458        let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
459        sum6 = _mm256_fmadd_ps(a6, b6, sum6);
460
461        let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
462        let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
463        sum7 = _mm256_fmadd_ps(a7, b7, sum7);
464    }
465
466    // Combine accumulators pairwise
467    let sum01 = _mm256_add_ps(sum0, sum1);
468    let sum23 = _mm256_add_ps(sum2, sum3);
469    let sum45 = _mm256_add_ps(sum4, sum5);
470    let sum67 = _mm256_add_ps(sum6, sum7);
471    let sum0123 = _mm256_add_ps(sum01, sum23);
472    let sum4567 = _mm256_add_ps(sum45, sum67);
473    let sum_vec = _mm256_add_ps(sum0123, sum4567);
474
475    horizontal_sum_avx2(sum_vec)
476}
477
478/// Horizontal sum of AVX2 register (8 floats -> 1 float).
479///
480/// # Safety
481///
482/// Caller must ensure CPU supports AVX2 (verified via `target_feature` gate).
483#[cfg(target_arch = "x86_64")]
484#[target_feature(enable = "avx2")]
485#[inline]
486pub(crate) unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
487    // Sum high and low 128-bit lanes
488    let high = _mm256_extractf128_ps(v, 1);
489    let low = _mm256_castps256_ps128(v);
490    let sum128 = _mm_add_ps(high, low);
491
492    // Horizontal add within 128-bit
493    let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
494    let sums = _mm_add_ps(sum128, shuf); // [0+1,1+1,2+3,3+3]
495    let shuf2 = _mm_movehl_ps(sums, sums); // [2+3,3+3,2+3,3+3]
496    let sums2 = _mm_add_ss(sums, shuf2); // [0+1+2+3,...]
497
498    _mm_cvtss_f32(sums2)
499}
500
501/// AVX2 batch-4 dot product kernel wrapper (routes to 384-specialized or general path).
502///
503/// # Safety
504///
505/// Only stored in `DOT_PRODUCT_BATCH4_KERNEL` when AVX2+FMA was detected at init time.
506#[cfg(target_arch = "x86_64")]
507#[inline]
508fn dot_product_batch4_avx2_kernel(
509    q: &[f32],
510    c0: &[f32],
511    c1: &[f32],
512    c2: &[f32],
513    c3: &[f32],
514) -> [f32; 4] {
515    if q.len() == 384 {
516        unsafe { dot_product_384_batch4_avx2(q, c0, c1, c2, c3) }
517    } else {
518        unsafe { dot_product_batch4_avx2(q, c0, c1, c2, c3) }
519    }
520}
521
522/// AVX2 batch-4 dot product specialized for 384-dimension vectors.
523///
524/// 384 = 24 × 16, processed exactly as 24 chunks (2 AVX2 loads per chunk) with no
525/// scalar remainder, matching the existing `dot_product_384_avx2` specialization.
526///
527/// # Safety
528///
529/// Caller must ensure:
530/// - CPU supports AVX2 and FMA (verified via dispatch table)
531/// - All 5 slices have equal length == 384 (enforced by `dot_product_batch4`)
532#[cfg(target_arch = "x86_64")]
533#[target_feature(enable = "avx2", enable = "fma")]
534unsafe fn dot_product_384_batch4_avx2(
535    q: &[f32],
536    c0: &[f32],
537    c1: &[f32],
538    c2: &[f32],
539    c3: &[f32],
540) -> [f32; 4] {
541    const W: usize = 8; // floats per AVX2 register
542    const CHUNK: usize = W * 2; // 16 floats per loop (2 query loads reused across 4 candidates)
543    const CHUNKS: usize = 384 / CHUNK; // 24 chunks, zero remainder
544
545    debug_assert_eq!(q.len(), 384);
546
547    let mut acc00 = _mm256_setzero_ps();
548    let mut acc01 = _mm256_setzero_ps();
549    let mut acc10 = _mm256_setzero_ps();
550    let mut acc11 = _mm256_setzero_ps();
551    let mut acc20 = _mm256_setzero_ps();
552    let mut acc21 = _mm256_setzero_ps();
553    let mut acc30 = _mm256_setzero_ps();
554    let mut acc31 = _mm256_setzero_ps();
555
556    for i in 0..CHUNKS {
557        let base = i * CHUNK;
558        let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
559        let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
560
561        acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
562        acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
563        acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
564        acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
565        acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
566        acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
567        acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
568        acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
569    }
570
571    [
572        horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
573        horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
574        horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
575        horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
576    ]
577}
578
579/// AVX2 batch-4 dot product for arbitrary-length vectors.
580///
581/// Processes 16 floats per loop (2 AVX2 query loads reused across 4 candidates),
582/// with 2 accumulators per candidate to break FMA dependency chains.
583///
584/// # Safety
585///
586/// Caller must ensure:
587/// - CPU supports AVX2 and FMA (verified via dispatch table)
588/// - All 5 slices have equal length (enforced by `dot_product_batch4`)
589///
590/// Memory safety:
591/// - `chunks * CHUNK <= q.len()` by construction (floor division)
592/// - Scalar tail uses safe `q[i]` / `cN[i]` indexing within slice bounds
593#[cfg(target_arch = "x86_64")]
594#[target_feature(enable = "avx2", enable = "fma")]
595unsafe fn dot_product_batch4_avx2(
596    q: &[f32],
597    c0: &[f32],
598    c1: &[f32],
599    c2: &[f32],
600    c3: &[f32],
601) -> [f32; 4] {
602    const W: usize = 8;
603    const CHUNK: usize = W * 2; // 16 floats per loop
604
605    let n = q.len();
606    let chunks = n / CHUNK;
607
608    let mut acc00 = _mm256_setzero_ps();
609    let mut acc01 = _mm256_setzero_ps();
610    let mut acc10 = _mm256_setzero_ps();
611    let mut acc11 = _mm256_setzero_ps();
612    let mut acc20 = _mm256_setzero_ps();
613    let mut acc21 = _mm256_setzero_ps();
614    let mut acc30 = _mm256_setzero_ps();
615    let mut acc31 = _mm256_setzero_ps();
616
617    for i in 0..chunks {
618        let base = i * CHUNK;
619        let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
620        let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
621
622        acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
623        acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
624        acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
625        acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
626        acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
627        acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
628        acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
629        acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
630    }
631
632    let mut out = [
633        horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
634        horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
635        horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
636        horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
637    ];
638
639    let scalar_start = chunks * CHUNK;
640    for i in scalar_start..n {
641        let qi = q[i];
642        out[0] += qi * c0[i];
643        out[1] += qi * c1[i];
644        out[2] += qi * c2[i];
645        out[3] += qi * c3[i];
646    }
647
648    out
649}
650
651/// NEON-accelerated dot product with 4x unrolling and multiple accumulators.
652///
653/// Processes 16 floats per iteration (4 x 4 floats) with 4 independent accumulators.
654///
655/// # Safety
656///
657/// Caller must ensure:
658/// - Running on aarch64 (NEON is mandatory, always available)
659/// - `a` and `b` have equal length (checked by caller)
660///
661/// Memory safety:
662/// - Uses `vld1q_f32` for loads (handles any alignment)
663/// - Pointer arithmetic stays within slice bounds via chunk/remainder calculation:
664///   `chunks = n / CHUNK_SIZE` (floor), so `chunks * CHUNK_SIZE <= n`.
665///   `remaining_chunks = remaining / SIMD_WIDTH` (floor), so all NEON loads stay in bounds.
666/// - Final scalar loop uses safe `a[i]` / `b[i]` indexing and never reads past slice end.
667#[cfg(target_arch = "aarch64")]
668#[inline]
669unsafe fn dot_product_neon_unrolled(a: &[f32], b: &[f32]) -> f32 {
670    const SIMD_WIDTH: usize = 4;
671    const UNROLL: usize = 4;
672    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; // 16 floats per iteration
673    let n = a.len();
674    debug_assert_eq!(n, b.len());
675    let chunks = n / CHUNK_SIZE;
676
677    // 4 independent accumulators
678    let mut sum0 = vdupq_n_f32(0.0);
679    let mut sum1 = vdupq_n_f32(0.0);
680    let mut sum2 = vdupq_n_f32(0.0);
681    let mut sum3 = vdupq_n_f32(0.0);
682
683    for i in 0..chunks {
684        let base = i * CHUNK_SIZE;
685
686        let a0 = vld1q_f32(a.as_ptr().add(base));
687        let b0 = vld1q_f32(b.as_ptr().add(base));
688        sum0 = vfmaq_f32(sum0, a0, b0);
689
690        let a1 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH));
691        let b1 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH));
692        sum1 = vfmaq_f32(sum1, a1, b1);
693
694        let a2 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 2));
695        let b2 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 2));
696        sum2 = vfmaq_f32(sum2, a2, b2);
697
698        let a3 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 3));
699        let b3 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 3));
700        sum3 = vfmaq_f32(sum3, a3, b3);
701    }
702
703    // Combine accumulators
704    let sum01 = vaddq_f32(sum0, sum1);
705    let sum23 = vaddq_f32(sum2, sum3);
706    let sum_vec = vaddq_f32(sum01, sum23);
707
708    let mut sum = horizontal_sum_neon(sum_vec);
709
710    // Handle remainder with single-vector loop
711    let main_processed = chunks * CHUNK_SIZE;
712    let remaining = n - main_processed;
713    let remaining_chunks = remaining / SIMD_WIDTH;
714
715    let mut remainder_sum = vdupq_n_f32(0.0);
716    for i in 0..remaining_chunks {
717        let offset = main_processed + i * SIMD_WIDTH;
718        let a_vec = vld1q_f32(a.as_ptr().add(offset));
719        let b_vec = vld1q_f32(b.as_ptr().add(offset));
720        remainder_sum = vfmaq_f32(remainder_sum, a_vec, b_vec);
721    }
722
723    sum += horizontal_sum_neon(remainder_sum);
724
725    // Final scalar remainder
726    let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
727    for i in scalar_start..n {
728        sum += a[i] * b[i];
729    }
730
731    sum
732}
733
734/// Horizontal sum of NEON register (4 floats -> 1 float).
735///
736/// # Safety
737///
738/// Caller must ensure running on aarch64 (NEON is mandatory on this arch).
739#[cfg(target_arch = "aarch64")]
740#[inline]
741pub(crate) unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
742    vaddvq_f32(v)
743}
744
745/// NEON batch-4 dot product kernel wrapper.
746///
747/// # Safety
748///
749/// Only stored in `DOT_PRODUCT_BATCH4_KERNEL` when NEON is detected (always on aarch64).
750#[cfg(target_arch = "aarch64")]
751#[inline]
752fn dot_product_batch4_neon_kernel(
753    q: &[f32],
754    c0: &[f32],
755    c1: &[f32],
756    c2: &[f32],
757    c3: &[f32],
758) -> [f32; 4] {
759    // SAFETY: only stored when NEON detected, which is mandatory on aarch64.
760    unsafe { dot_product_batch4_neon(q, c0, c1, c2, c3) }
761}
762
763/// NEON batch-4 dot product: one query vs. 4 candidates simultaneously.
764///
765/// Processes 8 floats per loop (2 NEON vector loads from query, reused across all
766/// 4 candidates). Uses 2 accumulators per candidate to break vfmaq_f32 latency chains.
767/// NEON provides 32 Q-registers; 8 accumulators + 2 query loads = 10 registers, no spill.
768///
769/// # Safety
770///
771/// Caller must ensure:
772/// - Running on aarch64 (NEON is mandatory — always true on this arch)
773/// - All 5 slices have equal length (enforced by `dot_product_batch4`)
774///
775/// Memory safety:
776/// - `chunks * CHUNK <= q.len()` by construction
777/// - Scalar tail uses safe `q[i]` / `cN[i]` indexing within slice bounds
778#[cfg(target_arch = "aarch64")]
779#[inline]
780unsafe fn dot_product_batch4_neon(
781    q: &[f32],
782    c0: &[f32],
783    c1: &[f32],
784    c2: &[f32],
785    c3: &[f32],
786) -> [f32; 4] {
787    const W: usize = 4; // floats per NEON register
788    const CHUNK: usize = W * 2; // 8 floats per loop (2 NEON loads from query)
789
790    let n = q.len();
791    let chunks = n / CHUNK;
792
793    let mut acc00 = vdupq_n_f32(0.0);
794    let mut acc01 = vdupq_n_f32(0.0);
795    let mut acc10 = vdupq_n_f32(0.0);
796    let mut acc11 = vdupq_n_f32(0.0);
797    let mut acc20 = vdupq_n_f32(0.0);
798    let mut acc21 = vdupq_n_f32(0.0);
799    let mut acc30 = vdupq_n_f32(0.0);
800    let mut acc31 = vdupq_n_f32(0.0);
801
802    for i in 0..chunks {
803        let base = i * CHUNK;
804        let q0 = vld1q_f32(q.as_ptr().add(base));
805        let q1 = vld1q_f32(q.as_ptr().add(base + W));
806
807        acc00 = vfmaq_f32(acc00, q0, vld1q_f32(c0.as_ptr().add(base)));
808        acc01 = vfmaq_f32(acc01, q1, vld1q_f32(c0.as_ptr().add(base + W)));
809        acc10 = vfmaq_f32(acc10, q0, vld1q_f32(c1.as_ptr().add(base)));
810        acc11 = vfmaq_f32(acc11, q1, vld1q_f32(c1.as_ptr().add(base + W)));
811        acc20 = vfmaq_f32(acc20, q0, vld1q_f32(c2.as_ptr().add(base)));
812        acc21 = vfmaq_f32(acc21, q1, vld1q_f32(c2.as_ptr().add(base + W)));
813        acc30 = vfmaq_f32(acc30, q0, vld1q_f32(c3.as_ptr().add(base)));
814        acc31 = vfmaq_f32(acc31, q1, vld1q_f32(c3.as_ptr().add(base + W)));
815    }
816
817    let mut out = [
818        vaddvq_f32(vaddq_f32(acc00, acc01)),
819        vaddvq_f32(vaddq_f32(acc10, acc11)),
820        vaddvq_f32(vaddq_f32(acc20, acc21)),
821        vaddvq_f32(vaddq_f32(acc30, acc31)),
822    ];
823
824    let scalar_start = chunks * CHUNK;
825    for i in scalar_start..n {
826        let qi = q[i];
827        out[0] += qi * c0[i];
828        out[1] += qi * c1[i];
829        out[2] += qi * c2[i];
830        out[3] += qi * c3[i];
831    }
832
833    out
834}
835
836/// Returns true only when all 4 pairs in a chunk share the same query pointer and all lengths match.
837#[inline]
838fn same_query_batch4(chunk: &[(&[f32], &[f32])]) -> bool {
839    debug_assert_eq!(chunk.len(), 4);
840    let q_ptr = chunk[0].0.as_ptr();
841    let q_len = chunk[0].0.len();
842    q_len == chunk[0].1.len()
843        && chunk
844            .iter()
845            .all(|(q, c)| q.as_ptr() == q_ptr && q.len() == q_len && c.len() == q_len)
846}
847
848/// **Unstable**: SIMD batch dispatch; use `lattice_embed::utils::batch_dot_product` for stable wrapper.
849///
850/// Uses the batch-4 SIMD kernel for consecutive same-query chunks (e.g., query-vs-N
851/// search pattern where the left-hand slice is the same borrowed reference for every pair).
852/// Falls back to the per-pair kernel for mixed or remainder inputs.
853pub fn batch_dot_product(pairs: &[(&[f32], &[f32])]) -> Vec<f32> {
854    let pair_kernel = resolved_dot_product_kernel();
855    let batch4_kernel = resolved_dot_product_batch4_kernel();
856    let mut out = Vec::with_capacity(pairs.len());
857
858    let mut chunks = pairs.chunks_exact(4);
859    for chunk in &mut chunks {
860        if same_query_batch4(chunk) {
861            let q = chunk[0].0;
862            let dots = batch4_kernel(q, chunk[0].1, chunk[1].1, chunk[2].1, chunk[3].1);
863            out.extend_from_slice(&dots);
864        } else {
865            for &(a, b) in chunk {
866                out.push(if a.len() == b.len() {
867                    pair_kernel(a, b)
868                } else {
869                    0.0
870                });
871            }
872        }
873    }
874    for &(a, b) in chunks.remainder() {
875        out.push(if a.len() == b.len() {
876            pair_kernel(a, b)
877        } else {
878            0.0
879        });
880    }
881    out
882}