Skip to main content

hermes_core/structures/
simd.rs

1//! Shared SIMD-accelerated functions for posting list compression
2//!
3//! This module provides platform-optimized implementations for common operations:
4//! - **Unpacking**: Convert packed 8/16/32-bit values to u32 arrays
5//! - **Delta decoding**: Prefix sum for converting deltas to absolute values
6//! - **Add one**: Increment all values in an array (for TF decoding)
7//!
8//! Supports:
9//! - **NEON** on aarch64 (Apple Silicon, ARM servers)
10//! - **SSE/SSE4.1** on x86_64 (Intel/AMD)
11//! - **Scalar fallback** for other architectures
12
13// ============================================================================
14// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
15// ============================================================================
16
17#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20    use std::arch::aarch64::*;
21
22    /// SIMD unpack for 8-bit values using NEON
23    #[target_feature(enable = "neon")]
24    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25        let chunks = count / 16;
26        let remainder = count % 16;
27
28        for chunk in 0..chunks {
29            let base = chunk * 16;
30            let in_ptr = input.as_ptr().add(base);
31
32            // Load 16 bytes
33            let bytes = vld1q_u8(in_ptr);
34
35            // Widen u8 -> u16 -> u32
36            let low8 = vget_low_u8(bytes);
37            let high8 = vget_high_u8(bytes);
38
39            let low16 = vmovl_u8(low8);
40            let high16 = vmovl_u8(high8);
41
42            let v0 = vmovl_u16(vget_low_u16(low16));
43            let v1 = vmovl_u16(vget_high_u16(low16));
44            let v2 = vmovl_u16(vget_low_u16(high16));
45            let v3 = vmovl_u16(vget_high_u16(high16));
46
47            let out_ptr = output.as_mut_ptr().add(base);
48            vst1q_u32(out_ptr, v0);
49            vst1q_u32(out_ptr.add(4), v1);
50            vst1q_u32(out_ptr.add(8), v2);
51            vst1q_u32(out_ptr.add(12), v3);
52        }
53
54        // Handle remainder
55        let base = chunks * 16;
56        for i in 0..remainder {
57            output[base + i] = input[base + i] as u32;
58        }
59    }
60
61    /// SIMD unpack for 16-bit values using NEON
62    #[target_feature(enable = "neon")]
63    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64        let chunks = count / 8;
65        let remainder = count % 8;
66
67        for chunk in 0..chunks {
68            let base = chunk * 8;
69            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71            let vals = vld1q_u16(in_ptr);
72            let low = vmovl_u16(vget_low_u16(vals));
73            let high = vmovl_u16(vget_high_u16(vals));
74
75            let out_ptr = output.as_mut_ptr().add(base);
76            vst1q_u32(out_ptr, low);
77            vst1q_u32(out_ptr.add(4), high);
78        }
79
80        // Handle remainder
81        let base = chunks * 8;
82        for i in 0..remainder {
83            let idx = (base + i) * 2;
84            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85        }
86    }
87
88    /// SIMD unpack for 32-bit values using NEON (fast copy)
89    #[target_feature(enable = "neon")]
90    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91        let chunks = count / 4;
92        let remainder = count % 4;
93
94        let in_ptr = input.as_ptr() as *const u32;
95        let out_ptr = output.as_mut_ptr();
96
97        for chunk in 0..chunks {
98            let vals = vld1q_u32(in_ptr.add(chunk * 4));
99            vst1q_u32(out_ptr.add(chunk * 4), vals);
100        }
101
102        // Handle remainder
103        let base = chunks * 4;
104        for i in 0..remainder {
105            let idx = (base + i) * 4;
106            output[base + i] =
107                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108        }
109    }
110
111    /// SIMD prefix sum for 4 u32 values using NEON
112    /// Input:  [a, b, c, d]
113    /// Output: [a, a+b, a+b+c, a+b+c+d]
114    #[inline]
115    #[target_feature(enable = "neon")]
116    unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117        // Step 1: shift by 1 and add
118        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
119        let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120        let sum1 = vaddq_u32(v, shifted1);
121
122        // Step 2: shift by 2 and add
123        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
124        let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125        vaddq_u32(sum1, shifted2)
126    }
127
128    /// SIMD delta decode: convert deltas to absolute doc IDs
129    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
130    /// Uses NEON SIMD prefix sum for high throughput
131    #[target_feature(enable = "neon")]
132    pub unsafe fn delta_decode(
133        output: &mut [u32],
134        deltas: &[u32],
135        first_doc_id: u32,
136        count: usize,
137    ) {
138        if count == 0 {
139            return;
140        }
141
142        output[0] = first_doc_id;
143        if count == 1 {
144            return;
145        }
146
147        let ones = vdupq_n_u32(1);
148        let mut carry = vdupq_n_u32(first_doc_id);
149
150        let full_groups = (count - 1) / 4;
151        let remainder = (count - 1) % 4;
152
153        for group in 0..full_groups {
154            let base = group * 4;
155
156            // Load 4 deltas and add 1 (since we store gap-1)
157            let d = vld1q_u32(deltas[base..].as_ptr());
158            let gaps = vaddq_u32(d, ones);
159
160            // Compute prefix sum within the 4 elements
161            let prefix = prefix_sum_4(gaps);
162
163            // Add carry (broadcast last element of previous group)
164            let result = vaddq_u32(prefix, carry);
165
166            // Store result
167            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169            // Update carry: broadcast the last element for next iteration
170            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171        }
172
173        // Handle remainder
174        let base = full_groups * 4;
175        let mut scalar_carry = vgetq_lane_u32(carry, 0);
176        for j in 0..remainder {
177            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178            output[base + j + 1] = scalar_carry;
179        }
180    }
181
182    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
183    #[target_feature(enable = "neon")]
184    pub unsafe fn add_one(values: &mut [u32], count: usize) {
185        let ones = vdupq_n_u32(1);
186        let chunks = count / 4;
187        let remainder = count % 4;
188
189        for chunk in 0..chunks {
190            let base = chunk * 4;
191            let ptr = values.as_mut_ptr().add(base);
192            let v = vld1q_u32(ptr);
193            let result = vaddq_u32(v, ones);
194            vst1q_u32(ptr, result);
195        }
196
197        let base = chunks * 4;
198        for i in 0..remainder {
199            values[base + i] += 1;
200        }
201    }
202
203    /// Fused unpack 8-bit + delta decode using NEON
204    /// Processes 4 values at a time, fusing unpack and prefix sum
205    #[target_feature(enable = "neon")]
206    pub unsafe fn unpack_8bit_delta_decode(
207        input: &[u8],
208        output: &mut [u32],
209        first_value: u32,
210        count: usize,
211    ) {
212        output[0] = first_value;
213        if count <= 1 {
214            return;
215        }
216
217        let ones = vdupq_n_u32(1);
218        let mut carry = vdupq_n_u32(first_value);
219
220        let full_groups = (count - 1) / 4;
221        let remainder = (count - 1) % 4;
222
223        for group in 0..full_groups {
224            let base = group * 4;
225
226            // Load 4 bytes as a u32, then widen u8→u16→u32 via NEON
227            let raw = std::ptr::read_unaligned(input.as_ptr().add(base) as *const u32);
228            let bytes = vreinterpret_u8_u32(vdup_n_u32(raw));
229            let u16s = vmovl_u8(bytes); // 8×u8 → 8×u16 (only low 4 matter)
230            let d = vmovl_u16(vget_low_u16(u16s)); // 4×u16 → 4×u32
231
232            // Add 1 (since we store gap-1)
233            let gaps = vaddq_u32(d, ones);
234
235            // Compute prefix sum within the 4 elements
236            let prefix = prefix_sum_4(gaps);
237
238            // Add carry
239            let result = vaddq_u32(prefix, carry);
240
241            // Store result
242            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
243
244            // Update carry
245            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
246        }
247
248        // Handle remainder
249        let base = full_groups * 4;
250        let mut scalar_carry = vgetq_lane_u32(carry, 0);
251        for j in 0..remainder {
252            scalar_carry = scalar_carry
253                .wrapping_add(input[base + j] as u32)
254                .wrapping_add(1);
255            output[base + j + 1] = scalar_carry;
256        }
257    }
258
259    /// Fused unpack 16-bit + delta decode using NEON
260    #[target_feature(enable = "neon")]
261    pub unsafe fn unpack_16bit_delta_decode(
262        input: &[u8],
263        output: &mut [u32],
264        first_value: u32,
265        count: usize,
266    ) {
267        output[0] = first_value;
268        if count <= 1 {
269            return;
270        }
271
272        let ones = vdupq_n_u32(1);
273        let mut carry = vdupq_n_u32(first_value);
274
275        let full_groups = (count - 1) / 4;
276        let remainder = (count - 1) % 4;
277
278        for group in 0..full_groups {
279            let base = group * 4;
280            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
281
282            // Load 4 u16 values and widen to u32
283            let vals = vld1_u16(in_ptr);
284            let d = vmovl_u16(vals);
285
286            // Add 1 (since we store gap-1)
287            let gaps = vaddq_u32(d, ones);
288
289            // Compute prefix sum within the 4 elements
290            let prefix = prefix_sum_4(gaps);
291
292            // Add carry
293            let result = vaddq_u32(prefix, carry);
294
295            // Store result
296            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
297
298            // Update carry
299            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
300        }
301
302        // Handle remainder
303        let base = full_groups * 4;
304        let mut scalar_carry = vgetq_lane_u32(carry, 0);
305        for j in 0..remainder {
306            let idx = (base + j) * 2;
307            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
308            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
309            output[base + j + 1] = scalar_carry;
310        }
311    }
312
313    /// Check if NEON is available (always true on aarch64)
314    #[inline]
315    pub fn is_available() -> bool {
316        true
317    }
318}
319
320// ============================================================================
321// SSE intrinsics for x86_64 (Intel/AMD)
322// ============================================================================
323
324#[cfg(target_arch = "x86_64")]
325#[allow(unsafe_op_in_unsafe_fn)]
326mod sse {
327    use std::arch::x86_64::*;
328
329    /// SIMD unpack for 8-bit values using SSE
330    #[target_feature(enable = "sse2", enable = "sse4.1")]
331    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
332        let chunks = count / 16;
333        let remainder = count % 16;
334
335        for chunk in 0..chunks {
336            let base = chunk * 16;
337            let in_ptr = input.as_ptr().add(base);
338
339            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
340
341            // Zero extend u8 -> u32 using SSE4.1 pmovzx
342            let v0 = _mm_cvtepu8_epi32(bytes);
343            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
344            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
345            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
346
347            let out_ptr = output.as_mut_ptr().add(base);
348            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
349            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
350            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
351            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
352        }
353
354        let base = chunks * 16;
355        for i in 0..remainder {
356            output[base + i] = input[base + i] as u32;
357        }
358    }
359
360    /// SIMD unpack for 16-bit values using SSE
361    #[target_feature(enable = "sse2", enable = "sse4.1")]
362    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
363        let chunks = count / 8;
364        let remainder = count % 8;
365
366        for chunk in 0..chunks {
367            let base = chunk * 8;
368            let in_ptr = input.as_ptr().add(base * 2);
369
370            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
371            let low = _mm_cvtepu16_epi32(vals);
372            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
373
374            let out_ptr = output.as_mut_ptr().add(base);
375            _mm_storeu_si128(out_ptr as *mut __m128i, low);
376            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
377        }
378
379        let base = chunks * 8;
380        for i in 0..remainder {
381            let idx = (base + i) * 2;
382            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
383        }
384    }
385
386    /// SIMD unpack for 32-bit values using SSE (fast copy)
387    #[target_feature(enable = "sse2")]
388    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
389        let chunks = count / 4;
390        let remainder = count % 4;
391
392        let in_ptr = input.as_ptr() as *const __m128i;
393        let out_ptr = output.as_mut_ptr() as *mut __m128i;
394
395        for chunk in 0..chunks {
396            let vals = _mm_loadu_si128(in_ptr.add(chunk));
397            _mm_storeu_si128(out_ptr.add(chunk), vals);
398        }
399
400        // Handle remainder
401        let base = chunks * 4;
402        for i in 0..remainder {
403            let idx = (base + i) * 4;
404            output[base + i] =
405                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
406        }
407    }
408
409    /// SIMD prefix sum for 4 u32 values using SSE
410    /// Input:  [a, b, c, d]
411    /// Output: [a, a+b, a+b+c, a+b+c+d]
412    #[inline]
413    #[target_feature(enable = "sse2")]
414    unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
415        // Step 1: shift by 1 element (4 bytes) and add
416        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
417        let shifted1 = _mm_slli_si128(v, 4);
418        let sum1 = _mm_add_epi32(v, shifted1);
419
420        // Step 2: shift by 2 elements (8 bytes) and add
421        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
422        let shifted2 = _mm_slli_si128(sum1, 8);
423        _mm_add_epi32(sum1, shifted2)
424    }
425
426    /// SIMD delta decode using SSE with true SIMD prefix sum
427    #[target_feature(enable = "sse2", enable = "sse4.1")]
428    pub unsafe fn delta_decode(
429        output: &mut [u32],
430        deltas: &[u32],
431        first_doc_id: u32,
432        count: usize,
433    ) {
434        if count == 0 {
435            return;
436        }
437
438        output[0] = first_doc_id;
439        if count == 1 {
440            return;
441        }
442
443        let ones = _mm_set1_epi32(1);
444        let mut carry = _mm_set1_epi32(first_doc_id as i32);
445
446        let full_groups = (count - 1) / 4;
447        let remainder = (count - 1) % 4;
448
449        for group in 0..full_groups {
450            let base = group * 4;
451
452            // Load 4 deltas and add 1 (since we store gap-1)
453            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
454            let gaps = _mm_add_epi32(d, ones);
455
456            // Compute prefix sum within the 4 elements
457            let prefix = prefix_sum_4(gaps);
458
459            // Add carry (broadcast last element of previous group)
460            let result = _mm_add_epi32(prefix, carry);
461
462            // Store result
463            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
464
465            // Update carry: broadcast the last element for next iteration
466            carry = _mm_shuffle_epi32(result, 0xFF); // broadcast lane 3
467        }
468
469        // Handle remainder
470        let base = full_groups * 4;
471        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
472        for j in 0..remainder {
473            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
474            output[base + j + 1] = scalar_carry;
475        }
476    }
477
478    /// SIMD add 1 to all values using SSE
479    #[target_feature(enable = "sse2")]
480    pub unsafe fn add_one(values: &mut [u32], count: usize) {
481        let ones = _mm_set1_epi32(1);
482        let chunks = count / 4;
483        let remainder = count % 4;
484
485        for chunk in 0..chunks {
486            let base = chunk * 4;
487            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
488            let v = _mm_loadu_si128(ptr);
489            let result = _mm_add_epi32(v, ones);
490            _mm_storeu_si128(ptr, result);
491        }
492
493        let base = chunks * 4;
494        for i in 0..remainder {
495            values[base + i] += 1;
496        }
497    }
498
499    /// Fused unpack 8-bit + delta decode using SSE
500    #[target_feature(enable = "sse2", enable = "sse4.1")]
501    pub unsafe fn unpack_8bit_delta_decode(
502        input: &[u8],
503        output: &mut [u32],
504        first_value: u32,
505        count: usize,
506    ) {
507        output[0] = first_value;
508        if count <= 1 {
509            return;
510        }
511
512        let ones = _mm_set1_epi32(1);
513        let mut carry = _mm_set1_epi32(first_value as i32);
514
515        let full_groups = (count - 1) / 4;
516        let remainder = (count - 1) % 4;
517
518        for group in 0..full_groups {
519            let base = group * 4;
520
521            // Load 4 bytes (unaligned) and zero-extend to u32
522            let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
523                input.as_ptr().add(base) as *const i32
524            ));
525            let d = _mm_cvtepu8_epi32(bytes);
526
527            // Add 1 (since we store gap-1)
528            let gaps = _mm_add_epi32(d, ones);
529
530            // Compute prefix sum within the 4 elements
531            let prefix = prefix_sum_4(gaps);
532
533            // Add carry
534            let result = _mm_add_epi32(prefix, carry);
535
536            // Store result
537            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
538
539            // Update carry: broadcast the last element
540            carry = _mm_shuffle_epi32(result, 0xFF);
541        }
542
543        // Handle remainder
544        let base = full_groups * 4;
545        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
546        for j in 0..remainder {
547            scalar_carry = scalar_carry
548                .wrapping_add(input[base + j] as u32)
549                .wrapping_add(1);
550            output[base + j + 1] = scalar_carry;
551        }
552    }
553
554    /// Fused unpack 16-bit + delta decode using SSE
555    #[target_feature(enable = "sse2", enable = "sse4.1")]
556    pub unsafe fn unpack_16bit_delta_decode(
557        input: &[u8],
558        output: &mut [u32],
559        first_value: u32,
560        count: usize,
561    ) {
562        output[0] = first_value;
563        if count <= 1 {
564            return;
565        }
566
567        let ones = _mm_set1_epi32(1);
568        let mut carry = _mm_set1_epi32(first_value as i32);
569
570        let full_groups = (count - 1) / 4;
571        let remainder = (count - 1) % 4;
572
573        for group in 0..full_groups {
574            let base = group * 4;
575            let in_ptr = input.as_ptr().add(base * 2);
576
577            // Load 8 bytes (4 u16 values, unaligned) and zero-extend to u32
578            let vals = _mm_loadl_epi64(in_ptr as *const __m128i); // loadl_epi64 supports unaligned
579            let d = _mm_cvtepu16_epi32(vals);
580
581            // Add 1 (since we store gap-1)
582            let gaps = _mm_add_epi32(d, ones);
583
584            // Compute prefix sum within the 4 elements
585            let prefix = prefix_sum_4(gaps);
586
587            // Add carry
588            let result = _mm_add_epi32(prefix, carry);
589
590            // Store result
591            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
592
593            // Update carry: broadcast the last element
594            carry = _mm_shuffle_epi32(result, 0xFF);
595        }
596
597        // Handle remainder
598        let base = full_groups * 4;
599        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
600        for j in 0..remainder {
601            let idx = (base + j) * 2;
602            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
603            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
604            output[base + j + 1] = scalar_carry;
605        }
606    }
607
608    /// Check if SSE4.1 is available at runtime
609    #[inline]
610    pub fn is_available() -> bool {
611        is_x86_feature_detected!("sse4.1")
612    }
613}
614
615// ============================================================================
616// AVX2 intrinsics for x86_64 (Intel/AMD with 256-bit registers)
617// ============================================================================
618
619#[cfg(target_arch = "x86_64")]
620#[allow(unsafe_op_in_unsafe_fn)]
621mod avx2 {
622    use std::arch::x86_64::*;
623
624    /// AVX2 unpack for 8-bit values (processes 32 bytes at a time)
625    #[target_feature(enable = "avx2")]
626    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
627        let chunks = count / 32;
628        let remainder = count % 32;
629
630        for chunk in 0..chunks {
631            let base = chunk * 32;
632            let in_ptr = input.as_ptr().add(base);
633
634            // Load 32 bytes (two 128-bit loads, then combine)
635            let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
636            let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
637
638            // Zero extend first 16 bytes: u8 -> u32
639            let v0 = _mm256_cvtepu8_epi32(bytes_lo);
640            let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
641            let v2 = _mm256_cvtepu8_epi32(bytes_hi);
642            let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
643
644            let out_ptr = output.as_mut_ptr().add(base);
645            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
646            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
647            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
648            _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
649        }
650
651        // Handle remainder with SSE
652        let base = chunks * 32;
653        for i in 0..remainder {
654            output[base + i] = input[base + i] as u32;
655        }
656    }
657
658    /// AVX2 unpack for 16-bit values (processes 16 values at a time)
659    #[target_feature(enable = "avx2")]
660    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
661        let chunks = count / 16;
662        let remainder = count % 16;
663
664        for chunk in 0..chunks {
665            let base = chunk * 16;
666            let in_ptr = input.as_ptr().add(base * 2);
667
668            // Load 32 bytes (16 u16 values)
669            let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
670            let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
671
672            // Zero extend u16 -> u32
673            let v0 = _mm256_cvtepu16_epi32(vals_lo);
674            let v1 = _mm256_cvtepu16_epi32(vals_hi);
675
676            let out_ptr = output.as_mut_ptr().add(base);
677            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
678            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
679        }
680
681        // Handle remainder
682        let base = chunks * 16;
683        for i in 0..remainder {
684            let idx = (base + i) * 2;
685            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
686        }
687    }
688
689    /// AVX2 unpack for 32-bit values (fast copy, 8 values at a time)
690    #[target_feature(enable = "avx2")]
691    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
692        let chunks = count / 8;
693        let remainder = count % 8;
694
695        let in_ptr = input.as_ptr() as *const __m256i;
696        let out_ptr = output.as_mut_ptr() as *mut __m256i;
697
698        for chunk in 0..chunks {
699            let vals = _mm256_loadu_si256(in_ptr.add(chunk));
700            _mm256_storeu_si256(out_ptr.add(chunk), vals);
701        }
702
703        // Handle remainder
704        let base = chunks * 8;
705        for i in 0..remainder {
706            let idx = (base + i) * 4;
707            output[base + i] =
708                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
709        }
710    }
711
712    /// AVX2 add 1 to all values (8 values at a time)
713    #[target_feature(enable = "avx2")]
714    pub unsafe fn add_one(values: &mut [u32], count: usize) {
715        let ones = _mm256_set1_epi32(1);
716        let chunks = count / 8;
717        let remainder = count % 8;
718
719        for chunk in 0..chunks {
720            let base = chunk * 8;
721            let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
722            let v = _mm256_loadu_si256(ptr);
723            let result = _mm256_add_epi32(v, ones);
724            _mm256_storeu_si256(ptr, result);
725        }
726
727        let base = chunks * 8;
728        for i in 0..remainder {
729            values[base + i] += 1;
730        }
731    }
732
733    /// AVX2 prefix sum for 8 u32 values (Hillis-Steele)
734    /// Input:  [a, b, c, d, e, f, g, h]
735    /// Output: [a, a+b, a+b+c, ..., a+b+c+d+e+f+g+h]
736    #[inline]
737    #[target_feature(enable = "avx2")]
738    unsafe fn prefix_sum_8(v: __m256i) -> __m256i {
739        // Step 1: intra-lane shift by 1 element (4 bytes) and add
740        let s1 = _mm256_slli_si256(v, 4);
741        let r1 = _mm256_add_epi32(v, s1);
742
743        // Step 2: intra-lane shift by 2 elements (8 bytes) and add
744        let s2 = _mm256_slli_si256(r1, 8);
745        let r2 = _mm256_add_epi32(r1, s2);
746
747        // Step 3: propagate lower lane sum to upper lane
748        // Broadcast element 3 (lower lane sum) within each lane
749        let lo_sum = _mm256_shuffle_epi32(r2, 0xFF);
750        // Duplicate lane 0 to both lanes
751        let carry = _mm256_permute2x128_si256(lo_sum, lo_sum, 0x00);
752        // Zero carry for lower lane, keep for upper
753        let carry_hi = _mm256_blend_epi32::<0xF0>(_mm256_setzero_si256(), carry);
754        _mm256_add_epi32(r2, carry_hi)
755    }
756
757    /// AVX2 fused unpack 8-bit + delta decode (processes 8 values at a time)
758    #[target_feature(enable = "avx2")]
759    pub unsafe fn unpack_8bit_delta_decode(
760        input: &[u8],
761        output: &mut [u32],
762        first_value: u32,
763        count: usize,
764    ) {
765        output[0] = first_value;
766        if count <= 1 {
767            return;
768        }
769
770        let ones = _mm256_set1_epi32(1);
771        let mut carry = _mm256_set1_epi32(first_value as i32);
772        let broadcast_idx = _mm256_set1_epi32(7);
773
774        let full_groups = (count - 1) / 8;
775        let remainder = (count - 1) % 8;
776
777        for group in 0..full_groups {
778            let base = group * 8;
779
780            // Load 8 bytes and zero-extend to 8×u32
781            let bytes = _mm_loadl_epi64(input.as_ptr().add(base) as *const __m128i);
782            let d = _mm256_cvtepu8_epi32(bytes);
783
784            // Add 1 (since we store gap-1)
785            let gaps = _mm256_add_epi32(d, ones);
786
787            // Compute prefix sum within 8 elements
788            let prefix = prefix_sum_8(gaps);
789
790            // Add carry from previous group
791            let result = _mm256_add_epi32(prefix, carry);
792
793            // Store 8 results
794            _mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
795
796            // Update carry: broadcast element 7 to all positions
797            carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
798        }
799
800        // Handle remainder with scalar
801        let base = full_groups * 8;
802        let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
803        for j in 0..remainder {
804            scalar_carry = scalar_carry
805                .wrapping_add(input[base + j] as u32)
806                .wrapping_add(1);
807            output[base + j + 1] = scalar_carry;
808        }
809    }
810
811    /// AVX2 fused unpack 16-bit + delta decode (processes 8 values at a time)
812    #[target_feature(enable = "avx2")]
813    pub unsafe fn unpack_16bit_delta_decode(
814        input: &[u8],
815        output: &mut [u32],
816        first_value: u32,
817        count: usize,
818    ) {
819        output[0] = first_value;
820        if count <= 1 {
821            return;
822        }
823
824        let ones = _mm256_set1_epi32(1);
825        let mut carry = _mm256_set1_epi32(first_value as i32);
826        let broadcast_idx = _mm256_set1_epi32(7);
827
828        let full_groups = (count - 1) / 8;
829        let remainder = (count - 1) % 8;
830
831        for group in 0..full_groups {
832            let base = group * 8;
833            let in_ptr = input.as_ptr().add(base * 2);
834
835            // Load 16 bytes (8 u16 values) and zero-extend to 8×u32
836            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
837            let d = _mm256_cvtepu16_epi32(vals);
838
839            // Add 1 (since we store gap-1)
840            let gaps = _mm256_add_epi32(d, ones);
841
842            // Compute prefix sum within 8 elements
843            let prefix = prefix_sum_8(gaps);
844
845            // Add carry from previous group
846            let result = _mm256_add_epi32(prefix, carry);
847
848            // Store 8 results
849            _mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
850
851            // Update carry: broadcast element 7 to all positions
852            carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
853        }
854
855        // Handle remainder with scalar
856        let base = full_groups * 8;
857        let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
858        for j in 0..remainder {
859            let idx = (base + j) * 2;
860            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
861            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
862            output[base + j + 1] = scalar_carry;
863        }
864    }
865
866    /// Check if AVX2 is available at runtime
867    #[inline]
868    pub fn is_available() -> bool {
869        is_x86_feature_detected!("avx2")
870    }
871}
872
873// ============================================================================
874// Scalar fallback implementations
875// ============================================================================
876
877#[allow(dead_code)]
878mod scalar {
879    /// Scalar unpack for 8-bit values
880    #[inline]
881    pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
882        for i in 0..count {
883            output[i] = input[i] as u32;
884        }
885    }
886
887    /// Scalar unpack for 16-bit values
888    #[inline]
889    pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
890        for (i, out) in output.iter_mut().enumerate().take(count) {
891            let idx = i * 2;
892            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
893        }
894    }
895
896    /// Scalar unpack for 32-bit values
897    #[inline]
898    pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
899        for (i, out) in output.iter_mut().enumerate().take(count) {
900            let idx = i * 4;
901            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
902        }
903    }
904
905    /// Scalar delta decode
906    #[inline]
907    pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
908        if count == 0 {
909            return;
910        }
911
912        output[0] = first_doc_id;
913        let mut carry = first_doc_id;
914
915        for i in 0..count - 1 {
916            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
917            output[i + 1] = carry;
918        }
919    }
920
921    /// Scalar add 1 to all values
922    #[inline]
923    pub fn add_one(values: &mut [u32], count: usize) {
924        for val in values.iter_mut().take(count) {
925            *val += 1;
926        }
927    }
928}
929
930// ============================================================================
931// Public dispatch functions that select SIMD or scalar at runtime
932// ============================================================================
933
934/// Unpack 8-bit packed values to u32 with SIMD acceleration
935#[inline]
936pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
937    #[cfg(target_arch = "aarch64")]
938    {
939        if neon::is_available() {
940            unsafe {
941                neon::unpack_8bit(input, output, count);
942            }
943            return;
944        }
945    }
946
947    #[cfg(target_arch = "x86_64")]
948    {
949        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
950        if avx2::is_available() {
951            unsafe {
952                avx2::unpack_8bit(input, output, count);
953            }
954            return;
955        }
956        if sse::is_available() {
957            unsafe {
958                sse::unpack_8bit(input, output, count);
959            }
960            return;
961        }
962    }
963
964    scalar::unpack_8bit(input, output, count);
965}
966
967/// Unpack 16-bit packed values to u32 with SIMD acceleration
968#[inline]
969pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
970    #[cfg(target_arch = "aarch64")]
971    {
972        if neon::is_available() {
973            unsafe {
974                neon::unpack_16bit(input, output, count);
975            }
976            return;
977        }
978    }
979
980    #[cfg(target_arch = "x86_64")]
981    {
982        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
983        if avx2::is_available() {
984            unsafe {
985                avx2::unpack_16bit(input, output, count);
986            }
987            return;
988        }
989        if sse::is_available() {
990            unsafe {
991                sse::unpack_16bit(input, output, count);
992            }
993            return;
994        }
995    }
996
997    scalar::unpack_16bit(input, output, count);
998}
999
1000/// Unpack 32-bit packed values to u32 with SIMD acceleration
1001#[inline]
1002pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
1003    #[cfg(target_arch = "aarch64")]
1004    {
1005        if neon::is_available() {
1006            unsafe {
1007                neon::unpack_32bit(input, output, count);
1008            }
1009        }
1010    }
1011
1012    #[cfg(target_arch = "x86_64")]
1013    {
1014        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
1015        if avx2::is_available() {
1016            unsafe {
1017                avx2::unpack_32bit(input, output, count);
1018            }
1019        } else {
1020            // SSE2 is always available on x86_64
1021            unsafe {
1022                sse::unpack_32bit(input, output, count);
1023            }
1024        }
1025    }
1026
1027    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
1028    {
1029        scalar::unpack_32bit(input, output, count);
1030    }
1031}
1032
1033/// Delta decode with SIMD acceleration
1034///
1035/// Converts delta-encoded values to absolute values.
1036/// Input: deltas[i] = value[i+1] - value[i] - 1 (gap minus one)
1037/// Output: absolute values starting from first_value
1038#[inline]
1039pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
1040    #[cfg(target_arch = "aarch64")]
1041    {
1042        if neon::is_available() {
1043            unsafe {
1044                neon::delta_decode(output, deltas, first_value, count);
1045            }
1046            return;
1047        }
1048    }
1049
1050    #[cfg(target_arch = "x86_64")]
1051    {
1052        if sse::is_available() {
1053            unsafe {
1054                sse::delta_decode(output, deltas, first_value, count);
1055            }
1056            return;
1057        }
1058    }
1059
1060    scalar::delta_decode(output, deltas, first_value, count);
1061}
1062
1063/// Add 1 to all values with SIMD acceleration
1064///
1065/// Used for TF decoding where values are stored as (tf - 1)
1066#[inline]
1067pub fn add_one(values: &mut [u32], count: usize) {
1068    #[cfg(target_arch = "aarch64")]
1069    {
1070        if neon::is_available() {
1071            unsafe {
1072                neon::add_one(values, count);
1073            }
1074        }
1075    }
1076
1077    #[cfg(target_arch = "x86_64")]
1078    {
1079        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
1080        if avx2::is_available() {
1081            unsafe {
1082                avx2::add_one(values, count);
1083            }
1084        } else {
1085            // SSE2 is always available on x86_64
1086            unsafe {
1087                sse::add_one(values, count);
1088            }
1089        }
1090    }
1091
1092    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
1093    {
1094        scalar::add_one(values, count);
1095    }
1096}
1097
1098/// Compute the number of bits needed to represent a value
1099#[inline]
1100pub fn bits_needed(val: u32) -> u8 {
1101    if val == 0 {
1102        0
1103    } else {
1104        32 - val.leading_zeros() as u8
1105    }
1106}
1107
1108// ============================================================================
1109// Rounded bitpacking for truly vectorized encoding/decoding
1110// ============================================================================
1111//
1112// Instead of using arbitrary bit widths (1-32), we round up to SIMD-friendly
1113// widths: 0, 8, 16, or 32 bits. This trades ~10-20% more space for much faster
1114// decoding since we can use direct SIMD widening instructions (pmovzx) without
1115// any bit-shifting or masking.
1116//
1117// Bit width mapping:
1118//   0      -> 0  (all zeros)
1119//   1-8    -> 8  (u8)
1120//   9-16   -> 16 (u16)
1121//   17-32  -> 32 (u32)
1122
1123/// Rounded bit width type for SIMD-friendly encoding
1124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1125#[repr(u8)]
1126pub enum RoundedBitWidth {
1127    Zero = 0,
1128    Bits8 = 8,
1129    Bits16 = 16,
1130    Bits32 = 32,
1131}
1132
1133impl RoundedBitWidth {
1134    /// Round an exact bit width to the nearest SIMD-friendly width
1135    #[inline]
1136    pub fn from_exact(bits: u8) -> Self {
1137        match bits {
1138            0 => RoundedBitWidth::Zero,
1139            1..=8 => RoundedBitWidth::Bits8,
1140            9..=16 => RoundedBitWidth::Bits16,
1141            _ => RoundedBitWidth::Bits32,
1142        }
1143    }
1144
1145    /// Convert from stored u8 value (must be 0, 8, 16, or 32)
1146    #[inline]
1147    pub fn from_u8(bits: u8) -> Self {
1148        match bits {
1149            0 => RoundedBitWidth::Zero,
1150            8 => RoundedBitWidth::Bits8,
1151            16 => RoundedBitWidth::Bits16,
1152            32 => RoundedBitWidth::Bits32,
1153            _ => RoundedBitWidth::Bits32, // Fallback for invalid values
1154        }
1155    }
1156
1157    /// Get the byte size per value
1158    #[inline]
1159    pub fn bytes_per_value(self) -> usize {
1160        match self {
1161            RoundedBitWidth::Zero => 0,
1162            RoundedBitWidth::Bits8 => 1,
1163            RoundedBitWidth::Bits16 => 2,
1164            RoundedBitWidth::Bits32 => 4,
1165        }
1166    }
1167
1168    /// Get the raw bit width value
1169    #[inline]
1170    pub fn as_u8(self) -> u8 {
1171        self as u8
1172    }
1173}
1174
1175/// Round a bit width to the nearest SIMD-friendly width (0, 8, 16, or 32)
1176#[inline]
1177pub fn round_bit_width(bits: u8) -> u8 {
1178    RoundedBitWidth::from_exact(bits).as_u8()
1179}
1180
1181/// Pack values using rounded bit width (SIMD-friendly)
1182///
1183/// This is much simpler than arbitrary bitpacking since values are byte-aligned.
1184/// Returns the number of bytes written.
1185#[inline]
1186pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1187    let count = values.len();
1188    match bit_width {
1189        RoundedBitWidth::Zero => 0,
1190        RoundedBitWidth::Bits8 => {
1191            for (i, &v) in values.iter().enumerate() {
1192                output[i] = v as u8;
1193            }
1194            count
1195        }
1196        RoundedBitWidth::Bits16 => {
1197            for (i, &v) in values.iter().enumerate() {
1198                let bytes = (v as u16).to_le_bytes();
1199                output[i * 2] = bytes[0];
1200                output[i * 2 + 1] = bytes[1];
1201            }
1202            count * 2
1203        }
1204        RoundedBitWidth::Bits32 => {
1205            for (i, &v) in values.iter().enumerate() {
1206                let bytes = v.to_le_bytes();
1207                output[i * 4] = bytes[0];
1208                output[i * 4 + 1] = bytes[1];
1209                output[i * 4 + 2] = bytes[2];
1210                output[i * 4 + 3] = bytes[3];
1211            }
1212            count * 4
1213        }
1214    }
1215}
1216
1217/// Unpack values using rounded bit width with SIMD acceleration
1218///
1219/// This is the fast path - no bit manipulation needed, just widening.
1220#[inline]
1221pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1222    match bit_width {
1223        RoundedBitWidth::Zero => {
1224            for out in output.iter_mut().take(count) {
1225                *out = 0;
1226            }
1227        }
1228        RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1229        RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1230        RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1231    }
1232}
1233
1234/// Fused unpack + delta decode using rounded bit width
1235///
1236/// Combines unpacking and prefix sum in a single pass for better cache utilization.
1237#[inline]
1238pub fn unpack_rounded_delta_decode(
1239    input: &[u8],
1240    bit_width: RoundedBitWidth,
1241    output: &mut [u32],
1242    first_value: u32,
1243    count: usize,
1244) {
1245    match bit_width {
1246        RoundedBitWidth::Zero => {
1247            // All deltas are 0, meaning gaps of 1
1248            let mut val = first_value;
1249            for out in output.iter_mut().take(count) {
1250                *out = val;
1251                val = val.wrapping_add(1);
1252            }
1253        }
1254        RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1255        RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1256        RoundedBitWidth::Bits32 => {
1257            // Unpack count-1 deltas from input, then prefix sum to absolute values
1258            if count > 0 {
1259                output[0] = first_value;
1260                let mut carry = first_value;
1261                for i in 0..count - 1 {
1262                    let idx = i * 4;
1263                    let delta = u32::from_le_bytes([
1264                        input[idx],
1265                        input[idx + 1],
1266                        input[idx + 2],
1267                        input[idx + 3],
1268                    ]);
1269                    carry = carry.wrapping_add(delta).wrapping_add(1);
1270                    output[i + 1] = carry;
1271                }
1272            }
1273        }
1274    }
1275}
1276
1277// ============================================================================
1278// Fused operations for better cache utilization
1279// ============================================================================
1280
1281/// Fused unpack 8-bit + delta decode in a single pass
1282///
1283/// This avoids writing the intermediate unpacked values to memory,
1284/// improving cache utilization for large blocks.
1285#[inline]
1286pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1287    if count == 0 {
1288        return;
1289    }
1290
1291    output[0] = first_value;
1292    if count == 1 {
1293        return;
1294    }
1295
1296    #[cfg(target_arch = "aarch64")]
1297    {
1298        if neon::is_available() {
1299            unsafe {
1300                neon::unpack_8bit_delta_decode(input, output, first_value, count);
1301            }
1302            return;
1303        }
1304    }
1305
1306    #[cfg(target_arch = "x86_64")]
1307    {
1308        if avx2::is_available() {
1309            unsafe {
1310                avx2::unpack_8bit_delta_decode(input, output, first_value, count);
1311            }
1312            return;
1313        }
1314        if sse::is_available() {
1315            unsafe {
1316                sse::unpack_8bit_delta_decode(input, output, first_value, count);
1317            }
1318            return;
1319        }
1320    }
1321
1322    // Scalar fallback
1323    let mut carry = first_value;
1324    for i in 0..count - 1 {
1325        carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1326        output[i + 1] = carry;
1327    }
1328}
1329
1330/// Fused unpack 16-bit + delta decode in a single pass
1331#[inline]
1332pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1333    if count == 0 {
1334        return;
1335    }
1336
1337    output[0] = first_value;
1338    if count == 1 {
1339        return;
1340    }
1341
1342    #[cfg(target_arch = "aarch64")]
1343    {
1344        if neon::is_available() {
1345            unsafe {
1346                neon::unpack_16bit_delta_decode(input, output, first_value, count);
1347            }
1348            return;
1349        }
1350    }
1351
1352    #[cfg(target_arch = "x86_64")]
1353    {
1354        if avx2::is_available() {
1355            unsafe {
1356                avx2::unpack_16bit_delta_decode(input, output, first_value, count);
1357            }
1358            return;
1359        }
1360        if sse::is_available() {
1361            unsafe {
1362                sse::unpack_16bit_delta_decode(input, output, first_value, count);
1363            }
1364            return;
1365        }
1366    }
1367
1368    // Scalar fallback
1369    let mut carry = first_value;
1370    for i in 0..count - 1 {
1371        let idx = i * 2;
1372        let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1373        carry = carry.wrapping_add(delta).wrapping_add(1);
1374        output[i + 1] = carry;
1375    }
1376}
1377
1378/// Fused unpack + delta decode for arbitrary bit widths
1379///
1380/// Combines unpacking and prefix sum in a single pass, avoiding intermediate buffer.
1381/// Uses SIMD-accelerated paths for 8/16-bit widths, scalar for others.
1382#[inline]
1383pub fn unpack_delta_decode(
1384    input: &[u8],
1385    bit_width: u8,
1386    output: &mut [u32],
1387    first_value: u32,
1388    count: usize,
1389) {
1390    if count == 0 {
1391        return;
1392    }
1393
1394    output[0] = first_value;
1395    if count == 1 {
1396        return;
1397    }
1398
1399    // Fast paths for SIMD-friendly bit widths
1400    match bit_width {
1401        0 => {
1402            // All zeros = consecutive doc IDs (gap of 1)
1403            let mut val = first_value;
1404            for item in output.iter_mut().take(count).skip(1) {
1405                val = val.wrapping_add(1);
1406                *item = val;
1407            }
1408        }
1409        8 => unpack_8bit_delta_decode(input, output, first_value, count),
1410        16 => unpack_16bit_delta_decode(input, output, first_value, count),
1411        32 => {
1412            // 32-bit: unpack inline and delta decode
1413            let mut carry = first_value;
1414            for i in 0..count - 1 {
1415                let idx = i * 4;
1416                let delta = u32::from_le_bytes([
1417                    input[idx],
1418                    input[idx + 1],
1419                    input[idx + 2],
1420                    input[idx + 3],
1421                ]);
1422                carry = carry.wrapping_add(delta).wrapping_add(1);
1423                output[i + 1] = carry;
1424            }
1425        }
1426        _ => {
1427            // Generic bit width: fused unpack + delta decode
1428            let mask = (1u64 << bit_width) - 1;
1429            let bit_width_usize = bit_width as usize;
1430            let mut bit_pos = 0usize;
1431            let input_ptr = input.as_ptr();
1432            let mut carry = first_value;
1433
1434            for i in 0..count - 1 {
1435                let byte_idx = bit_pos >> 3;
1436                let bit_offset = bit_pos & 7;
1437
1438                // SAFETY: Caller guarantees input has enough data
1439                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1440                let delta = ((word >> bit_offset) & mask) as u32;
1441
1442                carry = carry.wrapping_add(delta).wrapping_add(1);
1443                output[i + 1] = carry;
1444                bit_pos += bit_width_usize;
1445            }
1446        }
1447    }
1448}
1449
1450// ============================================================================
1451// Sparse Vector SIMD Functions
1452// ============================================================================
1453
1454/// Dequantize UInt8 weights to f32 with SIMD acceleration
1455///
1456/// Computes: output[i] = input[i] as f32 * scale + min_val
1457#[inline]
1458pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1459    #[cfg(target_arch = "aarch64")]
1460    {
1461        if neon::is_available() {
1462            unsafe {
1463                dequantize_uint8_neon(input, output, scale, min_val, count);
1464            }
1465            return;
1466        }
1467    }
1468
1469    #[cfg(target_arch = "x86_64")]
1470    {
1471        if sse::is_available() {
1472            unsafe {
1473                dequantize_uint8_sse(input, output, scale, min_val, count);
1474            }
1475            return;
1476        }
1477    }
1478
1479    // Scalar fallback
1480    for i in 0..count {
1481        output[i] = input[i] as f32 * scale + min_val;
1482    }
1483}
1484
1485#[cfg(target_arch = "aarch64")]
1486#[target_feature(enable = "neon")]
1487#[allow(unsafe_op_in_unsafe_fn)]
1488unsafe fn dequantize_uint8_neon(
1489    input: &[u8],
1490    output: &mut [f32],
1491    scale: f32,
1492    min_val: f32,
1493    count: usize,
1494) {
1495    use std::arch::aarch64::*;
1496
1497    let scale_v = vdupq_n_f32(scale);
1498    let min_v = vdupq_n_f32(min_val);
1499
1500    let chunks = count / 16;
1501    let remainder = count % 16;
1502
1503    for chunk in 0..chunks {
1504        let base = chunk * 16;
1505        let in_ptr = input.as_ptr().add(base);
1506
1507        // Load 16 bytes
1508        let bytes = vld1q_u8(in_ptr);
1509
1510        // Widen u8 -> u16 -> u32 -> f32
1511        let low8 = vget_low_u8(bytes);
1512        let high8 = vget_high_u8(bytes);
1513
1514        let low16 = vmovl_u8(low8);
1515        let high16 = vmovl_u8(high8);
1516
1517        // Process 4 values at a time
1518        let u32_0 = vmovl_u16(vget_low_u16(low16));
1519        let u32_1 = vmovl_u16(vget_high_u16(low16));
1520        let u32_2 = vmovl_u16(vget_low_u16(high16));
1521        let u32_3 = vmovl_u16(vget_high_u16(high16));
1522
1523        // Convert to f32 and apply scale + min_val
1524        let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1525        let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1526        let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1527        let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1528
1529        let out_ptr = output.as_mut_ptr().add(base);
1530        vst1q_f32(out_ptr, f32_0);
1531        vst1q_f32(out_ptr.add(4), f32_1);
1532        vst1q_f32(out_ptr.add(8), f32_2);
1533        vst1q_f32(out_ptr.add(12), f32_3);
1534    }
1535
1536    // Handle remainder
1537    let base = chunks * 16;
1538    for i in 0..remainder {
1539        output[base + i] = input[base + i] as f32 * scale + min_val;
1540    }
1541}
1542
1543#[cfg(target_arch = "x86_64")]
1544#[target_feature(enable = "sse2", enable = "sse4.1")]
1545#[allow(unsafe_op_in_unsafe_fn)]
1546unsafe fn dequantize_uint8_sse(
1547    input: &[u8],
1548    output: &mut [f32],
1549    scale: f32,
1550    min_val: f32,
1551    count: usize,
1552) {
1553    use std::arch::x86_64::*;
1554
1555    let scale_v = _mm_set1_ps(scale);
1556    let min_v = _mm_set1_ps(min_val);
1557
1558    let chunks = count / 4;
1559    let remainder = count % 4;
1560
1561    for chunk in 0..chunks {
1562        let base = chunk * 4;
1563
1564        // Load 4 bytes as a single i32 and zero-extend u8→u32 via SSE4.1
1565        let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
1566            input.as_ptr().add(base) as *const i32
1567        ));
1568        let ints = _mm_cvtepu8_epi32(bytes);
1569        let floats = _mm_cvtepi32_ps(ints);
1570
1571        // Apply scale and min_val: result = floats * scale + min_val
1572        let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1573
1574        _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1575    }
1576
1577    // Handle remainder
1578    let base = chunks * 4;
1579    for i in 0..remainder {
1580        output[base + i] = input[base + i] as f32 * scale + min_val;
1581    }
1582}
1583
1584/// Compute dot product of two f32 arrays with SIMD acceleration
1585#[inline]
1586pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1587    #[cfg(target_arch = "aarch64")]
1588    {
1589        if neon::is_available() {
1590            return unsafe { dot_product_f32_neon(a, b, count) };
1591        }
1592    }
1593
1594    #[cfg(target_arch = "x86_64")]
1595    {
1596        if is_x86_feature_detected!("avx512f") {
1597            return unsafe { dot_product_f32_avx512(a, b, count) };
1598        }
1599        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1600            return unsafe { dot_product_f32_avx2(a, b, count) };
1601        }
1602        if sse::is_available() {
1603            return unsafe { dot_product_f32_sse(a, b, count) };
1604        }
1605    }
1606
1607    // Scalar fallback
1608    let mut sum = 0.0f32;
1609    for i in 0..count {
1610        sum += a[i] * b[i];
1611    }
1612    sum
1613}
1614
1615#[cfg(target_arch = "aarch64")]
1616#[target_feature(enable = "neon")]
1617#[allow(unsafe_op_in_unsafe_fn)]
1618unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1619    use std::arch::aarch64::*;
1620
1621    let chunks16 = count / 16;
1622    let remainder = count % 16;
1623
1624    let mut acc0 = vdupq_n_f32(0.0);
1625    let mut acc1 = vdupq_n_f32(0.0);
1626    let mut acc2 = vdupq_n_f32(0.0);
1627    let mut acc3 = vdupq_n_f32(0.0);
1628
1629    for c in 0..chunks16 {
1630        let base = c * 16;
1631        acc0 = vfmaq_f32(
1632            acc0,
1633            vld1q_f32(a.as_ptr().add(base)),
1634            vld1q_f32(b.as_ptr().add(base)),
1635        );
1636        acc1 = vfmaq_f32(
1637            acc1,
1638            vld1q_f32(a.as_ptr().add(base + 4)),
1639            vld1q_f32(b.as_ptr().add(base + 4)),
1640        );
1641        acc2 = vfmaq_f32(
1642            acc2,
1643            vld1q_f32(a.as_ptr().add(base + 8)),
1644            vld1q_f32(b.as_ptr().add(base + 8)),
1645        );
1646        acc3 = vfmaq_f32(
1647            acc3,
1648            vld1q_f32(a.as_ptr().add(base + 12)),
1649            vld1q_f32(b.as_ptr().add(base + 12)),
1650        );
1651    }
1652
1653    let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1654    let mut sum = vaddvq_f32(acc);
1655
1656    let base = chunks16 * 16;
1657    for i in 0..remainder {
1658        sum += a[base + i] * b[base + i];
1659    }
1660
1661    sum
1662}
1663
1664#[cfg(target_arch = "x86_64")]
1665#[target_feature(enable = "avx2", enable = "fma")]
1666#[allow(unsafe_op_in_unsafe_fn)]
1667unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
1668    use std::arch::x86_64::*;
1669
1670    let chunks32 = count / 32;
1671    let remainder = count % 32;
1672
1673    let mut acc0 = _mm256_setzero_ps();
1674    let mut acc1 = _mm256_setzero_ps();
1675    let mut acc2 = _mm256_setzero_ps();
1676    let mut acc3 = _mm256_setzero_ps();
1677
1678    for c in 0..chunks32 {
1679        let base = c * 32;
1680        acc0 = _mm256_fmadd_ps(
1681            _mm256_loadu_ps(a.as_ptr().add(base)),
1682            _mm256_loadu_ps(b.as_ptr().add(base)),
1683            acc0,
1684        );
1685        acc1 = _mm256_fmadd_ps(
1686            _mm256_loadu_ps(a.as_ptr().add(base + 8)),
1687            _mm256_loadu_ps(b.as_ptr().add(base + 8)),
1688            acc1,
1689        );
1690        acc2 = _mm256_fmadd_ps(
1691            _mm256_loadu_ps(a.as_ptr().add(base + 16)),
1692            _mm256_loadu_ps(b.as_ptr().add(base + 16)),
1693            acc2,
1694        );
1695        acc3 = _mm256_fmadd_ps(
1696            _mm256_loadu_ps(a.as_ptr().add(base + 24)),
1697            _mm256_loadu_ps(b.as_ptr().add(base + 24)),
1698            acc3,
1699        );
1700    }
1701
1702    let acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
1703
1704    // Horizontal sum: 256-bit → 128-bit → scalar
1705    let hi = _mm256_extractf128_ps(acc, 1);
1706    let lo = _mm256_castps256_ps128(acc);
1707    let sum128 = _mm_add_ps(lo, hi);
1708    let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
1709    let sums = _mm_add_ps(sum128, shuf);
1710    let shuf2 = _mm_movehl_ps(sums, sums);
1711    let final_sum = _mm_add_ss(sums, shuf2);
1712
1713    let mut sum = _mm_cvtss_f32(final_sum);
1714
1715    let base = chunks32 * 32;
1716    for i in 0..remainder {
1717        sum += a[base + i] * b[base + i];
1718    }
1719
1720    sum
1721}
1722
1723#[cfg(target_arch = "x86_64")]
1724#[target_feature(enable = "sse")]
1725#[allow(unsafe_op_in_unsafe_fn)]
1726unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1727    use std::arch::x86_64::*;
1728
1729    let chunks = count / 4;
1730    let remainder = count % 4;
1731
1732    let mut acc = _mm_setzero_ps();
1733
1734    for chunk in 0..chunks {
1735        let base = chunk * 4;
1736        let va = _mm_loadu_ps(a.as_ptr().add(base));
1737        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1738        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1739    }
1740
1741    // Horizontal sum: [a, b, c, d] -> a + b + c + d
1742    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
1743    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
1744    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
1745    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
1746
1747    let mut sum = _mm_cvtss_f32(final_sum);
1748
1749    // Handle remainder
1750    let base = chunks * 4;
1751    for i in 0..remainder {
1752        sum += a[base + i] * b[base + i];
1753    }
1754
1755    sum
1756}
1757
1758#[cfg(target_arch = "x86_64")]
1759#[target_feature(enable = "avx512f")]
1760#[allow(unsafe_op_in_unsafe_fn)]
1761unsafe fn dot_product_f32_avx512(a: &[f32], b: &[f32], count: usize) -> f32 {
1762    use std::arch::x86_64::*;
1763
1764    let chunks64 = count / 64;
1765    let remainder = count % 64;
1766
1767    let mut acc0 = _mm512_setzero_ps();
1768    let mut acc1 = _mm512_setzero_ps();
1769    let mut acc2 = _mm512_setzero_ps();
1770    let mut acc3 = _mm512_setzero_ps();
1771
1772    for c in 0..chunks64 {
1773        let base = c * 64;
1774        acc0 = _mm512_fmadd_ps(
1775            _mm512_loadu_ps(a.as_ptr().add(base)),
1776            _mm512_loadu_ps(b.as_ptr().add(base)),
1777            acc0,
1778        );
1779        acc1 = _mm512_fmadd_ps(
1780            _mm512_loadu_ps(a.as_ptr().add(base + 16)),
1781            _mm512_loadu_ps(b.as_ptr().add(base + 16)),
1782            acc1,
1783        );
1784        acc2 = _mm512_fmadd_ps(
1785            _mm512_loadu_ps(a.as_ptr().add(base + 32)),
1786            _mm512_loadu_ps(b.as_ptr().add(base + 32)),
1787            acc2,
1788        );
1789        acc3 = _mm512_fmadd_ps(
1790            _mm512_loadu_ps(a.as_ptr().add(base + 48)),
1791            _mm512_loadu_ps(b.as_ptr().add(base + 48)),
1792            acc3,
1793        );
1794    }
1795
1796    let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
1797    let mut sum = _mm512_reduce_add_ps(acc);
1798
1799    let base = chunks64 * 64;
1800    for i in 0..remainder {
1801        sum += a[base + i] * b[base + i];
1802    }
1803
1804    sum
1805}
1806
1807#[cfg(target_arch = "x86_64")]
1808#[target_feature(enable = "avx512f")]
1809#[allow(unsafe_op_in_unsafe_fn)]
1810unsafe fn fused_dot_norm_avx512(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1811    use std::arch::x86_64::*;
1812
1813    let chunks64 = count / 64;
1814    let remainder = count % 64;
1815
1816    let mut d0 = _mm512_setzero_ps();
1817    let mut d1 = _mm512_setzero_ps();
1818    let mut d2 = _mm512_setzero_ps();
1819    let mut d3 = _mm512_setzero_ps();
1820    let mut n0 = _mm512_setzero_ps();
1821    let mut n1 = _mm512_setzero_ps();
1822    let mut n2 = _mm512_setzero_ps();
1823    let mut n3 = _mm512_setzero_ps();
1824
1825    for c in 0..chunks64 {
1826        let base = c * 64;
1827        let vb0 = _mm512_loadu_ps(b.as_ptr().add(base));
1828        d0 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1829        n0 = _mm512_fmadd_ps(vb0, vb0, n0);
1830        let vb1 = _mm512_loadu_ps(b.as_ptr().add(base + 16));
1831        d1 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 16)), vb1, d1);
1832        n1 = _mm512_fmadd_ps(vb1, vb1, n1);
1833        let vb2 = _mm512_loadu_ps(b.as_ptr().add(base + 32));
1834        d2 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 32)), vb2, d2);
1835        n2 = _mm512_fmadd_ps(vb2, vb2, n2);
1836        let vb3 = _mm512_loadu_ps(b.as_ptr().add(base + 48));
1837        d3 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 48)), vb3, d3);
1838        n3 = _mm512_fmadd_ps(vb3, vb3, n3);
1839    }
1840
1841    let acc_dot = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
1842    let acc_norm = _mm512_add_ps(_mm512_add_ps(n0, n1), _mm512_add_ps(n2, n3));
1843    let mut dot = _mm512_reduce_add_ps(acc_dot);
1844    let mut norm = _mm512_reduce_add_ps(acc_norm);
1845
1846    let base = chunks64 * 64;
1847    for i in 0..remainder {
1848        dot += a[base + i] * b[base + i];
1849        norm += b[base + i] * b[base + i];
1850    }
1851
1852    (dot, norm)
1853}
1854
1855/// Find maximum value in f32 array with SIMD acceleration
1856#[inline]
1857pub fn max_f32(values: &[f32], count: usize) -> f32 {
1858    if count == 0 {
1859        return f32::NEG_INFINITY;
1860    }
1861
1862    #[cfg(target_arch = "aarch64")]
1863    {
1864        if neon::is_available() {
1865            return unsafe { max_f32_neon(values, count) };
1866        }
1867    }
1868
1869    #[cfg(target_arch = "x86_64")]
1870    {
1871        if sse::is_available() {
1872            return unsafe { max_f32_sse(values, count) };
1873        }
1874    }
1875
1876    // Scalar fallback
1877    values[..count]
1878        .iter()
1879        .cloned()
1880        .fold(f32::NEG_INFINITY, f32::max)
1881}
1882
1883#[cfg(target_arch = "aarch64")]
1884#[target_feature(enable = "neon")]
1885#[allow(unsafe_op_in_unsafe_fn)]
1886unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1887    use std::arch::aarch64::*;
1888
1889    let chunks = count / 4;
1890    let remainder = count % 4;
1891
1892    let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1893
1894    for chunk in 0..chunks {
1895        let base = chunk * 4;
1896        let v = vld1q_f32(values.as_ptr().add(base));
1897        max_v = vmaxq_f32(max_v, v);
1898    }
1899
1900    // Horizontal max
1901    let mut max_val = vmaxvq_f32(max_v);
1902
1903    // Handle remainder
1904    let base = chunks * 4;
1905    for i in 0..remainder {
1906        max_val = max_val.max(values[base + i]);
1907    }
1908
1909    max_val
1910}
1911
1912#[cfg(target_arch = "x86_64")]
1913#[target_feature(enable = "sse")]
1914#[allow(unsafe_op_in_unsafe_fn)]
1915unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1916    use std::arch::x86_64::*;
1917
1918    let chunks = count / 4;
1919    let remainder = count % 4;
1920
1921    let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1922
1923    for chunk in 0..chunks {
1924        let base = chunk * 4;
1925        let v = _mm_loadu_ps(values.as_ptr().add(base));
1926        max_v = _mm_max_ps(max_v, v);
1927    }
1928
1929    // Horizontal max: [a, b, c, d] -> max(a, b, c, d)
1930    let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); // [b, a, d, c]
1931    let max1 = _mm_max_ps(max_v, shuf); // [max(a,b), max(a,b), max(c,d), max(c,d)]
1932    let shuf2 = _mm_movehl_ps(max1, max1); // [max(c,d), max(c,d), ?, ?]
1933    let final_max = _mm_max_ss(max1, shuf2); // [max(a,b,c,d), ?, ?, ?]
1934
1935    let mut max_val = _mm_cvtss_f32(final_max);
1936
1937    // Handle remainder
1938    let base = chunks * 4;
1939    for i in 0..remainder {
1940        max_val = max_val.max(values[base + i]);
1941    }
1942
1943    max_val
1944}
1945
1946// ============================================================================
1947// Batched Cosine Similarity for Dense Vector Search
1948// ============================================================================
1949
1950/// Fused dot-product + self-norm in a single pass (SIMD accelerated).
1951///
1952/// Returns (dot(a, b), dot(b, b)) — i.e. the dot product of a·b and ||b||².
1953/// Loads `b` only once (halves memory bandwidth vs two separate dot products).
1954#[inline]
1955fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1956    #[cfg(target_arch = "aarch64")]
1957    {
1958        if neon::is_available() {
1959            return unsafe { fused_dot_norm_neon(a, b, count) };
1960        }
1961    }
1962
1963    #[cfg(target_arch = "x86_64")]
1964    {
1965        if is_x86_feature_detected!("avx512f") {
1966            return unsafe { fused_dot_norm_avx512(a, b, count) };
1967        }
1968        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1969            return unsafe { fused_dot_norm_avx2(a, b, count) };
1970        }
1971        if sse::is_available() {
1972            return unsafe { fused_dot_norm_sse(a, b, count) };
1973        }
1974    }
1975
1976    // Scalar fallback
1977    let mut dot = 0.0f32;
1978    let mut norm_b = 0.0f32;
1979    for i in 0..count {
1980        dot += a[i] * b[i];
1981        norm_b += b[i] * b[i];
1982    }
1983    (dot, norm_b)
1984}
1985
1986#[cfg(target_arch = "aarch64")]
1987#[target_feature(enable = "neon")]
1988#[allow(unsafe_op_in_unsafe_fn)]
1989unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1990    use std::arch::aarch64::*;
1991
1992    let chunks16 = count / 16;
1993    let remainder = count % 16;
1994
1995    let mut d0 = vdupq_n_f32(0.0);
1996    let mut d1 = vdupq_n_f32(0.0);
1997    let mut d2 = vdupq_n_f32(0.0);
1998    let mut d3 = vdupq_n_f32(0.0);
1999    let mut n0 = vdupq_n_f32(0.0);
2000    let mut n1 = vdupq_n_f32(0.0);
2001    let mut n2 = vdupq_n_f32(0.0);
2002    let mut n3 = vdupq_n_f32(0.0);
2003
2004    for c in 0..chunks16 {
2005        let base = c * 16;
2006        let va0 = vld1q_f32(a.as_ptr().add(base));
2007        let vb0 = vld1q_f32(b.as_ptr().add(base));
2008        d0 = vfmaq_f32(d0, va0, vb0);
2009        n0 = vfmaq_f32(n0, vb0, vb0);
2010        let va1 = vld1q_f32(a.as_ptr().add(base + 4));
2011        let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
2012        d1 = vfmaq_f32(d1, va1, vb1);
2013        n1 = vfmaq_f32(n1, vb1, vb1);
2014        let va2 = vld1q_f32(a.as_ptr().add(base + 8));
2015        let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
2016        d2 = vfmaq_f32(d2, va2, vb2);
2017        n2 = vfmaq_f32(n2, vb2, vb2);
2018        let va3 = vld1q_f32(a.as_ptr().add(base + 12));
2019        let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
2020        d3 = vfmaq_f32(d3, va3, vb3);
2021        n3 = vfmaq_f32(n3, vb3, vb3);
2022    }
2023
2024    let acc_dot = vaddq_f32(vaddq_f32(d0, d1), vaddq_f32(d2, d3));
2025    let acc_norm = vaddq_f32(vaddq_f32(n0, n1), vaddq_f32(n2, n3));
2026    let mut dot = vaddvq_f32(acc_dot);
2027    let mut norm = vaddvq_f32(acc_norm);
2028
2029    let base = chunks16 * 16;
2030    for i in 0..remainder {
2031        dot += a[base + i] * b[base + i];
2032        norm += b[base + i] * b[base + i];
2033    }
2034
2035    (dot, norm)
2036}
2037
2038#[cfg(target_arch = "x86_64")]
2039#[target_feature(enable = "avx2", enable = "fma")]
2040#[allow(unsafe_op_in_unsafe_fn)]
2041unsafe fn fused_dot_norm_avx2(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
2042    use std::arch::x86_64::*;
2043
2044    let chunks32 = count / 32;
2045    let remainder = count % 32;
2046
2047    let mut d0 = _mm256_setzero_ps();
2048    let mut d1 = _mm256_setzero_ps();
2049    let mut d2 = _mm256_setzero_ps();
2050    let mut d3 = _mm256_setzero_ps();
2051    let mut n0 = _mm256_setzero_ps();
2052    let mut n1 = _mm256_setzero_ps();
2053    let mut n2 = _mm256_setzero_ps();
2054    let mut n3 = _mm256_setzero_ps();
2055
2056    for c in 0..chunks32 {
2057        let base = c * 32;
2058        let vb0 = _mm256_loadu_ps(b.as_ptr().add(base));
2059        d0 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base)), vb0, d0);
2060        n0 = _mm256_fmadd_ps(vb0, vb0, n0);
2061        let vb1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
2062        d1 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 8)), vb1, d1);
2063        n1 = _mm256_fmadd_ps(vb1, vb1, n1);
2064        let vb2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
2065        d2 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 16)), vb2, d2);
2066        n2 = _mm256_fmadd_ps(vb2, vb2, n2);
2067        let vb3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
2068        d3 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 24)), vb3, d3);
2069        n3 = _mm256_fmadd_ps(vb3, vb3, n3);
2070    }
2071
2072    let acc_dot = _mm256_add_ps(_mm256_add_ps(d0, d1), _mm256_add_ps(d2, d3));
2073    let acc_norm = _mm256_add_ps(_mm256_add_ps(n0, n1), _mm256_add_ps(n2, n3));
2074
2075    // Horizontal sums: 256→128→scalar
2076    let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2077    let lo_d = _mm256_castps256_ps128(acc_dot);
2078    let sum_d = _mm_add_ps(lo_d, hi_d);
2079    let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2080    let sums_d = _mm_add_ps(sum_d, shuf_d);
2081    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2082    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2083
2084    let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2085    let lo_n = _mm256_castps256_ps128(acc_norm);
2086    let sum_n = _mm_add_ps(lo_n, hi_n);
2087    let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2088    let sums_n = _mm_add_ps(sum_n, shuf_n);
2089    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2090    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2091
2092    let base = chunks32 * 32;
2093    for i in 0..remainder {
2094        dot += a[base + i] * b[base + i];
2095        norm += b[base + i] * b[base + i];
2096    }
2097
2098    (dot, norm)
2099}
2100
2101#[cfg(target_arch = "x86_64")]
2102#[target_feature(enable = "sse")]
2103#[allow(unsafe_op_in_unsafe_fn)]
2104unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
2105    use std::arch::x86_64::*;
2106
2107    let chunks = count / 4;
2108    let remainder = count % 4;
2109
2110    let mut acc_dot = _mm_setzero_ps();
2111    let mut acc_norm = _mm_setzero_ps();
2112
2113    for chunk in 0..chunks {
2114        let base = chunk * 4;
2115        let va = _mm_loadu_ps(a.as_ptr().add(base));
2116        let vb = _mm_loadu_ps(b.as_ptr().add(base));
2117        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2118        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2119    }
2120
2121    // Horizontal sums
2122    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2123    let sums_d = _mm_add_ps(acc_dot, shuf_d);
2124    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2125    let final_d = _mm_add_ss(sums_d, shuf2_d);
2126    let mut dot = _mm_cvtss_f32(final_d);
2127
2128    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2129    let sums_n = _mm_add_ps(acc_norm, shuf_n);
2130    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2131    let final_n = _mm_add_ss(sums_n, shuf2_n);
2132    let mut norm = _mm_cvtss_f32(final_n);
2133
2134    let base = chunks * 4;
2135    for i in 0..remainder {
2136        dot += a[base + i] * b[base + i];
2137        norm += b[base + i] * b[base + i];
2138    }
2139
2140    (dot, norm)
2141}
2142
2143/// Fast approximate reciprocal square root: 1/sqrt(x).
2144///
2145/// Uses the IEEE 754 bit trick (Quake III) + one Newton-Raphson iteration
2146/// for ~23-bit precision — sufficient for cosine similarity scoring.
2147/// ~3-5x faster than `1.0 / x.sqrt()` on most architectures.
2148#[inline]
2149pub fn fast_inv_sqrt(x: f32) -> f32 {
2150    let half = 0.5 * x;
2151    let i = 0x5F37_5A86_u32.wrapping_sub(x.to_bits() >> 1);
2152    let y = f32::from_bits(i);
2153    let y = y * (1.5 - half * y * y); // first Newton-Raphson step
2154    y * (1.5 - half * y * y) // second step: ~23-bit precision
2155}
2156
2157/// Batch cosine similarity: query vs N contiguous vectors.
2158///
2159/// `vectors` is a contiguous buffer of `n * dim` floats (row-major).
2160/// `scores` must have length >= n.
2161///
2162/// Optimizations over calling `cosine_similarity` N times:
2163/// 1. Query norm computed once (not N times)
2164/// 2. Fused dot+norm kernel — each vector loaded once (halves bandwidth)
2165/// 3. No per-call overhead (branch prediction, function calls)
2166/// 4. Fast reciprocal square root (~3-5x faster than 1/sqrt)
2167#[inline]
2168pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2169    let n = scores.len();
2170    debug_assert!(vectors.len() >= n * dim);
2171    debug_assert_eq!(query.len(), dim);
2172
2173    if dim == 0 || n == 0 {
2174        return;
2175    }
2176
2177    // Pre-compute query inverse norm once
2178    let norm_q_sq = dot_product_f32(query, query, dim);
2179    if norm_q_sq < f32::EPSILON {
2180        for s in scores.iter_mut() {
2181            *s = 0.0;
2182        }
2183        return;
2184    }
2185    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2186
2187    for i in 0..n {
2188        let vec = &vectors[i * dim..(i + 1) * dim];
2189        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2190        if norm_v_sq < f32::EPSILON {
2191            scores[i] = 0.0;
2192        } else {
2193            scores[i] = dot * inv_norm_q * fast_inv_sqrt(norm_v_sq);
2194        }
2195    }
2196}
2197
2198// ============================================================================
2199// f16 (IEEE 754 half-precision) conversion
2200// ============================================================================
2201
2202/// Convert f32 to f16 (IEEE 754 half-precision), stored as u16
2203#[inline]
2204pub fn f32_to_f16(value: f32) -> u16 {
2205    let bits = value.to_bits();
2206    let sign = (bits >> 16) & 0x8000;
2207    let exp = ((bits >> 23) & 0xFF) as i32;
2208    let mantissa = bits & 0x7F_FFFF;
2209
2210    if exp == 255 {
2211        // Inf/NaN
2212        return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
2213    }
2214
2215    let exp16 = exp - 127 + 15;
2216
2217    if exp16 >= 31 {
2218        return (sign | 0x7C00) as u16; // overflow → infinity
2219    }
2220
2221    if exp16 <= 0 {
2222        if exp16 < -10 {
2223            return sign as u16; // too small → zero
2224        }
2225        let shift = (1 - exp16) as u32;
2226        let m = (mantissa | 0x80_0000) >> shift;
2227        // Round-to-nearest-even
2228        let round_bit = (m >> 12) & 1;
2229        let sticky = m & 0xFFF;
2230        let m13 = m >> 13;
2231        let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
2232        return (sign | rounded) as u16;
2233    }
2234
2235    // Round-to-nearest-even for normal numbers
2236    let round_bit = (mantissa >> 12) & 1;
2237    let sticky = mantissa & 0xFFF;
2238    let m13 = mantissa >> 13;
2239    let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
2240    // Check if rounding caused mantissa overflow (carry into exponent)
2241    if rounded > 0x3FF {
2242        let exp16_inc = exp16 as u32 + 1;
2243        if exp16_inc >= 31 {
2244            return (sign | 0x7C00) as u16; // overflow → infinity
2245        }
2246        (sign | (exp16_inc << 10)) as u16
2247    } else {
2248        (sign | ((exp16 as u32) << 10) | rounded) as u16
2249    }
2250}
2251
2252/// Convert f16 (stored as u16) to f32
2253#[inline]
2254pub fn f16_to_f32(half: u16) -> f32 {
2255    let sign = ((half & 0x8000) as u32) << 16;
2256    let exp = ((half >> 10) & 0x1F) as u32;
2257    let mantissa = (half & 0x3FF) as u32;
2258
2259    if exp == 0 {
2260        if mantissa == 0 {
2261            return f32::from_bits(sign);
2262        }
2263        // Subnormal: normalize
2264        let mut e = 0u32;
2265        let mut m = mantissa;
2266        while (m & 0x400) == 0 {
2267            m <<= 1;
2268            e += 1;
2269        }
2270        return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
2271    }
2272
2273    if exp == 31 {
2274        return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
2275    }
2276
2277    f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
2278}
2279
2280// ============================================================================
2281// uint8 scalar quantization for [-1, 1] range
2282// ============================================================================
2283
2284const U8_SCALE: f32 = 127.5;
2285const U8_INV_SCALE: f32 = 1.0 / 127.5;
2286
2287/// Quantize f32 in [-1, 1] to u8 [0, 255]
2288#[inline]
2289pub fn f32_to_u8_saturating(value: f32) -> u8 {
2290    ((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
2291}
2292
2293/// Dequantize u8 [0, 255] to f32 in [-1, 1]
2294#[inline]
2295pub fn u8_to_f32(byte: u8) -> f32 {
2296    byte as f32 * U8_INV_SCALE - 1.0
2297}
2298
2299// ============================================================================
2300// Batch conversion (used during builder write)
2301// ============================================================================
2302
2303/// Batch convert f32 slice to f16 (stored as u16)
2304pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
2305    debug_assert_eq!(src.len(), dst.len());
2306    for (s, d) in src.iter().zip(dst.iter_mut()) {
2307        *d = f32_to_f16(*s);
2308    }
2309}
2310
2311/// Batch convert f32 slice to u8 with [-1,1] → [0,255] mapping
2312pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
2313    debug_assert_eq!(src.len(), dst.len());
2314    for (s, d) in src.iter().zip(dst.iter_mut()) {
2315        *d = f32_to_u8_saturating(*s);
2316    }
2317}
2318
2319// ============================================================================
2320// NEON-accelerated fused dot+norm for quantized vectors
2321// ============================================================================
2322
2323#[cfg(target_arch = "aarch64")]
2324#[allow(unsafe_op_in_unsafe_fn)]
2325mod neon_quant {
2326    use std::arch::aarch64::*;
2327
2328    /// Fused dot(query_f16, vec_f16) + norm(vec_f16) for f16 vectors on NEON.
2329    ///
2330    /// Both query and vectors are f16 (stored as u16). Uses hardware `vcvt_f32_f16`
2331    /// for SIMD f16→f32 conversion (replaces scalar bit manipulation), processes
2332    /// 8 elements per iteration with f32 accumulation for precision.
2333    #[allow(clippy::incompatible_msrv)]
2334    #[target_feature(enable = "neon")]
2335    pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2336        let chunks16 = dim / 16;
2337        let remainder = dim % 16;
2338
2339        // 2 accumulator pairs to hide FMA latency (processes 16 f16 per iteration)
2340        let mut acc_dot0 = vdupq_n_f32(0.0);
2341        let mut acc_dot1 = vdupq_n_f32(0.0);
2342        let mut acc_norm0 = vdupq_n_f32(0.0);
2343        let mut acc_norm1 = vdupq_n_f32(0.0);
2344
2345        for c in 0..chunks16 {
2346            let base = c * 16;
2347
2348            // First 8 f16 elements
2349            let v_raw0 = vld1q_u16(vec_f16.as_ptr().add(base));
2350            let v_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw0)));
2351            let v_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw0)));
2352            let q_raw0 = vld1q_u16(query_f16.as_ptr().add(base));
2353            let q_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw0)));
2354            let q_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw0)));
2355
2356            acc_dot0 = vfmaq_f32(acc_dot0, q_lo0, v_lo0);
2357            acc_dot0 = vfmaq_f32(acc_dot0, q_hi0, v_hi0);
2358            acc_norm0 = vfmaq_f32(acc_norm0, v_lo0, v_lo0);
2359            acc_norm0 = vfmaq_f32(acc_norm0, v_hi0, v_hi0);
2360
2361            // Second 8 f16 elements (independent accumulator chain)
2362            let v_raw1 = vld1q_u16(vec_f16.as_ptr().add(base + 8));
2363            let v_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw1)));
2364            let v_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw1)));
2365            let q_raw1 = vld1q_u16(query_f16.as_ptr().add(base + 8));
2366            let q_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw1)));
2367            let q_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw1)));
2368
2369            acc_dot1 = vfmaq_f32(acc_dot1, q_lo1, v_lo1);
2370            acc_dot1 = vfmaq_f32(acc_dot1, q_hi1, v_hi1);
2371            acc_norm1 = vfmaq_f32(acc_norm1, v_lo1, v_lo1);
2372            acc_norm1 = vfmaq_f32(acc_norm1, v_hi1, v_hi1);
2373        }
2374
2375        // Combine accumulator pairs
2376        let mut dot = vaddvq_f32(vaddq_f32(acc_dot0, acc_dot1));
2377        let mut norm = vaddvq_f32(vaddq_f32(acc_norm0, acc_norm1));
2378
2379        // Handle remainder
2380        let base = chunks16 * 16;
2381        for i in 0..remainder {
2382            let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2383            let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2384            dot += q * v;
2385            norm += v * v;
2386        }
2387
2388        (dot, norm)
2389    }
2390
2391    /// Fused dot(query, vec) + norm(vec) for u8 vectors on NEON.
2392    /// Processes 16 u8 values per iteration using NEON widening chain.
2393    #[target_feature(enable = "neon")]
2394    pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2395        let scale = vdupq_n_f32(super::U8_INV_SCALE);
2396        let offset = vdupq_n_f32(-1.0);
2397
2398        let chunks16 = dim / 16;
2399        let remainder = dim % 16;
2400
2401        let mut acc_dot = vdupq_n_f32(0.0);
2402        let mut acc_norm = vdupq_n_f32(0.0);
2403
2404        for c in 0..chunks16 {
2405            let base = c * 16;
2406
2407            // Load 16 u8 values
2408            let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2409
2410            // Widen: 16×u8 → 2×8×u16 → 4×4×u32 → 4×4×f32
2411            let lo8 = vget_low_u8(bytes);
2412            let hi8 = vget_high_u8(bytes);
2413            let lo16 = vmovl_u8(lo8);
2414            let hi16 = vmovl_u8(hi8);
2415
2416            let f0 = vaddq_f32(
2417                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2418                offset,
2419            );
2420            let f1 = vaddq_f32(
2421                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2422                offset,
2423            );
2424            let f2 = vaddq_f32(
2425                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2426                offset,
2427            );
2428            let f3 = vaddq_f32(
2429                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2430                offset,
2431            );
2432
2433            let q0 = vld1q_f32(query.as_ptr().add(base));
2434            let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2435            let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2436            let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2437
2438            acc_dot = vfmaq_f32(acc_dot, q0, f0);
2439            acc_dot = vfmaq_f32(acc_dot, q1, f1);
2440            acc_dot = vfmaq_f32(acc_dot, q2, f2);
2441            acc_dot = vfmaq_f32(acc_dot, q3, f3);
2442
2443            acc_norm = vfmaq_f32(acc_norm, f0, f0);
2444            acc_norm = vfmaq_f32(acc_norm, f1, f1);
2445            acc_norm = vfmaq_f32(acc_norm, f2, f2);
2446            acc_norm = vfmaq_f32(acc_norm, f3, f3);
2447        }
2448
2449        let mut dot = vaddvq_f32(acc_dot);
2450        let mut norm = vaddvq_f32(acc_norm);
2451
2452        let base = chunks16 * 16;
2453        for i in 0..remainder {
2454            let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2455            dot += *query.get_unchecked(base + i) * v;
2456            norm += v * v;
2457        }
2458
2459        (dot, norm)
2460    }
2461
2462    /// Dot product only for f16 vectors on NEON (no norm — for unit_norm vectors).
2463    #[allow(clippy::incompatible_msrv)]
2464    #[target_feature(enable = "neon")]
2465    pub unsafe fn dot_product_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2466        let chunks8 = dim / 8;
2467        let remainder = dim % 8;
2468
2469        let mut acc = vdupq_n_f32(0.0);
2470
2471        for c in 0..chunks8 {
2472            let base = c * 8;
2473            let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2474            let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2475            let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2476            let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2477            let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2478            let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2479            acc = vfmaq_f32(acc, q_lo, v_lo);
2480            acc = vfmaq_f32(acc, q_hi, v_hi);
2481        }
2482
2483        let mut dot = vaddvq_f32(acc);
2484        let base = chunks8 * 8;
2485        for i in 0..remainder {
2486            let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2487            let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2488            dot += q * v;
2489        }
2490        dot
2491    }
2492
2493    /// Dot product only for u8 vectors on NEON (no norm — for unit_norm vectors).
2494    #[target_feature(enable = "neon")]
2495    pub unsafe fn dot_product_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2496        let scale = vdupq_n_f32(super::U8_INV_SCALE);
2497        let offset = vdupq_n_f32(-1.0);
2498        let chunks16 = dim / 16;
2499        let remainder = dim % 16;
2500
2501        let mut acc = vdupq_n_f32(0.0);
2502
2503        for c in 0..chunks16 {
2504            let base = c * 16;
2505            let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2506            let lo8 = vget_low_u8(bytes);
2507            let hi8 = vget_high_u8(bytes);
2508            let lo16 = vmovl_u8(lo8);
2509            let hi16 = vmovl_u8(hi8);
2510            let f0 = vaddq_f32(
2511                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2512                offset,
2513            );
2514            let f1 = vaddq_f32(
2515                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2516                offset,
2517            );
2518            let f2 = vaddq_f32(
2519                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2520                offset,
2521            );
2522            let f3 = vaddq_f32(
2523                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2524                offset,
2525            );
2526            let q0 = vld1q_f32(query.as_ptr().add(base));
2527            let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2528            let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2529            let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2530            acc = vfmaq_f32(acc, q0, f0);
2531            acc = vfmaq_f32(acc, q1, f1);
2532            acc = vfmaq_f32(acc, q2, f2);
2533            acc = vfmaq_f32(acc, q3, f3);
2534        }
2535
2536        let mut dot = vaddvq_f32(acc);
2537        let base = chunks16 * 16;
2538        for i in 0..remainder {
2539            let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2540            dot += *query.get_unchecked(base + i) * v;
2541        }
2542        dot
2543    }
2544}
2545
2546// ============================================================================
2547// Scalar fallback for fused dot+norm on quantized vectors
2548// ============================================================================
2549
2550#[allow(dead_code)]
2551fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2552    let mut dot = 0.0f32;
2553    let mut norm = 0.0f32;
2554    for i in 0..dim {
2555        let v = f16_to_f32(vec_f16[i]);
2556        let q = f16_to_f32(query_f16[i]);
2557        dot += q * v;
2558        norm += v * v;
2559    }
2560    (dot, norm)
2561}
2562
2563#[allow(dead_code)]
2564fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2565    let mut dot = 0.0f32;
2566    let mut norm = 0.0f32;
2567    for i in 0..dim {
2568        let v = u8_to_f32(vec_u8[i]);
2569        dot += query[i] * v;
2570        norm += v * v;
2571    }
2572    (dot, norm)
2573}
2574
2575#[allow(dead_code)]
2576fn dot_product_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2577    let mut dot = 0.0f32;
2578    for i in 0..dim {
2579        dot += f16_to_f32(query_f16[i]) * f16_to_f32(vec_f16[i]);
2580    }
2581    dot
2582}
2583
2584#[allow(dead_code)]
2585fn dot_product_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2586    let mut dot = 0.0f32;
2587    for i in 0..dim {
2588        dot += query[i] * u8_to_f32(vec_u8[i]);
2589    }
2590    dot
2591}
2592
2593// ============================================================================
2594// x86_64 SSE4.1 quantized fused dot+norm
2595// ============================================================================
2596
2597#[cfg(target_arch = "x86_64")]
2598#[target_feature(enable = "sse2", enable = "sse4.1")]
2599#[allow(unsafe_op_in_unsafe_fn)]
2600unsafe fn fused_dot_norm_f16_sse(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2601    use std::arch::x86_64::*;
2602
2603    let chunks = dim / 4;
2604    let remainder = dim % 4;
2605
2606    let mut acc_dot = _mm_setzero_ps();
2607    let mut acc_norm = _mm_setzero_ps();
2608
2609    for chunk in 0..chunks {
2610        let base = chunk * 4;
2611        // Load 4 f16 values and convert to f32 using scalar conversion
2612        let v0 = f16_to_f32(*vec_f16.get_unchecked(base));
2613        let v1 = f16_to_f32(*vec_f16.get_unchecked(base + 1));
2614        let v2 = f16_to_f32(*vec_f16.get_unchecked(base + 2));
2615        let v3 = f16_to_f32(*vec_f16.get_unchecked(base + 3));
2616        let vb = _mm_set_ps(v3, v2, v1, v0);
2617
2618        let q0 = f16_to_f32(*query_f16.get_unchecked(base));
2619        let q1 = f16_to_f32(*query_f16.get_unchecked(base + 1));
2620        let q2 = f16_to_f32(*query_f16.get_unchecked(base + 2));
2621        let q3 = f16_to_f32(*query_f16.get_unchecked(base + 3));
2622        let va = _mm_set_ps(q3, q2, q1, q0);
2623
2624        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2625        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2626    }
2627
2628    // Horizontal sums
2629    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2630    let sums_d = _mm_add_ps(acc_dot, shuf_d);
2631    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2632    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2633
2634    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2635    let sums_n = _mm_add_ps(acc_norm, shuf_n);
2636    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2637    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2638
2639    let base = chunks * 4;
2640    for i in 0..remainder {
2641        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2642        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2643        dot += q * v;
2644        norm += v * v;
2645    }
2646
2647    (dot, norm)
2648}
2649
2650#[cfg(target_arch = "x86_64")]
2651#[target_feature(enable = "sse2", enable = "sse4.1")]
2652#[allow(unsafe_op_in_unsafe_fn)]
2653unsafe fn fused_dot_norm_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2654    use std::arch::x86_64::*;
2655
2656    let scale = _mm_set1_ps(U8_INV_SCALE);
2657    let offset = _mm_set1_ps(-1.0);
2658
2659    let chunks = dim / 4;
2660    let remainder = dim % 4;
2661
2662    let mut acc_dot = _mm_setzero_ps();
2663    let mut acc_norm = _mm_setzero_ps();
2664
2665    for chunk in 0..chunks {
2666        let base = chunk * 4;
2667
2668        // Load 4 bytes, zero-extend to i32, convert to f32, dequantize
2669        let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2670            vec_u8.as_ptr().add(base) as *const i32
2671        ));
2672        let ints = _mm_cvtepu8_epi32(bytes);
2673        let floats = _mm_cvtepi32_ps(ints);
2674        let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2675
2676        let va = _mm_loadu_ps(query.as_ptr().add(base));
2677
2678        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2679        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2680    }
2681
2682    // Horizontal sums
2683    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2684    let sums_d = _mm_add_ps(acc_dot, shuf_d);
2685    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2686    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2687
2688    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2689    let sums_n = _mm_add_ps(acc_norm, shuf_n);
2690    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2691    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2692
2693    let base = chunks * 4;
2694    for i in 0..remainder {
2695        let v = u8_to_f32(*vec_u8.get_unchecked(base + i));
2696        dot += *query.get_unchecked(base + i) * v;
2697        norm += v * v;
2698    }
2699
2700    (dot, norm)
2701}
2702
2703// ============================================================================
2704// x86_64 F16C + AVX + FMA accelerated f16 scoring
2705// ============================================================================
2706
2707#[cfg(target_arch = "x86_64")]
2708#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2709#[allow(unsafe_op_in_unsafe_fn)]
2710unsafe fn fused_dot_norm_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2711    use std::arch::x86_64::*;
2712
2713    let chunks16 = dim / 16;
2714    let remainder = dim % 16;
2715
2716    // 2 accumulator pairs to hide FMA latency (processes 16 f16 per iteration)
2717    let mut acc_dot0 = _mm256_setzero_ps();
2718    let mut acc_dot1 = _mm256_setzero_ps();
2719    let mut acc_norm0 = _mm256_setzero_ps();
2720    let mut acc_norm1 = _mm256_setzero_ps();
2721
2722    for c in 0..chunks16 {
2723        let base = c * 16;
2724
2725        // First 8 f16 elements
2726        let v_raw0 = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2727        let vb0 = _mm256_cvtph_ps(v_raw0);
2728        let q_raw0 = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2729        let qa0 = _mm256_cvtph_ps(q_raw0);
2730        acc_dot0 = _mm256_fmadd_ps(qa0, vb0, acc_dot0);
2731        acc_norm0 = _mm256_fmadd_ps(vb0, vb0, acc_norm0);
2732
2733        // Second 8 f16 elements (independent accumulator chain)
2734        let v_raw1 = _mm_loadu_si128(vec_f16.as_ptr().add(base + 8) as *const __m128i);
2735        let vb1 = _mm256_cvtph_ps(v_raw1);
2736        let q_raw1 = _mm_loadu_si128(query_f16.as_ptr().add(base + 8) as *const __m128i);
2737        let qa1 = _mm256_cvtph_ps(q_raw1);
2738        acc_dot1 = _mm256_fmadd_ps(qa1, vb1, acc_dot1);
2739        acc_norm1 = _mm256_fmadd_ps(vb1, vb1, acc_norm1);
2740    }
2741
2742    // Combine accumulator pairs
2743    let acc_dot = _mm256_add_ps(acc_dot0, acc_dot1);
2744    let acc_norm = _mm256_add_ps(acc_norm0, acc_norm1);
2745
2746    // Horizontal sum 256→128→scalar
2747    let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2748    let lo_d = _mm256_castps256_ps128(acc_dot);
2749    let sum_d = _mm_add_ps(lo_d, hi_d);
2750    let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2751    let sums_d = _mm_add_ps(sum_d, shuf_d);
2752    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2753    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2754
2755    let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2756    let lo_n = _mm256_castps256_ps128(acc_norm);
2757    let sum_n = _mm_add_ps(lo_n, hi_n);
2758    let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2759    let sums_n = _mm_add_ps(sum_n, shuf_n);
2760    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2761    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2762
2763    let base = chunks16 * 16;
2764    for i in 0..remainder {
2765        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2766        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2767        dot += q * v;
2768        norm += v * v;
2769    }
2770
2771    (dot, norm)
2772}
2773
2774#[cfg(target_arch = "x86_64")]
2775#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2776#[allow(unsafe_op_in_unsafe_fn)]
2777unsafe fn dot_product_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2778    use std::arch::x86_64::*;
2779
2780    let chunks = dim / 8;
2781    let remainder = dim % 8;
2782    let mut acc = _mm256_setzero_ps();
2783
2784    for chunk in 0..chunks {
2785        let base = chunk * 8;
2786        let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2787        let vb = _mm256_cvtph_ps(v_raw);
2788        let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2789        let qa = _mm256_cvtph_ps(q_raw);
2790        acc = _mm256_fmadd_ps(qa, vb, acc);
2791    }
2792
2793    let hi = _mm256_extractf128_ps(acc, 1);
2794    let lo = _mm256_castps256_ps128(acc);
2795    let sum = _mm_add_ps(lo, hi);
2796    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
2797    let sums = _mm_add_ps(sum, shuf);
2798    let shuf2 = _mm_movehl_ps(sums, sums);
2799    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2800
2801    let base = chunks * 8;
2802    for i in 0..remainder {
2803        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2804        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2805        dot += q * v;
2806    }
2807    dot
2808}
2809
2810#[cfg(target_arch = "x86_64")]
2811#[target_feature(enable = "sse2", enable = "sse4.1")]
2812#[allow(unsafe_op_in_unsafe_fn)]
2813unsafe fn dot_product_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2814    use std::arch::x86_64::*;
2815
2816    let scale = _mm_set1_ps(U8_INV_SCALE);
2817    let offset = _mm_set1_ps(-1.0);
2818    let chunks = dim / 4;
2819    let remainder = dim % 4;
2820    let mut acc = _mm_setzero_ps();
2821
2822    for chunk in 0..chunks {
2823        let base = chunk * 4;
2824        let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2825            vec_u8.as_ptr().add(base) as *const i32
2826        ));
2827        let ints = _mm_cvtepu8_epi32(bytes);
2828        let floats = _mm_cvtepi32_ps(ints);
2829        let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2830        let va = _mm_loadu_ps(query.as_ptr().add(base));
2831        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
2832    }
2833
2834    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01);
2835    let sums = _mm_add_ps(acc, shuf);
2836    let shuf2 = _mm_movehl_ps(sums, sums);
2837    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2838
2839    let base = chunks * 4;
2840    for i in 0..remainder {
2841        dot += *query.get_unchecked(base + i) * u8_to_f32(*vec_u8.get_unchecked(base + i));
2842    }
2843    dot
2844}
2845
2846// ============================================================================
2847// Platform dispatch
2848// ============================================================================
2849
2850#[inline]
2851fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2852    #[cfg(target_arch = "aarch64")]
2853    {
2854        return unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) };
2855    }
2856
2857    #[cfg(target_arch = "x86_64")]
2858    {
2859        if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2860            return unsafe { fused_dot_norm_f16_f16c(query_f16, vec_f16, dim) };
2861        }
2862        if sse::is_available() {
2863            return unsafe { fused_dot_norm_f16_sse(query_f16, vec_f16, dim) };
2864        }
2865    }
2866
2867    #[allow(unreachable_code)]
2868    fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
2869}
2870
2871#[inline]
2872fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2873    #[cfg(target_arch = "aarch64")]
2874    {
2875        return unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) };
2876    }
2877
2878    #[cfg(target_arch = "x86_64")]
2879    {
2880        if sse::is_available() {
2881            return unsafe { fused_dot_norm_u8_sse(query, vec_u8, dim) };
2882        }
2883    }
2884
2885    #[allow(unreachable_code)]
2886    fused_dot_norm_u8_scalar(query, vec_u8, dim)
2887}
2888
2889// ── Dot-product-only dispatch (for unit_norm vectors) ─────────────────────
2890
2891#[inline]
2892fn dot_product_f16_quant(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2893    #[cfg(target_arch = "aarch64")]
2894    {
2895        return unsafe { neon_quant::dot_product_f16(query_f16, vec_f16, dim) };
2896    }
2897
2898    #[cfg(target_arch = "x86_64")]
2899    {
2900        if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2901            return unsafe { dot_product_f16_f16c(query_f16, vec_f16, dim) };
2902        }
2903    }
2904
2905    #[allow(unreachable_code)]
2906    dot_product_f16_scalar(query_f16, vec_f16, dim)
2907}
2908
2909#[inline]
2910fn dot_product_u8_quant(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2911    #[cfg(target_arch = "aarch64")]
2912    {
2913        return unsafe { neon_quant::dot_product_u8(query, vec_u8, dim) };
2914    }
2915
2916    #[cfg(target_arch = "x86_64")]
2917    {
2918        if sse::is_available() {
2919            return unsafe { dot_product_u8_sse(query, vec_u8, dim) };
2920        }
2921    }
2922
2923    #[allow(unreachable_code)]
2924    dot_product_u8_scalar(query, vec_u8, dim)
2925}
2926
2927// ============================================================================
2928// Public batch cosine scoring for quantized vectors
2929// ============================================================================
2930
2931/// Batch cosine similarity: f32 query vs N contiguous f16 vectors.
2932///
2933/// `vectors_raw` is raw bytes: N vectors × dim × 2 bytes (f16 stored as u16).
2934/// Query is quantized to f16 once, then both query and vectors are scored in
2935/// f16 space using hardware SIMD conversion (8 elements/iteration on NEON).
2936/// Memory bandwidth is halved for both query and vector loads.
2937#[inline]
2938pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2939    let n = scores.len();
2940    if dim == 0 || n == 0 {
2941        return;
2942    }
2943
2944    // Compute query inverse norm in f32 (full precision, before quantization)
2945    let norm_q_sq = dot_product_f32(query, query, dim);
2946    if norm_q_sq < f32::EPSILON {
2947        for s in scores.iter_mut() {
2948            *s = 0.0;
2949        }
2950        return;
2951    }
2952    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2953
2954    // Quantize query to f16 once (O(dim)), reused for all N vector scorings
2955    let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2956
2957    let vec_bytes = dim * 2;
2958    debug_assert!(vectors_raw.len() >= n * vec_bytes);
2959
2960    // Vectors file uses data-first layout with 8-byte padding between fields,
2961    // so mmap slices are always 2-byte aligned for u16 access.
2962    debug_assert!(
2963        (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2964        "f16 vector data not 2-byte aligned"
2965    );
2966
2967    for i in 0..n {
2968        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2969        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2970
2971        let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
2972        scores[i] = if norm_v_sq < f32::EPSILON {
2973            0.0
2974        } else {
2975            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2976        };
2977    }
2978}
2979
2980/// Batch cosine similarity: f32 query vs N contiguous u8 vectors.
2981///
2982/// `vectors_raw` is raw bytes: N vectors × dim bytes (u8, mapping [-1,1]→[0,255]).
2983/// Converts u8→f32 using NEON widening chain (16 values/iteration), scores with FMA.
2984/// Memory bandwidth is quartered compared to f32 scoring.
2985#[inline]
2986pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2987    let n = scores.len();
2988    if dim == 0 || n == 0 {
2989        return;
2990    }
2991
2992    let norm_q_sq = dot_product_f32(query, query, dim);
2993    if norm_q_sq < f32::EPSILON {
2994        for s in scores.iter_mut() {
2995            *s = 0.0;
2996        }
2997        return;
2998    }
2999    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3000
3001    debug_assert!(vectors_raw.len() >= n * dim);
3002
3003    for i in 0..n {
3004        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3005
3006        let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
3007        scores[i] = if norm_v_sq < f32::EPSILON {
3008            0.0
3009        } else {
3010            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3011        };
3012    }
3013}
3014
3015// ============================================================================
3016// Batch dot-product scoring for unit-norm vectors
3017// ============================================================================
3018
3019/// Batch dot-product scoring: f32 query vs N contiguous f32 unit-norm vectors.
3020///
3021/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
3022/// Skips per-vector norm computation — ~40% less work than `batch_cosine_scores`.
3023#[inline]
3024pub fn batch_dot_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
3025    let n = scores.len();
3026    debug_assert!(vectors.len() >= n * dim);
3027    debug_assert_eq!(query.len(), dim);
3028
3029    if dim == 0 || n == 0 {
3030        return;
3031    }
3032
3033    let norm_q_sq = dot_product_f32(query, query, dim);
3034    if norm_q_sq < f32::EPSILON {
3035        for s in scores.iter_mut() {
3036            *s = 0.0;
3037        }
3038        return;
3039    }
3040    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3041
3042    for i in 0..n {
3043        let vec = &vectors[i * dim..(i + 1) * dim];
3044        let dot = dot_product_f32(query, vec, dim);
3045        scores[i] = dot * inv_norm_q;
3046    }
3047}
3048
3049/// Batch dot-product scoring: f32 query vs N contiguous f16 unit-norm vectors.
3050///
3051/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
3052/// Uses F16C/NEON hardware conversion + dot-only kernel.
3053#[inline]
3054pub fn batch_dot_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
3055    let n = scores.len();
3056    if dim == 0 || n == 0 {
3057        return;
3058    }
3059
3060    let norm_q_sq = dot_product_f32(query, query, dim);
3061    if norm_q_sq < f32::EPSILON {
3062        for s in scores.iter_mut() {
3063            *s = 0.0;
3064        }
3065        return;
3066    }
3067    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3068
3069    let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
3070    let vec_bytes = dim * 2;
3071    debug_assert!(vectors_raw.len() >= n * vec_bytes);
3072    debug_assert!(
3073        (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
3074        "f16 vector data not 2-byte aligned"
3075    );
3076
3077    for i in 0..n {
3078        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3079        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3080        let dot = dot_product_f16_quant(&query_f16, f16_slice, dim);
3081        scores[i] = dot * inv_norm_q;
3082    }
3083}
3084
3085/// Batch dot-product scoring: f32 query vs N contiguous u8 unit-norm vectors.
3086///
3087/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
3088/// Uses NEON/SSE widening chain for u8→f32 conversion + dot-only kernel.
3089#[inline]
3090pub fn batch_dot_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
3091    let n = scores.len();
3092    if dim == 0 || n == 0 {
3093        return;
3094    }
3095
3096    let norm_q_sq = dot_product_f32(query, query, dim);
3097    if norm_q_sq < f32::EPSILON {
3098        for s in scores.iter_mut() {
3099            *s = 0.0;
3100        }
3101        return;
3102    }
3103    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
3104
3105    debug_assert!(vectors_raw.len() >= n * dim);
3106
3107    for i in 0..n {
3108        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3109        let dot = dot_product_u8_quant(query, u8_slice, dim);
3110        scores[i] = dot * inv_norm_q;
3111    }
3112}
3113
3114// ============================================================================
3115// Precomputed-norm batch scoring (avoids redundant query norm + f16 conversion)
3116// ============================================================================
3117
3118/// Batch cosine: f32 query vs N f32 vectors, with precomputed `inv_norm_q`.
3119#[inline]
3120pub fn batch_cosine_scores_precomp(
3121    query: &[f32],
3122    vectors: &[f32],
3123    dim: usize,
3124    scores: &mut [f32],
3125    inv_norm_q: f32,
3126) {
3127    let n = scores.len();
3128    debug_assert!(vectors.len() >= n * dim);
3129    for i in 0..n {
3130        let vec = &vectors[i * dim..(i + 1) * dim];
3131        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
3132        scores[i] = if norm_v_sq < f32::EPSILON {
3133            0.0
3134        } else {
3135            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3136        };
3137    }
3138}
3139
3140/// Batch cosine: precomputed `inv_norm_q` + `query_f16` vs N f16 vectors.
3141#[inline]
3142pub fn batch_cosine_scores_f16_precomp(
3143    query_f16: &[u16],
3144    vectors_raw: &[u8],
3145    dim: usize,
3146    scores: &mut [f32],
3147    inv_norm_q: f32,
3148) {
3149    let n = scores.len();
3150    let vec_bytes = dim * 2;
3151    debug_assert!(vectors_raw.len() >= n * vec_bytes);
3152    for i in 0..n {
3153        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3154        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3155        let (dot, norm_v_sq) = fused_dot_norm_f16(query_f16, f16_slice, dim);
3156        scores[i] = if norm_v_sq < f32::EPSILON {
3157            0.0
3158        } else {
3159            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3160        };
3161    }
3162}
3163
3164/// Batch cosine: precomputed `inv_norm_q` vs N u8 vectors.
3165#[inline]
3166pub fn batch_cosine_scores_u8_precomp(
3167    query: &[f32],
3168    vectors_raw: &[u8],
3169    dim: usize,
3170    scores: &mut [f32],
3171    inv_norm_q: f32,
3172) {
3173    let n = scores.len();
3174    debug_assert!(vectors_raw.len() >= n * dim);
3175    for i in 0..n {
3176        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3177        let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
3178        scores[i] = if norm_v_sq < f32::EPSILON {
3179            0.0
3180        } else {
3181            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
3182        };
3183    }
3184}
3185
3186/// Batch dot-product: precomputed `inv_norm_q` vs N f32 unit-norm vectors.
3187#[inline]
3188pub fn batch_dot_scores_precomp(
3189    query: &[f32],
3190    vectors: &[f32],
3191    dim: usize,
3192    scores: &mut [f32],
3193    inv_norm_q: f32,
3194) {
3195    let n = scores.len();
3196    debug_assert!(vectors.len() >= n * dim);
3197    for i in 0..n {
3198        let vec = &vectors[i * dim..(i + 1) * dim];
3199        scores[i] = dot_product_f32(query, vec, dim) * inv_norm_q;
3200    }
3201}
3202
3203/// Batch dot-product: precomputed `inv_norm_q` + `query_f16` vs N f16 unit-norm vectors.
3204#[inline]
3205pub fn batch_dot_scores_f16_precomp(
3206    query_f16: &[u16],
3207    vectors_raw: &[u8],
3208    dim: usize,
3209    scores: &mut [f32],
3210    inv_norm_q: f32,
3211) {
3212    let n = scores.len();
3213    let vec_bytes = dim * 2;
3214    debug_assert!(vectors_raw.len() >= n * vec_bytes);
3215    for i in 0..n {
3216        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3217        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3218        scores[i] = dot_product_f16_quant(query_f16, f16_slice, dim) * inv_norm_q;
3219    }
3220}
3221
3222/// Batch dot-product: precomputed `inv_norm_q` vs N u8 unit-norm vectors.
3223#[inline]
3224pub fn batch_dot_scores_u8_precomp(
3225    query: &[f32],
3226    vectors_raw: &[u8],
3227    dim: usize,
3228    scores: &mut [f32],
3229    inv_norm_q: f32,
3230) {
3231    let n = scores.len();
3232    debug_assert!(vectors_raw.len() >= n * dim);
3233    for i in 0..n {
3234        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3235        scores[i] = dot_product_u8_quant(query, u8_slice, dim) * inv_norm_q;
3236    }
3237}
3238
3239/// Compute cosine similarity between two f32 vectors with SIMD acceleration
3240///
3241/// Returns dot(a,b) / (||a|| * ||b||), range [-1, 1]
3242/// Returns 0.0 if either vector has zero norm.
3243#[inline]
3244pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3245    debug_assert_eq!(a.len(), b.len());
3246    let count = a.len();
3247
3248    if count == 0 {
3249        return 0.0;
3250    }
3251
3252    let dot = dot_product_f32(a, b, count);
3253    let norm_a = dot_product_f32(a, a, count);
3254    let norm_b = dot_product_f32(b, b, count);
3255
3256    let denom = (norm_a * norm_b).sqrt();
3257    if denom < f32::EPSILON {
3258        return 0.0;
3259    }
3260
3261    dot / denom
3262}
3263
3264/// Compute squared Euclidean distance between two f32 vectors with SIMD acceleration
3265///
3266/// Returns sum((a[i] - b[i])^2) for all i
3267#[inline]
3268pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
3269    debug_assert_eq!(a.len(), b.len());
3270    let count = a.len();
3271
3272    if count == 0 {
3273        return 0.0;
3274    }
3275
3276    #[cfg(target_arch = "aarch64")]
3277    {
3278        if neon::is_available() {
3279            return unsafe { squared_euclidean_neon(a, b, count) };
3280        }
3281    }
3282
3283    #[cfg(target_arch = "x86_64")]
3284    {
3285        if avx2::is_available() && is_x86_feature_detected!("fma") {
3286            return unsafe { squared_euclidean_avx2(a, b, count) };
3287        }
3288        if sse::is_available() {
3289            return unsafe { squared_euclidean_sse(a, b, count) };
3290        }
3291    }
3292
3293    // Scalar fallback
3294    a.iter()
3295        .zip(b.iter())
3296        .map(|(&x, &y)| {
3297            let d = x - y;
3298            d * d
3299        })
3300        .sum()
3301}
3302
3303#[cfg(target_arch = "aarch64")]
3304#[target_feature(enable = "neon")]
3305#[allow(unsafe_op_in_unsafe_fn)]
3306unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
3307    use std::arch::aarch64::*;
3308
3309    let chunks16 = count / 16;
3310    let remainder = count % 16;
3311
3312    // 4 independent accumulators to hide FMA latency
3313    let mut acc0 = vdupq_n_f32(0.0);
3314    let mut acc1 = vdupq_n_f32(0.0);
3315    let mut acc2 = vdupq_n_f32(0.0);
3316    let mut acc3 = vdupq_n_f32(0.0);
3317
3318    for c in 0..chunks16 {
3319        let base = c * 16;
3320        let va0 = vld1q_f32(a.as_ptr().add(base));
3321        let vb0 = vld1q_f32(b.as_ptr().add(base));
3322        let d0 = vsubq_f32(va0, vb0);
3323        acc0 = vfmaq_f32(acc0, d0, d0);
3324
3325        let va1 = vld1q_f32(a.as_ptr().add(base + 4));
3326        let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
3327        let d1 = vsubq_f32(va1, vb1);
3328        acc1 = vfmaq_f32(acc1, d1, d1);
3329
3330        let va2 = vld1q_f32(a.as_ptr().add(base + 8));
3331        let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
3332        let d2 = vsubq_f32(va2, vb2);
3333        acc2 = vfmaq_f32(acc2, d2, d2);
3334
3335        let va3 = vld1q_f32(a.as_ptr().add(base + 12));
3336        let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
3337        let d3 = vsubq_f32(va3, vb3);
3338        acc3 = vfmaq_f32(acc3, d3, d3);
3339    }
3340
3341    // Combine accumulators and horizontal sum
3342    let combined = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
3343    let mut sum = vaddvq_f32(combined);
3344
3345    // Handle remainder
3346    let base = chunks16 * 16;
3347    for i in 0..remainder {
3348        let d = a[base + i] - b[base + i];
3349        sum += d * d;
3350    }
3351
3352    sum
3353}
3354
3355#[cfg(target_arch = "x86_64")]
3356#[target_feature(enable = "sse")]
3357#[allow(unsafe_op_in_unsafe_fn)]
3358unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
3359    use std::arch::x86_64::*;
3360
3361    let chunks16 = count / 16;
3362    let remainder = count % 16;
3363
3364    // 4 independent accumulators to hide multiply latency
3365    let mut acc0 = _mm_setzero_ps();
3366    let mut acc1 = _mm_setzero_ps();
3367    let mut acc2 = _mm_setzero_ps();
3368    let mut acc3 = _mm_setzero_ps();
3369
3370    for c in 0..chunks16 {
3371        let base = c * 16;
3372
3373        let d0 = _mm_sub_ps(
3374            _mm_loadu_ps(a.as_ptr().add(base)),
3375            _mm_loadu_ps(b.as_ptr().add(base)),
3376        );
3377        acc0 = _mm_add_ps(acc0, _mm_mul_ps(d0, d0));
3378
3379        let d1 = _mm_sub_ps(
3380            _mm_loadu_ps(a.as_ptr().add(base + 4)),
3381            _mm_loadu_ps(b.as_ptr().add(base + 4)),
3382        );
3383        acc1 = _mm_add_ps(acc1, _mm_mul_ps(d1, d1));
3384
3385        let d2 = _mm_sub_ps(
3386            _mm_loadu_ps(a.as_ptr().add(base + 8)),
3387            _mm_loadu_ps(b.as_ptr().add(base + 8)),
3388        );
3389        acc2 = _mm_add_ps(acc2, _mm_mul_ps(d2, d2));
3390
3391        let d3 = _mm_sub_ps(
3392            _mm_loadu_ps(a.as_ptr().add(base + 12)),
3393            _mm_loadu_ps(b.as_ptr().add(base + 12)),
3394        );
3395        acc3 = _mm_add_ps(acc3, _mm_mul_ps(d3, d3));
3396    }
3397
3398    // Combine accumulators
3399    let combined = _mm_add_ps(_mm_add_ps(acc0, acc1), _mm_add_ps(acc2, acc3));
3400
3401    // Horizontal sum: [a, b, c, d] -> a + b + c + d
3402    let shuf = _mm_shuffle_ps(combined, combined, 0b10_11_00_01);
3403    let sums = _mm_add_ps(combined, shuf);
3404    let shuf2 = _mm_movehl_ps(sums, sums);
3405    let final_sum = _mm_add_ss(sums, shuf2);
3406
3407    let mut sum = _mm_cvtss_f32(final_sum);
3408
3409    // Handle remainder
3410    let base = chunks16 * 16;
3411    for i in 0..remainder {
3412        let d = a[base + i] - b[base + i];
3413        sum += d * d;
3414    }
3415
3416    sum
3417}
3418
3419#[cfg(target_arch = "x86_64")]
3420#[target_feature(enable = "avx2", enable = "fma")]
3421#[allow(unsafe_op_in_unsafe_fn)]
3422unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
3423    use std::arch::x86_64::*;
3424
3425    let chunks32 = count / 32;
3426    let remainder = count % 32;
3427
3428    // 4 independent accumulators to hide FMA latency
3429    let mut acc0 = _mm256_setzero_ps();
3430    let mut acc1 = _mm256_setzero_ps();
3431    let mut acc2 = _mm256_setzero_ps();
3432    let mut acc3 = _mm256_setzero_ps();
3433
3434    for c in 0..chunks32 {
3435        let base = c * 32;
3436
3437        let d0 = _mm256_sub_ps(
3438            _mm256_loadu_ps(a.as_ptr().add(base)),
3439            _mm256_loadu_ps(b.as_ptr().add(base)),
3440        );
3441        acc0 = _mm256_fmadd_ps(d0, d0, acc0);
3442
3443        let d1 = _mm256_sub_ps(
3444            _mm256_loadu_ps(a.as_ptr().add(base + 8)),
3445            _mm256_loadu_ps(b.as_ptr().add(base + 8)),
3446        );
3447        acc1 = _mm256_fmadd_ps(d1, d1, acc1);
3448
3449        let d2 = _mm256_sub_ps(
3450            _mm256_loadu_ps(a.as_ptr().add(base + 16)),
3451            _mm256_loadu_ps(b.as_ptr().add(base + 16)),
3452        );
3453        acc2 = _mm256_fmadd_ps(d2, d2, acc2);
3454
3455        let d3 = _mm256_sub_ps(
3456            _mm256_loadu_ps(a.as_ptr().add(base + 24)),
3457            _mm256_loadu_ps(b.as_ptr().add(base + 24)),
3458        );
3459        acc3 = _mm256_fmadd_ps(d3, d3, acc3);
3460    }
3461
3462    // Combine accumulators
3463    let combined = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
3464
3465    // Horizontal sum of 8 floats
3466    let high = _mm256_extractf128_ps(combined, 1);
3467    let low = _mm256_castps256_ps128(combined);
3468    let sum128 = _mm_add_ps(low, high);
3469
3470    let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
3471    let sums = _mm_add_ps(sum128, shuf);
3472    let shuf2 = _mm_movehl_ps(sums, sums);
3473    let final_sum = _mm_add_ss(sums, shuf2);
3474
3475    let mut sum = _mm_cvtss_f32(final_sum);
3476
3477    // Handle remainder
3478    let base = chunks32 * 32;
3479    for i in 0..remainder {
3480        let d = a[base + i] - b[base + i];
3481        sum += d * d;
3482    }
3483
3484    sum
3485}
3486
3487/// Batch compute squared Euclidean distances from one query to multiple vectors
3488///
3489/// Returns distances[i] = squared_euclidean_distance(query, vectors[i])
3490/// This is more efficient than calling squared_euclidean_distance in a loop
3491/// because we can keep the query in registers.
3492#[inline]
3493pub fn batch_squared_euclidean_distances(
3494    query: &[f32],
3495    vectors: &[Vec<f32>],
3496    distances: &mut [f32],
3497) {
3498    debug_assert_eq!(vectors.len(), distances.len());
3499
3500    #[cfg(target_arch = "x86_64")]
3501    {
3502        if avx2::is_available() && is_x86_feature_detected!("fma") {
3503            for (i, vec) in vectors.iter().enumerate() {
3504                distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
3505            }
3506            return;
3507        }
3508    }
3509
3510    // Fallback to individual calls
3511    for (i, vec) in vectors.iter().enumerate() {
3512        distances[i] = squared_euclidean_distance(query, vec);
3513    }
3514}
3515
3516#[cfg(test)]
3517mod tests {
3518    use super::*;
3519
3520    #[test]
3521    fn test_unpack_8bit() {
3522        let input: Vec<u8> = (0..128).collect();
3523        let mut output = vec![0u32; 128];
3524        unpack_8bit(&input, &mut output, 128);
3525
3526        for (i, &v) in output.iter().enumerate() {
3527            assert_eq!(v, i as u32);
3528        }
3529    }
3530
3531    #[test]
3532    fn test_unpack_16bit() {
3533        let mut input = vec![0u8; 256];
3534        for i in 0..128 {
3535            let val = (i * 100) as u16;
3536            input[i * 2] = val as u8;
3537            input[i * 2 + 1] = (val >> 8) as u8;
3538        }
3539
3540        let mut output = vec![0u32; 128];
3541        unpack_16bit(&input, &mut output, 128);
3542
3543        for (i, &v) in output.iter().enumerate() {
3544            assert_eq!(v, (i * 100) as u32);
3545        }
3546    }
3547
3548    #[test]
3549    fn test_unpack_32bit() {
3550        let mut input = vec![0u8; 512];
3551        for i in 0..128 {
3552            let val = (i * 1000) as u32;
3553            let bytes = val.to_le_bytes();
3554            input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
3555        }
3556
3557        let mut output = vec![0u32; 128];
3558        unpack_32bit(&input, &mut output, 128);
3559
3560        for (i, &v) in output.iter().enumerate() {
3561            assert_eq!(v, (i * 1000) as u32);
3562        }
3563    }
3564
3565    #[test]
3566    fn test_delta_decode() {
3567        // doc_ids: [10, 15, 20, 30, 50]
3568        // gaps: [5, 5, 10, 20]
3569        // deltas (gap-1): [4, 4, 9, 19]
3570        let deltas = vec![4u32, 4, 9, 19];
3571        let mut output = vec![0u32; 5];
3572
3573        delta_decode(&mut output, &deltas, 10, 5);
3574
3575        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3576    }
3577
3578    #[test]
3579    fn test_add_one() {
3580        let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
3581        add_one(&mut values, 8);
3582
3583        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
3584    }
3585
3586    #[test]
3587    fn test_bits_needed() {
3588        assert_eq!(bits_needed(0), 0);
3589        assert_eq!(bits_needed(1), 1);
3590        assert_eq!(bits_needed(2), 2);
3591        assert_eq!(bits_needed(3), 2);
3592        assert_eq!(bits_needed(4), 3);
3593        assert_eq!(bits_needed(255), 8);
3594        assert_eq!(bits_needed(256), 9);
3595        assert_eq!(bits_needed(u32::MAX), 32);
3596    }
3597
3598    #[test]
3599    fn test_unpack_8bit_delta_decode() {
3600        // doc_ids: [10, 15, 20, 30, 50]
3601        // gaps: [5, 5, 10, 20]
3602        // deltas (gap-1): [4, 4, 9, 19] stored as u8
3603        let input: Vec<u8> = vec![4, 4, 9, 19];
3604        let mut output = vec![0u32; 5];
3605
3606        unpack_8bit_delta_decode(&input, &mut output, 10, 5);
3607
3608        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3609    }
3610
3611    #[test]
3612    fn test_unpack_16bit_delta_decode() {
3613        // doc_ids: [100, 600, 1100, 2100, 4100]
3614        // gaps: [500, 500, 1000, 2000]
3615        // deltas (gap-1): [499, 499, 999, 1999] stored as u16
3616        let mut input = vec![0u8; 8];
3617        for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
3618            input[i * 2] = delta as u8;
3619            input[i * 2 + 1] = (delta >> 8) as u8;
3620        }
3621        let mut output = vec![0u32; 5];
3622
3623        unpack_16bit_delta_decode(&input, &mut output, 100, 5);
3624
3625        assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
3626    }
3627
3628    #[test]
3629    fn test_fused_vs_separate_8bit() {
3630        // Test that fused and separate operations produce the same result
3631        let input: Vec<u8> = (0..127).collect();
3632        let first_value = 1000u32;
3633        let count = 128;
3634
3635        // Separate: unpack then delta_decode
3636        let mut unpacked = vec![0u32; 128];
3637        unpack_8bit(&input, &mut unpacked, 127);
3638        let mut separate_output = vec![0u32; 128];
3639        delta_decode(&mut separate_output, &unpacked, first_value, count);
3640
3641        // Fused
3642        let mut fused_output = vec![0u32; 128];
3643        unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
3644
3645        assert_eq!(separate_output, fused_output);
3646    }
3647
3648    #[test]
3649    fn test_round_bit_width() {
3650        assert_eq!(round_bit_width(0), 0);
3651        assert_eq!(round_bit_width(1), 8);
3652        assert_eq!(round_bit_width(5), 8);
3653        assert_eq!(round_bit_width(8), 8);
3654        assert_eq!(round_bit_width(9), 16);
3655        assert_eq!(round_bit_width(12), 16);
3656        assert_eq!(round_bit_width(16), 16);
3657        assert_eq!(round_bit_width(17), 32);
3658        assert_eq!(round_bit_width(24), 32);
3659        assert_eq!(round_bit_width(32), 32);
3660    }
3661
3662    #[test]
3663    fn test_rounded_bitwidth_from_exact() {
3664        assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
3665        assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
3666        assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
3667        assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
3668        assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
3669        assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
3670        assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
3671    }
3672
3673    #[test]
3674    fn test_pack_unpack_rounded_8bit() {
3675        let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
3676        let mut packed = vec![0u8; 128];
3677
3678        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
3679        assert_eq!(bytes_written, 128);
3680
3681        let mut unpacked = vec![0u32; 128];
3682        unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
3683
3684        assert_eq!(values, unpacked);
3685    }
3686
3687    #[test]
3688    fn test_pack_unpack_rounded_16bit() {
3689        let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
3690        let mut packed = vec![0u8; 256];
3691
3692        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
3693        assert_eq!(bytes_written, 256);
3694
3695        let mut unpacked = vec![0u32; 128];
3696        unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
3697
3698        assert_eq!(values, unpacked);
3699    }
3700
3701    #[test]
3702    fn test_pack_unpack_rounded_32bit() {
3703        let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
3704        let mut packed = vec![0u8; 512];
3705
3706        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
3707        assert_eq!(bytes_written, 512);
3708
3709        let mut unpacked = vec![0u32; 128];
3710        unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
3711
3712        assert_eq!(values, unpacked);
3713    }
3714
3715    #[test]
3716    fn test_unpack_rounded_delta_decode() {
3717        // Test 8-bit rounded delta decode
3718        // doc_ids: [10, 15, 20, 30, 50]
3719        // gaps: [5, 5, 10, 20]
3720        // deltas (gap-1): [4, 4, 9, 19] stored as u8
3721        let input: Vec<u8> = vec![4, 4, 9, 19];
3722        let mut output = vec![0u32; 5];
3723
3724        unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
3725
3726        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3727    }
3728
3729    #[test]
3730    fn test_unpack_rounded_delta_decode_zero() {
3731        // All zeros means gaps of 1 (consecutive doc IDs)
3732        let input: Vec<u8> = vec![];
3733        let mut output = vec![0u32; 5];
3734
3735        unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
3736
3737        assert_eq!(output, vec![100, 101, 102, 103, 104]);
3738    }
3739
3740    // ========================================================================
3741    // Sparse Vector SIMD Tests
3742    // ========================================================================
3743
3744    #[test]
3745    fn test_dequantize_uint8() {
3746        let input: Vec<u8> = vec![0, 128, 255, 64, 192];
3747        let mut output = vec![0.0f32; 5];
3748        let scale = 0.1;
3749        let min_val = 1.0;
3750
3751        dequantize_uint8(&input, &mut output, scale, min_val, 5);
3752
3753        // Expected: input[i] * scale + min_val
3754        assert!((output[0] - 1.0).abs() < 1e-6); // 0 * 0.1 + 1.0 = 1.0
3755        assert!((output[1] - 13.8).abs() < 1e-6); // 128 * 0.1 + 1.0 = 13.8
3756        assert!((output[2] - 26.5).abs() < 1e-6); // 255 * 0.1 + 1.0 = 26.5
3757        assert!((output[3] - 7.4).abs() < 1e-6); // 64 * 0.1 + 1.0 = 7.4
3758        assert!((output[4] - 20.2).abs() < 1e-6); // 192 * 0.1 + 1.0 = 20.2
3759    }
3760
3761    #[test]
3762    fn test_dequantize_uint8_large() {
3763        // Test with 128 values (full SIMD block)
3764        let input: Vec<u8> = (0..128).collect();
3765        let mut output = vec![0.0f32; 128];
3766        let scale = 2.0;
3767        let min_val = -10.0;
3768
3769        dequantize_uint8(&input, &mut output, scale, min_val, 128);
3770
3771        for (i, &out) in output.iter().enumerate().take(128) {
3772            let expected = i as f32 * scale + min_val;
3773            assert!(
3774                (out - expected).abs() < 1e-5,
3775                "Mismatch at {}: expected {}, got {}",
3776                i,
3777                expected,
3778                out
3779            );
3780        }
3781    }
3782
3783    #[test]
3784    fn test_dot_product_f32() {
3785        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
3786        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
3787
3788        let result = dot_product_f32(&a, &b, 5);
3789
3790        // Expected: 1*2 + 2*3 + 3*4 + 4*5 + 5*6 = 2 + 6 + 12 + 20 + 30 = 70
3791        assert!((result - 70.0).abs() < 1e-5);
3792    }
3793
3794    #[test]
3795    fn test_dot_product_f32_large() {
3796        // Test with 128 values
3797        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
3798        let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
3799
3800        let result = dot_product_f32(&a, &b, 128);
3801
3802        // Compute expected
3803        let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
3804        assert!(
3805            (result - expected).abs() < 1e-3,
3806            "Expected {}, got {}",
3807            expected,
3808            result
3809        );
3810    }
3811
3812    #[test]
3813    fn test_max_f32() {
3814        let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
3815        let result = max_f32(&values, 6);
3816        assert!((result - 9.0).abs() < 1e-6);
3817    }
3818
3819    #[test]
3820    fn test_max_f32_large() {
3821        // Test with 128 values, max at position 77
3822        let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
3823        values[77] = 1000.0;
3824
3825        let result = max_f32(&values, 128);
3826        assert!((result - 1000.0).abs() < 1e-5);
3827    }
3828
3829    #[test]
3830    fn test_max_f32_negative() {
3831        let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
3832        let result = max_f32(&values, 5);
3833        assert!((result - (-1.0)).abs() < 1e-6);
3834    }
3835
3836    #[test]
3837    fn test_max_f32_empty() {
3838        let values: Vec<f32> = vec![];
3839        let result = max_f32(&values, 0);
3840        assert_eq!(result, f32::NEG_INFINITY);
3841    }
3842
3843    #[test]
3844    fn test_fused_dot_norm() {
3845        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3846        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3847        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3848
3849        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3850        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3851        assert!(
3852            (dot - expected_dot).abs() < 1e-5,
3853            "dot: expected {}, got {}",
3854            expected_dot,
3855            dot
3856        );
3857        assert!(
3858            (norm_b - expected_norm).abs() < 1e-5,
3859            "norm: expected {}, got {}",
3860            expected_norm,
3861            norm_b
3862        );
3863    }
3864
3865    #[test]
3866    fn test_fused_dot_norm_large() {
3867        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
3868        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
3869        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3870
3871        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3872        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3873        assert!(
3874            (dot - expected_dot).abs() < 1.0,
3875            "dot: expected {}, got {}",
3876            expected_dot,
3877            dot
3878        );
3879        assert!(
3880            (norm_b - expected_norm).abs() < 1.0,
3881            "norm: expected {}, got {}",
3882            expected_norm,
3883            norm_b
3884        );
3885    }
3886
3887    #[test]
3888    fn test_batch_cosine_scores() {
3889        // 4 vectors of dim 3
3890        let query = vec![1.0f32, 0.0, 0.0];
3891        let vectors = vec![
3892            1.0, 0.0, 0.0, // identical to query
3893            0.0, 1.0, 0.0, // orthogonal
3894            -1.0, 0.0, 0.0, // opposite
3895            0.5, 0.5, 0.0, // 45 degrees
3896        ];
3897        let mut scores = vec![0f32; 4];
3898        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3899
3900        assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
3901        assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
3902        assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
3903        let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
3904        assert!(
3905            (scores[3] - expected_45).abs() < 1e-5,
3906            "45deg: expected {}, got {}",
3907            expected_45,
3908            scores[3]
3909        );
3910    }
3911
3912    #[test]
3913    fn test_batch_cosine_scores_matches_individual() {
3914        let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
3915        let n = 50;
3916        let dim = 128;
3917        let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
3918
3919        let mut batch_scores = vec![0f32; n];
3920        batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
3921
3922        for i in 0..n {
3923            let vec_i = &vectors[i * dim..(i + 1) * dim];
3924            let individual = cosine_similarity(&query, vec_i);
3925            assert!(
3926                (batch_scores[i] - individual).abs() < 1e-5,
3927                "vec {}: batch={}, individual={}",
3928                i,
3929                batch_scores[i],
3930                individual
3931            );
3932        }
3933    }
3934
3935    #[test]
3936    fn test_batch_cosine_scores_empty() {
3937        let query = vec![1.0f32, 2.0, 3.0];
3938        let vectors: Vec<f32> = vec![];
3939        let mut scores: Vec<f32> = vec![];
3940        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3941        assert!(scores.is_empty());
3942    }
3943
3944    #[test]
3945    fn test_batch_cosine_scores_zero_query() {
3946        let query = vec![0.0f32, 0.0, 0.0];
3947        let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
3948        let mut scores = vec![0f32; 2];
3949        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3950        assert_eq!(scores[0], 0.0);
3951        assert_eq!(scores[1], 0.0);
3952    }
3953
3954    #[test]
3955    fn test_squared_euclidean_distance() {
3956        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3957        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3958        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3959        let result = squared_euclidean_distance(&a, &b);
3960        assert!(
3961            (result - expected).abs() < 1e-5,
3962            "expected {}, got {}",
3963            expected,
3964            result
3965        );
3966    }
3967
3968    #[test]
3969    fn test_squared_euclidean_distance_large() {
3970        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
3971        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
3972        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3973        let result = squared_euclidean_distance(&a, &b);
3974        assert!(
3975            (result - expected).abs() < 1e-3,
3976            "expected {}, got {}",
3977            expected,
3978            result
3979        );
3980    }
3981
3982    // ================================================================
3983    // f16 conversion tests
3984    // ================================================================
3985
3986    #[test]
3987    fn test_f16_roundtrip_normal() {
3988        for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
3989            let h = f32_to_f16(v);
3990            let back = f16_to_f32(h);
3991            let err = (back - v).abs() / v.abs().max(1e-6);
3992            assert!(
3993                err < 0.002,
3994                "f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
3995            );
3996        }
3997    }
3998
3999    #[test]
4000    fn test_f16_special() {
4001        // Zero
4002        assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
4003        // Negative zero
4004        assert_eq!(f32_to_f16(-0.0), 0x8000);
4005        // Infinity
4006        assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
4007        // NaN
4008        assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
4009    }
4010
4011    #[test]
4012    fn test_f16_embedding_range() {
4013        // Typical embedding values in [-1, 1]
4014        let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
4015        for &v in &values {
4016            let back = f16_to_f32(f32_to_f16(v));
4017            assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
4018        }
4019    }
4020
4021    // ================================================================
4022    // u8 conversion tests
4023    // ================================================================
4024
4025    #[test]
4026    fn test_u8_roundtrip() {
4027        // Boundary values
4028        assert_eq!(f32_to_u8_saturating(-1.0), 0);
4029        assert_eq!(f32_to_u8_saturating(1.0), 255);
4030        assert_eq!(f32_to_u8_saturating(0.0), 127); // ~127.5 truncated
4031
4032        // Saturation
4033        assert_eq!(f32_to_u8_saturating(-2.0), 0);
4034        assert_eq!(f32_to_u8_saturating(2.0), 255);
4035    }
4036
4037    #[test]
4038    fn test_u8_dequantize() {
4039        assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
4040        assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
4041        assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
4042    }
4043
4044    // ================================================================
4045    // Batch scoring tests for quantized vectors
4046    // ================================================================
4047
4048    #[test]
4049    fn test_batch_cosine_scores_f16() {
4050        let query = vec![0.6f32, 0.8, 0.0, 0.0];
4051        let dim = 4;
4052        let vecs_f32 = vec![
4053            0.6f32, 0.8, 0.0, 0.0, // identical to query
4054            0.0, 0.0, 0.6, 0.8, // orthogonal
4055        ];
4056
4057        // Quantize to f16
4058        let mut f16_buf = vec![0u16; 8];
4059        batch_f32_to_f16(&vecs_f32, &mut f16_buf);
4060        let raw: &[u8] =
4061            unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
4062
4063        let mut scores = vec![0f32; 2];
4064        batch_cosine_scores_f16(&query, raw, dim, &mut scores);
4065
4066        assert!(
4067            (scores[0] - 1.0).abs() < 0.01,
4068            "identical vectors: {}",
4069            scores[0]
4070        );
4071        assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
4072    }
4073
4074    #[test]
4075    fn test_batch_cosine_scores_u8() {
4076        let query = vec![0.6f32, 0.8, 0.0, 0.0];
4077        let dim = 4;
4078        let vecs_f32 = vec![
4079            0.6f32, 0.8, 0.0, 0.0, // ~identical to query
4080            -0.6, -0.8, 0.0, 0.0, // opposite
4081        ];
4082
4083        // Quantize to u8
4084        let mut u8_buf = vec![0u8; 8];
4085        batch_f32_to_u8(&vecs_f32, &mut u8_buf);
4086
4087        let mut scores = vec![0f32; 2];
4088        batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
4089
4090        assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
4091        assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
4092    }
4093
4094    #[test]
4095    fn test_batch_cosine_scores_f16_large_dim() {
4096        // Test with typical embedding dimension
4097        let dim = 768;
4098        let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
4099        let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
4100
4101        let mut all_vecs = query.clone();
4102        all_vecs.extend_from_slice(&vec2);
4103
4104        let mut f16_buf = vec![0u16; all_vecs.len()];
4105        batch_f32_to_f16(&all_vecs, &mut f16_buf);
4106        let raw: &[u8] =
4107            unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
4108
4109        let mut scores = vec![0f32; 2];
4110        batch_cosine_scores_f16(&query, raw, dim, &mut scores);
4111
4112        // Self-similarity should be ~1.0
4113        assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
4114        // High similarity with scaled version
4115        assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
4116    }
4117}
4118
4119// ============================================================================
4120// SIMD-accelerated linear scan for sorted u32 slices (within-block seek)
4121// ============================================================================
4122
4123/// Find index of first element >= `target` in a sorted `u32` slice.
4124///
4125/// Equivalent to `slice.partition_point(|&d| d < target)` but uses SIMD to
4126/// scan 4 elements per cycle. Faster than binary search for slices ≤ 256
4127/// elements because it avoids the data-dependency chain inherent in binary
4128/// search (~8-10 cycles/iteration vs ~1-2 cycles/iteration for SIMD scan).
4129///
4130/// Returns `slice.len()` if no element >= `target`.
4131#[inline]
4132pub fn find_first_ge_u32(slice: &[u32], target: u32) -> usize {
4133    #[cfg(target_arch = "aarch64")]
4134    {
4135        if neon::is_available() {
4136            return unsafe { find_first_ge_u32_neon(slice, target) };
4137        }
4138    }
4139
4140    #[cfg(target_arch = "x86_64")]
4141    {
4142        if sse::is_available() {
4143            return unsafe { find_first_ge_u32_sse(slice, target) };
4144        }
4145    }
4146
4147    // Scalar fallback (WASM, other architectures)
4148    slice.partition_point(|&d| d < target)
4149}
4150
4151#[cfg(target_arch = "aarch64")]
4152#[target_feature(enable = "neon")]
4153#[allow(unsafe_op_in_unsafe_fn)]
4154unsafe fn find_first_ge_u32_neon(slice: &[u32], target: u32) -> usize {
4155    use std::arch::aarch64::*;
4156
4157    let n = slice.len();
4158    let ptr = slice.as_ptr();
4159    let target_vec = vdupq_n_u32(target);
4160    // Bit positions for each lane: [1, 2, 4, 8]
4161    let bit_mask: uint32x4_t = core::mem::transmute([1u32, 2u32, 4u32, 8u32]);
4162
4163    let chunks = n / 16;
4164    let mut base = 0usize;
4165
4166    // Process 16 elements per iteration (4 × 4-wide NEON compares)
4167    for _ in 0..chunks {
4168        let v0 = vld1q_u32(ptr.add(base));
4169        let v1 = vld1q_u32(ptr.add(base + 4));
4170        let v2 = vld1q_u32(ptr.add(base + 8));
4171        let v3 = vld1q_u32(ptr.add(base + 12));
4172
4173        let c0 = vcgeq_u32(v0, target_vec);
4174        let c1 = vcgeq_u32(v1, target_vec);
4175        let c2 = vcgeq_u32(v2, target_vec);
4176        let c3 = vcgeq_u32(v3, target_vec);
4177
4178        let m0 = vaddvq_u32(vandq_u32(c0, bit_mask));
4179        if m0 != 0 {
4180            return base + m0.trailing_zeros() as usize;
4181        }
4182        let m1 = vaddvq_u32(vandq_u32(c1, bit_mask));
4183        if m1 != 0 {
4184            return base + 4 + m1.trailing_zeros() as usize;
4185        }
4186        let m2 = vaddvq_u32(vandq_u32(c2, bit_mask));
4187        if m2 != 0 {
4188            return base + 8 + m2.trailing_zeros() as usize;
4189        }
4190        let m3 = vaddvq_u32(vandq_u32(c3, bit_mask));
4191        if m3 != 0 {
4192            return base + 12 + m3.trailing_zeros() as usize;
4193        }
4194        base += 16;
4195    }
4196
4197    // Process remaining 4 elements at a time
4198    while base + 4 <= n {
4199        let vals = vld1q_u32(ptr.add(base));
4200        let cmp = vcgeq_u32(vals, target_vec);
4201        let mask = vaddvq_u32(vandq_u32(cmp, bit_mask));
4202        if mask != 0 {
4203            return base + mask.trailing_zeros() as usize;
4204        }
4205        base += 4;
4206    }
4207
4208    // Scalar remainder (0-3 elements)
4209    while base < n {
4210        if *slice.get_unchecked(base) >= target {
4211            return base;
4212        }
4213        base += 1;
4214    }
4215    n
4216}
4217
4218#[cfg(target_arch = "x86_64")]
4219#[target_feature(enable = "sse2")]
4220#[allow(unsafe_op_in_unsafe_fn)]
4221unsafe fn find_first_ge_u32_sse(slice: &[u32], target: u32) -> usize {
4222    use std::arch::x86_64::*;
4223
4224    let n = slice.len();
4225    let ptr = slice.as_ptr();
4226
4227    // For unsigned >= comparison: XOR with 0x80000000 converts to signed domain
4228    let sign_flip = _mm_set1_epi32(i32::MIN);
4229    let target_xor = _mm_xor_si128(_mm_set1_epi32(target as i32), sign_flip);
4230
4231    let chunks = n / 16;
4232    let mut base = 0usize;
4233
4234    // Process 16 elements per iteration (4 × 4-wide SSE compares)
4235    for _ in 0..chunks {
4236        let v0 = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4237        let v1 = _mm_xor_si128(
4238            _mm_loadu_si128(ptr.add(base + 4) as *const __m128i),
4239            sign_flip,
4240        );
4241        let v2 = _mm_xor_si128(
4242            _mm_loadu_si128(ptr.add(base + 8) as *const __m128i),
4243            sign_flip,
4244        );
4245        let v3 = _mm_xor_si128(
4246            _mm_loadu_si128(ptr.add(base + 12) as *const __m128i),
4247            sign_flip,
4248        );
4249
4250        // ge = eq | gt (in signed domain after XOR)
4251        let ge0 = _mm_or_si128(
4252            _mm_cmpeq_epi32(v0, target_xor),
4253            _mm_cmpgt_epi32(v0, target_xor),
4254        );
4255        let m0 = _mm_movemask_ps(_mm_castsi128_ps(ge0)) as u32;
4256        if m0 != 0 {
4257            return base + m0.trailing_zeros() as usize;
4258        }
4259
4260        let ge1 = _mm_or_si128(
4261            _mm_cmpeq_epi32(v1, target_xor),
4262            _mm_cmpgt_epi32(v1, target_xor),
4263        );
4264        let m1 = _mm_movemask_ps(_mm_castsi128_ps(ge1)) as u32;
4265        if m1 != 0 {
4266            return base + 4 + m1.trailing_zeros() as usize;
4267        }
4268
4269        let ge2 = _mm_or_si128(
4270            _mm_cmpeq_epi32(v2, target_xor),
4271            _mm_cmpgt_epi32(v2, target_xor),
4272        );
4273        let m2 = _mm_movemask_ps(_mm_castsi128_ps(ge2)) as u32;
4274        if m2 != 0 {
4275            return base + 8 + m2.trailing_zeros() as usize;
4276        }
4277
4278        let ge3 = _mm_or_si128(
4279            _mm_cmpeq_epi32(v3, target_xor),
4280            _mm_cmpgt_epi32(v3, target_xor),
4281        );
4282        let m3 = _mm_movemask_ps(_mm_castsi128_ps(ge3)) as u32;
4283        if m3 != 0 {
4284            return base + 12 + m3.trailing_zeros() as usize;
4285        }
4286        base += 16;
4287    }
4288
4289    // Process remaining 4 elements at a time
4290    while base + 4 <= n {
4291        let vals = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4292        let ge = _mm_or_si128(
4293            _mm_cmpeq_epi32(vals, target_xor),
4294            _mm_cmpgt_epi32(vals, target_xor),
4295        );
4296        let mask = _mm_movemask_ps(_mm_castsi128_ps(ge)) as u32;
4297        if mask != 0 {
4298            return base + mask.trailing_zeros() as usize;
4299        }
4300        base += 4;
4301    }
4302
4303    // Scalar remainder (0-3 elements)
4304    while base < n {
4305        if *slice.get_unchecked(base) >= target {
4306            return base;
4307        }
4308        base += 1;
4309    }
4310    n
4311}
4312
4313#[cfg(test)]
4314mod find_first_ge_tests {
4315    use super::find_first_ge_u32;
4316
4317    #[test]
4318    fn test_find_first_ge_basic() {
4319        let data: Vec<u32> = (0..128).map(|i| i * 3).collect(); // [0, 3, 6, ..., 381]
4320        assert_eq!(find_first_ge_u32(&data, 0), 0);
4321        assert_eq!(find_first_ge_u32(&data, 1), 1); // first >= 1 is 3 at idx 1
4322        assert_eq!(find_first_ge_u32(&data, 3), 1);
4323        assert_eq!(find_first_ge_u32(&data, 4), 2); // first >= 4 is 6 at idx 2
4324        assert_eq!(find_first_ge_u32(&data, 381), 127);
4325        assert_eq!(find_first_ge_u32(&data, 382), 128); // past end
4326    }
4327
4328    #[test]
4329    fn test_find_first_ge_matches_partition_point() {
4330        let data: Vec<u32> = vec![1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75];
4331        for target in 0..80 {
4332            let expected = data.partition_point(|&d| d < target);
4333            let actual = find_first_ge_u32(&data, target);
4334            assert_eq!(actual, expected, "target={}", target);
4335        }
4336    }
4337
4338    #[test]
4339    fn test_find_first_ge_small_slices() {
4340        // Empty
4341        assert_eq!(find_first_ge_u32(&[], 5), 0);
4342        // Single element
4343        assert_eq!(find_first_ge_u32(&[10], 5), 0);
4344        assert_eq!(find_first_ge_u32(&[10], 10), 0);
4345        assert_eq!(find_first_ge_u32(&[10], 11), 1);
4346        // Three elements (< SIMD width)
4347        assert_eq!(find_first_ge_u32(&[2, 4, 6], 5), 2);
4348    }
4349
4350    #[test]
4351    fn test_find_first_ge_full_block() {
4352        // Simulate a full 128-entry block
4353        let data: Vec<u32> = (100..228).collect();
4354        assert_eq!(find_first_ge_u32(&data, 100), 0);
4355        assert_eq!(find_first_ge_u32(&data, 150), 50);
4356        assert_eq!(find_first_ge_u32(&data, 227), 127);
4357        assert_eq!(find_first_ge_u32(&data, 228), 128);
4358        assert_eq!(find_first_ge_u32(&data, 99), 0);
4359    }
4360
4361    #[test]
4362    fn test_find_first_ge_u32_max() {
4363        // Test with large u32 values (unsigned correctness)
4364        let data = vec![u32::MAX - 10, u32::MAX - 5, u32::MAX - 1, u32::MAX];
4365        assert_eq!(find_first_ge_u32(&data, u32::MAX - 10), 0);
4366        assert_eq!(find_first_ge_u32(&data, u32::MAX - 7), 1);
4367        assert_eq!(find_first_ge_u32(&data, u32::MAX), 3);
4368    }
4369}