1#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20 use std::arch::aarch64::*;
21
22 #[target_feature(enable = "neon")]
24 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25 let chunks = count / 16;
26 let remainder = count % 16;
27
28 for chunk in 0..chunks {
29 let base = chunk * 16;
30 let in_ptr = input.as_ptr().add(base);
31
32 let bytes = vld1q_u8(in_ptr);
34
35 let low8 = vget_low_u8(bytes);
37 let high8 = vget_high_u8(bytes);
38
39 let low16 = vmovl_u8(low8);
40 let high16 = vmovl_u8(high8);
41
42 let v0 = vmovl_u16(vget_low_u16(low16));
43 let v1 = vmovl_u16(vget_high_u16(low16));
44 let v2 = vmovl_u16(vget_low_u16(high16));
45 let v3 = vmovl_u16(vget_high_u16(high16));
46
47 let out_ptr = output.as_mut_ptr().add(base);
48 vst1q_u32(out_ptr, v0);
49 vst1q_u32(out_ptr.add(4), v1);
50 vst1q_u32(out_ptr.add(8), v2);
51 vst1q_u32(out_ptr.add(12), v3);
52 }
53
54 let base = chunks * 16;
56 for i in 0..remainder {
57 output[base + i] = input[base + i] as u32;
58 }
59 }
60
61 #[target_feature(enable = "neon")]
63 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64 let chunks = count / 8;
65 let remainder = count % 8;
66
67 for chunk in 0..chunks {
68 let base = chunk * 8;
69 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71 let vals = vld1q_u16(in_ptr);
72 let low = vmovl_u16(vget_low_u16(vals));
73 let high = vmovl_u16(vget_high_u16(vals));
74
75 let out_ptr = output.as_mut_ptr().add(base);
76 vst1q_u32(out_ptr, low);
77 vst1q_u32(out_ptr.add(4), high);
78 }
79
80 let base = chunks * 8;
82 for i in 0..remainder {
83 let idx = (base + i) * 2;
84 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85 }
86 }
87
88 #[target_feature(enable = "neon")]
90 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91 let chunks = count / 4;
92 let remainder = count % 4;
93
94 let in_ptr = input.as_ptr() as *const u32;
95 let out_ptr = output.as_mut_ptr();
96
97 for chunk in 0..chunks {
98 let vals = vld1q_u32(in_ptr.add(chunk * 4));
99 vst1q_u32(out_ptr.add(chunk * 4), vals);
100 }
101
102 let base = chunks * 4;
104 for i in 0..remainder {
105 let idx = (base + i) * 4;
106 output[base + i] =
107 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108 }
109 }
110
111 #[inline]
115 #[target_feature(enable = "neon")]
116 unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120 let sum1 = vaddq_u32(v, shifted1);
121
122 let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125 vaddq_u32(sum1, shifted2)
126 }
127
128 #[target_feature(enable = "neon")]
132 pub unsafe fn delta_decode(
133 output: &mut [u32],
134 deltas: &[u32],
135 first_doc_id: u32,
136 count: usize,
137 ) {
138 if count == 0 {
139 return;
140 }
141
142 output[0] = first_doc_id;
143 if count == 1 {
144 return;
145 }
146
147 let ones = vdupq_n_u32(1);
148 let mut carry = vdupq_n_u32(first_doc_id);
149
150 let full_groups = (count - 1) / 4;
151 let remainder = (count - 1) % 4;
152
153 for group in 0..full_groups {
154 let base = group * 4;
155
156 let d = vld1q_u32(deltas[base..].as_ptr());
158 let gaps = vaddq_u32(d, ones);
159
160 let prefix = prefix_sum_4(gaps);
162
163 let result = vaddq_u32(prefix, carry);
165
166 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171 }
172
173 let base = full_groups * 4;
175 let mut scalar_carry = vgetq_lane_u32(carry, 0);
176 for j in 0..remainder {
177 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178 output[base + j + 1] = scalar_carry;
179 }
180 }
181
182 #[target_feature(enable = "neon")]
184 pub unsafe fn add_one(values: &mut [u32], count: usize) {
185 let ones = vdupq_n_u32(1);
186 let chunks = count / 4;
187 let remainder = count % 4;
188
189 for chunk in 0..chunks {
190 let base = chunk * 4;
191 let ptr = values.as_mut_ptr().add(base);
192 let v = vld1q_u32(ptr);
193 let result = vaddq_u32(v, ones);
194 vst1q_u32(ptr, result);
195 }
196
197 let base = chunks * 4;
198 for i in 0..remainder {
199 values[base + i] += 1;
200 }
201 }
202
203 #[target_feature(enable = "neon")]
206 pub unsafe fn unpack_8bit_delta_decode(
207 input: &[u8],
208 output: &mut [u32],
209 first_value: u32,
210 count: usize,
211 ) {
212 output[0] = first_value;
213 if count <= 1 {
214 return;
215 }
216
217 let ones = vdupq_n_u32(1);
218 let mut carry = vdupq_n_u32(first_value);
219
220 let full_groups = (count - 1) / 4;
221 let remainder = (count - 1) % 4;
222
223 for group in 0..full_groups {
224 let base = group * 4;
225
226 let b0 = input[base] as u32;
228 let b1 = input[base + 1] as u32;
229 let b2 = input[base + 2] as u32;
230 let b3 = input[base + 3] as u32;
231 let deltas = [b0, b1, b2, b3];
232 let d = vld1q_u32(deltas.as_ptr());
233
234 let gaps = vaddq_u32(d, ones);
236
237 let prefix = prefix_sum_4(gaps);
239
240 let result = vaddq_u32(prefix, carry);
242
243 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248 }
249
250 let base = full_groups * 4;
252 let mut scalar_carry = vgetq_lane_u32(carry, 0);
253 for j in 0..remainder {
254 scalar_carry = scalar_carry
255 .wrapping_add(input[base + j] as u32)
256 .wrapping_add(1);
257 output[base + j + 1] = scalar_carry;
258 }
259 }
260
261 #[target_feature(enable = "neon")]
263 pub unsafe fn unpack_16bit_delta_decode(
264 input: &[u8],
265 output: &mut [u32],
266 first_value: u32,
267 count: usize,
268 ) {
269 output[0] = first_value;
270 if count <= 1 {
271 return;
272 }
273
274 let ones = vdupq_n_u32(1);
275 let mut carry = vdupq_n_u32(first_value);
276
277 let full_groups = (count - 1) / 4;
278 let remainder = (count - 1) % 4;
279
280 for group in 0..full_groups {
281 let base = group * 4;
282 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284 let vals = vld1_u16(in_ptr);
286 let d = vmovl_u16(vals);
287
288 let gaps = vaddq_u32(d, ones);
290
291 let prefix = prefix_sum_4(gaps);
293
294 let result = vaddq_u32(prefix, carry);
296
297 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302 }
303
304 let base = full_groups * 4;
306 let mut scalar_carry = vgetq_lane_u32(carry, 0);
307 for j in 0..remainder {
308 let idx = (base + j) * 2;
309 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311 output[base + j + 1] = scalar_carry;
312 }
313 }
314
315 #[inline]
317 pub fn is_available() -> bool {
318 true
319 }
320}
321
322#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329 use std::arch::x86_64::*;
330
331 #[target_feature(enable = "sse2", enable = "sse4.1")]
333 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334 let chunks = count / 16;
335 let remainder = count % 16;
336
337 for chunk in 0..chunks {
338 let base = chunk * 16;
339 let in_ptr = input.as_ptr().add(base);
340
341 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343 let v0 = _mm_cvtepu8_epi32(bytes);
345 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349 let out_ptr = output.as_mut_ptr().add(base);
350 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354 }
355
356 let base = chunks * 16;
357 for i in 0..remainder {
358 output[base + i] = input[base + i] as u32;
359 }
360 }
361
362 #[target_feature(enable = "sse2", enable = "sse4.1")]
364 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365 let chunks = count / 8;
366 let remainder = count % 8;
367
368 for chunk in 0..chunks {
369 let base = chunk * 8;
370 let in_ptr = input.as_ptr().add(base * 2);
371
372 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373 let low = _mm_cvtepu16_epi32(vals);
374 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376 let out_ptr = output.as_mut_ptr().add(base);
377 _mm_storeu_si128(out_ptr as *mut __m128i, low);
378 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379 }
380
381 let base = chunks * 8;
382 for i in 0..remainder {
383 let idx = (base + i) * 2;
384 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385 }
386 }
387
388 #[target_feature(enable = "sse2")]
390 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391 let chunks = count / 4;
392 let remainder = count % 4;
393
394 let in_ptr = input.as_ptr() as *const __m128i;
395 let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397 for chunk in 0..chunks {
398 let vals = _mm_loadu_si128(in_ptr.add(chunk));
399 _mm_storeu_si128(out_ptr.add(chunk), vals);
400 }
401
402 let base = chunks * 4;
404 for i in 0..remainder {
405 let idx = (base + i) * 4;
406 output[base + i] =
407 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408 }
409 }
410
411 #[inline]
415 #[target_feature(enable = "sse2")]
416 unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417 let shifted1 = _mm_slli_si128(v, 4);
420 let sum1 = _mm_add_epi32(v, shifted1);
421
422 let shifted2 = _mm_slli_si128(sum1, 8);
425 _mm_add_epi32(sum1, shifted2)
426 }
427
428 #[target_feature(enable = "sse2", enable = "sse4.1")]
430 pub unsafe fn delta_decode(
431 output: &mut [u32],
432 deltas: &[u32],
433 first_doc_id: u32,
434 count: usize,
435 ) {
436 if count == 0 {
437 return;
438 }
439
440 output[0] = first_doc_id;
441 if count == 1 {
442 return;
443 }
444
445 let ones = _mm_set1_epi32(1);
446 let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448 let full_groups = (count - 1) / 4;
449 let remainder = (count - 1) % 4;
450
451 for group in 0..full_groups {
452 let base = group * 4;
453
454 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456 let gaps = _mm_add_epi32(d, ones);
457
458 let prefix = prefix_sum_4(gaps);
460
461 let result = _mm_add_epi32(prefix, carry);
463
464 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467 carry = _mm_shuffle_epi32(result, 0xFF); }
470
471 let base = full_groups * 4;
473 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474 for j in 0..remainder {
475 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476 output[base + j + 1] = scalar_carry;
477 }
478 }
479
480 #[target_feature(enable = "sse2")]
482 pub unsafe fn add_one(values: &mut [u32], count: usize) {
483 let ones = _mm_set1_epi32(1);
484 let chunks = count / 4;
485 let remainder = count % 4;
486
487 for chunk in 0..chunks {
488 let base = chunk * 4;
489 let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490 let v = _mm_loadu_si128(ptr);
491 let result = _mm_add_epi32(v, ones);
492 _mm_storeu_si128(ptr, result);
493 }
494
495 let base = chunks * 4;
496 for i in 0..remainder {
497 values[base + i] += 1;
498 }
499 }
500
501 #[target_feature(enable = "sse2", enable = "sse4.1")]
503 pub unsafe fn unpack_8bit_delta_decode(
504 input: &[u8],
505 output: &mut [u32],
506 first_value: u32,
507 count: usize,
508 ) {
509 output[0] = first_value;
510 if count <= 1 {
511 return;
512 }
513
514 let ones = _mm_set1_epi32(1);
515 let mut carry = _mm_set1_epi32(first_value as i32);
516
517 let full_groups = (count - 1) / 4;
518 let remainder = (count - 1) % 4;
519
520 for group in 0..full_groups {
521 let base = group * 4;
522
523 let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
525 input.as_ptr().add(base) as *const i32
526 ));
527 let d = _mm_cvtepu8_epi32(bytes);
528
529 let gaps = _mm_add_epi32(d, ones);
531
532 let prefix = prefix_sum_4(gaps);
534
535 let result = _mm_add_epi32(prefix, carry);
537
538 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
540
541 carry = _mm_shuffle_epi32(result, 0xFF);
543 }
544
545 let base = full_groups * 4;
547 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
548 for j in 0..remainder {
549 scalar_carry = scalar_carry
550 .wrapping_add(input[base + j] as u32)
551 .wrapping_add(1);
552 output[base + j + 1] = scalar_carry;
553 }
554 }
555
556 #[target_feature(enable = "sse2", enable = "sse4.1")]
558 pub unsafe fn unpack_16bit_delta_decode(
559 input: &[u8],
560 output: &mut [u32],
561 first_value: u32,
562 count: usize,
563 ) {
564 output[0] = first_value;
565 if count <= 1 {
566 return;
567 }
568
569 let ones = _mm_set1_epi32(1);
570 let mut carry = _mm_set1_epi32(first_value as i32);
571
572 let full_groups = (count - 1) / 4;
573 let remainder = (count - 1) % 4;
574
575 for group in 0..full_groups {
576 let base = group * 4;
577 let in_ptr = input.as_ptr().add(base * 2);
578
579 let vals = _mm_loadl_epi64(in_ptr as *const __m128i); let d = _mm_cvtepu16_epi32(vals);
582
583 let gaps = _mm_add_epi32(d, ones);
585
586 let prefix = prefix_sum_4(gaps);
588
589 let result = _mm_add_epi32(prefix, carry);
591
592 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
594
595 carry = _mm_shuffle_epi32(result, 0xFF);
597 }
598
599 let base = full_groups * 4;
601 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
602 for j in 0..remainder {
603 let idx = (base + j) * 2;
604 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
605 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
606 output[base + j + 1] = scalar_carry;
607 }
608 }
609
610 #[inline]
612 pub fn is_available() -> bool {
613 is_x86_feature_detected!("sse4.1")
614 }
615}
616
617#[cfg(target_arch = "x86_64")]
622#[allow(unsafe_op_in_unsafe_fn)]
623mod avx2 {
624 use std::arch::x86_64::*;
625
626 #[target_feature(enable = "avx2")]
628 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
629 let chunks = count / 32;
630 let remainder = count % 32;
631
632 for chunk in 0..chunks {
633 let base = chunk * 32;
634 let in_ptr = input.as_ptr().add(base);
635
636 let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
638 let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
639
640 let v0 = _mm256_cvtepu8_epi32(bytes_lo);
642 let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
643 let v2 = _mm256_cvtepu8_epi32(bytes_hi);
644 let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
645
646 let out_ptr = output.as_mut_ptr().add(base);
647 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
648 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
649 _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
650 _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
651 }
652
653 let base = chunks * 32;
655 for i in 0..remainder {
656 output[base + i] = input[base + i] as u32;
657 }
658 }
659
660 #[target_feature(enable = "avx2")]
662 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
663 let chunks = count / 16;
664 let remainder = count % 16;
665
666 for chunk in 0..chunks {
667 let base = chunk * 16;
668 let in_ptr = input.as_ptr().add(base * 2);
669
670 let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
672 let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
673
674 let v0 = _mm256_cvtepu16_epi32(vals_lo);
676 let v1 = _mm256_cvtepu16_epi32(vals_hi);
677
678 let out_ptr = output.as_mut_ptr().add(base);
679 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
680 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
681 }
682
683 let base = chunks * 16;
685 for i in 0..remainder {
686 let idx = (base + i) * 2;
687 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
688 }
689 }
690
691 #[target_feature(enable = "avx2")]
693 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
694 let chunks = count / 8;
695 let remainder = count % 8;
696
697 let in_ptr = input.as_ptr() as *const __m256i;
698 let out_ptr = output.as_mut_ptr() as *mut __m256i;
699
700 for chunk in 0..chunks {
701 let vals = _mm256_loadu_si256(in_ptr.add(chunk));
702 _mm256_storeu_si256(out_ptr.add(chunk), vals);
703 }
704
705 let base = chunks * 8;
707 for i in 0..remainder {
708 let idx = (base + i) * 4;
709 output[base + i] =
710 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
711 }
712 }
713
714 #[target_feature(enable = "avx2")]
716 pub unsafe fn add_one(values: &mut [u32], count: usize) {
717 let ones = _mm256_set1_epi32(1);
718 let chunks = count / 8;
719 let remainder = count % 8;
720
721 for chunk in 0..chunks {
722 let base = chunk * 8;
723 let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
724 let v = _mm256_loadu_si256(ptr);
725 let result = _mm256_add_epi32(v, ones);
726 _mm256_storeu_si256(ptr, result);
727 }
728
729 let base = chunks * 8;
730 for i in 0..remainder {
731 values[base + i] += 1;
732 }
733 }
734
735 #[inline]
737 pub fn is_available() -> bool {
738 is_x86_feature_detected!("avx2")
739 }
740}
741
742#[allow(dead_code)]
747mod scalar {
748 #[inline]
750 pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
751 for i in 0..count {
752 output[i] = input[i] as u32;
753 }
754 }
755
756 #[inline]
758 pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
759 for (i, out) in output.iter_mut().enumerate().take(count) {
760 let idx = i * 2;
761 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
762 }
763 }
764
765 #[inline]
767 pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
768 for (i, out) in output.iter_mut().enumerate().take(count) {
769 let idx = i * 4;
770 *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
771 }
772 }
773
774 #[inline]
776 pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
777 if count == 0 {
778 return;
779 }
780
781 output[0] = first_doc_id;
782 let mut carry = first_doc_id;
783
784 for i in 0..count - 1 {
785 carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
786 output[i + 1] = carry;
787 }
788 }
789
790 #[inline]
792 pub fn add_one(values: &mut [u32], count: usize) {
793 for val in values.iter_mut().take(count) {
794 *val += 1;
795 }
796 }
797}
798
799#[inline]
805pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
806 #[cfg(target_arch = "aarch64")]
807 {
808 if neon::is_available() {
809 unsafe {
810 neon::unpack_8bit(input, output, count);
811 }
812 return;
813 }
814 }
815
816 #[cfg(target_arch = "x86_64")]
817 {
818 if avx2::is_available() {
820 unsafe {
821 avx2::unpack_8bit(input, output, count);
822 }
823 return;
824 }
825 if sse::is_available() {
826 unsafe {
827 sse::unpack_8bit(input, output, count);
828 }
829 return;
830 }
831 }
832
833 scalar::unpack_8bit(input, output, count);
834}
835
836#[inline]
838pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
839 #[cfg(target_arch = "aarch64")]
840 {
841 if neon::is_available() {
842 unsafe {
843 neon::unpack_16bit(input, output, count);
844 }
845 return;
846 }
847 }
848
849 #[cfg(target_arch = "x86_64")]
850 {
851 if avx2::is_available() {
853 unsafe {
854 avx2::unpack_16bit(input, output, count);
855 }
856 return;
857 }
858 if sse::is_available() {
859 unsafe {
860 sse::unpack_16bit(input, output, count);
861 }
862 return;
863 }
864 }
865
866 scalar::unpack_16bit(input, output, count);
867}
868
869#[inline]
871pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
872 #[cfg(target_arch = "aarch64")]
873 {
874 if neon::is_available() {
875 unsafe {
876 neon::unpack_32bit(input, output, count);
877 }
878 }
879 }
880
881 #[cfg(target_arch = "x86_64")]
882 {
883 if avx2::is_available() {
885 unsafe {
886 avx2::unpack_32bit(input, output, count);
887 }
888 } else {
889 unsafe {
891 sse::unpack_32bit(input, output, count);
892 }
893 }
894 }
895
896 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
897 {
898 scalar::unpack_32bit(input, output, count);
899 }
900}
901
902#[inline]
908pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
909 #[cfg(target_arch = "aarch64")]
910 {
911 if neon::is_available() {
912 unsafe {
913 neon::delta_decode(output, deltas, first_value, count);
914 }
915 return;
916 }
917 }
918
919 #[cfg(target_arch = "x86_64")]
920 {
921 if sse::is_available() {
922 unsafe {
923 sse::delta_decode(output, deltas, first_value, count);
924 }
925 return;
926 }
927 }
928
929 scalar::delta_decode(output, deltas, first_value, count);
930}
931
932#[inline]
936pub fn add_one(values: &mut [u32], count: usize) {
937 #[cfg(target_arch = "aarch64")]
938 {
939 if neon::is_available() {
940 unsafe {
941 neon::add_one(values, count);
942 }
943 }
944 }
945
946 #[cfg(target_arch = "x86_64")]
947 {
948 if avx2::is_available() {
950 unsafe {
951 avx2::add_one(values, count);
952 }
953 } else {
954 unsafe {
956 sse::add_one(values, count);
957 }
958 }
959 }
960
961 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
962 {
963 scalar::add_one(values, count);
964 }
965}
966
967#[inline]
969pub fn bits_needed(val: u32) -> u8 {
970 if val == 0 {
971 0
972 } else {
973 32 - val.leading_zeros() as u8
974 }
975}
976
977#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994#[repr(u8)]
995pub enum RoundedBitWidth {
996 Zero = 0,
997 Bits8 = 8,
998 Bits16 = 16,
999 Bits32 = 32,
1000}
1001
1002impl RoundedBitWidth {
1003 #[inline]
1005 pub fn from_exact(bits: u8) -> Self {
1006 match bits {
1007 0 => RoundedBitWidth::Zero,
1008 1..=8 => RoundedBitWidth::Bits8,
1009 9..=16 => RoundedBitWidth::Bits16,
1010 _ => RoundedBitWidth::Bits32,
1011 }
1012 }
1013
1014 #[inline]
1016 pub fn from_u8(bits: u8) -> Self {
1017 match bits {
1018 0 => RoundedBitWidth::Zero,
1019 8 => RoundedBitWidth::Bits8,
1020 16 => RoundedBitWidth::Bits16,
1021 32 => RoundedBitWidth::Bits32,
1022 _ => RoundedBitWidth::Bits32, }
1024 }
1025
1026 #[inline]
1028 pub fn bytes_per_value(self) -> usize {
1029 match self {
1030 RoundedBitWidth::Zero => 0,
1031 RoundedBitWidth::Bits8 => 1,
1032 RoundedBitWidth::Bits16 => 2,
1033 RoundedBitWidth::Bits32 => 4,
1034 }
1035 }
1036
1037 #[inline]
1039 pub fn as_u8(self) -> u8 {
1040 self as u8
1041 }
1042}
1043
1044#[inline]
1046pub fn round_bit_width(bits: u8) -> u8 {
1047 RoundedBitWidth::from_exact(bits).as_u8()
1048}
1049
1050#[inline]
1055pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1056 let count = values.len();
1057 match bit_width {
1058 RoundedBitWidth::Zero => 0,
1059 RoundedBitWidth::Bits8 => {
1060 for (i, &v) in values.iter().enumerate() {
1061 output[i] = v as u8;
1062 }
1063 count
1064 }
1065 RoundedBitWidth::Bits16 => {
1066 for (i, &v) in values.iter().enumerate() {
1067 let bytes = (v as u16).to_le_bytes();
1068 output[i * 2] = bytes[0];
1069 output[i * 2 + 1] = bytes[1];
1070 }
1071 count * 2
1072 }
1073 RoundedBitWidth::Bits32 => {
1074 for (i, &v) in values.iter().enumerate() {
1075 let bytes = v.to_le_bytes();
1076 output[i * 4] = bytes[0];
1077 output[i * 4 + 1] = bytes[1];
1078 output[i * 4 + 2] = bytes[2];
1079 output[i * 4 + 3] = bytes[3];
1080 }
1081 count * 4
1082 }
1083 }
1084}
1085
1086#[inline]
1090pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1091 match bit_width {
1092 RoundedBitWidth::Zero => {
1093 for out in output.iter_mut().take(count) {
1094 *out = 0;
1095 }
1096 }
1097 RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1098 RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1099 RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1100 }
1101}
1102
1103#[inline]
1107pub fn unpack_rounded_delta_decode(
1108 input: &[u8],
1109 bit_width: RoundedBitWidth,
1110 output: &mut [u32],
1111 first_value: u32,
1112 count: usize,
1113) {
1114 match bit_width {
1115 RoundedBitWidth::Zero => {
1116 let mut val = first_value;
1118 for out in output.iter_mut().take(count) {
1119 *out = val;
1120 val = val.wrapping_add(1);
1121 }
1122 }
1123 RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1124 RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1125 RoundedBitWidth::Bits32 => {
1126 unpack_32bit(input, output, count);
1128 if count > 0 {
1131 let mut carry = first_value;
1132 output[0] = first_value;
1133 for item in output.iter_mut().take(count).skip(1) {
1134 carry = carry.wrapping_add(*item).wrapping_add(1);
1136 *item = carry;
1137 }
1138 }
1139 }
1140 }
1141}
1142
1143#[inline]
1152pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1153 if count == 0 {
1154 return;
1155 }
1156
1157 output[0] = first_value;
1158 if count == 1 {
1159 return;
1160 }
1161
1162 #[cfg(target_arch = "aarch64")]
1163 {
1164 if neon::is_available() {
1165 unsafe {
1166 neon::unpack_8bit_delta_decode(input, output, first_value, count);
1167 }
1168 return;
1169 }
1170 }
1171
1172 #[cfg(target_arch = "x86_64")]
1173 {
1174 if sse::is_available() {
1175 unsafe {
1176 sse::unpack_8bit_delta_decode(input, output, first_value, count);
1177 }
1178 return;
1179 }
1180 }
1181
1182 let mut carry = first_value;
1184 for i in 0..count - 1 {
1185 carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1186 output[i + 1] = carry;
1187 }
1188}
1189
1190#[inline]
1192pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1193 if count == 0 {
1194 return;
1195 }
1196
1197 output[0] = first_value;
1198 if count == 1 {
1199 return;
1200 }
1201
1202 #[cfg(target_arch = "aarch64")]
1203 {
1204 if neon::is_available() {
1205 unsafe {
1206 neon::unpack_16bit_delta_decode(input, output, first_value, count);
1207 }
1208 return;
1209 }
1210 }
1211
1212 #[cfg(target_arch = "x86_64")]
1213 {
1214 if sse::is_available() {
1215 unsafe {
1216 sse::unpack_16bit_delta_decode(input, output, first_value, count);
1217 }
1218 return;
1219 }
1220 }
1221
1222 let mut carry = first_value;
1224 for i in 0..count - 1 {
1225 let idx = i * 2;
1226 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1227 carry = carry.wrapping_add(delta).wrapping_add(1);
1228 output[i + 1] = carry;
1229 }
1230}
1231
1232#[inline]
1237pub fn unpack_delta_decode(
1238 input: &[u8],
1239 bit_width: u8,
1240 output: &mut [u32],
1241 first_value: u32,
1242 count: usize,
1243) {
1244 if count == 0 {
1245 return;
1246 }
1247
1248 output[0] = first_value;
1249 if count == 1 {
1250 return;
1251 }
1252
1253 match bit_width {
1255 0 => {
1256 let mut val = first_value;
1258 for item in output.iter_mut().take(count).skip(1) {
1259 val = val.wrapping_add(1);
1260 *item = val;
1261 }
1262 }
1263 8 => unpack_8bit_delta_decode(input, output, first_value, count),
1264 16 => unpack_16bit_delta_decode(input, output, first_value, count),
1265 32 => {
1266 let mut carry = first_value;
1268 for i in 0..count - 1 {
1269 let idx = i * 4;
1270 let delta = u32::from_le_bytes([
1271 input[idx],
1272 input[idx + 1],
1273 input[idx + 2],
1274 input[idx + 3],
1275 ]);
1276 carry = carry.wrapping_add(delta).wrapping_add(1);
1277 output[i + 1] = carry;
1278 }
1279 }
1280 _ => {
1281 let mask = (1u64 << bit_width) - 1;
1283 let bit_width_usize = bit_width as usize;
1284 let mut bit_pos = 0usize;
1285 let input_ptr = input.as_ptr();
1286 let mut carry = first_value;
1287
1288 for i in 0..count - 1 {
1289 let byte_idx = bit_pos >> 3;
1290 let bit_offset = bit_pos & 7;
1291
1292 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1294 let delta = ((word >> bit_offset) & mask) as u32;
1295
1296 carry = carry.wrapping_add(delta).wrapping_add(1);
1297 output[i + 1] = carry;
1298 bit_pos += bit_width_usize;
1299 }
1300 }
1301 }
1302}
1303
1304#[inline]
1312pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1313 #[cfg(target_arch = "aarch64")]
1314 {
1315 if neon::is_available() {
1316 unsafe {
1317 dequantize_uint8_neon(input, output, scale, min_val, count);
1318 }
1319 return;
1320 }
1321 }
1322
1323 #[cfg(target_arch = "x86_64")]
1324 {
1325 if sse::is_available() {
1326 unsafe {
1327 dequantize_uint8_sse(input, output, scale, min_val, count);
1328 }
1329 return;
1330 }
1331 }
1332
1333 for i in 0..count {
1335 output[i] = input[i] as f32 * scale + min_val;
1336 }
1337}
1338
1339#[cfg(target_arch = "aarch64")]
1340#[target_feature(enable = "neon")]
1341#[allow(unsafe_op_in_unsafe_fn)]
1342unsafe fn dequantize_uint8_neon(
1343 input: &[u8],
1344 output: &mut [f32],
1345 scale: f32,
1346 min_val: f32,
1347 count: usize,
1348) {
1349 use std::arch::aarch64::*;
1350
1351 let scale_v = vdupq_n_f32(scale);
1352 let min_v = vdupq_n_f32(min_val);
1353
1354 let chunks = count / 16;
1355 let remainder = count % 16;
1356
1357 for chunk in 0..chunks {
1358 let base = chunk * 16;
1359 let in_ptr = input.as_ptr().add(base);
1360
1361 let bytes = vld1q_u8(in_ptr);
1363
1364 let low8 = vget_low_u8(bytes);
1366 let high8 = vget_high_u8(bytes);
1367
1368 let low16 = vmovl_u8(low8);
1369 let high16 = vmovl_u8(high8);
1370
1371 let u32_0 = vmovl_u16(vget_low_u16(low16));
1373 let u32_1 = vmovl_u16(vget_high_u16(low16));
1374 let u32_2 = vmovl_u16(vget_low_u16(high16));
1375 let u32_3 = vmovl_u16(vget_high_u16(high16));
1376
1377 let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1379 let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1380 let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1381 let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1382
1383 let out_ptr = output.as_mut_ptr().add(base);
1384 vst1q_f32(out_ptr, f32_0);
1385 vst1q_f32(out_ptr.add(4), f32_1);
1386 vst1q_f32(out_ptr.add(8), f32_2);
1387 vst1q_f32(out_ptr.add(12), f32_3);
1388 }
1389
1390 let base = chunks * 16;
1392 for i in 0..remainder {
1393 output[base + i] = input[base + i] as f32 * scale + min_val;
1394 }
1395}
1396
1397#[cfg(target_arch = "x86_64")]
1398#[target_feature(enable = "sse2", enable = "sse4.1")]
1399#[allow(unsafe_op_in_unsafe_fn)]
1400unsafe fn dequantize_uint8_sse(
1401 input: &[u8],
1402 output: &mut [f32],
1403 scale: f32,
1404 min_val: f32,
1405 count: usize,
1406) {
1407 use std::arch::x86_64::*;
1408
1409 let scale_v = _mm_set1_ps(scale);
1410 let min_v = _mm_set1_ps(min_val);
1411
1412 let chunks = count / 4;
1413 let remainder = count % 4;
1414
1415 for chunk in 0..chunks {
1416 let base = chunk * 4;
1417
1418 let b0 = input[base] as i32;
1420 let b1 = input[base + 1] as i32;
1421 let b2 = input[base + 2] as i32;
1422 let b3 = input[base + 3] as i32;
1423
1424 let ints = _mm_set_epi32(b3, b2, b1, b0);
1425 let floats = _mm_cvtepi32_ps(ints);
1426
1427 let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1429
1430 _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1431 }
1432
1433 let base = chunks * 4;
1435 for i in 0..remainder {
1436 output[base + i] = input[base + i] as f32 * scale + min_val;
1437 }
1438}
1439
1440#[inline]
1442pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1443 #[cfg(target_arch = "aarch64")]
1444 {
1445 if neon::is_available() {
1446 return unsafe { dot_product_f32_neon(a, b, count) };
1447 }
1448 }
1449
1450 #[cfg(target_arch = "x86_64")]
1451 {
1452 if sse::is_available() {
1453 return unsafe { dot_product_f32_sse(a, b, count) };
1454 }
1455 }
1456
1457 let mut sum = 0.0f32;
1459 for i in 0..count {
1460 sum += a[i] * b[i];
1461 }
1462 sum
1463}
1464
1465#[cfg(target_arch = "aarch64")]
1466#[target_feature(enable = "neon")]
1467#[allow(unsafe_op_in_unsafe_fn)]
1468unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1469 use std::arch::aarch64::*;
1470
1471 let chunks = count / 4;
1472 let remainder = count % 4;
1473
1474 let mut acc = vdupq_n_f32(0.0);
1475
1476 for chunk in 0..chunks {
1477 let base = chunk * 4;
1478 let va = vld1q_f32(a.as_ptr().add(base));
1479 let vb = vld1q_f32(b.as_ptr().add(base));
1480 acc = vfmaq_f32(acc, va, vb);
1481 }
1482
1483 let mut sum = vaddvq_f32(acc);
1485
1486 let base = chunks * 4;
1488 for i in 0..remainder {
1489 sum += a[base + i] * b[base + i];
1490 }
1491
1492 sum
1493}
1494
1495#[cfg(target_arch = "x86_64")]
1496#[target_feature(enable = "sse")]
1497#[allow(unsafe_op_in_unsafe_fn)]
1498unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1499 use std::arch::x86_64::*;
1500
1501 let chunks = count / 4;
1502 let remainder = count % 4;
1503
1504 let mut acc = _mm_setzero_ps();
1505
1506 for chunk in 0..chunks {
1507 let base = chunk * 4;
1508 let va = _mm_loadu_ps(a.as_ptr().add(base));
1509 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1510 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1511 }
1512
1513 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
1520
1521 let base = chunks * 4;
1523 for i in 0..remainder {
1524 sum += a[base + i] * b[base + i];
1525 }
1526
1527 sum
1528}
1529
1530#[inline]
1532pub fn max_f32(values: &[f32], count: usize) -> f32 {
1533 if count == 0 {
1534 return f32::NEG_INFINITY;
1535 }
1536
1537 #[cfg(target_arch = "aarch64")]
1538 {
1539 if neon::is_available() {
1540 return unsafe { max_f32_neon(values, count) };
1541 }
1542 }
1543
1544 #[cfg(target_arch = "x86_64")]
1545 {
1546 if sse::is_available() {
1547 return unsafe { max_f32_sse(values, count) };
1548 }
1549 }
1550
1551 values[..count]
1553 .iter()
1554 .cloned()
1555 .fold(f32::NEG_INFINITY, f32::max)
1556}
1557
1558#[cfg(target_arch = "aarch64")]
1559#[target_feature(enable = "neon")]
1560#[allow(unsafe_op_in_unsafe_fn)]
1561unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1562 use std::arch::aarch64::*;
1563
1564 let chunks = count / 4;
1565 let remainder = count % 4;
1566
1567 let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1568
1569 for chunk in 0..chunks {
1570 let base = chunk * 4;
1571 let v = vld1q_f32(values.as_ptr().add(base));
1572 max_v = vmaxq_f32(max_v, v);
1573 }
1574
1575 let mut max_val = vmaxvq_f32(max_v);
1577
1578 let base = chunks * 4;
1580 for i in 0..remainder {
1581 max_val = max_val.max(values[base + i]);
1582 }
1583
1584 max_val
1585}
1586
1587#[cfg(target_arch = "x86_64")]
1588#[target_feature(enable = "sse")]
1589#[allow(unsafe_op_in_unsafe_fn)]
1590unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1591 use std::arch::x86_64::*;
1592
1593 let chunks = count / 4;
1594 let remainder = count % 4;
1595
1596 let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1597
1598 for chunk in 0..chunks {
1599 let base = chunk * 4;
1600 let v = _mm_loadu_ps(values.as_ptr().add(base));
1601 max_v = _mm_max_ps(max_v, v);
1602 }
1603
1604 let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); let max1 = _mm_max_ps(max_v, shuf); let shuf2 = _mm_movehl_ps(max1, max1); let final_max = _mm_max_ss(max1, shuf2); let mut max_val = _mm_cvtss_f32(final_max);
1611
1612 let base = chunks * 4;
1614 for i in 0..remainder {
1615 max_val = max_val.max(values[base + i]);
1616 }
1617
1618 max_val
1619}
1620
1621#[inline]
1630fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1631 #[cfg(target_arch = "aarch64")]
1632 {
1633 if neon::is_available() {
1634 return unsafe { fused_dot_norm_neon(a, b, count) };
1635 }
1636 }
1637
1638 #[cfg(target_arch = "x86_64")]
1639 {
1640 if sse::is_available() {
1641 return unsafe { fused_dot_norm_sse(a, b, count) };
1642 }
1643 }
1644
1645 let mut dot = 0.0f32;
1647 let mut norm_b = 0.0f32;
1648 for i in 0..count {
1649 dot += a[i] * b[i];
1650 norm_b += b[i] * b[i];
1651 }
1652 (dot, norm_b)
1653}
1654
1655#[cfg(target_arch = "aarch64")]
1656#[target_feature(enable = "neon")]
1657#[allow(unsafe_op_in_unsafe_fn)]
1658unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1659 use std::arch::aarch64::*;
1660
1661 let chunks = count / 4;
1662 let remainder = count % 4;
1663
1664 let mut acc_dot = vdupq_n_f32(0.0);
1665 let mut acc_norm = vdupq_n_f32(0.0);
1666
1667 for chunk in 0..chunks {
1668 let base = chunk * 4;
1669 let va = vld1q_f32(a.as_ptr().add(base));
1670 let vb = vld1q_f32(b.as_ptr().add(base));
1671 acc_dot = vfmaq_f32(acc_dot, va, vb);
1672 acc_norm = vfmaq_f32(acc_norm, vb, vb);
1673 }
1674
1675 let mut dot = vaddvq_f32(acc_dot);
1676 let mut norm = vaddvq_f32(acc_norm);
1677
1678 let base = chunks * 4;
1679 for i in 0..remainder {
1680 dot += a[base + i] * b[base + i];
1681 norm += b[base + i] * b[base + i];
1682 }
1683
1684 (dot, norm)
1685}
1686
1687#[cfg(target_arch = "x86_64")]
1688#[target_feature(enable = "sse")]
1689#[allow(unsafe_op_in_unsafe_fn)]
1690unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1691 use std::arch::x86_64::*;
1692
1693 let chunks = count / 4;
1694 let remainder = count % 4;
1695
1696 let mut acc_dot = _mm_setzero_ps();
1697 let mut acc_norm = _mm_setzero_ps();
1698
1699 for chunk in 0..chunks {
1700 let base = chunk * 4;
1701 let va = _mm_loadu_ps(a.as_ptr().add(base));
1702 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1703 acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
1704 acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
1705 }
1706
1707 let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
1709 let sums_d = _mm_add_ps(acc_dot, shuf_d);
1710 let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1711 let final_d = _mm_add_ss(sums_d, shuf2_d);
1712 let mut dot = _mm_cvtss_f32(final_d);
1713
1714 let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
1715 let sums_n = _mm_add_ps(acc_norm, shuf_n);
1716 let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1717 let final_n = _mm_add_ss(sums_n, shuf2_n);
1718 let mut norm = _mm_cvtss_f32(final_n);
1719
1720 let base = chunks * 4;
1721 for i in 0..remainder {
1722 dot += a[base + i] * b[base + i];
1723 norm += b[base + i] * b[base + i];
1724 }
1725
1726 (dot, norm)
1727}
1728
1729#[inline]
1739pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
1740 let n = scores.len();
1741 debug_assert!(vectors.len() >= n * dim);
1742 debug_assert_eq!(query.len(), dim);
1743
1744 if dim == 0 || n == 0 {
1745 return;
1746 }
1747
1748 let norm_q_sq = dot_product_f32(query, query, dim);
1750 if norm_q_sq < f32::EPSILON {
1751 for s in scores.iter_mut() {
1752 *s = 0.0;
1753 }
1754 return;
1755 }
1756 let norm_q = norm_q_sq.sqrt();
1757
1758 for i in 0..n {
1759 let vec = &vectors[i * dim..(i + 1) * dim];
1760 let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
1761 if norm_v_sq < f32::EPSILON {
1762 scores[i] = 0.0;
1763 } else {
1764 scores[i] = dot / (norm_q * norm_v_sq.sqrt());
1765 }
1766 }
1767}
1768
1769#[inline]
1775pub fn f32_to_f16(value: f32) -> u16 {
1776 let bits = value.to_bits();
1777 let sign = (bits >> 16) & 0x8000;
1778 let exp = ((bits >> 23) & 0xFF) as i32;
1779 let mantissa = bits & 0x7F_FFFF;
1780
1781 if exp == 255 {
1782 return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
1784 }
1785
1786 let exp16 = exp - 127 + 15;
1787
1788 if exp16 >= 31 {
1789 return (sign | 0x7C00) as u16; }
1791
1792 if exp16 <= 0 {
1793 if exp16 < -10 {
1794 return sign as u16; }
1796 let m = (mantissa | 0x80_0000) >> (1 - exp16);
1797 return (sign | (m >> 13)) as u16;
1798 }
1799
1800 (sign | ((exp16 as u32) << 10) | (mantissa >> 13)) as u16
1801}
1802
1803#[inline]
1805pub fn f16_to_f32(half: u16) -> f32 {
1806 let sign = ((half & 0x8000) as u32) << 16;
1807 let exp = ((half >> 10) & 0x1F) as u32;
1808 let mantissa = (half & 0x3FF) as u32;
1809
1810 if exp == 0 {
1811 if mantissa == 0 {
1812 return f32::from_bits(sign);
1813 }
1814 let mut e = 0u32;
1816 let mut m = mantissa;
1817 while (m & 0x400) == 0 {
1818 m <<= 1;
1819 e += 1;
1820 }
1821 return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
1822 }
1823
1824 if exp == 31 {
1825 return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
1826 }
1827
1828 f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
1829}
1830
1831const U8_SCALE: f32 = 127.5;
1836const U8_INV_SCALE: f32 = 1.0 / 127.5;
1837
1838#[inline]
1840pub fn f32_to_u8_saturating(value: f32) -> u8 {
1841 ((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
1842}
1843
1844#[inline]
1846pub fn u8_to_f32(byte: u8) -> f32 {
1847 byte as f32 * U8_INV_SCALE - 1.0
1848}
1849
1850pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
1856 debug_assert_eq!(src.len(), dst.len());
1857 for (s, d) in src.iter().zip(dst.iter_mut()) {
1858 *d = f32_to_f16(*s);
1859 }
1860}
1861
1862pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
1864 debug_assert_eq!(src.len(), dst.len());
1865 for (s, d) in src.iter().zip(dst.iter_mut()) {
1866 *d = f32_to_u8_saturating(*s);
1867 }
1868}
1869
1870#[cfg(target_arch = "aarch64")]
1875#[allow(unsafe_op_in_unsafe_fn)]
1876mod neon_quant {
1877 use std::arch::aarch64::*;
1878
1879 #[target_feature(enable = "neon")]
1885 pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
1886 let chunks8 = dim / 8;
1887 let remainder = dim % 8;
1888
1889 let mut acc_dot = vdupq_n_f32(0.0);
1890 let mut acc_norm = vdupq_n_f32(0.0);
1891
1892 for c in 0..chunks8 {
1893 let base = c * 8;
1894
1895 let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
1897 let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
1898 let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
1899
1900 let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
1902 let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
1903 let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
1904
1905 acc_dot = vfmaq_f32(acc_dot, q_lo, v_lo);
1906 acc_dot = vfmaq_f32(acc_dot, q_hi, v_hi);
1907 acc_norm = vfmaq_f32(acc_norm, v_lo, v_lo);
1908 acc_norm = vfmaq_f32(acc_norm, v_hi, v_hi);
1909 }
1910
1911 let mut dot = vaddvq_f32(acc_dot);
1912 let mut norm = vaddvq_f32(acc_norm);
1913
1914 let base = chunks8 * 8;
1915 for i in 0..remainder {
1916 let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
1917 let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
1918 dot += q * v;
1919 norm += v * v;
1920 }
1921
1922 (dot, norm)
1923 }
1924
1925 #[target_feature(enable = "neon")]
1928 pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
1929 let scale = vdupq_n_f32(super::U8_INV_SCALE);
1930 let offset = vdupq_n_f32(-1.0);
1931
1932 let chunks16 = dim / 16;
1933 let remainder = dim % 16;
1934
1935 let mut acc_dot = vdupq_n_f32(0.0);
1936 let mut acc_norm = vdupq_n_f32(0.0);
1937
1938 for c in 0..chunks16 {
1939 let base = c * 16;
1940
1941 let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
1943
1944 let lo8 = vget_low_u8(bytes);
1946 let hi8 = vget_high_u8(bytes);
1947 let lo16 = vmovl_u8(lo8);
1948 let hi16 = vmovl_u8(hi8);
1949
1950 let f0 = vaddq_f32(
1951 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
1952 offset,
1953 );
1954 let f1 = vaddq_f32(
1955 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
1956 offset,
1957 );
1958 let f2 = vaddq_f32(
1959 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
1960 offset,
1961 );
1962 let f3 = vaddq_f32(
1963 vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
1964 offset,
1965 );
1966
1967 let q0 = vld1q_f32(query.as_ptr().add(base));
1968 let q1 = vld1q_f32(query.as_ptr().add(base + 4));
1969 let q2 = vld1q_f32(query.as_ptr().add(base + 8));
1970 let q3 = vld1q_f32(query.as_ptr().add(base + 12));
1971
1972 acc_dot = vfmaq_f32(acc_dot, q0, f0);
1973 acc_dot = vfmaq_f32(acc_dot, q1, f1);
1974 acc_dot = vfmaq_f32(acc_dot, q2, f2);
1975 acc_dot = vfmaq_f32(acc_dot, q3, f3);
1976
1977 acc_norm = vfmaq_f32(acc_norm, f0, f0);
1978 acc_norm = vfmaq_f32(acc_norm, f1, f1);
1979 acc_norm = vfmaq_f32(acc_norm, f2, f2);
1980 acc_norm = vfmaq_f32(acc_norm, f3, f3);
1981 }
1982
1983 let mut dot = vaddvq_f32(acc_dot);
1984 let mut norm = vaddvq_f32(acc_norm);
1985
1986 let base = chunks16 * 16;
1987 for i in 0..remainder {
1988 let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
1989 dot += *query.get_unchecked(base + i) * v;
1990 norm += v * v;
1991 }
1992
1993 (dot, norm)
1994 }
1995}
1996
1997#[allow(dead_code)]
2002fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2003 let mut dot = 0.0f32;
2004 let mut norm = 0.0f32;
2005 for i in 0..dim {
2006 let v = f16_to_f32(vec_f16[i]);
2007 let q = f16_to_f32(query_f16[i]);
2008 dot += q * v;
2009 norm += v * v;
2010 }
2011 (dot, norm)
2012}
2013
2014#[allow(dead_code)]
2015fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2016 let mut dot = 0.0f32;
2017 let mut norm = 0.0f32;
2018 for i in 0..dim {
2019 let v = u8_to_f32(vec_u8[i]);
2020 dot += query[i] * v;
2021 norm += v * v;
2022 }
2023 (dot, norm)
2024}
2025
2026#[inline]
2031fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2032 #[cfg(target_arch = "aarch64")]
2033 {
2034 unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) }
2035 }
2036 #[cfg(not(target_arch = "aarch64"))]
2037 {
2038 fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
2039 }
2040}
2041
2042#[inline]
2043fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2044 #[cfg(target_arch = "aarch64")]
2045 {
2046 unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) }
2047 }
2048 #[cfg(not(target_arch = "aarch64"))]
2049 {
2050 fused_dot_norm_u8_scalar(query, vec_u8, dim)
2051 }
2052}
2053
2054#[inline]
2065pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2066 let n = scores.len();
2067 if dim == 0 || n == 0 {
2068 return;
2069 }
2070
2071 let norm_q_sq = dot_product_f32(query, query, dim);
2073 if norm_q_sq < f32::EPSILON {
2074 for s in scores.iter_mut() {
2075 *s = 0.0;
2076 }
2077 return;
2078 }
2079 let norm_q = norm_q_sq.sqrt();
2080
2081 let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2083
2084 let vec_bytes = dim * 2;
2085 debug_assert!(vectors_raw.len() >= n * vec_bytes);
2086
2087 debug_assert!(
2090 (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2091 "f16 vector data not 2-byte aligned"
2092 );
2093
2094 for i in 0..n {
2095 let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2096 let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2097
2098 let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
2099 scores[i] = if norm_v_sq < f32::EPSILON {
2100 0.0
2101 } else {
2102 dot / (norm_q * norm_v_sq.sqrt())
2103 };
2104 }
2105}
2106
2107#[inline]
2113pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2114 let n = scores.len();
2115 if dim == 0 || n == 0 {
2116 return;
2117 }
2118
2119 let norm_q_sq = dot_product_f32(query, query, dim);
2120 if norm_q_sq < f32::EPSILON {
2121 for s in scores.iter_mut() {
2122 *s = 0.0;
2123 }
2124 return;
2125 }
2126 let norm_q = norm_q_sq.sqrt();
2127
2128 debug_assert!(vectors_raw.len() >= n * dim);
2129
2130 for i in 0..n {
2131 let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2132
2133 let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
2134 scores[i] = if norm_v_sq < f32::EPSILON {
2135 0.0
2136 } else {
2137 dot / (norm_q * norm_v_sq.sqrt())
2138 };
2139 }
2140}
2141
2142#[inline]
2147pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
2148 debug_assert_eq!(a.len(), b.len());
2149 let count = a.len();
2150
2151 if count == 0 {
2152 return 0.0;
2153 }
2154
2155 let dot = dot_product_f32(a, b, count);
2156 let norm_a = dot_product_f32(a, a, count);
2157 let norm_b = dot_product_f32(b, b, count);
2158
2159 let denom = (norm_a * norm_b).sqrt();
2160 if denom < f32::EPSILON {
2161 return 0.0;
2162 }
2163
2164 dot / denom
2165}
2166
2167#[inline]
2171pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
2172 debug_assert_eq!(a.len(), b.len());
2173 let count = a.len();
2174
2175 if count == 0 {
2176 return 0.0;
2177 }
2178
2179 #[cfg(target_arch = "aarch64")]
2180 {
2181 if neon::is_available() {
2182 return unsafe { squared_euclidean_neon(a, b, count) };
2183 }
2184 }
2185
2186 #[cfg(target_arch = "x86_64")]
2187 {
2188 if avx2::is_available() {
2189 return unsafe { squared_euclidean_avx2(a, b, count) };
2190 }
2191 if sse::is_available() {
2192 return unsafe { squared_euclidean_sse(a, b, count) };
2193 }
2194 }
2195
2196 a.iter()
2198 .zip(b.iter())
2199 .map(|(&x, &y)| {
2200 let d = x - y;
2201 d * d
2202 })
2203 .sum()
2204}
2205
2206#[cfg(target_arch = "aarch64")]
2207#[target_feature(enable = "neon")]
2208#[allow(unsafe_op_in_unsafe_fn)]
2209unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
2210 use std::arch::aarch64::*;
2211
2212 let chunks = count / 4;
2213 let remainder = count % 4;
2214
2215 let mut acc = vdupq_n_f32(0.0);
2216
2217 for chunk in 0..chunks {
2218 let base = chunk * 4;
2219 let va = vld1q_f32(a.as_ptr().add(base));
2220 let vb = vld1q_f32(b.as_ptr().add(base));
2221 let diff = vsubq_f32(va, vb);
2222 acc = vfmaq_f32(acc, diff, diff); }
2224
2225 let mut sum = vaddvq_f32(acc);
2227
2228 let base = chunks * 4;
2230 for i in 0..remainder {
2231 let d = a[base + i] - b[base + i];
2232 sum += d * d;
2233 }
2234
2235 sum
2236}
2237
2238#[cfg(target_arch = "x86_64")]
2239#[target_feature(enable = "sse")]
2240#[allow(unsafe_op_in_unsafe_fn)]
2241unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
2242 use std::arch::x86_64::*;
2243
2244 let chunks = count / 4;
2245 let remainder = count % 4;
2246
2247 let mut acc = _mm_setzero_ps();
2248
2249 for chunk in 0..chunks {
2250 let base = chunk * 4;
2251 let va = _mm_loadu_ps(a.as_ptr().add(base));
2252 let vb = _mm_loadu_ps(b.as_ptr().add(base));
2253 let diff = _mm_sub_ps(va, vb);
2254 acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
2255 }
2256
2257 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
2264
2265 let base = chunks * 4;
2267 for i in 0..remainder {
2268 let d = a[base + i] - b[base + i];
2269 sum += d * d;
2270 }
2271
2272 sum
2273}
2274
2275#[cfg(target_arch = "x86_64")]
2276#[target_feature(enable = "avx2")]
2277#[allow(unsafe_op_in_unsafe_fn)]
2278unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
2279 use std::arch::x86_64::*;
2280
2281 let chunks = count / 8;
2282 let remainder = count % 8;
2283
2284 let mut acc = _mm256_setzero_ps();
2285
2286 for chunk in 0..chunks {
2287 let base = chunk * 8;
2288 let va = _mm256_loadu_ps(a.as_ptr().add(base));
2289 let vb = _mm256_loadu_ps(b.as_ptr().add(base));
2290 let diff = _mm256_sub_ps(va, vb);
2291 acc = _mm256_fmadd_ps(diff, diff, acc); }
2293
2294 let high = _mm256_extractf128_ps(acc, 1);
2297 let low = _mm256_castps256_ps128(acc);
2298 let sum128 = _mm_add_ps(low, high);
2299
2300 let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
2302 let sums = _mm_add_ps(sum128, shuf);
2303 let shuf2 = _mm_movehl_ps(sums, sums);
2304 let final_sum = _mm_add_ss(sums, shuf2);
2305
2306 let mut sum = _mm_cvtss_f32(final_sum);
2307
2308 let base = chunks * 8;
2310 for i in 0..remainder {
2311 let d = a[base + i] - b[base + i];
2312 sum += d * d;
2313 }
2314
2315 sum
2316}
2317
2318#[inline]
2324pub fn batch_squared_euclidean_distances(
2325 query: &[f32],
2326 vectors: &[Vec<f32>],
2327 distances: &mut [f32],
2328) {
2329 debug_assert_eq!(vectors.len(), distances.len());
2330
2331 #[cfg(target_arch = "x86_64")]
2332 {
2333 if avx2::is_available() {
2334 for (i, vec) in vectors.iter().enumerate() {
2335 distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
2336 }
2337 return;
2338 }
2339 }
2340
2341 for (i, vec) in vectors.iter().enumerate() {
2343 distances[i] = squared_euclidean_distance(query, vec);
2344 }
2345}
2346
2347#[cfg(test)]
2348mod tests {
2349 use super::*;
2350
2351 #[test]
2352 fn test_unpack_8bit() {
2353 let input: Vec<u8> = (0..128).collect();
2354 let mut output = vec![0u32; 128];
2355 unpack_8bit(&input, &mut output, 128);
2356
2357 for (i, &v) in output.iter().enumerate() {
2358 assert_eq!(v, i as u32);
2359 }
2360 }
2361
2362 #[test]
2363 fn test_unpack_16bit() {
2364 let mut input = vec![0u8; 256];
2365 for i in 0..128 {
2366 let val = (i * 100) as u16;
2367 input[i * 2] = val as u8;
2368 input[i * 2 + 1] = (val >> 8) as u8;
2369 }
2370
2371 let mut output = vec![0u32; 128];
2372 unpack_16bit(&input, &mut output, 128);
2373
2374 for (i, &v) in output.iter().enumerate() {
2375 assert_eq!(v, (i * 100) as u32);
2376 }
2377 }
2378
2379 #[test]
2380 fn test_unpack_32bit() {
2381 let mut input = vec![0u8; 512];
2382 for i in 0..128 {
2383 let val = (i * 1000) as u32;
2384 let bytes = val.to_le_bytes();
2385 input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
2386 }
2387
2388 let mut output = vec![0u32; 128];
2389 unpack_32bit(&input, &mut output, 128);
2390
2391 for (i, &v) in output.iter().enumerate() {
2392 assert_eq!(v, (i * 1000) as u32);
2393 }
2394 }
2395
2396 #[test]
2397 fn test_delta_decode() {
2398 let deltas = vec![4u32, 4, 9, 19];
2402 let mut output = vec![0u32; 5];
2403
2404 delta_decode(&mut output, &deltas, 10, 5);
2405
2406 assert_eq!(output, vec![10, 15, 20, 30, 50]);
2407 }
2408
2409 #[test]
2410 fn test_add_one() {
2411 let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
2412 add_one(&mut values, 8);
2413
2414 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
2415 }
2416
2417 #[test]
2418 fn test_bits_needed() {
2419 assert_eq!(bits_needed(0), 0);
2420 assert_eq!(bits_needed(1), 1);
2421 assert_eq!(bits_needed(2), 2);
2422 assert_eq!(bits_needed(3), 2);
2423 assert_eq!(bits_needed(4), 3);
2424 assert_eq!(bits_needed(255), 8);
2425 assert_eq!(bits_needed(256), 9);
2426 assert_eq!(bits_needed(u32::MAX), 32);
2427 }
2428
2429 #[test]
2430 fn test_unpack_8bit_delta_decode() {
2431 let input: Vec<u8> = vec![4, 4, 9, 19];
2435 let mut output = vec![0u32; 5];
2436
2437 unpack_8bit_delta_decode(&input, &mut output, 10, 5);
2438
2439 assert_eq!(output, vec![10, 15, 20, 30, 50]);
2440 }
2441
2442 #[test]
2443 fn test_unpack_16bit_delta_decode() {
2444 let mut input = vec![0u8; 8];
2448 for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
2449 input[i * 2] = delta as u8;
2450 input[i * 2 + 1] = (delta >> 8) as u8;
2451 }
2452 let mut output = vec![0u32; 5];
2453
2454 unpack_16bit_delta_decode(&input, &mut output, 100, 5);
2455
2456 assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
2457 }
2458
2459 #[test]
2460 fn test_fused_vs_separate_8bit() {
2461 let input: Vec<u8> = (0..127).collect();
2463 let first_value = 1000u32;
2464 let count = 128;
2465
2466 let mut unpacked = vec![0u32; 128];
2468 unpack_8bit(&input, &mut unpacked, 127);
2469 let mut separate_output = vec![0u32; 128];
2470 delta_decode(&mut separate_output, &unpacked, first_value, count);
2471
2472 let mut fused_output = vec![0u32; 128];
2474 unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
2475
2476 assert_eq!(separate_output, fused_output);
2477 }
2478
2479 #[test]
2480 fn test_round_bit_width() {
2481 assert_eq!(round_bit_width(0), 0);
2482 assert_eq!(round_bit_width(1), 8);
2483 assert_eq!(round_bit_width(5), 8);
2484 assert_eq!(round_bit_width(8), 8);
2485 assert_eq!(round_bit_width(9), 16);
2486 assert_eq!(round_bit_width(12), 16);
2487 assert_eq!(round_bit_width(16), 16);
2488 assert_eq!(round_bit_width(17), 32);
2489 assert_eq!(round_bit_width(24), 32);
2490 assert_eq!(round_bit_width(32), 32);
2491 }
2492
2493 #[test]
2494 fn test_rounded_bitwidth_from_exact() {
2495 assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
2496 assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
2497 assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
2498 assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
2499 assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
2500 assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
2501 assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
2502 }
2503
2504 #[test]
2505 fn test_pack_unpack_rounded_8bit() {
2506 let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
2507 let mut packed = vec![0u8; 128];
2508
2509 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
2510 assert_eq!(bytes_written, 128);
2511
2512 let mut unpacked = vec![0u32; 128];
2513 unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
2514
2515 assert_eq!(values, unpacked);
2516 }
2517
2518 #[test]
2519 fn test_pack_unpack_rounded_16bit() {
2520 let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
2521 let mut packed = vec![0u8; 256];
2522
2523 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
2524 assert_eq!(bytes_written, 256);
2525
2526 let mut unpacked = vec![0u32; 128];
2527 unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
2528
2529 assert_eq!(values, unpacked);
2530 }
2531
2532 #[test]
2533 fn test_pack_unpack_rounded_32bit() {
2534 let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
2535 let mut packed = vec![0u8; 512];
2536
2537 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
2538 assert_eq!(bytes_written, 512);
2539
2540 let mut unpacked = vec![0u32; 128];
2541 unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
2542
2543 assert_eq!(values, unpacked);
2544 }
2545
2546 #[test]
2547 fn test_unpack_rounded_delta_decode() {
2548 let input: Vec<u8> = vec![4, 4, 9, 19];
2553 let mut output = vec![0u32; 5];
2554
2555 unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
2556
2557 assert_eq!(output, vec![10, 15, 20, 30, 50]);
2558 }
2559
2560 #[test]
2561 fn test_unpack_rounded_delta_decode_zero() {
2562 let input: Vec<u8> = vec![];
2564 let mut output = vec![0u32; 5];
2565
2566 unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
2567
2568 assert_eq!(output, vec![100, 101, 102, 103, 104]);
2569 }
2570
2571 #[test]
2576 fn test_dequantize_uint8() {
2577 let input: Vec<u8> = vec![0, 128, 255, 64, 192];
2578 let mut output = vec![0.0f32; 5];
2579 let scale = 0.1;
2580 let min_val = 1.0;
2581
2582 dequantize_uint8(&input, &mut output, scale, min_val, 5);
2583
2584 assert!((output[0] - 1.0).abs() < 1e-6); assert!((output[1] - 13.8).abs() < 1e-6); assert!((output[2] - 26.5).abs() < 1e-6); assert!((output[3] - 7.4).abs() < 1e-6); assert!((output[4] - 20.2).abs() < 1e-6); }
2591
2592 #[test]
2593 fn test_dequantize_uint8_large() {
2594 let input: Vec<u8> = (0..128).collect();
2596 let mut output = vec![0.0f32; 128];
2597 let scale = 2.0;
2598 let min_val = -10.0;
2599
2600 dequantize_uint8(&input, &mut output, scale, min_val, 128);
2601
2602 for (i, &out) in output.iter().enumerate().take(128) {
2603 let expected = i as f32 * scale + min_val;
2604 assert!(
2605 (out - expected).abs() < 1e-5,
2606 "Mismatch at {}: expected {}, got {}",
2607 i,
2608 expected,
2609 out
2610 );
2611 }
2612 }
2613
2614 #[test]
2615 fn test_dot_product_f32() {
2616 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
2617 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
2618
2619 let result = dot_product_f32(&a, &b, 5);
2620
2621 assert!((result - 70.0).abs() < 1e-5);
2623 }
2624
2625 #[test]
2626 fn test_dot_product_f32_large() {
2627 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
2629 let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
2630
2631 let result = dot_product_f32(&a, &b, 128);
2632
2633 let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
2635 assert!(
2636 (result - expected).abs() < 1e-3,
2637 "Expected {}, got {}",
2638 expected,
2639 result
2640 );
2641 }
2642
2643 #[test]
2644 fn test_max_f32() {
2645 let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
2646 let result = max_f32(&values, 6);
2647 assert!((result - 9.0).abs() < 1e-6);
2648 }
2649
2650 #[test]
2651 fn test_max_f32_large() {
2652 let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
2654 values[77] = 1000.0;
2655
2656 let result = max_f32(&values, 128);
2657 assert!((result - 1000.0).abs() < 1e-5);
2658 }
2659
2660 #[test]
2661 fn test_max_f32_negative() {
2662 let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
2663 let result = max_f32(&values, 5);
2664 assert!((result - (-1.0)).abs() < 1e-6);
2665 }
2666
2667 #[test]
2668 fn test_max_f32_empty() {
2669 let values: Vec<f32> = vec![];
2670 let result = max_f32(&values, 0);
2671 assert_eq!(result, f32::NEG_INFINITY);
2672 }
2673
2674 #[test]
2675 fn test_fused_dot_norm() {
2676 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2677 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2678 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
2679
2680 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
2681 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
2682 assert!(
2683 (dot - expected_dot).abs() < 1e-5,
2684 "dot: expected {}, got {}",
2685 expected_dot,
2686 dot
2687 );
2688 assert!(
2689 (norm_b - expected_norm).abs() < 1e-5,
2690 "norm: expected {}, got {}",
2691 expected_norm,
2692 norm_b
2693 );
2694 }
2695
2696 #[test]
2697 fn test_fused_dot_norm_large() {
2698 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
2699 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
2700 let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
2701
2702 let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
2703 let expected_norm: f32 = b.iter().map(|x| x * x).sum();
2704 assert!(
2705 (dot - expected_dot).abs() < 1.0,
2706 "dot: expected {}, got {}",
2707 expected_dot,
2708 dot
2709 );
2710 assert!(
2711 (norm_b - expected_norm).abs() < 1.0,
2712 "norm: expected {}, got {}",
2713 expected_norm,
2714 norm_b
2715 );
2716 }
2717
2718 #[test]
2719 fn test_batch_cosine_scores() {
2720 let query = vec![1.0f32, 0.0, 0.0];
2722 let vectors = vec![
2723 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.5, 0.5, 0.0, ];
2728 let mut scores = vec![0f32; 4];
2729 batch_cosine_scores(&query, &vectors, 3, &mut scores);
2730
2731 assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
2732 assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
2733 assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
2734 let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
2735 assert!(
2736 (scores[3] - expected_45).abs() < 1e-5,
2737 "45deg: expected {}, got {}",
2738 expected_45,
2739 scores[3]
2740 );
2741 }
2742
2743 #[test]
2744 fn test_batch_cosine_scores_matches_individual() {
2745 let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
2746 let n = 50;
2747 let dim = 128;
2748 let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
2749
2750 let mut batch_scores = vec![0f32; n];
2751 batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
2752
2753 for i in 0..n {
2754 let vec_i = &vectors[i * dim..(i + 1) * dim];
2755 let individual = cosine_similarity(&query, vec_i);
2756 assert!(
2757 (batch_scores[i] - individual).abs() < 1e-5,
2758 "vec {}: batch={}, individual={}",
2759 i,
2760 batch_scores[i],
2761 individual
2762 );
2763 }
2764 }
2765
2766 #[test]
2767 fn test_batch_cosine_scores_empty() {
2768 let query = vec![1.0f32, 2.0, 3.0];
2769 let vectors: Vec<f32> = vec![];
2770 let mut scores: Vec<f32> = vec![];
2771 batch_cosine_scores(&query, &vectors, 3, &mut scores);
2772 assert!(scores.is_empty());
2773 }
2774
2775 #[test]
2776 fn test_batch_cosine_scores_zero_query() {
2777 let query = vec![0.0f32, 0.0, 0.0];
2778 let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
2779 let mut scores = vec![0f32; 2];
2780 batch_cosine_scores(&query, &vectors, 3, &mut scores);
2781 assert_eq!(scores[0], 0.0);
2782 assert_eq!(scores[1], 0.0);
2783 }
2784
2785 #[test]
2786 fn test_squared_euclidean_distance() {
2787 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2788 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2789 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
2790 let result = squared_euclidean_distance(&a, &b);
2791 assert!(
2792 (result - expected).abs() < 1e-5,
2793 "expected {}, got {}",
2794 expected,
2795 result
2796 );
2797 }
2798
2799 #[test]
2800 fn test_squared_euclidean_distance_large() {
2801 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
2802 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
2803 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
2804 let result = squared_euclidean_distance(&a, &b);
2805 assert!(
2806 (result - expected).abs() < 1e-3,
2807 "expected {}, got {}",
2808 expected,
2809 result
2810 );
2811 }
2812
2813 #[test]
2818 fn test_f16_roundtrip_normal() {
2819 for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
2820 let h = f32_to_f16(v);
2821 let back = f16_to_f32(h);
2822 let err = (back - v).abs() / v.abs().max(1e-6);
2823 assert!(
2824 err < 0.002,
2825 "f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
2826 );
2827 }
2828 }
2829
2830 #[test]
2831 fn test_f16_special() {
2832 assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
2834 assert_eq!(f32_to_f16(-0.0), 0x8000);
2836 assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
2838 assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
2840 }
2841
2842 #[test]
2843 fn test_f16_embedding_range() {
2844 let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
2846 for &v in &values {
2847 let back = f16_to_f32(f32_to_f16(v));
2848 assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
2849 }
2850 }
2851
2852 #[test]
2857 fn test_u8_roundtrip() {
2858 assert_eq!(f32_to_u8_saturating(-1.0), 0);
2860 assert_eq!(f32_to_u8_saturating(1.0), 255);
2861 assert_eq!(f32_to_u8_saturating(0.0), 127); assert_eq!(f32_to_u8_saturating(-2.0), 0);
2865 assert_eq!(f32_to_u8_saturating(2.0), 255);
2866 }
2867
2868 #[test]
2869 fn test_u8_dequantize() {
2870 assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
2871 assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
2872 assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
2873 }
2874
2875 #[test]
2880 fn test_batch_cosine_scores_f16() {
2881 let query = vec![0.6f32, 0.8, 0.0, 0.0];
2882 let dim = 4;
2883 let vecs_f32 = vec![
2884 0.6f32, 0.8, 0.0, 0.0, 0.0, 0.0, 0.6, 0.8, ];
2887
2888 let mut f16_buf = vec![0u16; 8];
2890 batch_f32_to_f16(&vecs_f32, &mut f16_buf);
2891 let raw: &[u8] =
2892 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
2893
2894 let mut scores = vec![0f32; 2];
2895 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
2896
2897 assert!(
2898 (scores[0] - 1.0).abs() < 0.01,
2899 "identical vectors: {}",
2900 scores[0]
2901 );
2902 assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
2903 }
2904
2905 #[test]
2906 fn test_batch_cosine_scores_u8() {
2907 let query = vec![0.6f32, 0.8, 0.0, 0.0];
2908 let dim = 4;
2909 let vecs_f32 = vec![
2910 0.6f32, 0.8, 0.0, 0.0, -0.6, -0.8, 0.0, 0.0, ];
2913
2914 let mut u8_buf = vec![0u8; 8];
2916 batch_f32_to_u8(&vecs_f32, &mut u8_buf);
2917
2918 let mut scores = vec![0f32; 2];
2919 batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
2920
2921 assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
2922 assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
2923 }
2924
2925 #[test]
2926 fn test_batch_cosine_scores_f16_large_dim() {
2927 let dim = 768;
2929 let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
2930 let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
2931
2932 let mut all_vecs = query.clone();
2933 all_vecs.extend_from_slice(&vec2);
2934
2935 let mut f16_buf = vec![0u16; all_vecs.len()];
2936 batch_f32_to_f16(&all_vecs, &mut f16_buf);
2937 let raw: &[u8] =
2938 unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
2939
2940 let mut scores = vec![0f32; 2];
2941 batch_cosine_scores_f16(&query, raw, dim, &mut scores);
2942
2943 assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
2945 assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
2947 }
2948}