1#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20 use std::arch::aarch64::*;
21
22 #[target_feature(enable = "neon")]
24 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25 let chunks = count / 16;
26 let remainder = count % 16;
27
28 for chunk in 0..chunks {
29 let base = chunk * 16;
30 let in_ptr = input.as_ptr().add(base);
31
32 let bytes = vld1q_u8(in_ptr);
34
35 let low8 = vget_low_u8(bytes);
37 let high8 = vget_high_u8(bytes);
38
39 let low16 = vmovl_u8(low8);
40 let high16 = vmovl_u8(high8);
41
42 let v0 = vmovl_u16(vget_low_u16(low16));
43 let v1 = vmovl_u16(vget_high_u16(low16));
44 let v2 = vmovl_u16(vget_low_u16(high16));
45 let v3 = vmovl_u16(vget_high_u16(high16));
46
47 let out_ptr = output.as_mut_ptr().add(base);
48 vst1q_u32(out_ptr, v0);
49 vst1q_u32(out_ptr.add(4), v1);
50 vst1q_u32(out_ptr.add(8), v2);
51 vst1q_u32(out_ptr.add(12), v3);
52 }
53
54 let base = chunks * 16;
56 for i in 0..remainder {
57 output[base + i] = input[base + i] as u32;
58 }
59 }
60
61 #[target_feature(enable = "neon")]
63 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64 let chunks = count / 8;
65 let remainder = count % 8;
66
67 for chunk in 0..chunks {
68 let base = chunk * 8;
69 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71 let vals = vld1q_u16(in_ptr);
72 let low = vmovl_u16(vget_low_u16(vals));
73 let high = vmovl_u16(vget_high_u16(vals));
74
75 let out_ptr = output.as_mut_ptr().add(base);
76 vst1q_u32(out_ptr, low);
77 vst1q_u32(out_ptr.add(4), high);
78 }
79
80 let base = chunks * 8;
82 for i in 0..remainder {
83 let idx = (base + i) * 2;
84 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85 }
86 }
87
88 #[target_feature(enable = "neon")]
90 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91 let chunks = count / 4;
92 let remainder = count % 4;
93
94 let in_ptr = input.as_ptr() as *const u32;
95 let out_ptr = output.as_mut_ptr();
96
97 for chunk in 0..chunks {
98 let vals = vld1q_u32(in_ptr.add(chunk * 4));
99 vst1q_u32(out_ptr.add(chunk * 4), vals);
100 }
101
102 let base = chunks * 4;
104 for i in 0..remainder {
105 let idx = (base + i) * 4;
106 output[base + i] =
107 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108 }
109 }
110
111 #[inline]
115 #[target_feature(enable = "neon")]
116 unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120 let sum1 = vaddq_u32(v, shifted1);
121
122 let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125 vaddq_u32(sum1, shifted2)
126 }
127
128 #[target_feature(enable = "neon")]
132 pub unsafe fn delta_decode(
133 output: &mut [u32],
134 deltas: &[u32],
135 first_doc_id: u32,
136 count: usize,
137 ) {
138 if count == 0 {
139 return;
140 }
141
142 output[0] = first_doc_id;
143 if count == 1 {
144 return;
145 }
146
147 let ones = vdupq_n_u32(1);
148 let mut carry = vdupq_n_u32(first_doc_id);
149
150 let full_groups = (count - 1) / 4;
151 let remainder = (count - 1) % 4;
152
153 for group in 0..full_groups {
154 let base = group * 4;
155
156 let d = vld1q_u32(deltas[base..].as_ptr());
158 let gaps = vaddq_u32(d, ones);
159
160 let prefix = prefix_sum_4(gaps);
162
163 let result = vaddq_u32(prefix, carry);
165
166 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171 }
172
173 let base = full_groups * 4;
175 let mut scalar_carry = vgetq_lane_u32(carry, 0);
176 for j in 0..remainder {
177 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178 output[base + j + 1] = scalar_carry;
179 }
180 }
181
182 #[target_feature(enable = "neon")]
184 pub unsafe fn add_one(values: &mut [u32], count: usize) {
185 let ones = vdupq_n_u32(1);
186 let chunks = count / 4;
187 let remainder = count % 4;
188
189 for chunk in 0..chunks {
190 let base = chunk * 4;
191 let ptr = values.as_mut_ptr().add(base);
192 let v = vld1q_u32(ptr);
193 let result = vaddq_u32(v, ones);
194 vst1q_u32(ptr, result);
195 }
196
197 let base = chunks * 4;
198 for i in 0..remainder {
199 values[base + i] += 1;
200 }
201 }
202
203 #[target_feature(enable = "neon")]
206 pub unsafe fn unpack_8bit_delta_decode(
207 input: &[u8],
208 output: &mut [u32],
209 first_value: u32,
210 count: usize,
211 ) {
212 output[0] = first_value;
213 if count <= 1 {
214 return;
215 }
216
217 let ones = vdupq_n_u32(1);
218 let mut carry = vdupq_n_u32(first_value);
219
220 let full_groups = (count - 1) / 4;
221 let remainder = (count - 1) % 4;
222
223 for group in 0..full_groups {
224 let base = group * 4;
225
226 let b0 = input[base] as u32;
228 let b1 = input[base + 1] as u32;
229 let b2 = input[base + 2] as u32;
230 let b3 = input[base + 3] as u32;
231 let deltas = [b0, b1, b2, b3];
232 let d = vld1q_u32(deltas.as_ptr());
233
234 let gaps = vaddq_u32(d, ones);
236
237 let prefix = prefix_sum_4(gaps);
239
240 let result = vaddq_u32(prefix, carry);
242
243 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248 }
249
250 let base = full_groups * 4;
252 let mut scalar_carry = vgetq_lane_u32(carry, 0);
253 for j in 0..remainder {
254 scalar_carry = scalar_carry
255 .wrapping_add(input[base + j] as u32)
256 .wrapping_add(1);
257 output[base + j + 1] = scalar_carry;
258 }
259 }
260
261 #[target_feature(enable = "neon")]
263 pub unsafe fn unpack_16bit_delta_decode(
264 input: &[u8],
265 output: &mut [u32],
266 first_value: u32,
267 count: usize,
268 ) {
269 output[0] = first_value;
270 if count <= 1 {
271 return;
272 }
273
274 let ones = vdupq_n_u32(1);
275 let mut carry = vdupq_n_u32(first_value);
276
277 let full_groups = (count - 1) / 4;
278 let remainder = (count - 1) % 4;
279
280 for group in 0..full_groups {
281 let base = group * 4;
282 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284 let vals = vld1_u16(in_ptr);
286 let d = vmovl_u16(vals);
287
288 let gaps = vaddq_u32(d, ones);
290
291 let prefix = prefix_sum_4(gaps);
293
294 let result = vaddq_u32(prefix, carry);
296
297 vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300 carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302 }
303
304 let base = full_groups * 4;
306 let mut scalar_carry = vgetq_lane_u32(carry, 0);
307 for j in 0..remainder {
308 let idx = (base + j) * 2;
309 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311 output[base + j + 1] = scalar_carry;
312 }
313 }
314
315 #[inline]
317 pub fn is_available() -> bool {
318 true
319 }
320}
321
322#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329 use std::arch::x86_64::*;
330
331 #[target_feature(enable = "sse2", enable = "sse4.1")]
333 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334 let chunks = count / 16;
335 let remainder = count % 16;
336
337 for chunk in 0..chunks {
338 let base = chunk * 16;
339 let in_ptr = input.as_ptr().add(base);
340
341 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343 let v0 = _mm_cvtepu8_epi32(bytes);
345 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349 let out_ptr = output.as_mut_ptr().add(base);
350 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354 }
355
356 let base = chunks * 16;
357 for i in 0..remainder {
358 output[base + i] = input[base + i] as u32;
359 }
360 }
361
362 #[target_feature(enable = "sse2", enable = "sse4.1")]
364 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365 let chunks = count / 8;
366 let remainder = count % 8;
367
368 for chunk in 0..chunks {
369 let base = chunk * 8;
370 let in_ptr = input.as_ptr().add(base * 2);
371
372 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373 let low = _mm_cvtepu16_epi32(vals);
374 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376 let out_ptr = output.as_mut_ptr().add(base);
377 _mm_storeu_si128(out_ptr as *mut __m128i, low);
378 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379 }
380
381 let base = chunks * 8;
382 for i in 0..remainder {
383 let idx = (base + i) * 2;
384 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385 }
386 }
387
388 #[target_feature(enable = "sse2")]
390 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391 let chunks = count / 4;
392 let remainder = count % 4;
393
394 let in_ptr = input.as_ptr() as *const __m128i;
395 let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397 for chunk in 0..chunks {
398 let vals = _mm_loadu_si128(in_ptr.add(chunk));
399 _mm_storeu_si128(out_ptr.add(chunk), vals);
400 }
401
402 let base = chunks * 4;
404 for i in 0..remainder {
405 let idx = (base + i) * 4;
406 output[base + i] =
407 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408 }
409 }
410
411 #[inline]
415 #[target_feature(enable = "sse2")]
416 unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417 let shifted1 = _mm_slli_si128(v, 4);
420 let sum1 = _mm_add_epi32(v, shifted1);
421
422 let shifted2 = _mm_slli_si128(sum1, 8);
425 _mm_add_epi32(sum1, shifted2)
426 }
427
428 #[target_feature(enable = "sse2", enable = "sse4.1")]
430 pub unsafe fn delta_decode(
431 output: &mut [u32],
432 deltas: &[u32],
433 first_doc_id: u32,
434 count: usize,
435 ) {
436 if count == 0 {
437 return;
438 }
439
440 output[0] = first_doc_id;
441 if count == 1 {
442 return;
443 }
444
445 let ones = _mm_set1_epi32(1);
446 let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448 let full_groups = (count - 1) / 4;
449 let remainder = (count - 1) % 4;
450
451 for group in 0..full_groups {
452 let base = group * 4;
453
454 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456 let gaps = _mm_add_epi32(d, ones);
457
458 let prefix = prefix_sum_4(gaps);
460
461 let result = _mm_add_epi32(prefix, carry);
463
464 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467 carry = _mm_shuffle_epi32(result, 0xFF); }
470
471 let base = full_groups * 4;
473 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474 for j in 0..remainder {
475 scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476 output[base + j + 1] = scalar_carry;
477 }
478 }
479
480 #[target_feature(enable = "sse2")]
482 pub unsafe fn add_one(values: &mut [u32], count: usize) {
483 let ones = _mm_set1_epi32(1);
484 let chunks = count / 4;
485 let remainder = count % 4;
486
487 for chunk in 0..chunks {
488 let base = chunk * 4;
489 let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490 let v = _mm_loadu_si128(ptr);
491 let result = _mm_add_epi32(v, ones);
492 _mm_storeu_si128(ptr, result);
493 }
494
495 let base = chunks * 4;
496 for i in 0..remainder {
497 values[base + i] += 1;
498 }
499 }
500
501 #[target_feature(enable = "sse2", enable = "sse4.1")]
503 pub unsafe fn unpack_8bit_delta_decode(
504 input: &[u8],
505 output: &mut [u32],
506 first_value: u32,
507 count: usize,
508 ) {
509 output[0] = first_value;
510 if count <= 1 {
511 return;
512 }
513
514 let ones = _mm_set1_epi32(1);
515 let mut carry = _mm_set1_epi32(first_value as i32);
516
517 let full_groups = (count - 1) / 4;
518 let remainder = (count - 1) % 4;
519
520 for group in 0..full_groups {
521 let base = group * 4;
522
523 let bytes = _mm_cvtsi32_si128(*(input.as_ptr().add(base) as *const i32));
525 let d = _mm_cvtepu8_epi32(bytes);
526
527 let gaps = _mm_add_epi32(d, ones);
529
530 let prefix = prefix_sum_4(gaps);
532
533 let result = _mm_add_epi32(prefix, carry);
535
536 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
538
539 carry = _mm_shuffle_epi32(result, 0xFF);
541 }
542
543 let base = full_groups * 4;
545 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
546 for j in 0..remainder {
547 scalar_carry = scalar_carry
548 .wrapping_add(input[base + j] as u32)
549 .wrapping_add(1);
550 output[base + j + 1] = scalar_carry;
551 }
552 }
553
554 #[target_feature(enable = "sse2", enable = "sse4.1")]
556 pub unsafe fn unpack_16bit_delta_decode(
557 input: &[u8],
558 output: &mut [u32],
559 first_value: u32,
560 count: usize,
561 ) {
562 output[0] = first_value;
563 if count <= 1 {
564 return;
565 }
566
567 let ones = _mm_set1_epi32(1);
568 let mut carry = _mm_set1_epi32(first_value as i32);
569
570 let full_groups = (count - 1) / 4;
571 let remainder = (count - 1) % 4;
572
573 for group in 0..full_groups {
574 let base = group * 4;
575 let in_ptr = input.as_ptr().add(base * 2);
576
577 let vals = _mm_loadl_epi64(in_ptr as *const __m128i);
579 let d = _mm_cvtepu16_epi32(vals);
580
581 let gaps = _mm_add_epi32(d, ones);
583
584 let prefix = prefix_sum_4(gaps);
586
587 let result = _mm_add_epi32(prefix, carry);
589
590 _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
592
593 carry = _mm_shuffle_epi32(result, 0xFF);
595 }
596
597 let base = full_groups * 4;
599 let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
600 for j in 0..remainder {
601 let idx = (base + j) * 2;
602 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
603 scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
604 output[base + j + 1] = scalar_carry;
605 }
606 }
607
608 #[inline]
610 pub fn is_available() -> bool {
611 is_x86_feature_detected!("sse4.1")
612 }
613}
614
615#[cfg(target_arch = "x86_64")]
620#[allow(unsafe_op_in_unsafe_fn)]
621mod avx2 {
622 use std::arch::x86_64::*;
623
624 #[target_feature(enable = "avx2")]
626 pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
627 let chunks = count / 32;
628 let remainder = count % 32;
629
630 for chunk in 0..chunks {
631 let base = chunk * 32;
632 let in_ptr = input.as_ptr().add(base);
633
634 let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
636 let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
637
638 let v0 = _mm256_cvtepu8_epi32(bytes_lo);
640 let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
641 let v2 = _mm256_cvtepu8_epi32(bytes_hi);
642 let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
643
644 let out_ptr = output.as_mut_ptr().add(base);
645 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
646 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
647 _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
648 _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
649 }
650
651 let base = chunks * 32;
653 for i in 0..remainder {
654 output[base + i] = input[base + i] as u32;
655 }
656 }
657
658 #[target_feature(enable = "avx2")]
660 pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
661 let chunks = count / 16;
662 let remainder = count % 16;
663
664 for chunk in 0..chunks {
665 let base = chunk * 16;
666 let in_ptr = input.as_ptr().add(base * 2);
667
668 let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
670 let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
671
672 let v0 = _mm256_cvtepu16_epi32(vals_lo);
674 let v1 = _mm256_cvtepu16_epi32(vals_hi);
675
676 let out_ptr = output.as_mut_ptr().add(base);
677 _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
678 _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
679 }
680
681 let base = chunks * 16;
683 for i in 0..remainder {
684 let idx = (base + i) * 2;
685 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
686 }
687 }
688
689 #[target_feature(enable = "avx2")]
691 pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
692 let chunks = count / 8;
693 let remainder = count % 8;
694
695 let in_ptr = input.as_ptr() as *const __m256i;
696 let out_ptr = output.as_mut_ptr() as *mut __m256i;
697
698 for chunk in 0..chunks {
699 let vals = _mm256_loadu_si256(in_ptr.add(chunk));
700 _mm256_storeu_si256(out_ptr.add(chunk), vals);
701 }
702
703 let base = chunks * 8;
705 for i in 0..remainder {
706 let idx = (base + i) * 4;
707 output[base + i] =
708 u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
709 }
710 }
711
712 #[target_feature(enable = "avx2")]
714 pub unsafe fn add_one(values: &mut [u32], count: usize) {
715 let ones = _mm256_set1_epi32(1);
716 let chunks = count / 8;
717 let remainder = count % 8;
718
719 for chunk in 0..chunks {
720 let base = chunk * 8;
721 let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
722 let v = _mm256_loadu_si256(ptr);
723 let result = _mm256_add_epi32(v, ones);
724 _mm256_storeu_si256(ptr, result);
725 }
726
727 let base = chunks * 8;
728 for i in 0..remainder {
729 values[base + i] += 1;
730 }
731 }
732
733 #[inline]
735 pub fn is_available() -> bool {
736 is_x86_feature_detected!("avx2")
737 }
738}
739
740#[allow(dead_code)]
745mod scalar {
746 #[inline]
748 pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
749 for i in 0..count {
750 output[i] = input[i] as u32;
751 }
752 }
753
754 #[inline]
756 pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
757 for (i, out) in output.iter_mut().enumerate().take(count) {
758 let idx = i * 2;
759 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
760 }
761 }
762
763 #[inline]
765 pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
766 for (i, out) in output.iter_mut().enumerate().take(count) {
767 let idx = i * 4;
768 *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
769 }
770 }
771
772 #[inline]
774 pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
775 if count == 0 {
776 return;
777 }
778
779 output[0] = first_doc_id;
780 let mut carry = first_doc_id;
781
782 for i in 0..count - 1 {
783 carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
784 output[i + 1] = carry;
785 }
786 }
787
788 #[inline]
790 pub fn add_one(values: &mut [u32], count: usize) {
791 for val in values.iter_mut().take(count) {
792 *val += 1;
793 }
794 }
795}
796
797#[inline]
803pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
804 #[cfg(target_arch = "aarch64")]
805 {
806 if neon::is_available() {
807 unsafe {
808 neon::unpack_8bit(input, output, count);
809 }
810 return;
811 }
812 }
813
814 #[cfg(target_arch = "x86_64")]
815 {
816 if avx2::is_available() {
818 unsafe {
819 avx2::unpack_8bit(input, output, count);
820 }
821 return;
822 }
823 if sse::is_available() {
824 unsafe {
825 sse::unpack_8bit(input, output, count);
826 }
827 return;
828 }
829 }
830
831 scalar::unpack_8bit(input, output, count);
832}
833
834#[inline]
836pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
837 #[cfg(target_arch = "aarch64")]
838 {
839 if neon::is_available() {
840 unsafe {
841 neon::unpack_16bit(input, output, count);
842 }
843 return;
844 }
845 }
846
847 #[cfg(target_arch = "x86_64")]
848 {
849 if avx2::is_available() {
851 unsafe {
852 avx2::unpack_16bit(input, output, count);
853 }
854 return;
855 }
856 if sse::is_available() {
857 unsafe {
858 sse::unpack_16bit(input, output, count);
859 }
860 return;
861 }
862 }
863
864 scalar::unpack_16bit(input, output, count);
865}
866
867#[inline]
869pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
870 #[cfg(target_arch = "aarch64")]
871 {
872 if neon::is_available() {
873 unsafe {
874 neon::unpack_32bit(input, output, count);
875 }
876 }
877 }
878
879 #[cfg(target_arch = "x86_64")]
880 {
881 if avx2::is_available() {
883 unsafe {
884 avx2::unpack_32bit(input, output, count);
885 }
886 } else {
887 unsafe {
889 sse::unpack_32bit(input, output, count);
890 }
891 }
892 }
893
894 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
895 {
896 scalar::unpack_32bit(input, output, count);
897 }
898}
899
900#[inline]
906pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
907 #[cfg(target_arch = "aarch64")]
908 {
909 if neon::is_available() {
910 unsafe {
911 neon::delta_decode(output, deltas, first_value, count);
912 }
913 return;
914 }
915 }
916
917 #[cfg(target_arch = "x86_64")]
918 {
919 if sse::is_available() {
920 unsafe {
921 sse::delta_decode(output, deltas, first_value, count);
922 }
923 return;
924 }
925 }
926
927 scalar::delta_decode(output, deltas, first_value, count);
928}
929
930#[inline]
934pub fn add_one(values: &mut [u32], count: usize) {
935 #[cfg(target_arch = "aarch64")]
936 {
937 if neon::is_available() {
938 unsafe {
939 neon::add_one(values, count);
940 }
941 }
942 }
943
944 #[cfg(target_arch = "x86_64")]
945 {
946 if avx2::is_available() {
948 unsafe {
949 avx2::add_one(values, count);
950 }
951 } else {
952 unsafe {
954 sse::add_one(values, count);
955 }
956 }
957 }
958
959 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
960 {
961 scalar::add_one(values, count);
962 }
963}
964
965#[inline]
967pub fn bits_needed(val: u32) -> u8 {
968 if val == 0 {
969 0
970 } else {
971 32 - val.leading_zeros() as u8
972 }
973}
974
975#[derive(Debug, Clone, Copy, PartialEq, Eq)]
992#[repr(u8)]
993pub enum RoundedBitWidth {
994 Zero = 0,
995 Bits8 = 8,
996 Bits16 = 16,
997 Bits32 = 32,
998}
999
1000impl RoundedBitWidth {
1001 #[inline]
1003 pub fn from_exact(bits: u8) -> Self {
1004 match bits {
1005 0 => RoundedBitWidth::Zero,
1006 1..=8 => RoundedBitWidth::Bits8,
1007 9..=16 => RoundedBitWidth::Bits16,
1008 _ => RoundedBitWidth::Bits32,
1009 }
1010 }
1011
1012 #[inline]
1014 pub fn from_u8(bits: u8) -> Self {
1015 match bits {
1016 0 => RoundedBitWidth::Zero,
1017 8 => RoundedBitWidth::Bits8,
1018 16 => RoundedBitWidth::Bits16,
1019 32 => RoundedBitWidth::Bits32,
1020 _ => RoundedBitWidth::Bits32, }
1022 }
1023
1024 #[inline]
1026 pub fn bytes_per_value(self) -> usize {
1027 match self {
1028 RoundedBitWidth::Zero => 0,
1029 RoundedBitWidth::Bits8 => 1,
1030 RoundedBitWidth::Bits16 => 2,
1031 RoundedBitWidth::Bits32 => 4,
1032 }
1033 }
1034
1035 #[inline]
1037 pub fn as_u8(self) -> u8 {
1038 self as u8
1039 }
1040}
1041
1042#[inline]
1044pub fn round_bit_width(bits: u8) -> u8 {
1045 RoundedBitWidth::from_exact(bits).as_u8()
1046}
1047
1048#[inline]
1053pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1054 let count = values.len();
1055 match bit_width {
1056 RoundedBitWidth::Zero => 0,
1057 RoundedBitWidth::Bits8 => {
1058 for (i, &v) in values.iter().enumerate() {
1059 output[i] = v as u8;
1060 }
1061 count
1062 }
1063 RoundedBitWidth::Bits16 => {
1064 for (i, &v) in values.iter().enumerate() {
1065 let bytes = (v as u16).to_le_bytes();
1066 output[i * 2] = bytes[0];
1067 output[i * 2 + 1] = bytes[1];
1068 }
1069 count * 2
1070 }
1071 RoundedBitWidth::Bits32 => {
1072 for (i, &v) in values.iter().enumerate() {
1073 let bytes = v.to_le_bytes();
1074 output[i * 4] = bytes[0];
1075 output[i * 4 + 1] = bytes[1];
1076 output[i * 4 + 2] = bytes[2];
1077 output[i * 4 + 3] = bytes[3];
1078 }
1079 count * 4
1080 }
1081 }
1082}
1083
1084#[inline]
1088pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1089 match bit_width {
1090 RoundedBitWidth::Zero => {
1091 for out in output.iter_mut().take(count) {
1092 *out = 0;
1093 }
1094 }
1095 RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1096 RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1097 RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1098 }
1099}
1100
1101#[inline]
1105pub fn unpack_rounded_delta_decode(
1106 input: &[u8],
1107 bit_width: RoundedBitWidth,
1108 output: &mut [u32],
1109 first_value: u32,
1110 count: usize,
1111) {
1112 match bit_width {
1113 RoundedBitWidth::Zero => {
1114 let mut val = first_value;
1116 for out in output.iter_mut().take(count) {
1117 *out = val;
1118 val = val.wrapping_add(1);
1119 }
1120 }
1121 RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1122 RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1123 RoundedBitWidth::Bits32 => {
1124 unpack_32bit(input, output, count);
1126 if count > 0 {
1129 let mut carry = first_value;
1130 output[0] = first_value;
1131 for item in output.iter_mut().take(count).skip(1) {
1132 carry = carry.wrapping_add(*item).wrapping_add(1);
1134 *item = carry;
1135 }
1136 }
1137 }
1138 }
1139}
1140
1141#[inline]
1150pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1151 if count == 0 {
1152 return;
1153 }
1154
1155 output[0] = first_value;
1156 if count == 1 {
1157 return;
1158 }
1159
1160 #[cfg(target_arch = "aarch64")]
1161 {
1162 if neon::is_available() {
1163 unsafe {
1164 neon::unpack_8bit_delta_decode(input, output, first_value, count);
1165 }
1166 return;
1167 }
1168 }
1169
1170 #[cfg(target_arch = "x86_64")]
1171 {
1172 if sse::is_available() {
1173 unsafe {
1174 sse::unpack_8bit_delta_decode(input, output, first_value, count);
1175 }
1176 return;
1177 }
1178 }
1179
1180 let mut carry = first_value;
1182 for i in 0..count - 1 {
1183 carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1184 output[i + 1] = carry;
1185 }
1186}
1187
1188#[inline]
1190pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1191 if count == 0 {
1192 return;
1193 }
1194
1195 output[0] = first_value;
1196 if count == 1 {
1197 return;
1198 }
1199
1200 #[cfg(target_arch = "aarch64")]
1201 {
1202 if neon::is_available() {
1203 unsafe {
1204 neon::unpack_16bit_delta_decode(input, output, first_value, count);
1205 }
1206 return;
1207 }
1208 }
1209
1210 #[cfg(target_arch = "x86_64")]
1211 {
1212 if sse::is_available() {
1213 unsafe {
1214 sse::unpack_16bit_delta_decode(input, output, first_value, count);
1215 }
1216 return;
1217 }
1218 }
1219
1220 let mut carry = first_value;
1222 for i in 0..count - 1 {
1223 let idx = i * 2;
1224 let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1225 carry = carry.wrapping_add(delta).wrapping_add(1);
1226 output[i + 1] = carry;
1227 }
1228}
1229
1230#[inline]
1235pub fn unpack_delta_decode(
1236 input: &[u8],
1237 bit_width: u8,
1238 output: &mut [u32],
1239 first_value: u32,
1240 count: usize,
1241) {
1242 if count == 0 {
1243 return;
1244 }
1245
1246 output[0] = first_value;
1247 if count == 1 {
1248 return;
1249 }
1250
1251 match bit_width {
1253 0 => {
1254 let mut val = first_value;
1256 for item in output.iter_mut().take(count).skip(1) {
1257 val = val.wrapping_add(1);
1258 *item = val;
1259 }
1260 }
1261 8 => unpack_8bit_delta_decode(input, output, first_value, count),
1262 16 => unpack_16bit_delta_decode(input, output, first_value, count),
1263 32 => {
1264 let mut carry = first_value;
1266 for i in 0..count - 1 {
1267 let idx = i * 4;
1268 let delta = u32::from_le_bytes([
1269 input[idx],
1270 input[idx + 1],
1271 input[idx + 2],
1272 input[idx + 3],
1273 ]);
1274 carry = carry.wrapping_add(delta).wrapping_add(1);
1275 output[i + 1] = carry;
1276 }
1277 }
1278 _ => {
1279 let mask = (1u64 << bit_width) - 1;
1281 let bit_width_usize = bit_width as usize;
1282 let mut bit_pos = 0usize;
1283 let input_ptr = input.as_ptr();
1284 let mut carry = first_value;
1285
1286 for i in 0..count - 1 {
1287 let byte_idx = bit_pos >> 3;
1288 let bit_offset = bit_pos & 7;
1289
1290 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1292 let delta = ((word >> bit_offset) & mask) as u32;
1293
1294 carry = carry.wrapping_add(delta).wrapping_add(1);
1295 output[i + 1] = carry;
1296 bit_pos += bit_width_usize;
1297 }
1298 }
1299 }
1300}
1301
1302#[inline]
1310pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1311 #[cfg(target_arch = "aarch64")]
1312 {
1313 if neon::is_available() {
1314 unsafe {
1315 dequantize_uint8_neon(input, output, scale, min_val, count);
1316 }
1317 return;
1318 }
1319 }
1320
1321 #[cfg(target_arch = "x86_64")]
1322 {
1323 if sse::is_available() {
1324 unsafe {
1325 dequantize_uint8_sse(input, output, scale, min_val, count);
1326 }
1327 return;
1328 }
1329 }
1330
1331 for i in 0..count {
1333 output[i] = input[i] as f32 * scale + min_val;
1334 }
1335}
1336
1337#[cfg(target_arch = "aarch64")]
1338#[target_feature(enable = "neon")]
1339#[allow(unsafe_op_in_unsafe_fn)]
1340unsafe fn dequantize_uint8_neon(
1341 input: &[u8],
1342 output: &mut [f32],
1343 scale: f32,
1344 min_val: f32,
1345 count: usize,
1346) {
1347 use std::arch::aarch64::*;
1348
1349 let scale_v = vdupq_n_f32(scale);
1350 let min_v = vdupq_n_f32(min_val);
1351
1352 let chunks = count / 16;
1353 let remainder = count % 16;
1354
1355 for chunk in 0..chunks {
1356 let base = chunk * 16;
1357 let in_ptr = input.as_ptr().add(base);
1358
1359 let bytes = vld1q_u8(in_ptr);
1361
1362 let low8 = vget_low_u8(bytes);
1364 let high8 = vget_high_u8(bytes);
1365
1366 let low16 = vmovl_u8(low8);
1367 let high16 = vmovl_u8(high8);
1368
1369 let u32_0 = vmovl_u16(vget_low_u16(low16));
1371 let u32_1 = vmovl_u16(vget_high_u16(low16));
1372 let u32_2 = vmovl_u16(vget_low_u16(high16));
1373 let u32_3 = vmovl_u16(vget_high_u16(high16));
1374
1375 let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1377 let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1378 let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1379 let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1380
1381 let out_ptr = output.as_mut_ptr().add(base);
1382 vst1q_f32(out_ptr, f32_0);
1383 vst1q_f32(out_ptr.add(4), f32_1);
1384 vst1q_f32(out_ptr.add(8), f32_2);
1385 vst1q_f32(out_ptr.add(12), f32_3);
1386 }
1387
1388 let base = chunks * 16;
1390 for i in 0..remainder {
1391 output[base + i] = input[base + i] as f32 * scale + min_val;
1392 }
1393}
1394
1395#[cfg(target_arch = "x86_64")]
1396#[target_feature(enable = "sse2", enable = "sse4.1")]
1397#[allow(unsafe_op_in_unsafe_fn)]
1398unsafe fn dequantize_uint8_sse(
1399 input: &[u8],
1400 output: &mut [f32],
1401 scale: f32,
1402 min_val: f32,
1403 count: usize,
1404) {
1405 use std::arch::x86_64::*;
1406
1407 let scale_v = _mm_set1_ps(scale);
1408 let min_v = _mm_set1_ps(min_val);
1409
1410 let chunks = count / 4;
1411 let remainder = count % 4;
1412
1413 for chunk in 0..chunks {
1414 let base = chunk * 4;
1415
1416 let b0 = input[base] as i32;
1418 let b1 = input[base + 1] as i32;
1419 let b2 = input[base + 2] as i32;
1420 let b3 = input[base + 3] as i32;
1421
1422 let ints = _mm_set_epi32(b3, b2, b1, b0);
1423 let floats = _mm_cvtepi32_ps(ints);
1424
1425 let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1427
1428 _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1429 }
1430
1431 let base = chunks * 4;
1433 for i in 0..remainder {
1434 output[base + i] = input[base + i] as f32 * scale + min_val;
1435 }
1436}
1437
1438#[inline]
1440pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1441 #[cfg(target_arch = "aarch64")]
1442 {
1443 if neon::is_available() {
1444 return unsafe { dot_product_f32_neon(a, b, count) };
1445 }
1446 }
1447
1448 #[cfg(target_arch = "x86_64")]
1449 {
1450 if sse::is_available() {
1451 return unsafe { dot_product_f32_sse(a, b, count) };
1452 }
1453 }
1454
1455 let mut sum = 0.0f32;
1457 for i in 0..count {
1458 sum += a[i] * b[i];
1459 }
1460 sum
1461}
1462
1463#[cfg(target_arch = "aarch64")]
1464#[target_feature(enable = "neon")]
1465#[allow(unsafe_op_in_unsafe_fn)]
1466unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1467 use std::arch::aarch64::*;
1468
1469 let chunks = count / 4;
1470 let remainder = count % 4;
1471
1472 let mut acc = vdupq_n_f32(0.0);
1473
1474 for chunk in 0..chunks {
1475 let base = chunk * 4;
1476 let va = vld1q_f32(a.as_ptr().add(base));
1477 let vb = vld1q_f32(b.as_ptr().add(base));
1478 acc = vfmaq_f32(acc, va, vb);
1479 }
1480
1481 let mut sum = vaddvq_f32(acc);
1483
1484 let base = chunks * 4;
1486 for i in 0..remainder {
1487 sum += a[base + i] * b[base + i];
1488 }
1489
1490 sum
1491}
1492
1493#[cfg(target_arch = "x86_64")]
1494#[target_feature(enable = "sse")]
1495#[allow(unsafe_op_in_unsafe_fn)]
1496unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1497 use std::arch::x86_64::*;
1498
1499 let chunks = count / 4;
1500 let remainder = count % 4;
1501
1502 let mut acc = _mm_setzero_ps();
1503
1504 for chunk in 0..chunks {
1505 let base = chunk * 4;
1506 let va = _mm_loadu_ps(a.as_ptr().add(base));
1507 let vb = _mm_loadu_ps(b.as_ptr().add(base));
1508 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1509 }
1510
1511 let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2); let mut sum = _mm_cvtss_f32(final_sum);
1518
1519 let base = chunks * 4;
1521 for i in 0..remainder {
1522 sum += a[base + i] * b[base + i];
1523 }
1524
1525 sum
1526}
1527
1528#[inline]
1530pub fn max_f32(values: &[f32], count: usize) -> f32 {
1531 if count == 0 {
1532 return f32::NEG_INFINITY;
1533 }
1534
1535 #[cfg(target_arch = "aarch64")]
1536 {
1537 if neon::is_available() {
1538 return unsafe { max_f32_neon(values, count) };
1539 }
1540 }
1541
1542 #[cfg(target_arch = "x86_64")]
1543 {
1544 if sse::is_available() {
1545 return unsafe { max_f32_sse(values, count) };
1546 }
1547 }
1548
1549 values[..count]
1551 .iter()
1552 .cloned()
1553 .fold(f32::NEG_INFINITY, f32::max)
1554}
1555
1556#[cfg(target_arch = "aarch64")]
1557#[target_feature(enable = "neon")]
1558#[allow(unsafe_op_in_unsafe_fn)]
1559unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1560 use std::arch::aarch64::*;
1561
1562 let chunks = count / 4;
1563 let remainder = count % 4;
1564
1565 let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1566
1567 for chunk in 0..chunks {
1568 let base = chunk * 4;
1569 let v = vld1q_f32(values.as_ptr().add(base));
1570 max_v = vmaxq_f32(max_v, v);
1571 }
1572
1573 let mut max_val = vmaxvq_f32(max_v);
1575
1576 let base = chunks * 4;
1578 for i in 0..remainder {
1579 max_val = max_val.max(values[base + i]);
1580 }
1581
1582 max_val
1583}
1584
1585#[cfg(target_arch = "x86_64")]
1586#[target_feature(enable = "sse")]
1587#[allow(unsafe_op_in_unsafe_fn)]
1588unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1589 use std::arch::x86_64::*;
1590
1591 let chunks = count / 4;
1592 let remainder = count % 4;
1593
1594 let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1595
1596 for chunk in 0..chunks {
1597 let base = chunk * 4;
1598 let v = _mm_loadu_ps(values.as_ptr().add(base));
1599 max_v = _mm_max_ps(max_v, v);
1600 }
1601
1602 let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); let max1 = _mm_max_ps(max_v, shuf); let shuf2 = _mm_movehl_ps(max1, max1); let final_max = _mm_max_ss(max1, shuf2); let mut max_val = _mm_cvtss_f32(final_max);
1609
1610 let base = chunks * 4;
1612 for i in 0..remainder {
1613 max_val = max_val.max(values[base + i]);
1614 }
1615
1616 max_val
1617}
1618
1619#[cfg(test)]
1620mod tests {
1621 use super::*;
1622
1623 #[test]
1624 fn test_unpack_8bit() {
1625 let input: Vec<u8> = (0..128).collect();
1626 let mut output = vec![0u32; 128];
1627 unpack_8bit(&input, &mut output, 128);
1628
1629 for (i, &v) in output.iter().enumerate() {
1630 assert_eq!(v, i as u32);
1631 }
1632 }
1633
1634 #[test]
1635 fn test_unpack_16bit() {
1636 let mut input = vec![0u8; 256];
1637 for i in 0..128 {
1638 let val = (i * 100) as u16;
1639 input[i * 2] = val as u8;
1640 input[i * 2 + 1] = (val >> 8) as u8;
1641 }
1642
1643 let mut output = vec![0u32; 128];
1644 unpack_16bit(&input, &mut output, 128);
1645
1646 for (i, &v) in output.iter().enumerate() {
1647 assert_eq!(v, (i * 100) as u32);
1648 }
1649 }
1650
1651 #[test]
1652 fn test_unpack_32bit() {
1653 let mut input = vec![0u8; 512];
1654 for i in 0..128 {
1655 let val = (i * 1000) as u32;
1656 let bytes = val.to_le_bytes();
1657 input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
1658 }
1659
1660 let mut output = vec![0u32; 128];
1661 unpack_32bit(&input, &mut output, 128);
1662
1663 for (i, &v) in output.iter().enumerate() {
1664 assert_eq!(v, (i * 1000) as u32);
1665 }
1666 }
1667
1668 #[test]
1669 fn test_delta_decode() {
1670 let deltas = vec![4u32, 4, 9, 19];
1674 let mut output = vec![0u32; 5];
1675
1676 delta_decode(&mut output, &deltas, 10, 5);
1677
1678 assert_eq!(output, vec![10, 15, 20, 30, 50]);
1679 }
1680
1681 #[test]
1682 fn test_add_one() {
1683 let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
1684 add_one(&mut values, 8);
1685
1686 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
1687 }
1688
1689 #[test]
1690 fn test_bits_needed() {
1691 assert_eq!(bits_needed(0), 0);
1692 assert_eq!(bits_needed(1), 1);
1693 assert_eq!(bits_needed(2), 2);
1694 assert_eq!(bits_needed(3), 2);
1695 assert_eq!(bits_needed(4), 3);
1696 assert_eq!(bits_needed(255), 8);
1697 assert_eq!(bits_needed(256), 9);
1698 assert_eq!(bits_needed(u32::MAX), 32);
1699 }
1700
1701 #[test]
1702 fn test_unpack_8bit_delta_decode() {
1703 let input: Vec<u8> = vec![4, 4, 9, 19];
1707 let mut output = vec![0u32; 5];
1708
1709 unpack_8bit_delta_decode(&input, &mut output, 10, 5);
1710
1711 assert_eq!(output, vec![10, 15, 20, 30, 50]);
1712 }
1713
1714 #[test]
1715 fn test_unpack_16bit_delta_decode() {
1716 let mut input = vec![0u8; 8];
1720 for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
1721 input[i * 2] = delta as u8;
1722 input[i * 2 + 1] = (delta >> 8) as u8;
1723 }
1724 let mut output = vec![0u32; 5];
1725
1726 unpack_16bit_delta_decode(&input, &mut output, 100, 5);
1727
1728 assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
1729 }
1730
1731 #[test]
1732 fn test_fused_vs_separate_8bit() {
1733 let input: Vec<u8> = (0..127).collect();
1735 let first_value = 1000u32;
1736 let count = 128;
1737
1738 let mut unpacked = vec![0u32; 128];
1740 unpack_8bit(&input, &mut unpacked, 127);
1741 let mut separate_output = vec![0u32; 128];
1742 delta_decode(&mut separate_output, &unpacked, first_value, count);
1743
1744 let mut fused_output = vec![0u32; 128];
1746 unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
1747
1748 assert_eq!(separate_output, fused_output);
1749 }
1750
1751 #[test]
1752 fn test_round_bit_width() {
1753 assert_eq!(round_bit_width(0), 0);
1754 assert_eq!(round_bit_width(1), 8);
1755 assert_eq!(round_bit_width(5), 8);
1756 assert_eq!(round_bit_width(8), 8);
1757 assert_eq!(round_bit_width(9), 16);
1758 assert_eq!(round_bit_width(12), 16);
1759 assert_eq!(round_bit_width(16), 16);
1760 assert_eq!(round_bit_width(17), 32);
1761 assert_eq!(round_bit_width(24), 32);
1762 assert_eq!(round_bit_width(32), 32);
1763 }
1764
1765 #[test]
1766 fn test_rounded_bitwidth_from_exact() {
1767 assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
1768 assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
1769 assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
1770 assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
1771 assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
1772 assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
1773 assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
1774 }
1775
1776 #[test]
1777 fn test_pack_unpack_rounded_8bit() {
1778 let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
1779 let mut packed = vec![0u8; 128];
1780
1781 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
1782 assert_eq!(bytes_written, 128);
1783
1784 let mut unpacked = vec![0u32; 128];
1785 unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
1786
1787 assert_eq!(values, unpacked);
1788 }
1789
1790 #[test]
1791 fn test_pack_unpack_rounded_16bit() {
1792 let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
1793 let mut packed = vec![0u8; 256];
1794
1795 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
1796 assert_eq!(bytes_written, 256);
1797
1798 let mut unpacked = vec![0u32; 128];
1799 unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
1800
1801 assert_eq!(values, unpacked);
1802 }
1803
1804 #[test]
1805 fn test_pack_unpack_rounded_32bit() {
1806 let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
1807 let mut packed = vec![0u8; 512];
1808
1809 let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
1810 assert_eq!(bytes_written, 512);
1811
1812 let mut unpacked = vec![0u32; 128];
1813 unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
1814
1815 assert_eq!(values, unpacked);
1816 }
1817
1818 #[test]
1819 fn test_unpack_rounded_delta_decode() {
1820 let input: Vec<u8> = vec![4, 4, 9, 19];
1825 let mut output = vec![0u32; 5];
1826
1827 unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
1828
1829 assert_eq!(output, vec![10, 15, 20, 30, 50]);
1830 }
1831
1832 #[test]
1833 fn test_unpack_rounded_delta_decode_zero() {
1834 let input: Vec<u8> = vec![];
1836 let mut output = vec![0u32; 5];
1837
1838 unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
1839
1840 assert_eq!(output, vec![100, 101, 102, 103, 104]);
1841 }
1842
1843 #[test]
1848 fn test_dequantize_uint8() {
1849 let input: Vec<u8> = vec![0, 128, 255, 64, 192];
1850 let mut output = vec![0.0f32; 5];
1851 let scale = 0.1;
1852 let min_val = 1.0;
1853
1854 dequantize_uint8(&input, &mut output, scale, min_val, 5);
1855
1856 assert!((output[0] - 1.0).abs() < 1e-6); assert!((output[1] - 13.8).abs() < 1e-6); assert!((output[2] - 26.5).abs() < 1e-6); assert!((output[3] - 7.4).abs() < 1e-6); assert!((output[4] - 20.2).abs() < 1e-6); }
1863
1864 #[test]
1865 fn test_dequantize_uint8_large() {
1866 let input: Vec<u8> = (0..128).collect();
1868 let mut output = vec![0.0f32; 128];
1869 let scale = 2.0;
1870 let min_val = -10.0;
1871
1872 dequantize_uint8(&input, &mut output, scale, min_val, 128);
1873
1874 for (i, &out) in output.iter().enumerate().take(128) {
1875 let expected = i as f32 * scale + min_val;
1876 assert!(
1877 (out - expected).abs() < 1e-5,
1878 "Mismatch at {}: expected {}, got {}",
1879 i,
1880 expected,
1881 out
1882 );
1883 }
1884 }
1885
1886 #[test]
1887 fn test_dot_product_f32() {
1888 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1889 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
1890
1891 let result = dot_product_f32(&a, &b, 5);
1892
1893 assert!((result - 70.0).abs() < 1e-5);
1895 }
1896
1897 #[test]
1898 fn test_dot_product_f32_large() {
1899 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
1901 let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
1902
1903 let result = dot_product_f32(&a, &b, 128);
1904
1905 let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
1907 assert!(
1908 (result - expected).abs() < 1e-3,
1909 "Expected {}, got {}",
1910 expected,
1911 result
1912 );
1913 }
1914
1915 #[test]
1916 fn test_max_f32() {
1917 let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
1918 let result = max_f32(&values, 6);
1919 assert!((result - 9.0).abs() < 1e-6);
1920 }
1921
1922 #[test]
1923 fn test_max_f32_large() {
1924 let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
1926 values[77] = 1000.0;
1927
1928 let result = max_f32(&values, 128);
1929 assert!((result - 1000.0).abs() < 1e-5);
1930 }
1931
1932 #[test]
1933 fn test_max_f32_negative() {
1934 let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
1935 let result = max_f32(&values, 5);
1936 assert!((result - (-1.0)).abs() < 1e-6);
1937 }
1938
1939 #[test]
1940 fn test_max_f32_empty() {
1941 let values: Vec<f32> = vec![];
1942 let result = max_f32(&values, 0);
1943 assert_eq!(result, f32::NEG_INFINITY);
1944 }
1945}