hermes_core/structures/
simd.rs

1//! Shared SIMD-accelerated functions for posting list compression
2//!
3//! This module provides platform-optimized implementations for common operations:
4//! - **Unpacking**: Convert packed 8/16/32-bit values to u32 arrays
5//! - **Delta decoding**: Prefix sum for converting deltas to absolute values
6//! - **Add one**: Increment all values in an array (for TF decoding)
7//!
8//! Supports:
9//! - **NEON** on aarch64 (Apple Silicon, ARM servers)
10//! - **SSE/SSE4.1** on x86_64 (Intel/AMD)
11//! - **Scalar fallback** for other architectures
12
13// ============================================================================
14// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
15// ============================================================================
16
17#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20    use std::arch::aarch64::*;
21
22    /// SIMD unpack for 8-bit values using NEON
23    #[target_feature(enable = "neon")]
24    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25        let chunks = count / 16;
26        let remainder = count % 16;
27
28        for chunk in 0..chunks {
29            let base = chunk * 16;
30            let in_ptr = input.as_ptr().add(base);
31
32            // Load 16 bytes
33            let bytes = vld1q_u8(in_ptr);
34
35            // Widen u8 -> u16 -> u32
36            let low8 = vget_low_u8(bytes);
37            let high8 = vget_high_u8(bytes);
38
39            let low16 = vmovl_u8(low8);
40            let high16 = vmovl_u8(high8);
41
42            let v0 = vmovl_u16(vget_low_u16(low16));
43            let v1 = vmovl_u16(vget_high_u16(low16));
44            let v2 = vmovl_u16(vget_low_u16(high16));
45            let v3 = vmovl_u16(vget_high_u16(high16));
46
47            let out_ptr = output.as_mut_ptr().add(base);
48            vst1q_u32(out_ptr, v0);
49            vst1q_u32(out_ptr.add(4), v1);
50            vst1q_u32(out_ptr.add(8), v2);
51            vst1q_u32(out_ptr.add(12), v3);
52        }
53
54        // Handle remainder
55        let base = chunks * 16;
56        for i in 0..remainder {
57            output[base + i] = input[base + i] as u32;
58        }
59    }
60
61    /// SIMD unpack for 16-bit values using NEON
62    #[target_feature(enable = "neon")]
63    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64        let chunks = count / 8;
65        let remainder = count % 8;
66
67        for chunk in 0..chunks {
68            let base = chunk * 8;
69            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71            let vals = vld1q_u16(in_ptr);
72            let low = vmovl_u16(vget_low_u16(vals));
73            let high = vmovl_u16(vget_high_u16(vals));
74
75            let out_ptr = output.as_mut_ptr().add(base);
76            vst1q_u32(out_ptr, low);
77            vst1q_u32(out_ptr.add(4), high);
78        }
79
80        // Handle remainder
81        let base = chunks * 8;
82        for i in 0..remainder {
83            let idx = (base + i) * 2;
84            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85        }
86    }
87
88    /// SIMD unpack for 32-bit values using NEON (fast copy)
89    #[target_feature(enable = "neon")]
90    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91        let chunks = count / 4;
92        let remainder = count % 4;
93
94        let in_ptr = input.as_ptr() as *const u32;
95        let out_ptr = output.as_mut_ptr();
96
97        for chunk in 0..chunks {
98            let vals = vld1q_u32(in_ptr.add(chunk * 4));
99            vst1q_u32(out_ptr.add(chunk * 4), vals);
100        }
101
102        // Handle remainder
103        let base = chunks * 4;
104        for i in 0..remainder {
105            let idx = (base + i) * 4;
106            output[base + i] =
107                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108        }
109    }
110
111    /// SIMD prefix sum for 4 u32 values using NEON
112    /// Input:  [a, b, c, d]
113    /// Output: [a, a+b, a+b+c, a+b+c+d]
114    #[inline]
115    #[target_feature(enable = "neon")]
116    unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117        // Step 1: shift by 1 and add
118        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
119        let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120        let sum1 = vaddq_u32(v, shifted1);
121
122        // Step 2: shift by 2 and add
123        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
124        let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125        vaddq_u32(sum1, shifted2)
126    }
127
128    /// SIMD delta decode: convert deltas to absolute doc IDs
129    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
130    /// Uses NEON SIMD prefix sum for high throughput
131    #[target_feature(enable = "neon")]
132    pub unsafe fn delta_decode(
133        output: &mut [u32],
134        deltas: &[u32],
135        first_doc_id: u32,
136        count: usize,
137    ) {
138        if count == 0 {
139            return;
140        }
141
142        output[0] = first_doc_id;
143        if count == 1 {
144            return;
145        }
146
147        let ones = vdupq_n_u32(1);
148        let mut carry = vdupq_n_u32(first_doc_id);
149
150        let full_groups = (count - 1) / 4;
151        let remainder = (count - 1) % 4;
152
153        for group in 0..full_groups {
154            let base = group * 4;
155
156            // Load 4 deltas and add 1 (since we store gap-1)
157            let d = vld1q_u32(deltas[base..].as_ptr());
158            let gaps = vaddq_u32(d, ones);
159
160            // Compute prefix sum within the 4 elements
161            let prefix = prefix_sum_4(gaps);
162
163            // Add carry (broadcast last element of previous group)
164            let result = vaddq_u32(prefix, carry);
165
166            // Store result
167            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169            // Update carry: broadcast the last element for next iteration
170            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171        }
172
173        // Handle remainder
174        let base = full_groups * 4;
175        let mut scalar_carry = vgetq_lane_u32(carry, 0);
176        for j in 0..remainder {
177            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178            output[base + j + 1] = scalar_carry;
179        }
180    }
181
182    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
183    #[target_feature(enable = "neon")]
184    pub unsafe fn add_one(values: &mut [u32], count: usize) {
185        let ones = vdupq_n_u32(1);
186        let chunks = count / 4;
187        let remainder = count % 4;
188
189        for chunk in 0..chunks {
190            let base = chunk * 4;
191            let ptr = values.as_mut_ptr().add(base);
192            let v = vld1q_u32(ptr);
193            let result = vaddq_u32(v, ones);
194            vst1q_u32(ptr, result);
195        }
196
197        let base = chunks * 4;
198        for i in 0..remainder {
199            values[base + i] += 1;
200        }
201    }
202
203    /// Fused unpack 8-bit + delta decode using NEON
204    /// Processes 4 values at a time, fusing unpack and prefix sum
205    #[target_feature(enable = "neon")]
206    pub unsafe fn unpack_8bit_delta_decode(
207        input: &[u8],
208        output: &mut [u32],
209        first_value: u32,
210        count: usize,
211    ) {
212        output[0] = first_value;
213        if count <= 1 {
214            return;
215        }
216
217        let ones = vdupq_n_u32(1);
218        let mut carry = vdupq_n_u32(first_value);
219
220        let full_groups = (count - 1) / 4;
221        let remainder = (count - 1) % 4;
222
223        for group in 0..full_groups {
224            let base = group * 4;
225
226            // Load 4 bytes and widen to u32
227            let b0 = input[base] as u32;
228            let b1 = input[base + 1] as u32;
229            let b2 = input[base + 2] as u32;
230            let b3 = input[base + 3] as u32;
231            let deltas = [b0, b1, b2, b3];
232            let d = vld1q_u32(deltas.as_ptr());
233
234            // Add 1 (since we store gap-1)
235            let gaps = vaddq_u32(d, ones);
236
237            // Compute prefix sum within the 4 elements
238            let prefix = prefix_sum_4(gaps);
239
240            // Add carry
241            let result = vaddq_u32(prefix, carry);
242
243            // Store result
244            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246            // Update carry
247            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248        }
249
250        // Handle remainder
251        let base = full_groups * 4;
252        let mut scalar_carry = vgetq_lane_u32(carry, 0);
253        for j in 0..remainder {
254            scalar_carry = scalar_carry
255                .wrapping_add(input[base + j] as u32)
256                .wrapping_add(1);
257            output[base + j + 1] = scalar_carry;
258        }
259    }
260
261    /// Fused unpack 16-bit + delta decode using NEON
262    #[target_feature(enable = "neon")]
263    pub unsafe fn unpack_16bit_delta_decode(
264        input: &[u8],
265        output: &mut [u32],
266        first_value: u32,
267        count: usize,
268    ) {
269        output[0] = first_value;
270        if count <= 1 {
271            return;
272        }
273
274        let ones = vdupq_n_u32(1);
275        let mut carry = vdupq_n_u32(first_value);
276
277        let full_groups = (count - 1) / 4;
278        let remainder = (count - 1) % 4;
279
280        for group in 0..full_groups {
281            let base = group * 4;
282            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284            // Load 4 u16 values and widen to u32
285            let vals = vld1_u16(in_ptr);
286            let d = vmovl_u16(vals);
287
288            // Add 1 (since we store gap-1)
289            let gaps = vaddq_u32(d, ones);
290
291            // Compute prefix sum within the 4 elements
292            let prefix = prefix_sum_4(gaps);
293
294            // Add carry
295            let result = vaddq_u32(prefix, carry);
296
297            // Store result
298            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300            // Update carry
301            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302        }
303
304        // Handle remainder
305        let base = full_groups * 4;
306        let mut scalar_carry = vgetq_lane_u32(carry, 0);
307        for j in 0..remainder {
308            let idx = (base + j) * 2;
309            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311            output[base + j + 1] = scalar_carry;
312        }
313    }
314
315    /// Check if NEON is available (always true on aarch64)
316    #[inline]
317    pub fn is_available() -> bool {
318        true
319    }
320}
321
322// ============================================================================
323// SSE intrinsics for x86_64 (Intel/AMD)
324// ============================================================================
325
326#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329    use std::arch::x86_64::*;
330
331    /// SIMD unpack for 8-bit values using SSE
332    #[target_feature(enable = "sse2", enable = "sse4.1")]
333    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334        let chunks = count / 16;
335        let remainder = count % 16;
336
337        for chunk in 0..chunks {
338            let base = chunk * 16;
339            let in_ptr = input.as_ptr().add(base);
340
341            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343            // Zero extend u8 -> u32 using SSE4.1 pmovzx
344            let v0 = _mm_cvtepu8_epi32(bytes);
345            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349            let out_ptr = output.as_mut_ptr().add(base);
350            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354        }
355
356        let base = chunks * 16;
357        for i in 0..remainder {
358            output[base + i] = input[base + i] as u32;
359        }
360    }
361
362    /// SIMD unpack for 16-bit values using SSE
363    #[target_feature(enable = "sse2", enable = "sse4.1")]
364    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365        let chunks = count / 8;
366        let remainder = count % 8;
367
368        for chunk in 0..chunks {
369            let base = chunk * 8;
370            let in_ptr = input.as_ptr().add(base * 2);
371
372            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373            let low = _mm_cvtepu16_epi32(vals);
374            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376            let out_ptr = output.as_mut_ptr().add(base);
377            _mm_storeu_si128(out_ptr as *mut __m128i, low);
378            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379        }
380
381        let base = chunks * 8;
382        for i in 0..remainder {
383            let idx = (base + i) * 2;
384            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385        }
386    }
387
388    /// SIMD unpack for 32-bit values using SSE (fast copy)
389    #[target_feature(enable = "sse2")]
390    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391        let chunks = count / 4;
392        let remainder = count % 4;
393
394        let in_ptr = input.as_ptr() as *const __m128i;
395        let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397        for chunk in 0..chunks {
398            let vals = _mm_loadu_si128(in_ptr.add(chunk));
399            _mm_storeu_si128(out_ptr.add(chunk), vals);
400        }
401
402        // Handle remainder
403        let base = chunks * 4;
404        for i in 0..remainder {
405            let idx = (base + i) * 4;
406            output[base + i] =
407                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408        }
409    }
410
411    /// SIMD prefix sum for 4 u32 values using SSE
412    /// Input:  [a, b, c, d]
413    /// Output: [a, a+b, a+b+c, a+b+c+d]
414    #[inline]
415    #[target_feature(enable = "sse2")]
416    unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417        // Step 1: shift by 1 element (4 bytes) and add
418        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
419        let shifted1 = _mm_slli_si128(v, 4);
420        let sum1 = _mm_add_epi32(v, shifted1);
421
422        // Step 2: shift by 2 elements (8 bytes) and add
423        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
424        let shifted2 = _mm_slli_si128(sum1, 8);
425        _mm_add_epi32(sum1, shifted2)
426    }
427
428    /// SIMD delta decode using SSE with true SIMD prefix sum
429    #[target_feature(enable = "sse2", enable = "sse4.1")]
430    pub unsafe fn delta_decode(
431        output: &mut [u32],
432        deltas: &[u32],
433        first_doc_id: u32,
434        count: usize,
435    ) {
436        if count == 0 {
437            return;
438        }
439
440        output[0] = first_doc_id;
441        if count == 1 {
442            return;
443        }
444
445        let ones = _mm_set1_epi32(1);
446        let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448        let full_groups = (count - 1) / 4;
449        let remainder = (count - 1) % 4;
450
451        for group in 0..full_groups {
452            let base = group * 4;
453
454            // Load 4 deltas and add 1 (since we store gap-1)
455            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456            let gaps = _mm_add_epi32(d, ones);
457
458            // Compute prefix sum within the 4 elements
459            let prefix = prefix_sum_4(gaps);
460
461            // Add carry (broadcast last element of previous group)
462            let result = _mm_add_epi32(prefix, carry);
463
464            // Store result
465            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467            // Update carry: broadcast the last element for next iteration
468            carry = _mm_shuffle_epi32(result, 0xFF); // broadcast lane 3
469        }
470
471        // Handle remainder
472        let base = full_groups * 4;
473        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474        for j in 0..remainder {
475            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476            output[base + j + 1] = scalar_carry;
477        }
478    }
479
480    /// SIMD add 1 to all values using SSE
481    #[target_feature(enable = "sse2")]
482    pub unsafe fn add_one(values: &mut [u32], count: usize) {
483        let ones = _mm_set1_epi32(1);
484        let chunks = count / 4;
485        let remainder = count % 4;
486
487        for chunk in 0..chunks {
488            let base = chunk * 4;
489            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490            let v = _mm_loadu_si128(ptr);
491            let result = _mm_add_epi32(v, ones);
492            _mm_storeu_si128(ptr, result);
493        }
494
495        let base = chunks * 4;
496        for i in 0..remainder {
497            values[base + i] += 1;
498        }
499    }
500
501    /// Fused unpack 8-bit + delta decode using SSE
502    #[target_feature(enable = "sse2", enable = "sse4.1")]
503    pub unsafe fn unpack_8bit_delta_decode(
504        input: &[u8],
505        output: &mut [u32],
506        first_value: u32,
507        count: usize,
508    ) {
509        output[0] = first_value;
510        if count <= 1 {
511            return;
512        }
513
514        let ones = _mm_set1_epi32(1);
515        let mut carry = _mm_set1_epi32(first_value as i32);
516
517        let full_groups = (count - 1) / 4;
518        let remainder = (count - 1) % 4;
519
520        for group in 0..full_groups {
521            let base = group * 4;
522
523            // Load 4 bytes and zero-extend to u32
524            let bytes = _mm_cvtsi32_si128(*(input.as_ptr().add(base) as *const i32));
525            let d = _mm_cvtepu8_epi32(bytes);
526
527            // Add 1 (since we store gap-1)
528            let gaps = _mm_add_epi32(d, ones);
529
530            // Compute prefix sum within the 4 elements
531            let prefix = prefix_sum_4(gaps);
532
533            // Add carry
534            let result = _mm_add_epi32(prefix, carry);
535
536            // Store result
537            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
538
539            // Update carry: broadcast the last element
540            carry = _mm_shuffle_epi32(result, 0xFF);
541        }
542
543        // Handle remainder
544        let base = full_groups * 4;
545        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
546        for j in 0..remainder {
547            scalar_carry = scalar_carry
548                .wrapping_add(input[base + j] as u32)
549                .wrapping_add(1);
550            output[base + j + 1] = scalar_carry;
551        }
552    }
553
554    /// Fused unpack 16-bit + delta decode using SSE
555    #[target_feature(enable = "sse2", enable = "sse4.1")]
556    pub unsafe fn unpack_16bit_delta_decode(
557        input: &[u8],
558        output: &mut [u32],
559        first_value: u32,
560        count: usize,
561    ) {
562        output[0] = first_value;
563        if count <= 1 {
564            return;
565        }
566
567        let ones = _mm_set1_epi32(1);
568        let mut carry = _mm_set1_epi32(first_value as i32);
569
570        let full_groups = (count - 1) / 4;
571        let remainder = (count - 1) % 4;
572
573        for group in 0..full_groups {
574            let base = group * 4;
575            let in_ptr = input.as_ptr().add(base * 2);
576
577            // Load 8 bytes (4 u16 values) and zero-extend to u32
578            let vals = _mm_loadl_epi64(in_ptr as *const __m128i);
579            let d = _mm_cvtepu16_epi32(vals);
580
581            // Add 1 (since we store gap-1)
582            let gaps = _mm_add_epi32(d, ones);
583
584            // Compute prefix sum within the 4 elements
585            let prefix = prefix_sum_4(gaps);
586
587            // Add carry
588            let result = _mm_add_epi32(prefix, carry);
589
590            // Store result
591            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
592
593            // Update carry: broadcast the last element
594            carry = _mm_shuffle_epi32(result, 0xFF);
595        }
596
597        // Handle remainder
598        let base = full_groups * 4;
599        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
600        for j in 0..remainder {
601            let idx = (base + j) * 2;
602            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
603            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
604            output[base + j + 1] = scalar_carry;
605        }
606    }
607
608    /// Check if SSE4.1 is available at runtime
609    #[inline]
610    pub fn is_available() -> bool {
611        is_x86_feature_detected!("sse4.1")
612    }
613}
614
615// ============================================================================
616// AVX2 intrinsics for x86_64 (Intel/AMD with 256-bit registers)
617// ============================================================================
618
619#[cfg(target_arch = "x86_64")]
620#[allow(unsafe_op_in_unsafe_fn)]
621mod avx2 {
622    use std::arch::x86_64::*;
623
624    /// AVX2 unpack for 8-bit values (processes 32 bytes at a time)
625    #[target_feature(enable = "avx2")]
626    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
627        let chunks = count / 32;
628        let remainder = count % 32;
629
630        for chunk in 0..chunks {
631            let base = chunk * 32;
632            let in_ptr = input.as_ptr().add(base);
633
634            // Load 32 bytes (two 128-bit loads, then combine)
635            let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
636            let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
637
638            // Zero extend first 16 bytes: u8 -> u32
639            let v0 = _mm256_cvtepu8_epi32(bytes_lo);
640            let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
641            let v2 = _mm256_cvtepu8_epi32(bytes_hi);
642            let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
643
644            let out_ptr = output.as_mut_ptr().add(base);
645            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
646            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
647            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
648            _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
649        }
650
651        // Handle remainder with SSE
652        let base = chunks * 32;
653        for i in 0..remainder {
654            output[base + i] = input[base + i] as u32;
655        }
656    }
657
658    /// AVX2 unpack for 16-bit values (processes 16 values at a time)
659    #[target_feature(enable = "avx2")]
660    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
661        let chunks = count / 16;
662        let remainder = count % 16;
663
664        for chunk in 0..chunks {
665            let base = chunk * 16;
666            let in_ptr = input.as_ptr().add(base * 2);
667
668            // Load 32 bytes (16 u16 values)
669            let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
670            let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
671
672            // Zero extend u16 -> u32
673            let v0 = _mm256_cvtepu16_epi32(vals_lo);
674            let v1 = _mm256_cvtepu16_epi32(vals_hi);
675
676            let out_ptr = output.as_mut_ptr().add(base);
677            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
678            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
679        }
680
681        // Handle remainder
682        let base = chunks * 16;
683        for i in 0..remainder {
684            let idx = (base + i) * 2;
685            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
686        }
687    }
688
689    /// AVX2 unpack for 32-bit values (fast copy, 8 values at a time)
690    #[target_feature(enable = "avx2")]
691    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
692        let chunks = count / 8;
693        let remainder = count % 8;
694
695        let in_ptr = input.as_ptr() as *const __m256i;
696        let out_ptr = output.as_mut_ptr() as *mut __m256i;
697
698        for chunk in 0..chunks {
699            let vals = _mm256_loadu_si256(in_ptr.add(chunk));
700            _mm256_storeu_si256(out_ptr.add(chunk), vals);
701        }
702
703        // Handle remainder
704        let base = chunks * 8;
705        for i in 0..remainder {
706            let idx = (base + i) * 4;
707            output[base + i] =
708                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
709        }
710    }
711
712    /// AVX2 add 1 to all values (8 values at a time)
713    #[target_feature(enable = "avx2")]
714    pub unsafe fn add_one(values: &mut [u32], count: usize) {
715        let ones = _mm256_set1_epi32(1);
716        let chunks = count / 8;
717        let remainder = count % 8;
718
719        for chunk in 0..chunks {
720            let base = chunk * 8;
721            let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
722            let v = _mm256_loadu_si256(ptr);
723            let result = _mm256_add_epi32(v, ones);
724            _mm256_storeu_si256(ptr, result);
725        }
726
727        let base = chunks * 8;
728        for i in 0..remainder {
729            values[base + i] += 1;
730        }
731    }
732
733    /// Check if AVX2 is available at runtime
734    #[inline]
735    pub fn is_available() -> bool {
736        is_x86_feature_detected!("avx2")
737    }
738}
739
740// ============================================================================
741// Scalar fallback implementations
742// ============================================================================
743
744#[allow(dead_code)]
745mod scalar {
746    /// Scalar unpack for 8-bit values
747    #[inline]
748    pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
749        for i in 0..count {
750            output[i] = input[i] as u32;
751        }
752    }
753
754    /// Scalar unpack for 16-bit values
755    #[inline]
756    pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
757        for (i, out) in output.iter_mut().enumerate().take(count) {
758            let idx = i * 2;
759            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
760        }
761    }
762
763    /// Scalar unpack for 32-bit values
764    #[inline]
765    pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
766        for (i, out) in output.iter_mut().enumerate().take(count) {
767            let idx = i * 4;
768            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
769        }
770    }
771
772    /// Scalar delta decode
773    #[inline]
774    pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
775        if count == 0 {
776            return;
777        }
778
779        output[0] = first_doc_id;
780        let mut carry = first_doc_id;
781
782        for i in 0..count - 1 {
783            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
784            output[i + 1] = carry;
785        }
786    }
787
788    /// Scalar add 1 to all values
789    #[inline]
790    pub fn add_one(values: &mut [u32], count: usize) {
791        for val in values.iter_mut().take(count) {
792            *val += 1;
793        }
794    }
795}
796
797// ============================================================================
798// Public dispatch functions that select SIMD or scalar at runtime
799// ============================================================================
800
801/// Unpack 8-bit packed values to u32 with SIMD acceleration
802#[inline]
803pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
804    #[cfg(target_arch = "aarch64")]
805    {
806        if neon::is_available() {
807            unsafe {
808                neon::unpack_8bit(input, output, count);
809            }
810            return;
811        }
812    }
813
814    #[cfg(target_arch = "x86_64")]
815    {
816        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
817        if avx2::is_available() {
818            unsafe {
819                avx2::unpack_8bit(input, output, count);
820            }
821            return;
822        }
823        if sse::is_available() {
824            unsafe {
825                sse::unpack_8bit(input, output, count);
826            }
827            return;
828        }
829    }
830
831    scalar::unpack_8bit(input, output, count);
832}
833
834/// Unpack 16-bit packed values to u32 with SIMD acceleration
835#[inline]
836pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
837    #[cfg(target_arch = "aarch64")]
838    {
839        if neon::is_available() {
840            unsafe {
841                neon::unpack_16bit(input, output, count);
842            }
843            return;
844        }
845    }
846
847    #[cfg(target_arch = "x86_64")]
848    {
849        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
850        if avx2::is_available() {
851            unsafe {
852                avx2::unpack_16bit(input, output, count);
853            }
854            return;
855        }
856        if sse::is_available() {
857            unsafe {
858                sse::unpack_16bit(input, output, count);
859            }
860            return;
861        }
862    }
863
864    scalar::unpack_16bit(input, output, count);
865}
866
867/// Unpack 32-bit packed values to u32 with SIMD acceleration
868#[inline]
869pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
870    #[cfg(target_arch = "aarch64")]
871    {
872        if neon::is_available() {
873            unsafe {
874                neon::unpack_32bit(input, output, count);
875            }
876        }
877    }
878
879    #[cfg(target_arch = "x86_64")]
880    {
881        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
882        if avx2::is_available() {
883            unsafe {
884                avx2::unpack_32bit(input, output, count);
885            }
886        } else {
887            // SSE2 is always available on x86_64
888            unsafe {
889                sse::unpack_32bit(input, output, count);
890            }
891        }
892    }
893
894    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
895    {
896        scalar::unpack_32bit(input, output, count);
897    }
898}
899
900/// Delta decode with SIMD acceleration
901///
902/// Converts delta-encoded values to absolute values.
903/// Input: deltas[i] = value[i+1] - value[i] - 1 (gap minus one)
904/// Output: absolute values starting from first_value
905#[inline]
906pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
907    #[cfg(target_arch = "aarch64")]
908    {
909        if neon::is_available() {
910            unsafe {
911                neon::delta_decode(output, deltas, first_value, count);
912            }
913            return;
914        }
915    }
916
917    #[cfg(target_arch = "x86_64")]
918    {
919        if sse::is_available() {
920            unsafe {
921                sse::delta_decode(output, deltas, first_value, count);
922            }
923            return;
924        }
925    }
926
927    scalar::delta_decode(output, deltas, first_value, count);
928}
929
930/// Add 1 to all values with SIMD acceleration
931///
932/// Used for TF decoding where values are stored as (tf - 1)
933#[inline]
934pub fn add_one(values: &mut [u32], count: usize) {
935    #[cfg(target_arch = "aarch64")]
936    {
937        if neon::is_available() {
938            unsafe {
939                neon::add_one(values, count);
940            }
941        }
942    }
943
944    #[cfg(target_arch = "x86_64")]
945    {
946        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
947        if avx2::is_available() {
948            unsafe {
949                avx2::add_one(values, count);
950            }
951        } else {
952            // SSE2 is always available on x86_64
953            unsafe {
954                sse::add_one(values, count);
955            }
956        }
957    }
958
959    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
960    {
961        scalar::add_one(values, count);
962    }
963}
964
965/// Compute the number of bits needed to represent a value
966#[inline]
967pub fn bits_needed(val: u32) -> u8 {
968    if val == 0 {
969        0
970    } else {
971        32 - val.leading_zeros() as u8
972    }
973}
974
975// ============================================================================
976// Rounded bitpacking for truly vectorized encoding/decoding
977// ============================================================================
978//
979// Instead of using arbitrary bit widths (1-32), we round up to SIMD-friendly
980// widths: 0, 8, 16, or 32 bits. This trades ~10-20% more space for much faster
981// decoding since we can use direct SIMD widening instructions (pmovzx) without
982// any bit-shifting or masking.
983//
984// Bit width mapping:
985//   0      -> 0  (all zeros)
986//   1-8    -> 8  (u8)
987//   9-16   -> 16 (u16)
988//   17-32  -> 32 (u32)
989
990/// Rounded bit width type for SIMD-friendly encoding
991#[derive(Debug, Clone, Copy, PartialEq, Eq)]
992#[repr(u8)]
993pub enum RoundedBitWidth {
994    Zero = 0,
995    Bits8 = 8,
996    Bits16 = 16,
997    Bits32 = 32,
998}
999
1000impl RoundedBitWidth {
1001    /// Round an exact bit width to the nearest SIMD-friendly width
1002    #[inline]
1003    pub fn from_exact(bits: u8) -> Self {
1004        match bits {
1005            0 => RoundedBitWidth::Zero,
1006            1..=8 => RoundedBitWidth::Bits8,
1007            9..=16 => RoundedBitWidth::Bits16,
1008            _ => RoundedBitWidth::Bits32,
1009        }
1010    }
1011
1012    /// Convert from stored u8 value (must be 0, 8, 16, or 32)
1013    #[inline]
1014    pub fn from_u8(bits: u8) -> Self {
1015        match bits {
1016            0 => RoundedBitWidth::Zero,
1017            8 => RoundedBitWidth::Bits8,
1018            16 => RoundedBitWidth::Bits16,
1019            32 => RoundedBitWidth::Bits32,
1020            _ => RoundedBitWidth::Bits32, // Fallback for invalid values
1021        }
1022    }
1023
1024    /// Get the byte size per value
1025    #[inline]
1026    pub fn bytes_per_value(self) -> usize {
1027        match self {
1028            RoundedBitWidth::Zero => 0,
1029            RoundedBitWidth::Bits8 => 1,
1030            RoundedBitWidth::Bits16 => 2,
1031            RoundedBitWidth::Bits32 => 4,
1032        }
1033    }
1034
1035    /// Get the raw bit width value
1036    #[inline]
1037    pub fn as_u8(self) -> u8 {
1038        self as u8
1039    }
1040}
1041
1042/// Round a bit width to the nearest SIMD-friendly width (0, 8, 16, or 32)
1043#[inline]
1044pub fn round_bit_width(bits: u8) -> u8 {
1045    RoundedBitWidth::from_exact(bits).as_u8()
1046}
1047
1048/// Pack values using rounded bit width (SIMD-friendly)
1049///
1050/// This is much simpler than arbitrary bitpacking since values are byte-aligned.
1051/// Returns the number of bytes written.
1052#[inline]
1053pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1054    let count = values.len();
1055    match bit_width {
1056        RoundedBitWidth::Zero => 0,
1057        RoundedBitWidth::Bits8 => {
1058            for (i, &v) in values.iter().enumerate() {
1059                output[i] = v as u8;
1060            }
1061            count
1062        }
1063        RoundedBitWidth::Bits16 => {
1064            for (i, &v) in values.iter().enumerate() {
1065                let bytes = (v as u16).to_le_bytes();
1066                output[i * 2] = bytes[0];
1067                output[i * 2 + 1] = bytes[1];
1068            }
1069            count * 2
1070        }
1071        RoundedBitWidth::Bits32 => {
1072            for (i, &v) in values.iter().enumerate() {
1073                let bytes = v.to_le_bytes();
1074                output[i * 4] = bytes[0];
1075                output[i * 4 + 1] = bytes[1];
1076                output[i * 4 + 2] = bytes[2];
1077                output[i * 4 + 3] = bytes[3];
1078            }
1079            count * 4
1080        }
1081    }
1082}
1083
1084/// Unpack values using rounded bit width with SIMD acceleration
1085///
1086/// This is the fast path - no bit manipulation needed, just widening.
1087#[inline]
1088pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1089    match bit_width {
1090        RoundedBitWidth::Zero => {
1091            for out in output.iter_mut().take(count) {
1092                *out = 0;
1093            }
1094        }
1095        RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1096        RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1097        RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1098    }
1099}
1100
1101/// Fused unpack + delta decode using rounded bit width
1102///
1103/// Combines unpacking and prefix sum in a single pass for better cache utilization.
1104#[inline]
1105pub fn unpack_rounded_delta_decode(
1106    input: &[u8],
1107    bit_width: RoundedBitWidth,
1108    output: &mut [u32],
1109    first_value: u32,
1110    count: usize,
1111) {
1112    match bit_width {
1113        RoundedBitWidth::Zero => {
1114            // All deltas are 0, meaning gaps of 1
1115            let mut val = first_value;
1116            for out in output.iter_mut().take(count) {
1117                *out = val;
1118                val = val.wrapping_add(1);
1119            }
1120        }
1121        RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1122        RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1123        RoundedBitWidth::Bits32 => {
1124            // For 32-bit, unpack then delta decode (no fused version needed)
1125            unpack_32bit(input, output, count);
1126            // Delta decode in place - but we need the deltas separate
1127            // Actually for 32-bit we should just unpack and delta decode separately
1128            if count > 0 {
1129                let mut carry = first_value;
1130                output[0] = first_value;
1131                for item in output.iter_mut().take(count).skip(1) {
1132                    // item currently holds delta (gap-1)
1133                    carry = carry.wrapping_add(*item).wrapping_add(1);
1134                    *item = carry;
1135                }
1136            }
1137        }
1138    }
1139}
1140
1141// ============================================================================
1142// Fused operations for better cache utilization
1143// ============================================================================
1144
1145/// Fused unpack 8-bit + delta decode in a single pass
1146///
1147/// This avoids writing the intermediate unpacked values to memory,
1148/// improving cache utilization for large blocks.
1149#[inline]
1150pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1151    if count == 0 {
1152        return;
1153    }
1154
1155    output[0] = first_value;
1156    if count == 1 {
1157        return;
1158    }
1159
1160    #[cfg(target_arch = "aarch64")]
1161    {
1162        if neon::is_available() {
1163            unsafe {
1164                neon::unpack_8bit_delta_decode(input, output, first_value, count);
1165            }
1166            return;
1167        }
1168    }
1169
1170    #[cfg(target_arch = "x86_64")]
1171    {
1172        if sse::is_available() {
1173            unsafe {
1174                sse::unpack_8bit_delta_decode(input, output, first_value, count);
1175            }
1176            return;
1177        }
1178    }
1179
1180    // Scalar fallback
1181    let mut carry = first_value;
1182    for i in 0..count - 1 {
1183        carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1184        output[i + 1] = carry;
1185    }
1186}
1187
1188/// Fused unpack 16-bit + delta decode in a single pass
1189#[inline]
1190pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1191    if count == 0 {
1192        return;
1193    }
1194
1195    output[0] = first_value;
1196    if count == 1 {
1197        return;
1198    }
1199
1200    #[cfg(target_arch = "aarch64")]
1201    {
1202        if neon::is_available() {
1203            unsafe {
1204                neon::unpack_16bit_delta_decode(input, output, first_value, count);
1205            }
1206            return;
1207        }
1208    }
1209
1210    #[cfg(target_arch = "x86_64")]
1211    {
1212        if sse::is_available() {
1213            unsafe {
1214                sse::unpack_16bit_delta_decode(input, output, first_value, count);
1215            }
1216            return;
1217        }
1218    }
1219
1220    // Scalar fallback
1221    let mut carry = first_value;
1222    for i in 0..count - 1 {
1223        let idx = i * 2;
1224        let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1225        carry = carry.wrapping_add(delta).wrapping_add(1);
1226        output[i + 1] = carry;
1227    }
1228}
1229
1230/// Fused unpack + delta decode for arbitrary bit widths
1231///
1232/// Combines unpacking and prefix sum in a single pass, avoiding intermediate buffer.
1233/// Uses SIMD-accelerated paths for 8/16-bit widths, scalar for others.
1234#[inline]
1235pub fn unpack_delta_decode(
1236    input: &[u8],
1237    bit_width: u8,
1238    output: &mut [u32],
1239    first_value: u32,
1240    count: usize,
1241) {
1242    if count == 0 {
1243        return;
1244    }
1245
1246    output[0] = first_value;
1247    if count == 1 {
1248        return;
1249    }
1250
1251    // Fast paths for SIMD-friendly bit widths
1252    match bit_width {
1253        0 => {
1254            // All zeros = consecutive doc IDs (gap of 1)
1255            let mut val = first_value;
1256            for item in output.iter_mut().take(count).skip(1) {
1257                val = val.wrapping_add(1);
1258                *item = val;
1259            }
1260        }
1261        8 => unpack_8bit_delta_decode(input, output, first_value, count),
1262        16 => unpack_16bit_delta_decode(input, output, first_value, count),
1263        32 => {
1264            // 32-bit: unpack inline and delta decode
1265            let mut carry = first_value;
1266            for i in 0..count - 1 {
1267                let idx = i * 4;
1268                let delta = u32::from_le_bytes([
1269                    input[idx],
1270                    input[idx + 1],
1271                    input[idx + 2],
1272                    input[idx + 3],
1273                ]);
1274                carry = carry.wrapping_add(delta).wrapping_add(1);
1275                output[i + 1] = carry;
1276            }
1277        }
1278        _ => {
1279            // Generic bit width: fused unpack + delta decode
1280            let mask = (1u64 << bit_width) - 1;
1281            let bit_width_usize = bit_width as usize;
1282            let mut bit_pos = 0usize;
1283            let input_ptr = input.as_ptr();
1284            let mut carry = first_value;
1285
1286            for i in 0..count - 1 {
1287                let byte_idx = bit_pos >> 3;
1288                let bit_offset = bit_pos & 7;
1289
1290                // SAFETY: Caller guarantees input has enough data
1291                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1292                let delta = ((word >> bit_offset) & mask) as u32;
1293
1294                carry = carry.wrapping_add(delta).wrapping_add(1);
1295                output[i + 1] = carry;
1296                bit_pos += bit_width_usize;
1297            }
1298        }
1299    }
1300}
1301
1302// ============================================================================
1303// Sparse Vector SIMD Functions
1304// ============================================================================
1305
1306/// Dequantize UInt8 weights to f32 with SIMD acceleration
1307///
1308/// Computes: output[i] = input[i] as f32 * scale + min_val
1309#[inline]
1310pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1311    #[cfg(target_arch = "aarch64")]
1312    {
1313        if neon::is_available() {
1314            unsafe {
1315                dequantize_uint8_neon(input, output, scale, min_val, count);
1316            }
1317            return;
1318        }
1319    }
1320
1321    #[cfg(target_arch = "x86_64")]
1322    {
1323        if sse::is_available() {
1324            unsafe {
1325                dequantize_uint8_sse(input, output, scale, min_val, count);
1326            }
1327            return;
1328        }
1329    }
1330
1331    // Scalar fallback
1332    for i in 0..count {
1333        output[i] = input[i] as f32 * scale + min_val;
1334    }
1335}
1336
1337#[cfg(target_arch = "aarch64")]
1338#[target_feature(enable = "neon")]
1339#[allow(unsafe_op_in_unsafe_fn)]
1340unsafe fn dequantize_uint8_neon(
1341    input: &[u8],
1342    output: &mut [f32],
1343    scale: f32,
1344    min_val: f32,
1345    count: usize,
1346) {
1347    use std::arch::aarch64::*;
1348
1349    let scale_v = vdupq_n_f32(scale);
1350    let min_v = vdupq_n_f32(min_val);
1351
1352    let chunks = count / 16;
1353    let remainder = count % 16;
1354
1355    for chunk in 0..chunks {
1356        let base = chunk * 16;
1357        let in_ptr = input.as_ptr().add(base);
1358
1359        // Load 16 bytes
1360        let bytes = vld1q_u8(in_ptr);
1361
1362        // Widen u8 -> u16 -> u32 -> f32
1363        let low8 = vget_low_u8(bytes);
1364        let high8 = vget_high_u8(bytes);
1365
1366        let low16 = vmovl_u8(low8);
1367        let high16 = vmovl_u8(high8);
1368
1369        // Process 4 values at a time
1370        let u32_0 = vmovl_u16(vget_low_u16(low16));
1371        let u32_1 = vmovl_u16(vget_high_u16(low16));
1372        let u32_2 = vmovl_u16(vget_low_u16(high16));
1373        let u32_3 = vmovl_u16(vget_high_u16(high16));
1374
1375        // Convert to f32 and apply scale + min_val
1376        let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1377        let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1378        let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1379        let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1380
1381        let out_ptr = output.as_mut_ptr().add(base);
1382        vst1q_f32(out_ptr, f32_0);
1383        vst1q_f32(out_ptr.add(4), f32_1);
1384        vst1q_f32(out_ptr.add(8), f32_2);
1385        vst1q_f32(out_ptr.add(12), f32_3);
1386    }
1387
1388    // Handle remainder
1389    let base = chunks * 16;
1390    for i in 0..remainder {
1391        output[base + i] = input[base + i] as f32 * scale + min_val;
1392    }
1393}
1394
1395#[cfg(target_arch = "x86_64")]
1396#[target_feature(enable = "sse2", enable = "sse4.1")]
1397#[allow(unsafe_op_in_unsafe_fn)]
1398unsafe fn dequantize_uint8_sse(
1399    input: &[u8],
1400    output: &mut [f32],
1401    scale: f32,
1402    min_val: f32,
1403    count: usize,
1404) {
1405    use std::arch::x86_64::*;
1406
1407    let scale_v = _mm_set1_ps(scale);
1408    let min_v = _mm_set1_ps(min_val);
1409
1410    let chunks = count / 4;
1411    let remainder = count % 4;
1412
1413    for chunk in 0..chunks {
1414        let base = chunk * 4;
1415
1416        // Load 4 bytes and zero-extend to 32-bit
1417        let b0 = input[base] as i32;
1418        let b1 = input[base + 1] as i32;
1419        let b2 = input[base + 2] as i32;
1420        let b3 = input[base + 3] as i32;
1421
1422        let ints = _mm_set_epi32(b3, b2, b1, b0);
1423        let floats = _mm_cvtepi32_ps(ints);
1424
1425        // Apply scale and min_val: result = floats * scale + min_val
1426        let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1427
1428        _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1429    }
1430
1431    // Handle remainder
1432    let base = chunks * 4;
1433    for i in 0..remainder {
1434        output[base + i] = input[base + i] as f32 * scale + min_val;
1435    }
1436}
1437
1438/// Compute dot product of two f32 arrays with SIMD acceleration
1439#[inline]
1440pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1441    #[cfg(target_arch = "aarch64")]
1442    {
1443        if neon::is_available() {
1444            return unsafe { dot_product_f32_neon(a, b, count) };
1445        }
1446    }
1447
1448    #[cfg(target_arch = "x86_64")]
1449    {
1450        if sse::is_available() {
1451            return unsafe { dot_product_f32_sse(a, b, count) };
1452        }
1453    }
1454
1455    // Scalar fallback
1456    let mut sum = 0.0f32;
1457    for i in 0..count {
1458        sum += a[i] * b[i];
1459    }
1460    sum
1461}
1462
1463#[cfg(target_arch = "aarch64")]
1464#[target_feature(enable = "neon")]
1465#[allow(unsafe_op_in_unsafe_fn)]
1466unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1467    use std::arch::aarch64::*;
1468
1469    let chunks = count / 4;
1470    let remainder = count % 4;
1471
1472    let mut acc = vdupq_n_f32(0.0);
1473
1474    for chunk in 0..chunks {
1475        let base = chunk * 4;
1476        let va = vld1q_f32(a.as_ptr().add(base));
1477        let vb = vld1q_f32(b.as_ptr().add(base));
1478        acc = vfmaq_f32(acc, va, vb);
1479    }
1480
1481    // Horizontal sum
1482    let mut sum = vaddvq_f32(acc);
1483
1484    // Handle remainder
1485    let base = chunks * 4;
1486    for i in 0..remainder {
1487        sum += a[base + i] * b[base + i];
1488    }
1489
1490    sum
1491}
1492
1493#[cfg(target_arch = "x86_64")]
1494#[target_feature(enable = "sse")]
1495#[allow(unsafe_op_in_unsafe_fn)]
1496unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1497    use std::arch::x86_64::*;
1498
1499    let chunks = count / 4;
1500    let remainder = count % 4;
1501
1502    let mut acc = _mm_setzero_ps();
1503
1504    for chunk in 0..chunks {
1505        let base = chunk * 4;
1506        let va = _mm_loadu_ps(a.as_ptr().add(base));
1507        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1508        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1509    }
1510
1511    // Horizontal sum: [a, b, c, d] -> a + b + c + d
1512    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
1513    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
1514    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
1515    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
1516
1517    let mut sum = _mm_cvtss_f32(final_sum);
1518
1519    // Handle remainder
1520    let base = chunks * 4;
1521    for i in 0..remainder {
1522        sum += a[base + i] * b[base + i];
1523    }
1524
1525    sum
1526}
1527
1528/// Find maximum value in f32 array with SIMD acceleration
1529#[inline]
1530pub fn max_f32(values: &[f32], count: usize) -> f32 {
1531    if count == 0 {
1532        return f32::NEG_INFINITY;
1533    }
1534
1535    #[cfg(target_arch = "aarch64")]
1536    {
1537        if neon::is_available() {
1538            return unsafe { max_f32_neon(values, count) };
1539        }
1540    }
1541
1542    #[cfg(target_arch = "x86_64")]
1543    {
1544        if sse::is_available() {
1545            return unsafe { max_f32_sse(values, count) };
1546        }
1547    }
1548
1549    // Scalar fallback
1550    values[..count]
1551        .iter()
1552        .cloned()
1553        .fold(f32::NEG_INFINITY, f32::max)
1554}
1555
1556#[cfg(target_arch = "aarch64")]
1557#[target_feature(enable = "neon")]
1558#[allow(unsafe_op_in_unsafe_fn)]
1559unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1560    use std::arch::aarch64::*;
1561
1562    let chunks = count / 4;
1563    let remainder = count % 4;
1564
1565    let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1566
1567    for chunk in 0..chunks {
1568        let base = chunk * 4;
1569        let v = vld1q_f32(values.as_ptr().add(base));
1570        max_v = vmaxq_f32(max_v, v);
1571    }
1572
1573    // Horizontal max
1574    let mut max_val = vmaxvq_f32(max_v);
1575
1576    // Handle remainder
1577    let base = chunks * 4;
1578    for i in 0..remainder {
1579        max_val = max_val.max(values[base + i]);
1580    }
1581
1582    max_val
1583}
1584
1585#[cfg(target_arch = "x86_64")]
1586#[target_feature(enable = "sse")]
1587#[allow(unsafe_op_in_unsafe_fn)]
1588unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1589    use std::arch::x86_64::*;
1590
1591    let chunks = count / 4;
1592    let remainder = count % 4;
1593
1594    let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1595
1596    for chunk in 0..chunks {
1597        let base = chunk * 4;
1598        let v = _mm_loadu_ps(values.as_ptr().add(base));
1599        max_v = _mm_max_ps(max_v, v);
1600    }
1601
1602    // Horizontal max: [a, b, c, d] -> max(a, b, c, d)
1603    let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); // [b, a, d, c]
1604    let max1 = _mm_max_ps(max_v, shuf); // [max(a,b), max(a,b), max(c,d), max(c,d)]
1605    let shuf2 = _mm_movehl_ps(max1, max1); // [max(c,d), max(c,d), ?, ?]
1606    let final_max = _mm_max_ss(max1, shuf2); // [max(a,b,c,d), ?, ?, ?]
1607
1608    let mut max_val = _mm_cvtss_f32(final_max);
1609
1610    // Handle remainder
1611    let base = chunks * 4;
1612    for i in 0..remainder {
1613        max_val = max_val.max(values[base + i]);
1614    }
1615
1616    max_val
1617}
1618
1619#[cfg(test)]
1620mod tests {
1621    use super::*;
1622
1623    #[test]
1624    fn test_unpack_8bit() {
1625        let input: Vec<u8> = (0..128).collect();
1626        let mut output = vec![0u32; 128];
1627        unpack_8bit(&input, &mut output, 128);
1628
1629        for (i, &v) in output.iter().enumerate() {
1630            assert_eq!(v, i as u32);
1631        }
1632    }
1633
1634    #[test]
1635    fn test_unpack_16bit() {
1636        let mut input = vec![0u8; 256];
1637        for i in 0..128 {
1638            let val = (i * 100) as u16;
1639            input[i * 2] = val as u8;
1640            input[i * 2 + 1] = (val >> 8) as u8;
1641        }
1642
1643        let mut output = vec![0u32; 128];
1644        unpack_16bit(&input, &mut output, 128);
1645
1646        for (i, &v) in output.iter().enumerate() {
1647            assert_eq!(v, (i * 100) as u32);
1648        }
1649    }
1650
1651    #[test]
1652    fn test_unpack_32bit() {
1653        let mut input = vec![0u8; 512];
1654        for i in 0..128 {
1655            let val = (i * 1000) as u32;
1656            let bytes = val.to_le_bytes();
1657            input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
1658        }
1659
1660        let mut output = vec![0u32; 128];
1661        unpack_32bit(&input, &mut output, 128);
1662
1663        for (i, &v) in output.iter().enumerate() {
1664            assert_eq!(v, (i * 1000) as u32);
1665        }
1666    }
1667
1668    #[test]
1669    fn test_delta_decode() {
1670        // doc_ids: [10, 15, 20, 30, 50]
1671        // gaps: [5, 5, 10, 20]
1672        // deltas (gap-1): [4, 4, 9, 19]
1673        let deltas = vec![4u32, 4, 9, 19];
1674        let mut output = vec![0u32; 5];
1675
1676        delta_decode(&mut output, &deltas, 10, 5);
1677
1678        assert_eq!(output, vec![10, 15, 20, 30, 50]);
1679    }
1680
1681    #[test]
1682    fn test_add_one() {
1683        let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
1684        add_one(&mut values, 8);
1685
1686        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
1687    }
1688
1689    #[test]
1690    fn test_bits_needed() {
1691        assert_eq!(bits_needed(0), 0);
1692        assert_eq!(bits_needed(1), 1);
1693        assert_eq!(bits_needed(2), 2);
1694        assert_eq!(bits_needed(3), 2);
1695        assert_eq!(bits_needed(4), 3);
1696        assert_eq!(bits_needed(255), 8);
1697        assert_eq!(bits_needed(256), 9);
1698        assert_eq!(bits_needed(u32::MAX), 32);
1699    }
1700
1701    #[test]
1702    fn test_unpack_8bit_delta_decode() {
1703        // doc_ids: [10, 15, 20, 30, 50]
1704        // gaps: [5, 5, 10, 20]
1705        // deltas (gap-1): [4, 4, 9, 19] stored as u8
1706        let input: Vec<u8> = vec![4, 4, 9, 19];
1707        let mut output = vec![0u32; 5];
1708
1709        unpack_8bit_delta_decode(&input, &mut output, 10, 5);
1710
1711        assert_eq!(output, vec![10, 15, 20, 30, 50]);
1712    }
1713
1714    #[test]
1715    fn test_unpack_16bit_delta_decode() {
1716        // doc_ids: [100, 600, 1100, 2100, 4100]
1717        // gaps: [500, 500, 1000, 2000]
1718        // deltas (gap-1): [499, 499, 999, 1999] stored as u16
1719        let mut input = vec![0u8; 8];
1720        for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
1721            input[i * 2] = delta as u8;
1722            input[i * 2 + 1] = (delta >> 8) as u8;
1723        }
1724        let mut output = vec![0u32; 5];
1725
1726        unpack_16bit_delta_decode(&input, &mut output, 100, 5);
1727
1728        assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
1729    }
1730
1731    #[test]
1732    fn test_fused_vs_separate_8bit() {
1733        // Test that fused and separate operations produce the same result
1734        let input: Vec<u8> = (0..127).collect();
1735        let first_value = 1000u32;
1736        let count = 128;
1737
1738        // Separate: unpack then delta_decode
1739        let mut unpacked = vec![0u32; 128];
1740        unpack_8bit(&input, &mut unpacked, 127);
1741        let mut separate_output = vec![0u32; 128];
1742        delta_decode(&mut separate_output, &unpacked, first_value, count);
1743
1744        // Fused
1745        let mut fused_output = vec![0u32; 128];
1746        unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
1747
1748        assert_eq!(separate_output, fused_output);
1749    }
1750
1751    #[test]
1752    fn test_round_bit_width() {
1753        assert_eq!(round_bit_width(0), 0);
1754        assert_eq!(round_bit_width(1), 8);
1755        assert_eq!(round_bit_width(5), 8);
1756        assert_eq!(round_bit_width(8), 8);
1757        assert_eq!(round_bit_width(9), 16);
1758        assert_eq!(round_bit_width(12), 16);
1759        assert_eq!(round_bit_width(16), 16);
1760        assert_eq!(round_bit_width(17), 32);
1761        assert_eq!(round_bit_width(24), 32);
1762        assert_eq!(round_bit_width(32), 32);
1763    }
1764
1765    #[test]
1766    fn test_rounded_bitwidth_from_exact() {
1767        assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
1768        assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
1769        assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
1770        assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
1771        assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
1772        assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
1773        assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
1774    }
1775
1776    #[test]
1777    fn test_pack_unpack_rounded_8bit() {
1778        let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
1779        let mut packed = vec![0u8; 128];
1780
1781        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
1782        assert_eq!(bytes_written, 128);
1783
1784        let mut unpacked = vec![0u32; 128];
1785        unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
1786
1787        assert_eq!(values, unpacked);
1788    }
1789
1790    #[test]
1791    fn test_pack_unpack_rounded_16bit() {
1792        let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
1793        let mut packed = vec![0u8; 256];
1794
1795        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
1796        assert_eq!(bytes_written, 256);
1797
1798        let mut unpacked = vec![0u32; 128];
1799        unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
1800
1801        assert_eq!(values, unpacked);
1802    }
1803
1804    #[test]
1805    fn test_pack_unpack_rounded_32bit() {
1806        let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
1807        let mut packed = vec![0u8; 512];
1808
1809        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
1810        assert_eq!(bytes_written, 512);
1811
1812        let mut unpacked = vec![0u32; 128];
1813        unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
1814
1815        assert_eq!(values, unpacked);
1816    }
1817
1818    #[test]
1819    fn test_unpack_rounded_delta_decode() {
1820        // Test 8-bit rounded delta decode
1821        // doc_ids: [10, 15, 20, 30, 50]
1822        // gaps: [5, 5, 10, 20]
1823        // deltas (gap-1): [4, 4, 9, 19] stored as u8
1824        let input: Vec<u8> = vec![4, 4, 9, 19];
1825        let mut output = vec![0u32; 5];
1826
1827        unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
1828
1829        assert_eq!(output, vec![10, 15, 20, 30, 50]);
1830    }
1831
1832    #[test]
1833    fn test_unpack_rounded_delta_decode_zero() {
1834        // All zeros means gaps of 1 (consecutive doc IDs)
1835        let input: Vec<u8> = vec![];
1836        let mut output = vec![0u32; 5];
1837
1838        unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
1839
1840        assert_eq!(output, vec![100, 101, 102, 103, 104]);
1841    }
1842
1843    // ========================================================================
1844    // Sparse Vector SIMD Tests
1845    // ========================================================================
1846
1847    #[test]
1848    fn test_dequantize_uint8() {
1849        let input: Vec<u8> = vec![0, 128, 255, 64, 192];
1850        let mut output = vec![0.0f32; 5];
1851        let scale = 0.1;
1852        let min_val = 1.0;
1853
1854        dequantize_uint8(&input, &mut output, scale, min_val, 5);
1855
1856        // Expected: input[i] * scale + min_val
1857        assert!((output[0] - 1.0).abs() < 1e-6); // 0 * 0.1 + 1.0 = 1.0
1858        assert!((output[1] - 13.8).abs() < 1e-6); // 128 * 0.1 + 1.0 = 13.8
1859        assert!((output[2] - 26.5).abs() < 1e-6); // 255 * 0.1 + 1.0 = 26.5
1860        assert!((output[3] - 7.4).abs() < 1e-6); // 64 * 0.1 + 1.0 = 7.4
1861        assert!((output[4] - 20.2).abs() < 1e-6); // 192 * 0.1 + 1.0 = 20.2
1862    }
1863
1864    #[test]
1865    fn test_dequantize_uint8_large() {
1866        // Test with 128 values (full SIMD block)
1867        let input: Vec<u8> = (0..128).collect();
1868        let mut output = vec![0.0f32; 128];
1869        let scale = 2.0;
1870        let min_val = -10.0;
1871
1872        dequantize_uint8(&input, &mut output, scale, min_val, 128);
1873
1874        for (i, &out) in output.iter().enumerate().take(128) {
1875            let expected = i as f32 * scale + min_val;
1876            assert!(
1877                (out - expected).abs() < 1e-5,
1878                "Mismatch at {}: expected {}, got {}",
1879                i,
1880                expected,
1881                out
1882            );
1883        }
1884    }
1885
1886    #[test]
1887    fn test_dot_product_f32() {
1888        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1889        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
1890
1891        let result = dot_product_f32(&a, &b, 5);
1892
1893        // Expected: 1*2 + 2*3 + 3*4 + 4*5 + 5*6 = 2 + 6 + 12 + 20 + 30 = 70
1894        assert!((result - 70.0).abs() < 1e-5);
1895    }
1896
1897    #[test]
1898    fn test_dot_product_f32_large() {
1899        // Test with 128 values
1900        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
1901        let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
1902
1903        let result = dot_product_f32(&a, &b, 128);
1904
1905        // Compute expected
1906        let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
1907        assert!(
1908            (result - expected).abs() < 1e-3,
1909            "Expected {}, got {}",
1910            expected,
1911            result
1912        );
1913    }
1914
1915    #[test]
1916    fn test_max_f32() {
1917        let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
1918        let result = max_f32(&values, 6);
1919        assert!((result - 9.0).abs() < 1e-6);
1920    }
1921
1922    #[test]
1923    fn test_max_f32_large() {
1924        // Test with 128 values, max at position 77
1925        let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
1926        values[77] = 1000.0;
1927
1928        let result = max_f32(&values, 128);
1929        assert!((result - 1000.0).abs() < 1e-5);
1930    }
1931
1932    #[test]
1933    fn test_max_f32_negative() {
1934        let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
1935        let result = max_f32(&values, 5);
1936        assert!((result - (-1.0)).abs() < 1e-6);
1937    }
1938
1939    #[test]
1940    fn test_max_f32_empty() {
1941        let values: Vec<f32> = vec![];
1942        let result = max_f32(&values, 0);
1943        assert_eq!(result, f32::NEG_INFINITY);
1944    }
1945}