hermes_core/structures/
simd.rs

1//! Shared SIMD-accelerated functions for posting list compression
2//!
3//! This module provides platform-optimized implementations for common operations:
4//! - **Unpacking**: Convert packed 8/16/32-bit values to u32 arrays
5//! - **Delta decoding**: Prefix sum for converting deltas to absolute values
6//! - **Add one**: Increment all values in an array (for TF decoding)
7//!
8//! Supports:
9//! - **NEON** on aarch64 (Apple Silicon, ARM servers)
10//! - **SSE/SSE4.1** on x86_64 (Intel/AMD)
11//! - **Scalar fallback** for other architectures
12
13// ============================================================================
14// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
15// ============================================================================
16
17#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20    use std::arch::aarch64::*;
21
22    /// SIMD unpack for 8-bit values using NEON
23    #[target_feature(enable = "neon")]
24    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25        let chunks = count / 16;
26        let remainder = count % 16;
27
28        for chunk in 0..chunks {
29            let base = chunk * 16;
30            let in_ptr = input.as_ptr().add(base);
31
32            // Load 16 bytes
33            let bytes = vld1q_u8(in_ptr);
34
35            // Widen u8 -> u16 -> u32
36            let low8 = vget_low_u8(bytes);
37            let high8 = vget_high_u8(bytes);
38
39            let low16 = vmovl_u8(low8);
40            let high16 = vmovl_u8(high8);
41
42            let v0 = vmovl_u16(vget_low_u16(low16));
43            let v1 = vmovl_u16(vget_high_u16(low16));
44            let v2 = vmovl_u16(vget_low_u16(high16));
45            let v3 = vmovl_u16(vget_high_u16(high16));
46
47            let out_ptr = output.as_mut_ptr().add(base);
48            vst1q_u32(out_ptr, v0);
49            vst1q_u32(out_ptr.add(4), v1);
50            vst1q_u32(out_ptr.add(8), v2);
51            vst1q_u32(out_ptr.add(12), v3);
52        }
53
54        // Handle remainder
55        let base = chunks * 16;
56        for i in 0..remainder {
57            output[base + i] = input[base + i] as u32;
58        }
59    }
60
61    /// SIMD unpack for 16-bit values using NEON
62    #[target_feature(enable = "neon")]
63    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64        let chunks = count / 8;
65        let remainder = count % 8;
66
67        for chunk in 0..chunks {
68            let base = chunk * 8;
69            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71            let vals = vld1q_u16(in_ptr);
72            let low = vmovl_u16(vget_low_u16(vals));
73            let high = vmovl_u16(vget_high_u16(vals));
74
75            let out_ptr = output.as_mut_ptr().add(base);
76            vst1q_u32(out_ptr, low);
77            vst1q_u32(out_ptr.add(4), high);
78        }
79
80        // Handle remainder
81        let base = chunks * 8;
82        for i in 0..remainder {
83            let idx = (base + i) * 2;
84            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85        }
86    }
87
88    /// SIMD unpack for 32-bit values using NEON (fast copy)
89    #[target_feature(enable = "neon")]
90    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91        let chunks = count / 4;
92        let remainder = count % 4;
93
94        let in_ptr = input.as_ptr() as *const u32;
95        let out_ptr = output.as_mut_ptr();
96
97        for chunk in 0..chunks {
98            let vals = vld1q_u32(in_ptr.add(chunk * 4));
99            vst1q_u32(out_ptr.add(chunk * 4), vals);
100        }
101
102        // Handle remainder
103        let base = chunks * 4;
104        for i in 0..remainder {
105            let idx = (base + i) * 4;
106            output[base + i] =
107                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108        }
109    }
110
111    /// SIMD prefix sum for 4 u32 values using NEON
112    /// Input:  [a, b, c, d]
113    /// Output: [a, a+b, a+b+c, a+b+c+d]
114    #[inline]
115    #[target_feature(enable = "neon")]
116    unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117        // Step 1: shift by 1 and add
118        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
119        let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120        let sum1 = vaddq_u32(v, shifted1);
121
122        // Step 2: shift by 2 and add
123        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
124        let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125        vaddq_u32(sum1, shifted2)
126    }
127
128    /// SIMD delta decode: convert deltas to absolute doc IDs
129    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
130    /// Uses NEON SIMD prefix sum for high throughput
131    #[target_feature(enable = "neon")]
132    pub unsafe fn delta_decode(
133        output: &mut [u32],
134        deltas: &[u32],
135        first_doc_id: u32,
136        count: usize,
137    ) {
138        if count == 0 {
139            return;
140        }
141
142        output[0] = first_doc_id;
143        if count == 1 {
144            return;
145        }
146
147        let ones = vdupq_n_u32(1);
148        let mut carry = vdupq_n_u32(first_doc_id);
149
150        let full_groups = (count - 1) / 4;
151        let remainder = (count - 1) % 4;
152
153        for group in 0..full_groups {
154            let base = group * 4;
155
156            // Load 4 deltas and add 1 (since we store gap-1)
157            let d = vld1q_u32(deltas[base..].as_ptr());
158            let gaps = vaddq_u32(d, ones);
159
160            // Compute prefix sum within the 4 elements
161            let prefix = prefix_sum_4(gaps);
162
163            // Add carry (broadcast last element of previous group)
164            let result = vaddq_u32(prefix, carry);
165
166            // Store result
167            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169            // Update carry: broadcast the last element for next iteration
170            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171        }
172
173        // Handle remainder
174        let base = full_groups * 4;
175        let mut scalar_carry = vgetq_lane_u32(carry, 0);
176        for j in 0..remainder {
177            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178            output[base + j + 1] = scalar_carry;
179        }
180    }
181
182    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
183    #[target_feature(enable = "neon")]
184    pub unsafe fn add_one(values: &mut [u32], count: usize) {
185        let ones = vdupq_n_u32(1);
186        let chunks = count / 4;
187        let remainder = count % 4;
188
189        for chunk in 0..chunks {
190            let base = chunk * 4;
191            let ptr = values.as_mut_ptr().add(base);
192            let v = vld1q_u32(ptr);
193            let result = vaddq_u32(v, ones);
194            vst1q_u32(ptr, result);
195        }
196
197        let base = chunks * 4;
198        for i in 0..remainder {
199            values[base + i] += 1;
200        }
201    }
202
203    /// Fused unpack 8-bit + delta decode using NEON
204    /// Processes 4 values at a time, fusing unpack and prefix sum
205    #[target_feature(enable = "neon")]
206    pub unsafe fn unpack_8bit_delta_decode(
207        input: &[u8],
208        output: &mut [u32],
209        first_value: u32,
210        count: usize,
211    ) {
212        output[0] = first_value;
213        if count <= 1 {
214            return;
215        }
216
217        let ones = vdupq_n_u32(1);
218        let mut carry = vdupq_n_u32(first_value);
219
220        let full_groups = (count - 1) / 4;
221        let remainder = (count - 1) % 4;
222
223        for group in 0..full_groups {
224            let base = group * 4;
225
226            // Load 4 bytes and widen to u32
227            let b0 = input[base] as u32;
228            let b1 = input[base + 1] as u32;
229            let b2 = input[base + 2] as u32;
230            let b3 = input[base + 3] as u32;
231            let deltas = [b0, b1, b2, b3];
232            let d = vld1q_u32(deltas.as_ptr());
233
234            // Add 1 (since we store gap-1)
235            let gaps = vaddq_u32(d, ones);
236
237            // Compute prefix sum within the 4 elements
238            let prefix = prefix_sum_4(gaps);
239
240            // Add carry
241            let result = vaddq_u32(prefix, carry);
242
243            // Store result
244            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246            // Update carry
247            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248        }
249
250        // Handle remainder
251        let base = full_groups * 4;
252        let mut scalar_carry = vgetq_lane_u32(carry, 0);
253        for j in 0..remainder {
254            scalar_carry = scalar_carry
255                .wrapping_add(input[base + j] as u32)
256                .wrapping_add(1);
257            output[base + j + 1] = scalar_carry;
258        }
259    }
260
261    /// Fused unpack 16-bit + delta decode using NEON
262    #[target_feature(enable = "neon")]
263    pub unsafe fn unpack_16bit_delta_decode(
264        input: &[u8],
265        output: &mut [u32],
266        first_value: u32,
267        count: usize,
268    ) {
269        output[0] = first_value;
270        if count <= 1 {
271            return;
272        }
273
274        let ones = vdupq_n_u32(1);
275        let mut carry = vdupq_n_u32(first_value);
276
277        let full_groups = (count - 1) / 4;
278        let remainder = (count - 1) % 4;
279
280        for group in 0..full_groups {
281            let base = group * 4;
282            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284            // Load 4 u16 values and widen to u32
285            let vals = vld1_u16(in_ptr);
286            let d = vmovl_u16(vals);
287
288            // Add 1 (since we store gap-1)
289            let gaps = vaddq_u32(d, ones);
290
291            // Compute prefix sum within the 4 elements
292            let prefix = prefix_sum_4(gaps);
293
294            // Add carry
295            let result = vaddq_u32(prefix, carry);
296
297            // Store result
298            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300            // Update carry
301            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302        }
303
304        // Handle remainder
305        let base = full_groups * 4;
306        let mut scalar_carry = vgetq_lane_u32(carry, 0);
307        for j in 0..remainder {
308            let idx = (base + j) * 2;
309            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311            output[base + j + 1] = scalar_carry;
312        }
313    }
314
315    /// Check if NEON is available (always true on aarch64)
316    #[inline]
317    pub fn is_available() -> bool {
318        true
319    }
320}
321
322// ============================================================================
323// SSE intrinsics for x86_64 (Intel/AMD)
324// ============================================================================
325
326#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329    use std::arch::x86_64::*;
330
331    /// SIMD unpack for 8-bit values using SSE
332    #[target_feature(enable = "sse2", enable = "sse4.1")]
333    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334        let chunks = count / 16;
335        let remainder = count % 16;
336
337        for chunk in 0..chunks {
338            let base = chunk * 16;
339            let in_ptr = input.as_ptr().add(base);
340
341            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343            // Zero extend u8 -> u32 using SSE4.1 pmovzx
344            let v0 = _mm_cvtepu8_epi32(bytes);
345            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349            let out_ptr = output.as_mut_ptr().add(base);
350            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354        }
355
356        let base = chunks * 16;
357        for i in 0..remainder {
358            output[base + i] = input[base + i] as u32;
359        }
360    }
361
362    /// SIMD unpack for 16-bit values using SSE
363    #[target_feature(enable = "sse2", enable = "sse4.1")]
364    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365        let chunks = count / 8;
366        let remainder = count % 8;
367
368        for chunk in 0..chunks {
369            let base = chunk * 8;
370            let in_ptr = input.as_ptr().add(base * 2);
371
372            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373            let low = _mm_cvtepu16_epi32(vals);
374            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376            let out_ptr = output.as_mut_ptr().add(base);
377            _mm_storeu_si128(out_ptr as *mut __m128i, low);
378            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379        }
380
381        let base = chunks * 8;
382        for i in 0..remainder {
383            let idx = (base + i) * 2;
384            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385        }
386    }
387
388    /// SIMD unpack for 32-bit values using SSE (fast copy)
389    #[target_feature(enable = "sse2")]
390    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391        let chunks = count / 4;
392        let remainder = count % 4;
393
394        let in_ptr = input.as_ptr() as *const __m128i;
395        let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397        for chunk in 0..chunks {
398            let vals = _mm_loadu_si128(in_ptr.add(chunk));
399            _mm_storeu_si128(out_ptr.add(chunk), vals);
400        }
401
402        // Handle remainder
403        let base = chunks * 4;
404        for i in 0..remainder {
405            let idx = (base + i) * 4;
406            output[base + i] =
407                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408        }
409    }
410
411    /// SIMD prefix sum for 4 u32 values using SSE
412    /// Input:  [a, b, c, d]
413    /// Output: [a, a+b, a+b+c, a+b+c+d]
414    #[inline]
415    #[target_feature(enable = "sse2")]
416    unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417        // Step 1: shift by 1 element (4 bytes) and add
418        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
419        let shifted1 = _mm_slli_si128(v, 4);
420        let sum1 = _mm_add_epi32(v, shifted1);
421
422        // Step 2: shift by 2 elements (8 bytes) and add
423        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
424        let shifted2 = _mm_slli_si128(sum1, 8);
425        _mm_add_epi32(sum1, shifted2)
426    }
427
428    /// SIMD delta decode using SSE with true SIMD prefix sum
429    #[target_feature(enable = "sse2", enable = "sse4.1")]
430    pub unsafe fn delta_decode(
431        output: &mut [u32],
432        deltas: &[u32],
433        first_doc_id: u32,
434        count: usize,
435    ) {
436        if count == 0 {
437            return;
438        }
439
440        output[0] = first_doc_id;
441        if count == 1 {
442            return;
443        }
444
445        let ones = _mm_set1_epi32(1);
446        let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448        let full_groups = (count - 1) / 4;
449        let remainder = (count - 1) % 4;
450
451        for group in 0..full_groups {
452            let base = group * 4;
453
454            // Load 4 deltas and add 1 (since we store gap-1)
455            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456            let gaps = _mm_add_epi32(d, ones);
457
458            // Compute prefix sum within the 4 elements
459            let prefix = prefix_sum_4(gaps);
460
461            // Add carry (broadcast last element of previous group)
462            let result = _mm_add_epi32(prefix, carry);
463
464            // Store result
465            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467            // Update carry: broadcast the last element for next iteration
468            carry = _mm_shuffle_epi32(result, 0xFF); // broadcast lane 3
469        }
470
471        // Handle remainder
472        let base = full_groups * 4;
473        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474        for j in 0..remainder {
475            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476            output[base + j + 1] = scalar_carry;
477        }
478    }
479
480    /// SIMD add 1 to all values using SSE
481    #[target_feature(enable = "sse2")]
482    pub unsafe fn add_one(values: &mut [u32], count: usize) {
483        let ones = _mm_set1_epi32(1);
484        let chunks = count / 4;
485        let remainder = count % 4;
486
487        for chunk in 0..chunks {
488            let base = chunk * 4;
489            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490            let v = _mm_loadu_si128(ptr);
491            let result = _mm_add_epi32(v, ones);
492            _mm_storeu_si128(ptr, result);
493        }
494
495        let base = chunks * 4;
496        for i in 0..remainder {
497            values[base + i] += 1;
498        }
499    }
500
501    /// Fused unpack 8-bit + delta decode using SSE
502    #[target_feature(enable = "sse2", enable = "sse4.1")]
503    pub unsafe fn unpack_8bit_delta_decode(
504        input: &[u8],
505        output: &mut [u32],
506        first_value: u32,
507        count: usize,
508    ) {
509        output[0] = first_value;
510        if count <= 1 {
511            return;
512        }
513
514        let ones = _mm_set1_epi32(1);
515        let mut carry = _mm_set1_epi32(first_value as i32);
516
517        let full_groups = (count - 1) / 4;
518        let remainder = (count - 1) % 4;
519
520        for group in 0..full_groups {
521            let base = group * 4;
522
523            // Load 4 bytes and zero-extend to u32
524            let bytes = _mm_cvtsi32_si128(*(input.as_ptr().add(base) as *const i32));
525            let d = _mm_cvtepu8_epi32(bytes);
526
527            // Add 1 (since we store gap-1)
528            let gaps = _mm_add_epi32(d, ones);
529
530            // Compute prefix sum within the 4 elements
531            let prefix = prefix_sum_4(gaps);
532
533            // Add carry
534            let result = _mm_add_epi32(prefix, carry);
535
536            // Store result
537            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
538
539            // Update carry: broadcast the last element
540            carry = _mm_shuffle_epi32(result, 0xFF);
541        }
542
543        // Handle remainder
544        let base = full_groups * 4;
545        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
546        for j in 0..remainder {
547            scalar_carry = scalar_carry
548                .wrapping_add(input[base + j] as u32)
549                .wrapping_add(1);
550            output[base + j + 1] = scalar_carry;
551        }
552    }
553
554    /// Fused unpack 16-bit + delta decode using SSE
555    #[target_feature(enable = "sse2", enable = "sse4.1")]
556    pub unsafe fn unpack_16bit_delta_decode(
557        input: &[u8],
558        output: &mut [u32],
559        first_value: u32,
560        count: usize,
561    ) {
562        output[0] = first_value;
563        if count <= 1 {
564            return;
565        }
566
567        let ones = _mm_set1_epi32(1);
568        let mut carry = _mm_set1_epi32(first_value as i32);
569
570        let full_groups = (count - 1) / 4;
571        let remainder = (count - 1) % 4;
572
573        for group in 0..full_groups {
574            let base = group * 4;
575            let in_ptr = input.as_ptr().add(base * 2);
576
577            // Load 8 bytes (4 u16 values) and zero-extend to u32
578            let vals = _mm_loadl_epi64(in_ptr as *const __m128i);
579            let d = _mm_cvtepu16_epi32(vals);
580
581            // Add 1 (since we store gap-1)
582            let gaps = _mm_add_epi32(d, ones);
583
584            // Compute prefix sum within the 4 elements
585            let prefix = prefix_sum_4(gaps);
586
587            // Add carry
588            let result = _mm_add_epi32(prefix, carry);
589
590            // Store result
591            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
592
593            // Update carry: broadcast the last element
594            carry = _mm_shuffle_epi32(result, 0xFF);
595        }
596
597        // Handle remainder
598        let base = full_groups * 4;
599        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
600        for j in 0..remainder {
601            let idx = (base + j) * 2;
602            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
603            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
604            output[base + j + 1] = scalar_carry;
605        }
606    }
607
608    /// Check if SSE4.1 is available at runtime
609    #[inline]
610    pub fn is_available() -> bool {
611        is_x86_feature_detected!("sse4.1")
612    }
613}
614
615// ============================================================================
616// AVX2 intrinsics for x86_64 (Intel/AMD with 256-bit registers)
617// ============================================================================
618
619#[cfg(target_arch = "x86_64")]
620#[allow(unsafe_op_in_unsafe_fn)]
621mod avx2 {
622    use std::arch::x86_64::*;
623
624    /// AVX2 unpack for 8-bit values (processes 32 bytes at a time)
625    #[target_feature(enable = "avx2")]
626    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
627        let chunks = count / 32;
628        let remainder = count % 32;
629
630        for chunk in 0..chunks {
631            let base = chunk * 32;
632            let in_ptr = input.as_ptr().add(base);
633
634            // Load 32 bytes (two 128-bit loads, then combine)
635            let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
636            let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
637
638            // Zero extend first 16 bytes: u8 -> u32
639            let v0 = _mm256_cvtepu8_epi32(bytes_lo);
640            let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
641            let v2 = _mm256_cvtepu8_epi32(bytes_hi);
642            let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
643
644            let out_ptr = output.as_mut_ptr().add(base);
645            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
646            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
647            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
648            _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
649        }
650
651        // Handle remainder with SSE
652        let base = chunks * 32;
653        for i in 0..remainder {
654            output[base + i] = input[base + i] as u32;
655        }
656    }
657
658    /// AVX2 unpack for 16-bit values (processes 16 values at a time)
659    #[target_feature(enable = "avx2")]
660    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
661        let chunks = count / 16;
662        let remainder = count % 16;
663
664        for chunk in 0..chunks {
665            let base = chunk * 16;
666            let in_ptr = input.as_ptr().add(base * 2);
667
668            // Load 32 bytes (16 u16 values)
669            let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
670            let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
671
672            // Zero extend u16 -> u32
673            let v0 = _mm256_cvtepu16_epi32(vals_lo);
674            let v1 = _mm256_cvtepu16_epi32(vals_hi);
675
676            let out_ptr = output.as_mut_ptr().add(base);
677            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
678            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
679        }
680
681        // Handle remainder
682        let base = chunks * 16;
683        for i in 0..remainder {
684            let idx = (base + i) * 2;
685            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
686        }
687    }
688
689    /// AVX2 unpack for 32-bit values (fast copy, 8 values at a time)
690    #[target_feature(enable = "avx2")]
691    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
692        let chunks = count / 8;
693        let remainder = count % 8;
694
695        let in_ptr = input.as_ptr() as *const __m256i;
696        let out_ptr = output.as_mut_ptr() as *mut __m256i;
697
698        for chunk in 0..chunks {
699            let vals = _mm256_loadu_si256(in_ptr.add(chunk));
700            _mm256_storeu_si256(out_ptr.add(chunk), vals);
701        }
702
703        // Handle remainder
704        let base = chunks * 8;
705        for i in 0..remainder {
706            let idx = (base + i) * 4;
707            output[base + i] =
708                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
709        }
710    }
711
712    /// AVX2 add 1 to all values (8 values at a time)
713    #[target_feature(enable = "avx2")]
714    pub unsafe fn add_one(values: &mut [u32], count: usize) {
715        let ones = _mm256_set1_epi32(1);
716        let chunks = count / 8;
717        let remainder = count % 8;
718
719        for chunk in 0..chunks {
720            let base = chunk * 8;
721            let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
722            let v = _mm256_loadu_si256(ptr);
723            let result = _mm256_add_epi32(v, ones);
724            _mm256_storeu_si256(ptr, result);
725        }
726
727        let base = chunks * 8;
728        for i in 0..remainder {
729            values[base + i] += 1;
730        }
731    }
732
733    /// Check if AVX2 is available at runtime
734    #[inline]
735    pub fn is_available() -> bool {
736        is_x86_feature_detected!("avx2")
737    }
738}
739
740// ============================================================================
741// Scalar fallback implementations
742// ============================================================================
743
744#[allow(dead_code)]
745mod scalar {
746    /// Scalar unpack for 8-bit values
747    #[inline]
748    pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
749        for i in 0..count {
750            output[i] = input[i] as u32;
751        }
752    }
753
754    /// Scalar unpack for 16-bit values
755    #[inline]
756    pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
757        for (i, out) in output.iter_mut().enumerate().take(count) {
758            let idx = i * 2;
759            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
760        }
761    }
762
763    /// Scalar unpack for 32-bit values
764    #[inline]
765    pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
766        for (i, out) in output.iter_mut().enumerate().take(count) {
767            let idx = i * 4;
768            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
769        }
770    }
771
772    /// Scalar delta decode
773    #[inline]
774    pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
775        if count == 0 {
776            return;
777        }
778
779        output[0] = first_doc_id;
780        let mut carry = first_doc_id;
781
782        for i in 0..count - 1 {
783            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
784            output[i + 1] = carry;
785        }
786    }
787
788    /// Scalar add 1 to all values
789    #[inline]
790    pub fn add_one(values: &mut [u32], count: usize) {
791        for val in values.iter_mut().take(count) {
792            *val += 1;
793        }
794    }
795}
796
797// ============================================================================
798// Public dispatch functions that select SIMD or scalar at runtime
799// ============================================================================
800
801/// Unpack 8-bit packed values to u32 with SIMD acceleration
802#[inline]
803pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
804    #[cfg(target_arch = "aarch64")]
805    {
806        if neon::is_available() {
807            unsafe {
808                neon::unpack_8bit(input, output, count);
809            }
810            return;
811        }
812    }
813
814    #[cfg(target_arch = "x86_64")]
815    {
816        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
817        if avx2::is_available() {
818            unsafe {
819                avx2::unpack_8bit(input, output, count);
820            }
821            return;
822        }
823        if sse::is_available() {
824            unsafe {
825                sse::unpack_8bit(input, output, count);
826            }
827            return;
828        }
829    }
830
831    scalar::unpack_8bit(input, output, count);
832}
833
834/// Unpack 16-bit packed values to u32 with SIMD acceleration
835#[inline]
836pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
837    #[cfg(target_arch = "aarch64")]
838    {
839        if neon::is_available() {
840            unsafe {
841                neon::unpack_16bit(input, output, count);
842            }
843            return;
844        }
845    }
846
847    #[cfg(target_arch = "x86_64")]
848    {
849        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
850        if avx2::is_available() {
851            unsafe {
852                avx2::unpack_16bit(input, output, count);
853            }
854            return;
855        }
856        if sse::is_available() {
857            unsafe {
858                sse::unpack_16bit(input, output, count);
859            }
860            return;
861        }
862    }
863
864    scalar::unpack_16bit(input, output, count);
865}
866
867/// Unpack 32-bit packed values to u32 with SIMD acceleration
868#[inline]
869pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
870    #[cfg(target_arch = "aarch64")]
871    {
872        if neon::is_available() {
873            unsafe {
874                neon::unpack_32bit(input, output, count);
875            }
876        }
877    }
878
879    #[cfg(target_arch = "x86_64")]
880    {
881        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
882        if avx2::is_available() {
883            unsafe {
884                avx2::unpack_32bit(input, output, count);
885            }
886        } else {
887            // SSE2 is always available on x86_64
888            unsafe {
889                sse::unpack_32bit(input, output, count);
890            }
891        }
892    }
893
894    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
895    {
896        scalar::unpack_32bit(input, output, count);
897    }
898}
899
900/// Delta decode with SIMD acceleration
901///
902/// Converts delta-encoded values to absolute values.
903/// Input: deltas[i] = value[i+1] - value[i] - 1 (gap minus one)
904/// Output: absolute values starting from first_value
905#[inline]
906pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
907    #[cfg(target_arch = "aarch64")]
908    {
909        if neon::is_available() {
910            unsafe {
911                neon::delta_decode(output, deltas, first_value, count);
912            }
913            return;
914        }
915    }
916
917    #[cfg(target_arch = "x86_64")]
918    {
919        if sse::is_available() {
920            unsafe {
921                sse::delta_decode(output, deltas, first_value, count);
922            }
923            return;
924        }
925    }
926
927    scalar::delta_decode(output, deltas, first_value, count);
928}
929
930/// Add 1 to all values with SIMD acceleration
931///
932/// Used for TF decoding where values are stored as (tf - 1)
933#[inline]
934pub fn add_one(values: &mut [u32], count: usize) {
935    #[cfg(target_arch = "aarch64")]
936    {
937        if neon::is_available() {
938            unsafe {
939                neon::add_one(values, count);
940            }
941        }
942    }
943
944    #[cfg(target_arch = "x86_64")]
945    {
946        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
947        if avx2::is_available() {
948            unsafe {
949                avx2::add_one(values, count);
950            }
951        } else {
952            // SSE2 is always available on x86_64
953            unsafe {
954                sse::add_one(values, count);
955            }
956        }
957    }
958
959    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
960    {
961        scalar::add_one(values, count);
962    }
963}
964
965/// Compute the number of bits needed to represent a value
966#[inline]
967pub fn bits_needed(val: u32) -> u8 {
968    if val == 0 {
969        0
970    } else {
971        32 - val.leading_zeros() as u8
972    }
973}
974
975// ============================================================================
976// Fused operations for better cache utilization
977// ============================================================================
978
979/// Fused unpack 8-bit + delta decode in a single pass
980///
981/// This avoids writing the intermediate unpacked values to memory,
982/// improving cache utilization for large blocks.
983#[inline]
984pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
985    if count == 0 {
986        return;
987    }
988
989    output[0] = first_value;
990    if count == 1 {
991        return;
992    }
993
994    #[cfg(target_arch = "aarch64")]
995    {
996        if neon::is_available() {
997            unsafe {
998                neon::unpack_8bit_delta_decode(input, output, first_value, count);
999            }
1000            return;
1001        }
1002    }
1003
1004    #[cfg(target_arch = "x86_64")]
1005    {
1006        if sse::is_available() {
1007            unsafe {
1008                sse::unpack_8bit_delta_decode(input, output, first_value, count);
1009            }
1010            return;
1011        }
1012    }
1013
1014    // Scalar fallback
1015    let mut carry = first_value;
1016    for i in 0..count - 1 {
1017        carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1018        output[i + 1] = carry;
1019    }
1020}
1021
1022/// Fused unpack 16-bit + delta decode in a single pass
1023#[inline]
1024pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1025    if count == 0 {
1026        return;
1027    }
1028
1029    output[0] = first_value;
1030    if count == 1 {
1031        return;
1032    }
1033
1034    #[cfg(target_arch = "aarch64")]
1035    {
1036        if neon::is_available() {
1037            unsafe {
1038                neon::unpack_16bit_delta_decode(input, output, first_value, count);
1039            }
1040            return;
1041        }
1042    }
1043
1044    #[cfg(target_arch = "x86_64")]
1045    {
1046        if sse::is_available() {
1047            unsafe {
1048                sse::unpack_16bit_delta_decode(input, output, first_value, count);
1049            }
1050            return;
1051        }
1052    }
1053
1054    // Scalar fallback
1055    let mut carry = first_value;
1056    for i in 0..count - 1 {
1057        let idx = i * 2;
1058        let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1059        carry = carry.wrapping_add(delta).wrapping_add(1);
1060        output[i + 1] = carry;
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067
1068    #[test]
1069    fn test_unpack_8bit() {
1070        let input: Vec<u8> = (0..128).collect();
1071        let mut output = vec![0u32; 128];
1072        unpack_8bit(&input, &mut output, 128);
1073
1074        for (i, &v) in output.iter().enumerate() {
1075            assert_eq!(v, i as u32);
1076        }
1077    }
1078
1079    #[test]
1080    fn test_unpack_16bit() {
1081        let mut input = vec![0u8; 256];
1082        for i in 0..128 {
1083            let val = (i * 100) as u16;
1084            input[i * 2] = val as u8;
1085            input[i * 2 + 1] = (val >> 8) as u8;
1086        }
1087
1088        let mut output = vec![0u32; 128];
1089        unpack_16bit(&input, &mut output, 128);
1090
1091        for (i, &v) in output.iter().enumerate() {
1092            assert_eq!(v, (i * 100) as u32);
1093        }
1094    }
1095
1096    #[test]
1097    fn test_unpack_32bit() {
1098        let mut input = vec![0u8; 512];
1099        for i in 0..128 {
1100            let val = (i * 1000) as u32;
1101            let bytes = val.to_le_bytes();
1102            input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
1103        }
1104
1105        let mut output = vec![0u32; 128];
1106        unpack_32bit(&input, &mut output, 128);
1107
1108        for (i, &v) in output.iter().enumerate() {
1109            assert_eq!(v, (i * 1000) as u32);
1110        }
1111    }
1112
1113    #[test]
1114    fn test_delta_decode() {
1115        // doc_ids: [10, 15, 20, 30, 50]
1116        // gaps: [5, 5, 10, 20]
1117        // deltas (gap-1): [4, 4, 9, 19]
1118        let deltas = vec![4u32, 4, 9, 19];
1119        let mut output = vec![0u32; 5];
1120
1121        delta_decode(&mut output, &deltas, 10, 5);
1122
1123        assert_eq!(output, vec![10, 15, 20, 30, 50]);
1124    }
1125
1126    #[test]
1127    fn test_add_one() {
1128        let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
1129        add_one(&mut values, 8);
1130
1131        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
1132    }
1133
1134    #[test]
1135    fn test_bits_needed() {
1136        assert_eq!(bits_needed(0), 0);
1137        assert_eq!(bits_needed(1), 1);
1138        assert_eq!(bits_needed(2), 2);
1139        assert_eq!(bits_needed(3), 2);
1140        assert_eq!(bits_needed(4), 3);
1141        assert_eq!(bits_needed(255), 8);
1142        assert_eq!(bits_needed(256), 9);
1143        assert_eq!(bits_needed(u32::MAX), 32);
1144    }
1145
1146    #[test]
1147    fn test_unpack_8bit_delta_decode() {
1148        // doc_ids: [10, 15, 20, 30, 50]
1149        // gaps: [5, 5, 10, 20]
1150        // deltas (gap-1): [4, 4, 9, 19] stored as u8
1151        let input: Vec<u8> = vec![4, 4, 9, 19];
1152        let mut output = vec![0u32; 5];
1153
1154        unpack_8bit_delta_decode(&input, &mut output, 10, 5);
1155
1156        assert_eq!(output, vec![10, 15, 20, 30, 50]);
1157    }
1158
1159    #[test]
1160    fn test_unpack_16bit_delta_decode() {
1161        // doc_ids: [100, 600, 1100, 2100, 4100]
1162        // gaps: [500, 500, 1000, 2000]
1163        // deltas (gap-1): [499, 499, 999, 1999] stored as u16
1164        let mut input = vec![0u8; 8];
1165        for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
1166            input[i * 2] = delta as u8;
1167            input[i * 2 + 1] = (delta >> 8) as u8;
1168        }
1169        let mut output = vec![0u32; 5];
1170
1171        unpack_16bit_delta_decode(&input, &mut output, 100, 5);
1172
1173        assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
1174    }
1175
1176    #[test]
1177    fn test_fused_vs_separate_8bit() {
1178        // Test that fused and separate operations produce the same result
1179        let input: Vec<u8> = (0..127).collect();
1180        let first_value = 1000u32;
1181        let count = 128;
1182
1183        // Separate: unpack then delta_decode
1184        let mut unpacked = vec![0u32; 128];
1185        unpack_8bit(&input, &mut unpacked, 127);
1186        let mut separate_output = vec![0u32; 128];
1187        delta_decode(&mut separate_output, &unpacked, first_value, count);
1188
1189        // Fused
1190        let mut fused_output = vec![0u32; 128];
1191        unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
1192
1193        assert_eq!(separate_output, fused_output);
1194    }
1195}