Skip to main content

hermes_core/structures/
simd.rs

1//! Shared SIMD-accelerated functions for posting list compression
2//!
3//! This module provides platform-optimized implementations for common operations:
4//! - **Unpacking**: Convert packed 8/16/32-bit values to u32 arrays
5//! - **Delta decoding**: Prefix sum for converting deltas to absolute values
6//! - **Add one**: Increment all values in an array (for TF decoding)
7//!
8//! Supports:
9//! - **NEON** on aarch64 (Apple Silicon, ARM servers)
10//! - **SSE/SSE4.1** on x86_64 (Intel/AMD)
11//! - **Scalar fallback** for other architectures
12
13// ============================================================================
14// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
15// ============================================================================
16
17#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20    use std::arch::aarch64::*;
21
22    /// SIMD unpack for 8-bit values using NEON
23    #[target_feature(enable = "neon")]
24    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25        let chunks = count / 16;
26        let remainder = count % 16;
27
28        for chunk in 0..chunks {
29            let base = chunk * 16;
30            let in_ptr = input.as_ptr().add(base);
31
32            // Load 16 bytes
33            let bytes = vld1q_u8(in_ptr);
34
35            // Widen u8 -> u16 -> u32
36            let low8 = vget_low_u8(bytes);
37            let high8 = vget_high_u8(bytes);
38
39            let low16 = vmovl_u8(low8);
40            let high16 = vmovl_u8(high8);
41
42            let v0 = vmovl_u16(vget_low_u16(low16));
43            let v1 = vmovl_u16(vget_high_u16(low16));
44            let v2 = vmovl_u16(vget_low_u16(high16));
45            let v3 = vmovl_u16(vget_high_u16(high16));
46
47            let out_ptr = output.as_mut_ptr().add(base);
48            vst1q_u32(out_ptr, v0);
49            vst1q_u32(out_ptr.add(4), v1);
50            vst1q_u32(out_ptr.add(8), v2);
51            vst1q_u32(out_ptr.add(12), v3);
52        }
53
54        // Handle remainder
55        let base = chunks * 16;
56        for i in 0..remainder {
57            output[base + i] = input[base + i] as u32;
58        }
59    }
60
61    /// SIMD unpack for 16-bit values using NEON
62    #[target_feature(enable = "neon")]
63    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64        let chunks = count / 8;
65        let remainder = count % 8;
66
67        for chunk in 0..chunks {
68            let base = chunk * 8;
69            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71            let vals = vld1q_u16(in_ptr);
72            let low = vmovl_u16(vget_low_u16(vals));
73            let high = vmovl_u16(vget_high_u16(vals));
74
75            let out_ptr = output.as_mut_ptr().add(base);
76            vst1q_u32(out_ptr, low);
77            vst1q_u32(out_ptr.add(4), high);
78        }
79
80        // Handle remainder
81        let base = chunks * 8;
82        for i in 0..remainder {
83            let idx = (base + i) * 2;
84            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85        }
86    }
87
88    /// SIMD unpack for 32-bit values using NEON (fast copy)
89    #[target_feature(enable = "neon")]
90    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91        let chunks = count / 4;
92        let remainder = count % 4;
93
94        let in_ptr = input.as_ptr() as *const u32;
95        let out_ptr = output.as_mut_ptr();
96
97        for chunk in 0..chunks {
98            let vals = vld1q_u32(in_ptr.add(chunk * 4));
99            vst1q_u32(out_ptr.add(chunk * 4), vals);
100        }
101
102        // Handle remainder
103        let base = chunks * 4;
104        for i in 0..remainder {
105            let idx = (base + i) * 4;
106            output[base + i] =
107                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108        }
109    }
110
111    /// SIMD prefix sum for 4 u32 values using NEON
112    /// Input:  [a, b, c, d]
113    /// Output: [a, a+b, a+b+c, a+b+c+d]
114    #[inline]
115    #[target_feature(enable = "neon")]
116    unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117        // Step 1: shift by 1 and add
118        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
119        let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120        let sum1 = vaddq_u32(v, shifted1);
121
122        // Step 2: shift by 2 and add
123        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
124        let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125        vaddq_u32(sum1, shifted2)
126    }
127
128    /// SIMD delta decode: convert deltas to absolute doc IDs
129    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
130    /// Uses NEON SIMD prefix sum for high throughput
131    #[target_feature(enable = "neon")]
132    pub unsafe fn delta_decode(
133        output: &mut [u32],
134        deltas: &[u32],
135        first_doc_id: u32,
136        count: usize,
137    ) {
138        if count == 0 {
139            return;
140        }
141
142        output[0] = first_doc_id;
143        if count == 1 {
144            return;
145        }
146
147        let ones = vdupq_n_u32(1);
148        let mut carry = vdupq_n_u32(first_doc_id);
149
150        let full_groups = (count - 1) / 4;
151        let remainder = (count - 1) % 4;
152
153        for group in 0..full_groups {
154            let base = group * 4;
155
156            // Load 4 deltas and add 1 (since we store gap-1)
157            let d = vld1q_u32(deltas[base..].as_ptr());
158            let gaps = vaddq_u32(d, ones);
159
160            // Compute prefix sum within the 4 elements
161            let prefix = prefix_sum_4(gaps);
162
163            // Add carry (broadcast last element of previous group)
164            let result = vaddq_u32(prefix, carry);
165
166            // Store result
167            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169            // Update carry: broadcast the last element for next iteration
170            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171        }
172
173        // Handle remainder
174        let base = full_groups * 4;
175        let mut scalar_carry = vgetq_lane_u32(carry, 0);
176        for j in 0..remainder {
177            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178            output[base + j + 1] = scalar_carry;
179        }
180    }
181
182    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
183    #[target_feature(enable = "neon")]
184    pub unsafe fn add_one(values: &mut [u32], count: usize) {
185        let ones = vdupq_n_u32(1);
186        let chunks = count / 4;
187        let remainder = count % 4;
188
189        for chunk in 0..chunks {
190            let base = chunk * 4;
191            let ptr = values.as_mut_ptr().add(base);
192            let v = vld1q_u32(ptr);
193            let result = vaddq_u32(v, ones);
194            vst1q_u32(ptr, result);
195        }
196
197        let base = chunks * 4;
198        for i in 0..remainder {
199            values[base + i] += 1;
200        }
201    }
202
203    /// Fused unpack 8-bit + delta decode using NEON
204    /// Processes 4 values at a time, fusing unpack and prefix sum
205    #[target_feature(enable = "neon")]
206    pub unsafe fn unpack_8bit_delta_decode(
207        input: &[u8],
208        output: &mut [u32],
209        first_value: u32,
210        count: usize,
211    ) {
212        output[0] = first_value;
213        if count <= 1 {
214            return;
215        }
216
217        let ones = vdupq_n_u32(1);
218        let mut carry = vdupq_n_u32(first_value);
219
220        let full_groups = (count - 1) / 4;
221        let remainder = (count - 1) % 4;
222
223        for group in 0..full_groups {
224            let base = group * 4;
225
226            // Load 4 bytes and widen to u32
227            let b0 = input[base] as u32;
228            let b1 = input[base + 1] as u32;
229            let b2 = input[base + 2] as u32;
230            let b3 = input[base + 3] as u32;
231            let deltas = [b0, b1, b2, b3];
232            let d = vld1q_u32(deltas.as_ptr());
233
234            // Add 1 (since we store gap-1)
235            let gaps = vaddq_u32(d, ones);
236
237            // Compute prefix sum within the 4 elements
238            let prefix = prefix_sum_4(gaps);
239
240            // Add carry
241            let result = vaddq_u32(prefix, carry);
242
243            // Store result
244            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246            // Update carry
247            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248        }
249
250        // Handle remainder
251        let base = full_groups * 4;
252        let mut scalar_carry = vgetq_lane_u32(carry, 0);
253        for j in 0..remainder {
254            scalar_carry = scalar_carry
255                .wrapping_add(input[base + j] as u32)
256                .wrapping_add(1);
257            output[base + j + 1] = scalar_carry;
258        }
259    }
260
261    /// Fused unpack 16-bit + delta decode using NEON
262    #[target_feature(enable = "neon")]
263    pub unsafe fn unpack_16bit_delta_decode(
264        input: &[u8],
265        output: &mut [u32],
266        first_value: u32,
267        count: usize,
268    ) {
269        output[0] = first_value;
270        if count <= 1 {
271            return;
272        }
273
274        let ones = vdupq_n_u32(1);
275        let mut carry = vdupq_n_u32(first_value);
276
277        let full_groups = (count - 1) / 4;
278        let remainder = (count - 1) % 4;
279
280        for group in 0..full_groups {
281            let base = group * 4;
282            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284            // Load 4 u16 values and widen to u32
285            let vals = vld1_u16(in_ptr);
286            let d = vmovl_u16(vals);
287
288            // Add 1 (since we store gap-1)
289            let gaps = vaddq_u32(d, ones);
290
291            // Compute prefix sum within the 4 elements
292            let prefix = prefix_sum_4(gaps);
293
294            // Add carry
295            let result = vaddq_u32(prefix, carry);
296
297            // Store result
298            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300            // Update carry
301            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302        }
303
304        // Handle remainder
305        let base = full_groups * 4;
306        let mut scalar_carry = vgetq_lane_u32(carry, 0);
307        for j in 0..remainder {
308            let idx = (base + j) * 2;
309            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311            output[base + j + 1] = scalar_carry;
312        }
313    }
314
315    /// Check if NEON is available (always true on aarch64)
316    #[inline]
317    pub fn is_available() -> bool {
318        true
319    }
320}
321
322// ============================================================================
323// SSE intrinsics for x86_64 (Intel/AMD)
324// ============================================================================
325
326#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329    use std::arch::x86_64::*;
330
331    /// SIMD unpack for 8-bit values using SSE
332    #[target_feature(enable = "sse2", enable = "sse4.1")]
333    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334        let chunks = count / 16;
335        let remainder = count % 16;
336
337        for chunk in 0..chunks {
338            let base = chunk * 16;
339            let in_ptr = input.as_ptr().add(base);
340
341            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343            // Zero extend u8 -> u32 using SSE4.1 pmovzx
344            let v0 = _mm_cvtepu8_epi32(bytes);
345            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349            let out_ptr = output.as_mut_ptr().add(base);
350            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354        }
355
356        let base = chunks * 16;
357        for i in 0..remainder {
358            output[base + i] = input[base + i] as u32;
359        }
360    }
361
362    /// SIMD unpack for 16-bit values using SSE
363    #[target_feature(enable = "sse2", enable = "sse4.1")]
364    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365        let chunks = count / 8;
366        let remainder = count % 8;
367
368        for chunk in 0..chunks {
369            let base = chunk * 8;
370            let in_ptr = input.as_ptr().add(base * 2);
371
372            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373            let low = _mm_cvtepu16_epi32(vals);
374            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376            let out_ptr = output.as_mut_ptr().add(base);
377            _mm_storeu_si128(out_ptr as *mut __m128i, low);
378            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379        }
380
381        let base = chunks * 8;
382        for i in 0..remainder {
383            let idx = (base + i) * 2;
384            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385        }
386    }
387
388    /// SIMD unpack for 32-bit values using SSE (fast copy)
389    #[target_feature(enable = "sse2")]
390    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391        let chunks = count / 4;
392        let remainder = count % 4;
393
394        let in_ptr = input.as_ptr() as *const __m128i;
395        let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397        for chunk in 0..chunks {
398            let vals = _mm_loadu_si128(in_ptr.add(chunk));
399            _mm_storeu_si128(out_ptr.add(chunk), vals);
400        }
401
402        // Handle remainder
403        let base = chunks * 4;
404        for i in 0..remainder {
405            let idx = (base + i) * 4;
406            output[base + i] =
407                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408        }
409    }
410
411    /// SIMD prefix sum for 4 u32 values using SSE
412    /// Input:  [a, b, c, d]
413    /// Output: [a, a+b, a+b+c, a+b+c+d]
414    #[inline]
415    #[target_feature(enable = "sse2")]
416    unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417        // Step 1: shift by 1 element (4 bytes) and add
418        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
419        let shifted1 = _mm_slli_si128(v, 4);
420        let sum1 = _mm_add_epi32(v, shifted1);
421
422        // Step 2: shift by 2 elements (8 bytes) and add
423        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
424        let shifted2 = _mm_slli_si128(sum1, 8);
425        _mm_add_epi32(sum1, shifted2)
426    }
427
428    /// SIMD delta decode using SSE with true SIMD prefix sum
429    #[target_feature(enable = "sse2", enable = "sse4.1")]
430    pub unsafe fn delta_decode(
431        output: &mut [u32],
432        deltas: &[u32],
433        first_doc_id: u32,
434        count: usize,
435    ) {
436        if count == 0 {
437            return;
438        }
439
440        output[0] = first_doc_id;
441        if count == 1 {
442            return;
443        }
444
445        let ones = _mm_set1_epi32(1);
446        let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448        let full_groups = (count - 1) / 4;
449        let remainder = (count - 1) % 4;
450
451        for group in 0..full_groups {
452            let base = group * 4;
453
454            // Load 4 deltas and add 1 (since we store gap-1)
455            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456            let gaps = _mm_add_epi32(d, ones);
457
458            // Compute prefix sum within the 4 elements
459            let prefix = prefix_sum_4(gaps);
460
461            // Add carry (broadcast last element of previous group)
462            let result = _mm_add_epi32(prefix, carry);
463
464            // Store result
465            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467            // Update carry: broadcast the last element for next iteration
468            carry = _mm_shuffle_epi32(result, 0xFF); // broadcast lane 3
469        }
470
471        // Handle remainder
472        let base = full_groups * 4;
473        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474        for j in 0..remainder {
475            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476            output[base + j + 1] = scalar_carry;
477        }
478    }
479
480    /// SIMD add 1 to all values using SSE
481    #[target_feature(enable = "sse2")]
482    pub unsafe fn add_one(values: &mut [u32], count: usize) {
483        let ones = _mm_set1_epi32(1);
484        let chunks = count / 4;
485        let remainder = count % 4;
486
487        for chunk in 0..chunks {
488            let base = chunk * 4;
489            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490            let v = _mm_loadu_si128(ptr);
491            let result = _mm_add_epi32(v, ones);
492            _mm_storeu_si128(ptr, result);
493        }
494
495        let base = chunks * 4;
496        for i in 0..remainder {
497            values[base + i] += 1;
498        }
499    }
500
501    /// Fused unpack 8-bit + delta decode using SSE
502    #[target_feature(enable = "sse2", enable = "sse4.1")]
503    pub unsafe fn unpack_8bit_delta_decode(
504        input: &[u8],
505        output: &mut [u32],
506        first_value: u32,
507        count: usize,
508    ) {
509        output[0] = first_value;
510        if count <= 1 {
511            return;
512        }
513
514        let ones = _mm_set1_epi32(1);
515        let mut carry = _mm_set1_epi32(first_value as i32);
516
517        let full_groups = (count - 1) / 4;
518        let remainder = (count - 1) % 4;
519
520        for group in 0..full_groups {
521            let base = group * 4;
522
523            // Load 4 bytes (unaligned) and zero-extend to u32
524            let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
525                input.as_ptr().add(base) as *const i32
526            ));
527            let d = _mm_cvtepu8_epi32(bytes);
528
529            // Add 1 (since we store gap-1)
530            let gaps = _mm_add_epi32(d, ones);
531
532            // Compute prefix sum within the 4 elements
533            let prefix = prefix_sum_4(gaps);
534
535            // Add carry
536            let result = _mm_add_epi32(prefix, carry);
537
538            // Store result
539            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
540
541            // Update carry: broadcast the last element
542            carry = _mm_shuffle_epi32(result, 0xFF);
543        }
544
545        // Handle remainder
546        let base = full_groups * 4;
547        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
548        for j in 0..remainder {
549            scalar_carry = scalar_carry
550                .wrapping_add(input[base + j] as u32)
551                .wrapping_add(1);
552            output[base + j + 1] = scalar_carry;
553        }
554    }
555
556    /// Fused unpack 16-bit + delta decode using SSE
557    #[target_feature(enable = "sse2", enable = "sse4.1")]
558    pub unsafe fn unpack_16bit_delta_decode(
559        input: &[u8],
560        output: &mut [u32],
561        first_value: u32,
562        count: usize,
563    ) {
564        output[0] = first_value;
565        if count <= 1 {
566            return;
567        }
568
569        let ones = _mm_set1_epi32(1);
570        let mut carry = _mm_set1_epi32(first_value as i32);
571
572        let full_groups = (count - 1) / 4;
573        let remainder = (count - 1) % 4;
574
575        for group in 0..full_groups {
576            let base = group * 4;
577            let in_ptr = input.as_ptr().add(base * 2);
578
579            // Load 8 bytes (4 u16 values, unaligned) and zero-extend to u32
580            let vals = _mm_loadl_epi64(in_ptr as *const __m128i); // loadl_epi64 supports unaligned
581            let d = _mm_cvtepu16_epi32(vals);
582
583            // Add 1 (since we store gap-1)
584            let gaps = _mm_add_epi32(d, ones);
585
586            // Compute prefix sum within the 4 elements
587            let prefix = prefix_sum_4(gaps);
588
589            // Add carry
590            let result = _mm_add_epi32(prefix, carry);
591
592            // Store result
593            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
594
595            // Update carry: broadcast the last element
596            carry = _mm_shuffle_epi32(result, 0xFF);
597        }
598
599        // Handle remainder
600        let base = full_groups * 4;
601        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
602        for j in 0..remainder {
603            let idx = (base + j) * 2;
604            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
605            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
606            output[base + j + 1] = scalar_carry;
607        }
608    }
609
610    /// Check if SSE4.1 is available at runtime
611    #[inline]
612    pub fn is_available() -> bool {
613        is_x86_feature_detected!("sse4.1")
614    }
615}
616
617// ============================================================================
618// AVX2 intrinsics for x86_64 (Intel/AMD with 256-bit registers)
619// ============================================================================
620
621#[cfg(target_arch = "x86_64")]
622#[allow(unsafe_op_in_unsafe_fn)]
623mod avx2 {
624    use std::arch::x86_64::*;
625
626    /// AVX2 unpack for 8-bit values (processes 32 bytes at a time)
627    #[target_feature(enable = "avx2")]
628    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
629        let chunks = count / 32;
630        let remainder = count % 32;
631
632        for chunk in 0..chunks {
633            let base = chunk * 32;
634            let in_ptr = input.as_ptr().add(base);
635
636            // Load 32 bytes (two 128-bit loads, then combine)
637            let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
638            let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
639
640            // Zero extend first 16 bytes: u8 -> u32
641            let v0 = _mm256_cvtepu8_epi32(bytes_lo);
642            let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
643            let v2 = _mm256_cvtepu8_epi32(bytes_hi);
644            let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
645
646            let out_ptr = output.as_mut_ptr().add(base);
647            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
648            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
649            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
650            _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
651        }
652
653        // Handle remainder with SSE
654        let base = chunks * 32;
655        for i in 0..remainder {
656            output[base + i] = input[base + i] as u32;
657        }
658    }
659
660    /// AVX2 unpack for 16-bit values (processes 16 values at a time)
661    #[target_feature(enable = "avx2")]
662    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
663        let chunks = count / 16;
664        let remainder = count % 16;
665
666        for chunk in 0..chunks {
667            let base = chunk * 16;
668            let in_ptr = input.as_ptr().add(base * 2);
669
670            // Load 32 bytes (16 u16 values)
671            let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
672            let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
673
674            // Zero extend u16 -> u32
675            let v0 = _mm256_cvtepu16_epi32(vals_lo);
676            let v1 = _mm256_cvtepu16_epi32(vals_hi);
677
678            let out_ptr = output.as_mut_ptr().add(base);
679            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
680            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
681        }
682
683        // Handle remainder
684        let base = chunks * 16;
685        for i in 0..remainder {
686            let idx = (base + i) * 2;
687            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
688        }
689    }
690
691    /// AVX2 unpack for 32-bit values (fast copy, 8 values at a time)
692    #[target_feature(enable = "avx2")]
693    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
694        let chunks = count / 8;
695        let remainder = count % 8;
696
697        let in_ptr = input.as_ptr() as *const __m256i;
698        let out_ptr = output.as_mut_ptr() as *mut __m256i;
699
700        for chunk in 0..chunks {
701            let vals = _mm256_loadu_si256(in_ptr.add(chunk));
702            _mm256_storeu_si256(out_ptr.add(chunk), vals);
703        }
704
705        // Handle remainder
706        let base = chunks * 8;
707        for i in 0..remainder {
708            let idx = (base + i) * 4;
709            output[base + i] =
710                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
711        }
712    }
713
714    /// AVX2 add 1 to all values (8 values at a time)
715    #[target_feature(enable = "avx2")]
716    pub unsafe fn add_one(values: &mut [u32], count: usize) {
717        let ones = _mm256_set1_epi32(1);
718        let chunks = count / 8;
719        let remainder = count % 8;
720
721        for chunk in 0..chunks {
722            let base = chunk * 8;
723            let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
724            let v = _mm256_loadu_si256(ptr);
725            let result = _mm256_add_epi32(v, ones);
726            _mm256_storeu_si256(ptr, result);
727        }
728
729        let base = chunks * 8;
730        for i in 0..remainder {
731            values[base + i] += 1;
732        }
733    }
734
735    /// Check if AVX2 is available at runtime
736    #[inline]
737    pub fn is_available() -> bool {
738        is_x86_feature_detected!("avx2")
739    }
740}
741
742// ============================================================================
743// Scalar fallback implementations
744// ============================================================================
745
746#[allow(dead_code)]
747mod scalar {
748    /// Scalar unpack for 8-bit values
749    #[inline]
750    pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
751        for i in 0..count {
752            output[i] = input[i] as u32;
753        }
754    }
755
756    /// Scalar unpack for 16-bit values
757    #[inline]
758    pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
759        for (i, out) in output.iter_mut().enumerate().take(count) {
760            let idx = i * 2;
761            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
762        }
763    }
764
765    /// Scalar unpack for 32-bit values
766    #[inline]
767    pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
768        for (i, out) in output.iter_mut().enumerate().take(count) {
769            let idx = i * 4;
770            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
771        }
772    }
773
774    /// Scalar delta decode
775    #[inline]
776    pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
777        if count == 0 {
778            return;
779        }
780
781        output[0] = first_doc_id;
782        let mut carry = first_doc_id;
783
784        for i in 0..count - 1 {
785            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
786            output[i + 1] = carry;
787        }
788    }
789
790    /// Scalar add 1 to all values
791    #[inline]
792    pub fn add_one(values: &mut [u32], count: usize) {
793        for val in values.iter_mut().take(count) {
794            *val += 1;
795        }
796    }
797}
798
799// ============================================================================
800// Public dispatch functions that select SIMD or scalar at runtime
801// ============================================================================
802
803/// Unpack 8-bit packed values to u32 with SIMD acceleration
804#[inline]
805pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
806    #[cfg(target_arch = "aarch64")]
807    {
808        if neon::is_available() {
809            unsafe {
810                neon::unpack_8bit(input, output, count);
811            }
812            return;
813        }
814    }
815
816    #[cfg(target_arch = "x86_64")]
817    {
818        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
819        if avx2::is_available() {
820            unsafe {
821                avx2::unpack_8bit(input, output, count);
822            }
823            return;
824        }
825        if sse::is_available() {
826            unsafe {
827                sse::unpack_8bit(input, output, count);
828            }
829            return;
830        }
831    }
832
833    scalar::unpack_8bit(input, output, count);
834}
835
836/// Unpack 16-bit packed values to u32 with SIMD acceleration
837#[inline]
838pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
839    #[cfg(target_arch = "aarch64")]
840    {
841        if neon::is_available() {
842            unsafe {
843                neon::unpack_16bit(input, output, count);
844            }
845            return;
846        }
847    }
848
849    #[cfg(target_arch = "x86_64")]
850    {
851        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
852        if avx2::is_available() {
853            unsafe {
854                avx2::unpack_16bit(input, output, count);
855            }
856            return;
857        }
858        if sse::is_available() {
859            unsafe {
860                sse::unpack_16bit(input, output, count);
861            }
862            return;
863        }
864    }
865
866    scalar::unpack_16bit(input, output, count);
867}
868
869/// Unpack 32-bit packed values to u32 with SIMD acceleration
870#[inline]
871pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
872    #[cfg(target_arch = "aarch64")]
873    {
874        if neon::is_available() {
875            unsafe {
876                neon::unpack_32bit(input, output, count);
877            }
878        }
879    }
880
881    #[cfg(target_arch = "x86_64")]
882    {
883        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
884        if avx2::is_available() {
885            unsafe {
886                avx2::unpack_32bit(input, output, count);
887            }
888        } else {
889            // SSE2 is always available on x86_64
890            unsafe {
891                sse::unpack_32bit(input, output, count);
892            }
893        }
894    }
895
896    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
897    {
898        scalar::unpack_32bit(input, output, count);
899    }
900}
901
902/// Delta decode with SIMD acceleration
903///
904/// Converts delta-encoded values to absolute values.
905/// Input: deltas[i] = value[i+1] - value[i] - 1 (gap minus one)
906/// Output: absolute values starting from first_value
907#[inline]
908pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
909    #[cfg(target_arch = "aarch64")]
910    {
911        if neon::is_available() {
912            unsafe {
913                neon::delta_decode(output, deltas, first_value, count);
914            }
915            return;
916        }
917    }
918
919    #[cfg(target_arch = "x86_64")]
920    {
921        if sse::is_available() {
922            unsafe {
923                sse::delta_decode(output, deltas, first_value, count);
924            }
925            return;
926        }
927    }
928
929    scalar::delta_decode(output, deltas, first_value, count);
930}
931
932/// Add 1 to all values with SIMD acceleration
933///
934/// Used for TF decoding where values are stored as (tf - 1)
935#[inline]
936pub fn add_one(values: &mut [u32], count: usize) {
937    #[cfg(target_arch = "aarch64")]
938    {
939        if neon::is_available() {
940            unsafe {
941                neon::add_one(values, count);
942            }
943        }
944    }
945
946    #[cfg(target_arch = "x86_64")]
947    {
948        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
949        if avx2::is_available() {
950            unsafe {
951                avx2::add_one(values, count);
952            }
953        } else {
954            // SSE2 is always available on x86_64
955            unsafe {
956                sse::add_one(values, count);
957            }
958        }
959    }
960
961    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
962    {
963        scalar::add_one(values, count);
964    }
965}
966
967/// Compute the number of bits needed to represent a value
968#[inline]
969pub fn bits_needed(val: u32) -> u8 {
970    if val == 0 {
971        0
972    } else {
973        32 - val.leading_zeros() as u8
974    }
975}
976
977// ============================================================================
978// Rounded bitpacking for truly vectorized encoding/decoding
979// ============================================================================
980//
981// Instead of using arbitrary bit widths (1-32), we round up to SIMD-friendly
982// widths: 0, 8, 16, or 32 bits. This trades ~10-20% more space for much faster
983// decoding since we can use direct SIMD widening instructions (pmovzx) without
984// any bit-shifting or masking.
985//
986// Bit width mapping:
987//   0      -> 0  (all zeros)
988//   1-8    -> 8  (u8)
989//   9-16   -> 16 (u16)
990//   17-32  -> 32 (u32)
991
992/// Rounded bit width type for SIMD-friendly encoding
993#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994#[repr(u8)]
995pub enum RoundedBitWidth {
996    Zero = 0,
997    Bits8 = 8,
998    Bits16 = 16,
999    Bits32 = 32,
1000}
1001
1002impl RoundedBitWidth {
1003    /// Round an exact bit width to the nearest SIMD-friendly width
1004    #[inline]
1005    pub fn from_exact(bits: u8) -> Self {
1006        match bits {
1007            0 => RoundedBitWidth::Zero,
1008            1..=8 => RoundedBitWidth::Bits8,
1009            9..=16 => RoundedBitWidth::Bits16,
1010            _ => RoundedBitWidth::Bits32,
1011        }
1012    }
1013
1014    /// Convert from stored u8 value (must be 0, 8, 16, or 32)
1015    #[inline]
1016    pub fn from_u8(bits: u8) -> Self {
1017        match bits {
1018            0 => RoundedBitWidth::Zero,
1019            8 => RoundedBitWidth::Bits8,
1020            16 => RoundedBitWidth::Bits16,
1021            32 => RoundedBitWidth::Bits32,
1022            _ => RoundedBitWidth::Bits32, // Fallback for invalid values
1023        }
1024    }
1025
1026    /// Get the byte size per value
1027    #[inline]
1028    pub fn bytes_per_value(self) -> usize {
1029        match self {
1030            RoundedBitWidth::Zero => 0,
1031            RoundedBitWidth::Bits8 => 1,
1032            RoundedBitWidth::Bits16 => 2,
1033            RoundedBitWidth::Bits32 => 4,
1034        }
1035    }
1036
1037    /// Get the raw bit width value
1038    #[inline]
1039    pub fn as_u8(self) -> u8 {
1040        self as u8
1041    }
1042}
1043
1044/// Round a bit width to the nearest SIMD-friendly width (0, 8, 16, or 32)
1045#[inline]
1046pub fn round_bit_width(bits: u8) -> u8 {
1047    RoundedBitWidth::from_exact(bits).as_u8()
1048}
1049
1050/// Pack values using rounded bit width (SIMD-friendly)
1051///
1052/// This is much simpler than arbitrary bitpacking since values are byte-aligned.
1053/// Returns the number of bytes written.
1054#[inline]
1055pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1056    let count = values.len();
1057    match bit_width {
1058        RoundedBitWidth::Zero => 0,
1059        RoundedBitWidth::Bits8 => {
1060            for (i, &v) in values.iter().enumerate() {
1061                output[i] = v as u8;
1062            }
1063            count
1064        }
1065        RoundedBitWidth::Bits16 => {
1066            for (i, &v) in values.iter().enumerate() {
1067                let bytes = (v as u16).to_le_bytes();
1068                output[i * 2] = bytes[0];
1069                output[i * 2 + 1] = bytes[1];
1070            }
1071            count * 2
1072        }
1073        RoundedBitWidth::Bits32 => {
1074            for (i, &v) in values.iter().enumerate() {
1075                let bytes = v.to_le_bytes();
1076                output[i * 4] = bytes[0];
1077                output[i * 4 + 1] = bytes[1];
1078                output[i * 4 + 2] = bytes[2];
1079                output[i * 4 + 3] = bytes[3];
1080            }
1081            count * 4
1082        }
1083    }
1084}
1085
1086/// Unpack values using rounded bit width with SIMD acceleration
1087///
1088/// This is the fast path - no bit manipulation needed, just widening.
1089#[inline]
1090pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1091    match bit_width {
1092        RoundedBitWidth::Zero => {
1093            for out in output.iter_mut().take(count) {
1094                *out = 0;
1095            }
1096        }
1097        RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1098        RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1099        RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1100    }
1101}
1102
1103/// Fused unpack + delta decode using rounded bit width
1104///
1105/// Combines unpacking and prefix sum in a single pass for better cache utilization.
1106#[inline]
1107pub fn unpack_rounded_delta_decode(
1108    input: &[u8],
1109    bit_width: RoundedBitWidth,
1110    output: &mut [u32],
1111    first_value: u32,
1112    count: usize,
1113) {
1114    match bit_width {
1115        RoundedBitWidth::Zero => {
1116            // All deltas are 0, meaning gaps of 1
1117            let mut val = first_value;
1118            for out in output.iter_mut().take(count) {
1119                *out = val;
1120                val = val.wrapping_add(1);
1121            }
1122        }
1123        RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1124        RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1125        RoundedBitWidth::Bits32 => {
1126            // For 32-bit, unpack then delta decode (no fused version needed)
1127            unpack_32bit(input, output, count);
1128            // Delta decode in place - but we need the deltas separate
1129            // Actually for 32-bit we should just unpack and delta decode separately
1130            if count > 0 {
1131                let mut carry = first_value;
1132                output[0] = first_value;
1133                for item in output.iter_mut().take(count).skip(1) {
1134                    // item currently holds delta (gap-1)
1135                    carry = carry.wrapping_add(*item).wrapping_add(1);
1136                    *item = carry;
1137                }
1138            }
1139        }
1140    }
1141}
1142
1143// ============================================================================
1144// Fused operations for better cache utilization
1145// ============================================================================
1146
1147/// Fused unpack 8-bit + delta decode in a single pass
1148///
1149/// This avoids writing the intermediate unpacked values to memory,
1150/// improving cache utilization for large blocks.
1151#[inline]
1152pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1153    if count == 0 {
1154        return;
1155    }
1156
1157    output[0] = first_value;
1158    if count == 1 {
1159        return;
1160    }
1161
1162    #[cfg(target_arch = "aarch64")]
1163    {
1164        if neon::is_available() {
1165            unsafe {
1166                neon::unpack_8bit_delta_decode(input, output, first_value, count);
1167            }
1168            return;
1169        }
1170    }
1171
1172    #[cfg(target_arch = "x86_64")]
1173    {
1174        if sse::is_available() {
1175            unsafe {
1176                sse::unpack_8bit_delta_decode(input, output, first_value, count);
1177            }
1178            return;
1179        }
1180    }
1181
1182    // Scalar fallback
1183    let mut carry = first_value;
1184    for i in 0..count - 1 {
1185        carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1186        output[i + 1] = carry;
1187    }
1188}
1189
1190/// Fused unpack 16-bit + delta decode in a single pass
1191#[inline]
1192pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1193    if count == 0 {
1194        return;
1195    }
1196
1197    output[0] = first_value;
1198    if count == 1 {
1199        return;
1200    }
1201
1202    #[cfg(target_arch = "aarch64")]
1203    {
1204        if neon::is_available() {
1205            unsafe {
1206                neon::unpack_16bit_delta_decode(input, output, first_value, count);
1207            }
1208            return;
1209        }
1210    }
1211
1212    #[cfg(target_arch = "x86_64")]
1213    {
1214        if sse::is_available() {
1215            unsafe {
1216                sse::unpack_16bit_delta_decode(input, output, first_value, count);
1217            }
1218            return;
1219        }
1220    }
1221
1222    // Scalar fallback
1223    let mut carry = first_value;
1224    for i in 0..count - 1 {
1225        let idx = i * 2;
1226        let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1227        carry = carry.wrapping_add(delta).wrapping_add(1);
1228        output[i + 1] = carry;
1229    }
1230}
1231
1232/// Fused unpack + delta decode for arbitrary bit widths
1233///
1234/// Combines unpacking and prefix sum in a single pass, avoiding intermediate buffer.
1235/// Uses SIMD-accelerated paths for 8/16-bit widths, scalar for others.
1236#[inline]
1237pub fn unpack_delta_decode(
1238    input: &[u8],
1239    bit_width: u8,
1240    output: &mut [u32],
1241    first_value: u32,
1242    count: usize,
1243) {
1244    if count == 0 {
1245        return;
1246    }
1247
1248    output[0] = first_value;
1249    if count == 1 {
1250        return;
1251    }
1252
1253    // Fast paths for SIMD-friendly bit widths
1254    match bit_width {
1255        0 => {
1256            // All zeros = consecutive doc IDs (gap of 1)
1257            let mut val = first_value;
1258            for item in output.iter_mut().take(count).skip(1) {
1259                val = val.wrapping_add(1);
1260                *item = val;
1261            }
1262        }
1263        8 => unpack_8bit_delta_decode(input, output, first_value, count),
1264        16 => unpack_16bit_delta_decode(input, output, first_value, count),
1265        32 => {
1266            // 32-bit: unpack inline and delta decode
1267            let mut carry = first_value;
1268            for i in 0..count - 1 {
1269                let idx = i * 4;
1270                let delta = u32::from_le_bytes([
1271                    input[idx],
1272                    input[idx + 1],
1273                    input[idx + 2],
1274                    input[idx + 3],
1275                ]);
1276                carry = carry.wrapping_add(delta).wrapping_add(1);
1277                output[i + 1] = carry;
1278            }
1279        }
1280        _ => {
1281            // Generic bit width: fused unpack + delta decode
1282            let mask = (1u64 << bit_width) - 1;
1283            let bit_width_usize = bit_width as usize;
1284            let mut bit_pos = 0usize;
1285            let input_ptr = input.as_ptr();
1286            let mut carry = first_value;
1287
1288            for i in 0..count - 1 {
1289                let byte_idx = bit_pos >> 3;
1290                let bit_offset = bit_pos & 7;
1291
1292                // SAFETY: Caller guarantees input has enough data
1293                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1294                let delta = ((word >> bit_offset) & mask) as u32;
1295
1296                carry = carry.wrapping_add(delta).wrapping_add(1);
1297                output[i + 1] = carry;
1298                bit_pos += bit_width_usize;
1299            }
1300        }
1301    }
1302}
1303
1304// ============================================================================
1305// Sparse Vector SIMD Functions
1306// ============================================================================
1307
1308/// Dequantize UInt8 weights to f32 with SIMD acceleration
1309///
1310/// Computes: output[i] = input[i] as f32 * scale + min_val
1311#[inline]
1312pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1313    #[cfg(target_arch = "aarch64")]
1314    {
1315        if neon::is_available() {
1316            unsafe {
1317                dequantize_uint8_neon(input, output, scale, min_val, count);
1318            }
1319            return;
1320        }
1321    }
1322
1323    #[cfg(target_arch = "x86_64")]
1324    {
1325        if sse::is_available() {
1326            unsafe {
1327                dequantize_uint8_sse(input, output, scale, min_val, count);
1328            }
1329            return;
1330        }
1331    }
1332
1333    // Scalar fallback
1334    for i in 0..count {
1335        output[i] = input[i] as f32 * scale + min_val;
1336    }
1337}
1338
1339#[cfg(target_arch = "aarch64")]
1340#[target_feature(enable = "neon")]
1341#[allow(unsafe_op_in_unsafe_fn)]
1342unsafe fn dequantize_uint8_neon(
1343    input: &[u8],
1344    output: &mut [f32],
1345    scale: f32,
1346    min_val: f32,
1347    count: usize,
1348) {
1349    use std::arch::aarch64::*;
1350
1351    let scale_v = vdupq_n_f32(scale);
1352    let min_v = vdupq_n_f32(min_val);
1353
1354    let chunks = count / 16;
1355    let remainder = count % 16;
1356
1357    for chunk in 0..chunks {
1358        let base = chunk * 16;
1359        let in_ptr = input.as_ptr().add(base);
1360
1361        // Load 16 bytes
1362        let bytes = vld1q_u8(in_ptr);
1363
1364        // Widen u8 -> u16 -> u32 -> f32
1365        let low8 = vget_low_u8(bytes);
1366        let high8 = vget_high_u8(bytes);
1367
1368        let low16 = vmovl_u8(low8);
1369        let high16 = vmovl_u8(high8);
1370
1371        // Process 4 values at a time
1372        let u32_0 = vmovl_u16(vget_low_u16(low16));
1373        let u32_1 = vmovl_u16(vget_high_u16(low16));
1374        let u32_2 = vmovl_u16(vget_low_u16(high16));
1375        let u32_3 = vmovl_u16(vget_high_u16(high16));
1376
1377        // Convert to f32 and apply scale + min_val
1378        let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1379        let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1380        let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1381        let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1382
1383        let out_ptr = output.as_mut_ptr().add(base);
1384        vst1q_f32(out_ptr, f32_0);
1385        vst1q_f32(out_ptr.add(4), f32_1);
1386        vst1q_f32(out_ptr.add(8), f32_2);
1387        vst1q_f32(out_ptr.add(12), f32_3);
1388    }
1389
1390    // Handle remainder
1391    let base = chunks * 16;
1392    for i in 0..remainder {
1393        output[base + i] = input[base + i] as f32 * scale + min_val;
1394    }
1395}
1396
1397#[cfg(target_arch = "x86_64")]
1398#[target_feature(enable = "sse2", enable = "sse4.1")]
1399#[allow(unsafe_op_in_unsafe_fn)]
1400unsafe fn dequantize_uint8_sse(
1401    input: &[u8],
1402    output: &mut [f32],
1403    scale: f32,
1404    min_val: f32,
1405    count: usize,
1406) {
1407    use std::arch::x86_64::*;
1408
1409    let scale_v = _mm_set1_ps(scale);
1410    let min_v = _mm_set1_ps(min_val);
1411
1412    let chunks = count / 4;
1413    let remainder = count % 4;
1414
1415    for chunk in 0..chunks {
1416        let base = chunk * 4;
1417
1418        // Load 4 bytes and zero-extend to 32-bit
1419        let b0 = input[base] as i32;
1420        let b1 = input[base + 1] as i32;
1421        let b2 = input[base + 2] as i32;
1422        let b3 = input[base + 3] as i32;
1423
1424        let ints = _mm_set_epi32(b3, b2, b1, b0);
1425        let floats = _mm_cvtepi32_ps(ints);
1426
1427        // Apply scale and min_val: result = floats * scale + min_val
1428        let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1429
1430        _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1431    }
1432
1433    // Handle remainder
1434    let base = chunks * 4;
1435    for i in 0..remainder {
1436        output[base + i] = input[base + i] as f32 * scale + min_val;
1437    }
1438}
1439
1440/// Compute dot product of two f32 arrays with SIMD acceleration
1441#[inline]
1442pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1443    #[cfg(target_arch = "aarch64")]
1444    {
1445        if neon::is_available() {
1446            return unsafe { dot_product_f32_neon(a, b, count) };
1447        }
1448    }
1449
1450    #[cfg(target_arch = "x86_64")]
1451    {
1452        if sse::is_available() {
1453            return unsafe { dot_product_f32_sse(a, b, count) };
1454        }
1455    }
1456
1457    // Scalar fallback
1458    let mut sum = 0.0f32;
1459    for i in 0..count {
1460        sum += a[i] * b[i];
1461    }
1462    sum
1463}
1464
1465#[cfg(target_arch = "aarch64")]
1466#[target_feature(enable = "neon")]
1467#[allow(unsafe_op_in_unsafe_fn)]
1468unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1469    use std::arch::aarch64::*;
1470
1471    let chunks = count / 4;
1472    let remainder = count % 4;
1473
1474    let mut acc = vdupq_n_f32(0.0);
1475
1476    for chunk in 0..chunks {
1477        let base = chunk * 4;
1478        let va = vld1q_f32(a.as_ptr().add(base));
1479        let vb = vld1q_f32(b.as_ptr().add(base));
1480        acc = vfmaq_f32(acc, va, vb);
1481    }
1482
1483    // Horizontal sum
1484    let mut sum = vaddvq_f32(acc);
1485
1486    // Handle remainder
1487    let base = chunks * 4;
1488    for i in 0..remainder {
1489        sum += a[base + i] * b[base + i];
1490    }
1491
1492    sum
1493}
1494
1495#[cfg(target_arch = "x86_64")]
1496#[target_feature(enable = "sse")]
1497#[allow(unsafe_op_in_unsafe_fn)]
1498unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1499    use std::arch::x86_64::*;
1500
1501    let chunks = count / 4;
1502    let remainder = count % 4;
1503
1504    let mut acc = _mm_setzero_ps();
1505
1506    for chunk in 0..chunks {
1507        let base = chunk * 4;
1508        let va = _mm_loadu_ps(a.as_ptr().add(base));
1509        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1510        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1511    }
1512
1513    // Horizontal sum: [a, b, c, d] -> a + b + c + d
1514    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
1515    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
1516    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
1517    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
1518
1519    let mut sum = _mm_cvtss_f32(final_sum);
1520
1521    // Handle remainder
1522    let base = chunks * 4;
1523    for i in 0..remainder {
1524        sum += a[base + i] * b[base + i];
1525    }
1526
1527    sum
1528}
1529
1530/// Find maximum value in f32 array with SIMD acceleration
1531#[inline]
1532pub fn max_f32(values: &[f32], count: usize) -> f32 {
1533    if count == 0 {
1534        return f32::NEG_INFINITY;
1535    }
1536
1537    #[cfg(target_arch = "aarch64")]
1538    {
1539        if neon::is_available() {
1540            return unsafe { max_f32_neon(values, count) };
1541        }
1542    }
1543
1544    #[cfg(target_arch = "x86_64")]
1545    {
1546        if sse::is_available() {
1547            return unsafe { max_f32_sse(values, count) };
1548        }
1549    }
1550
1551    // Scalar fallback
1552    values[..count]
1553        .iter()
1554        .cloned()
1555        .fold(f32::NEG_INFINITY, f32::max)
1556}
1557
1558#[cfg(target_arch = "aarch64")]
1559#[target_feature(enable = "neon")]
1560#[allow(unsafe_op_in_unsafe_fn)]
1561unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1562    use std::arch::aarch64::*;
1563
1564    let chunks = count / 4;
1565    let remainder = count % 4;
1566
1567    let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1568
1569    for chunk in 0..chunks {
1570        let base = chunk * 4;
1571        let v = vld1q_f32(values.as_ptr().add(base));
1572        max_v = vmaxq_f32(max_v, v);
1573    }
1574
1575    // Horizontal max
1576    let mut max_val = vmaxvq_f32(max_v);
1577
1578    // Handle remainder
1579    let base = chunks * 4;
1580    for i in 0..remainder {
1581        max_val = max_val.max(values[base + i]);
1582    }
1583
1584    max_val
1585}
1586
1587#[cfg(target_arch = "x86_64")]
1588#[target_feature(enable = "sse")]
1589#[allow(unsafe_op_in_unsafe_fn)]
1590unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1591    use std::arch::x86_64::*;
1592
1593    let chunks = count / 4;
1594    let remainder = count % 4;
1595
1596    let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1597
1598    for chunk in 0..chunks {
1599        let base = chunk * 4;
1600        let v = _mm_loadu_ps(values.as_ptr().add(base));
1601        max_v = _mm_max_ps(max_v, v);
1602    }
1603
1604    // Horizontal max: [a, b, c, d] -> max(a, b, c, d)
1605    let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); // [b, a, d, c]
1606    let max1 = _mm_max_ps(max_v, shuf); // [max(a,b), max(a,b), max(c,d), max(c,d)]
1607    let shuf2 = _mm_movehl_ps(max1, max1); // [max(c,d), max(c,d), ?, ?]
1608    let final_max = _mm_max_ss(max1, shuf2); // [max(a,b,c,d), ?, ?, ?]
1609
1610    let mut max_val = _mm_cvtss_f32(final_max);
1611
1612    // Handle remainder
1613    let base = chunks * 4;
1614    for i in 0..remainder {
1615        max_val = max_val.max(values[base + i]);
1616    }
1617
1618    max_val
1619}
1620
1621// ============================================================================
1622// Batched Cosine Similarity for Dense Vector Search
1623// ============================================================================
1624
1625/// Fused dot-product + self-norm in a single pass (SIMD accelerated).
1626///
1627/// Returns (dot(a, b), dot(b, b)) — i.e. the dot product of a·b and ||b||².
1628/// Loads `b` only once (halves memory bandwidth vs two separate dot products).
1629#[inline]
1630fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1631    #[cfg(target_arch = "aarch64")]
1632    {
1633        if neon::is_available() {
1634            return unsafe { fused_dot_norm_neon(a, b, count) };
1635        }
1636    }
1637
1638    #[cfg(target_arch = "x86_64")]
1639    {
1640        if sse::is_available() {
1641            return unsafe { fused_dot_norm_sse(a, b, count) };
1642        }
1643    }
1644
1645    // Scalar fallback
1646    let mut dot = 0.0f32;
1647    let mut norm_b = 0.0f32;
1648    for i in 0..count {
1649        dot += a[i] * b[i];
1650        norm_b += b[i] * b[i];
1651    }
1652    (dot, norm_b)
1653}
1654
1655#[cfg(target_arch = "aarch64")]
1656#[target_feature(enable = "neon")]
1657#[allow(unsafe_op_in_unsafe_fn)]
1658unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1659    use std::arch::aarch64::*;
1660
1661    let chunks = count / 4;
1662    let remainder = count % 4;
1663
1664    let mut acc_dot = vdupq_n_f32(0.0);
1665    let mut acc_norm = vdupq_n_f32(0.0);
1666
1667    for chunk in 0..chunks {
1668        let base = chunk * 4;
1669        let va = vld1q_f32(a.as_ptr().add(base));
1670        let vb = vld1q_f32(b.as_ptr().add(base));
1671        acc_dot = vfmaq_f32(acc_dot, va, vb);
1672        acc_norm = vfmaq_f32(acc_norm, vb, vb);
1673    }
1674
1675    let mut dot = vaddvq_f32(acc_dot);
1676    let mut norm = vaddvq_f32(acc_norm);
1677
1678    let base = chunks * 4;
1679    for i in 0..remainder {
1680        dot += a[base + i] * b[base + i];
1681        norm += b[base + i] * b[base + i];
1682    }
1683
1684    (dot, norm)
1685}
1686
1687#[cfg(target_arch = "x86_64")]
1688#[target_feature(enable = "sse")]
1689#[allow(unsafe_op_in_unsafe_fn)]
1690unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1691    use std::arch::x86_64::*;
1692
1693    let chunks = count / 4;
1694    let remainder = count % 4;
1695
1696    let mut acc_dot = _mm_setzero_ps();
1697    let mut acc_norm = _mm_setzero_ps();
1698
1699    for chunk in 0..chunks {
1700        let base = chunk * 4;
1701        let va = _mm_loadu_ps(a.as_ptr().add(base));
1702        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1703        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
1704        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
1705    }
1706
1707    // Horizontal sums
1708    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
1709    let sums_d = _mm_add_ps(acc_dot, shuf_d);
1710    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1711    let final_d = _mm_add_ss(sums_d, shuf2_d);
1712    let mut dot = _mm_cvtss_f32(final_d);
1713
1714    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
1715    let sums_n = _mm_add_ps(acc_norm, shuf_n);
1716    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1717    let final_n = _mm_add_ss(sums_n, shuf2_n);
1718    let mut norm = _mm_cvtss_f32(final_n);
1719
1720    let base = chunks * 4;
1721    for i in 0..remainder {
1722        dot += a[base + i] * b[base + i];
1723        norm += b[base + i] * b[base + i];
1724    }
1725
1726    (dot, norm)
1727}
1728
1729/// Batch cosine similarity: query vs N contiguous vectors.
1730///
1731/// `vectors` is a contiguous buffer of `n * dim` floats (row-major).
1732/// `scores` must have length >= n.
1733///
1734/// Optimizations over calling `cosine_similarity` N times:
1735/// 1. Query norm computed once (not N times)
1736/// 2. Fused dot+norm kernel — each vector loaded once (halves bandwidth)
1737/// 3. No per-call overhead (branch prediction, function calls)
1738#[inline]
1739pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
1740    let n = scores.len();
1741    debug_assert!(vectors.len() >= n * dim);
1742    debug_assert_eq!(query.len(), dim);
1743
1744    if dim == 0 || n == 0 {
1745        return;
1746    }
1747
1748    // Pre-compute query norm once
1749    let norm_q_sq = dot_product_f32(query, query, dim);
1750    if norm_q_sq < f32::EPSILON {
1751        for s in scores.iter_mut() {
1752            *s = 0.0;
1753        }
1754        return;
1755    }
1756    let norm_q = norm_q_sq.sqrt();
1757
1758    for i in 0..n {
1759        let vec = &vectors[i * dim..(i + 1) * dim];
1760        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
1761        if norm_v_sq < f32::EPSILON {
1762            scores[i] = 0.0;
1763        } else {
1764            scores[i] = dot / (norm_q * norm_v_sq.sqrt());
1765        }
1766    }
1767}
1768
1769/// Batch cosine similarity with stride: query vs N vectors stored at `stride` apart.
1770///
1771/// Computes cosine similarity using only the first `dim` elements of each vector,
1772/// where vectors are spaced `stride` elements apart in the buffer. This avoids
1773/// copying/trimming when only a prefix of each vector is needed (e.g., MRL).
1774///
1775/// - `query.len()` must equal `dim`
1776/// - `vectors.len()` must be >= `(n-1) * stride + dim`
1777/// - When `stride == dim`, this is identical to [`batch_cosine_scores`]
1778///
1779/// Example: 768-dim vectors, MRL trim to 256 →
1780///   `dim = 256, stride = 768` — reads first 256 floats of each vector, skips rest.
1781///   3× less SIMD work, zero copies.
1782#[inline]
1783pub fn batch_cosine_scores_strided(
1784    query: &[f32],
1785    vectors: &[f32],
1786    dim: usize,
1787    stride: usize,
1788    scores: &mut [f32],
1789) {
1790    let n = scores.len();
1791    debug_assert_eq!(query.len(), dim);
1792    debug_assert!(stride >= dim);
1793
1794    if dim == 0 || n == 0 {
1795        return;
1796    }
1797
1798    // Pre-compute query norm once
1799    let norm_q_sq = dot_product_f32(query, query, dim);
1800    if norm_q_sq < f32::EPSILON {
1801        for s in scores.iter_mut() {
1802            *s = 0.0;
1803        }
1804        return;
1805    }
1806    let norm_q = norm_q_sq.sqrt();
1807
1808    for (i, score) in scores.iter_mut().enumerate() {
1809        let start = i * stride;
1810        let vec = &vectors[start..start + dim];
1811        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
1812        *score = if norm_v_sq < f32::EPSILON {
1813            0.0
1814        } else {
1815            dot / (norm_q * norm_v_sq.sqrt())
1816        };
1817    }
1818}
1819
1820/// Compute cosine similarity between two f32 vectors with SIMD acceleration
1821///
1822/// Returns dot(a,b) / (||a|| * ||b||), range [-1, 1]
1823/// Returns 0.0 if either vector has zero norm.
1824#[inline]
1825pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1826    debug_assert_eq!(a.len(), b.len());
1827    let count = a.len();
1828
1829    if count == 0 {
1830        return 0.0;
1831    }
1832
1833    let dot = dot_product_f32(a, b, count);
1834    let norm_a = dot_product_f32(a, a, count);
1835    let norm_b = dot_product_f32(b, b, count);
1836
1837    let denom = (norm_a * norm_b).sqrt();
1838    if denom < f32::EPSILON {
1839        return 0.0;
1840    }
1841
1842    dot / denom
1843}
1844
1845/// Compute squared Euclidean distance between two f32 vectors with SIMD acceleration
1846///
1847/// Returns sum((a[i] - b[i])^2) for all i
1848#[inline]
1849pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
1850    debug_assert_eq!(a.len(), b.len());
1851    let count = a.len();
1852
1853    if count == 0 {
1854        return 0.0;
1855    }
1856
1857    #[cfg(target_arch = "aarch64")]
1858    {
1859        if neon::is_available() {
1860            return unsafe { squared_euclidean_neon(a, b, count) };
1861        }
1862    }
1863
1864    #[cfg(target_arch = "x86_64")]
1865    {
1866        if avx2::is_available() {
1867            return unsafe { squared_euclidean_avx2(a, b, count) };
1868        }
1869        if sse::is_available() {
1870            return unsafe { squared_euclidean_sse(a, b, count) };
1871        }
1872    }
1873
1874    // Scalar fallback
1875    a.iter()
1876        .zip(b.iter())
1877        .map(|(&x, &y)| {
1878            let d = x - y;
1879            d * d
1880        })
1881        .sum()
1882}
1883
1884#[cfg(target_arch = "aarch64")]
1885#[target_feature(enable = "neon")]
1886#[allow(unsafe_op_in_unsafe_fn)]
1887unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1888    use std::arch::aarch64::*;
1889
1890    let chunks = count / 4;
1891    let remainder = count % 4;
1892
1893    let mut acc = vdupq_n_f32(0.0);
1894
1895    for chunk in 0..chunks {
1896        let base = chunk * 4;
1897        let va = vld1q_f32(a.as_ptr().add(base));
1898        let vb = vld1q_f32(b.as_ptr().add(base));
1899        let diff = vsubq_f32(va, vb);
1900        acc = vfmaq_f32(acc, diff, diff); // acc += diff * diff (fused multiply-add)
1901    }
1902
1903    // Horizontal sum
1904    let mut sum = vaddvq_f32(acc);
1905
1906    // Handle remainder
1907    let base = chunks * 4;
1908    for i in 0..remainder {
1909        let d = a[base + i] - b[base + i];
1910        sum += d * d;
1911    }
1912
1913    sum
1914}
1915
1916#[cfg(target_arch = "x86_64")]
1917#[target_feature(enable = "sse")]
1918#[allow(unsafe_op_in_unsafe_fn)]
1919unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1920    use std::arch::x86_64::*;
1921
1922    let chunks = count / 4;
1923    let remainder = count % 4;
1924
1925    let mut acc = _mm_setzero_ps();
1926
1927    for chunk in 0..chunks {
1928        let base = chunk * 4;
1929        let va = _mm_loadu_ps(a.as_ptr().add(base));
1930        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1931        let diff = _mm_sub_ps(va, vb);
1932        acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
1933    }
1934
1935    // Horizontal sum: [a, b, c, d] -> a + b + c + d
1936    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
1937    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
1938    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
1939    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
1940
1941    let mut sum = _mm_cvtss_f32(final_sum);
1942
1943    // Handle remainder
1944    let base = chunks * 4;
1945    for i in 0..remainder {
1946        let d = a[base + i] - b[base + i];
1947        sum += d * d;
1948    }
1949
1950    sum
1951}
1952
1953#[cfg(target_arch = "x86_64")]
1954#[target_feature(enable = "avx2")]
1955#[allow(unsafe_op_in_unsafe_fn)]
1956unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
1957    use std::arch::x86_64::*;
1958
1959    let chunks = count / 8;
1960    let remainder = count % 8;
1961
1962    let mut acc = _mm256_setzero_ps();
1963
1964    for chunk in 0..chunks {
1965        let base = chunk * 8;
1966        let va = _mm256_loadu_ps(a.as_ptr().add(base));
1967        let vb = _mm256_loadu_ps(b.as_ptr().add(base));
1968        let diff = _mm256_sub_ps(va, vb);
1969        acc = _mm256_fmadd_ps(diff, diff, acc); // acc += diff * diff (FMA)
1970    }
1971
1972    // Horizontal sum of 8 floats
1973    // First, add high 128 bits to low 128 bits
1974    let high = _mm256_extractf128_ps(acc, 1);
1975    let low = _mm256_castps256_ps128(acc);
1976    let sum128 = _mm_add_ps(low, high);
1977
1978    // Now sum the 4 floats in sum128
1979    let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
1980    let sums = _mm_add_ps(sum128, shuf);
1981    let shuf2 = _mm_movehl_ps(sums, sums);
1982    let final_sum = _mm_add_ss(sums, shuf2);
1983
1984    let mut sum = _mm_cvtss_f32(final_sum);
1985
1986    // Handle remainder
1987    let base = chunks * 8;
1988    for i in 0..remainder {
1989        let d = a[base + i] - b[base + i];
1990        sum += d * d;
1991    }
1992
1993    sum
1994}
1995
1996/// Batch compute squared Euclidean distances from one query to multiple vectors
1997///
1998/// Returns distances[i] = squared_euclidean_distance(query, vectors[i])
1999/// This is more efficient than calling squared_euclidean_distance in a loop
2000/// because we can keep the query in registers.
2001#[inline]
2002pub fn batch_squared_euclidean_distances(
2003    query: &[f32],
2004    vectors: &[Vec<f32>],
2005    distances: &mut [f32],
2006) {
2007    debug_assert_eq!(vectors.len(), distances.len());
2008
2009    #[cfg(target_arch = "x86_64")]
2010    {
2011        if avx2::is_available() {
2012            for (i, vec) in vectors.iter().enumerate() {
2013                distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
2014            }
2015            return;
2016        }
2017    }
2018
2019    // Fallback to individual calls
2020    for (i, vec) in vectors.iter().enumerate() {
2021        distances[i] = squared_euclidean_distance(query, vec);
2022    }
2023}
2024
2025#[cfg(test)]
2026mod tests {
2027    use super::*;
2028
2029    #[test]
2030    fn test_unpack_8bit() {
2031        let input: Vec<u8> = (0..128).collect();
2032        let mut output = vec![0u32; 128];
2033        unpack_8bit(&input, &mut output, 128);
2034
2035        for (i, &v) in output.iter().enumerate() {
2036            assert_eq!(v, i as u32);
2037        }
2038    }
2039
2040    #[test]
2041    fn test_unpack_16bit() {
2042        let mut input = vec![0u8; 256];
2043        for i in 0..128 {
2044            let val = (i * 100) as u16;
2045            input[i * 2] = val as u8;
2046            input[i * 2 + 1] = (val >> 8) as u8;
2047        }
2048
2049        let mut output = vec![0u32; 128];
2050        unpack_16bit(&input, &mut output, 128);
2051
2052        for (i, &v) in output.iter().enumerate() {
2053            assert_eq!(v, (i * 100) as u32);
2054        }
2055    }
2056
2057    #[test]
2058    fn test_unpack_32bit() {
2059        let mut input = vec![0u8; 512];
2060        for i in 0..128 {
2061            let val = (i * 1000) as u32;
2062            let bytes = val.to_le_bytes();
2063            input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
2064        }
2065
2066        let mut output = vec![0u32; 128];
2067        unpack_32bit(&input, &mut output, 128);
2068
2069        for (i, &v) in output.iter().enumerate() {
2070            assert_eq!(v, (i * 1000) as u32);
2071        }
2072    }
2073
2074    #[test]
2075    fn test_delta_decode() {
2076        // doc_ids: [10, 15, 20, 30, 50]
2077        // gaps: [5, 5, 10, 20]
2078        // deltas (gap-1): [4, 4, 9, 19]
2079        let deltas = vec![4u32, 4, 9, 19];
2080        let mut output = vec![0u32; 5];
2081
2082        delta_decode(&mut output, &deltas, 10, 5);
2083
2084        assert_eq!(output, vec![10, 15, 20, 30, 50]);
2085    }
2086
2087    #[test]
2088    fn test_add_one() {
2089        let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
2090        add_one(&mut values, 8);
2091
2092        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
2093    }
2094
2095    #[test]
2096    fn test_bits_needed() {
2097        assert_eq!(bits_needed(0), 0);
2098        assert_eq!(bits_needed(1), 1);
2099        assert_eq!(bits_needed(2), 2);
2100        assert_eq!(bits_needed(3), 2);
2101        assert_eq!(bits_needed(4), 3);
2102        assert_eq!(bits_needed(255), 8);
2103        assert_eq!(bits_needed(256), 9);
2104        assert_eq!(bits_needed(u32::MAX), 32);
2105    }
2106
2107    #[test]
2108    fn test_unpack_8bit_delta_decode() {
2109        // doc_ids: [10, 15, 20, 30, 50]
2110        // gaps: [5, 5, 10, 20]
2111        // deltas (gap-1): [4, 4, 9, 19] stored as u8
2112        let input: Vec<u8> = vec![4, 4, 9, 19];
2113        let mut output = vec![0u32; 5];
2114
2115        unpack_8bit_delta_decode(&input, &mut output, 10, 5);
2116
2117        assert_eq!(output, vec![10, 15, 20, 30, 50]);
2118    }
2119
2120    #[test]
2121    fn test_unpack_16bit_delta_decode() {
2122        // doc_ids: [100, 600, 1100, 2100, 4100]
2123        // gaps: [500, 500, 1000, 2000]
2124        // deltas (gap-1): [499, 499, 999, 1999] stored as u16
2125        let mut input = vec![0u8; 8];
2126        for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
2127            input[i * 2] = delta as u8;
2128            input[i * 2 + 1] = (delta >> 8) as u8;
2129        }
2130        let mut output = vec![0u32; 5];
2131
2132        unpack_16bit_delta_decode(&input, &mut output, 100, 5);
2133
2134        assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
2135    }
2136
2137    #[test]
2138    fn test_fused_vs_separate_8bit() {
2139        // Test that fused and separate operations produce the same result
2140        let input: Vec<u8> = (0..127).collect();
2141        let first_value = 1000u32;
2142        let count = 128;
2143
2144        // Separate: unpack then delta_decode
2145        let mut unpacked = vec![0u32; 128];
2146        unpack_8bit(&input, &mut unpacked, 127);
2147        let mut separate_output = vec![0u32; 128];
2148        delta_decode(&mut separate_output, &unpacked, first_value, count);
2149
2150        // Fused
2151        let mut fused_output = vec![0u32; 128];
2152        unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
2153
2154        assert_eq!(separate_output, fused_output);
2155    }
2156
2157    #[test]
2158    fn test_round_bit_width() {
2159        assert_eq!(round_bit_width(0), 0);
2160        assert_eq!(round_bit_width(1), 8);
2161        assert_eq!(round_bit_width(5), 8);
2162        assert_eq!(round_bit_width(8), 8);
2163        assert_eq!(round_bit_width(9), 16);
2164        assert_eq!(round_bit_width(12), 16);
2165        assert_eq!(round_bit_width(16), 16);
2166        assert_eq!(round_bit_width(17), 32);
2167        assert_eq!(round_bit_width(24), 32);
2168        assert_eq!(round_bit_width(32), 32);
2169    }
2170
2171    #[test]
2172    fn test_rounded_bitwidth_from_exact() {
2173        assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
2174        assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
2175        assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
2176        assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
2177        assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
2178        assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
2179        assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
2180    }
2181
2182    #[test]
2183    fn test_pack_unpack_rounded_8bit() {
2184        let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
2185        let mut packed = vec![0u8; 128];
2186
2187        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
2188        assert_eq!(bytes_written, 128);
2189
2190        let mut unpacked = vec![0u32; 128];
2191        unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
2192
2193        assert_eq!(values, unpacked);
2194    }
2195
2196    #[test]
2197    fn test_pack_unpack_rounded_16bit() {
2198        let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
2199        let mut packed = vec![0u8; 256];
2200
2201        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
2202        assert_eq!(bytes_written, 256);
2203
2204        let mut unpacked = vec![0u32; 128];
2205        unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
2206
2207        assert_eq!(values, unpacked);
2208    }
2209
2210    #[test]
2211    fn test_pack_unpack_rounded_32bit() {
2212        let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
2213        let mut packed = vec![0u8; 512];
2214
2215        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
2216        assert_eq!(bytes_written, 512);
2217
2218        let mut unpacked = vec![0u32; 128];
2219        unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
2220
2221        assert_eq!(values, unpacked);
2222    }
2223
2224    #[test]
2225    fn test_unpack_rounded_delta_decode() {
2226        // Test 8-bit rounded delta decode
2227        // doc_ids: [10, 15, 20, 30, 50]
2228        // gaps: [5, 5, 10, 20]
2229        // deltas (gap-1): [4, 4, 9, 19] stored as u8
2230        let input: Vec<u8> = vec![4, 4, 9, 19];
2231        let mut output = vec![0u32; 5];
2232
2233        unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
2234
2235        assert_eq!(output, vec![10, 15, 20, 30, 50]);
2236    }
2237
2238    #[test]
2239    fn test_unpack_rounded_delta_decode_zero() {
2240        // All zeros means gaps of 1 (consecutive doc IDs)
2241        let input: Vec<u8> = vec![];
2242        let mut output = vec![0u32; 5];
2243
2244        unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
2245
2246        assert_eq!(output, vec![100, 101, 102, 103, 104]);
2247    }
2248
2249    // ========================================================================
2250    // Sparse Vector SIMD Tests
2251    // ========================================================================
2252
2253    #[test]
2254    fn test_dequantize_uint8() {
2255        let input: Vec<u8> = vec![0, 128, 255, 64, 192];
2256        let mut output = vec![0.0f32; 5];
2257        let scale = 0.1;
2258        let min_val = 1.0;
2259
2260        dequantize_uint8(&input, &mut output, scale, min_val, 5);
2261
2262        // Expected: input[i] * scale + min_val
2263        assert!((output[0] - 1.0).abs() < 1e-6); // 0 * 0.1 + 1.0 = 1.0
2264        assert!((output[1] - 13.8).abs() < 1e-6); // 128 * 0.1 + 1.0 = 13.8
2265        assert!((output[2] - 26.5).abs() < 1e-6); // 255 * 0.1 + 1.0 = 26.5
2266        assert!((output[3] - 7.4).abs() < 1e-6); // 64 * 0.1 + 1.0 = 7.4
2267        assert!((output[4] - 20.2).abs() < 1e-6); // 192 * 0.1 + 1.0 = 20.2
2268    }
2269
2270    #[test]
2271    fn test_dequantize_uint8_large() {
2272        // Test with 128 values (full SIMD block)
2273        let input: Vec<u8> = (0..128).collect();
2274        let mut output = vec![0.0f32; 128];
2275        let scale = 2.0;
2276        let min_val = -10.0;
2277
2278        dequantize_uint8(&input, &mut output, scale, min_val, 128);
2279
2280        for (i, &out) in output.iter().enumerate().take(128) {
2281            let expected = i as f32 * scale + min_val;
2282            assert!(
2283                (out - expected).abs() < 1e-5,
2284                "Mismatch at {}: expected {}, got {}",
2285                i,
2286                expected,
2287                out
2288            );
2289        }
2290    }
2291
2292    #[test]
2293    fn test_dot_product_f32() {
2294        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
2295        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
2296
2297        let result = dot_product_f32(&a, &b, 5);
2298
2299        // Expected: 1*2 + 2*3 + 3*4 + 4*5 + 5*6 = 2 + 6 + 12 + 20 + 30 = 70
2300        assert!((result - 70.0).abs() < 1e-5);
2301    }
2302
2303    #[test]
2304    fn test_dot_product_f32_large() {
2305        // Test with 128 values
2306        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
2307        let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
2308
2309        let result = dot_product_f32(&a, &b, 128);
2310
2311        // Compute expected
2312        let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
2313        assert!(
2314            (result - expected).abs() < 1e-3,
2315            "Expected {}, got {}",
2316            expected,
2317            result
2318        );
2319    }
2320
2321    #[test]
2322    fn test_max_f32() {
2323        let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
2324        let result = max_f32(&values, 6);
2325        assert!((result - 9.0).abs() < 1e-6);
2326    }
2327
2328    #[test]
2329    fn test_max_f32_large() {
2330        // Test with 128 values, max at position 77
2331        let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
2332        values[77] = 1000.0;
2333
2334        let result = max_f32(&values, 128);
2335        assert!((result - 1000.0).abs() < 1e-5);
2336    }
2337
2338    #[test]
2339    fn test_max_f32_negative() {
2340        let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
2341        let result = max_f32(&values, 5);
2342        assert!((result - (-1.0)).abs() < 1e-6);
2343    }
2344
2345    #[test]
2346    fn test_max_f32_empty() {
2347        let values: Vec<f32> = vec![];
2348        let result = max_f32(&values, 0);
2349        assert_eq!(result, f32::NEG_INFINITY);
2350    }
2351
2352    #[test]
2353    fn test_fused_dot_norm() {
2354        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2355        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2356        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
2357
2358        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
2359        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
2360        assert!(
2361            (dot - expected_dot).abs() < 1e-5,
2362            "dot: expected {}, got {}",
2363            expected_dot,
2364            dot
2365        );
2366        assert!(
2367            (norm_b - expected_norm).abs() < 1e-5,
2368            "norm: expected {}, got {}",
2369            expected_norm,
2370            norm_b
2371        );
2372    }
2373
2374    #[test]
2375    fn test_fused_dot_norm_large() {
2376        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
2377        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
2378        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
2379
2380        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
2381        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
2382        assert!(
2383            (dot - expected_dot).abs() < 1.0,
2384            "dot: expected {}, got {}",
2385            expected_dot,
2386            dot
2387        );
2388        assert!(
2389            (norm_b - expected_norm).abs() < 1.0,
2390            "norm: expected {}, got {}",
2391            expected_norm,
2392            norm_b
2393        );
2394    }
2395
2396    #[test]
2397    fn test_batch_cosine_scores() {
2398        // 4 vectors of dim 3
2399        let query = vec![1.0f32, 0.0, 0.0];
2400        let vectors = vec![
2401            1.0, 0.0, 0.0, // identical to query
2402            0.0, 1.0, 0.0, // orthogonal
2403            -1.0, 0.0, 0.0, // opposite
2404            0.5, 0.5, 0.0, // 45 degrees
2405        ];
2406        let mut scores = vec![0f32; 4];
2407        batch_cosine_scores(&query, &vectors, 3, &mut scores);
2408
2409        assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
2410        assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
2411        assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
2412        let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
2413        assert!(
2414            (scores[3] - expected_45).abs() < 1e-5,
2415            "45deg: expected {}, got {}",
2416            expected_45,
2417            scores[3]
2418        );
2419    }
2420
2421    #[test]
2422    fn test_batch_cosine_scores_matches_individual() {
2423        let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
2424        let n = 50;
2425        let dim = 128;
2426        let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
2427
2428        let mut batch_scores = vec![0f32; n];
2429        batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
2430
2431        for i in 0..n {
2432            let vec_i = &vectors[i * dim..(i + 1) * dim];
2433            let individual = cosine_similarity(&query, vec_i);
2434            assert!(
2435                (batch_scores[i] - individual).abs() < 1e-5,
2436                "vec {}: batch={}, individual={}",
2437                i,
2438                batch_scores[i],
2439                individual
2440            );
2441        }
2442    }
2443
2444    #[test]
2445    fn test_batch_cosine_scores_empty() {
2446        let query = vec![1.0f32, 2.0, 3.0];
2447        let vectors: Vec<f32> = vec![];
2448        let mut scores: Vec<f32> = vec![];
2449        batch_cosine_scores(&query, &vectors, 3, &mut scores);
2450        assert!(scores.is_empty());
2451    }
2452
2453    #[test]
2454    fn test_batch_cosine_scores_zero_query() {
2455        let query = vec![0.0f32, 0.0, 0.0];
2456        let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
2457        let mut scores = vec![0f32; 2];
2458        batch_cosine_scores(&query, &vectors, 3, &mut scores);
2459        assert_eq!(scores[0], 0.0);
2460        assert_eq!(scores[1], 0.0);
2461    }
2462
2463    #[test]
2464    fn test_batch_cosine_scores_strided_mrl() {
2465        // Simulate MRL: 4 vectors at full_dim=8, query trimmed to dim=3
2466        // Vectors: [v0(8 floats), v1(8 floats), v2(8 floats), v3(8 floats)]
2467        // Only first 3 floats of each vector should be used
2468        let query = vec![1.0f32, 0.0, 0.0]; // dim=3
2469        let vectors = vec![
2470            1.0, 0.0, 0.0, 99.0, 99.0, 99.0, 99.0, 99.0, // v0: identical prefix
2471            0.0, 1.0, 0.0, 99.0, 99.0, 99.0, 99.0, 99.0, // v1: orthogonal prefix
2472            -1.0, 0.0, 0.0, 99.0, 99.0, 99.0, 99.0, 99.0, // v2: opposite prefix
2473            0.5, 0.5, 0.0, 99.0, 99.0, 99.0, 99.0, 99.0, // v3: 45-degree prefix
2474        ];
2475        let mut scores = vec![0f32; 4];
2476        batch_cosine_scores_strided(&query, &vectors, 3, 8, &mut scores);
2477
2478        assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
2479        assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
2480        assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
2481        let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
2482        assert!(
2483            (scores[3] - expected_45).abs() < 1e-5,
2484            "45deg: {}",
2485            scores[3]
2486        );
2487    }
2488
2489    #[test]
2490    fn test_batch_cosine_scores_strided_matches_trimmed() {
2491        // Verify strided version matches manually-trimmed batch version
2492        let full_dim = 128;
2493        let trim_dim = 32; // MRL dimension
2494        let n = 20;
2495
2496        let full_query: Vec<f32> = (0..full_dim).map(|i| (i as f32) * 0.1).collect();
2497        let trimmed_query = &full_query[..trim_dim];
2498        let vectors: Vec<f32> = (0..n * full_dim)
2499            .map(|i| ((i * 7 + 3) as f32) * 0.01)
2500            .collect();
2501
2502        // Strided: operate on full buffer with stride
2503        let mut strided_scores = vec![0f32; n];
2504        batch_cosine_scores_strided(
2505            trimmed_query,
2506            &vectors,
2507            trim_dim,
2508            full_dim,
2509            &mut strided_scores,
2510        );
2511
2512        // Manual trim: copy first trim_dim floats per vector, then batch
2513        let trimmed_vectors: Vec<f32> = (0..n)
2514            .flat_map(|i| {
2515                vectors[i * full_dim..i * full_dim + trim_dim]
2516                    .iter()
2517                    .copied()
2518            })
2519            .collect();
2520        let mut trimmed_scores = vec![0f32; n];
2521        batch_cosine_scores(
2522            trimmed_query,
2523            &trimmed_vectors,
2524            trim_dim,
2525            &mut trimmed_scores,
2526        );
2527
2528        for i in 0..n {
2529            assert!(
2530                (strided_scores[i] - trimmed_scores[i]).abs() < 1e-5,
2531                "vec {}: strided={}, trimmed={}",
2532                i,
2533                strided_scores[i],
2534                trimmed_scores[i]
2535            );
2536        }
2537    }
2538
2539    #[test]
2540    fn test_batch_cosine_scores_strided_equals_non_strided() {
2541        // When stride == dim, should produce identical results to batch_cosine_scores
2542        let dim = 64;
2543        let n = 30;
2544        let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1).collect();
2545        let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 3 + 1) as f32) * 0.01).collect();
2546
2547        let mut scores_batch = vec![0f32; n];
2548        batch_cosine_scores(&query, &vectors, dim, &mut scores_batch);
2549
2550        let mut scores_strided = vec![0f32; n];
2551        batch_cosine_scores_strided(&query, &vectors, dim, dim, &mut scores_strided);
2552
2553        for i in 0..n {
2554            assert!(
2555                (scores_batch[i] - scores_strided[i]).abs() < 1e-6,
2556                "vec {}: batch={}, strided={}",
2557                i,
2558                scores_batch[i],
2559                scores_strided[i]
2560            );
2561        }
2562    }
2563
2564    #[test]
2565    fn test_squared_euclidean_distance() {
2566        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2567        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2568        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
2569        let result = squared_euclidean_distance(&a, &b);
2570        assert!(
2571            (result - expected).abs() < 1e-5,
2572            "expected {}, got {}",
2573            expected,
2574            result
2575        );
2576    }
2577
2578    #[test]
2579    fn test_squared_euclidean_distance_large() {
2580        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
2581        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
2582        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
2583        let result = squared_euclidean_distance(&a, &b);
2584        assert!(
2585            (result - expected).abs() < 1e-3,
2586            "expected {}, got {}",
2587            expected,
2588            result
2589        );
2590    }
2591}