1#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20 use std::arch::aarch64::*;
21
22 #[target_feature(enable = "neon")]
24 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25 let chunks = count / 16;
26 let remainder = count % 16;
27
28 for chunk in 0..chunks {
29 let base = chunk * 16;
30 let in_ptr = input.as_ptr().add(base);
31
32 let bytes = vld1q_u8(in_ptr);
34
35 let low8 = vget_low_u8(bytes);
37 let high8 = vget_high_u8(bytes);
38
39 let low16 = vmovl_u8(low8);
40 let high16 = vmovl_u8(high8);
41
42 let v0 = vmovl_u16(vget_low_u16(low16));
43 let v1 = vmovl_u16(vget_high_u16(low16));
44 let v2 = vmovl_u16(vget_low_u16(high16));
45 let v3 = vmovl_u16(vget_high_u16(high16));
46
47 let out_ptr = output.as_mut_ptr().add(base);
48 vst1q_u32(out_ptr, v0);
49 vst1q_u32(out_ptr.add(4), v1);
50 vst1q_u32(out_ptr.add(8), v2);
51 vst1q_u32(out_ptr.add(12), v3);
52 }
53
54 let base = chunks * 16;
56 for i in 0..remainder {
57 output[base + i] = input[base + i] as u32;
58 }
59 }
60
61 #[target_feature(enable = "neon")]
63 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64 let chunks = count / 8;
65 let remainder = count % 8;
66
67 for chunk in 0..chunks {
68 let base = chunk * 8;
69 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71 let vals = vld1q_u16(in_ptr);
72 let low = vmovl_u16(vget_low_u16(vals));
73 let high = vmovl_u16(vget_high_u16(vals));
74
75 let out_ptr = output.as_mut_ptr().add(base);
76 vst1q_u32(out_ptr, low);
77 vst1q_u32(out_ptr.add(4), high);
78 }
79
80 let base = chunks * 8;
82 for i in 0..remainder {
83 let idx = (base + i) * 2;
84 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85 }
86 }
87
88 #[target_feature(enable = "neon")]
90 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91 let chunks = count / 4;
92 let remainder = count % 4;
93
94 let in_ptr = input.as_ptr() as *const u32;
95 let out_ptr = output.as_mut_ptr();
96
97 for chunk in 0..chunks {
98 let vals = vld1q_u32(in_ptr.add(chunk * 4));
99 vst1q_u32(out_ptr.add(chunk * 4), vals);
100 }
101
102 let base = chunks * 4;
104 for i in 0..remainder {
105 let idx = (base + i) * 4;
106 output[base + i] =
107 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108 }
109 }
110
111 #[inline]
115 #[target_feature(enable = "neon")]
116 unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120 let sum1 = vaddq_u32(v, shifted1);
121
122 let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125 vaddq_u32(sum1, shifted2)
126 }
127
128 #[target_feature(enable = "neon")]
132 pub unsafe fn delta_decode(
133 output: &mut [u32],
134 deltas: &[u32],
135 first_doc_id: u32,
136 count: usize,
137 ) {
138 if count == 0 {
139 return;
140 }
141
142 output[0] = first_doc_id;
143 if count == 1 {
144 return;
145 }
146
147 let ones = vdupq_n_u32(1);
148 let mut carry = vdupq_n_u32(first_doc_id);
149
150 let full_groups = (count - 1) / 4;
151 let remainder = (count - 1) % 4;
152
153 for group in 0..full_groups {
154 let base = group * 4;
155
156 let d = vld1q_u32(deltas[base..].as_ptr());
158 let gaps = vaddq_u32(d, ones);
159
160 let prefix = prefix_sum_4(gaps);
162
163 let result = vaddq_u32(prefix, carry);
165
166 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171 }
172
173 let base = full_groups * 4;
175 let mut scalar_carry = vgetq_lane_u32(carry, 0);
176 for j in 0..remainder {
177 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178 output[base + j + 1] = scalar_carry;
179 }
180 }
181
182 #[target_feature(enable = "neon")]
184 pub unsafe fn add_one(values: &mut [u32], count: usize) {
185 let ones = vdupq_n_u32(1);
186 let chunks = count / 4;
187 let remainder = count % 4;
188
189 for chunk in 0..chunks {
190 let base = chunk * 4;
191 let ptr = values.as_mut_ptr().add(base);
192 let v = vld1q_u32(ptr);
193 let result = vaddq_u32(v, ones);
194 vst1q_u32(ptr, result);
195 }
196
197 let base = chunks * 4;
198 for i in 0..remainder {
199 values[base + i] += 1;
200 }
201 }
202
203 #[target_feature(enable = "neon")]
206 pub unsafe fn unpack_8bit_delta_decode(
207 input: &[u8],
208 output: &mut [u32],
209 first_value: u32,
210 count: usize,
211 ) {
212 output[0] = first_value;
213 if count <= 1 {
214 return;
215 }
216
217 let ones = vdupq_n_u32(1);
218 let mut carry = vdupq_n_u32(first_value);
219
220 let full_groups = (count - 1) / 4;
221 let remainder = (count - 1) % 4;
222
223 for group in 0..full_groups {
224 let base = group * 4;
225
226 let raw = std::ptr::read_unaligned(input.as_ptr().add(base) as *const u32);
228 let bytes = vreinterpret_u8_u32(vdup_n_u32(raw));
229 let u16s = vmovl_u8(bytes); let d = vmovl_u16(vget_low_u16(u16s)); let gaps = vaddq_u32(d, ones);
234
235 let prefix = prefix_sum_4(gaps);
237
238 let result = vaddq_u32(prefix, carry);
240
241 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
243
244 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
246 }
247
248 let base = full_groups * 4;
250 let mut scalar_carry = vgetq_lane_u32(carry, 0);
251 for j in 0..remainder {
252 scalar_carry = scalar_carry
253 .wrapping_add(input[base + j] as u32)
254 .wrapping_add(1);
255 output[base + j + 1] = scalar_carry;
256 }
257 }
258
259 #[target_feature(enable = "neon")]
261 pub unsafe fn unpack_16bit_delta_decode(
262 input: &[u8],
263 output: &mut [u32],
264 first_value: u32,
265 count: usize,
266 ) {
267 output[0] = first_value;
268 if count <= 1 {
269 return;
270 }
271
272 let ones = vdupq_n_u32(1);
273 let mut carry = vdupq_n_u32(first_value);
274
275 let full_groups = (count - 1) / 4;
276 let remainder = (count - 1) % 4;
277
278 for group in 0..full_groups {
279 let base = group * 4;
280 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
281
282 let vals = vld1_u16(in_ptr);
284 let d = vmovl_u16(vals);
285
286 let gaps = vaddq_u32(d, ones);
288
289 let prefix = prefix_sum_4(gaps);
291
292 let result = vaddq_u32(prefix, carry);
294
295 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
297
298 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
300 }
301
302 let base = full_groups * 4;
304 let mut scalar_carry = vgetq_lane_u32(carry, 0);
305 for j in 0..remainder {
306 let idx = (base + j) * 2;
307 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
308 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
309 output[base + j + 1] = scalar_carry;
310 }
311 }
312
313 #[inline]
315 pub fn is_available() -> bool {
316 true
317 }
318}
319
320#[cfg(target_arch = "x86_64")]
325#[allow(unsafe_op_in_unsafe_fn)]
326mod sse {
327 use std::arch::x86_64::*;
328
329 #[target_feature(enable = "sse2", enable = "sse4.1")]
331 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
332 let chunks = count / 16;
333 let remainder = count % 16;
334
335 for chunk in 0..chunks {
336 let base = chunk * 16;
337 let in_ptr = input.as_ptr().add(base);
338
339 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
340
341 let v0 = _mm_cvtepu8_epi32(bytes);
343 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
344 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
345 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
346
347 let out_ptr = output.as_mut_ptr().add(base);
348 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
349 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
350 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
351 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
352 }
353
354 let base = chunks * 16;
355 for i in 0..remainder {
356 output[base + i] = input[base + i] as u32;
357 }
358 }
359
360 #[target_feature(enable = "sse2", enable = "sse4.1")]
362 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
363 let chunks = count / 8;
364 let remainder = count % 8;
365
366 for chunk in 0..chunks {
367 let base = chunk * 8;
368 let in_ptr = input.as_ptr().add(base * 2);
369
370 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
371 let low = _mm_cvtepu16_epi32(vals);
372 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
373
374 let out_ptr = output.as_mut_ptr().add(base);
375 _mm_storeu_si128(out_ptr as *mut __m128i, low);
376 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
377 }
378
379 let base = chunks * 8;
380 for i in 0..remainder {
381 let idx = (base + i) * 2;
382 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
383 }
384 }
385
386 #[target_feature(enable = "sse2")]
388 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
389 let chunks = count / 4;
390 let remainder = count % 4;
391
392 let in_ptr = input.as_ptr() as *const __m128i;
393 let out_ptr = output.as_mut_ptr() as *mut __m128i;
394
395 for chunk in 0..chunks {
396 let vals = _mm_loadu_si128(in_ptr.add(chunk));
397 _mm_storeu_si128(out_ptr.add(chunk), vals);
398 }
399
400 let base = chunks * 4;
402 for i in 0..remainder {
403 let idx = (base + i) * 4;
404 output[base + i] =
405 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
406 }
407 }
408
409 #[inline]
413 #[target_feature(enable = "sse2")]
414 unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
415 let shifted1 = _mm_slli_si128(v, 4);
418 let sum1 = _mm_add_epi32(v, shifted1);
419
420 let shifted2 = _mm_slli_si128(sum1, 8);
423 _mm_add_epi32(sum1, shifted2)
424 }
425
426 #[target_feature(enable = "sse2", enable = "sse4.1")]
428 pub unsafe fn delta_decode(
429 output: &mut [u32],
430 deltas: &[u32],
431 first_doc_id: u32,
432 count: usize,
433 ) {
434 if count == 0 {
435 return;
436 }
437
438 output[0] = first_doc_id;
439 if count == 1 {
440 return;
441 }
442
443 let ones = _mm_set1_epi32(1);
444 let mut carry = _mm_set1_epi32(first_doc_id as i32);
445
446 let full_groups = (count - 1) / 4;
447 let remainder = (count - 1) % 4;
448
449 for group in 0..full_groups {
450 let base = group * 4;
451
452 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
454 let gaps = _mm_add_epi32(d, ones);
455
456 let prefix = prefix_sum_4(gaps);
458
459 let result = _mm_add_epi32(prefix, carry);
461
462 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
464
465 carry = _mm_shuffle_epi32(result, 0xFF); }
468
469 let base = full_groups * 4;
471 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
472 for j in 0..remainder {
473 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
474 output[base + j + 1] = scalar_carry;
475 }
476 }
477
478 #[target_feature(enable = "sse2")]
480 pub unsafe fn add_one(values: &mut [u32], count: usize) {
481 let ones = _mm_set1_epi32(1);
482 let chunks = count / 4;
483 let remainder = count % 4;
484
485 for chunk in 0..chunks {
486 let base = chunk * 4;
487 let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
488 let v = _mm_loadu_si128(ptr);
489 let result = _mm_add_epi32(v, ones);
490 _mm_storeu_si128(ptr, result);
491 }
492
493 let base = chunks * 4;
494 for i in 0..remainder {
495 values[base + i] += 1;
496 }
497 }
498
499 #[target_feature(enable = "sse2", enable = "sse4.1")]
501 pub unsafe fn unpack_8bit_delta_decode(
502 input: &[u8],
503 output: &mut [u32],
504 first_value: u32,
505 count: usize,
506 ) {
507 output[0] = first_value;
508 if count <= 1 {
509 return;
510 }
511
512 let ones = _mm_set1_epi32(1);
513 let mut carry = _mm_set1_epi32(first_value as i32);
514
515 let full_groups = (count - 1) / 4;
516 let remainder = (count - 1) % 4;
517
518 for group in 0..full_groups {
519 let base = group * 4;
520
521 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
523 input.as_ptr().add(base) as *const i32
524 ));
525 let d = _mm_cvtepu8_epi32(bytes);
526
527 let gaps = _mm_add_epi32(d, ones);
529
530 let prefix = prefix_sum_4(gaps);
532
533 let result = _mm_add_epi32(prefix, carry);
535
536 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
538
539 carry = _mm_shuffle_epi32(result, 0xFF);
541 }
542
543 let base = full_groups * 4;
545 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
546 for j in 0..remainder {
547 scalar_carry = scalar_carry
548 .wrapping_add(input[base + j] as u32)
549 .wrapping_add(1);
550 output[base + j + 1] = scalar_carry;
551 }
552 }
553
554 #[target_feature(enable = "sse2", enable = "sse4.1")]
556 pub unsafe fn unpack_16bit_delta_decode(
557 input: &[u8],
558 output: &mut [u32],
559 first_value: u32,
560 count: usize,
561 ) {
562 output[0] = first_value;
563 if count <= 1 {
564 return;
565 }
566
567 let ones = _mm_set1_epi32(1);
568 let mut carry = _mm_set1_epi32(first_value as i32);
569
570 let full_groups = (count - 1) / 4;
571 let remainder = (count - 1) % 4;
572
573 for group in 0..full_groups {
574 let base = group * 4;
575 let in_ptr = input.as_ptr().add(base * 2);
576
577 let vals = _mm_loadl_epi64(in_ptr as *const __m128i); let d = _mm_cvtepu16_epi32(vals);
580
581 let gaps = _mm_add_epi32(d, ones);
583
584 let prefix = prefix_sum_4(gaps);
586
587 let result = _mm_add_epi32(prefix, carry);
589
590 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
592
593 carry = _mm_shuffle_epi32(result, 0xFF);
595 }
596
597 let base = full_groups * 4;
599 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
600 for j in 0..remainder {
601 let idx = (base + j) * 2;
602 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
603 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
604 output[base + j + 1] = scalar_carry;
605 }
606 }
607
608 #[inline]
610 pub fn is_available() -> bool {
611 is_x86_feature_detected!("sse4.1")
612 }
613}
614
615#[cfg(target_arch = "x86_64")]
620#[allow(unsafe_op_in_unsafe_fn)]
621mod avx2 {
622 use std::arch::x86_64::*;
623
624 #[target_feature(enable = "avx2")]
626 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
627 let chunks = count / 32;
628 let remainder = count % 32;
629
630 for chunk in 0..chunks {
631 let base = chunk * 32;
632 let in_ptr = input.as_ptr().add(base);
633
634 let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
636 let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
637
638 let v0 = _mm256_cvtepu8_epi32(bytes_lo);
640 let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
641 let v2 = _mm256_cvtepu8_epi32(bytes_hi);
642 let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
643
644 let out_ptr = output.as_mut_ptr().add(base);
645 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
646 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
647 _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
648 _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
649 }
650
651 let base = chunks * 32;
653 for i in 0..remainder {
654 output[base + i] = input[base + i] as u32;
655 }
656 }
657
658 #[target_feature(enable = "avx2")]
660 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
661 let chunks = count / 16;
662 let remainder = count % 16;
663
664 for chunk in 0..chunks {
665 let base = chunk * 16;
666 let in_ptr = input.as_ptr().add(base * 2);
667
668 let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
670 let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
671
672 let v0 = _mm256_cvtepu16_epi32(vals_lo);
674 let v1 = _mm256_cvtepu16_epi32(vals_hi);
675
676 let out_ptr = output.as_mut_ptr().add(base);
677 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
678 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
679 }
680
681 let base = chunks * 16;
683 for i in 0..remainder {
684 let idx = (base + i) * 2;
685 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
686 }
687 }
688
689 #[target_feature(enable = "avx2")]
691 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
692 let chunks = count / 8;
693 let remainder = count % 8;
694
695 let in_ptr = input.as_ptr() as *const __m256i;
696 let out_ptr = output.as_mut_ptr() as *mut __m256i;
697
698 for chunk in 0..chunks {
699 let vals = _mm256_loadu_si256(in_ptr.add(chunk));
700 _mm256_storeu_si256(out_ptr.add(chunk), vals);
701 }
702
703 let base = chunks * 8;
705 for i in 0..remainder {
706 let idx = (base + i) * 4;
707 output[base + i] =
708 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
709 }
710 }
711
712 #[target_feature(enable = "avx2")]
714 pub unsafe fn add_one(values: &mut [u32], count: usize) {
715 let ones = _mm256_set1_epi32(1);
716 let chunks = count / 8;
717 let remainder = count % 8;
718
719 for chunk in 0..chunks {
720 let base = chunk * 8;
721 let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
722 let v = _mm256_loadu_si256(ptr);
723 let result = _mm256_add_epi32(v, ones);
724 _mm256_storeu_si256(ptr, result);
725 }
726
727 let base = chunks * 8;
728 for i in 0..remainder {
729 values[base + i] += 1;
730 }
731 }
732
733 #[inline]
737 #[target_feature(enable = "avx2")]
738 unsafe fn prefix_sum_8(v: __m256i) -> __m256i {
739 let s1 = _mm256_slli_si256(v, 4);
741 let r1 = _mm256_add_epi32(v, s1);
742
743 let s2 = _mm256_slli_si256(r1, 8);
745 let r2 = _mm256_add_epi32(r1, s2);
746
747 let lo_sum = _mm256_shuffle_epi32(r2, 0xFF);
750 let carry = _mm256_permute2x128_si256(lo_sum, lo_sum, 0x00);
752 let carry_hi = _mm256_blend_epi32::<0xF0>(_mm256_setzero_si256(), carry);
754 _mm256_add_epi32(r2, carry_hi)
755 }
756
757 #[target_feature(enable = "avx2")]
759 pub unsafe fn unpack_8bit_delta_decode(
760 input: &[u8],
761 output: &mut [u32],
762 first_value: u32,
763 count: usize,
764 ) {
765 output[0] = first_value;
766 if count <= 1 {
767 return;
768 }
769
770 let ones = _mm256_set1_epi32(1);
771 let mut carry = _mm256_set1_epi32(first_value as i32);
772 let broadcast_idx = _mm256_set1_epi32(7);
773
774 let full_groups = (count - 1) / 8;
775 let remainder = (count - 1) % 8;
776
777 for group in 0..full_groups {
778 let base = group * 8;
779
780 let bytes = _mm_loadl_epi64(input.as_ptr().add(base) as *const __m128i);
782 let d = _mm256_cvtepu8_epi32(bytes);
783
784 let gaps = _mm256_add_epi32(d, ones);
786
787 let prefix = prefix_sum_8(gaps);
789
790 let result = _mm256_add_epi32(prefix, carry);
792
793 _mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
795
796 carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
798 }
799
800 let base = full_groups * 8;
802 let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
803 for j in 0..remainder {
804 scalar_carry = scalar_carry
805 .wrapping_add(input[base + j] as u32)
806 .wrapping_add(1);
807 output[base + j + 1] = scalar_carry;
808 }
809 }
810
811 #[target_feature(enable = "avx2")]
813 pub unsafe fn unpack_16bit_delta_decode(
814 input: &[u8],
815 output: &mut [u32],
816 first_value: u32,
817 count: usize,
818 ) {
819 output[0] = first_value;
820 if count <= 1 {
821 return;
822 }
823
824 let ones = _mm256_set1_epi32(1);
825 let mut carry = _mm256_set1_epi32(first_value as i32);
826 let broadcast_idx = _mm256_set1_epi32(7);
827
828 let full_groups = (count - 1) / 8;
829 let remainder = (count - 1) % 8;
830
831 for group in 0..full_groups {
832 let base = group * 8;
833 let in_ptr = input.as_ptr().add(base * 2);
834
835 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
837 let d = _mm256_cvtepu16_epi32(vals);
838
839 let gaps = _mm256_add_epi32(d, ones);
841
842 let prefix = prefix_sum_8(gaps);
844
845 let result = _mm256_add_epi32(prefix, carry);
847
848 _mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
850
851 carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
853 }
854
855 let base = full_groups * 8;
857 let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
858 for j in 0..remainder {
859 let idx = (base + j) * 2;
860 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
861 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
862 output[base + j + 1] = scalar_carry;
863 }
864 }
865
866 #[inline]
868 pub fn is_available() -> bool {
869 is_x86_feature_detected!("avx2")
870 }
871}
872
873#[allow(dead_code)]
878mod scalar {
879 #[inline]
881 pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
882 for i in 0..count {
883 output[i] = input[i] as u32;
884 }
885 }
886
887 #[inline]
889 pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
890 for (i, out) in output.iter_mut().enumerate().take(count) {
891 let idx = i * 2;
892 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
893 }
894 }
895
896 #[inline]
898 pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
899 for (i, out) in output.iter_mut().enumerate().take(count) {
900 let idx = i * 4;
901 *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
902 }
903 }
904
905 #[inline]
907 pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
908 if count == 0 {
909 return;
910 }
911
912 output[0] = first_doc_id;
913 let mut carry = first_doc_id;
914
915 for i in 0..count - 1 {
916 carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
917 output[i + 1] = carry;
918 }
919 }
920
921 #[inline]
923 pub fn add_one(values: &mut [u32], count: usize) {
924 for val in values.iter_mut().take(count) {
925 *val += 1;
926 }
927 }
928}
929
930#[inline]
936pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
937 #[cfg(target_arch = "aarch64")]
938 {
939 if neon::is_available() {
940 unsafe {
941 neon::unpack_8bit(input, output, count);
942 }
943 return;
944 }
945 }
946
947 #[cfg(target_arch = "x86_64")]
948 {
949 if avx2::is_available() {
951 unsafe {
952 avx2::unpack_8bit(input, output, count);
953 }
954 return;
955 }
956 if sse::is_available() {
957 unsafe {
958 sse::unpack_8bit(input, output, count);
959 }
960 return;
961 }
962 }
963
964 scalar::unpack_8bit(input, output, count);
965}
966
967#[inline]
969pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
970 #[cfg(target_arch = "aarch64")]
971 {
972 if neon::is_available() {
973 unsafe {
974 neon::unpack_16bit(input, output, count);
975 }
976 return;
977 }
978 }
979
980 #[cfg(target_arch = "x86_64")]
981 {
982 if avx2::is_available() {
984 unsafe {
985 avx2::unpack_16bit(input, output, count);
986 }
987 return;
988 }
989 if sse::is_available() {
990 unsafe {
991 sse::unpack_16bit(input, output, count);
992 }
993 return;
994 }
995 }
996
997 scalar::unpack_16bit(input, output, count);
998}
999
1000#[inline]
1002pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
1003 #[cfg(target_arch = "aarch64")]
1004 {
1005 if neon::is_available() {
1006 unsafe {
1007 neon::unpack_32bit(input, output, count);
1008 }
1009 }
1010 }
1011
1012 #[cfg(target_arch = "x86_64")]
1013 {
1014 if avx2::is_available() {
1016 unsafe {
1017 avx2::unpack_32bit(input, output, count);
1018 }
1019 } else {
1020 unsafe {
1022 sse::unpack_32bit(input, output, count);
1023 }
1024 }
1025 }
1026
1027 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
1028 {
1029 scalar::unpack_32bit(input, output, count);
1030 }
1031}
1032
1033#[inline]
1039pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
1040 #[cfg(target_arch = "aarch64")]
1041 {
1042 if neon::is_available() {
1043 unsafe {
1044 neon::delta_decode(output, deltas, first_value, count);
1045 }
1046 return;
1047 }
1048 }
1049
1050 #[cfg(target_arch = "x86_64")]
1051 {
1052 if sse::is_available() {
1053 unsafe {
1054 sse::delta_decode(output, deltas, first_value, count);
1055 }
1056 return;
1057 }
1058 }
1059
1060 scalar::delta_decode(output, deltas, first_value, count);
1061}
1062
1063#[inline]
1067pub fn add_one(values: &mut [u32], count: usize) {
1068 #[cfg(target_arch = "aarch64")]
1069 {
1070 if neon::is_available() {
1071 unsafe {
1072 neon::add_one(values, count);
1073 }
1074 }
1075 }
1076
1077 #[cfg(target_arch = "x86_64")]
1078 {
1079 if avx2::is_available() {
1081 unsafe {
1082 avx2::add_one(values, count);
1083 }
1084 } else {
1085 unsafe {
1087 sse::add_one(values, count);
1088 }
1089 }
1090 }
1091
1092 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
1093 {
1094 scalar::add_one(values, count);
1095 }
1096}
1097
1098#[inline]
1100pub fn bits_needed(val: u32) -> u8 {
1101 if val == 0 {
1102 0
1103 } else {
1104 32 - val.leading_zeros() as u8
1105 }
1106}
1107
1108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1125#[repr(u8)]
1126pub enum RoundedBitWidth {
1127 Zero = 0,
1128 Bits8 = 8,
1129 Bits16 = 16,
1130 Bits32 = 32,
1131}
1132
1133impl RoundedBitWidth {
1134 #[inline]
1136 pub fn from_exact(bits: u8) -> Self {
1137 match bits {
1138 0 => RoundedBitWidth::Zero,
1139 1..=8 => RoundedBitWidth::Bits8,
1140 9..=16 => RoundedBitWidth::Bits16,
1141 _ => RoundedBitWidth::Bits32,
1142 }
1143 }
1144
1145 #[inline]
1147 pub fn from_u8(bits: u8) -> Self {
1148 match bits {
1149 0 => RoundedBitWidth::Zero,
1150 8 => RoundedBitWidth::Bits8,
1151 16 => RoundedBitWidth::Bits16,
1152 32 => RoundedBitWidth::Bits32,
1153 _ => RoundedBitWidth::Bits32, }
1155 }
1156
1157 #[inline]
1159 pub fn bytes_per_value(self) -> usize {
1160 match self {
1161 RoundedBitWidth::Zero => 0,
1162 RoundedBitWidth::Bits8 => 1,
1163 RoundedBitWidth::Bits16 => 2,
1164 RoundedBitWidth::Bits32 => 4,
1165 }
1166 }
1167
1168 #[inline]
1170 pub fn as_u8(self) -> u8 {
1171 self as u8
1172 }
1173}
1174
1175#[inline]
1177pub fn round_bit_width(bits: u8) -> u8 {
1178 RoundedBitWidth::from_exact(bits).as_u8()
1179}
1180
1181#[inline]
1186pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1187 let count = values.len();
1188 match bit_width {
1189 RoundedBitWidth::Zero => 0,
1190 RoundedBitWidth::Bits8 => {
1191 for (i, &v) in values.iter().enumerate() {
1192 output[i] = v as u8;
1193 }
1194 count
1195 }
1196 RoundedBitWidth::Bits16 => {
1197 for (i, &v) in values.iter().enumerate() {
1198 let bytes = (v as u16).to_le_bytes();
1199 output[i * 2] = bytes[0];
1200 output[i * 2 + 1] = bytes[1];
1201 }
1202 count * 2
1203 }
1204 RoundedBitWidth::Bits32 => {
1205 for (i, &v) in values.iter().enumerate() {
1206 let bytes = v.to_le_bytes();
1207 output[i * 4] = bytes[0];
1208 output[i * 4 + 1] = bytes[1];
1209 output[i * 4 + 2] = bytes[2];
1210 output[i * 4 + 3] = bytes[3];
1211 }
1212 count * 4
1213 }
1214 }
1215}
1216
1217#[inline]
1221pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1222 match bit_width {
1223 RoundedBitWidth::Zero => {
1224 for out in output.iter_mut().take(count) {
1225 *out = 0;
1226 }
1227 }
1228 RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1229 RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1230 RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1231 }
1232}
1233
1234#[inline]
1238pub fn unpack_rounded_delta_decode(
1239 input: &[u8],
1240 bit_width: RoundedBitWidth,
1241 output: &mut [u32],
1242 first_value: u32,
1243 count: usize,
1244) {
1245 match bit_width {
1246 RoundedBitWidth::Zero => {
1247 let mut val = first_value;
1249 for out in output.iter_mut().take(count) {
1250 *out = val;
1251 val = val.wrapping_add(1);
1252 }
1253 }
1254 RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1255 RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1256 RoundedBitWidth::Bits32 => {
1257 if count > 0 {
1259 output[0] = first_value;
1260 let mut carry = first_value;
1261 for i in 0..count - 1 {
1262 let idx = i * 4;
1263 let delta = u32::from_le_bytes([
1264 input[idx],
1265 input[idx + 1],
1266 input[idx + 2],
1267 input[idx + 3],
1268 ]);
1269 carry = carry.wrapping_add(delta).wrapping_add(1);
1270 output[i + 1] = carry;
1271 }
1272 }
1273 }
1274 }
1275}
1276
1277#[inline]
1286pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1287 if count == 0 {
1288 return;
1289 }
1290
1291 output[0] = first_value;
1292 if count == 1 {
1293 return;
1294 }
1295
1296 #[cfg(target_arch = "aarch64")]
1297 {
1298 if neon::is_available() {
1299 unsafe {
1300 neon::unpack_8bit_delta_decode(input, output, first_value, count);
1301 }
1302 return;
1303 }
1304 }
1305
1306 #[cfg(target_arch = "x86_64")]
1307 {
1308 if avx2::is_available() {
1309 unsafe {
1310 avx2::unpack_8bit_delta_decode(input, output, first_value, count);
1311 }
1312 return;
1313 }
1314 if sse::is_available() {
1315 unsafe {
1316 sse::unpack_8bit_delta_decode(input, output, first_value, count);
1317 }
1318 return;
1319 }
1320 }
1321
1322 let mut carry = first_value;
1324 for i in 0..count - 1 {
1325 carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1326 output[i + 1] = carry;
1327 }
1328}
1329
1330#[inline]
1332pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1333 if count == 0 {
1334 return;
1335 }
1336
1337 output[0] = first_value;
1338 if count == 1 {
1339 return;
1340 }
1341
1342 #[cfg(target_arch = "aarch64")]
1343 {
1344 if neon::is_available() {
1345 unsafe {
1346 neon::unpack_16bit_delta_decode(input, output, first_value, count);
1347 }
1348 return;
1349 }
1350 }
1351
1352 #[cfg(target_arch = "x86_64")]
1353 {
1354 if avx2::is_available() {
1355 unsafe {
1356 avx2::unpack_16bit_delta_decode(input, output, first_value, count);
1357 }
1358 return;
1359 }
1360 if sse::is_available() {
1361 unsafe {
1362 sse::unpack_16bit_delta_decode(input, output, first_value, count);
1363 }
1364 return;
1365 }
1366 }
1367
1368 let mut carry = first_value;
1370 for i in 0..count - 1 {
1371 let idx = i * 2;
1372 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1373 carry = carry.wrapping_add(delta).wrapping_add(1);
1374 output[i + 1] = carry;
1375 }
1376}
1377
1378#[inline]
1383pub fn unpack_delta_decode(
1384 input: &[u8],
1385 bit_width: u8,
1386 output: &mut [u32],
1387 first_value: u32,
1388 count: usize,
1389) {
1390 if count == 0 {
1391 return;
1392 }
1393
1394 output[0] = first_value;
1395 if count == 1 {
1396 return;
1397 }
1398
1399 match bit_width {
1401 0 => {
1402 let mut val = first_value;
1404 for item in output.iter_mut().take(count).skip(1) {
1405 val = val.wrapping_add(1);
1406 *item = val;
1407 }
1408 }
1409 8 => unpack_8bit_delta_decode(input, output, first_value, count),
1410 16 => unpack_16bit_delta_decode(input, output, first_value, count),
1411 32 => {
1412 let mut carry = first_value;
1414 for i in 0..count - 1 {
1415 let idx = i * 4;
1416 let delta = u32::from_le_bytes([
1417 input[idx],
1418 input[idx + 1],
1419 input[idx + 2],
1420 input[idx + 3],
1421 ]);
1422 carry = carry.wrapping_add(delta).wrapping_add(1);
1423 output[i + 1] = carry;
1424 }
1425 }
1426 _ => {
1427 let mask = (1u64 << bit_width) - 1;
1429 let bit_width_usize = bit_width as usize;
1430 let mut bit_pos = 0usize;
1431 let input_ptr = input.as_ptr();
1432 let mut carry = first_value;
1433
1434 for i in 0..count - 1 {
1435 let byte_idx = bit_pos >> 3;
1436 let bit_offset = bit_pos & 7;
1437
1438 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1440 let delta = ((word >> bit_offset) & mask) as u32;
1441
1442 carry = carry.wrapping_add(delta).wrapping_add(1);
1443 output[i + 1] = carry;
1444 bit_pos += bit_width_usize;
1445 }
1446 }
1447 }
1448}
1449
1450#[inline]
1458pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1459 #[cfg(target_arch = "aarch64")]
1460 {
1461 if neon::is_available() {
1462 unsafe {
1463 dequantize_uint8_neon(input, output, scale, min_val, count);
1464 }
1465 return;
1466 }
1467 }
1468
1469 #[cfg(target_arch = "x86_64")]
1470 {
1471 if sse::is_available() {
1472 unsafe {
1473 dequantize_uint8_sse(input, output, scale, min_val, count);
1474 }
1475 return;
1476 }
1477 }
1478
1479 for i in 0..count {
1481 output[i] = input[i] as f32 * scale + min_val;
1482 }
1483}
1484
1485#[cfg(target_arch = "aarch64")]
1486#[target_feature(enable = "neon")]
1487#[allow(unsafe_op_in_unsafe_fn)]
1488unsafe fn dequantize_uint8_neon(
1489 input: &[u8],
1490 output: &mut [f32],
1491 scale: f32,
1492 min_val: f32,
1493 count: usize,
1494) {
1495 use std::arch::aarch64::*;
1496
1497 let scale_v = vdupq_n_f32(scale);
1498 let min_v = vdupq_n_f32(min_val);
1499
1500 let chunks = count / 16;
1501 let remainder = count % 16;
1502
1503 for chunk in 0..chunks {
1504 let base = chunk * 16;
1505 let in_ptr = input.as_ptr().add(base);
1506
1507 let bytes = vld1q_u8(in_ptr);
1509
1510 let low8 = vget_low_u8(bytes);
1512 let high8 = vget_high_u8(bytes);
1513
1514 let low16 = vmovl_u8(low8);
1515 let high16 = vmovl_u8(high8);
1516
1517 let u32_0 = vmovl_u16(vget_low_u16(low16));
1519 let u32_1 = vmovl_u16(vget_high_u16(low16));
1520 let u32_2 = vmovl_u16(vget_low_u16(high16));
1521 let u32_3 = vmovl_u16(vget_high_u16(high16));
1522
1523 let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1525 let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1526 let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1527 let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1528
1529 let out_ptr = output.as_mut_ptr().add(base);
1530 vst1q_f32(out_ptr, f32_0);
1531 vst1q_f32(out_ptr.add(4), f32_1);
1532 vst1q_f32(out_ptr.add(8), f32_2);
1533 vst1q_f32(out_ptr.add(12), f32_3);
1534 }
1535
1536 let base = chunks * 16;
1538 for i in 0..remainder {
1539 output[base + i] = input[base + i] as f32 * scale + min_val;
1540 }
1541}
1542
1543#[cfg(target_arch = "x86_64")]
1544#[target_feature(enable = "sse2", enable = "sse4.1")]
1545#[allow(unsafe_op_in_unsafe_fn)]
1546unsafe fn dequantize_uint8_sse(
1547 input: &[u8],
1548 output: &mut [f32],
1549 scale: f32,
1550 min_val: f32,
1551 count: usize,
1552) {
1553 use std::arch::x86_64::*;
1554
1555 let scale_v = _mm_set1_ps(scale);
1556 let min_v = _mm_set1_ps(min_val);
1557
1558 let chunks = count / 4;
1559 let remainder = count % 4;
1560
1561 for chunk in 0..chunks {
1562 let base = chunk * 4;
1563
1564 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
1566 input.as_ptr().add(base) as *const i32
1567 ));
1568 let ints = _mm_cvtepu8_epi32(bytes);
1569 let floats = _mm_cvtepi32_ps(ints);
1570
1571 let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1573
1574 _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1575 }
1576
1577 let base = chunks * 4;
1579 for i in 0..remainder {
1580 output[base + i] = input[base + i] as f32 * scale + min_val;
1581 }
1582}
1583
1584#[inline]
1586pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1587 #[cfg(target_arch = "aarch64")]
1588 {
1589 if neon::is_available() {
1590 return unsafe { dot_product_f32_neon(a, b, count) };
1591 }
1592 }
1593
1594 #[cfg(target_arch = "x86_64")]
1595 {
1596 if is_x86_feature_detected!("avx512f") {
1597 return unsafe { dot_product_f32_avx512(a, b, count) };
1598 }
1599 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1600 return unsafe { dot_product_f32_avx2(a, b, count) };
1601 }
1602 if sse::is_available() {
1603 return unsafe { dot_product_f32_sse(a, b, count) };
1604 }
1605 }
1606
1607 let mut sum = 0.0f32;
1609 for i in 0..count {
1610 sum += a[i] * b[i];
1611 }
1612 sum
1613}
1614
1615#[cfg(target_arch = "aarch64")]
1616#[target_feature(enable = "neon")]
1617#[allow(unsafe_op_in_unsafe_fn)]
1618unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1619 use std::arch::aarch64::*;
1620
1621 let chunks16 = count / 16;
1622 let remainder = count % 16;
1623
1624 let mut acc0 = vdupq_n_f32(0.0);
1625 let mut acc1 = vdupq_n_f32(0.0);
1626 let mut acc2 = vdupq_n_f32(0.0);
1627 let mut acc3 = vdupq_n_f32(0.0);
1628
1629 for c in 0..chunks16 {
1630 let base = c * 16;
1631 acc0 = vfmaq_f32(
1632 acc0,
1633 vld1q_f32(a.as_ptr().add(base)),
1634 vld1q_f32(b.as_ptr().add(base)),
1635 );
1636 acc1 = vfmaq_f32(
1637 acc1,
1638 vld1q_f32(a.as_ptr().add(base + 4)),
1639 vld1q_f32(b.as_ptr().add(base + 4)),
1640 );
1641 acc2 = vfmaq_f32(
1642 acc2,
1643 vld1q_f32(a.as_ptr().add(base + 8)),
1644 vld1q_f32(b.as_ptr().add(base + 8)),
1645 );
1646 acc3 = vfmaq_f32(
1647 acc3,
1648 vld1q_f32(a.as_ptr().add(base + 12)),
1649 vld1q_f32(b.as_ptr().add(base + 12)),
1650 );
1651 }
1652
1653 let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1654 let mut sum = vaddvq_f32(acc);
1655
1656 let base = chunks16 * 16;
1657 for i in 0..remainder {
1658 sum += a[base + i] * b[base + i];
1659 }
1660
1661 sum
1662}
1663
1664#[cfg(target_arch = "x86_64")]
1665#[target_feature(enable = "avx2", enable = "fma")]
1666#[allow(unsafe_op_in_unsafe_fn)]
1667unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
1668 use std::arch::x86_64::*;
1669
1670 let chunks32 = count / 32;
1671 let remainder = count % 32;
1672
1673 let mut acc0 = _mm256_setzero_ps();
1674 let mut acc1 = _mm256_setzero_ps();
1675 let mut acc2 = _mm256_setzero_ps();
1676 let mut acc3 = _mm256_setzero_ps();
1677
1678 for c in 0..chunks32 {
1679 let base = c * 32;
1680 acc0 = _mm256_fmadd_ps(
1681 _mm256_loadu_ps(a.as_ptr().add(base)),
1682 _mm256_loadu_ps(b.as_ptr().add(base)),
1683 acc0,
1684 );
1685 acc1 = _mm256_fmadd_ps(
1686 _mm256_loadu_ps(a.as_ptr().add(base + 8)),
1687 _mm256_loadu_ps(b.as_ptr().add(base + 8)),
1688 acc1,
1689 );
1690 acc2 = _mm256_fmadd_ps(
1691 _mm256_loadu_ps(a.as_ptr().add(base + 16)),
1692 _mm256_loadu_ps(b.as_ptr().add(base + 16)),
1693 acc2,
1694 );
1695 acc3 = _mm256_fmadd_ps(
1696 _mm256_loadu_ps(a.as_ptr().add(base + 24)),
1697 _mm256_loadu_ps(b.as_ptr().add(base + 24)),
1698 acc3,
1699 );
1700 }
1701
1702 let acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
1703
1704 let hi = _mm256_extractf128_ps(acc, 1);
1706 let lo = _mm256_castps256_ps128(acc);
1707 let sum128 = _mm_add_ps(lo, hi);
1708 let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
1709 let sums = _mm_add_ps(sum128, shuf);
1710 let shuf2 = _mm_movehl_ps(sums, sums);
1711 let final_sum = _mm_add_ss(sums, shuf2);
1712
1713 let mut sum = _mm_cvtss_f32(final_sum);
1714
1715 let base = chunks32 * 32;
1716 for i in 0..remainder {
1717 sum += a[base + i] * b[base + i];
1718 }
1719
1720 sum
1721}
1722
1723#[cfg(target_arch = "x86_64")]
1724#[target_feature(enable = "sse")]
1725#[allow(unsafe_op_in_unsafe_fn)]
1726unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1727 use std::arch::x86_64::*;
1728
1729 let chunks = count / 4;
1730 let remainder = count % 4;
1731
1732 let mut acc = _mm_setzero_ps();
1733
1734 for chunk in 0..chunks {
1735 let base = chunk * 4;
1736 let va = _mm_loadu_ps(a.as_ptr().add(base));
1737 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1738 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1739 }
1740
1741 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
1748
1749 let base = chunks * 4;
1751 for i in 0..remainder {
1752 sum += a[base + i] * b[base + i];
1753 }
1754
1755 sum
1756}
1757
1758#[cfg(target_arch = "x86_64")]
1759#[target_feature(enable = "avx512f")]
1760#[allow(unsafe_op_in_unsafe_fn)]
1761unsafe fn dot_product_f32_avx512(a: &[f32], b: &[f32], count: usize) -> f32 {
1762 use std::arch::x86_64::*;
1763
1764 let chunks64 = count / 64;
1765 let remainder = count % 64;
1766
1767 let mut acc0 = _mm512_setzero_ps();
1768 let mut acc1 = _mm512_setzero_ps();
1769 let mut acc2 = _mm512_setzero_ps();
1770 let mut acc3 = _mm512_setzero_ps();
1771
1772 for c in 0..chunks64 {
1773 let base = c * 64;
1774 acc0 = _mm512_fmadd_ps(
1775 _mm512_loadu_ps(a.as_ptr().add(base)),
1776 _mm512_loadu_ps(b.as_ptr().add(base)),
1777 acc0,
1778 );
1779 acc1 = _mm512_fmadd_ps(
1780 _mm512_loadu_ps(a.as_ptr().add(base + 16)),
1781 _mm512_loadu_ps(b.as_ptr().add(base + 16)),
1782 acc1,
1783 );
1784 acc2 = _mm512_fmadd_ps(
1785 _mm512_loadu_ps(a.as_ptr().add(base + 32)),
1786 _mm512_loadu_ps(b.as_ptr().add(base + 32)),
1787 acc2,
1788 );
1789 acc3 = _mm512_fmadd_ps(
1790 _mm512_loadu_ps(a.as_ptr().add(base + 48)),
1791 _mm512_loadu_ps(b.as_ptr().add(base + 48)),
1792 acc3,
1793 );
1794 }
1795
1796 let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
1797 let mut sum = _mm512_reduce_add_ps(acc);
1798
1799 let base = chunks64 * 64;
1800 for i in 0..remainder {
1801 sum += a[base + i] * b[base + i];
1802 }
1803
1804 sum
1805}
1806
1807#[cfg(target_arch = "x86_64")]
1808#[target_feature(enable = "avx512f")]
1809#[allow(unsafe_op_in_unsafe_fn)]
1810unsafe fn fused_dot_norm_avx512(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1811 use std::arch::x86_64::*;
1812
1813 let chunks64 = count / 64;
1814 let remainder = count % 64;
1815
1816 let mut d0 = _mm512_setzero_ps();
1817 let mut d1 = _mm512_setzero_ps();
1818 let mut d2 = _mm512_setzero_ps();
1819 let mut d3 = _mm512_setzero_ps();
1820 let mut n0 = _mm512_setzero_ps();
1821 let mut n1 = _mm512_setzero_ps();
1822 let mut n2 = _mm512_setzero_ps();
1823 let mut n3 = _mm512_setzero_ps();
1824
1825 for c in 0..chunks64 {
1826 let base = c * 64;
1827 let vb0 = _mm512_loadu_ps(b.as_ptr().add(base));
1828 d0 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1829 n0 = _mm512_fmadd_ps(vb0, vb0, n0);
1830 let vb1 = _mm512_loadu_ps(b.as_ptr().add(base + 16));
1831 d1 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 16)), vb1, d1);
1832 n1 = _mm512_fmadd_ps(vb1, vb1, n1);
1833 let vb2 = _mm512_loadu_ps(b.as_ptr().add(base + 32));
1834 d2 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 32)), vb2, d2);
1835 n2 = _mm512_fmadd_ps(vb2, vb2, n2);
1836 let vb3 = _mm512_loadu_ps(b.as_ptr().add(base + 48));
1837 d3 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 48)), vb3, d3);
1838 n3 = _mm512_fmadd_ps(vb3, vb3, n3);
1839 }
1840
1841 let acc_dot = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
1842 let acc_norm = _mm512_add_ps(_mm512_add_ps(n0, n1), _mm512_add_ps(n2, n3));
1843 let mut dot = _mm512_reduce_add_ps(acc_dot);
1844 let mut norm = _mm512_reduce_add_ps(acc_norm);
1845
1846 let base = chunks64 * 64;
1847 for i in 0..remainder {
1848 dot += a[base + i] * b[base + i];
1849 norm += b[base + i] * b[base + i];
1850 }
1851
1852 (dot, norm)
1853}
1854
1855#[inline]
1857pub fn max_f32(values: &[f32], count: usize) -> f32 {
1858 if count == 0 {
1859 return f32::NEG_INFINITY;
1860 }
1861
1862 #[cfg(target_arch = "aarch64")]
1863 {
1864 if neon::is_available() {
1865 return unsafe { max_f32_neon(values, count) };
1866 }
1867 }
1868
1869 #[cfg(target_arch = "x86_64")]
1870 {
1871 if sse::is_available() {
1872 return unsafe { max_f32_sse(values, count) };
1873 }
1874 }
1875
1876 values[..count]
1878 .iter()
1879 .cloned()
1880 .fold(f32::NEG_INFINITY, f32::max)
1881}
1882
1883#[cfg(target_arch = "aarch64")]
1884#[target_feature(enable = "neon")]
1885#[allow(unsafe_op_in_unsafe_fn)]
1886unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1887 use std::arch::aarch64::*;
1888
1889 let chunks = count / 4;
1890 let remainder = count % 4;
1891
1892 let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1893
1894 for chunk in 0..chunks {
1895 let base = chunk * 4;
1896 let v = vld1q_f32(values.as_ptr().add(base));
1897 max_v = vmaxq_f32(max_v, v);
1898 }
1899
1900 let mut max_val = vmaxvq_f32(max_v);
1902
1903 let base = chunks * 4;
1905 for i in 0..remainder {
1906 max_val = max_val.max(values[base + i]);
1907 }
1908
1909 max_val
1910}
1911
1912#[cfg(target_arch = "x86_64")]
1913#[target_feature(enable = "sse")]
1914#[allow(unsafe_op_in_unsafe_fn)]
1915unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1916 use std::arch::x86_64::*;
1917
1918 let chunks = count / 4;
1919 let remainder = count % 4;
1920
1921 let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1922
1923 for chunk in 0..chunks {
1924 let base = chunk * 4;
1925 let v = _mm_loadu_ps(values.as_ptr().add(base));
1926 max_v = _mm_max_ps(max_v, v);
1927 }
1928
1929 let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); let max1 = _mm_max_ps(max_v, shuf); let shuf2 = _mm_movehl_ps(max1, max1); let final_max = _mm_max_ss(max1, shuf2); let mut max_val = _mm_cvtss_f32(final_max);
1936
1937 let base = chunks * 4;
1939 for i in 0..remainder {
1940 max_val = max_val.max(values[base + i]);
1941 }
1942
1943 max_val
1944}
1945
1946#[inline]
1955fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1956 #[cfg(target_arch = "aarch64")]
1957 {
1958 if neon::is_available() {
1959 return unsafe { fused_dot_norm_neon(a, b, count) };
1960 }
1961 }
1962
1963 #[cfg(target_arch = "x86_64")]
1964 {
1965 if is_x86_feature_detected!("avx512f") {
1966 return unsafe { fused_dot_norm_avx512(a, b, count) };
1967 }
1968 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1969 return unsafe { fused_dot_norm_avx2(a, b, count) };
1970 }
1971 if sse::is_available() {
1972 return unsafe { fused_dot_norm_sse(a, b, count) };
1973 }
1974 }
1975
1976 let mut dot = 0.0f32;
1978 let mut norm_b = 0.0f32;
1979 for i in 0..count {
1980 dot += a[i] * b[i];
1981 norm_b += b[i] * b[i];
1982 }
1983 (dot, norm_b)
1984}
1985
1986#[cfg(target_arch = "aarch64")]
1987#[target_feature(enable = "neon")]
1988#[allow(unsafe_op_in_unsafe_fn)]
1989unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1990 use std::arch::aarch64::*;
1991
1992 let chunks16 = count / 16;
1993 let remainder = count % 16;
1994
1995 let mut d0 = vdupq_n_f32(0.0);
1996 let mut d1 = vdupq_n_f32(0.0);
1997 let mut d2 = vdupq_n_f32(0.0);
1998 let mut d3 = vdupq_n_f32(0.0);
1999 let mut n0 = vdupq_n_f32(0.0);
2000 let mut n1 = vdupq_n_f32(0.0);
2001 let mut n2 = vdupq_n_f32(0.0);
2002 let mut n3 = vdupq_n_f32(0.0);
2003
2004 for c in 0..chunks16 {
2005 let base = c * 16;
2006 let va0 = vld1q_f32(a.as_ptr().add(base));
2007 let vb0 = vld1q_f32(b.as_ptr().add(base));
2008 d0 = vfmaq_f32(d0, va0, vb0);
2009 n0 = vfmaq_f32(n0, vb0, vb0);
2010 let va1 = vld1q_f32(a.as_ptr().add(base + 4));
2011 let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
2012 d1 = vfmaq_f32(d1, va1, vb1);
2013 n1 = vfmaq_f32(n1, vb1, vb1);
2014 let va2 = vld1q_f32(a.as_ptr().add(base + 8));
2015 let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
2016 d2 = vfmaq_f32(d2, va2, vb2);
2017 n2 = vfmaq_f32(n2, vb2, vb2);
2018 let va3 = vld1q_f32(a.as_ptr().add(base + 12));
2019 let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
2020 d3 = vfmaq_f32(d3, va3, vb3);
2021 n3 = vfmaq_f32(n3, vb3, vb3);
2022 }
2023
2024 let acc_dot = vaddq_f32(vaddq_f32(d0, d1), vaddq_f32(d2, d3));
2025 let acc_norm = vaddq_f32(vaddq_f32(n0, n1), vaddq_f32(n2, n3));
2026 let mut dot = vaddvq_f32(acc_dot);
2027 let mut norm = vaddvq_f32(acc_norm);
2028
2029 let base = chunks16 * 16;
2030 for i in 0..remainder {
2031 dot += a[base + i] * b[base + i];
2032 norm += b[base + i] * b[base + i];
2033 }
2034
2035 (dot, norm)
2036}
2037
2038#[cfg(target_arch = "x86_64")]
2039#[target_feature(enable = "avx2", enable = "fma")]
2040#[allow(unsafe_op_in_unsafe_fn)]
2041unsafe fn fused_dot_norm_avx2(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
2042 use std::arch::x86_64::*;
2043
2044 let chunks32 = count / 32;
2045 let remainder = count % 32;
2046
2047 let mut d0 = _mm256_setzero_ps();
2048 let mut d1 = _mm256_setzero_ps();
2049 let mut d2 = _mm256_setzero_ps();
2050 let mut d3 = _mm256_setzero_ps();
2051 let mut n0 = _mm256_setzero_ps();
2052 let mut n1 = _mm256_setzero_ps();
2053 let mut n2 = _mm256_setzero_ps();
2054 let mut n3 = _mm256_setzero_ps();
2055
2056 for c in 0..chunks32 {
2057 let base = c * 32;
2058 let vb0 = _mm256_loadu_ps(b.as_ptr().add(base));
2059 d0 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base)), vb0, d0);
2060 n0 = _mm256_fmadd_ps(vb0, vb0, n0);
2061 let vb1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
2062 d1 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 8)), vb1, d1);
2063 n1 = _mm256_fmadd_ps(vb1, vb1, n1);
2064 let vb2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
2065 d2 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 16)), vb2, d2);
2066 n2 = _mm256_fmadd_ps(vb2, vb2, n2);
2067 let vb3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
2068 d3 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 24)), vb3, d3);
2069 n3 = _mm256_fmadd_ps(vb3, vb3, n3);
2070 }
2071
2072 let acc_dot = _mm256_add_ps(_mm256_add_ps(d0, d1), _mm256_add_ps(d2, d3));
2073 let acc_norm = _mm256_add_ps(_mm256_add_ps(n0, n1), _mm256_add_ps(n2, n3));
2074
2075 let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2077 let lo_d = _mm256_castps256_ps128(acc_dot);
2078 let sum_d = _mm_add_ps(lo_d, hi_d);
2079 let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2080 let sums_d = _mm_add_ps(sum_d, shuf_d);
2081 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2082 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2083
2084 let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2085 let lo_n = _mm256_castps256_ps128(acc_norm);
2086 let sum_n = _mm_add_ps(lo_n, hi_n);
2087 let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2088 let sums_n = _mm_add_ps(sum_n, shuf_n);
2089 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2090 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2091
2092 let base = chunks32 * 32;
2093 for i in 0..remainder {
2094 dot += a[base + i] * b[base + i];
2095 norm += b[base + i] * b[base + i];
2096 }
2097
2098 (dot, norm)
2099}
2100
2101#[cfg(target_arch = "x86_64")]
2102#[target_feature(enable = "sse")]
2103#[allow(unsafe_op_in_unsafe_fn)]
2104unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
2105 use std::arch::x86_64::*;
2106
2107 let chunks = count / 4;
2108 let remainder = count % 4;
2109
2110 let mut acc_dot = _mm_setzero_ps();
2111 let mut acc_norm = _mm_setzero_ps();
2112
2113 for chunk in 0..chunks {
2114 let base = chunk * 4;
2115 let va = _mm_loadu_ps(a.as_ptr().add(base));
2116 let vb = _mm_loadu_ps(b.as_ptr().add(base));
2117 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2118 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2119 }
2120
2121 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2123 let sums_d = _mm_add_ps(acc_dot, shuf_d);
2124 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2125 let final_d = _mm_add_ss(sums_d, shuf2_d);
2126 let mut dot = _mm_cvtss_f32(final_d);
2127
2128 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2129 let sums_n = _mm_add_ps(acc_norm, shuf_n);
2130 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2131 let final_n = _mm_add_ss(sums_n, shuf2_n);
2132 let mut norm = _mm_cvtss_f32(final_n);
2133
2134 let base = chunks * 4;
2135 for i in 0..remainder {
2136 dot += a[base + i] * b[base + i];
2137 norm += b[base + i] * b[base + i];
2138 }
2139
2140 (dot, norm)
2141}
2142
2143#[inline]
2149pub fn fast_inv_sqrt(x: f32) -> f32 {
2150 let half = 0.5 * x;
2151 let i = 0x5F37_5A86_u32.wrapping_sub(x.to_bits() >> 1);
2152 let y = f32::from_bits(i);
2153 let y = y * (1.5 - half * y * y); y * (1.5 - half * y * y) }
2156
2157#[inline]
2168pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2169 let n = scores.len();
2170 debug_assert!(vectors.len() >= n * dim);
2171 debug_assert_eq!(query.len(), dim);
2172
2173 if dim == 0 || n == 0 {
2174 return;
2175 }
2176
2177 let norm_q_sq = dot_product_f32(query, query, dim);
2179 if norm_q_sq < f32::EPSILON {
2180 for s in scores.iter_mut() {
2181 *s = 0.0;
2182 }
2183 return;
2184 }
2185 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2186
2187 for i in 0..n {
2188 let vec = &vectors[i * dim..(i + 1) * dim];
2189 let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2190 if norm_v_sq < f32::EPSILON {
2191 scores[i] = 0.0;
2192 } else {
2193 scores[i] = dot * inv_norm_q * fast_inv_sqrt(norm_v_sq);
2194 }
2195 }
2196}
2197
2198#[inline]
2204pub fn f32_to_f16(value: f32) -> u16 {
2205 let bits = value.to_bits();
2206 let sign = (bits >> 16) & 0x8000;
2207 let exp = ((bits >> 23) & 0xFF) as i32;
2208 let mantissa = bits & 0x7F_FFFF;
2209
2210 if exp == 255 {
2211 return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
2213 }
2214
2215 let exp16 = exp - 127 + 15;
2216
2217 if exp16 >= 31 {
2218 return (sign | 0x7C00) as u16; }
2220
2221 if exp16 <= 0 {
2222 if exp16 < -10 {
2223 return sign as u16; }
2225 let shift = (1 - exp16) as u32;
2226 let m = (mantissa | 0x80_0000) >> shift;
2227 let round_bit = (m >> 12) & 1;
2229 let sticky = m & 0xFFF;
2230 let m13 = m >> 13;
2231 let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
2232 return (sign | rounded) as u16;
2233 }
2234
2235 let round_bit = (mantissa >> 12) & 1;
2237 let sticky = mantissa & 0xFFF;
2238 let m13 = mantissa >> 13;
2239 let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
2240 if rounded > 0x3FF {
2242 let exp16_inc = exp16 as u32 + 1;
2243 if exp16_inc >= 31 {
2244 return (sign | 0x7C00) as u16; }
2246 (sign | (exp16_inc << 10)) as u16
2247 } else {
2248 (sign | ((exp16 as u32) << 10) | rounded) as u16
2249 }
2250}
2251
2252#[inline]
2254pub fn f16_to_f32(half: u16) -> f32 {
2255 let sign = ((half & 0x8000) as u32) << 16;
2256 let exp = ((half >> 10) & 0x1F) as u32;
2257 let mantissa = (half & 0x3FF) as u32;
2258
2259 if exp == 0 {
2260 if mantissa == 0 {
2261 return f32::from_bits(sign);
2262 }
2263 let mut e = 0u32;
2265 let mut m = mantissa;
2266 while (m & 0x400) == 0 {
2267 m <<= 1;
2268 e += 1;
2269 }
2270 return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
2271 }
2272
2273 if exp == 31 {
2274 return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
2275 }
2276
2277 f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
2278}
2279
2280const U8_SCALE: f32 = 127.5;
2285const U8_INV_SCALE: f32 = 1.0 / 127.5;
2286
2287#[inline]
2289pub fn f32_to_u8_saturating(value: f32) -> u8 {
2290 ((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
2291}
2292
2293#[inline]
2295pub fn u8_to_f32(byte: u8) -> f32 {
2296 byte as f32 * U8_INV_SCALE - 1.0
2297}
2298
2299pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
2305 debug_assert_eq!(src.len(), dst.len());
2306 for (s, d) in src.iter().zip(dst.iter_mut()) {
2307 *d = f32_to_f16(*s);
2308 }
2309}
2310
2311pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
2313 debug_assert_eq!(src.len(), dst.len());
2314 for (s, d) in src.iter().zip(dst.iter_mut()) {
2315 *d = f32_to_u8_saturating(*s);
2316 }
2317}
2318
2319#[cfg(target_arch = "aarch64")]
2324#[allow(unsafe_op_in_unsafe_fn)]
2325mod neon_quant {
2326 use std::arch::aarch64::*;
2327
2328 #[allow(clippy::incompatible_msrv)]
2334 #[target_feature(enable = "neon")]
2335 pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2336 let chunks16 = dim / 16;
2337 let remainder = dim % 16;
2338
2339 let mut acc_dot0 = vdupq_n_f32(0.0);
2341 let mut acc_dot1 = vdupq_n_f32(0.0);
2342 let mut acc_norm0 = vdupq_n_f32(0.0);
2343 let mut acc_norm1 = vdupq_n_f32(0.0);
2344
2345 for c in 0..chunks16 {
2346 let base = c * 16;
2347
2348 let v_raw0 = vld1q_u16(vec_f16.as_ptr().add(base));
2350 let v_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw0)));
2351 let v_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw0)));
2352 let q_raw0 = vld1q_u16(query_f16.as_ptr().add(base));
2353 let q_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw0)));
2354 let q_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw0)));
2355
2356 acc_dot0 = vfmaq_f32(acc_dot0, q_lo0, v_lo0);
2357 acc_dot0 = vfmaq_f32(acc_dot0, q_hi0, v_hi0);
2358 acc_norm0 = vfmaq_f32(acc_norm0, v_lo0, v_lo0);
2359 acc_norm0 = vfmaq_f32(acc_norm0, v_hi0, v_hi0);
2360
2361 let v_raw1 = vld1q_u16(vec_f16.as_ptr().add(base + 8));
2363 let v_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw1)));
2364 let v_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw1)));
2365 let q_raw1 = vld1q_u16(query_f16.as_ptr().add(base + 8));
2366 let q_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw1)));
2367 let q_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw1)));
2368
2369 acc_dot1 = vfmaq_f32(acc_dot1, q_lo1, v_lo1);
2370 acc_dot1 = vfmaq_f32(acc_dot1, q_hi1, v_hi1);
2371 acc_norm1 = vfmaq_f32(acc_norm1, v_lo1, v_lo1);
2372 acc_norm1 = vfmaq_f32(acc_norm1, v_hi1, v_hi1);
2373 }
2374
2375 let mut dot = vaddvq_f32(vaddq_f32(acc_dot0, acc_dot1));
2377 let mut norm = vaddvq_f32(vaddq_f32(acc_norm0, acc_norm1));
2378
2379 let base = chunks16 * 16;
2381 for i in 0..remainder {
2382 let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2383 let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2384 dot += q * v;
2385 norm += v * v;
2386 }
2387
2388 (dot, norm)
2389 }
2390
2391 #[target_feature(enable = "neon")]
2394 pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2395 let scale = vdupq_n_f32(super::U8_INV_SCALE);
2396 let offset = vdupq_n_f32(-1.0);
2397
2398 let chunks16 = dim / 16;
2399 let remainder = dim % 16;
2400
2401 let mut acc_dot = vdupq_n_f32(0.0);
2402 let mut acc_norm = vdupq_n_f32(0.0);
2403
2404 for c in 0..chunks16 {
2405 let base = c * 16;
2406
2407 let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2409
2410 let lo8 = vget_low_u8(bytes);
2412 let hi8 = vget_high_u8(bytes);
2413 let lo16 = vmovl_u8(lo8);
2414 let hi16 = vmovl_u8(hi8);
2415
2416 let f0 = vaddq_f32(
2417 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2418 offset,
2419 );
2420 let f1 = vaddq_f32(
2421 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2422 offset,
2423 );
2424 let f2 = vaddq_f32(
2425 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2426 offset,
2427 );
2428 let f3 = vaddq_f32(
2429 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2430 offset,
2431 );
2432
2433 let q0 = vld1q_f32(query.as_ptr().add(base));
2434 let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2435 let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2436 let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2437
2438 acc_dot = vfmaq_f32(acc_dot, q0, f0);
2439 acc_dot = vfmaq_f32(acc_dot, q1, f1);
2440 acc_dot = vfmaq_f32(acc_dot, q2, f2);
2441 acc_dot = vfmaq_f32(acc_dot, q3, f3);
2442
2443 acc_norm = vfmaq_f32(acc_norm, f0, f0);
2444 acc_norm = vfmaq_f32(acc_norm, f1, f1);
2445 acc_norm = vfmaq_f32(acc_norm, f2, f2);
2446 acc_norm = vfmaq_f32(acc_norm, f3, f3);
2447 }
2448
2449 let mut dot = vaddvq_f32(acc_dot);
2450 let mut norm = vaddvq_f32(acc_norm);
2451
2452 let base = chunks16 * 16;
2453 for i in 0..remainder {
2454 let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2455 dot += *query.get_unchecked(base + i) * v;
2456 norm += v * v;
2457 }
2458
2459 (dot, norm)
2460 }
2461
2462 #[allow(clippy::incompatible_msrv)]
2464 #[target_feature(enable = "neon")]
2465 pub unsafe fn dot_product_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2466 let chunks8 = dim / 8;
2467 let remainder = dim % 8;
2468
2469 let mut acc = vdupq_n_f32(0.0);
2470
2471 for c in 0..chunks8 {
2472 let base = c * 8;
2473 let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2474 let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2475 let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2476 let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2477 let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2478 let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2479 acc = vfmaq_f32(acc, q_lo, v_lo);
2480 acc = vfmaq_f32(acc, q_hi, v_hi);
2481 }
2482
2483 let mut dot = vaddvq_f32(acc);
2484 let base = chunks8 * 8;
2485 for i in 0..remainder {
2486 let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2487 let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2488 dot += q * v;
2489 }
2490 dot
2491 }
2492
2493 #[target_feature(enable = "neon")]
2495 pub unsafe fn dot_product_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2496 let scale = vdupq_n_f32(super::U8_INV_SCALE);
2497 let offset = vdupq_n_f32(-1.0);
2498 let chunks16 = dim / 16;
2499 let remainder = dim % 16;
2500
2501 let mut acc = vdupq_n_f32(0.0);
2502
2503 for c in 0..chunks16 {
2504 let base = c * 16;
2505 let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2506 let lo8 = vget_low_u8(bytes);
2507 let hi8 = vget_high_u8(bytes);
2508 let lo16 = vmovl_u8(lo8);
2509 let hi16 = vmovl_u8(hi8);
2510 let f0 = vaddq_f32(
2511 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2512 offset,
2513 );
2514 let f1 = vaddq_f32(
2515 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2516 offset,
2517 );
2518 let f2 = vaddq_f32(
2519 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2520 offset,
2521 );
2522 let f3 = vaddq_f32(
2523 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2524 offset,
2525 );
2526 let q0 = vld1q_f32(query.as_ptr().add(base));
2527 let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2528 let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2529 let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2530 acc = vfmaq_f32(acc, q0, f0);
2531 acc = vfmaq_f32(acc, q1, f1);
2532 acc = vfmaq_f32(acc, q2, f2);
2533 acc = vfmaq_f32(acc, q3, f3);
2534 }
2535
2536 let mut dot = vaddvq_f32(acc);
2537 let base = chunks16 * 16;
2538 for i in 0..remainder {
2539 let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2540 dot += *query.get_unchecked(base + i) * v;
2541 }
2542 dot
2543 }
2544}
2545
2546#[allow(dead_code)]
2551fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2552 let mut dot = 0.0f32;
2553 let mut norm = 0.0f32;
2554 for i in 0..dim {
2555 let v = f16_to_f32(vec_f16[i]);
2556 let q = f16_to_f32(query_f16[i]);
2557 dot += q * v;
2558 norm += v * v;
2559 }
2560 (dot, norm)
2561}
2562
2563#[allow(dead_code)]
2564fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2565 let mut dot = 0.0f32;
2566 let mut norm = 0.0f32;
2567 for i in 0..dim {
2568 let v = u8_to_f32(vec_u8[i]);
2569 dot += query[i] * v;
2570 norm += v * v;
2571 }
2572 (dot, norm)
2573}
2574
2575#[allow(dead_code)]
2576fn dot_product_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2577 let mut dot = 0.0f32;
2578 for i in 0..dim {
2579 dot += f16_to_f32(query_f16[i]) * f16_to_f32(vec_f16[i]);
2580 }
2581 dot
2582}
2583
2584#[allow(dead_code)]
2585fn dot_product_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2586 let mut dot = 0.0f32;
2587 for i in 0..dim {
2588 dot += query[i] * u8_to_f32(vec_u8[i]);
2589 }
2590 dot
2591}
2592
2593#[cfg(target_arch = "x86_64")]
2598#[target_feature(enable = "sse2", enable = "sse4.1")]
2599#[allow(unsafe_op_in_unsafe_fn)]
2600unsafe fn fused_dot_norm_f16_sse(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2601 use std::arch::x86_64::*;
2602
2603 let chunks = dim / 4;
2604 let remainder = dim % 4;
2605
2606 let mut acc_dot = _mm_setzero_ps();
2607 let mut acc_norm = _mm_setzero_ps();
2608
2609 for chunk in 0..chunks {
2610 let base = chunk * 4;
2611 let v0 = f16_to_f32(*vec_f16.get_unchecked(base));
2613 let v1 = f16_to_f32(*vec_f16.get_unchecked(base + 1));
2614 let v2 = f16_to_f32(*vec_f16.get_unchecked(base + 2));
2615 let v3 = f16_to_f32(*vec_f16.get_unchecked(base + 3));
2616 let vb = _mm_set_ps(v3, v2, v1, v0);
2617
2618 let q0 = f16_to_f32(*query_f16.get_unchecked(base));
2619 let q1 = f16_to_f32(*query_f16.get_unchecked(base + 1));
2620 let q2 = f16_to_f32(*query_f16.get_unchecked(base + 2));
2621 let q3 = f16_to_f32(*query_f16.get_unchecked(base + 3));
2622 let va = _mm_set_ps(q3, q2, q1, q0);
2623
2624 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2625 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2626 }
2627
2628 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2630 let sums_d = _mm_add_ps(acc_dot, shuf_d);
2631 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2632 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2633
2634 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2635 let sums_n = _mm_add_ps(acc_norm, shuf_n);
2636 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2637 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2638
2639 let base = chunks * 4;
2640 for i in 0..remainder {
2641 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2642 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2643 dot += q * v;
2644 norm += v * v;
2645 }
2646
2647 (dot, norm)
2648}
2649
2650#[cfg(target_arch = "x86_64")]
2651#[target_feature(enable = "sse2", enable = "sse4.1")]
2652#[allow(unsafe_op_in_unsafe_fn)]
2653unsafe fn fused_dot_norm_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2654 use std::arch::x86_64::*;
2655
2656 let scale = _mm_set1_ps(U8_INV_SCALE);
2657 let offset = _mm_set1_ps(-1.0);
2658
2659 let chunks = dim / 4;
2660 let remainder = dim % 4;
2661
2662 let mut acc_dot = _mm_setzero_ps();
2663 let mut acc_norm = _mm_setzero_ps();
2664
2665 for chunk in 0..chunks {
2666 let base = chunk * 4;
2667
2668 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2670 vec_u8.as_ptr().add(base) as *const i32
2671 ));
2672 let ints = _mm_cvtepu8_epi32(bytes);
2673 let floats = _mm_cvtepi32_ps(ints);
2674 let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2675
2676 let va = _mm_loadu_ps(query.as_ptr().add(base));
2677
2678 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2679 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2680 }
2681
2682 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2684 let sums_d = _mm_add_ps(acc_dot, shuf_d);
2685 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2686 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2687
2688 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2689 let sums_n = _mm_add_ps(acc_norm, shuf_n);
2690 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2691 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2692
2693 let base = chunks * 4;
2694 for i in 0..remainder {
2695 let v = u8_to_f32(*vec_u8.get_unchecked(base + i));
2696 dot += *query.get_unchecked(base + i) * v;
2697 norm += v * v;
2698 }
2699
2700 (dot, norm)
2701}
2702
2703#[cfg(target_arch = "x86_64")]
2708#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2709#[allow(unsafe_op_in_unsafe_fn)]
2710unsafe fn fused_dot_norm_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2711 use std::arch::x86_64::*;
2712
2713 let chunks16 = dim / 16;
2714 let remainder = dim % 16;
2715
2716 let mut acc_dot0 = _mm256_setzero_ps();
2718 let mut acc_dot1 = _mm256_setzero_ps();
2719 let mut acc_norm0 = _mm256_setzero_ps();
2720 let mut acc_norm1 = _mm256_setzero_ps();
2721
2722 for c in 0..chunks16 {
2723 let base = c * 16;
2724
2725 let v_raw0 = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2727 let vb0 = _mm256_cvtph_ps(v_raw0);
2728 let q_raw0 = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2729 let qa0 = _mm256_cvtph_ps(q_raw0);
2730 acc_dot0 = _mm256_fmadd_ps(qa0, vb0, acc_dot0);
2731 acc_norm0 = _mm256_fmadd_ps(vb0, vb0, acc_norm0);
2732
2733 let v_raw1 = _mm_loadu_si128(vec_f16.as_ptr().add(base + 8) as *const __m128i);
2735 let vb1 = _mm256_cvtph_ps(v_raw1);
2736 let q_raw1 = _mm_loadu_si128(query_f16.as_ptr().add(base + 8) as *const __m128i);
2737 let qa1 = _mm256_cvtph_ps(q_raw1);
2738 acc_dot1 = _mm256_fmadd_ps(qa1, vb1, acc_dot1);
2739 acc_norm1 = _mm256_fmadd_ps(vb1, vb1, acc_norm1);
2740 }
2741
2742 let acc_dot = _mm256_add_ps(acc_dot0, acc_dot1);
2744 let acc_norm = _mm256_add_ps(acc_norm0, acc_norm1);
2745
2746 let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2748 let lo_d = _mm256_castps256_ps128(acc_dot);
2749 let sum_d = _mm_add_ps(lo_d, hi_d);
2750 let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2751 let sums_d = _mm_add_ps(sum_d, shuf_d);
2752 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2753 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2754
2755 let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2756 let lo_n = _mm256_castps256_ps128(acc_norm);
2757 let sum_n = _mm_add_ps(lo_n, hi_n);
2758 let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2759 let sums_n = _mm_add_ps(sum_n, shuf_n);
2760 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2761 let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2762
2763 let base = chunks16 * 16;
2764 for i in 0..remainder {
2765 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2766 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2767 dot += q * v;
2768 norm += v * v;
2769 }
2770
2771 (dot, norm)
2772}
2773
2774#[cfg(target_arch = "x86_64")]
2775#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2776#[allow(unsafe_op_in_unsafe_fn)]
2777unsafe fn dot_product_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2778 use std::arch::x86_64::*;
2779
2780 let chunks = dim / 8;
2781 let remainder = dim % 8;
2782 let mut acc = _mm256_setzero_ps();
2783
2784 for chunk in 0..chunks {
2785 let base = chunk * 8;
2786 let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2787 let vb = _mm256_cvtph_ps(v_raw);
2788 let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2789 let qa = _mm256_cvtph_ps(q_raw);
2790 acc = _mm256_fmadd_ps(qa, vb, acc);
2791 }
2792
2793 let hi = _mm256_extractf128_ps(acc, 1);
2794 let lo = _mm256_castps256_ps128(acc);
2795 let sum = _mm_add_ps(lo, hi);
2796 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
2797 let sums = _mm_add_ps(sum, shuf);
2798 let shuf2 = _mm_movehl_ps(sums, sums);
2799 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2800
2801 let base = chunks * 8;
2802 for i in 0..remainder {
2803 let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2804 let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2805 dot += q * v;
2806 }
2807 dot
2808}
2809
2810#[cfg(target_arch = "x86_64")]
2811#[target_feature(enable = "sse2", enable = "sse4.1")]
2812#[allow(unsafe_op_in_unsafe_fn)]
2813unsafe fn dot_product_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2814 use std::arch::x86_64::*;
2815
2816 let scale = _mm_set1_ps(U8_INV_SCALE);
2817 let offset = _mm_set1_ps(-1.0);
2818 let chunks = dim / 4;
2819 let remainder = dim % 4;
2820 let mut acc = _mm_setzero_ps();
2821
2822 for chunk in 0..chunks {
2823 let base = chunk * 4;
2824 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2825 vec_u8.as_ptr().add(base) as *const i32
2826 ));
2827 let ints = _mm_cvtepu8_epi32(bytes);
2828 let floats = _mm_cvtepi32_ps(ints);
2829 let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2830 let va = _mm_loadu_ps(query.as_ptr().add(base));
2831 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
2832 }
2833
2834 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01);
2835 let sums = _mm_add_ps(acc, shuf);
2836 let shuf2 = _mm_movehl_ps(sums, sums);
2837 let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2838
2839 let base = chunks * 4;
2840 for i in 0..remainder {
2841 dot += *query.get_unchecked(base + i) * u8_to_f32(*vec_u8.get_unchecked(base + i));
2842 }
2843 dot
2844}
2845
2846#[inline]
2851fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2852 #[cfg(target_arch = "aarch64")]
2853 {
2854 return unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) };
2855 }
2856
2857 #[cfg(target_arch = "x86_64")]
2858 {
2859 if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2860 return unsafe { fused_dot_norm_f16_f16c(query_f16, vec_f16, dim) };
2861 }
2862 if sse::is_available() {
2863 return unsafe { fused_dot_norm_f16_sse(query_f16, vec_f16, dim) };
2864 }
2865 }
2866
2867 #[allow(unreachable_code)]
2868 fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
2869}
2870
2871#[inline]
2872fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2873 #[cfg(target_arch = "aarch64")]
2874 {
2875 return unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) };
2876 }
2877
2878 #[cfg(target_arch = "x86_64")]
2879 {
2880 if sse::is_available() {
2881 return unsafe { fused_dot_norm_u8_sse(query, vec_u8, dim) };
2882 }
2883 }
2884
2885 #[allow(unreachable_code)]
2886 fused_dot_norm_u8_scalar(query, vec_u8, dim)
2887}
2888
2889#[inline]
2892fn dot_product_f16_quant(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2893 #[cfg(target_arch = "aarch64")]
2894 {
2895 return unsafe { neon_quant::dot_product_f16(query_f16, vec_f16, dim) };
2896 }
2897
2898 #[cfg(target_arch = "x86_64")]
2899 {
2900 if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2901 return unsafe { dot_product_f16_f16c(query_f16, vec_f16, dim) };
2902 }
2903 }
2904
2905 #[allow(unreachable_code)]
2906 dot_product_f16_scalar(query_f16, vec_f16, dim)
2907}
2908
2909#[inline]
2910fn dot_product_u8_quant(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2911 #[cfg(target_arch = "aarch64")]
2912 {
2913 return unsafe { neon_quant::dot_product_u8(query, vec_u8, dim) };
2914 }
2915
2916 #[cfg(target_arch = "x86_64")]
2917 {
2918 if sse::is_available() {
2919 return unsafe { dot_product_u8_sse(query, vec_u8, dim) };
2920 }
2921 }
2922
2923 #[allow(unreachable_code)]
2924 dot_product_u8_scalar(query, vec_u8, dim)
2925}
2926
2927#[inline]
2938pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2939 let n = scores.len();
2940 if dim == 0 || n == 0 {
2941 return;
2942 }
2943
2944 let norm_q_sq = dot_product_f32(query, query, dim);
2946 if norm_q_sq < f32::EPSILON {
2947 for s in scores.iter_mut() {
2948 *s = 0.0;
2949 }
2950 return;
2951 }
2952 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2953
2954 let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2956
2957 let vec_bytes = dim * 2;
2958 debug_assert!(vectors_raw.len() >= n * vec_bytes);
2959
2960 debug_assert!(
2963 (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2964 "f16 vector data not 2-byte aligned"
2965 );
2966
2967 for i in 0..n {
2968 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2969 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2970
2971 let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
2972 scores[i] = if norm_v_sq < f32::EPSILON {
2973 0.0
2974 } else {
2975 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2976 };
2977 }
2978}
2979
2980#[inline]
2986pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2987 let n = scores.len();
2988 if dim == 0 || n == 0 {
2989 return;
2990 }
2991
2992 let norm_q_sq = dot_product_f32(query, query, dim);
2993 if norm_q_sq < f32::EPSILON {
2994 for s in scores.iter_mut() {
2995 *s = 0.0;
2996 }
2997 return;
2998 }
2999 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3000
3001 debug_assert!(vectors_raw.len() >= n * dim);
3002
3003 for i in 0..n {
3004 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3005
3006 let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
3007 scores[i] = if norm_v_sq < f32::EPSILON {
3008 0.0
3009 } else {
3010 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3011 };
3012 }
3013}
3014
3015#[inline]
3024pub fn batch_dot_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
3025 let n = scores.len();
3026 debug_assert!(vectors.len() >= n * dim);
3027 debug_assert_eq!(query.len(), dim);
3028
3029 if dim == 0 || n == 0 {
3030 return;
3031 }
3032
3033 let norm_q_sq = dot_product_f32(query, query, dim);
3034 if norm_q_sq < f32::EPSILON {
3035 for s in scores.iter_mut() {
3036 *s = 0.0;
3037 }
3038 return;
3039 }
3040 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3041
3042 for i in 0..n {
3043 let vec = &vectors[i * dim..(i + 1) * dim];
3044 let dot = dot_product_f32(query, vec, dim);
3045 scores[i] = dot * inv_norm_q;
3046 }
3047}
3048
3049#[inline]
3054pub fn batch_dot_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
3055 let n = scores.len();
3056 if dim == 0 || n == 0 {
3057 return;
3058 }
3059
3060 let norm_q_sq = dot_product_f32(query, query, dim);
3061 if norm_q_sq < f32::EPSILON {
3062 for s in scores.iter_mut() {
3063 *s = 0.0;
3064 }
3065 return;
3066 }
3067 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3068
3069 let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
3070 let vec_bytes = dim * 2;
3071 debug_assert!(vectors_raw.len() >= n * vec_bytes);
3072 debug_assert!(
3073 (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
3074 "f16 vector data not 2-byte aligned"
3075 );
3076
3077 for i in 0..n {
3078 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3079 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3080 let dot = dot_product_f16_quant(&query_f16, f16_slice, dim);
3081 scores[i] = dot * inv_norm_q;
3082 }
3083}
3084
3085#[inline]
3090pub fn batch_dot_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
3091 let n = scores.len();
3092 if dim == 0 || n == 0 {
3093 return;
3094 }
3095
3096 let norm_q_sq = dot_product_f32(query, query, dim);
3097 if norm_q_sq < f32::EPSILON {
3098 for s in scores.iter_mut() {
3099 *s = 0.0;
3100 }
3101 return;
3102 }
3103 let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3104
3105 debug_assert!(vectors_raw.len() >= n * dim);
3106
3107 for i in 0..n {
3108 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3109 let dot = dot_product_u8_quant(query, u8_slice, dim);
3110 scores[i] = dot * inv_norm_q;
3111 }
3112}
3113
3114#[inline]
3120pub fn batch_cosine_scores_precomp(
3121 query: &[f32],
3122 vectors: &[f32],
3123 dim: usize,
3124 scores: &mut [f32],
3125 inv_norm_q: f32,
3126) {
3127 let n = scores.len();
3128 debug_assert!(vectors.len() >= n * dim);
3129 for i in 0..n {
3130 let vec = &vectors[i * dim..(i + 1) * dim];
3131 let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
3132 scores[i] = if norm_v_sq < f32::EPSILON {
3133 0.0
3134 } else {
3135 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3136 };
3137 }
3138}
3139
3140#[inline]
3142pub fn batch_cosine_scores_f16_precomp(
3143 query_f16: &[u16],
3144 vectors_raw: &[u8],
3145 dim: usize,
3146 scores: &mut [f32],
3147 inv_norm_q: f32,
3148) {
3149 let n = scores.len();
3150 let vec_bytes = dim * 2;
3151 debug_assert!(vectors_raw.len() >= n * vec_bytes);
3152 for i in 0..n {
3153 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3154 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3155 let (dot, norm_v_sq) = fused_dot_norm_f16(query_f16, f16_slice, dim);
3156 scores[i] = if norm_v_sq < f32::EPSILON {
3157 0.0
3158 } else {
3159 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3160 };
3161 }
3162}
3163
3164#[inline]
3166pub fn batch_cosine_scores_u8_precomp(
3167 query: &[f32],
3168 vectors_raw: &[u8],
3169 dim: usize,
3170 scores: &mut [f32],
3171 inv_norm_q: f32,
3172) {
3173 let n = scores.len();
3174 debug_assert!(vectors_raw.len() >= n * dim);
3175 for i in 0..n {
3176 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3177 let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
3178 scores[i] = if norm_v_sq < f32::EPSILON {
3179 0.0
3180 } else {
3181 dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3182 };
3183 }
3184}
3185
3186#[inline]
3188pub fn batch_dot_scores_precomp(
3189 query: &[f32],
3190 vectors: &[f32],
3191 dim: usize,
3192 scores: &mut [f32],
3193 inv_norm_q: f32,
3194) {
3195 let n = scores.len();
3196 debug_assert!(vectors.len() >= n * dim);
3197 for i in 0..n {
3198 let vec = &vectors[i * dim..(i + 1) * dim];
3199 scores[i] = dot_product_f32(query, vec, dim) * inv_norm_q;
3200 }
3201}
3202
3203#[inline]
3205pub fn batch_dot_scores_f16_precomp(
3206 query_f16: &[u16],
3207 vectors_raw: &[u8],
3208 dim: usize,
3209 scores: &mut [f32],
3210 inv_norm_q: f32,
3211) {
3212 let n = scores.len();
3213 let vec_bytes = dim * 2;
3214 debug_assert!(vectors_raw.len() >= n * vec_bytes);
3215 for i in 0..n {
3216 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3217 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3218 scores[i] = dot_product_f16_quant(query_f16, f16_slice, dim) * inv_norm_q;
3219 }
3220}
3221
3222#[inline]
3224pub fn batch_dot_scores_u8_precomp(
3225 query: &[f32],
3226 vectors_raw: &[u8],
3227 dim: usize,
3228 scores: &mut [f32],
3229 inv_norm_q: f32,
3230) {
3231 let n = scores.len();
3232 debug_assert!(vectors_raw.len() >= n * dim);
3233 for i in 0..n {
3234 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3235 scores[i] = dot_product_u8_quant(query, u8_slice, dim) * inv_norm_q;
3236 }
3237}
3238
3239#[inline]
3244pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3245 debug_assert_eq!(a.len(), b.len());
3246 let count = a.len();
3247
3248 if count == 0 {
3249 return 0.0;
3250 }
3251
3252 let dot = dot_product_f32(a, b, count);
3253 let norm_a = dot_product_f32(a, a, count);
3254 let norm_b = dot_product_f32(b, b, count);
3255
3256 let denom = (norm_a * norm_b).sqrt();
3257 if denom < f32::EPSILON {
3258 return 0.0;
3259 }
3260
3261 dot / denom
3262}
3263
3264#[inline]
3268pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
3269 debug_assert_eq!(a.len(), b.len());
3270 let count = a.len();
3271
3272 if count == 0 {
3273 return 0.0;
3274 }
3275
3276 #[cfg(target_arch = "aarch64")]
3277 {
3278 if neon::is_available() {
3279 return unsafe { squared_euclidean_neon(a, b, count) };
3280 }
3281 }
3282
3283 #[cfg(target_arch = "x86_64")]
3284 {
3285 if avx2::is_available() && is_x86_feature_detected!("fma") {
3286 return unsafe { squared_euclidean_avx2(a, b, count) };
3287 }
3288 if sse::is_available() {
3289 return unsafe { squared_euclidean_sse(a, b, count) };
3290 }
3291 }
3292
3293 a.iter()
3295 .zip(b.iter())
3296 .map(|(&x, &y)| {
3297 let d = x - y;
3298 d * d
3299 })
3300 .sum()
3301}
3302
3303#[cfg(target_arch = "aarch64")]
3304#[target_feature(enable = "neon")]
3305#[allow(unsafe_op_in_unsafe_fn)]
3306unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
3307 use std::arch::aarch64::*;
3308
3309 let chunks16 = count / 16;
3310 let remainder = count % 16;
3311
3312 let mut acc0 = vdupq_n_f32(0.0);
3314 let mut acc1 = vdupq_n_f32(0.0);
3315 let mut acc2 = vdupq_n_f32(0.0);
3316 let mut acc3 = vdupq_n_f32(0.0);
3317
3318 for c in 0..chunks16 {
3319 let base = c * 16;
3320 let va0 = vld1q_f32(a.as_ptr().add(base));
3321 let vb0 = vld1q_f32(b.as_ptr().add(base));
3322 let d0 = vsubq_f32(va0, vb0);
3323 acc0 = vfmaq_f32(acc0, d0, d0);
3324
3325 let va1 = vld1q_f32(a.as_ptr().add(base + 4));
3326 let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
3327 let d1 = vsubq_f32(va1, vb1);
3328 acc1 = vfmaq_f32(acc1, d1, d1);
3329
3330 let va2 = vld1q_f32(a.as_ptr().add(base + 8));
3331 let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
3332 let d2 = vsubq_f32(va2, vb2);
3333 acc2 = vfmaq_f32(acc2, d2, d2);
3334
3335 let va3 = vld1q_f32(a.as_ptr().add(base + 12));
3336 let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
3337 let d3 = vsubq_f32(va3, vb3);
3338 acc3 = vfmaq_f32(acc3, d3, d3);
3339 }
3340
3341 let combined = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
3343 let mut sum = vaddvq_f32(combined);
3344
3345 let base = chunks16 * 16;
3347 for i in 0..remainder {
3348 let d = a[base + i] - b[base + i];
3349 sum += d * d;
3350 }
3351
3352 sum
3353}
3354
3355#[cfg(target_arch = "x86_64")]
3356#[target_feature(enable = "sse")]
3357#[allow(unsafe_op_in_unsafe_fn)]
3358unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
3359 use std::arch::x86_64::*;
3360
3361 let chunks16 = count / 16;
3362 let remainder = count % 16;
3363
3364 let mut acc0 = _mm_setzero_ps();
3366 let mut acc1 = _mm_setzero_ps();
3367 let mut acc2 = _mm_setzero_ps();
3368 let mut acc3 = _mm_setzero_ps();
3369
3370 for c in 0..chunks16 {
3371 let base = c * 16;
3372
3373 let d0 = _mm_sub_ps(
3374 _mm_loadu_ps(a.as_ptr().add(base)),
3375 _mm_loadu_ps(b.as_ptr().add(base)),
3376 );
3377 acc0 = _mm_add_ps(acc0, _mm_mul_ps(d0, d0));
3378
3379 let d1 = _mm_sub_ps(
3380 _mm_loadu_ps(a.as_ptr().add(base + 4)),
3381 _mm_loadu_ps(b.as_ptr().add(base + 4)),
3382 );
3383 acc1 = _mm_add_ps(acc1, _mm_mul_ps(d1, d1));
3384
3385 let d2 = _mm_sub_ps(
3386 _mm_loadu_ps(a.as_ptr().add(base + 8)),
3387 _mm_loadu_ps(b.as_ptr().add(base + 8)),
3388 );
3389 acc2 = _mm_add_ps(acc2, _mm_mul_ps(d2, d2));
3390
3391 let d3 = _mm_sub_ps(
3392 _mm_loadu_ps(a.as_ptr().add(base + 12)),
3393 _mm_loadu_ps(b.as_ptr().add(base + 12)),
3394 );
3395 acc3 = _mm_add_ps(acc3, _mm_mul_ps(d3, d3));
3396 }
3397
3398 let combined = _mm_add_ps(_mm_add_ps(acc0, acc1), _mm_add_ps(acc2, acc3));
3400
3401 let shuf = _mm_shuffle_ps(combined, combined, 0b10_11_00_01);
3403 let sums = _mm_add_ps(combined, shuf);
3404 let shuf2 = _mm_movehl_ps(sums, sums);
3405 let final_sum = _mm_add_ss(sums, shuf2);
3406
3407 let mut sum = _mm_cvtss_f32(final_sum);
3408
3409 let base = chunks16 * 16;
3411 for i in 0..remainder {
3412 let d = a[base + i] - b[base + i];
3413 sum += d * d;
3414 }
3415
3416 sum
3417}
3418
3419#[cfg(target_arch = "x86_64")]
3420#[target_feature(enable = "avx2", enable = "fma")]
3421#[allow(unsafe_op_in_unsafe_fn)]
3422unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
3423 use std::arch::x86_64::*;
3424
3425 let chunks32 = count / 32;
3426 let remainder = count % 32;
3427
3428 let mut acc0 = _mm256_setzero_ps();
3430 let mut acc1 = _mm256_setzero_ps();
3431 let mut acc2 = _mm256_setzero_ps();
3432 let mut acc3 = _mm256_setzero_ps();
3433
3434 for c in 0..chunks32 {
3435 let base = c * 32;
3436
3437 let d0 = _mm256_sub_ps(
3438 _mm256_loadu_ps(a.as_ptr().add(base)),
3439 _mm256_loadu_ps(b.as_ptr().add(base)),
3440 );
3441 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
3442
3443 let d1 = _mm256_sub_ps(
3444 _mm256_loadu_ps(a.as_ptr().add(base + 8)),
3445 _mm256_loadu_ps(b.as_ptr().add(base + 8)),
3446 );
3447 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
3448
3449 let d2 = _mm256_sub_ps(
3450 _mm256_loadu_ps(a.as_ptr().add(base + 16)),
3451 _mm256_loadu_ps(b.as_ptr().add(base + 16)),
3452 );
3453 acc2 = _mm256_fmadd_ps(d2, d2, acc2);
3454
3455 let d3 = _mm256_sub_ps(
3456 _mm256_loadu_ps(a.as_ptr().add(base + 24)),
3457 _mm256_loadu_ps(b.as_ptr().add(base + 24)),
3458 );
3459 acc3 = _mm256_fmadd_ps(d3, d3, acc3);
3460 }
3461
3462 let combined = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
3464
3465 let high = _mm256_extractf128_ps(combined, 1);
3467 let low = _mm256_castps256_ps128(combined);
3468 let sum128 = _mm_add_ps(low, high);
3469
3470 let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
3471 let sums = _mm_add_ps(sum128, shuf);
3472 let shuf2 = _mm_movehl_ps(sums, sums);
3473 let final_sum = _mm_add_ss(sums, shuf2);
3474
3475 let mut sum = _mm_cvtss_f32(final_sum);
3476
3477 let base = chunks32 * 32;
3479 for i in 0..remainder {
3480 let d = a[base + i] - b[base + i];
3481 sum += d * d;
3482 }
3483
3484 sum
3485}
3486
3487#[inline]
3493pub fn batch_squared_euclidean_distances(
3494 query: &[f32],
3495 vectors: &[Vec<f32>],
3496 distances: &mut [f32],
3497) {
3498 debug_assert_eq!(vectors.len(), distances.len());
3499
3500 #[cfg(target_arch = "x86_64")]
3501 {
3502 if avx2::is_available() && is_x86_feature_detected!("fma") {
3503 for (i, vec) in vectors.iter().enumerate() {
3504 distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
3505 }
3506 return;
3507 }
3508 }
3509
3510 for (i, vec) in vectors.iter().enumerate() {
3512 distances[i] = squared_euclidean_distance(query, vec);
3513 }
3514}
3515
3516#[cfg(test)]
3517mod tests {
3518 use super::*;
3519
3520 #[test]
3521 fn test_unpack_8bit() {
3522 let input: Vec<u8> = (0..128).collect();
3523 let mut output = vec![0u32; 128];
3524 unpack_8bit(&input, &mut output, 128);
3525
3526 for (i, &v) in output.iter().enumerate() {
3527 assert_eq!(v, i as u32);
3528 }
3529 }
3530
3531 #[test]
3532 fn test_unpack_16bit() {
3533 let mut input = vec![0u8; 256];
3534 for i in 0..128 {
3535 let val = (i * 100) as u16;
3536 input[i * 2] = val as u8;
3537 input[i * 2 + 1] = (val >> 8) as u8;
3538 }
3539
3540 let mut output = vec![0u32; 128];
3541 unpack_16bit(&input, &mut output, 128);
3542
3543 for (i, &v) in output.iter().enumerate() {
3544 assert_eq!(v, (i * 100) as u32);
3545 }
3546 }
3547
3548 #[test]
3549 fn test_unpack_32bit() {
3550 let mut input = vec![0u8; 512];
3551 for i in 0..128 {
3552 let val = (i * 1000) as u32;
3553 let bytes = val.to_le_bytes();
3554 input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
3555 }
3556
3557 let mut output = vec![0u32; 128];
3558 unpack_32bit(&input, &mut output, 128);
3559
3560 for (i, &v) in output.iter().enumerate() {
3561 assert_eq!(v, (i * 1000) as u32);
3562 }
3563 }
3564
3565 #[test]
3566 fn test_delta_decode() {
3567 let deltas = vec![4u32, 4, 9, 19];
3571 let mut output = vec![0u32; 5];
3572
3573 delta_decode(&mut output, &deltas, 10, 5);
3574
3575 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3576 }
3577
3578 #[test]
3579 fn test_add_one() {
3580 let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
3581 add_one(&mut values, 8);
3582
3583 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
3584 }
3585
3586 #[test]
3587 fn test_bits_needed() {
3588 assert_eq!(bits_needed(0), 0);
3589 assert_eq!(bits_needed(1), 1);
3590 assert_eq!(bits_needed(2), 2);
3591 assert_eq!(bits_needed(3), 2);
3592 assert_eq!(bits_needed(4), 3);
3593 assert_eq!(bits_needed(255), 8);
3594 assert_eq!(bits_needed(256), 9);
3595 assert_eq!(bits_needed(u32::MAX), 32);
3596 }
3597
3598 #[test]
3599 fn test_unpack_8bit_delta_decode() {
3600 let input: Vec<u8> = vec![4, 4, 9, 19];
3604 let mut output = vec![0u32; 5];
3605
3606 unpack_8bit_delta_decode(&input, &mut output, 10, 5);
3607
3608 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3609 }
3610
3611 #[test]
3612 fn test_unpack_16bit_delta_decode() {
3613 let mut input = vec![0u8; 8];
3617 for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
3618 input[i * 2] = delta as u8;
3619 input[i * 2 + 1] = (delta >> 8) as u8;
3620 }
3621 let mut output = vec![0u32; 5];
3622
3623 unpack_16bit_delta_decode(&input, &mut output, 100, 5);
3624
3625 assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
3626 }
3627
3628 #[test]
3629 fn test_fused_vs_separate_8bit() {
3630 let input: Vec<u8> = (0..127).collect();
3632 let first_value = 1000u32;
3633 let count = 128;
3634
3635 let mut unpacked = vec![0u32; 128];
3637 unpack_8bit(&input, &mut unpacked, 127);
3638 let mut separate_output = vec![0u32; 128];
3639 delta_decode(&mut separate_output, &unpacked, first_value, count);
3640
3641 let mut fused_output = vec![0u32; 128];
3643 unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
3644
3645 assert_eq!(separate_output, fused_output);
3646 }
3647
3648 #[test]
3649 fn test_round_bit_width() {
3650 assert_eq!(round_bit_width(0), 0);
3651 assert_eq!(round_bit_width(1), 8);
3652 assert_eq!(round_bit_width(5), 8);
3653 assert_eq!(round_bit_width(8), 8);
3654 assert_eq!(round_bit_width(9), 16);
3655 assert_eq!(round_bit_width(12), 16);
3656 assert_eq!(round_bit_width(16), 16);
3657 assert_eq!(round_bit_width(17), 32);
3658 assert_eq!(round_bit_width(24), 32);
3659 assert_eq!(round_bit_width(32), 32);
3660 }
3661
3662 #[test]
3663 fn test_rounded_bitwidth_from_exact() {
3664 assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
3665 assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
3666 assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
3667 assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
3668 assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
3669 assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
3670 assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
3671 }
3672
3673 #[test]
3674 fn test_pack_unpack_rounded_8bit() {
3675 let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
3676 let mut packed = vec![0u8; 128];
3677
3678 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
3679 assert_eq!(bytes_written, 128);
3680
3681 let mut unpacked = vec![0u32; 128];
3682 unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
3683
3684 assert_eq!(values, unpacked);
3685 }
3686
3687 #[test]
3688 fn test_pack_unpack_rounded_16bit() {
3689 let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
3690 let mut packed = vec![0u8; 256];
3691
3692 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
3693 assert_eq!(bytes_written, 256);
3694
3695 let mut unpacked = vec![0u32; 128];
3696 unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
3697
3698 assert_eq!(values, unpacked);
3699 }
3700
3701 #[test]
3702 fn test_pack_unpack_rounded_32bit() {
3703 let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
3704 let mut packed = vec![0u8; 512];
3705
3706 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
3707 assert_eq!(bytes_written, 512);
3708
3709 let mut unpacked = vec![0u32; 128];
3710 unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
3711
3712 assert_eq!(values, unpacked);
3713 }
3714
3715 #[test]
3716 fn test_unpack_rounded_delta_decode() {
3717 let input: Vec<u8> = vec![4, 4, 9, 19];
3722 let mut output = vec![0u32; 5];
3723
3724 unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
3725
3726 assert_eq!(output, vec![10, 15, 20, 30, 50]);
3727 }
3728
3729 #[test]
3730 fn test_unpack_rounded_delta_decode_zero() {
3731 let input: Vec<u8> = vec![];
3733 let mut output = vec![0u32; 5];
3734
3735 unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
3736
3737 assert_eq!(output, vec![100, 101, 102, 103, 104]);
3738 }
3739
3740 #[test]
3745 fn test_dequantize_uint8() {
3746 let input: Vec<u8> = vec![0, 128, 255, 64, 192];
3747 let mut output = vec![0.0f32; 5];
3748 let scale = 0.1;
3749 let min_val = 1.0;
3750
3751 dequantize_uint8(&input, &mut output, scale, min_val, 5);
3752
3753 assert!((output[0] - 1.0).abs() < 1e-6); assert!((output[1] - 13.8).abs() < 1e-6); assert!((output[2] - 26.5).abs() < 1e-6); assert!((output[3] - 7.4).abs() < 1e-6); assert!((output[4] - 20.2).abs() < 1e-6); }
3760
3761 #[test]
3762 fn test_dequantize_uint8_large() {
3763 let input: Vec<u8> = (0..128).collect();
3765 let mut output = vec![0.0f32; 128];
3766 let scale = 2.0;
3767 let min_val = -10.0;
3768
3769 dequantize_uint8(&input, &mut output, scale, min_val, 128);
3770
3771 for (i, &out) in output.iter().enumerate().take(128) {
3772 let expected = i as f32 * scale + min_val;
3773 assert!(
3774 (out - expected).abs() < 1e-5,
3775 "Mismatch at {}: expected {}, got {}",
3776 i,
3777 expected,
3778 out
3779 );
3780 }
3781 }
3782
3783 #[test]
3784 fn test_dot_product_f32() {
3785 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
3786 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
3787
3788 let result = dot_product_f32(&a, &b, 5);
3789
3790 assert!((result - 70.0).abs() < 1e-5);
3792 }
3793
3794 #[test]
3795 fn test_dot_product_f32_large() {
3796 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
3798 let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
3799
3800 let result = dot_product_f32(&a, &b, 128);
3801
3802 let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
3804 assert!(
3805 (result - expected).abs() < 1e-3,
3806 "Expected {}, got {}",
3807 expected,
3808 result
3809 );
3810 }
3811
3812 #[test]
3813 fn test_max_f32() {
3814 let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
3815 let result = max_f32(&values, 6);
3816 assert!((result - 9.0).abs() < 1e-6);
3817 }
3818
3819 #[test]
3820 fn test_max_f32_large() {
3821 let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
3823 values[77] = 1000.0;
3824
3825 let result = max_f32(&values, 128);
3826 assert!((result - 1000.0).abs() < 1e-5);
3827 }
3828
3829 #[test]
3830 fn test_max_f32_negative() {
3831 let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
3832 let result = max_f32(&values, 5);
3833 assert!((result - (-1.0)).abs() < 1e-6);
3834 }
3835
3836 #[test]
3837 fn test_max_f32_empty() {
3838 let values: Vec<f32> = vec![];
3839 let result = max_f32(&values, 0);
3840 assert_eq!(result, f32::NEG_INFINITY);
3841 }
3842
3843 #[test]
3844 fn test_fused_dot_norm() {
3845 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3846 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3847 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3848
3849 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3850 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3851 assert!(
3852 (dot - expected_dot).abs() < 1e-5,
3853 "dot: expected {}, got {}",
3854 expected_dot,
3855 dot
3856 );
3857 assert!(
3858 (norm_b - expected_norm).abs() < 1e-5,
3859 "norm: expected {}, got {}",
3860 expected_norm,
3861 norm_b
3862 );
3863 }
3864
3865 #[test]
3866 fn test_fused_dot_norm_large() {
3867 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
3868 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
3869 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3870
3871 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3872 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3873 assert!(
3874 (dot - expected_dot).abs() < 1.0,
3875 "dot: expected {}, got {}",
3876 expected_dot,
3877 dot
3878 );
3879 assert!(
3880 (norm_b - expected_norm).abs() < 1.0,
3881 "norm: expected {}, got {}",
3882 expected_norm,
3883 norm_b
3884 );
3885 }
3886
3887 #[test]
3888 fn test_batch_cosine_scores() {
3889 let query = vec![1.0f32, 0.0, 0.0];
3891 let vectors = vec![
3892 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.5, 0.5, 0.0, ];
3897 let mut scores = vec![0f32; 4];
3898 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3899
3900 assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
3901 assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
3902 assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
3903 let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
3904 assert!(
3905 (scores[3] - expected_45).abs() < 1e-5,
3906 "45deg: expected {}, got {}",
3907 expected_45,
3908 scores[3]
3909 );
3910 }
3911
3912 #[test]
3913 fn test_batch_cosine_scores_matches_individual() {
3914 let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
3915 let n = 50;
3916 let dim = 128;
3917 let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
3918
3919 let mut batch_scores = vec![0f32; n];
3920 batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
3921
3922 for i in 0..n {
3923 let vec_i = &vectors[i * dim..(i + 1) * dim];
3924 let individual = cosine_similarity(&query, vec_i);
3925 assert!(
3926 (batch_scores[i] - individual).abs() < 1e-5,
3927 "vec {}: batch={}, individual={}",
3928 i,
3929 batch_scores[i],
3930 individual
3931 );
3932 }
3933 }
3934
3935 #[test]
3936 fn test_batch_cosine_scores_empty() {
3937 let query = vec![1.0f32, 2.0, 3.0];
3938 let vectors: Vec<f32> = vec![];
3939 let mut scores: Vec<f32> = vec![];
3940 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3941 assert!(scores.is_empty());
3942 }
3943
3944 #[test]
3945 fn test_batch_cosine_scores_zero_query() {
3946 let query = vec![0.0f32, 0.0, 0.0];
3947 let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
3948 let mut scores = vec![0f32; 2];
3949 batch_cosine_scores(&query, &vectors, 3, &mut scores);
3950 assert_eq!(scores[0], 0.0);
3951 assert_eq!(scores[1], 0.0);
3952 }
3953
3954 #[test]
3955 fn test_squared_euclidean_distance() {
3956 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3957 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3958 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3959 let result = squared_euclidean_distance(&a, &b);
3960 assert!(
3961 (result - expected).abs() < 1e-5,
3962 "expected {}, got {}",
3963 expected,
3964 result
3965 );
3966 }
3967
3968 #[test]
3969 fn test_squared_euclidean_distance_large() {
3970 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
3971 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
3972 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3973 let result = squared_euclidean_distance(&a, &b);
3974 assert!(
3975 (result - expected).abs() < 1e-3,
3976 "expected {}, got {}",
3977 expected,
3978 result
3979 );
3980 }
3981
3982 #[test]
3987 fn test_f16_roundtrip_normal() {
3988 for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
3989 let h = f32_to_f16(v);
3990 let back = f16_to_f32(h);
3991 let err = (back - v).abs() / v.abs().max(1e-6);
3992 assert!(
3993 err < 0.002,
3994 "f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
3995 );
3996 }
3997 }
3998
3999 #[test]
4000 fn test_f16_special() {
4001 assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
4003 assert_eq!(f32_to_f16(-0.0), 0x8000);
4005 assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
4007 assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
4009 }
4010
4011 #[test]
4012 fn test_f16_embedding_range() {
4013 let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
4015 for &v in &values {
4016 let back = f16_to_f32(f32_to_f16(v));
4017 assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
4018 }
4019 }
4020
4021 #[test]
4026 fn test_u8_roundtrip() {
4027 assert_eq!(f32_to_u8_saturating(-1.0), 0);
4029 assert_eq!(f32_to_u8_saturating(1.0), 255);
4030 assert_eq!(f32_to_u8_saturating(0.0), 127); assert_eq!(f32_to_u8_saturating(-2.0), 0);
4034 assert_eq!(f32_to_u8_saturating(2.0), 255);
4035 }
4036
4037 #[test]
4038 fn test_u8_dequantize() {
4039 assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
4040 assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
4041 assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
4042 }
4043
4044 #[test]
4049 fn test_batch_cosine_scores_f16() {
4050 let query = vec![0.6f32, 0.8, 0.0, 0.0];
4051 let dim = 4;
4052 let vecs_f32 = vec![
4053 0.6f32, 0.8, 0.0, 0.0, 0.0, 0.0, 0.6, 0.8, ];
4056
4057 let mut f16_buf = vec![0u16; 8];
4059 batch_f32_to_f16(&vecs_f32, &mut f16_buf);
4060 let raw: &[u8] =
4061 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
4062
4063 let mut scores = vec![0f32; 2];
4064 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
4065
4066 assert!(
4067 (scores[0] - 1.0).abs() < 0.01,
4068 "identical vectors: {}",
4069 scores[0]
4070 );
4071 assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
4072 }
4073
4074 #[test]
4075 fn test_batch_cosine_scores_u8() {
4076 let query = vec![0.6f32, 0.8, 0.0, 0.0];
4077 let dim = 4;
4078 let vecs_f32 = vec![
4079 0.6f32, 0.8, 0.0, 0.0, -0.6, -0.8, 0.0, 0.0, ];
4082
4083 let mut u8_buf = vec![0u8; 8];
4085 batch_f32_to_u8(&vecs_f32, &mut u8_buf);
4086
4087 let mut scores = vec![0f32; 2];
4088 batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
4089
4090 assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
4091 assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
4092 }
4093
4094 #[test]
4095 fn test_batch_cosine_scores_f16_large_dim() {
4096 let dim = 768;
4098 let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
4099 let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
4100
4101 let mut all_vecs = query.clone();
4102 all_vecs.extend_from_slice(&vec2);
4103
4104 let mut f16_buf = vec![0u16; all_vecs.len()];
4105 batch_f32_to_f16(&all_vecs, &mut f16_buf);
4106 let raw: &[u8] =
4107 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
4108
4109 let mut scores = vec![0f32; 2];
4110 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
4111
4112 assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
4114 assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
4116 }
4117}
4118
4119#[inline]
4132pub fn find_first_ge_u32(slice: &[u32], target: u32) -> usize {
4133 #[cfg(target_arch = "aarch64")]
4134 {
4135 if neon::is_available() {
4136 return unsafe { find_first_ge_u32_neon(slice, target) };
4137 }
4138 }
4139
4140 #[cfg(target_arch = "x86_64")]
4141 {
4142 if sse::is_available() {
4143 return unsafe { find_first_ge_u32_sse(slice, target) };
4144 }
4145 }
4146
4147 slice.partition_point(|&d| d < target)
4149}
4150
4151#[cfg(target_arch = "aarch64")]
4152#[target_feature(enable = "neon")]
4153#[allow(unsafe_op_in_unsafe_fn)]
4154unsafe fn find_first_ge_u32_neon(slice: &[u32], target: u32) -> usize {
4155 use std::arch::aarch64::*;
4156
4157 let n = slice.len();
4158 let ptr = slice.as_ptr();
4159 let target_vec = vdupq_n_u32(target);
4160 let bit_mask: uint32x4_t = core::mem::transmute([1u32, 2u32, 4u32, 8u32]);
4162
4163 let chunks = n / 16;
4164 let mut base = 0usize;
4165
4166 for _ in 0..chunks {
4168 let v0 = vld1q_u32(ptr.add(base));
4169 let v1 = vld1q_u32(ptr.add(base + 4));
4170 let v2 = vld1q_u32(ptr.add(base + 8));
4171 let v3 = vld1q_u32(ptr.add(base + 12));
4172
4173 let c0 = vcgeq_u32(v0, target_vec);
4174 let c1 = vcgeq_u32(v1, target_vec);
4175 let c2 = vcgeq_u32(v2, target_vec);
4176 let c3 = vcgeq_u32(v3, target_vec);
4177
4178 let m0 = vaddvq_u32(vandq_u32(c0, bit_mask));
4179 if m0 != 0 {
4180 return base + m0.trailing_zeros() as usize;
4181 }
4182 let m1 = vaddvq_u32(vandq_u32(c1, bit_mask));
4183 if m1 != 0 {
4184 return base + 4 + m1.trailing_zeros() as usize;
4185 }
4186 let m2 = vaddvq_u32(vandq_u32(c2, bit_mask));
4187 if m2 != 0 {
4188 return base + 8 + m2.trailing_zeros() as usize;
4189 }
4190 let m3 = vaddvq_u32(vandq_u32(c3, bit_mask));
4191 if m3 != 0 {
4192 return base + 12 + m3.trailing_zeros() as usize;
4193 }
4194 base += 16;
4195 }
4196
4197 while base + 4 <= n {
4199 let vals = vld1q_u32(ptr.add(base));
4200 let cmp = vcgeq_u32(vals, target_vec);
4201 let mask = vaddvq_u32(vandq_u32(cmp, bit_mask));
4202 if mask != 0 {
4203 return base + mask.trailing_zeros() as usize;
4204 }
4205 base += 4;
4206 }
4207
4208 while base < n {
4210 if *slice.get_unchecked(base) >= target {
4211 return base;
4212 }
4213 base += 1;
4214 }
4215 n
4216}
4217
4218#[cfg(target_arch = "x86_64")]
4219#[target_feature(enable = "sse2")]
4220#[allow(unsafe_op_in_unsafe_fn)]
4221unsafe fn find_first_ge_u32_sse(slice: &[u32], target: u32) -> usize {
4222 use std::arch::x86_64::*;
4223
4224 let n = slice.len();
4225 let ptr = slice.as_ptr();
4226
4227 let sign_flip = _mm_set1_epi32(i32::MIN);
4229 let target_xor = _mm_xor_si128(_mm_set1_epi32(target as i32), sign_flip);
4230
4231 let chunks = n / 16;
4232 let mut base = 0usize;
4233
4234 for _ in 0..chunks {
4236 let v0 = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4237 let v1 = _mm_xor_si128(
4238 _mm_loadu_si128(ptr.add(base + 4) as *const __m128i),
4239 sign_flip,
4240 );
4241 let v2 = _mm_xor_si128(
4242 _mm_loadu_si128(ptr.add(base + 8) as *const __m128i),
4243 sign_flip,
4244 );
4245 let v3 = _mm_xor_si128(
4246 _mm_loadu_si128(ptr.add(base + 12) as *const __m128i),
4247 sign_flip,
4248 );
4249
4250 let ge0 = _mm_or_si128(
4252 _mm_cmpeq_epi32(v0, target_xor),
4253 _mm_cmpgt_epi32(v0, target_xor),
4254 );
4255 let m0 = _mm_movemask_ps(_mm_castsi128_ps(ge0)) as u32;
4256 if m0 != 0 {
4257 return base + m0.trailing_zeros() as usize;
4258 }
4259
4260 let ge1 = _mm_or_si128(
4261 _mm_cmpeq_epi32(v1, target_xor),
4262 _mm_cmpgt_epi32(v1, target_xor),
4263 );
4264 let m1 = _mm_movemask_ps(_mm_castsi128_ps(ge1)) as u32;
4265 if m1 != 0 {
4266 return base + 4 + m1.trailing_zeros() as usize;
4267 }
4268
4269 let ge2 = _mm_or_si128(
4270 _mm_cmpeq_epi32(v2, target_xor),
4271 _mm_cmpgt_epi32(v2, target_xor),
4272 );
4273 let m2 = _mm_movemask_ps(_mm_castsi128_ps(ge2)) as u32;
4274 if m2 != 0 {
4275 return base + 8 + m2.trailing_zeros() as usize;
4276 }
4277
4278 let ge3 = _mm_or_si128(
4279 _mm_cmpeq_epi32(v3, target_xor),
4280 _mm_cmpgt_epi32(v3, target_xor),
4281 );
4282 let m3 = _mm_movemask_ps(_mm_castsi128_ps(ge3)) as u32;
4283 if m3 != 0 {
4284 return base + 12 + m3.trailing_zeros() as usize;
4285 }
4286 base += 16;
4287 }
4288
4289 while base + 4 <= n {
4291 let vals = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4292 let ge = _mm_or_si128(
4293 _mm_cmpeq_epi32(vals, target_xor),
4294 _mm_cmpgt_epi32(vals, target_xor),
4295 );
4296 let mask = _mm_movemask_ps(_mm_castsi128_ps(ge)) as u32;
4297 if mask != 0 {
4298 return base + mask.trailing_zeros() as usize;
4299 }
4300 base += 4;
4301 }
4302
4303 while base < n {
4305 if *slice.get_unchecked(base) >= target {
4306 return base;
4307 }
4308 base += 1;
4309 }
4310 n
4311}
4312
4313#[cfg(test)]
4314mod find_first_ge_tests {
4315 use super::find_first_ge_u32;
4316
4317 #[test]
4318 fn test_find_first_ge_basic() {
4319 let data: Vec<u32> = (0..128).map(|i| i * 3).collect(); assert_eq!(find_first_ge_u32(&data, 0), 0);
4321 assert_eq!(find_first_ge_u32(&data, 1), 1); assert_eq!(find_first_ge_u32(&data, 3), 1);
4323 assert_eq!(find_first_ge_u32(&data, 4), 2); assert_eq!(find_first_ge_u32(&data, 381), 127);
4325 assert_eq!(find_first_ge_u32(&data, 382), 128); }
4327
4328 #[test]
4329 fn test_find_first_ge_matches_partition_point() {
4330 let data: Vec<u32> = vec![1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75];
4331 for target in 0..80 {
4332 let expected = data.partition_point(|&d| d < target);
4333 let actual = find_first_ge_u32(&data, target);
4334 assert_eq!(actual, expected, "target={}", target);
4335 }
4336 }
4337
4338 #[test]
4339 fn test_find_first_ge_small_slices() {
4340 assert_eq!(find_first_ge_u32(&[], 5), 0);
4342 assert_eq!(find_first_ge_u32(&[10], 5), 0);
4344 assert_eq!(find_first_ge_u32(&[10], 10), 0);
4345 assert_eq!(find_first_ge_u32(&[10], 11), 1);
4346 assert_eq!(find_first_ge_u32(&[2, 4, 6], 5), 2);
4348 }
4349
4350 #[test]
4351 fn test_find_first_ge_full_block() {
4352 let data: Vec<u32> = (100..228).collect();
4354 assert_eq!(find_first_ge_u32(&data, 100), 0);
4355 assert_eq!(find_first_ge_u32(&data, 150), 50);
4356 assert_eq!(find_first_ge_u32(&data, 227), 127);
4357 assert_eq!(find_first_ge_u32(&data, 228), 128);
4358 assert_eq!(find_first_ge_u32(&data, 99), 0);
4359 }
4360
4361 #[test]
4362 fn test_find_first_ge_u32_max() {
4363 let data = vec![u32::MAX - 10, u32::MAX - 5, u32::MAX - 1, u32::MAX];
4365 assert_eq!(find_first_ge_u32(&data, u32::MAX - 10), 0);
4366 assert_eq!(find_first_ge_u32(&data, u32::MAX - 7), 1);
4367 assert_eq!(find_first_ge_u32(&data, u32::MAX), 3);
4368 }
4369}