1#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20 use std::arch::aarch64::*;
21
22 #[target_feature(enable = "neon")]
24 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25 let chunks = count / 16;
26 let remainder = count % 16;
27
28 for chunk in 0..chunks {
29 let base = chunk * 16;
30 let in_ptr = input.as_ptr().add(base);
31
32 let bytes = vld1q_u8(in_ptr);
34
35 let low8 = vget_low_u8(bytes);
37 let high8 = vget_high_u8(bytes);
38
39 let low16 = vmovl_u8(low8);
40 let high16 = vmovl_u8(high8);
41
42 let v0 = vmovl_u16(vget_low_u16(low16));
43 let v1 = vmovl_u16(vget_high_u16(low16));
44 let v2 = vmovl_u16(vget_low_u16(high16));
45 let v3 = vmovl_u16(vget_high_u16(high16));
46
47 let out_ptr = output.as_mut_ptr().add(base);
48 vst1q_u32(out_ptr, v0);
49 vst1q_u32(out_ptr.add(4), v1);
50 vst1q_u32(out_ptr.add(8), v2);
51 vst1q_u32(out_ptr.add(12), v3);
52 }
53
54 let base = chunks * 16;
56 for i in 0..remainder {
57 output[base + i] = input[base + i] as u32;
58 }
59 }
60
61 #[target_feature(enable = "neon")]
63 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64 let chunks = count / 8;
65 let remainder = count % 8;
66
67 for chunk in 0..chunks {
68 let base = chunk * 8;
69 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71 let vals = vld1q_u16(in_ptr);
72 let low = vmovl_u16(vget_low_u16(vals));
73 let high = vmovl_u16(vget_high_u16(vals));
74
75 let out_ptr = output.as_mut_ptr().add(base);
76 vst1q_u32(out_ptr, low);
77 vst1q_u32(out_ptr.add(4), high);
78 }
79
80 let base = chunks * 8;
82 for i in 0..remainder {
83 let idx = (base + i) * 2;
84 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85 }
86 }
87
88 #[target_feature(enable = "neon")]
90 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91 let chunks = count / 4;
92 let remainder = count % 4;
93
94 let in_ptr = input.as_ptr() as *const u32;
95 let out_ptr = output.as_mut_ptr();
96
97 for chunk in 0..chunks {
98 let vals = vld1q_u32(in_ptr.add(chunk * 4));
99 vst1q_u32(out_ptr.add(chunk * 4), vals);
100 }
101
102 let base = chunks * 4;
104 for i in 0..remainder {
105 let idx = (base + i) * 4;
106 output[base + i] =
107 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108 }
109 }
110
111 #[inline]
115 #[target_feature(enable = "neon")]
116 unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120 let sum1 = vaddq_u32(v, shifted1);
121
122 let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125 vaddq_u32(sum1, shifted2)
126 }
127
128 #[target_feature(enable = "neon")]
132 pub unsafe fn delta_decode(
133 output: &mut [u32],
134 deltas: &[u32],
135 first_doc_id: u32,
136 count: usize,
137 ) {
138 if count == 0 {
139 return;
140 }
141
142 output[0] = first_doc_id;
143 if count == 1 {
144 return;
145 }
146
147 let ones = vdupq_n_u32(1);
148 let mut carry = vdupq_n_u32(first_doc_id);
149
150 let full_groups = (count - 1) / 4;
151 let remainder = (count - 1) % 4;
152
153 for group in 0..full_groups {
154 let base = group * 4;
155
156 let d = vld1q_u32(deltas[base..].as_ptr());
158 let gaps = vaddq_u32(d, ones);
159
160 let prefix = prefix_sum_4(gaps);
162
163 let result = vaddq_u32(prefix, carry);
165
166 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171 }
172
173 let base = full_groups * 4;
175 let mut scalar_carry = vgetq_lane_u32(carry, 0);
176 for j in 0..remainder {
177 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178 output[base + j + 1] = scalar_carry;
179 }
180 }
181
182 #[target_feature(enable = "neon")]
184 pub unsafe fn add_one(values: &mut [u32], count: usize) {
185 let ones = vdupq_n_u32(1);
186 let chunks = count / 4;
187 let remainder = count % 4;
188
189 for chunk in 0..chunks {
190 let base = chunk * 4;
191 let ptr = values.as_mut_ptr().add(base);
192 let v = vld1q_u32(ptr);
193 let result = vaddq_u32(v, ones);
194 vst1q_u32(ptr, result);
195 }
196
197 let base = chunks * 4;
198 for i in 0..remainder {
199 values[base + i] += 1;
200 }
201 }
202
203 #[target_feature(enable = "neon")]
206 pub unsafe fn unpack_8bit_delta_decode(
207 input: &[u8],
208 output: &mut [u32],
209 first_value: u32,
210 count: usize,
211 ) {
212 output[0] = first_value;
213 if count <= 1 {
214 return;
215 }
216
217 let ones = vdupq_n_u32(1);
218 let mut carry = vdupq_n_u32(first_value);
219
220 let full_groups = (count - 1) / 4;
221 let remainder = (count - 1) % 4;
222
223 for group in 0..full_groups {
224 let base = group * 4;
225
226 let b0 = input[base] as u32;
228 let b1 = input[base + 1] as u32;
229 let b2 = input[base + 2] as u32;
230 let b3 = input[base + 3] as u32;
231 let deltas = [b0, b1, b2, b3];
232 let d = vld1q_u32(deltas.as_ptr());
233
234 let gaps = vaddq_u32(d, ones);
236
237 let prefix = prefix_sum_4(gaps);
239
240 let result = vaddq_u32(prefix, carry);
242
243 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248 }
249
250 let base = full_groups * 4;
252 let mut scalar_carry = vgetq_lane_u32(carry, 0);
253 for j in 0..remainder {
254 scalar_carry = scalar_carry
255 .wrapping_add(input[base + j] as u32)
256 .wrapping_add(1);
257 output[base + j + 1] = scalar_carry;
258 }
259 }
260
261 #[target_feature(enable = "neon")]
263 pub unsafe fn unpack_16bit_delta_decode(
264 input: &[u8],
265 output: &mut [u32],
266 first_value: u32,
267 count: usize,
268 ) {
269 output[0] = first_value;
270 if count <= 1 {
271 return;
272 }
273
274 let ones = vdupq_n_u32(1);
275 let mut carry = vdupq_n_u32(first_value);
276
277 let full_groups = (count - 1) / 4;
278 let remainder = (count - 1) % 4;
279
280 for group in 0..full_groups {
281 let base = group * 4;
282 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284 let vals = vld1_u16(in_ptr);
286 let d = vmovl_u16(vals);
287
288 let gaps = vaddq_u32(d, ones);
290
291 let prefix = prefix_sum_4(gaps);
293
294 let result = vaddq_u32(prefix, carry);
296
297 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302 }
303
304 let base = full_groups * 4;
306 let mut scalar_carry = vgetq_lane_u32(carry, 0);
307 for j in 0..remainder {
308 let idx = (base + j) * 2;
309 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311 output[base + j + 1] = scalar_carry;
312 }
313 }
314
315 #[inline]
317 pub fn is_available() -> bool {
318 true
319 }
320}
321
322#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329 use std::arch::x86_64::*;
330
331 #[target_feature(enable = "sse2", enable = "sse4.1")]
333 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334 let chunks = count / 16;
335 let remainder = count % 16;
336
337 for chunk in 0..chunks {
338 let base = chunk * 16;
339 let in_ptr = input.as_ptr().add(base);
340
341 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343 let v0 = _mm_cvtepu8_epi32(bytes);
345 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349 let out_ptr = output.as_mut_ptr().add(base);
350 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354 }
355
356 let base = chunks * 16;
357 for i in 0..remainder {
358 output[base + i] = input[base + i] as u32;
359 }
360 }
361
362 #[target_feature(enable = "sse2", enable = "sse4.1")]
364 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365 let chunks = count / 8;
366 let remainder = count % 8;
367
368 for chunk in 0..chunks {
369 let base = chunk * 8;
370 let in_ptr = input.as_ptr().add(base * 2);
371
372 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373 let low = _mm_cvtepu16_epi32(vals);
374 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376 let out_ptr = output.as_mut_ptr().add(base);
377 _mm_storeu_si128(out_ptr as *mut __m128i, low);
378 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379 }
380
381 let base = chunks * 8;
382 for i in 0..remainder {
383 let idx = (base + i) * 2;
384 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385 }
386 }
387
388 #[target_feature(enable = "sse2")]
390 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391 let chunks = count / 4;
392 let remainder = count % 4;
393
394 let in_ptr = input.as_ptr() as *const __m128i;
395 let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397 for chunk in 0..chunks {
398 let vals = _mm_loadu_si128(in_ptr.add(chunk));
399 _mm_storeu_si128(out_ptr.add(chunk), vals);
400 }
401
402 let base = chunks * 4;
404 for i in 0..remainder {
405 let idx = (base + i) * 4;
406 output[base + i] =
407 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408 }
409 }
410
411 #[inline]
415 #[target_feature(enable = "sse2")]
416 unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417 let shifted1 = _mm_slli_si128(v, 4);
420 let sum1 = _mm_add_epi32(v, shifted1);
421
422 let shifted2 = _mm_slli_si128(sum1, 8);
425 _mm_add_epi32(sum1, shifted2)
426 }
427
428 #[target_feature(enable = "sse2", enable = "sse4.1")]
430 pub unsafe fn delta_decode(
431 output: &mut [u32],
432 deltas: &[u32],
433 first_doc_id: u32,
434 count: usize,
435 ) {
436 if count == 0 {
437 return;
438 }
439
440 output[0] = first_doc_id;
441 if count == 1 {
442 return;
443 }
444
445 let ones = _mm_set1_epi32(1);
446 let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448 let full_groups = (count - 1) / 4;
449 let remainder = (count - 1) % 4;
450
451 for group in 0..full_groups {
452 let base = group * 4;
453
454 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456 let gaps = _mm_add_epi32(d, ones);
457
458 let prefix = prefix_sum_4(gaps);
460
461 let result = _mm_add_epi32(prefix, carry);
463
464 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467 carry = _mm_shuffle_epi32(result, 0xFF); }
470
471 let base = full_groups * 4;
473 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474 for j in 0..remainder {
475 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476 output[base + j + 1] = scalar_carry;
477 }
478 }
479
480 #[target_feature(enable = "sse2")]
482 pub unsafe fn add_one(values: &mut [u32], count: usize) {
483 let ones = _mm_set1_epi32(1);
484 let chunks = count / 4;
485 let remainder = count % 4;
486
487 for chunk in 0..chunks {
488 let base = chunk * 4;
489 let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490 let v = _mm_loadu_si128(ptr);
491 let result = _mm_add_epi32(v, ones);
492 _mm_storeu_si128(ptr, result);
493 }
494
495 let base = chunks * 4;
496 for i in 0..remainder {
497 values[base + i] += 1;
498 }
499 }
500
501 #[target_feature(enable = "sse2", enable = "sse4.1")]
503 pub unsafe fn unpack_8bit_delta_decode(
504 input: &[u8],
505 output: &mut [u32],
506 first_value: u32,
507 count: usize,
508 ) {
509 output[0] = first_value;
510 if count <= 1 {
511 return;
512 }
513
514 let ones = _mm_set1_epi32(1);
515 let mut carry = _mm_set1_epi32(first_value as i32);
516
517 let full_groups = (count - 1) / 4;
518 let remainder = (count - 1) % 4;
519
520 for group in 0..full_groups {
521 let base = group * 4;
522
523 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
525 input.as_ptr().add(base) as *const i32
526 ));
527 let d = _mm_cvtepu8_epi32(bytes);
528
529 let gaps = _mm_add_epi32(d, ones);
531
532 let prefix = prefix_sum_4(gaps);
534
535 let result = _mm_add_epi32(prefix, carry);
537
538 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
540
541 carry = _mm_shuffle_epi32(result, 0xFF);
543 }
544
545 let base = full_groups * 4;
547 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
548 for j in 0..remainder {
549 scalar_carry = scalar_carry
550 .wrapping_add(input[base + j] as u32)
551 .wrapping_add(1);
552 output[base + j + 1] = scalar_carry;
553 }
554 }
555
556 #[target_feature(enable = "sse2", enable = "sse4.1")]
558 pub unsafe fn unpack_16bit_delta_decode(
559 input: &[u8],
560 output: &mut [u32],
561 first_value: u32,
562 count: usize,
563 ) {
564 output[0] = first_value;
565 if count <= 1 {
566 return;
567 }
568
569 let ones = _mm_set1_epi32(1);
570 let mut carry = _mm_set1_epi32(first_value as i32);
571
572 let full_groups = (count - 1) / 4;
573 let remainder = (count - 1) % 4;
574
575 for group in 0..full_groups {
576 let base = group * 4;
577 let in_ptr = input.as_ptr().add(base * 2);
578
579 let vals = _mm_loadl_epi64(in_ptr as *const __m128i); let d = _mm_cvtepu16_epi32(vals);
582
583 let gaps = _mm_add_epi32(d, ones);
585
586 let prefix = prefix_sum_4(gaps);
588
589 let result = _mm_add_epi32(prefix, carry);
591
592 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
594
595 carry = _mm_shuffle_epi32(result, 0xFF);
597 }
598
599 let base = full_groups * 4;
601 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
602 for j in 0..remainder {
603 let idx = (base + j) * 2;
604 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
605 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
606 output[base + j + 1] = scalar_carry;
607 }
608 }
609
610 #[inline]
612 pub fn is_available() -> bool {
613 is_x86_feature_detected!("sse4.1")
614 }
615}
616
617#[cfg(target_arch = "x86_64")]
622#[allow(unsafe_op_in_unsafe_fn)]
623mod avx2 {
624 use std::arch::x86_64::*;
625
626 #[target_feature(enable = "avx2")]
628 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
629 let chunks = count / 32;
630 let remainder = count % 32;
631
632 for chunk in 0..chunks {
633 let base = chunk * 32;
634 let in_ptr = input.as_ptr().add(base);
635
636 let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
638 let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
639
640 let v0 = _mm256_cvtepu8_epi32(bytes_lo);
642 let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
643 let v2 = _mm256_cvtepu8_epi32(bytes_hi);
644 let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
645
646 let out_ptr = output.as_mut_ptr().add(base);
647 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
648 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
649 _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
650 _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
651 }
652
653 let base = chunks * 32;
655 for i in 0..remainder {
656 output[base + i] = input[base + i] as u32;
657 }
658 }
659
660 #[target_feature(enable = "avx2")]
662 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
663 let chunks = count / 16;
664 let remainder = count % 16;
665
666 for chunk in 0..chunks {
667 let base = chunk * 16;
668 let in_ptr = input.as_ptr().add(base * 2);
669
670 let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
672 let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
673
674 let v0 = _mm256_cvtepu16_epi32(vals_lo);
676 let v1 = _mm256_cvtepu16_epi32(vals_hi);
677
678 let out_ptr = output.as_mut_ptr().add(base);
679 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
680 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
681 }
682
683 let base = chunks * 16;
685 for i in 0..remainder {
686 let idx = (base + i) * 2;
687 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
688 }
689 }
690
691 #[target_feature(enable = "avx2")]
693 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
694 let chunks = count / 8;
695 let remainder = count % 8;
696
697 let in_ptr = input.as_ptr() as *const __m256i;
698 let out_ptr = output.as_mut_ptr() as *mut __m256i;
699
700 for chunk in 0..chunks {
701 let vals = _mm256_loadu_si256(in_ptr.add(chunk));
702 _mm256_storeu_si256(out_ptr.add(chunk), vals);
703 }
704
705 let base = chunks * 8;
707 for i in 0..remainder {
708 let idx = (base + i) * 4;
709 output[base + i] =
710 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
711 }
712 }
713
714 #[target_feature(enable = "avx2")]
716 pub unsafe fn add_one(values: &mut [u32], count: usize) {
717 let ones = _mm256_set1_epi32(1);
718 let chunks = count / 8;
719 let remainder = count % 8;
720
721 for chunk in 0..chunks {
722 let base = chunk * 8;
723 let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
724 let v = _mm256_loadu_si256(ptr);
725 let result = _mm256_add_epi32(v, ones);
726 _mm256_storeu_si256(ptr, result);
727 }
728
729 let base = chunks * 8;
730 for i in 0..remainder {
731 values[base + i] += 1;
732 }
733 }
734
735 #[inline]
737 pub fn is_available() -> bool {
738 is_x86_feature_detected!("avx2")
739 }
740}
741
742#[allow(dead_code)]
747mod scalar {
748 #[inline]
750 pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
751 for i in 0..count {
752 output[i] = input[i] as u32;
753 }
754 }
755
756 #[inline]
758 pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
759 for (i, out) in output.iter_mut().enumerate().take(count) {
760 let idx = i * 2;
761 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
762 }
763 }
764
765 #[inline]
767 pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
768 for (i, out) in output.iter_mut().enumerate().take(count) {
769 let idx = i * 4;
770 *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
771 }
772 }
773
774 #[inline]
776 pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
777 if count == 0 {
778 return;
779 }
780
781 output[0] = first_doc_id;
782 let mut carry = first_doc_id;
783
784 for i in 0..count - 1 {
785 carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
786 output[i + 1] = carry;
787 }
788 }
789
790 #[inline]
792 pub fn add_one(values: &mut [u32], count: usize) {
793 for val in values.iter_mut().take(count) {
794 *val += 1;
795 }
796 }
797}
798
799#[inline]
805pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
806 #[cfg(target_arch = "aarch64")]
807 {
808 if neon::is_available() {
809 unsafe {
810 neon::unpack_8bit(input, output, count);
811 }
812 return;
813 }
814 }
815
816 #[cfg(target_arch = "x86_64")]
817 {
818 if avx2::is_available() {
820 unsafe {
821 avx2::unpack_8bit(input, output, count);
822 }
823 return;
824 }
825 if sse::is_available() {
826 unsafe {
827 sse::unpack_8bit(input, output, count);
828 }
829 return;
830 }
831 }
832
833 scalar::unpack_8bit(input, output, count);
834}
835
836#[inline]
838pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
839 #[cfg(target_arch = "aarch64")]
840 {
841 if neon::is_available() {
842 unsafe {
843 neon::unpack_16bit(input, output, count);
844 }
845 return;
846 }
847 }
848
849 #[cfg(target_arch = "x86_64")]
850 {
851 if avx2::is_available() {
853 unsafe {
854 avx2::unpack_16bit(input, output, count);
855 }
856 return;
857 }
858 if sse::is_available() {
859 unsafe {
860 sse::unpack_16bit(input, output, count);
861 }
862 return;
863 }
864 }
865
866 scalar::unpack_16bit(input, output, count);
867}
868
869#[inline]
871pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
872 #[cfg(target_arch = "aarch64")]
873 {
874 if neon::is_available() {
875 unsafe {
876 neon::unpack_32bit(input, output, count);
877 }
878 }
879 }
880
881 #[cfg(target_arch = "x86_64")]
882 {
883 if avx2::is_available() {
885 unsafe {
886 avx2::unpack_32bit(input, output, count);
887 }
888 } else {
889 unsafe {
891 sse::unpack_32bit(input, output, count);
892 }
893 }
894 }
895
896 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
897 {
898 scalar::unpack_32bit(input, output, count);
899 }
900}
901
902#[inline]
908pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
909 #[cfg(target_arch = "aarch64")]
910 {
911 if neon::is_available() {
912 unsafe {
913 neon::delta_decode(output, deltas, first_value, count);
914 }
915 return;
916 }
917 }
918
919 #[cfg(target_arch = "x86_64")]
920 {
921 if sse::is_available() {
922 unsafe {
923 sse::delta_decode(output, deltas, first_value, count);
924 }
925 return;
926 }
927 }
928
929 scalar::delta_decode(output, deltas, first_value, count);
930}
931
932#[inline]
936pub fn add_one(values: &mut [u32], count: usize) {
937 #[cfg(target_arch = "aarch64")]
938 {
939 if neon::is_available() {
940 unsafe {
941 neon::add_one(values, count);
942 }
943 }
944 }
945
946 #[cfg(target_arch = "x86_64")]
947 {
948 if avx2::is_available() {
950 unsafe {
951 avx2::add_one(values, count);
952 }
953 } else {
954 unsafe {
956 sse::add_one(values, count);
957 }
958 }
959 }
960
961 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
962 {
963 scalar::add_one(values, count);
964 }
965}
966
967#[inline]
969pub fn bits_needed(val: u32) -> u8 {
970 if val == 0 {
971 0
972 } else {
973 32 - val.leading_zeros() as u8
974 }
975}
976
977#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994#[repr(u8)]
995pub enum RoundedBitWidth {
996 Zero = 0,
997 Bits8 = 8,
998 Bits16 = 16,
999 Bits32 = 32,
1000}
1001
1002impl RoundedBitWidth {
1003 #[inline]
1005 pub fn from_exact(bits: u8) -> Self {
1006 match bits {
1007 0 => RoundedBitWidth::Zero,
1008 1..=8 => RoundedBitWidth::Bits8,
1009 9..=16 => RoundedBitWidth::Bits16,
1010 _ => RoundedBitWidth::Bits32,
1011 }
1012 }
1013
1014 #[inline]
1016 pub fn from_u8(bits: u8) -> Self {
1017 match bits {
1018 0 => RoundedBitWidth::Zero,
1019 8 => RoundedBitWidth::Bits8,
1020 16 => RoundedBitWidth::Bits16,
1021 32 => RoundedBitWidth::Bits32,
1022 _ => RoundedBitWidth::Bits32, }
1024 }
1025
1026 #[inline]
1028 pub fn bytes_per_value(self) -> usize {
1029 match self {
1030 RoundedBitWidth::Zero => 0,
1031 RoundedBitWidth::Bits8 => 1,
1032 RoundedBitWidth::Bits16 => 2,
1033 RoundedBitWidth::Bits32 => 4,
1034 }
1035 }
1036
1037 #[inline]
1039 pub fn as_u8(self) -> u8 {
1040 self as u8
1041 }
1042}
1043
1044#[inline]
1046pub fn round_bit_width(bits: u8) -> u8 {
1047 RoundedBitWidth::from_exact(bits).as_u8()
1048}
1049
1050#[inline]
1055pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1056 let count = values.len();
1057 match bit_width {
1058 RoundedBitWidth::Zero => 0,
1059 RoundedBitWidth::Bits8 => {
1060 for (i, &v) in values.iter().enumerate() {
1061 output[i] = v as u8;
1062 }
1063 count
1064 }
1065 RoundedBitWidth::Bits16 => {
1066 for (i, &v) in values.iter().enumerate() {
1067 let bytes = (v as u16).to_le_bytes();
1068 output[i * 2] = bytes[0];
1069 output[i * 2 + 1] = bytes[1];
1070 }
1071 count * 2
1072 }
1073 RoundedBitWidth::Bits32 => {
1074 for (i, &v) in values.iter().enumerate() {
1075 let bytes = v.to_le_bytes();
1076 output[i * 4] = bytes[0];
1077 output[i * 4 + 1] = bytes[1];
1078 output[i * 4 + 2] = bytes[2];
1079 output[i * 4 + 3] = bytes[3];
1080 }
1081 count * 4
1082 }
1083 }
1084}
1085
1086#[inline]
1090pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1091 match bit_width {
1092 RoundedBitWidth::Zero => {
1093 for out in output.iter_mut().take(count) {
1094 *out = 0;
1095 }
1096 }
1097 RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1098 RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1099 RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1100 }
1101}
1102
1103#[inline]
1107pub fn unpack_rounded_delta_decode(
1108 input: &[u8],
1109 bit_width: RoundedBitWidth,
1110 output: &mut [u32],
1111 first_value: u32,
1112 count: usize,
1113) {
1114 match bit_width {
1115 RoundedBitWidth::Zero => {
1116 let mut val = first_value;
1118 for out in output.iter_mut().take(count) {
1119 *out = val;
1120 val = val.wrapping_add(1);
1121 }
1122 }
1123 RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1124 RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1125 RoundedBitWidth::Bits32 => {
1126 unpack_32bit(input, output, count);
1128 if count > 0 {
1131 let mut carry = first_value;
1132 output[0] = first_value;
1133 for item in output.iter_mut().take(count).skip(1) {
1134 carry = carry.wrapping_add(*item).wrapping_add(1);
1136 *item = carry;
1137 }
1138 }
1139 }
1140 }
1141}
1142
1143#[inline]
1152pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1153 if count == 0 {
1154 return;
1155 }
1156
1157 output[0] = first_value;
1158 if count == 1 {
1159 return;
1160 }
1161
1162 #[cfg(target_arch = "aarch64")]
1163 {
1164 if neon::is_available() {
1165 unsafe {
1166 neon::unpack_8bit_delta_decode(input, output, first_value, count);
1167 }
1168 return;
1169 }
1170 }
1171
1172 #[cfg(target_arch = "x86_64")]
1173 {
1174 if sse::is_available() {
1175 unsafe {
1176 sse::unpack_8bit_delta_decode(input, output, first_value, count);
1177 }
1178 return;
1179 }
1180 }
1181
1182 let mut carry = first_value;
1184 for i in 0..count - 1 {
1185 carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1186 output[i + 1] = carry;
1187 }
1188}
1189
1190#[inline]
1192pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1193 if count == 0 {
1194 return;
1195 }
1196
1197 output[0] = first_value;
1198 if count == 1 {
1199 return;
1200 }
1201
1202 #[cfg(target_arch = "aarch64")]
1203 {
1204 if neon::is_available() {
1205 unsafe {
1206 neon::unpack_16bit_delta_decode(input, output, first_value, count);
1207 }
1208 return;
1209 }
1210 }
1211
1212 #[cfg(target_arch = "x86_64")]
1213 {
1214 if sse::is_available() {
1215 unsafe {
1216 sse::unpack_16bit_delta_decode(input, output, first_value, count);
1217 }
1218 return;
1219 }
1220 }
1221
1222 let mut carry = first_value;
1224 for i in 0..count - 1 {
1225 let idx = i * 2;
1226 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1227 carry = carry.wrapping_add(delta).wrapping_add(1);
1228 output[i + 1] = carry;
1229 }
1230}
1231
1232#[inline]
1237pub fn unpack_delta_decode(
1238 input: &[u8],
1239 bit_width: u8,
1240 output: &mut [u32],
1241 first_value: u32,
1242 count: usize,
1243) {
1244 if count == 0 {
1245 return;
1246 }
1247
1248 output[0] = first_value;
1249 if count == 1 {
1250 return;
1251 }
1252
1253 match bit_width {
1255 0 => {
1256 let mut val = first_value;
1258 for item in output.iter_mut().take(count).skip(1) {
1259 val = val.wrapping_add(1);
1260 *item = val;
1261 }
1262 }
1263 8 => unpack_8bit_delta_decode(input, output, first_value, count),
1264 16 => unpack_16bit_delta_decode(input, output, first_value, count),
1265 32 => {
1266 let mut carry = first_value;
1268 for i in 0..count - 1 {
1269 let idx = i * 4;
1270 let delta = u32::from_le_bytes([
1271 input[idx],
1272 input[idx + 1],
1273 input[idx + 2],
1274 input[idx + 3],
1275 ]);
1276 carry = carry.wrapping_add(delta).wrapping_add(1);
1277 output[i + 1] = carry;
1278 }
1279 }
1280 _ => {
1281 let mask = (1u64 << bit_width) - 1;
1283 let bit_width_usize = bit_width as usize;
1284 let mut bit_pos = 0usize;
1285 let input_ptr = input.as_ptr();
1286 let mut carry = first_value;
1287
1288 for i in 0..count - 1 {
1289 let byte_idx = bit_pos >> 3;
1290 let bit_offset = bit_pos & 7;
1291
1292 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1294 let delta = ((word >> bit_offset) & mask) as u32;
1295
1296 carry = carry.wrapping_add(delta).wrapping_add(1);
1297 output[i + 1] = carry;
1298 bit_pos += bit_width_usize;
1299 }
1300 }
1301 }
1302}
1303
1304#[inline]
1312pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1313 #[cfg(target_arch = "aarch64")]
1314 {
1315 if neon::is_available() {
1316 unsafe {
1317 dequantize_uint8_neon(input, output, scale, min_val, count);
1318 }
1319 return;
1320 }
1321 }
1322
1323 #[cfg(target_arch = "x86_64")]
1324 {
1325 if sse::is_available() {
1326 unsafe {
1327 dequantize_uint8_sse(input, output, scale, min_val, count);
1328 }
1329 return;
1330 }
1331 }
1332
1333 for i in 0..count {
1335 output[i] = input[i] as f32 * scale + min_val;
1336 }
1337}
1338
1339#[cfg(target_arch = "aarch64")]
1340#[target_feature(enable = "neon")]
1341#[allow(unsafe_op_in_unsafe_fn)]
1342unsafe fn dequantize_uint8_neon(
1343 input: &[u8],
1344 output: &mut [f32],
1345 scale: f32,
1346 min_val: f32,
1347 count: usize,
1348) {
1349 use std::arch::aarch64::*;
1350
1351 let scale_v = vdupq_n_f32(scale);
1352 let min_v = vdupq_n_f32(min_val);
1353
1354 let chunks = count / 16;
1355 let remainder = count % 16;
1356
1357 for chunk in 0..chunks {
1358 let base = chunk * 16;
1359 let in_ptr = input.as_ptr().add(base);
1360
1361 let bytes = vld1q_u8(in_ptr);
1363
1364 let low8 = vget_low_u8(bytes);
1366 let high8 = vget_high_u8(bytes);
1367
1368 let low16 = vmovl_u8(low8);
1369 let high16 = vmovl_u8(high8);
1370
1371 let u32_0 = vmovl_u16(vget_low_u16(low16));
1373 let u32_1 = vmovl_u16(vget_high_u16(low16));
1374 let u32_2 = vmovl_u16(vget_low_u16(high16));
1375 let u32_3 = vmovl_u16(vget_high_u16(high16));
1376
1377 let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1379 let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1380 let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1381 let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1382
1383 let out_ptr = output.as_mut_ptr().add(base);
1384 vst1q_f32(out_ptr, f32_0);
1385 vst1q_f32(out_ptr.add(4), f32_1);
1386 vst1q_f32(out_ptr.add(8), f32_2);
1387 vst1q_f32(out_ptr.add(12), f32_3);
1388 }
1389
1390 let base = chunks * 16;
1392 for i in 0..remainder {
1393 output[base + i] = input[base + i] as f32 * scale + min_val;
1394 }
1395}
1396
1397#[cfg(target_arch = "x86_64")]
1398#[target_feature(enable = "sse2", enable = "sse4.1")]
1399#[allow(unsafe_op_in_unsafe_fn)]
1400unsafe fn dequantize_uint8_sse(
1401 input: &[u8],
1402 output: &mut [f32],
1403 scale: f32,
1404 min_val: f32,
1405 count: usize,
1406) {
1407 use std::arch::x86_64::*;
1408
1409 let scale_v = _mm_set1_ps(scale);
1410 let min_v = _mm_set1_ps(min_val);
1411
1412 let chunks = count / 4;
1413 let remainder = count % 4;
1414
1415 for chunk in 0..chunks {
1416 let base = chunk * 4;
1417
1418 let b0 = input[base] as i32;
1420 let b1 = input[base + 1] as i32;
1421 let b2 = input[base + 2] as i32;
1422 let b3 = input[base + 3] as i32;
1423
1424 let ints = _mm_set_epi32(b3, b2, b1, b0);
1425 let floats = _mm_cvtepi32_ps(ints);
1426
1427 let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1429
1430 _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1431 }
1432
1433 let base = chunks * 4;
1435 for i in 0..remainder {
1436 output[base + i] = input[base + i] as f32 * scale + min_val;
1437 }
1438}
1439
1440#[inline]
1442pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1443 #[cfg(target_arch = "aarch64")]
1444 {
1445 if neon::is_available() {
1446 return unsafe { dot_product_f32_neon(a, b, count) };
1447 }
1448 }
1449
1450 #[cfg(target_arch = "x86_64")]
1451 {
1452 if is_x86_feature_detected!("avx512f") {
1453 return unsafe { dot_product_f32_avx512(a, b, count) };
1454 }
1455 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1456 return unsafe { dot_product_f32_avx2(a, b, count) };
1457 }
1458 if sse::is_available() {
1459 return unsafe { dot_product_f32_sse(a, b, count) };
1460 }
1461 }
1462
1463 let mut sum = 0.0f32;
1465 for i in 0..count {
1466 sum += a[i] * b[i];
1467 }
1468 sum
1469}
1470
1471#[cfg(target_arch = "aarch64")]
1472#[target_feature(enable = "neon")]
1473#[allow(unsafe_op_in_unsafe_fn)]
1474unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1475 use std::arch::aarch64::*;
1476
1477 let chunks16 = count / 16;
1478 let remainder = count % 16;
1479
1480 let mut acc0 = vdupq_n_f32(0.0);
1481 let mut acc1 = vdupq_n_f32(0.0);
1482 let mut acc2 = vdupq_n_f32(0.0);
1483 let mut acc3 = vdupq_n_f32(0.0);
1484
1485 for c in 0..chunks16 {
1486 let base = c * 16;
1487 acc0 = vfmaq_f32(
1488 acc0,
1489 vld1q_f32(a.as_ptr().add(base)),
1490 vld1q_f32(b.as_ptr().add(base)),
1491 );
1492 acc1 = vfmaq_f32(
1493 acc1,
1494 vld1q_f32(a.as_ptr().add(base + 4)),
1495 vld1q_f32(b.as_ptr().add(base + 4)),
1496 );
1497 acc2 = vfmaq_f32(
1498 acc2,
1499 vld1q_f32(a.as_ptr().add(base + 8)),
1500 vld1q_f32(b.as_ptr().add(base + 8)),
1501 );
1502 acc3 = vfmaq_f32(
1503 acc3,
1504 vld1q_f32(a.as_ptr().add(base + 12)),
1505 vld1q_f32(b.as_ptr().add(base + 12)),
1506 );
1507 }
1508
1509 let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1510 let mut sum = vaddvq_f32(acc);
1511
1512 let base = chunks16 * 16;
1513 for i in 0..remainder {
1514 sum += a[base + i] * b[base + i];
1515 }
1516
1517 sum
1518}
1519
1520#[cfg(target_arch = "x86_64")]
1521#[target_feature(enable = "avx2", enable = "fma")]
1522#[allow(unsafe_op_in_unsafe_fn)]
1523unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
1524 use std::arch::x86_64::*;
1525
1526 let chunks32 = count / 32;
1527 let remainder = count % 32;
1528
1529 let mut acc0 = _mm256_setzero_ps();
1530 let mut acc1 = _mm256_setzero_ps();
1531 let mut acc2 = _mm256_setzero_ps();
1532 let mut acc3 = _mm256_setzero_ps();
1533
1534 for c in 0..chunks32 {
1535 let base = c * 32;
1536 acc0 = _mm256_fmadd_ps(
1537 _mm256_loadu_ps(a.as_ptr().add(base)),
1538 _mm256_loadu_ps(b.as_ptr().add(base)),
1539 acc0,
1540 );
1541 acc1 = _mm256_fmadd_ps(
1542 _mm256_loadu_ps(a.as_ptr().add(base + 8)),
1543 _mm256_loadu_ps(b.as_ptr().add(base + 8)),
1544 acc1,
1545 );
1546 acc2 = _mm256_fmadd_ps(
1547 _mm256_loadu_ps(a.as_ptr().add(base + 16)),
1548 _mm256_loadu_ps(b.as_ptr().add(base + 16)),
1549 acc2,
1550 );
1551 acc3 = _mm256_fmadd_ps(
1552 _mm256_loadu_ps(a.as_ptr().add(base + 24)),
1553 _mm256_loadu_ps(b.as_ptr().add(base + 24)),
1554 acc3,
1555 );
1556 }
1557
1558 let acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
1559
1560 let hi = _mm256_extractf128_ps(acc, 1);
1562 let lo = _mm256_castps256_ps128(acc);
1563 let sum128 = _mm_add_ps(lo, hi);
1564 let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
1565 let sums = _mm_add_ps(sum128, shuf);
1566 let shuf2 = _mm_movehl_ps(sums, sums);
1567 let final_sum = _mm_add_ss(sums, shuf2);
1568
1569 let mut sum = _mm_cvtss_f32(final_sum);
1570
1571 let base = chunks32 * 32;
1572 for i in 0..remainder {
1573 sum += a[base + i] * b[base + i];
1574 }
1575
1576 sum
1577}
1578
1579#[cfg(target_arch = "x86_64")]
1580#[target_feature(enable = "sse")]
1581#[allow(unsafe_op_in_unsafe_fn)]
1582unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1583 use std::arch::x86_64::*;
1584
1585 let chunks = count / 4;
1586 let remainder = count % 4;
1587
1588 let mut acc = _mm_setzero_ps();
1589
1590 for chunk in 0..chunks {
1591 let base = chunk * 4;
1592 let va = _mm_loadu_ps(a.as_ptr().add(base));
1593 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1594 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1595 }
1596
1597 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
1604
1605 let base = chunks * 4;
1607 for i in 0..remainder {
1608 sum += a[base + i] * b[base + i];
1609 }
1610
1611 sum
1612}
1613
1614#[cfg(target_arch = "x86_64")]
1615#[target_feature(enable = "avx512f")]
1616#[allow(unsafe_op_in_unsafe_fn)]
1617unsafe fn dot_product_f32_avx512(a: &[f32], b: &[f32], count: usize) -> f32 {
1618 use std::arch::x86_64::*;
1619
1620 let chunks64 = count / 64;
1621 let remainder = count % 64;
1622
1623 let mut acc0 = _mm512_setzero_ps();
1624 let mut acc1 = _mm512_setzero_ps();
1625 let mut acc2 = _mm512_setzero_ps();
1626 let mut acc3 = _mm512_setzero_ps();
1627
1628 for c in 0..chunks64 {
1629 let base = c * 64;
1630 acc0 = _mm512_fmadd_ps(
1631 _mm512_loadu_ps(a.as_ptr().add(base)),
1632 _mm512_loadu_ps(b.as_ptr().add(base)),
1633 acc0,
1634 );
1635 acc1 = _mm512_fmadd_ps(
1636 _mm512_loadu_ps(a.as_ptr().add(base + 16)),
1637 _mm512_loadu_ps(b.as_ptr().add(base + 16)),
1638 acc1,
1639 );
1640 acc2 = _mm512_fmadd_ps(
1641 _mm512_loadu_ps(a.as_ptr().add(base + 32)),
1642 _mm512_loadu_ps(b.as_ptr().add(base + 32)),
1643 acc2,
1644 );
1645 acc3 = _mm512_fmadd_ps(
1646 _mm512_loadu_ps(a.as_ptr().add(base + 48)),
1647 _mm512_loadu_ps(b.as_ptr().add(base + 48)),
1648 acc3,
1649 );
1650 }
1651
1652 let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
1653 let mut sum = _mm512_reduce_add_ps(acc);
1654
1655 let base = chunks64 * 64;
1656 for i in 0..remainder {
1657 sum += a[base + i] * b[base + i];
1658 }
1659
1660 sum
1661}
1662
1663#[cfg(target_arch = "x86_64")]
1664#[target_feature(enable = "avx512f")]
1665#[allow(unsafe_op_in_unsafe_fn)]
1666unsafe fn fused_dot_norm_avx512(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1667 use std::arch::x86_64::*;
1668
1669 let chunks64 = count / 64;
1670 let remainder = count % 64;
1671
1672 let mut d0 = _mm512_setzero_ps();
1673 let mut d1 = _mm512_setzero_ps();
1674 let mut d2 = _mm512_setzero_ps();
1675 let mut d3 = _mm512_setzero_ps();
1676 let mut n0 = _mm512_setzero_ps();
1677 let mut n1 = _mm512_setzero_ps();
1678 let mut n2 = _mm512_setzero_ps();
1679 let mut n3 = _mm512_setzero_ps();
1680
1681 for c in 0..chunks64 {
1682 let base = c * 64;
1683 let vb0 = _mm512_loadu_ps(b.as_ptr().add(base));
1684 d0 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1685 n0 = _mm512_fmadd_ps(vb0, vb0, n0);
1686 let vb1 = _mm512_loadu_ps(b.as_ptr().add(base + 16));
1687 d1 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 16)), vb1, d1);
1688 n1 = _mm512_fmadd_ps(vb1, vb1, n1);
1689 let vb2 = _mm512_loadu_ps(b.as_ptr().add(base + 32));
1690 d2 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 32)), vb2, d2);
1691 n2 = _mm512_fmadd_ps(vb2, vb2, n2);
1692 let vb3 = _mm512_loadu_ps(b.as_ptr().add(base + 48));
1693 d3 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 48)), vb3, d3);
1694 n3 = _mm512_fmadd_ps(vb3, vb3, n3);
1695 }
1696
1697 let acc_dot = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
1698 let acc_norm = _mm512_add_ps(_mm512_add_ps(n0, n1), _mm512_add_ps(n2, n3));
1699 let mut dot = _mm512_reduce_add_ps(acc_dot);
1700 let mut norm = _mm512_reduce_add_ps(acc_norm);
1701
1702 let base = chunks64 * 64;
1703 for i in 0..remainder {
1704 dot += a[base + i] * b[base + i];
1705 norm += b[base + i] * b[base + i];
1706 }
1707
1708 (dot, norm)
1709}
1710
1711#[inline]
1713pub fn max_f32(values: &[f32], count: usize) -> f32 {
1714 if count == 0 {
1715 return f32::NEG_INFINITY;
1716 }
1717
1718 #[cfg(target_arch = "aarch64")]
1719 {
1720 if neon::is_available() {
1721 return unsafe { max_f32_neon(values, count) };
1722 }
1723 }
1724
1725 #[cfg(target_arch = "x86_64")]
1726 {
1727 if sse::is_available() {
1728 return unsafe { max_f32_sse(values, count) };
1729 }
1730 }
1731
1732 values[..count]
1734 .iter()
1735 .cloned()
1736 .fold(f32::NEG_INFINITY, f32::max)
1737}
1738
1739#[cfg(target_arch = "aarch64")]
1740#[target_feature(enable = "neon")]
1741#[allow(unsafe_op_in_unsafe_fn)]
1742unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1743 use std::arch::aarch64::*;
1744
1745 let chunks = count / 4;
1746 let remainder = count % 4;
1747
1748 let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1749
1750 for chunk in 0..chunks {
1751 let base = chunk * 4;
1752 let v = vld1q_f32(values.as_ptr().add(base));
1753 max_v = vmaxq_f32(max_v, v);
1754 }
1755
1756 let mut max_val = vmaxvq_f32(max_v);
1758
1759 let base = chunks * 4;
1761 for i in 0..remainder {
1762 max_val = max_val.max(values[base + i]);
1763 }
1764
1765 max_val
1766}
1767
1768#[cfg(target_arch = "x86_64")]
1769#[target_feature(enable = "sse")]
1770#[allow(unsafe_op_in_unsafe_fn)]
1771unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1772 use std::arch::x86_64::*;
1773
1774 let chunks = count / 4;
1775 let remainder = count % 4;
1776
1777 let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1778
1779 for chunk in 0..chunks {
1780 let base = chunk * 4;
1781 let v = _mm_loadu_ps(values.as_ptr().add(base));
1782 max_v = _mm_max_ps(max_v, v);
1783 }
1784
1785 let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); let max1 = _mm_max_ps(max_v, shuf); let shuf2 = _mm_movehl_ps(max1, max1); let final_max = _mm_max_ss(max1, shuf2); let mut max_val = _mm_cvtss_f32(final_max);
1792
1793 let base = chunks * 4;
1795 for i in 0..remainder {
1796 max_val = max_val.max(values[base + i]);
1797 }
1798
1799 max_val
1800}
1801
1802#[inline]
1811fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1812 #[cfg(target_arch = "aarch64")]
1813 {
1814 if neon::is_available() {
1815 return unsafe { fused_dot_norm_neon(a, b, count) };
1816 }
1817 }
1818
1819 #[cfg(target_arch = "x86_64")]
1820 {
1821 if is_x86_feature_detected!("avx512f") {
1822 return unsafe { fused_dot_norm_avx512(a, b, count) };
1823 }
1824 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1825 return unsafe { fused_dot_norm_avx2(a, b, count) };
1826 }
1827 if sse::is_available() {
1828 return unsafe { fused_dot_norm_sse(a, b, count) };
1829 }
1830 }
1831
1832 let mut dot = 0.0f32;
1834 let mut norm_b = 0.0f32;
1835 for i in 0..count {
1836 dot += a[i] * b[i];
1837 norm_b += b[i] * b[i];
1838 }
1839 (dot, norm_b)
1840}
1841
1842#[cfg(target_arch = "aarch64")]
1843#[target_feature(enable = "neon")]
1844#[allow(unsafe_op_in_unsafe_fn)]
1845unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1846 use std::arch::aarch64::*;
1847
1848 let chunks16 = count / 16;
1849 let remainder = count % 16;
1850
1851 let mut d0 = vdupq_n_f32(0.0);
1852 let mut d1 = vdupq_n_f32(0.0);
1853 let mut d2 = vdupq_n_f32(0.0);
1854 let mut d3 = vdupq_n_f32(0.0);
1855 let mut n0 = vdupq_n_f32(0.0);
1856 let mut n1 = vdupq_n_f32(0.0);
1857 let mut n2 = vdupq_n_f32(0.0);
1858 let mut n3 = vdupq_n_f32(0.0);
1859
1860 for c in 0..chunks16 {
1861 let base = c * 16;
1862 let va0 = vld1q_f32(a.as_ptr().add(base));
1863 let vb0 = vld1q_f32(b.as_ptr().add(base));
1864 d0 = vfmaq_f32(d0, va0, vb0);
1865 n0 = vfmaq_f32(n0, vb0, vb0);
1866 let va1 = vld1q_f32(a.as_ptr().add(base + 4));
1867 let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
1868 d1 = vfmaq_f32(d1, va1, vb1);
1869 n1 = vfmaq_f32(n1, vb1, vb1);
1870 let va2 = vld1q_f32(a.as_ptr().add(base + 8));
1871 let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
1872 d2 = vfmaq_f32(d2, va2, vb2);
1873 n2 = vfmaq_f32(n2, vb2, vb2);
1874 let va3 = vld1q_f32(a.as_ptr().add(base + 12));
1875 let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
1876 d3 = vfmaq_f32(d3, va3, vb3);
1877 n3 = vfmaq_f32(n3, vb3, vb3);
1878 }
1879
1880 let acc_dot = vaddq_f32(vaddq_f32(d0, d1), vaddq_f32(d2, d3));
1881 let acc_norm = vaddq_f32(vaddq_f32(n0, n1), vaddq_f32(n2, n3));
1882 let mut dot = vaddvq_f32(acc_dot);
1883 let mut norm = vaddvq_f32(acc_norm);
1884
1885 let base = chunks16 * 16;
1886 for i in 0..remainder {
1887 dot += a[base + i] * b[base + i];
1888 norm += b[base + i] * b[base + i];
1889 }
1890
1891 (dot, norm)
1892}
1893
1894#[cfg(target_arch = "x86_64")]
1895#[target_feature(enable = "avx2", enable = "fma")]
1896#[allow(unsafe_op_in_unsafe_fn)]
1897unsafe fn fused_dot_norm_avx2(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1898 use std::arch::x86_64::*;
1899
1900 let chunks32 = count / 32;
1901 let remainder = count % 32;
1902
1903 let mut d0 = _mm256_setzero_ps();
1904 let mut d1 = _mm256_setzero_ps();
1905 let mut d2 = _mm256_setzero_ps();
1906 let mut d3 = _mm256_setzero_ps();
1907 let mut n0 = _mm256_setzero_ps();
1908 let mut n1 = _mm256_setzero_ps();
1909 let mut n2 = _mm256_setzero_ps();
1910 let mut n3 = _mm256_setzero_ps();
1911
1912 for c in 0..chunks32 {
1913 let base = c * 32;
1914 let vb0 = _mm256_loadu_ps(b.as_ptr().add(base));
1915 d0 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1916 n0 = _mm256_fmadd_ps(vb0, vb0, n0);
1917 let vb1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
1918 d1 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 8)), vb1, d1);
1919 n1 = _mm256_fmadd_ps(vb1, vb1, n1);
1920 let vb2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
1921 d2 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 16)), vb2, d2);
1922 n2 = _mm256_fmadd_ps(vb2, vb2, n2);
1923 let vb3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
1924 d3 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 24)), vb3, d3);
1925 n3 = _mm256_fmadd_ps(vb3, vb3, n3);
1926 }
1927
1928 let acc_dot = _mm256_add_ps(_mm256_add_ps(d0, d1), _mm256_add_ps(d2, d3));
1929 let acc_norm = _mm256_add_ps(_mm256_add_ps(n0, n1), _mm256_add_ps(n2, n3));
1930
1931 let hi_d = _mm256_extractf128_ps(acc_dot, 1);
1933 let lo_d = _mm256_castps256_ps128(acc_dot);
1934 let sum_d = _mm_add_ps(lo_d, hi_d);
1935 let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
1936 let sums_d = _mm_add_ps(sum_d, shuf_d);
1937 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1938 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
1939
1940 let hi_n = _mm256_extractf128_ps(acc_norm, 1);
1941 let lo_n = _mm256_castps256_ps128(acc_norm);
1942 let sum_n = _mm_add_ps(lo_n, hi_n);
1943 let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
1944 let sums_n = _mm_add_ps(sum_n, shuf_n);
1945 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1946 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
1947
1948 let base = chunks32 * 32;
1949 for i in 0..remainder {
1950 dot += a[base + i] * b[base + i];
1951 norm += b[base + i] * b[base + i];
1952 }
1953
1954 (dot, norm)
1955}
1956
1957#[cfg(target_arch = "x86_64")]
1958#[target_feature(enable = "sse")]
1959#[allow(unsafe_op_in_unsafe_fn)]
1960unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1961 use std::arch::x86_64::*;
1962
1963 let chunks = count / 4;
1964 let remainder = count % 4;
1965
1966 let mut acc_dot = _mm_setzero_ps();
1967 let mut acc_norm = _mm_setzero_ps();
1968
1969 for chunk in 0..chunks {
1970 let base = chunk * 4;
1971 let va = _mm_loadu_ps(a.as_ptr().add(base));
1972 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1973 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
1974 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
1975 }
1976
1977 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
1979 let sums_d = _mm_add_ps(acc_dot, shuf_d);
1980 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1981 let final_d = _mm_add_ss(sums_d, shuf2_d);
1982 let mut dot = _mm_cvtss_f32(final_d);
1983
1984 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
1985 let sums_n = _mm_add_ps(acc_norm, shuf_n);
1986 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1987 let final_n = _mm_add_ss(sums_n, shuf2_n);
1988 let mut norm = _mm_cvtss_f32(final_n);
1989
1990 let base = chunks * 4;
1991 for i in 0..remainder {
1992 dot += a[base + i] * b[base + i];
1993 norm += b[base + i] * b[base + i];
1994 }
1995
1996 (dot, norm)
1997}
1998
1999#[inline]
2005pub fn fast_inv_sqrt(x: f32) -> f32 {
2006 let half = 0.5 * x;
2007 let i = 0x5F37_5A86_u32.wrapping_sub(x.to_bits() >> 1);
2008 let y = f32::from_bits(i);
2009 let y = y * (1.5 - half * y * y); y * (1.5 - half * y * y) }
2012
2013#[inline]
2024pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2025 let n = scores.len();
2026 debug_assert!(vectors.len() >= n * dim);
2027 debug_assert_eq!(query.len(), dim);
2028
2029 if dim == 0 || n == 0 {
2030 return;
2031 }
2032
2033 let norm_q_sq = dot_product_f32(query, query, dim);
2035 if norm_q_sq < f32::EPSILON {
2036 for s in scores.iter_mut() {
2037 *s = 0.0;
2038 }
2039 return;
2040 }
2041 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2042
2043 for i in 0..n {
2044 let vec = &vectors[i * dim..(i + 1) * dim];
2045 let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2046 if norm_v_sq < f32::EPSILON {
2047 scores[i] = 0.0;
2048 } else {
2049 scores[i] = dot * inv_norm_q * fast_inv_sqrt(norm_v_sq);
2050 }
2051 }
2052}
2053
2054#[inline]
2060pub fn f32_to_f16(value: f32) -> u16 {
2061 let bits = value.to_bits();
2062 let sign = (bits >> 16) & 0x8000;
2063 let exp = ((bits >> 23) & 0xFF) as i32;
2064 let mantissa = bits & 0x7F_FFFF;
2065
2066 if exp == 255 {
2067 return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
2069 }
2070
2071 let exp16 = exp - 127 + 15;
2072
2073 if exp16 >= 31 {
2074 return (sign | 0x7C00) as u16; }
2076
2077 if exp16 <= 0 {
2078 if exp16 < -10 {
2079 return sign as u16; }
2081 let m = (mantissa | 0x80_0000) >> (1 - exp16);
2082 return (sign | (m >> 13)) as u16;
2083 }
2084
2085 (sign | ((exp16 as u32) << 10) | (mantissa >> 13)) as u16
2086}
2087
2088#[inline]
2090pub fn f16_to_f32(half: u16) -> f32 {
2091 let sign = ((half & 0x8000) as u32) << 16;
2092 let exp = ((half >> 10) & 0x1F) as u32;
2093 let mantissa = (half & 0x3FF) as u32;
2094
2095 if exp == 0 {
2096 if mantissa == 0 {
2097 return f32::from_bits(sign);
2098 }
2099 let mut e = 0u32;
2101 let mut m = mantissa;
2102 while (m & 0x400) == 0 {
2103 m <<= 1;
2104 e += 1;
2105 }
2106 return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
2107 }
2108
2109 if exp == 31 {
2110 return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
2111 }
2112
2113 f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
2114}
2115
2116const U8_SCALE: f32 = 127.5;
2121const U8_INV_SCALE: f32 = 1.0 / 127.5;
2122
2123#[inline]
2125pub fn f32_to_u8_saturating(value: f32) -> u8 {
2126 ((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
2127}
2128
2129#[inline]
2131pub fn u8_to_f32(byte: u8) -> f32 {
2132 byte as f32 * U8_INV_SCALE - 1.0
2133}
2134
2135pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
2141 debug_assert_eq!(src.len(), dst.len());
2142 for (s, d) in src.iter().zip(dst.iter_mut()) {
2143 *d = f32_to_f16(*s);
2144 }
2145}
2146
2147pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
2149 debug_assert_eq!(src.len(), dst.len());
2150 for (s, d) in src.iter().zip(dst.iter_mut()) {
2151 *d = f32_to_u8_saturating(*s);
2152 }
2153}
2154
2155#[cfg(target_arch = "aarch64")]
2160#[allow(unsafe_op_in_unsafe_fn)]
2161mod neon_quant {
2162 use std::arch::aarch64::*;
2163
2164 #[target_feature(enable = "neon")]
2170 pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2171 let chunks8 = dim / 8;
2172 let remainder = dim % 8;
2173
2174 let mut acc_dot = vdupq_n_f32(0.0);
2175 let mut acc_norm = vdupq_n_f32(0.0);
2176
2177 for c in 0..chunks8 {
2178 let base = c * 8;
2179
2180 let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2182 let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2183 let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2184
2185 let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2187 let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2188 let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2189
2190 acc_dot = vfmaq_f32(acc_dot, q_lo, v_lo);
2191 acc_dot = vfmaq_f32(acc_dot, q_hi, v_hi);
2192 acc_norm = vfmaq_f32(acc_norm, v_lo, v_lo);
2193 acc_norm = vfmaq_f32(acc_norm, v_hi, v_hi);
2194 }
2195
2196 let mut dot = vaddvq_f32(acc_dot);
2197 let mut norm = vaddvq_f32(acc_norm);
2198
2199 let base = chunks8 * 8;
2200 for i in 0..remainder {
2201 let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2202 let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2203 dot += q * v;
2204 norm += v * v;
2205 }
2206
2207 (dot, norm)
2208 }
2209
2210 #[target_feature(enable = "neon")]
2213 pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2214 let scale = vdupq_n_f32(super::U8_INV_SCALE);
2215 let offset = vdupq_n_f32(-1.0);
2216
2217 let chunks16 = dim / 16;
2218 let remainder = dim % 16;
2219
2220 let mut acc_dot = vdupq_n_f32(0.0);
2221 let mut acc_norm = vdupq_n_f32(0.0);
2222
2223 for c in 0..chunks16 {
2224 let base = c * 16;
2225
2226 let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2228
2229 let lo8 = vget_low_u8(bytes);
2231 let hi8 = vget_high_u8(bytes);
2232 let lo16 = vmovl_u8(lo8);
2233 let hi16 = vmovl_u8(hi8);
2234
2235 let f0 = vaddq_f32(
2236 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2237 offset,
2238 );
2239 let f1 = vaddq_f32(
2240 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2241 offset,
2242 );
2243 let f2 = vaddq_f32(
2244 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2245 offset,
2246 );
2247 let f3 = vaddq_f32(
2248 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2249 offset,
2250 );
2251
2252 let q0 = vld1q_f32(query.as_ptr().add(base));
2253 let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2254 let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2255 let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2256
2257 acc_dot = vfmaq_f32(acc_dot, q0, f0);
2258 acc_dot = vfmaq_f32(acc_dot, q1, f1);
2259 acc_dot = vfmaq_f32(acc_dot, q2, f2);
2260 acc_dot = vfmaq_f32(acc_dot, q3, f3);
2261
2262 acc_norm = vfmaq_f32(acc_norm, f0, f0);
2263 acc_norm = vfmaq_f32(acc_norm, f1, f1);
2264 acc_norm = vfmaq_f32(acc_norm, f2, f2);
2265 acc_norm = vfmaq_f32(acc_norm, f3, f3);
2266 }
2267
2268 let mut dot = vaddvq_f32(acc_dot);
2269 let mut norm = vaddvq_f32(acc_norm);
2270
2271 let base = chunks16 * 16;
2272 for i in 0..remainder {
2273 let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2274 dot += *query.get_unchecked(base + i) * v;
2275 norm += v * v;
2276 }
2277
2278 (dot, norm)
2279 }
2280
2281 #[target_feature(enable = "neon")]
2283 pub unsafe fn dot_product_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2284 let chunks8 = dim / 8;
2285 let remainder = dim % 8;
2286
2287 let mut acc = vdupq_n_f32(0.0);
2288
2289 for c in 0..chunks8 {
2290 let base = c * 8;
2291 let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2292 let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2293 let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2294 let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2295 let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2296 let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2297 acc = vfmaq_f32(acc, q_lo, v_lo);
2298 acc = vfmaq_f32(acc, q_hi, v_hi);
2299 }
2300
2301 let mut dot = vaddvq_f32(acc);
2302 let base = chunks8 * 8;
2303 for i in 0..remainder {
2304 let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2305 let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2306 dot += q * v;
2307 }
2308 dot
2309 }
2310
2311 #[target_feature(enable = "neon")]
2313 pub unsafe fn dot_product_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2314 let scale = vdupq_n_f32(super::U8_INV_SCALE);
2315 let offset = vdupq_n_f32(-1.0);
2316 let chunks16 = dim / 16;
2317 let remainder = dim % 16;
2318
2319 let mut acc = vdupq_n_f32(0.0);
2320
2321 for c in 0..chunks16 {
2322 let base = c * 16;
2323 let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2324 let lo8 = vget_low_u8(bytes);
2325 let hi8 = vget_high_u8(bytes);
2326 let lo16 = vmovl_u8(lo8);
2327 let hi16 = vmovl_u8(hi8);
2328 let f0 = vaddq_f32(
2329 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2330 offset,
2331 );
2332 let f1 = vaddq_f32(
2333 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2334 offset,
2335 );
2336 let f2 = vaddq_f32(
2337 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2338 offset,
2339 );
2340 let f3 = vaddq_f32(
2341 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2342 offset,
2343 );
2344 let q0 = vld1q_f32(query.as_ptr().add(base));
2345 let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2346 let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2347 let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2348 acc = vfmaq_f32(acc, q0, f0);
2349 acc = vfmaq_f32(acc, q1, f1);
2350 acc = vfmaq_f32(acc, q2, f2);
2351 acc = vfmaq_f32(acc, q3, f3);
2352 }
2353
2354 let mut dot = vaddvq_f32(acc);
2355 let base = chunks16 * 16;
2356 for i in 0..remainder {
2357 let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2358 dot += *query.get_unchecked(base + i) * v;
2359 }
2360 dot
2361 }
2362}
2363
2364#[allow(dead_code)]
2369fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2370 let mut dot = 0.0f32;
2371 let mut norm = 0.0f32;
2372 for i in 0..dim {
2373 let v = f16_to_f32(vec_f16[i]);
2374 let q = f16_to_f32(query_f16[i]);
2375 dot += q * v;
2376 norm += v * v;
2377 }
2378 (dot, norm)
2379}
2380
2381#[allow(dead_code)]
2382fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2383 let mut dot = 0.0f32;
2384 let mut norm = 0.0f32;
2385 for i in 0..dim {
2386 let v = u8_to_f32(vec_u8[i]);
2387 dot += query[i] * v;
2388 norm += v * v;
2389 }
2390 (dot, norm)
2391}
2392
2393#[allow(dead_code)]
2394fn dot_product_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2395 let mut dot = 0.0f32;
2396 for i in 0..dim {
2397 dot += f16_to_f32(query_f16[i]) * f16_to_f32(vec_f16[i]);
2398 }
2399 dot
2400}
2401
2402#[allow(dead_code)]
2403fn dot_product_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2404 let mut dot = 0.0f32;
2405 for i in 0..dim {
2406 dot += query[i] * u8_to_f32(vec_u8[i]);
2407 }
2408 dot
2409}
2410
2411#[cfg(target_arch = "x86_64")]
2416#[target_feature(enable = "sse2", enable = "sse4.1")]
2417#[allow(unsafe_op_in_unsafe_fn)]
2418unsafe fn fused_dot_norm_f16_sse(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2419 use std::arch::x86_64::*;
2420
2421 let chunks = dim / 4;
2422 let remainder = dim % 4;
2423
2424 let mut acc_dot = _mm_setzero_ps();
2425 let mut acc_norm = _mm_setzero_ps();
2426
2427 for chunk in 0..chunks {
2428 let base = chunk * 4;
2429 let v0 = f16_to_f32(*vec_f16.get_unchecked(base));
2431 let v1 = f16_to_f32(*vec_f16.get_unchecked(base + 1));
2432 let v2 = f16_to_f32(*vec_f16.get_unchecked(base + 2));
2433 let v3 = f16_to_f32(*vec_f16.get_unchecked(base + 3));
2434 let vb = _mm_set_ps(v3, v2, v1, v0);
2435
2436 let q0 = f16_to_f32(*query_f16.get_unchecked(base));
2437 let q1 = f16_to_f32(*query_f16.get_unchecked(base + 1));
2438 let q2 = f16_to_f32(*query_f16.get_unchecked(base + 2));
2439 let q3 = f16_to_f32(*query_f16.get_unchecked(base + 3));
2440 let va = _mm_set_ps(q3, q2, q1, q0);
2441
2442 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2443 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2444 }
2445
2446 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2448 let sums_d = _mm_add_ps(acc_dot, shuf_d);
2449 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2450 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2451
2452 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2453 let sums_n = _mm_add_ps(acc_norm, shuf_n);
2454 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2455 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2456
2457 let base = chunks * 4;
2458 for i in 0..remainder {
2459 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2460 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2461 dot += q * v;
2462 norm += v * v;
2463 }
2464
2465 (dot, norm)
2466}
2467
2468#[cfg(target_arch = "x86_64")]
2469#[target_feature(enable = "sse2", enable = "sse4.1")]
2470#[allow(unsafe_op_in_unsafe_fn)]
2471unsafe fn fused_dot_norm_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2472 use std::arch::x86_64::*;
2473
2474 let scale = _mm_set1_ps(U8_INV_SCALE);
2475 let offset = _mm_set1_ps(-1.0);
2476
2477 let chunks = dim / 4;
2478 let remainder = dim % 4;
2479
2480 let mut acc_dot = _mm_setzero_ps();
2481 let mut acc_norm = _mm_setzero_ps();
2482
2483 for chunk in 0..chunks {
2484 let base = chunk * 4;
2485
2486 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2488 vec_u8.as_ptr().add(base) as *const i32
2489 ));
2490 let ints = _mm_cvtepu8_epi32(bytes);
2491 let floats = _mm_cvtepi32_ps(ints);
2492 let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2493
2494 let va = _mm_loadu_ps(query.as_ptr().add(base));
2495
2496 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2497 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2498 }
2499
2500 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2502 let sums_d = _mm_add_ps(acc_dot, shuf_d);
2503 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2504 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2505
2506 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2507 let sums_n = _mm_add_ps(acc_norm, shuf_n);
2508 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2509 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2510
2511 let base = chunks * 4;
2512 for i in 0..remainder {
2513 let v = u8_to_f32(*vec_u8.get_unchecked(base + i));
2514 dot += *query.get_unchecked(base + i) * v;
2515 norm += v * v;
2516 }
2517
2518 (dot, norm)
2519}
2520
2521#[cfg(target_arch = "x86_64")]
2526#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2527#[allow(unsafe_op_in_unsafe_fn)]
2528unsafe fn fused_dot_norm_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2529 use std::arch::x86_64::*;
2530
2531 let chunks = dim / 8;
2532 let remainder = dim % 8;
2533
2534 let mut acc_dot = _mm256_setzero_ps();
2535 let mut acc_norm = _mm256_setzero_ps();
2536
2537 for chunk in 0..chunks {
2538 let base = chunk * 8;
2539 let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2541 let vb = _mm256_cvtph_ps(v_raw);
2542 let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2543 let qa = _mm256_cvtph_ps(q_raw);
2544 acc_dot = _mm256_fmadd_ps(qa, vb, acc_dot);
2545 acc_norm = _mm256_fmadd_ps(vb, vb, acc_norm);
2546 }
2547
2548 let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2550 let lo_d = _mm256_castps256_ps128(acc_dot);
2551 let sum_d = _mm_add_ps(lo_d, hi_d);
2552 let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2553 let sums_d = _mm_add_ps(sum_d, shuf_d);
2554 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2555 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2556
2557 let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2558 let lo_n = _mm256_castps256_ps128(acc_norm);
2559 let sum_n = _mm_add_ps(lo_n, hi_n);
2560 let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2561 let sums_n = _mm_add_ps(sum_n, shuf_n);
2562 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2563 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2564
2565 let base = chunks * 8;
2566 for i in 0..remainder {
2567 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2568 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2569 dot += q * v;
2570 norm += v * v;
2571 }
2572
2573 (dot, norm)
2574}
2575
2576#[cfg(target_arch = "x86_64")]
2577#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2578#[allow(unsafe_op_in_unsafe_fn)]
2579unsafe fn dot_product_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2580 use std::arch::x86_64::*;
2581
2582 let chunks = dim / 8;
2583 let remainder = dim % 8;
2584 let mut acc = _mm256_setzero_ps();
2585
2586 for chunk in 0..chunks {
2587 let base = chunk * 8;
2588 let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2589 let vb = _mm256_cvtph_ps(v_raw);
2590 let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2591 let qa = _mm256_cvtph_ps(q_raw);
2592 acc = _mm256_fmadd_ps(qa, vb, acc);
2593 }
2594
2595 let hi = _mm256_extractf128_ps(acc, 1);
2596 let lo = _mm256_castps256_ps128(acc);
2597 let sum = _mm_add_ps(lo, hi);
2598 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
2599 let sums = _mm_add_ps(sum, shuf);
2600 let shuf2 = _mm_movehl_ps(sums, sums);
2601 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2602
2603 let base = chunks * 8;
2604 for i in 0..remainder {
2605 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2606 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2607 dot += q * v;
2608 }
2609 dot
2610}
2611
2612#[cfg(target_arch = "x86_64")]
2613#[target_feature(enable = "sse2", enable = "sse4.1")]
2614#[allow(unsafe_op_in_unsafe_fn)]
2615unsafe fn dot_product_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2616 use std::arch::x86_64::*;
2617
2618 let scale = _mm_set1_ps(U8_INV_SCALE);
2619 let offset = _mm_set1_ps(-1.0);
2620 let chunks = dim / 4;
2621 let remainder = dim % 4;
2622 let mut acc = _mm_setzero_ps();
2623
2624 for chunk in 0..chunks {
2625 let base = chunk * 4;
2626 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2627 vec_u8.as_ptr().add(base) as *const i32
2628 ));
2629 let ints = _mm_cvtepu8_epi32(bytes);
2630 let floats = _mm_cvtepi32_ps(ints);
2631 let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2632 let va = _mm_loadu_ps(query.as_ptr().add(base));
2633 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
2634 }
2635
2636 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01);
2637 let sums = _mm_add_ps(acc, shuf);
2638 let shuf2 = _mm_movehl_ps(sums, sums);
2639 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2640
2641 let base = chunks * 4;
2642 for i in 0..remainder {
2643 dot += *query.get_unchecked(base + i) * u8_to_f32(*vec_u8.get_unchecked(base + i));
2644 }
2645 dot
2646}
2647
2648#[inline]
2653fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2654 #[cfg(target_arch = "aarch64")]
2655 {
2656 return unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) };
2657 }
2658
2659 #[cfg(target_arch = "x86_64")]
2660 {
2661 if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2662 return unsafe { fused_dot_norm_f16_f16c(query_f16, vec_f16, dim) };
2663 }
2664 if sse::is_available() {
2665 return unsafe { fused_dot_norm_f16_sse(query_f16, vec_f16, dim) };
2666 }
2667 }
2668
2669 #[allow(unreachable_code)]
2670 fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
2671}
2672
2673#[inline]
2674fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2675 #[cfg(target_arch = "aarch64")]
2676 {
2677 return unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) };
2678 }
2679
2680 #[cfg(target_arch = "x86_64")]
2681 {
2682 if sse::is_available() {
2683 return unsafe { fused_dot_norm_u8_sse(query, vec_u8, dim) };
2684 }
2685 }
2686
2687 #[allow(unreachable_code)]
2688 fused_dot_norm_u8_scalar(query, vec_u8, dim)
2689}
2690
2691#[inline]
2694fn dot_product_f16_quant(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2695 #[cfg(target_arch = "aarch64")]
2696 {
2697 return unsafe { neon_quant::dot_product_f16(query_f16, vec_f16, dim) };
2698 }
2699
2700 #[cfg(target_arch = "x86_64")]
2701 {
2702 if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2703 return unsafe { dot_product_f16_f16c(query_f16, vec_f16, dim) };
2704 }
2705 }
2706
2707 #[allow(unreachable_code)]
2708 dot_product_f16_scalar(query_f16, vec_f16, dim)
2709}
2710
2711#[inline]
2712fn dot_product_u8_quant(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2713 #[cfg(target_arch = "aarch64")]
2714 {
2715 return unsafe { neon_quant::dot_product_u8(query, vec_u8, dim) };
2716 }
2717
2718 #[cfg(target_arch = "x86_64")]
2719 {
2720 if sse::is_available() {
2721 return unsafe { dot_product_u8_sse(query, vec_u8, dim) };
2722 }
2723 }
2724
2725 #[allow(unreachable_code)]
2726 dot_product_u8_scalar(query, vec_u8, dim)
2727}
2728
2729#[inline]
2740pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2741 let n = scores.len();
2742 if dim == 0 || n == 0 {
2743 return;
2744 }
2745
2746 let norm_q_sq = dot_product_f32(query, query, dim);
2748 if norm_q_sq < f32::EPSILON {
2749 for s in scores.iter_mut() {
2750 *s = 0.0;
2751 }
2752 return;
2753 }
2754 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2755
2756 let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2758
2759 let vec_bytes = dim * 2;
2760 debug_assert!(vectors_raw.len() >= n * vec_bytes);
2761
2762 debug_assert!(
2765 (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2766 "f16 vector data not 2-byte aligned"
2767 );
2768
2769 for i in 0..n {
2770 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2771 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2772
2773 let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
2774 scores[i] = if norm_v_sq < f32::EPSILON {
2775 0.0
2776 } else {
2777 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2778 };
2779 }
2780}
2781
2782#[inline]
2788pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2789 let n = scores.len();
2790 if dim == 0 || n == 0 {
2791 return;
2792 }
2793
2794 let norm_q_sq = dot_product_f32(query, query, dim);
2795 if norm_q_sq < f32::EPSILON {
2796 for s in scores.iter_mut() {
2797 *s = 0.0;
2798 }
2799 return;
2800 }
2801 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2802
2803 debug_assert!(vectors_raw.len() >= n * dim);
2804
2805 for i in 0..n {
2806 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2807
2808 let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
2809 scores[i] = if norm_v_sq < f32::EPSILON {
2810 0.0
2811 } else {
2812 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2813 };
2814 }
2815}
2816
2817#[inline]
2826pub fn batch_dot_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2827 let n = scores.len();
2828 debug_assert!(vectors.len() >= n * dim);
2829 debug_assert_eq!(query.len(), dim);
2830
2831 if dim == 0 || n == 0 {
2832 return;
2833 }
2834
2835 let norm_q_sq = dot_product_f32(query, query, dim);
2836 if norm_q_sq < f32::EPSILON {
2837 for s in scores.iter_mut() {
2838 *s = 0.0;
2839 }
2840 return;
2841 }
2842 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2843
2844 for i in 0..n {
2845 let vec = &vectors[i * dim..(i + 1) * dim];
2846 let dot = dot_product_f32(query, vec, dim);
2847 scores[i] = dot * inv_norm_q;
2848 }
2849}
2850
2851#[inline]
2856pub fn batch_dot_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2857 let n = scores.len();
2858 if dim == 0 || n == 0 {
2859 return;
2860 }
2861
2862 let norm_q_sq = dot_product_f32(query, query, dim);
2863 if norm_q_sq < f32::EPSILON {
2864 for s in scores.iter_mut() {
2865 *s = 0.0;
2866 }
2867 return;
2868 }
2869 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2870
2871 let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2872 let vec_bytes = dim * 2;
2873 debug_assert!(vectors_raw.len() >= n * vec_bytes);
2874 debug_assert!(
2875 (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2876 "f16 vector data not 2-byte aligned"
2877 );
2878
2879 for i in 0..n {
2880 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2881 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2882 let dot = dot_product_f16_quant(&query_f16, f16_slice, dim);
2883 scores[i] = dot * inv_norm_q;
2884 }
2885}
2886
2887#[inline]
2892pub fn batch_dot_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2893 let n = scores.len();
2894 if dim == 0 || n == 0 {
2895 return;
2896 }
2897
2898 let norm_q_sq = dot_product_f32(query, query, dim);
2899 if norm_q_sq < f32::EPSILON {
2900 for s in scores.iter_mut() {
2901 *s = 0.0;
2902 }
2903 return;
2904 }
2905 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2906
2907 debug_assert!(vectors_raw.len() >= n * dim);
2908
2909 for i in 0..n {
2910 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2911 let dot = dot_product_u8_quant(query, u8_slice, dim);
2912 scores[i] = dot * inv_norm_q;
2913 }
2914}
2915
2916#[inline]
2922pub fn batch_cosine_scores_precomp(
2923 query: &[f32],
2924 vectors: &[f32],
2925 dim: usize,
2926 scores: &mut [f32],
2927 inv_norm_q: f32,
2928) {
2929 let n = scores.len();
2930 debug_assert!(vectors.len() >= n * dim);
2931 for i in 0..n {
2932 let vec = &vectors[i * dim..(i + 1) * dim];
2933 let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2934 scores[i] = if norm_v_sq < f32::EPSILON {
2935 0.0
2936 } else {
2937 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2938 };
2939 }
2940}
2941
2942#[inline]
2944pub fn batch_cosine_scores_f16_precomp(
2945 query_f16: &[u16],
2946 vectors_raw: &[u8],
2947 dim: usize,
2948 scores: &mut [f32],
2949 inv_norm_q: f32,
2950) {
2951 let n = scores.len();
2952 let vec_bytes = dim * 2;
2953 debug_assert!(vectors_raw.len() >= n * vec_bytes);
2954 for i in 0..n {
2955 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2956 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2957 let (dot, norm_v_sq) = fused_dot_norm_f16(query_f16, f16_slice, dim);
2958 scores[i] = if norm_v_sq < f32::EPSILON {
2959 0.0
2960 } else {
2961 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2962 };
2963 }
2964}
2965
2966#[inline]
2968pub fn batch_cosine_scores_u8_precomp(
2969 query: &[f32],
2970 vectors_raw: &[u8],
2971 dim: usize,
2972 scores: &mut [f32],
2973 inv_norm_q: f32,
2974) {
2975 let n = scores.len();
2976 debug_assert!(vectors_raw.len() >= n * dim);
2977 for i in 0..n {
2978 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2979 let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
2980 scores[i] = if norm_v_sq < f32::EPSILON {
2981 0.0
2982 } else {
2983 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2984 };
2985 }
2986}
2987
2988#[inline]
2990pub fn batch_dot_scores_precomp(
2991 query: &[f32],
2992 vectors: &[f32],
2993 dim: usize,
2994 scores: &mut [f32],
2995 inv_norm_q: f32,
2996) {
2997 let n = scores.len();
2998 debug_assert!(vectors.len() >= n * dim);
2999 for i in 0..n {
3000 let vec = &vectors[i * dim..(i + 1) * dim];
3001 scores[i] = dot_product_f32(query, vec, dim) * inv_norm_q;
3002 }
3003}
3004
3005#[inline]
3007pub fn batch_dot_scores_f16_precomp(
3008 query_f16: &[u16],
3009 vectors_raw: &[u8],
3010 dim: usize,
3011 scores: &mut [f32],
3012 inv_norm_q: f32,
3013) {
3014 let n = scores.len();
3015 let vec_bytes = dim * 2;
3016 debug_assert!(vectors_raw.len() >= n * vec_bytes);
3017 for i in 0..n {
3018 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3019 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3020 scores[i] = dot_product_f16_quant(query_f16, f16_slice, dim) * inv_norm_q;
3021 }
3022}
3023
3024#[inline]
3026pub fn batch_dot_scores_u8_precomp(
3027 query: &[f32],
3028 vectors_raw: &[u8],
3029 dim: usize,
3030 scores: &mut [f32],
3031 inv_norm_q: f32,
3032) {
3033 let n = scores.len();
3034 debug_assert!(vectors_raw.len() >= n * dim);
3035 for i in 0..n {
3036 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3037 scores[i] = dot_product_u8_quant(query, u8_slice, dim) * inv_norm_q;
3038 }
3039}
3040
3041#[inline]
3046pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3047 debug_assert_eq!(a.len(), b.len());
3048 let count = a.len();
3049
3050 if count == 0 {
3051 return 0.0;
3052 }
3053
3054 let dot = dot_product_f32(a, b, count);
3055 let norm_a = dot_product_f32(a, a, count);
3056 let norm_b = dot_product_f32(b, b, count);
3057
3058 let denom = (norm_a * norm_b).sqrt();
3059 if denom < f32::EPSILON {
3060 return 0.0;
3061 }
3062
3063 dot / denom
3064}
3065
3066#[inline]
3070pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
3071 debug_assert_eq!(a.len(), b.len());
3072 let count = a.len();
3073
3074 if count == 0 {
3075 return 0.0;
3076 }
3077
3078 #[cfg(target_arch = "aarch64")]
3079 {
3080 if neon::is_available() {
3081 return unsafe { squared_euclidean_neon(a, b, count) };
3082 }
3083 }
3084
3085 #[cfg(target_arch = "x86_64")]
3086 {
3087 if avx2::is_available() {
3088 return unsafe { squared_euclidean_avx2(a, b, count) };
3089 }
3090 if sse::is_available() {
3091 return unsafe { squared_euclidean_sse(a, b, count) };
3092 }
3093 }
3094
3095 a.iter()
3097 .zip(b.iter())
3098 .map(|(&x, &y)| {
3099 let d = x - y;
3100 d * d
3101 })
3102 .sum()
3103}
3104
3105#[cfg(target_arch = "aarch64")]
3106#[target_feature(enable = "neon")]
3107#[allow(unsafe_op_in_unsafe_fn)]
3108unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
3109 use std::arch::aarch64::*;
3110
3111 let chunks = count / 4;
3112 let remainder = count % 4;
3113
3114 let mut acc = vdupq_n_f32(0.0);
3115
3116 for chunk in 0..chunks {
3117 let base = chunk * 4;
3118 let va = vld1q_f32(a.as_ptr().add(base));
3119 let vb = vld1q_f32(b.as_ptr().add(base));
3120 let diff = vsubq_f32(va, vb);
3121 acc = vfmaq_f32(acc, diff, diff); }
3123
3124 let mut sum = vaddvq_f32(acc);
3126
3127 let base = chunks * 4;
3129 for i in 0..remainder {
3130 let d = a[base + i] - b[base + i];
3131 sum += d * d;
3132 }
3133
3134 sum
3135}
3136
3137#[cfg(target_arch = "x86_64")]
3138#[target_feature(enable = "sse")]
3139#[allow(unsafe_op_in_unsafe_fn)]
3140unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
3141 use std::arch::x86_64::*;
3142
3143 let chunks = count / 4;
3144 let remainder = count % 4;
3145
3146 let mut acc = _mm_setzero_ps();
3147
3148 for chunk in 0..chunks {
3149 let base = chunk * 4;
3150 let va = _mm_loadu_ps(a.as_ptr().add(base));
3151 let vb = _mm_loadu_ps(b.as_ptr().add(base));
3152 let diff = _mm_sub_ps(va, vb);
3153 acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
3154 }
3155
3156 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
3163
3164 let base = chunks * 4;
3166 for i in 0..remainder {
3167 let d = a[base + i] - b[base + i];
3168 sum += d * d;
3169 }
3170
3171 sum
3172}
3173
3174#[cfg(target_arch = "x86_64")]
3175#[target_feature(enable = "avx2")]
3176#[allow(unsafe_op_in_unsafe_fn)]
3177unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
3178 use std::arch::x86_64::*;
3179
3180 let chunks = count / 8;
3181 let remainder = count % 8;
3182
3183 let mut acc = _mm256_setzero_ps();
3184
3185 for chunk in 0..chunks {
3186 let base = chunk * 8;
3187 let va = _mm256_loadu_ps(a.as_ptr().add(base));
3188 let vb = _mm256_loadu_ps(b.as_ptr().add(base));
3189 let diff = _mm256_sub_ps(va, vb);
3190 acc = _mm256_fmadd_ps(diff, diff, acc); }
3192
3193 let high = _mm256_extractf128_ps(acc, 1);
3196 let low = _mm256_castps256_ps128(acc);
3197 let sum128 = _mm_add_ps(low, high);
3198
3199 let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
3201 let sums = _mm_add_ps(sum128, shuf);
3202 let shuf2 = _mm_movehl_ps(sums, sums);
3203 let final_sum = _mm_add_ss(sums, shuf2);
3204
3205 let mut sum = _mm_cvtss_f32(final_sum);
3206
3207 let base = chunks * 8;
3209 for i in 0..remainder {
3210 let d = a[base + i] - b[base + i];
3211 sum += d * d;
3212 }
3213
3214 sum
3215}
3216
3217#[inline]
3223pub fn batch_squared_euclidean_distances(
3224 query: &[f32],
3225 vectors: &[Vec<f32>],
3226 distances: &mut [f32],
3227) {
3228 debug_assert_eq!(vectors.len(), distances.len());
3229
3230 #[cfg(target_arch = "x86_64")]
3231 {
3232 if avx2::is_available() {
3233 for (i, vec) in vectors.iter().enumerate() {
3234 distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
3235 }
3236 return;
3237 }
3238 }
3239
3240 for (i, vec) in vectors.iter().enumerate() {
3242 distances[i] = squared_euclidean_distance(query, vec);
3243 }
3244}
3245
3246#[cfg(test)]
3247mod tests {
3248 use super::*;
3249
3250 #[test]
3251 fn test_unpack_8bit() {
3252 let input: Vec<u8> = (0..128).collect();
3253 let mut output = vec![0u32; 128];
3254 unpack_8bit(&input, &mut output, 128);
3255
3256 for (i, &v) in output.iter().enumerate() {
3257 assert_eq!(v, i as u32);
3258 }
3259 }
3260
3261 #[test]
3262 fn test_unpack_16bit() {
3263 let mut input = vec![0u8; 256];
3264 for i in 0..128 {
3265 let val = (i * 100) as u16;
3266 input[i * 2] = val as u8;
3267 input[i * 2 + 1] = (val >> 8) as u8;
3268 }
3269
3270 let mut output = vec![0u32; 128];
3271 unpack_16bit(&input, &mut output, 128);
3272
3273 for (i, &v) in output.iter().enumerate() {
3274 assert_eq!(v, (i * 100) as u32);
3275 }
3276 }
3277
3278 #[test]
3279 fn test_unpack_32bit() {
3280 let mut input = vec![0u8; 512];
3281 for i in 0..128 {
3282 let val = (i * 1000) as u32;
3283 let bytes = val.to_le_bytes();
3284 input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
3285 }
3286
3287 let mut output = vec![0u32; 128];
3288 unpack_32bit(&input, &mut output, 128);
3289
3290 for (i, &v) in output.iter().enumerate() {
3291 assert_eq!(v, (i * 1000) as u32);
3292 }
3293 }
3294
3295 #[test]
3296 fn test_delta_decode() {
3297 let deltas = vec![4u32, 4, 9, 19];
3301 let mut output = vec![0u32; 5];
3302
3303 delta_decode(&mut output, &deltas, 10, 5);
3304
3305 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3306 }
3307
3308 #[test]
3309 fn test_add_one() {
3310 let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
3311 add_one(&mut values, 8);
3312
3313 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
3314 }
3315
3316 #[test]
3317 fn test_bits_needed() {
3318 assert_eq!(bits_needed(0), 0);
3319 assert_eq!(bits_needed(1), 1);
3320 assert_eq!(bits_needed(2), 2);
3321 assert_eq!(bits_needed(3), 2);
3322 assert_eq!(bits_needed(4), 3);
3323 assert_eq!(bits_needed(255), 8);
3324 assert_eq!(bits_needed(256), 9);
3325 assert_eq!(bits_needed(u32::MAX), 32);
3326 }
3327
3328 #[test]
3329 fn test_unpack_8bit_delta_decode() {
3330 let input: Vec<u8> = vec![4, 4, 9, 19];
3334 let mut output = vec![0u32; 5];
3335
3336 unpack_8bit_delta_decode(&input, &mut output, 10, 5);
3337
3338 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3339 }
3340
3341 #[test]
3342 fn test_unpack_16bit_delta_decode() {
3343 let mut input = vec![0u8; 8];
3347 for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
3348 input[i * 2] = delta as u8;
3349 input[i * 2 + 1] = (delta >> 8) as u8;
3350 }
3351 let mut output = vec![0u32; 5];
3352
3353 unpack_16bit_delta_decode(&input, &mut output, 100, 5);
3354
3355 assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
3356 }
3357
3358 #[test]
3359 fn test_fused_vs_separate_8bit() {
3360 let input: Vec<u8> = (0..127).collect();
3362 let first_value = 1000u32;
3363 let count = 128;
3364
3365 let mut unpacked = vec![0u32; 128];
3367 unpack_8bit(&input, &mut unpacked, 127);
3368 let mut separate_output = vec![0u32; 128];
3369 delta_decode(&mut separate_output, &unpacked, first_value, count);
3370
3371 let mut fused_output = vec![0u32; 128];
3373 unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
3374
3375 assert_eq!(separate_output, fused_output);
3376 }
3377
3378 #[test]
3379 fn test_round_bit_width() {
3380 assert_eq!(round_bit_width(0), 0);
3381 assert_eq!(round_bit_width(1), 8);
3382 assert_eq!(round_bit_width(5), 8);
3383 assert_eq!(round_bit_width(8), 8);
3384 assert_eq!(round_bit_width(9), 16);
3385 assert_eq!(round_bit_width(12), 16);
3386 assert_eq!(round_bit_width(16), 16);
3387 assert_eq!(round_bit_width(17), 32);
3388 assert_eq!(round_bit_width(24), 32);
3389 assert_eq!(round_bit_width(32), 32);
3390 }
3391
3392 #[test]
3393 fn test_rounded_bitwidth_from_exact() {
3394 assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
3395 assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
3396 assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
3397 assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
3398 assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
3399 assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
3400 assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
3401 }
3402
3403 #[test]
3404 fn test_pack_unpack_rounded_8bit() {
3405 let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
3406 let mut packed = vec![0u8; 128];
3407
3408 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
3409 assert_eq!(bytes_written, 128);
3410
3411 let mut unpacked = vec![0u32; 128];
3412 unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
3413
3414 assert_eq!(values, unpacked);
3415 }
3416
3417 #[test]
3418 fn test_pack_unpack_rounded_16bit() {
3419 let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
3420 let mut packed = vec![0u8; 256];
3421
3422 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
3423 assert_eq!(bytes_written, 256);
3424
3425 let mut unpacked = vec![0u32; 128];
3426 unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
3427
3428 assert_eq!(values, unpacked);
3429 }
3430
3431 #[test]
3432 fn test_pack_unpack_rounded_32bit() {
3433 let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
3434 let mut packed = vec![0u8; 512];
3435
3436 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
3437 assert_eq!(bytes_written, 512);
3438
3439 let mut unpacked = vec![0u32; 128];
3440 unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
3441
3442 assert_eq!(values, unpacked);
3443 }
3444
3445 #[test]
3446 fn test_unpack_rounded_delta_decode() {
3447 let input: Vec<u8> = vec![4, 4, 9, 19];
3452 let mut output = vec![0u32; 5];
3453
3454 unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
3455
3456 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3457 }
3458
3459 #[test]
3460 fn test_unpack_rounded_delta_decode_zero() {
3461 let input: Vec<u8> = vec![];
3463 let mut output = vec![0u32; 5];
3464
3465 unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
3466
3467 assert_eq!(output, vec![100, 101, 102, 103, 104]);
3468 }
3469
3470 #[test]
3475 fn test_dequantize_uint8() {
3476 let input: Vec<u8> = vec![0, 128, 255, 64, 192];
3477 let mut output = vec![0.0f32; 5];
3478 let scale = 0.1;
3479 let min_val = 1.0;
3480
3481 dequantize_uint8(&input, &mut output, scale, min_val, 5);
3482
3483 assert!((output[0] - 1.0).abs() < 1e-6); assert!((output[1] - 13.8).abs() < 1e-6); assert!((output[2] - 26.5).abs() < 1e-6); assert!((output[3] - 7.4).abs() < 1e-6); assert!((output[4] - 20.2).abs() < 1e-6); }
3490
3491 #[test]
3492 fn test_dequantize_uint8_large() {
3493 let input: Vec<u8> = (0..128).collect();
3495 let mut output = vec![0.0f32; 128];
3496 let scale = 2.0;
3497 let min_val = -10.0;
3498
3499 dequantize_uint8(&input, &mut output, scale, min_val, 128);
3500
3501 for (i, &out) in output.iter().enumerate().take(128) {
3502 let expected = i as f32 * scale + min_val;
3503 assert!(
3504 (out - expected).abs() < 1e-5,
3505 "Mismatch at {}: expected {}, got {}",
3506 i,
3507 expected,
3508 out
3509 );
3510 }
3511 }
3512
3513 #[test]
3514 fn test_dot_product_f32() {
3515 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
3516 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
3517
3518 let result = dot_product_f32(&a, &b, 5);
3519
3520 assert!((result - 70.0).abs() < 1e-5);
3522 }
3523
3524 #[test]
3525 fn test_dot_product_f32_large() {
3526 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
3528 let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
3529
3530 let result = dot_product_f32(&a, &b, 128);
3531
3532 let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
3534 assert!(
3535 (result - expected).abs() < 1e-3,
3536 "Expected {}, got {}",
3537 expected,
3538 result
3539 );
3540 }
3541
3542 #[test]
3543 fn test_max_f32() {
3544 let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
3545 let result = max_f32(&values, 6);
3546 assert!((result - 9.0).abs() < 1e-6);
3547 }
3548
3549 #[test]
3550 fn test_max_f32_large() {
3551 let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
3553 values[77] = 1000.0;
3554
3555 let result = max_f32(&values, 128);
3556 assert!((result - 1000.0).abs() < 1e-5);
3557 }
3558
3559 #[test]
3560 fn test_max_f32_negative() {
3561 let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
3562 let result = max_f32(&values, 5);
3563 assert!((result - (-1.0)).abs() < 1e-6);
3564 }
3565
3566 #[test]
3567 fn test_max_f32_empty() {
3568 let values: Vec<f32> = vec![];
3569 let result = max_f32(&values, 0);
3570 assert_eq!(result, f32::NEG_INFINITY);
3571 }
3572
3573 #[test]
3574 fn test_fused_dot_norm() {
3575 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3576 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3577 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3578
3579 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3580 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3581 assert!(
3582 (dot - expected_dot).abs() < 1e-5,
3583 "dot: expected {}, got {}",
3584 expected_dot,
3585 dot
3586 );
3587 assert!(
3588 (norm_b - expected_norm).abs() < 1e-5,
3589 "norm: expected {}, got {}",
3590 expected_norm,
3591 norm_b
3592 );
3593 }
3594
3595 #[test]
3596 fn test_fused_dot_norm_large() {
3597 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
3598 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
3599 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3600
3601 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3602 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3603 assert!(
3604 (dot - expected_dot).abs() < 1.0,
3605 "dot: expected {}, got {}",
3606 expected_dot,
3607 dot
3608 );
3609 assert!(
3610 (norm_b - expected_norm).abs() < 1.0,
3611 "norm: expected {}, got {}",
3612 expected_norm,
3613 norm_b
3614 );
3615 }
3616
3617 #[test]
3618 fn test_batch_cosine_scores() {
3619 let query = vec![1.0f32, 0.0, 0.0];
3621 let vectors = vec![
3622 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.5, 0.5, 0.0, ];
3627 let mut scores = vec![0f32; 4];
3628 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3629
3630 assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
3631 assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
3632 assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
3633 let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
3634 assert!(
3635 (scores[3] - expected_45).abs() < 1e-5,
3636 "45deg: expected {}, got {}",
3637 expected_45,
3638 scores[3]
3639 );
3640 }
3641
3642 #[test]
3643 fn test_batch_cosine_scores_matches_individual() {
3644 let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
3645 let n = 50;
3646 let dim = 128;
3647 let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
3648
3649 let mut batch_scores = vec![0f32; n];
3650 batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
3651
3652 for i in 0..n {
3653 let vec_i = &vectors[i * dim..(i + 1) * dim];
3654 let individual = cosine_similarity(&query, vec_i);
3655 assert!(
3656 (batch_scores[i] - individual).abs() < 1e-5,
3657 "vec {}: batch={}, individual={}",
3658 i,
3659 batch_scores[i],
3660 individual
3661 );
3662 }
3663 }
3664
3665 #[test]
3666 fn test_batch_cosine_scores_empty() {
3667 let query = vec![1.0f32, 2.0, 3.0];
3668 let vectors: Vec<f32> = vec![];
3669 let mut scores: Vec<f32> = vec![];
3670 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3671 assert!(scores.is_empty());
3672 }
3673
3674 #[test]
3675 fn test_batch_cosine_scores_zero_query() {
3676 let query = vec![0.0f32, 0.0, 0.0];
3677 let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
3678 let mut scores = vec![0f32; 2];
3679 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3680 assert_eq!(scores[0], 0.0);
3681 assert_eq!(scores[1], 0.0);
3682 }
3683
3684 #[test]
3685 fn test_squared_euclidean_distance() {
3686 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3687 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3688 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3689 let result = squared_euclidean_distance(&a, &b);
3690 assert!(
3691 (result - expected).abs() < 1e-5,
3692 "expected {}, got {}",
3693 expected,
3694 result
3695 );
3696 }
3697
3698 #[test]
3699 fn test_squared_euclidean_distance_large() {
3700 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
3701 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
3702 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3703 let result = squared_euclidean_distance(&a, &b);
3704 assert!(
3705 (result - expected).abs() < 1e-3,
3706 "expected {}, got {}",
3707 expected,
3708 result
3709 );
3710 }
3711
3712 #[test]
3717 fn test_f16_roundtrip_normal() {
3718 for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
3719 let h = f32_to_f16(v);
3720 let back = f16_to_f32(h);
3721 let err = (back - v).abs() / v.abs().max(1e-6);
3722 assert!(
3723 err < 0.002,
3724 "f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
3725 );
3726 }
3727 }
3728
3729 #[test]
3730 fn test_f16_special() {
3731 assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
3733 assert_eq!(f32_to_f16(-0.0), 0x8000);
3735 assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
3737 assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
3739 }
3740
3741 #[test]
3742 fn test_f16_embedding_range() {
3743 let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
3745 for &v in &values {
3746 let back = f16_to_f32(f32_to_f16(v));
3747 assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
3748 }
3749 }
3750
3751 #[test]
3756 fn test_u8_roundtrip() {
3757 assert_eq!(f32_to_u8_saturating(-1.0), 0);
3759 assert_eq!(f32_to_u8_saturating(1.0), 255);
3760 assert_eq!(f32_to_u8_saturating(0.0), 127); assert_eq!(f32_to_u8_saturating(-2.0), 0);
3764 assert_eq!(f32_to_u8_saturating(2.0), 255);
3765 }
3766
3767 #[test]
3768 fn test_u8_dequantize() {
3769 assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
3770 assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
3771 assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
3772 }
3773
3774 #[test]
3779 fn test_batch_cosine_scores_f16() {
3780 let query = vec![0.6f32, 0.8, 0.0, 0.0];
3781 let dim = 4;
3782 let vecs_f32 = vec![
3783 0.6f32, 0.8, 0.0, 0.0, 0.0, 0.0, 0.6, 0.8, ];
3786
3787 let mut f16_buf = vec![0u16; 8];
3789 batch_f32_to_f16(&vecs_f32, &mut f16_buf);
3790 let raw: &[u8] =
3791 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
3792
3793 let mut scores = vec![0f32; 2];
3794 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
3795
3796 assert!(
3797 (scores[0] - 1.0).abs() < 0.01,
3798 "identical vectors: {}",
3799 scores[0]
3800 );
3801 assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
3802 }
3803
3804 #[test]
3805 fn test_batch_cosine_scores_u8() {
3806 let query = vec![0.6f32, 0.8, 0.0, 0.0];
3807 let dim = 4;
3808 let vecs_f32 = vec![
3809 0.6f32, 0.8, 0.0, 0.0, -0.6, -0.8, 0.0, 0.0, ];
3812
3813 let mut u8_buf = vec![0u8; 8];
3815 batch_f32_to_u8(&vecs_f32, &mut u8_buf);
3816
3817 let mut scores = vec![0f32; 2];
3818 batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
3819
3820 assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
3821 assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
3822 }
3823
3824 #[test]
3825 fn test_batch_cosine_scores_f16_large_dim() {
3826 let dim = 768;
3828 let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
3829 let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
3830
3831 let mut all_vecs = query.clone();
3832 all_vecs.extend_from_slice(&vec2);
3833
3834 let mut f16_buf = vec![0u16; all_vecs.len()];
3835 batch_f32_to_f16(&all_vecs, &mut f16_buf);
3836 let raw: &[u8] =
3837 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
3838
3839 let mut scores = vec![0f32; 2];
3840 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
3841
3842 assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
3844 assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
3846 }
3847}
3848
3849#[inline]
3862pub fn find_first_ge_u32(slice: &[u32], target: u32) -> usize {
3863 #[cfg(target_arch = "aarch64")]
3864 {
3865 if neon::is_available() {
3866 return unsafe { find_first_ge_u32_neon(slice, target) };
3867 }
3868 }
3869
3870 #[cfg(target_arch = "x86_64")]
3871 {
3872 if sse::is_available() {
3873 return unsafe { find_first_ge_u32_sse(slice, target) };
3874 }
3875 }
3876
3877 slice.partition_point(|&d| d < target)
3879}
3880
3881#[cfg(target_arch = "aarch64")]
3882#[target_feature(enable = "neon")]
3883#[allow(unsafe_op_in_unsafe_fn)]
3884unsafe fn find_first_ge_u32_neon(slice: &[u32], target: u32) -> usize {
3885 use std::arch::aarch64::*;
3886
3887 let n = slice.len();
3888 let ptr = slice.as_ptr();
3889 let target_vec = vdupq_n_u32(target);
3890 let bit_mask: uint32x4_t = core::mem::transmute([1u32, 2u32, 4u32, 8u32]);
3892
3893 let chunks = n / 16;
3894 let mut base = 0usize;
3895
3896 for _ in 0..chunks {
3898 let v0 = vld1q_u32(ptr.add(base));
3899 let v1 = vld1q_u32(ptr.add(base + 4));
3900 let v2 = vld1q_u32(ptr.add(base + 8));
3901 let v3 = vld1q_u32(ptr.add(base + 12));
3902
3903 let c0 = vcgeq_u32(v0, target_vec);
3904 let c1 = vcgeq_u32(v1, target_vec);
3905 let c2 = vcgeq_u32(v2, target_vec);
3906 let c3 = vcgeq_u32(v3, target_vec);
3907
3908 let m0 = vaddvq_u32(vandq_u32(c0, bit_mask));
3909 if m0 != 0 {
3910 return base + m0.trailing_zeros() as usize;
3911 }
3912 let m1 = vaddvq_u32(vandq_u32(c1, bit_mask));
3913 if m1 != 0 {
3914 return base + 4 + m1.trailing_zeros() as usize;
3915 }
3916 let m2 = vaddvq_u32(vandq_u32(c2, bit_mask));
3917 if m2 != 0 {
3918 return base + 8 + m2.trailing_zeros() as usize;
3919 }
3920 let m3 = vaddvq_u32(vandq_u32(c3, bit_mask));
3921 if m3 != 0 {
3922 return base + 12 + m3.trailing_zeros() as usize;
3923 }
3924 base += 16;
3925 }
3926
3927 while base + 4 <= n {
3929 let vals = vld1q_u32(ptr.add(base));
3930 let cmp = vcgeq_u32(vals, target_vec);
3931 let mask = vaddvq_u32(vandq_u32(cmp, bit_mask));
3932 if mask != 0 {
3933 return base + mask.trailing_zeros() as usize;
3934 }
3935 base += 4;
3936 }
3937
3938 while base < n {
3940 if *slice.get_unchecked(base) >= target {
3941 return base;
3942 }
3943 base += 1;
3944 }
3945 n
3946}
3947
3948#[cfg(target_arch = "x86_64")]
3949#[target_feature(enable = "sse2", enable = "sse4.1")]
3950#[allow(unsafe_op_in_unsafe_fn)]
3951unsafe fn find_first_ge_u32_sse(slice: &[u32], target: u32) -> usize {
3952 use std::arch::x86_64::*;
3953
3954 let n = slice.len();
3955 let ptr = slice.as_ptr();
3956
3957 let sign_flip = _mm_set1_epi32(i32::MIN);
3959 let target_xor = _mm_xor_si128(_mm_set1_epi32(target as i32), sign_flip);
3960
3961 let chunks = n / 16;
3962 let mut base = 0usize;
3963
3964 for _ in 0..chunks {
3966 let v0 = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
3967 let v1 = _mm_xor_si128(
3968 _mm_loadu_si128(ptr.add(base + 4) as *const __m128i),
3969 sign_flip,
3970 );
3971 let v2 = _mm_xor_si128(
3972 _mm_loadu_si128(ptr.add(base + 8) as *const __m128i),
3973 sign_flip,
3974 );
3975 let v3 = _mm_xor_si128(
3976 _mm_loadu_si128(ptr.add(base + 12) as *const __m128i),
3977 sign_flip,
3978 );
3979
3980 let ge0 = _mm_or_si128(
3982 _mm_cmpeq_epi32(v0, target_xor),
3983 _mm_cmpgt_epi32(v0, target_xor),
3984 );
3985 let m0 = _mm_movemask_ps(_mm_castsi128_ps(ge0)) as u32;
3986 if m0 != 0 {
3987 return base + m0.trailing_zeros() as usize;
3988 }
3989
3990 let ge1 = _mm_or_si128(
3991 _mm_cmpeq_epi32(v1, target_xor),
3992 _mm_cmpgt_epi32(v1, target_xor),
3993 );
3994 let m1 = _mm_movemask_ps(_mm_castsi128_ps(ge1)) as u32;
3995 if m1 != 0 {
3996 return base + 4 + m1.trailing_zeros() as usize;
3997 }
3998
3999 let ge2 = _mm_or_si128(
4000 _mm_cmpeq_epi32(v2, target_xor),
4001 _mm_cmpgt_epi32(v2, target_xor),
4002 );
4003 let m2 = _mm_movemask_ps(_mm_castsi128_ps(ge2)) as u32;
4004 if m2 != 0 {
4005 return base + 8 + m2.trailing_zeros() as usize;
4006 }
4007
4008 let ge3 = _mm_or_si128(
4009 _mm_cmpeq_epi32(v3, target_xor),
4010 _mm_cmpgt_epi32(v3, target_xor),
4011 );
4012 let m3 = _mm_movemask_ps(_mm_castsi128_ps(ge3)) as u32;
4013 if m3 != 0 {
4014 return base + 12 + m3.trailing_zeros() as usize;
4015 }
4016 base += 16;
4017 }
4018
4019 while base + 4 <= n {
4021 let vals = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4022 let ge = _mm_or_si128(
4023 _mm_cmpeq_epi32(vals, target_xor),
4024 _mm_cmpgt_epi32(vals, target_xor),
4025 );
4026 let mask = _mm_movemask_ps(_mm_castsi128_ps(ge)) as u32;
4027 if mask != 0 {
4028 return base + mask.trailing_zeros() as usize;
4029 }
4030 base += 4;
4031 }
4032
4033 while base < n {
4035 if *slice.get_unchecked(base) >= target {
4036 return base;
4037 }
4038 base += 1;
4039 }
4040 n
4041}
4042
4043#[cfg(test)]
4044mod find_first_ge_tests {
4045 use super::find_first_ge_u32;
4046
4047 #[test]
4048 fn test_find_first_ge_basic() {
4049 let data: Vec<u32> = (0..128).map(|i| i * 3).collect(); assert_eq!(find_first_ge_u32(&data, 0), 0);
4051 assert_eq!(find_first_ge_u32(&data, 1), 1); assert_eq!(find_first_ge_u32(&data, 3), 1);
4053 assert_eq!(find_first_ge_u32(&data, 4), 2); assert_eq!(find_first_ge_u32(&data, 381), 127);
4055 assert_eq!(find_first_ge_u32(&data, 382), 128); }
4057
4058 #[test]
4059 fn test_find_first_ge_matches_partition_point() {
4060 let data: Vec<u32> = vec![1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75];
4061 for target in 0..80 {
4062 let expected = data.partition_point(|&d| d < target);
4063 let actual = find_first_ge_u32(&data, target);
4064 assert_eq!(actual, expected, "target={}", target);
4065 }
4066 }
4067
4068 #[test]
4069 fn test_find_first_ge_small_slices() {
4070 assert_eq!(find_first_ge_u32(&[], 5), 0);
4072 assert_eq!(find_first_ge_u32(&[10], 5), 0);
4074 assert_eq!(find_first_ge_u32(&[10], 10), 0);
4075 assert_eq!(find_first_ge_u32(&[10], 11), 1);
4076 assert_eq!(find_first_ge_u32(&[2, 4, 6], 5), 2);
4078 }
4079
4080 #[test]
4081 fn test_find_first_ge_full_block() {
4082 let data: Vec<u32> = (100..228).collect();
4084 assert_eq!(find_first_ge_u32(&data, 100), 0);
4085 assert_eq!(find_first_ge_u32(&data, 150), 50);
4086 assert_eq!(find_first_ge_u32(&data, 227), 127);
4087 assert_eq!(find_first_ge_u32(&data, 228), 128);
4088 assert_eq!(find_first_ge_u32(&data, 99), 0);
4089 }
4090
4091 #[test]
4092 fn test_find_first_ge_u32_max() {
4093 let data = vec![u32::MAX - 10, u32::MAX - 5, u32::MAX - 1, u32::MAX];
4095 assert_eq!(find_first_ge_u32(&data, u32::MAX - 10), 0);
4096 assert_eq!(find_first_ge_u32(&data, u32::MAX - 7), 1);
4097 assert_eq!(find_first_ge_u32(&data, u32::MAX), 3);
4098 }
4099}