1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use std::sync::OnceLock;
10
11use super::simd_config;
12
13pub type DotKernel = fn(&[f32], &[f32]) -> f32;
15
16static DOT_PRODUCT_KERNEL: OnceLock<DotKernel> = OnceLock::new();
17
18#[inline]
22pub fn resolved_dot_product_kernel() -> DotKernel {
23 *DOT_PRODUCT_KERNEL.get_or_init(resolve_dot_product_kernel)
24}
25
26fn resolve_dot_product_kernel() -> DotKernel {
27 let config = simd_config();
28
29 #[cfg(target_arch = "x86_64")]
30 {
31 if config.avx512f_enabled {
32 return dot_product_avx512_kernel;
33 }
34 if config.avx2_enabled && config.fma_enabled {
35 return dot_product_avx2_kernel;
36 }
37 }
38
39 #[cfg(target_arch = "aarch64")]
40 {
41 if config.neon_enabled {
42 return dot_product_neon_kernel;
43 }
44 }
45
46 dot_product_scalar
47}
48
49pub type DotBatch4Kernel = fn(&[f32], &[f32], &[f32], &[f32], &[f32]) -> [f32; 4];
58
59static DOT_PRODUCT_BATCH4_KERNEL: OnceLock<DotBatch4Kernel> = OnceLock::new();
60
61#[inline]
65pub fn resolved_dot_product_batch4_kernel() -> DotBatch4Kernel {
66 *DOT_PRODUCT_BATCH4_KERNEL.get_or_init(resolve_dot_product_batch4_kernel)
67}
68
69#[inline]
73pub fn dot_product_batch4(
74 query: &[f32],
75 c0: &[f32],
76 c1: &[f32],
77 c2: &[f32],
78 c3: &[f32],
79) -> [f32; 4] {
80 if query.len() != c0.len()
81 || query.len() != c1.len()
82 || query.len() != c2.len()
83 || query.len() != c3.len()
84 {
85 debug_assert!(
86 false,
87 "dot_product_batch4: dimension mismatch (query={}, c0={}, c1={}, c2={}, c3={})",
88 query.len(),
89 c0.len(),
90 c1.len(),
91 c2.len(),
92 c3.len()
93 );
94 return [0.0; 4];
95 }
96 resolved_dot_product_batch4_kernel()(query, c0, c1, c2, c3)
97}
98
99fn resolve_dot_product_batch4_kernel() -> DotBatch4Kernel {
100 let config = simd_config();
101
102 #[cfg(target_arch = "x86_64")]
103 {
104 if config.avx2_enabled && config.fma_enabled {
105 return dot_product_batch4_avx2_kernel;
106 }
107 }
108
109 #[cfg(target_arch = "aarch64")]
110 {
111 if config.neon_enabled {
112 return dot_product_batch4_neon_kernel;
113 }
114 }
115
116 dot_product_batch4_scalar
117}
118
119fn dot_product_batch4_scalar(
121 q: &[f32],
122 c0: &[f32],
123 c1: &[f32],
124 c2: &[f32],
125 c3: &[f32],
126) -> [f32; 4] {
127 let mut out = [0.0f32; 4];
128 for i in 0..q.len() {
129 let qi = q[i];
130 out[0] += qi * c0[i];
131 out[1] += qi * c1[i];
132 out[2] += qi * c2[i];
133 out[3] += qi * c3[i];
134 }
135 out
136}
137
138#[cfg(target_arch = "x86_64")]
139#[inline]
140fn dot_product_avx512_kernel(a: &[f32], b: &[f32]) -> f32 {
141 unsafe { dot_product_avx512_unrolled(a, b) }
143}
144
145#[cfg(target_arch = "x86_64")]
146#[inline]
147fn dot_product_avx2_kernel(a: &[f32], b: &[f32]) -> f32 {
148 if a.len() == 384 {
150 unsafe { dot_product_384_avx2(a, b) }
151 } else {
152 unsafe { dot_product_avx2_8acc(a, b) }
153 }
154}
155
156#[cfg(target_arch = "aarch64")]
157#[inline]
158fn dot_product_neon_kernel(a: &[f32], b: &[f32]) -> f32 {
159 unsafe { dot_product_neon_unrolled(a, b) }
161}
162
163#[inline]
168pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
169 if a.len() != b.len() {
171 return 0.0;
172 }
173 debug_assert_eq!(a.len(), b.len());
174 resolved_dot_product_kernel()(a, b)
175}
176
177#[inline]
179pub(crate) fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
180 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
181}
182
183#[cfg(target_arch = "x86_64")]
202#[target_feature(enable = "avx512f")]
203unsafe fn dot_product_avx512_unrolled(a: &[f32], b: &[f32]) -> f32 {
204 const SIMD_WIDTH: usize = 16;
205 const UNROLL: usize = 4;
206 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; let n = a.len();
209 debug_assert_eq!(n, b.len());
210 let chunks = n / CHUNK_SIZE;
211
212 let mut sum0 = _mm512_setzero_ps();
214 let mut sum1 = _mm512_setzero_ps();
215 let mut sum2 = _mm512_setzero_ps();
216 let mut sum3 = _mm512_setzero_ps();
217
218 for i in 0..chunks {
219 let base = i * CHUNK_SIZE;
220
221 let a0 = _mm512_loadu_ps(a.as_ptr().add(base));
222 let b0 = _mm512_loadu_ps(b.as_ptr().add(base));
223 sum0 = _mm512_fmadd_ps(a0, b0, sum0);
224
225 let a1 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
226 let b1 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
227 sum1 = _mm512_fmadd_ps(a1, b1, sum1);
228
229 let a2 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
230 let b2 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
231 sum2 = _mm512_fmadd_ps(a2, b2, sum2);
232
233 let a3 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
234 let b3 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
235 sum3 = _mm512_fmadd_ps(a3, b3, sum3);
236 }
237
238 let sum01 = _mm512_add_ps(sum0, sum1);
240 let sum23 = _mm512_add_ps(sum2, sum3);
241 let sum_vec = _mm512_add_ps(sum01, sum23);
242
243 let main_sum = horizontal_sum_avx512(sum_vec);
244
245 let main_processed = chunks * CHUNK_SIZE;
247 let remaining = n - main_processed;
248 let remaining_chunks = remaining / SIMD_WIDTH;
249
250 let mut remainder_sum = _mm512_setzero_ps();
251 for i in 0..remaining_chunks {
252 let offset = main_processed + i * SIMD_WIDTH;
253 let a_vec = _mm512_loadu_ps(a.as_ptr().add(offset));
254 let b_vec = _mm512_loadu_ps(b.as_ptr().add(offset));
255 remainder_sum = _mm512_fmadd_ps(a_vec, b_vec, remainder_sum);
256 }
257
258 let mut total = main_sum + horizontal_sum_avx512(remainder_sum);
259
260 let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
262 for i in scalar_start..n {
263 total += a[i] * b[i];
264 }
265
266 total
267}
268
269#[cfg(target_arch = "x86_64")]
275#[target_feature(enable = "avx512f")]
276#[inline]
277pub(crate) unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
278 _mm512_reduce_add_ps(v)
279}
280
281#[cfg(target_arch = "x86_64")]
300#[target_feature(enable = "avx2", enable = "fma")]
301unsafe fn dot_product_avx2_8acc(a: &[f32], b: &[f32]) -> f32 {
302 const SIMD_WIDTH: usize = 8;
303 const UNROLL: usize = 8;
304 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; let n = a.len();
306 debug_assert_eq!(n, b.len());
307 let chunks = n / CHUNK_SIZE;
308
309 let mut sum0 = _mm256_setzero_ps();
311 let mut sum1 = _mm256_setzero_ps();
312 let mut sum2 = _mm256_setzero_ps();
313 let mut sum3 = _mm256_setzero_ps();
314 let mut sum4 = _mm256_setzero_ps();
315 let mut sum5 = _mm256_setzero_ps();
316 let mut sum6 = _mm256_setzero_ps();
317 let mut sum7 = _mm256_setzero_ps();
318
319 for i in 0..chunks {
320 let base = i * CHUNK_SIZE;
321
322 let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
323 let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
324 sum0 = _mm256_fmadd_ps(a0, b0, sum0);
325
326 let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
327 let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
328 sum1 = _mm256_fmadd_ps(a1, b1, sum1);
329
330 let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
331 let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
332 sum2 = _mm256_fmadd_ps(a2, b2, sum2);
333
334 let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
335 let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
336 sum3 = _mm256_fmadd_ps(a3, b3, sum3);
337
338 let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
339 let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
340 sum4 = _mm256_fmadd_ps(a4, b4, sum4);
341
342 let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
343 let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
344 sum5 = _mm256_fmadd_ps(a5, b5, sum5);
345
346 let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
347 let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
348 sum6 = _mm256_fmadd_ps(a6, b6, sum6);
349
350 let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
351 let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
352 sum7 = _mm256_fmadd_ps(a7, b7, sum7);
353 }
354
355 let sum01 = _mm256_add_ps(sum0, sum1);
357 let sum23 = _mm256_add_ps(sum2, sum3);
358 let sum45 = _mm256_add_ps(sum4, sum5);
359 let sum67 = _mm256_add_ps(sum6, sum7);
360 let sum0123 = _mm256_add_ps(sum01, sum23);
361 let sum4567 = _mm256_add_ps(sum45, sum67);
362 let sum_vec = _mm256_add_ps(sum0123, sum4567);
363
364 let sum = horizontal_sum_avx2(sum_vec);
365
366 let main_processed = chunks * CHUNK_SIZE;
368 let remaining = n - main_processed;
369 let remaining_chunks = remaining / SIMD_WIDTH;
370
371 let mut remainder_sum = _mm256_setzero_ps();
372 for i in 0..remaining_chunks {
373 let offset = main_processed + i * SIMD_WIDTH;
374 let a_vec = _mm256_loadu_ps(a.as_ptr().add(offset));
375 let b_vec = _mm256_loadu_ps(b.as_ptr().add(offset));
376 remainder_sum = _mm256_fmadd_ps(a_vec, b_vec, remainder_sum);
377 }
378
379 let mut total = sum + horizontal_sum_avx2(remainder_sum);
380
381 let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
383 for i in scalar_start..n {
384 total += a[i] * b[i];
385 }
386
387 total
388}
389
390#[cfg(target_arch = "x86_64")]
406#[target_feature(enable = "avx2", enable = "fma")]
407unsafe fn dot_product_384_avx2(a: &[f32], b: &[f32]) -> f32 {
408 const SIMD_WIDTH: usize = 8;
409 const UNROLL: usize = 8;
411 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; const CHUNKS: usize = 384 / CHUNK_SIZE; const TAIL_ITERS: usize = (384 - CHUNKS * CHUNK_SIZE) / SIMD_WIDTH; debug_assert_eq!(a.len(), 384);
416 debug_assert_eq!(b.len(), 384);
417 debug_assert_eq!(CHUNKS * CHUNK_SIZE + TAIL_ITERS * SIMD_WIDTH, 384);
418
419 let mut sum0 = _mm256_setzero_ps();
421 let mut sum1 = _mm256_setzero_ps();
422 let mut sum2 = _mm256_setzero_ps();
423 let mut sum3 = _mm256_setzero_ps();
424 let mut sum4 = _mm256_setzero_ps();
425 let mut sum5 = _mm256_setzero_ps();
426 let mut sum6 = _mm256_setzero_ps();
427 let mut sum7 = _mm256_setzero_ps();
428
429 for i in 0..CHUNKS {
431 let base = i * CHUNK_SIZE;
432
433 let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
434 let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
435 sum0 = _mm256_fmadd_ps(a0, b0, sum0);
436
437 let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
438 let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
439 sum1 = _mm256_fmadd_ps(a1, b1, sum1);
440
441 let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
442 let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
443 sum2 = _mm256_fmadd_ps(a2, b2, sum2);
444
445 let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
446 let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
447 sum3 = _mm256_fmadd_ps(a3, b3, sum3);
448
449 let a4 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 4));
450 let b4 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 4));
451 sum4 = _mm256_fmadd_ps(a4, b4, sum4);
452
453 let a5 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 5));
454 let b5 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 5));
455 sum5 = _mm256_fmadd_ps(a5, b5, sum5);
456
457 let a6 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 6));
458 let b6 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 6));
459 sum6 = _mm256_fmadd_ps(a6, b6, sum6);
460
461 let a7 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 7));
462 let b7 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 7));
463 sum7 = _mm256_fmadd_ps(a7, b7, sum7);
464 }
465
466 let sum01 = _mm256_add_ps(sum0, sum1);
468 let sum23 = _mm256_add_ps(sum2, sum3);
469 let sum45 = _mm256_add_ps(sum4, sum5);
470 let sum67 = _mm256_add_ps(sum6, sum7);
471 let sum0123 = _mm256_add_ps(sum01, sum23);
472 let sum4567 = _mm256_add_ps(sum45, sum67);
473 let sum_vec = _mm256_add_ps(sum0123, sum4567);
474
475 horizontal_sum_avx2(sum_vec)
476}
477
478#[cfg(target_arch = "x86_64")]
484#[target_feature(enable = "avx2")]
485#[inline]
486pub(crate) unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
487 let high = _mm256_extractf128_ps(v, 1);
489 let low = _mm256_castps256_ps128(v);
490 let sum128 = _mm_add_ps(high, low);
491
492 let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let sums2 = _mm_add_ss(sums, shuf2); _mm_cvtss_f32(sums2)
499}
500
501#[cfg(target_arch = "x86_64")]
507#[inline]
508fn dot_product_batch4_avx2_kernel(
509 q: &[f32],
510 c0: &[f32],
511 c1: &[f32],
512 c2: &[f32],
513 c3: &[f32],
514) -> [f32; 4] {
515 if q.len() == 384 {
516 unsafe { dot_product_384_batch4_avx2(q, c0, c1, c2, c3) }
517 } else {
518 unsafe { dot_product_batch4_avx2(q, c0, c1, c2, c3) }
519 }
520}
521
522#[cfg(target_arch = "x86_64")]
533#[target_feature(enable = "avx2", enable = "fma")]
534unsafe fn dot_product_384_batch4_avx2(
535 q: &[f32],
536 c0: &[f32],
537 c1: &[f32],
538 c2: &[f32],
539 c3: &[f32],
540) -> [f32; 4] {
541 const W: usize = 8; const CHUNK: usize = W * 2; const CHUNKS: usize = 384 / CHUNK; debug_assert_eq!(q.len(), 384);
546
547 let mut acc00 = _mm256_setzero_ps();
548 let mut acc01 = _mm256_setzero_ps();
549 let mut acc10 = _mm256_setzero_ps();
550 let mut acc11 = _mm256_setzero_ps();
551 let mut acc20 = _mm256_setzero_ps();
552 let mut acc21 = _mm256_setzero_ps();
553 let mut acc30 = _mm256_setzero_ps();
554 let mut acc31 = _mm256_setzero_ps();
555
556 for i in 0..CHUNKS {
557 let base = i * CHUNK;
558 let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
559 let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
560
561 acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
562 acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
563 acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
564 acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
565 acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
566 acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
567 acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
568 acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
569 }
570
571 [
572 horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
573 horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
574 horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
575 horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
576 ]
577}
578
579#[cfg(target_arch = "x86_64")]
594#[target_feature(enable = "avx2", enable = "fma")]
595unsafe fn dot_product_batch4_avx2(
596 q: &[f32],
597 c0: &[f32],
598 c1: &[f32],
599 c2: &[f32],
600 c3: &[f32],
601) -> [f32; 4] {
602 const W: usize = 8;
603 const CHUNK: usize = W * 2; let n = q.len();
606 let chunks = n / CHUNK;
607
608 let mut acc00 = _mm256_setzero_ps();
609 let mut acc01 = _mm256_setzero_ps();
610 let mut acc10 = _mm256_setzero_ps();
611 let mut acc11 = _mm256_setzero_ps();
612 let mut acc20 = _mm256_setzero_ps();
613 let mut acc21 = _mm256_setzero_ps();
614 let mut acc30 = _mm256_setzero_ps();
615 let mut acc31 = _mm256_setzero_ps();
616
617 for i in 0..chunks {
618 let base = i * CHUNK;
619 let q0 = _mm256_loadu_ps(q.as_ptr().add(base));
620 let q1 = _mm256_loadu_ps(q.as_ptr().add(base + W));
621
622 acc00 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c0.as_ptr().add(base)), acc00);
623 acc01 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c0.as_ptr().add(base + W)), acc01);
624 acc10 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c1.as_ptr().add(base)), acc10);
625 acc11 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c1.as_ptr().add(base + W)), acc11);
626 acc20 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c2.as_ptr().add(base)), acc20);
627 acc21 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c2.as_ptr().add(base + W)), acc21);
628 acc30 = _mm256_fmadd_ps(q0, _mm256_loadu_ps(c3.as_ptr().add(base)), acc30);
629 acc31 = _mm256_fmadd_ps(q1, _mm256_loadu_ps(c3.as_ptr().add(base + W)), acc31);
630 }
631
632 let mut out = [
633 horizontal_sum_avx2(_mm256_add_ps(acc00, acc01)),
634 horizontal_sum_avx2(_mm256_add_ps(acc10, acc11)),
635 horizontal_sum_avx2(_mm256_add_ps(acc20, acc21)),
636 horizontal_sum_avx2(_mm256_add_ps(acc30, acc31)),
637 ];
638
639 let scalar_start = chunks * CHUNK;
640 for i in scalar_start..n {
641 let qi = q[i];
642 out[0] += qi * c0[i];
643 out[1] += qi * c1[i];
644 out[2] += qi * c2[i];
645 out[3] += qi * c3[i];
646 }
647
648 out
649}
650
651#[cfg(target_arch = "aarch64")]
668#[inline]
669unsafe fn dot_product_neon_unrolled(a: &[f32], b: &[f32]) -> f32 {
670 const SIMD_WIDTH: usize = 4;
671 const UNROLL: usize = 4;
672 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL; let n = a.len();
674 debug_assert_eq!(n, b.len());
675 let chunks = n / CHUNK_SIZE;
676
677 let mut sum0 = vdupq_n_f32(0.0);
679 let mut sum1 = vdupq_n_f32(0.0);
680 let mut sum2 = vdupq_n_f32(0.0);
681 let mut sum3 = vdupq_n_f32(0.0);
682
683 for i in 0..chunks {
684 let base = i * CHUNK_SIZE;
685
686 let a0 = vld1q_f32(a.as_ptr().add(base));
687 let b0 = vld1q_f32(b.as_ptr().add(base));
688 sum0 = vfmaq_f32(sum0, a0, b0);
689
690 let a1 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH));
691 let b1 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH));
692 sum1 = vfmaq_f32(sum1, a1, b1);
693
694 let a2 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 2));
695 let b2 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 2));
696 sum2 = vfmaq_f32(sum2, a2, b2);
697
698 let a3 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 3));
699 let b3 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 3));
700 sum3 = vfmaq_f32(sum3, a3, b3);
701 }
702
703 let sum01 = vaddq_f32(sum0, sum1);
705 let sum23 = vaddq_f32(sum2, sum3);
706 let sum_vec = vaddq_f32(sum01, sum23);
707
708 let mut sum = horizontal_sum_neon(sum_vec);
709
710 let main_processed = chunks * CHUNK_SIZE;
712 let remaining = n - main_processed;
713 let remaining_chunks = remaining / SIMD_WIDTH;
714
715 let mut remainder_sum = vdupq_n_f32(0.0);
716 for i in 0..remaining_chunks {
717 let offset = main_processed + i * SIMD_WIDTH;
718 let a_vec = vld1q_f32(a.as_ptr().add(offset));
719 let b_vec = vld1q_f32(b.as_ptr().add(offset));
720 remainder_sum = vfmaq_f32(remainder_sum, a_vec, b_vec);
721 }
722
723 sum += horizontal_sum_neon(remainder_sum);
724
725 let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
727 for i in scalar_start..n {
728 sum += a[i] * b[i];
729 }
730
731 sum
732}
733
734#[cfg(target_arch = "aarch64")]
740#[inline]
741pub(crate) unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
742 vaddvq_f32(v)
743}
744
745#[cfg(target_arch = "aarch64")]
751#[inline]
752fn dot_product_batch4_neon_kernel(
753 q: &[f32],
754 c0: &[f32],
755 c1: &[f32],
756 c2: &[f32],
757 c3: &[f32],
758) -> [f32; 4] {
759 unsafe { dot_product_batch4_neon(q, c0, c1, c2, c3) }
761}
762
763#[cfg(target_arch = "aarch64")]
779#[inline]
780unsafe fn dot_product_batch4_neon(
781 q: &[f32],
782 c0: &[f32],
783 c1: &[f32],
784 c2: &[f32],
785 c3: &[f32],
786) -> [f32; 4] {
787 const W: usize = 4; const CHUNK: usize = W * 2; let n = q.len();
791 let chunks = n / CHUNK;
792
793 let mut acc00 = vdupq_n_f32(0.0);
794 let mut acc01 = vdupq_n_f32(0.0);
795 let mut acc10 = vdupq_n_f32(0.0);
796 let mut acc11 = vdupq_n_f32(0.0);
797 let mut acc20 = vdupq_n_f32(0.0);
798 let mut acc21 = vdupq_n_f32(0.0);
799 let mut acc30 = vdupq_n_f32(0.0);
800 let mut acc31 = vdupq_n_f32(0.0);
801
802 for i in 0..chunks {
803 let base = i * CHUNK;
804 let q0 = vld1q_f32(q.as_ptr().add(base));
805 let q1 = vld1q_f32(q.as_ptr().add(base + W));
806
807 acc00 = vfmaq_f32(acc00, q0, vld1q_f32(c0.as_ptr().add(base)));
808 acc01 = vfmaq_f32(acc01, q1, vld1q_f32(c0.as_ptr().add(base + W)));
809 acc10 = vfmaq_f32(acc10, q0, vld1q_f32(c1.as_ptr().add(base)));
810 acc11 = vfmaq_f32(acc11, q1, vld1q_f32(c1.as_ptr().add(base + W)));
811 acc20 = vfmaq_f32(acc20, q0, vld1q_f32(c2.as_ptr().add(base)));
812 acc21 = vfmaq_f32(acc21, q1, vld1q_f32(c2.as_ptr().add(base + W)));
813 acc30 = vfmaq_f32(acc30, q0, vld1q_f32(c3.as_ptr().add(base)));
814 acc31 = vfmaq_f32(acc31, q1, vld1q_f32(c3.as_ptr().add(base + W)));
815 }
816
817 let mut out = [
818 vaddvq_f32(vaddq_f32(acc00, acc01)),
819 vaddvq_f32(vaddq_f32(acc10, acc11)),
820 vaddvq_f32(vaddq_f32(acc20, acc21)),
821 vaddvq_f32(vaddq_f32(acc30, acc31)),
822 ];
823
824 let scalar_start = chunks * CHUNK;
825 for i in scalar_start..n {
826 let qi = q[i];
827 out[0] += qi * c0[i];
828 out[1] += qi * c1[i];
829 out[2] += qi * c2[i];
830 out[3] += qi * c3[i];
831 }
832
833 out
834}
835
836#[inline]
838fn same_query_batch4(chunk: &[(&[f32], &[f32])]) -> bool {
839 debug_assert_eq!(chunk.len(), 4);
840 let q_ptr = chunk[0].0.as_ptr();
841 let q_len = chunk[0].0.len();
842 q_len == chunk[0].1.len()
843 && chunk
844 .iter()
845 .all(|(q, c)| q.as_ptr() == q_ptr && q.len() == q_len && c.len() == q_len)
846}
847
848pub fn batch_dot_product(pairs: &[(&[f32], &[f32])]) -> Vec<f32> {
854 let pair_kernel = resolved_dot_product_kernel();
855 let batch4_kernel = resolved_dot_product_batch4_kernel();
856 let mut out = Vec::with_capacity(pairs.len());
857
858 let mut chunks = pairs.chunks_exact(4);
859 for chunk in &mut chunks {
860 if same_query_batch4(chunk) {
861 let q = chunk[0].0;
862 let dots = batch4_kernel(q, chunk[0].1, chunk[1].1, chunk[2].1, chunk[3].1);
863 out.extend_from_slice(&dots);
864 } else {
865 for &(a, b) in chunk {
866 out.push(if a.len() == b.len() {
867 pair_kernel(a, b)
868 } else {
869 0.0
870 });
871 }
872 }
873 }
874 for &(a, b) in chunks.remainder() {
875 out.push(if a.len() == b.len() {
876 pair_kernel(a, b)
877 } else {
878 0.0
879 });
880 }
881 out
882}