Skip to main content

hermes_core/structures/
simd.rs

1//! Shared SIMD-accelerated functions for posting list compression
2//!
3//! This module provides platform-optimized implementations for common operations:
4//! - **Unpacking**: Convert packed 8/16/32-bit values to u32 arrays
5//! - **Delta decoding**: Prefix sum for converting deltas to absolute values
6//! - **Add one**: Increment all values in an array (for TF decoding)
7//!
8//! Supports:
9//! - **NEON** on aarch64 (Apple Silicon, ARM servers)
10//! - **SSE/SSE4.1** on x86_64 (Intel/AMD)
11//! - **Scalar fallback** for other architectures
12
13// ============================================================================
14// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
15// ============================================================================
16
17#[cfg(target_arch = "aarch64")]
18#[allow(unsafe_op_in_unsafe_fn)]
19mod neon {
20    use std::arch::aarch64::*;
21
22    /// SIMD unpack for 8-bit values using NEON
23    #[target_feature(enable = "neon")]
24    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
25        let chunks = count / 16;
26        let remainder = count % 16;
27
28        for chunk in 0..chunks {
29            let base = chunk * 16;
30            let in_ptr = input.as_ptr().add(base);
31
32            // Load 16 bytes
33            let bytes = vld1q_u8(in_ptr);
34
35            // Widen u8 -> u16 -> u32
36            let low8 = vget_low_u8(bytes);
37            let high8 = vget_high_u8(bytes);
38
39            let low16 = vmovl_u8(low8);
40            let high16 = vmovl_u8(high8);
41
42            let v0 = vmovl_u16(vget_low_u16(low16));
43            let v1 = vmovl_u16(vget_high_u16(low16));
44            let v2 = vmovl_u16(vget_low_u16(high16));
45            let v3 = vmovl_u16(vget_high_u16(high16));
46
47            let out_ptr = output.as_mut_ptr().add(base);
48            vst1q_u32(out_ptr, v0);
49            vst1q_u32(out_ptr.add(4), v1);
50            vst1q_u32(out_ptr.add(8), v2);
51            vst1q_u32(out_ptr.add(12), v3);
52        }
53
54        // Handle remainder
55        let base = chunks * 16;
56        for i in 0..remainder {
57            output[base + i] = input[base + i] as u32;
58        }
59    }
60
61    /// SIMD unpack for 16-bit values using NEON
62    #[target_feature(enable = "neon")]
63    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
64        let chunks = count / 8;
65        let remainder = count % 8;
66
67        for chunk in 0..chunks {
68            let base = chunk * 8;
69            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
70
71            let vals = vld1q_u16(in_ptr);
72            let low = vmovl_u16(vget_low_u16(vals));
73            let high = vmovl_u16(vget_high_u16(vals));
74
75            let out_ptr = output.as_mut_ptr().add(base);
76            vst1q_u32(out_ptr, low);
77            vst1q_u32(out_ptr.add(4), high);
78        }
79
80        // Handle remainder
81        let base = chunks * 8;
82        for i in 0..remainder {
83            let idx = (base + i) * 2;
84            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
85        }
86    }
87
88    /// SIMD unpack for 32-bit values using NEON (fast copy)
89    #[target_feature(enable = "neon")]
90    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
91        let chunks = count / 4;
92        let remainder = count % 4;
93
94        let in_ptr = input.as_ptr() as *const u32;
95        let out_ptr = output.as_mut_ptr();
96
97        for chunk in 0..chunks {
98            let vals = vld1q_u32(in_ptr.add(chunk * 4));
99            vst1q_u32(out_ptr.add(chunk * 4), vals);
100        }
101
102        // Handle remainder
103        let base = chunks * 4;
104        for i in 0..remainder {
105            let idx = (base + i) * 4;
106            output[base + i] =
107                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
108        }
109    }
110
111    /// SIMD prefix sum for 4 u32 values using NEON
112    /// Input:  [a, b, c, d]
113    /// Output: [a, a+b, a+b+c, a+b+c+d]
114    #[inline]
115    #[target_feature(enable = "neon")]
116    unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
117        // Step 1: shift by 1 and add
118        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
119        let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
120        let sum1 = vaddq_u32(v, shifted1);
121
122        // Step 2: shift by 2 and add
123        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
124        let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
125        vaddq_u32(sum1, shifted2)
126    }
127
128    /// SIMD delta decode: convert deltas to absolute doc IDs
129    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
130    /// Uses NEON SIMD prefix sum for high throughput
131    #[target_feature(enable = "neon")]
132    pub unsafe fn delta_decode(
133        output: &mut [u32],
134        deltas: &[u32],
135        first_doc_id: u32,
136        count: usize,
137    ) {
138        if count == 0 {
139            return;
140        }
141
142        output[0] = first_doc_id;
143        if count == 1 {
144            return;
145        }
146
147        let ones = vdupq_n_u32(1);
148        let mut carry = vdupq_n_u32(first_doc_id);
149
150        let full_groups = (count - 1) / 4;
151        let remainder = (count - 1) % 4;
152
153        for group in 0..full_groups {
154            let base = group * 4;
155
156            // Load 4 deltas and add 1 (since we store gap-1)
157            let d = vld1q_u32(deltas[base..].as_ptr());
158            let gaps = vaddq_u32(d, ones);
159
160            // Compute prefix sum within the 4 elements
161            let prefix = prefix_sum_4(gaps);
162
163            // Add carry (broadcast last element of previous group)
164            let result = vaddq_u32(prefix, carry);
165
166            // Store result
167            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
168
169            // Update carry: broadcast the last element for next iteration
170            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
171        }
172
173        // Handle remainder
174        let base = full_groups * 4;
175        let mut scalar_carry = vgetq_lane_u32(carry, 0);
176        for j in 0..remainder {
177            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
178            output[base + j + 1] = scalar_carry;
179        }
180    }
181
182    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
183    #[target_feature(enable = "neon")]
184    pub unsafe fn add_one(values: &mut [u32], count: usize) {
185        let ones = vdupq_n_u32(1);
186        let chunks = count / 4;
187        let remainder = count % 4;
188
189        for chunk in 0..chunks {
190            let base = chunk * 4;
191            let ptr = values.as_mut_ptr().add(base);
192            let v = vld1q_u32(ptr);
193            let result = vaddq_u32(v, ones);
194            vst1q_u32(ptr, result);
195        }
196
197        let base = chunks * 4;
198        for i in 0..remainder {
199            values[base + i] += 1;
200        }
201    }
202
203    /// Fused unpack 8-bit + delta decode using NEON
204    /// Processes 4 values at a time, fusing unpack and prefix sum
205    #[target_feature(enable = "neon")]
206    pub unsafe fn unpack_8bit_delta_decode(
207        input: &[u8],
208        output: &mut [u32],
209        first_value: u32,
210        count: usize,
211    ) {
212        output[0] = first_value;
213        if count <= 1 {
214            return;
215        }
216
217        let ones = vdupq_n_u32(1);
218        let mut carry = vdupq_n_u32(first_value);
219
220        let full_groups = (count - 1) / 4;
221        let remainder = (count - 1) % 4;
222
223        for group in 0..full_groups {
224            let base = group * 4;
225
226            // Load 4 bytes and widen to u32
227            let b0 = input[base] as u32;
228            let b1 = input[base + 1] as u32;
229            let b2 = input[base + 2] as u32;
230            let b3 = input[base + 3] as u32;
231            let deltas = [b0, b1, b2, b3];
232            let d = vld1q_u32(deltas.as_ptr());
233
234            // Add 1 (since we store gap-1)
235            let gaps = vaddq_u32(d, ones);
236
237            // Compute prefix sum within the 4 elements
238            let prefix = prefix_sum_4(gaps);
239
240            // Add carry
241            let result = vaddq_u32(prefix, carry);
242
243            // Store result
244            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
245
246            // Update carry
247            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
248        }
249
250        // Handle remainder
251        let base = full_groups * 4;
252        let mut scalar_carry = vgetq_lane_u32(carry, 0);
253        for j in 0..remainder {
254            scalar_carry = scalar_carry
255                .wrapping_add(input[base + j] as u32)
256                .wrapping_add(1);
257            output[base + j + 1] = scalar_carry;
258        }
259    }
260
261    /// Fused unpack 16-bit + delta decode using NEON
262    #[target_feature(enable = "neon")]
263    pub unsafe fn unpack_16bit_delta_decode(
264        input: &[u8],
265        output: &mut [u32],
266        first_value: u32,
267        count: usize,
268    ) {
269        output[0] = first_value;
270        if count <= 1 {
271            return;
272        }
273
274        let ones = vdupq_n_u32(1);
275        let mut carry = vdupq_n_u32(first_value);
276
277        let full_groups = (count - 1) / 4;
278        let remainder = (count - 1) % 4;
279
280        for group in 0..full_groups {
281            let base = group * 4;
282            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
283
284            // Load 4 u16 values and widen to u32
285            let vals = vld1_u16(in_ptr);
286            let d = vmovl_u16(vals);
287
288            // Add 1 (since we store gap-1)
289            let gaps = vaddq_u32(d, ones);
290
291            // Compute prefix sum within the 4 elements
292            let prefix = prefix_sum_4(gaps);
293
294            // Add carry
295            let result = vaddq_u32(prefix, carry);
296
297            // Store result
298            vst1q_u32(output[base + 1..].as_mut_ptr(), result);
299
300            // Update carry
301            carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
302        }
303
304        // Handle remainder
305        let base = full_groups * 4;
306        let mut scalar_carry = vgetq_lane_u32(carry, 0);
307        for j in 0..remainder {
308            let idx = (base + j) * 2;
309            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
310            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
311            output[base + j + 1] = scalar_carry;
312        }
313    }
314
315    /// Check if NEON is available (always true on aarch64)
316    #[inline]
317    pub fn is_available() -> bool {
318        true
319    }
320}
321
322// ============================================================================
323// SSE intrinsics for x86_64 (Intel/AMD)
324// ============================================================================
325
326#[cfg(target_arch = "x86_64")]
327#[allow(unsafe_op_in_unsafe_fn)]
328mod sse {
329    use std::arch::x86_64::*;
330
331    /// SIMD unpack for 8-bit values using SSE
332    #[target_feature(enable = "sse2", enable = "sse4.1")]
333    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
334        let chunks = count / 16;
335        let remainder = count % 16;
336
337        for chunk in 0..chunks {
338            let base = chunk * 16;
339            let in_ptr = input.as_ptr().add(base);
340
341            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
342
343            // Zero extend u8 -> u32 using SSE4.1 pmovzx
344            let v0 = _mm_cvtepu8_epi32(bytes);
345            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
346            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
347            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
348
349            let out_ptr = output.as_mut_ptr().add(base);
350            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
351            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
352            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
353            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
354        }
355
356        let base = chunks * 16;
357        for i in 0..remainder {
358            output[base + i] = input[base + i] as u32;
359        }
360    }
361
362    /// SIMD unpack for 16-bit values using SSE
363    #[target_feature(enable = "sse2", enable = "sse4.1")]
364    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
365        let chunks = count / 8;
366        let remainder = count % 8;
367
368        for chunk in 0..chunks {
369            let base = chunk * 8;
370            let in_ptr = input.as_ptr().add(base * 2);
371
372            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
373            let low = _mm_cvtepu16_epi32(vals);
374            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
375
376            let out_ptr = output.as_mut_ptr().add(base);
377            _mm_storeu_si128(out_ptr as *mut __m128i, low);
378            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
379        }
380
381        let base = chunks * 8;
382        for i in 0..remainder {
383            let idx = (base + i) * 2;
384            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
385        }
386    }
387
388    /// SIMD unpack for 32-bit values using SSE (fast copy)
389    #[target_feature(enable = "sse2")]
390    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
391        let chunks = count / 4;
392        let remainder = count % 4;
393
394        let in_ptr = input.as_ptr() as *const __m128i;
395        let out_ptr = output.as_mut_ptr() as *mut __m128i;
396
397        for chunk in 0..chunks {
398            let vals = _mm_loadu_si128(in_ptr.add(chunk));
399            _mm_storeu_si128(out_ptr.add(chunk), vals);
400        }
401
402        // Handle remainder
403        let base = chunks * 4;
404        for i in 0..remainder {
405            let idx = (base + i) * 4;
406            output[base + i] =
407                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
408        }
409    }
410
411    /// SIMD prefix sum for 4 u32 values using SSE
412    /// Input:  [a, b, c, d]
413    /// Output: [a, a+b, a+b+c, a+b+c+d]
414    #[inline]
415    #[target_feature(enable = "sse2")]
416    unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
417        // Step 1: shift by 1 element (4 bytes) and add
418        // [a, b, c, d] + [0, a, b, c] = [a, a+b, b+c, c+d]
419        let shifted1 = _mm_slli_si128(v, 4);
420        let sum1 = _mm_add_epi32(v, shifted1);
421
422        // Step 2: shift by 2 elements (8 bytes) and add
423        // [a, a+b, b+c, c+d] + [0, 0, a, a+b] = [a, a+b, a+b+c, a+b+c+d]
424        let shifted2 = _mm_slli_si128(sum1, 8);
425        _mm_add_epi32(sum1, shifted2)
426    }
427
428    /// SIMD delta decode using SSE with true SIMD prefix sum
429    #[target_feature(enable = "sse2", enable = "sse4.1")]
430    pub unsafe fn delta_decode(
431        output: &mut [u32],
432        deltas: &[u32],
433        first_doc_id: u32,
434        count: usize,
435    ) {
436        if count == 0 {
437            return;
438        }
439
440        output[0] = first_doc_id;
441        if count == 1 {
442            return;
443        }
444
445        let ones = _mm_set1_epi32(1);
446        let mut carry = _mm_set1_epi32(first_doc_id as i32);
447
448        let full_groups = (count - 1) / 4;
449        let remainder = (count - 1) % 4;
450
451        for group in 0..full_groups {
452            let base = group * 4;
453
454            // Load 4 deltas and add 1 (since we store gap-1)
455            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
456            let gaps = _mm_add_epi32(d, ones);
457
458            // Compute prefix sum within the 4 elements
459            let prefix = prefix_sum_4(gaps);
460
461            // Add carry (broadcast last element of previous group)
462            let result = _mm_add_epi32(prefix, carry);
463
464            // Store result
465            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
466
467            // Update carry: broadcast the last element for next iteration
468            carry = _mm_shuffle_epi32(result, 0xFF); // broadcast lane 3
469        }
470
471        // Handle remainder
472        let base = full_groups * 4;
473        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
474        for j in 0..remainder {
475            scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
476            output[base + j + 1] = scalar_carry;
477        }
478    }
479
480    /// SIMD add 1 to all values using SSE
481    #[target_feature(enable = "sse2")]
482    pub unsafe fn add_one(values: &mut [u32], count: usize) {
483        let ones = _mm_set1_epi32(1);
484        let chunks = count / 4;
485        let remainder = count % 4;
486
487        for chunk in 0..chunks {
488            let base = chunk * 4;
489            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
490            let v = _mm_loadu_si128(ptr);
491            let result = _mm_add_epi32(v, ones);
492            _mm_storeu_si128(ptr, result);
493        }
494
495        let base = chunks * 4;
496        for i in 0..remainder {
497            values[base + i] += 1;
498        }
499    }
500
501    /// Fused unpack 8-bit + delta decode using SSE
502    #[target_feature(enable = "sse2", enable = "sse4.1")]
503    pub unsafe fn unpack_8bit_delta_decode(
504        input: &[u8],
505        output: &mut [u32],
506        first_value: u32,
507        count: usize,
508    ) {
509        output[0] = first_value;
510        if count <= 1 {
511            return;
512        }
513
514        let ones = _mm_set1_epi32(1);
515        let mut carry = _mm_set1_epi32(first_value as i32);
516
517        let full_groups = (count - 1) / 4;
518        let remainder = (count - 1) % 4;
519
520        for group in 0..full_groups {
521            let base = group * 4;
522
523            // Load 4 bytes (unaligned) and zero-extend to u32
524            let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
525                input.as_ptr().add(base) as *const i32
526            ));
527            let d = _mm_cvtepu8_epi32(bytes);
528
529            // Add 1 (since we store gap-1)
530            let gaps = _mm_add_epi32(d, ones);
531
532            // Compute prefix sum within the 4 elements
533            let prefix = prefix_sum_4(gaps);
534
535            // Add carry
536            let result = _mm_add_epi32(prefix, carry);
537
538            // Store result
539            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
540
541            // Update carry: broadcast the last element
542            carry = _mm_shuffle_epi32(result, 0xFF);
543        }
544
545        // Handle remainder
546        let base = full_groups * 4;
547        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
548        for j in 0..remainder {
549            scalar_carry = scalar_carry
550                .wrapping_add(input[base + j] as u32)
551                .wrapping_add(1);
552            output[base + j + 1] = scalar_carry;
553        }
554    }
555
556    /// Fused unpack 16-bit + delta decode using SSE
557    #[target_feature(enable = "sse2", enable = "sse4.1")]
558    pub unsafe fn unpack_16bit_delta_decode(
559        input: &[u8],
560        output: &mut [u32],
561        first_value: u32,
562        count: usize,
563    ) {
564        output[0] = first_value;
565        if count <= 1 {
566            return;
567        }
568
569        let ones = _mm_set1_epi32(1);
570        let mut carry = _mm_set1_epi32(first_value as i32);
571
572        let full_groups = (count - 1) / 4;
573        let remainder = (count - 1) % 4;
574
575        for group in 0..full_groups {
576            let base = group * 4;
577            let in_ptr = input.as_ptr().add(base * 2);
578
579            // Load 8 bytes (4 u16 values, unaligned) and zero-extend to u32
580            let vals = _mm_loadl_epi64(in_ptr as *const __m128i); // loadl_epi64 supports unaligned
581            let d = _mm_cvtepu16_epi32(vals);
582
583            // Add 1 (since we store gap-1)
584            let gaps = _mm_add_epi32(d, ones);
585
586            // Compute prefix sum within the 4 elements
587            let prefix = prefix_sum_4(gaps);
588
589            // Add carry
590            let result = _mm_add_epi32(prefix, carry);
591
592            // Store result
593            _mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
594
595            // Update carry: broadcast the last element
596            carry = _mm_shuffle_epi32(result, 0xFF);
597        }
598
599        // Handle remainder
600        let base = full_groups * 4;
601        let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
602        for j in 0..remainder {
603            let idx = (base + j) * 2;
604            let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
605            scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
606            output[base + j + 1] = scalar_carry;
607        }
608    }
609
610    /// Check if SSE4.1 is available at runtime
611    #[inline]
612    pub fn is_available() -> bool {
613        is_x86_feature_detected!("sse4.1")
614    }
615}
616
617// ============================================================================
618// AVX2 intrinsics for x86_64 (Intel/AMD with 256-bit registers)
619// ============================================================================
620
621#[cfg(target_arch = "x86_64")]
622#[allow(unsafe_op_in_unsafe_fn)]
623mod avx2 {
624    use std::arch::x86_64::*;
625
626    /// AVX2 unpack for 8-bit values (processes 32 bytes at a time)
627    #[target_feature(enable = "avx2")]
628    pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
629        let chunks = count / 32;
630        let remainder = count % 32;
631
632        for chunk in 0..chunks {
633            let base = chunk * 32;
634            let in_ptr = input.as_ptr().add(base);
635
636            // Load 32 bytes (two 128-bit loads, then combine)
637            let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
638            let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
639
640            // Zero extend first 16 bytes: u8 -> u32
641            let v0 = _mm256_cvtepu8_epi32(bytes_lo);
642            let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
643            let v2 = _mm256_cvtepu8_epi32(bytes_hi);
644            let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
645
646            let out_ptr = output.as_mut_ptr().add(base);
647            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
648            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
649            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
650            _mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
651        }
652
653        // Handle remainder with SSE
654        let base = chunks * 32;
655        for i in 0..remainder {
656            output[base + i] = input[base + i] as u32;
657        }
658    }
659
660    /// AVX2 unpack for 16-bit values (processes 16 values at a time)
661    #[target_feature(enable = "avx2")]
662    pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
663        let chunks = count / 16;
664        let remainder = count % 16;
665
666        for chunk in 0..chunks {
667            let base = chunk * 16;
668            let in_ptr = input.as_ptr().add(base * 2);
669
670            // Load 32 bytes (16 u16 values)
671            let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
672            let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
673
674            // Zero extend u16 -> u32
675            let v0 = _mm256_cvtepu16_epi32(vals_lo);
676            let v1 = _mm256_cvtepu16_epi32(vals_hi);
677
678            let out_ptr = output.as_mut_ptr().add(base);
679            _mm256_storeu_si256(out_ptr as *mut __m256i, v0);
680            _mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
681        }
682
683        // Handle remainder
684        let base = chunks * 16;
685        for i in 0..remainder {
686            let idx = (base + i) * 2;
687            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
688        }
689    }
690
691    /// AVX2 unpack for 32-bit values (fast copy, 8 values at a time)
692    #[target_feature(enable = "avx2")]
693    pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
694        let chunks = count / 8;
695        let remainder = count % 8;
696
697        let in_ptr = input.as_ptr() as *const __m256i;
698        let out_ptr = output.as_mut_ptr() as *mut __m256i;
699
700        for chunk in 0..chunks {
701            let vals = _mm256_loadu_si256(in_ptr.add(chunk));
702            _mm256_storeu_si256(out_ptr.add(chunk), vals);
703        }
704
705        // Handle remainder
706        let base = chunks * 8;
707        for i in 0..remainder {
708            let idx = (base + i) * 4;
709            output[base + i] =
710                u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
711        }
712    }
713
714    /// AVX2 add 1 to all values (8 values at a time)
715    #[target_feature(enable = "avx2")]
716    pub unsafe fn add_one(values: &mut [u32], count: usize) {
717        let ones = _mm256_set1_epi32(1);
718        let chunks = count / 8;
719        let remainder = count % 8;
720
721        for chunk in 0..chunks {
722            let base = chunk * 8;
723            let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
724            let v = _mm256_loadu_si256(ptr);
725            let result = _mm256_add_epi32(v, ones);
726            _mm256_storeu_si256(ptr, result);
727        }
728
729        let base = chunks * 8;
730        for i in 0..remainder {
731            values[base + i] += 1;
732        }
733    }
734
735    /// Check if AVX2 is available at runtime
736    #[inline]
737    pub fn is_available() -> bool {
738        is_x86_feature_detected!("avx2")
739    }
740}
741
742// ============================================================================
743// Scalar fallback implementations
744// ============================================================================
745
746#[allow(dead_code)]
747mod scalar {
748    /// Scalar unpack for 8-bit values
749    #[inline]
750    pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
751        for i in 0..count {
752            output[i] = input[i] as u32;
753        }
754    }
755
756    /// Scalar unpack for 16-bit values
757    #[inline]
758    pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
759        for (i, out) in output.iter_mut().enumerate().take(count) {
760            let idx = i * 2;
761            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
762        }
763    }
764
765    /// Scalar unpack for 32-bit values
766    #[inline]
767    pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
768        for (i, out) in output.iter_mut().enumerate().take(count) {
769            let idx = i * 4;
770            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
771        }
772    }
773
774    /// Scalar delta decode
775    #[inline]
776    pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
777        if count == 0 {
778            return;
779        }
780
781        output[0] = first_doc_id;
782        let mut carry = first_doc_id;
783
784        for i in 0..count - 1 {
785            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
786            output[i + 1] = carry;
787        }
788    }
789
790    /// Scalar add 1 to all values
791    #[inline]
792    pub fn add_one(values: &mut [u32], count: usize) {
793        for val in values.iter_mut().take(count) {
794            *val += 1;
795        }
796    }
797}
798
799// ============================================================================
800// Public dispatch functions that select SIMD or scalar at runtime
801// ============================================================================
802
803/// Unpack 8-bit packed values to u32 with SIMD acceleration
804#[inline]
805pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
806    #[cfg(target_arch = "aarch64")]
807    {
808        if neon::is_available() {
809            unsafe {
810                neon::unpack_8bit(input, output, count);
811            }
812            return;
813        }
814    }
815
816    #[cfg(target_arch = "x86_64")]
817    {
818        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
819        if avx2::is_available() {
820            unsafe {
821                avx2::unpack_8bit(input, output, count);
822            }
823            return;
824        }
825        if sse::is_available() {
826            unsafe {
827                sse::unpack_8bit(input, output, count);
828            }
829            return;
830        }
831    }
832
833    scalar::unpack_8bit(input, output, count);
834}
835
836/// Unpack 16-bit packed values to u32 with SIMD acceleration
837#[inline]
838pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
839    #[cfg(target_arch = "aarch64")]
840    {
841        if neon::is_available() {
842            unsafe {
843                neon::unpack_16bit(input, output, count);
844            }
845            return;
846        }
847    }
848
849    #[cfg(target_arch = "x86_64")]
850    {
851        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
852        if avx2::is_available() {
853            unsafe {
854                avx2::unpack_16bit(input, output, count);
855            }
856            return;
857        }
858        if sse::is_available() {
859            unsafe {
860                sse::unpack_16bit(input, output, count);
861            }
862            return;
863        }
864    }
865
866    scalar::unpack_16bit(input, output, count);
867}
868
869/// Unpack 32-bit packed values to u32 with SIMD acceleration
870#[inline]
871pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
872    #[cfg(target_arch = "aarch64")]
873    {
874        if neon::is_available() {
875            unsafe {
876                neon::unpack_32bit(input, output, count);
877            }
878        }
879    }
880
881    #[cfg(target_arch = "x86_64")]
882    {
883        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
884        if avx2::is_available() {
885            unsafe {
886                avx2::unpack_32bit(input, output, count);
887            }
888        } else {
889            // SSE2 is always available on x86_64
890            unsafe {
891                sse::unpack_32bit(input, output, count);
892            }
893        }
894    }
895
896    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
897    {
898        scalar::unpack_32bit(input, output, count);
899    }
900}
901
902/// Delta decode with SIMD acceleration
903///
904/// Converts delta-encoded values to absolute values.
905/// Input: deltas[i] = value[i+1] - value[i] - 1 (gap minus one)
906/// Output: absolute values starting from first_value
907#[inline]
908pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
909    #[cfg(target_arch = "aarch64")]
910    {
911        if neon::is_available() {
912            unsafe {
913                neon::delta_decode(output, deltas, first_value, count);
914            }
915            return;
916        }
917    }
918
919    #[cfg(target_arch = "x86_64")]
920    {
921        if sse::is_available() {
922            unsafe {
923                sse::delta_decode(output, deltas, first_value, count);
924            }
925            return;
926        }
927    }
928
929    scalar::delta_decode(output, deltas, first_value, count);
930}
931
932/// Add 1 to all values with SIMD acceleration
933///
934/// Used for TF decoding where values are stored as (tf - 1)
935#[inline]
936pub fn add_one(values: &mut [u32], count: usize) {
937    #[cfg(target_arch = "aarch64")]
938    {
939        if neon::is_available() {
940            unsafe {
941                neon::add_one(values, count);
942            }
943        }
944    }
945
946    #[cfg(target_arch = "x86_64")]
947    {
948        // Prefer AVX2 (256-bit) over SSE (128-bit) when available
949        if avx2::is_available() {
950            unsafe {
951                avx2::add_one(values, count);
952            }
953        } else {
954            // SSE2 is always available on x86_64
955            unsafe {
956                sse::add_one(values, count);
957            }
958        }
959    }
960
961    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
962    {
963        scalar::add_one(values, count);
964    }
965}
966
967/// Compute the number of bits needed to represent a value
968#[inline]
969pub fn bits_needed(val: u32) -> u8 {
970    if val == 0 {
971        0
972    } else {
973        32 - val.leading_zeros() as u8
974    }
975}
976
977// ============================================================================
978// Rounded bitpacking for truly vectorized encoding/decoding
979// ============================================================================
980//
981// Instead of using arbitrary bit widths (1-32), we round up to SIMD-friendly
982// widths: 0, 8, 16, or 32 bits. This trades ~10-20% more space for much faster
983// decoding since we can use direct SIMD widening instructions (pmovzx) without
984// any bit-shifting or masking.
985//
986// Bit width mapping:
987//   0      -> 0  (all zeros)
988//   1-8    -> 8  (u8)
989//   9-16   -> 16 (u16)
990//   17-32  -> 32 (u32)
991
992/// Rounded bit width type for SIMD-friendly encoding
993#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994#[repr(u8)]
995pub enum RoundedBitWidth {
996    Zero = 0,
997    Bits8 = 8,
998    Bits16 = 16,
999    Bits32 = 32,
1000}
1001
1002impl RoundedBitWidth {
1003    /// Round an exact bit width to the nearest SIMD-friendly width
1004    #[inline]
1005    pub fn from_exact(bits: u8) -> Self {
1006        match bits {
1007            0 => RoundedBitWidth::Zero,
1008            1..=8 => RoundedBitWidth::Bits8,
1009            9..=16 => RoundedBitWidth::Bits16,
1010            _ => RoundedBitWidth::Bits32,
1011        }
1012    }
1013
1014    /// Convert from stored u8 value (must be 0, 8, 16, or 32)
1015    #[inline]
1016    pub fn from_u8(bits: u8) -> Self {
1017        match bits {
1018            0 => RoundedBitWidth::Zero,
1019            8 => RoundedBitWidth::Bits8,
1020            16 => RoundedBitWidth::Bits16,
1021            32 => RoundedBitWidth::Bits32,
1022            _ => RoundedBitWidth::Bits32, // Fallback for invalid values
1023        }
1024    }
1025
1026    /// Get the byte size per value
1027    #[inline]
1028    pub fn bytes_per_value(self) -> usize {
1029        match self {
1030            RoundedBitWidth::Zero => 0,
1031            RoundedBitWidth::Bits8 => 1,
1032            RoundedBitWidth::Bits16 => 2,
1033            RoundedBitWidth::Bits32 => 4,
1034        }
1035    }
1036
1037    /// Get the raw bit width value
1038    #[inline]
1039    pub fn as_u8(self) -> u8 {
1040        self as u8
1041    }
1042}
1043
1044/// Round a bit width to the nearest SIMD-friendly width (0, 8, 16, or 32)
1045#[inline]
1046pub fn round_bit_width(bits: u8) -> u8 {
1047    RoundedBitWidth::from_exact(bits).as_u8()
1048}
1049
1050/// Pack values using rounded bit width (SIMD-friendly)
1051///
1052/// This is much simpler than arbitrary bitpacking since values are byte-aligned.
1053/// Returns the number of bytes written.
1054#[inline]
1055pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
1056    let count = values.len();
1057    match bit_width {
1058        RoundedBitWidth::Zero => 0,
1059        RoundedBitWidth::Bits8 => {
1060            for (i, &v) in values.iter().enumerate() {
1061                output[i] = v as u8;
1062            }
1063            count
1064        }
1065        RoundedBitWidth::Bits16 => {
1066            for (i, &v) in values.iter().enumerate() {
1067                let bytes = (v as u16).to_le_bytes();
1068                output[i * 2] = bytes[0];
1069                output[i * 2 + 1] = bytes[1];
1070            }
1071            count * 2
1072        }
1073        RoundedBitWidth::Bits32 => {
1074            for (i, &v) in values.iter().enumerate() {
1075                let bytes = v.to_le_bytes();
1076                output[i * 4] = bytes[0];
1077                output[i * 4 + 1] = bytes[1];
1078                output[i * 4 + 2] = bytes[2];
1079                output[i * 4 + 3] = bytes[3];
1080            }
1081            count * 4
1082        }
1083    }
1084}
1085
1086/// Unpack values using rounded bit width with SIMD acceleration
1087///
1088/// This is the fast path - no bit manipulation needed, just widening.
1089#[inline]
1090pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
1091    match bit_width {
1092        RoundedBitWidth::Zero => {
1093            for out in output.iter_mut().take(count) {
1094                *out = 0;
1095            }
1096        }
1097        RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
1098        RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
1099        RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
1100    }
1101}
1102
1103/// Fused unpack + delta decode using rounded bit width
1104///
1105/// Combines unpacking and prefix sum in a single pass for better cache utilization.
1106#[inline]
1107pub fn unpack_rounded_delta_decode(
1108    input: &[u8],
1109    bit_width: RoundedBitWidth,
1110    output: &mut [u32],
1111    first_value: u32,
1112    count: usize,
1113) {
1114    match bit_width {
1115        RoundedBitWidth::Zero => {
1116            // All deltas are 0, meaning gaps of 1
1117            let mut val = first_value;
1118            for out in output.iter_mut().take(count) {
1119                *out = val;
1120                val = val.wrapping_add(1);
1121            }
1122        }
1123        RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
1124        RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
1125        RoundedBitWidth::Bits32 => {
1126            // For 32-bit, unpack then delta decode (no fused version needed)
1127            unpack_32bit(input, output, count);
1128            // Delta decode in place - but we need the deltas separate
1129            // Actually for 32-bit we should just unpack and delta decode separately
1130            if count > 0 {
1131                let mut carry = first_value;
1132                output[0] = first_value;
1133                for item in output.iter_mut().take(count).skip(1) {
1134                    // item currently holds delta (gap-1)
1135                    carry = carry.wrapping_add(*item).wrapping_add(1);
1136                    *item = carry;
1137                }
1138            }
1139        }
1140    }
1141}
1142
1143// ============================================================================
1144// Fused operations for better cache utilization
1145// ============================================================================
1146
1147/// Fused unpack 8-bit + delta decode in a single pass
1148///
1149/// This avoids writing the intermediate unpacked values to memory,
1150/// improving cache utilization for large blocks.
1151#[inline]
1152pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1153    if count == 0 {
1154        return;
1155    }
1156
1157    output[0] = first_value;
1158    if count == 1 {
1159        return;
1160    }
1161
1162    #[cfg(target_arch = "aarch64")]
1163    {
1164        if neon::is_available() {
1165            unsafe {
1166                neon::unpack_8bit_delta_decode(input, output, first_value, count);
1167            }
1168            return;
1169        }
1170    }
1171
1172    #[cfg(target_arch = "x86_64")]
1173    {
1174        if sse::is_available() {
1175            unsafe {
1176                sse::unpack_8bit_delta_decode(input, output, first_value, count);
1177            }
1178            return;
1179        }
1180    }
1181
1182    // Scalar fallback
1183    let mut carry = first_value;
1184    for i in 0..count - 1 {
1185        carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
1186        output[i + 1] = carry;
1187    }
1188}
1189
1190/// Fused unpack 16-bit + delta decode in a single pass
1191#[inline]
1192pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
1193    if count == 0 {
1194        return;
1195    }
1196
1197    output[0] = first_value;
1198    if count == 1 {
1199        return;
1200    }
1201
1202    #[cfg(target_arch = "aarch64")]
1203    {
1204        if neon::is_available() {
1205            unsafe {
1206                neon::unpack_16bit_delta_decode(input, output, first_value, count);
1207            }
1208            return;
1209        }
1210    }
1211
1212    #[cfg(target_arch = "x86_64")]
1213    {
1214        if sse::is_available() {
1215            unsafe {
1216                sse::unpack_16bit_delta_decode(input, output, first_value, count);
1217            }
1218            return;
1219        }
1220    }
1221
1222    // Scalar fallback
1223    let mut carry = first_value;
1224    for i in 0..count - 1 {
1225        let idx = i * 2;
1226        let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
1227        carry = carry.wrapping_add(delta).wrapping_add(1);
1228        output[i + 1] = carry;
1229    }
1230}
1231
1232/// Fused unpack + delta decode for arbitrary bit widths
1233///
1234/// Combines unpacking and prefix sum in a single pass, avoiding intermediate buffer.
1235/// Uses SIMD-accelerated paths for 8/16-bit widths, scalar for others.
1236#[inline]
1237pub fn unpack_delta_decode(
1238    input: &[u8],
1239    bit_width: u8,
1240    output: &mut [u32],
1241    first_value: u32,
1242    count: usize,
1243) {
1244    if count == 0 {
1245        return;
1246    }
1247
1248    output[0] = first_value;
1249    if count == 1 {
1250        return;
1251    }
1252
1253    // Fast paths for SIMD-friendly bit widths
1254    match bit_width {
1255        0 => {
1256            // All zeros = consecutive doc IDs (gap of 1)
1257            let mut val = first_value;
1258            for item in output.iter_mut().take(count).skip(1) {
1259                val = val.wrapping_add(1);
1260                *item = val;
1261            }
1262        }
1263        8 => unpack_8bit_delta_decode(input, output, first_value, count),
1264        16 => unpack_16bit_delta_decode(input, output, first_value, count),
1265        32 => {
1266            // 32-bit: unpack inline and delta decode
1267            let mut carry = first_value;
1268            for i in 0..count - 1 {
1269                let idx = i * 4;
1270                let delta = u32::from_le_bytes([
1271                    input[idx],
1272                    input[idx + 1],
1273                    input[idx + 2],
1274                    input[idx + 3],
1275                ]);
1276                carry = carry.wrapping_add(delta).wrapping_add(1);
1277                output[i + 1] = carry;
1278            }
1279        }
1280        _ => {
1281            // Generic bit width: fused unpack + delta decode
1282            let mask = (1u64 << bit_width) - 1;
1283            let bit_width_usize = bit_width as usize;
1284            let mut bit_pos = 0usize;
1285            let input_ptr = input.as_ptr();
1286            let mut carry = first_value;
1287
1288            for i in 0..count - 1 {
1289                let byte_idx = bit_pos >> 3;
1290                let bit_offset = bit_pos & 7;
1291
1292                // SAFETY: Caller guarantees input has enough data
1293                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
1294                let delta = ((word >> bit_offset) & mask) as u32;
1295
1296                carry = carry.wrapping_add(delta).wrapping_add(1);
1297                output[i + 1] = carry;
1298                bit_pos += bit_width_usize;
1299            }
1300        }
1301    }
1302}
1303
1304// ============================================================================
1305// Sparse Vector SIMD Functions
1306// ============================================================================
1307
1308/// Dequantize UInt8 weights to f32 with SIMD acceleration
1309///
1310/// Computes: output[i] = input[i] as f32 * scale + min_val
1311#[inline]
1312pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
1313    #[cfg(target_arch = "aarch64")]
1314    {
1315        if neon::is_available() {
1316            unsafe {
1317                dequantize_uint8_neon(input, output, scale, min_val, count);
1318            }
1319            return;
1320        }
1321    }
1322
1323    #[cfg(target_arch = "x86_64")]
1324    {
1325        if sse::is_available() {
1326            unsafe {
1327                dequantize_uint8_sse(input, output, scale, min_val, count);
1328            }
1329            return;
1330        }
1331    }
1332
1333    // Scalar fallback
1334    for i in 0..count {
1335        output[i] = input[i] as f32 * scale + min_val;
1336    }
1337}
1338
1339#[cfg(target_arch = "aarch64")]
1340#[target_feature(enable = "neon")]
1341#[allow(unsafe_op_in_unsafe_fn)]
1342unsafe fn dequantize_uint8_neon(
1343    input: &[u8],
1344    output: &mut [f32],
1345    scale: f32,
1346    min_val: f32,
1347    count: usize,
1348) {
1349    use std::arch::aarch64::*;
1350
1351    let scale_v = vdupq_n_f32(scale);
1352    let min_v = vdupq_n_f32(min_val);
1353
1354    let chunks = count / 16;
1355    let remainder = count % 16;
1356
1357    for chunk in 0..chunks {
1358        let base = chunk * 16;
1359        let in_ptr = input.as_ptr().add(base);
1360
1361        // Load 16 bytes
1362        let bytes = vld1q_u8(in_ptr);
1363
1364        // Widen u8 -> u16 -> u32 -> f32
1365        let low8 = vget_low_u8(bytes);
1366        let high8 = vget_high_u8(bytes);
1367
1368        let low16 = vmovl_u8(low8);
1369        let high16 = vmovl_u8(high8);
1370
1371        // Process 4 values at a time
1372        let u32_0 = vmovl_u16(vget_low_u16(low16));
1373        let u32_1 = vmovl_u16(vget_high_u16(low16));
1374        let u32_2 = vmovl_u16(vget_low_u16(high16));
1375        let u32_3 = vmovl_u16(vget_high_u16(high16));
1376
1377        // Convert to f32 and apply scale + min_val
1378        let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
1379        let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
1380        let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
1381        let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
1382
1383        let out_ptr = output.as_mut_ptr().add(base);
1384        vst1q_f32(out_ptr, f32_0);
1385        vst1q_f32(out_ptr.add(4), f32_1);
1386        vst1q_f32(out_ptr.add(8), f32_2);
1387        vst1q_f32(out_ptr.add(12), f32_3);
1388    }
1389
1390    // Handle remainder
1391    let base = chunks * 16;
1392    for i in 0..remainder {
1393        output[base + i] = input[base + i] as f32 * scale + min_val;
1394    }
1395}
1396
1397#[cfg(target_arch = "x86_64")]
1398#[target_feature(enable = "sse2", enable = "sse4.1")]
1399#[allow(unsafe_op_in_unsafe_fn)]
1400unsafe fn dequantize_uint8_sse(
1401    input: &[u8],
1402    output: &mut [f32],
1403    scale: f32,
1404    min_val: f32,
1405    count: usize,
1406) {
1407    use std::arch::x86_64::*;
1408
1409    let scale_v = _mm_set1_ps(scale);
1410    let min_v = _mm_set1_ps(min_val);
1411
1412    let chunks = count / 4;
1413    let remainder = count % 4;
1414
1415    for chunk in 0..chunks {
1416        let base = chunk * 4;
1417
1418        // Load 4 bytes and zero-extend to 32-bit
1419        let b0 = input[base] as i32;
1420        let b1 = input[base + 1] as i32;
1421        let b2 = input[base + 2] as i32;
1422        let b3 = input[base + 3] as i32;
1423
1424        let ints = _mm_set_epi32(b3, b2, b1, b0);
1425        let floats = _mm_cvtepi32_ps(ints);
1426
1427        // Apply scale and min_val: result = floats * scale + min_val
1428        let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
1429
1430        _mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
1431    }
1432
1433    // Handle remainder
1434    let base = chunks * 4;
1435    for i in 0..remainder {
1436        output[base + i] = input[base + i] as f32 * scale + min_val;
1437    }
1438}
1439
1440/// Compute dot product of two f32 arrays with SIMD acceleration
1441#[inline]
1442pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
1443    #[cfg(target_arch = "aarch64")]
1444    {
1445        if neon::is_available() {
1446            return unsafe { dot_product_f32_neon(a, b, count) };
1447        }
1448    }
1449
1450    #[cfg(target_arch = "x86_64")]
1451    {
1452        if is_x86_feature_detected!("avx512f") {
1453            return unsafe { dot_product_f32_avx512(a, b, count) };
1454        }
1455        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1456            return unsafe { dot_product_f32_avx2(a, b, count) };
1457        }
1458        if sse::is_available() {
1459            return unsafe { dot_product_f32_sse(a, b, count) };
1460        }
1461    }
1462
1463    // Scalar fallback
1464    let mut sum = 0.0f32;
1465    for i in 0..count {
1466        sum += a[i] * b[i];
1467    }
1468    sum
1469}
1470
1471#[cfg(target_arch = "aarch64")]
1472#[target_feature(enable = "neon")]
1473#[allow(unsafe_op_in_unsafe_fn)]
1474unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
1475    use std::arch::aarch64::*;
1476
1477    let chunks16 = count / 16;
1478    let remainder = count % 16;
1479
1480    let mut acc0 = vdupq_n_f32(0.0);
1481    let mut acc1 = vdupq_n_f32(0.0);
1482    let mut acc2 = vdupq_n_f32(0.0);
1483    let mut acc3 = vdupq_n_f32(0.0);
1484
1485    for c in 0..chunks16 {
1486        let base = c * 16;
1487        acc0 = vfmaq_f32(
1488            acc0,
1489            vld1q_f32(a.as_ptr().add(base)),
1490            vld1q_f32(b.as_ptr().add(base)),
1491        );
1492        acc1 = vfmaq_f32(
1493            acc1,
1494            vld1q_f32(a.as_ptr().add(base + 4)),
1495            vld1q_f32(b.as_ptr().add(base + 4)),
1496        );
1497        acc2 = vfmaq_f32(
1498            acc2,
1499            vld1q_f32(a.as_ptr().add(base + 8)),
1500            vld1q_f32(b.as_ptr().add(base + 8)),
1501        );
1502        acc3 = vfmaq_f32(
1503            acc3,
1504            vld1q_f32(a.as_ptr().add(base + 12)),
1505            vld1q_f32(b.as_ptr().add(base + 12)),
1506        );
1507    }
1508
1509    let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1510    let mut sum = vaddvq_f32(acc);
1511
1512    let base = chunks16 * 16;
1513    for i in 0..remainder {
1514        sum += a[base + i] * b[base + i];
1515    }
1516
1517    sum
1518}
1519
1520#[cfg(target_arch = "x86_64")]
1521#[target_feature(enable = "avx2", enable = "fma")]
1522#[allow(unsafe_op_in_unsafe_fn)]
1523unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
1524    use std::arch::x86_64::*;
1525
1526    let chunks32 = count / 32;
1527    let remainder = count % 32;
1528
1529    let mut acc0 = _mm256_setzero_ps();
1530    let mut acc1 = _mm256_setzero_ps();
1531    let mut acc2 = _mm256_setzero_ps();
1532    let mut acc3 = _mm256_setzero_ps();
1533
1534    for c in 0..chunks32 {
1535        let base = c * 32;
1536        acc0 = _mm256_fmadd_ps(
1537            _mm256_loadu_ps(a.as_ptr().add(base)),
1538            _mm256_loadu_ps(b.as_ptr().add(base)),
1539            acc0,
1540        );
1541        acc1 = _mm256_fmadd_ps(
1542            _mm256_loadu_ps(a.as_ptr().add(base + 8)),
1543            _mm256_loadu_ps(b.as_ptr().add(base + 8)),
1544            acc1,
1545        );
1546        acc2 = _mm256_fmadd_ps(
1547            _mm256_loadu_ps(a.as_ptr().add(base + 16)),
1548            _mm256_loadu_ps(b.as_ptr().add(base + 16)),
1549            acc2,
1550        );
1551        acc3 = _mm256_fmadd_ps(
1552            _mm256_loadu_ps(a.as_ptr().add(base + 24)),
1553            _mm256_loadu_ps(b.as_ptr().add(base + 24)),
1554            acc3,
1555        );
1556    }
1557
1558    let acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
1559
1560    // Horizontal sum: 256-bit → 128-bit → scalar
1561    let hi = _mm256_extractf128_ps(acc, 1);
1562    let lo = _mm256_castps256_ps128(acc);
1563    let sum128 = _mm_add_ps(lo, hi);
1564    let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
1565    let sums = _mm_add_ps(sum128, shuf);
1566    let shuf2 = _mm_movehl_ps(sums, sums);
1567    let final_sum = _mm_add_ss(sums, shuf2);
1568
1569    let mut sum = _mm_cvtss_f32(final_sum);
1570
1571    let base = chunks32 * 32;
1572    for i in 0..remainder {
1573        sum += a[base + i] * b[base + i];
1574    }
1575
1576    sum
1577}
1578
1579#[cfg(target_arch = "x86_64")]
1580#[target_feature(enable = "sse")]
1581#[allow(unsafe_op_in_unsafe_fn)]
1582unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
1583    use std::arch::x86_64::*;
1584
1585    let chunks = count / 4;
1586    let remainder = count % 4;
1587
1588    let mut acc = _mm_setzero_ps();
1589
1590    for chunk in 0..chunks {
1591        let base = chunk * 4;
1592        let va = _mm_loadu_ps(a.as_ptr().add(base));
1593        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1594        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
1595    }
1596
1597    // Horizontal sum: [a, b, c, d] -> a + b + c + d
1598    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
1599    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
1600    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
1601    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
1602
1603    let mut sum = _mm_cvtss_f32(final_sum);
1604
1605    // Handle remainder
1606    let base = chunks * 4;
1607    for i in 0..remainder {
1608        sum += a[base + i] * b[base + i];
1609    }
1610
1611    sum
1612}
1613
1614#[cfg(target_arch = "x86_64")]
1615#[target_feature(enable = "avx512f")]
1616#[allow(unsafe_op_in_unsafe_fn)]
1617unsafe fn dot_product_f32_avx512(a: &[f32], b: &[f32], count: usize) -> f32 {
1618    use std::arch::x86_64::*;
1619
1620    let chunks64 = count / 64;
1621    let remainder = count % 64;
1622
1623    let mut acc0 = _mm512_setzero_ps();
1624    let mut acc1 = _mm512_setzero_ps();
1625    let mut acc2 = _mm512_setzero_ps();
1626    let mut acc3 = _mm512_setzero_ps();
1627
1628    for c in 0..chunks64 {
1629        let base = c * 64;
1630        acc0 = _mm512_fmadd_ps(
1631            _mm512_loadu_ps(a.as_ptr().add(base)),
1632            _mm512_loadu_ps(b.as_ptr().add(base)),
1633            acc0,
1634        );
1635        acc1 = _mm512_fmadd_ps(
1636            _mm512_loadu_ps(a.as_ptr().add(base + 16)),
1637            _mm512_loadu_ps(b.as_ptr().add(base + 16)),
1638            acc1,
1639        );
1640        acc2 = _mm512_fmadd_ps(
1641            _mm512_loadu_ps(a.as_ptr().add(base + 32)),
1642            _mm512_loadu_ps(b.as_ptr().add(base + 32)),
1643            acc2,
1644        );
1645        acc3 = _mm512_fmadd_ps(
1646            _mm512_loadu_ps(a.as_ptr().add(base + 48)),
1647            _mm512_loadu_ps(b.as_ptr().add(base + 48)),
1648            acc3,
1649        );
1650    }
1651
1652    let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
1653    let mut sum = _mm512_reduce_add_ps(acc);
1654
1655    let base = chunks64 * 64;
1656    for i in 0..remainder {
1657        sum += a[base + i] * b[base + i];
1658    }
1659
1660    sum
1661}
1662
1663#[cfg(target_arch = "x86_64")]
1664#[target_feature(enable = "avx512f")]
1665#[allow(unsafe_op_in_unsafe_fn)]
1666unsafe fn fused_dot_norm_avx512(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1667    use std::arch::x86_64::*;
1668
1669    let chunks64 = count / 64;
1670    let remainder = count % 64;
1671
1672    let mut d0 = _mm512_setzero_ps();
1673    let mut d1 = _mm512_setzero_ps();
1674    let mut d2 = _mm512_setzero_ps();
1675    let mut d3 = _mm512_setzero_ps();
1676    let mut n0 = _mm512_setzero_ps();
1677    let mut n1 = _mm512_setzero_ps();
1678    let mut n2 = _mm512_setzero_ps();
1679    let mut n3 = _mm512_setzero_ps();
1680
1681    for c in 0..chunks64 {
1682        let base = c * 64;
1683        let vb0 = _mm512_loadu_ps(b.as_ptr().add(base));
1684        d0 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1685        n0 = _mm512_fmadd_ps(vb0, vb0, n0);
1686        let vb1 = _mm512_loadu_ps(b.as_ptr().add(base + 16));
1687        d1 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 16)), vb1, d1);
1688        n1 = _mm512_fmadd_ps(vb1, vb1, n1);
1689        let vb2 = _mm512_loadu_ps(b.as_ptr().add(base + 32));
1690        d2 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 32)), vb2, d2);
1691        n2 = _mm512_fmadd_ps(vb2, vb2, n2);
1692        let vb3 = _mm512_loadu_ps(b.as_ptr().add(base + 48));
1693        d3 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 48)), vb3, d3);
1694        n3 = _mm512_fmadd_ps(vb3, vb3, n3);
1695    }
1696
1697    let acc_dot = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
1698    let acc_norm = _mm512_add_ps(_mm512_add_ps(n0, n1), _mm512_add_ps(n2, n3));
1699    let mut dot = _mm512_reduce_add_ps(acc_dot);
1700    let mut norm = _mm512_reduce_add_ps(acc_norm);
1701
1702    let base = chunks64 * 64;
1703    for i in 0..remainder {
1704        dot += a[base + i] * b[base + i];
1705        norm += b[base + i] * b[base + i];
1706    }
1707
1708    (dot, norm)
1709}
1710
1711/// Find maximum value in f32 array with SIMD acceleration
1712#[inline]
1713pub fn max_f32(values: &[f32], count: usize) -> f32 {
1714    if count == 0 {
1715        return f32::NEG_INFINITY;
1716    }
1717
1718    #[cfg(target_arch = "aarch64")]
1719    {
1720        if neon::is_available() {
1721            return unsafe { max_f32_neon(values, count) };
1722        }
1723    }
1724
1725    #[cfg(target_arch = "x86_64")]
1726    {
1727        if sse::is_available() {
1728            return unsafe { max_f32_sse(values, count) };
1729        }
1730    }
1731
1732    // Scalar fallback
1733    values[..count]
1734        .iter()
1735        .cloned()
1736        .fold(f32::NEG_INFINITY, f32::max)
1737}
1738
1739#[cfg(target_arch = "aarch64")]
1740#[target_feature(enable = "neon")]
1741#[allow(unsafe_op_in_unsafe_fn)]
1742unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
1743    use std::arch::aarch64::*;
1744
1745    let chunks = count / 4;
1746    let remainder = count % 4;
1747
1748    let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
1749
1750    for chunk in 0..chunks {
1751        let base = chunk * 4;
1752        let v = vld1q_f32(values.as_ptr().add(base));
1753        max_v = vmaxq_f32(max_v, v);
1754    }
1755
1756    // Horizontal max
1757    let mut max_val = vmaxvq_f32(max_v);
1758
1759    // Handle remainder
1760    let base = chunks * 4;
1761    for i in 0..remainder {
1762        max_val = max_val.max(values[base + i]);
1763    }
1764
1765    max_val
1766}
1767
1768#[cfg(target_arch = "x86_64")]
1769#[target_feature(enable = "sse")]
1770#[allow(unsafe_op_in_unsafe_fn)]
1771unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
1772    use std::arch::x86_64::*;
1773
1774    let chunks = count / 4;
1775    let remainder = count % 4;
1776
1777    let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
1778
1779    for chunk in 0..chunks {
1780        let base = chunk * 4;
1781        let v = _mm_loadu_ps(values.as_ptr().add(base));
1782        max_v = _mm_max_ps(max_v, v);
1783    }
1784
1785    // Horizontal max: [a, b, c, d] -> max(a, b, c, d)
1786    let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); // [b, a, d, c]
1787    let max1 = _mm_max_ps(max_v, shuf); // [max(a,b), max(a,b), max(c,d), max(c,d)]
1788    let shuf2 = _mm_movehl_ps(max1, max1); // [max(c,d), max(c,d), ?, ?]
1789    let final_max = _mm_max_ss(max1, shuf2); // [max(a,b,c,d), ?, ?, ?]
1790
1791    let mut max_val = _mm_cvtss_f32(final_max);
1792
1793    // Handle remainder
1794    let base = chunks * 4;
1795    for i in 0..remainder {
1796        max_val = max_val.max(values[base + i]);
1797    }
1798
1799    max_val
1800}
1801
1802// ============================================================================
1803// Batched Cosine Similarity for Dense Vector Search
1804// ============================================================================
1805
1806/// Fused dot-product + self-norm in a single pass (SIMD accelerated).
1807///
1808/// Returns (dot(a, b), dot(b, b)) — i.e. the dot product of a·b and ||b||².
1809/// Loads `b` only once (halves memory bandwidth vs two separate dot products).
1810#[inline]
1811fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1812    #[cfg(target_arch = "aarch64")]
1813    {
1814        if neon::is_available() {
1815            return unsafe { fused_dot_norm_neon(a, b, count) };
1816        }
1817    }
1818
1819    #[cfg(target_arch = "x86_64")]
1820    {
1821        if is_x86_feature_detected!("avx512f") {
1822            return unsafe { fused_dot_norm_avx512(a, b, count) };
1823        }
1824        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
1825            return unsafe { fused_dot_norm_avx2(a, b, count) };
1826        }
1827        if sse::is_available() {
1828            return unsafe { fused_dot_norm_sse(a, b, count) };
1829        }
1830    }
1831
1832    // Scalar fallback
1833    let mut dot = 0.0f32;
1834    let mut norm_b = 0.0f32;
1835    for i in 0..count {
1836        dot += a[i] * b[i];
1837        norm_b += b[i] * b[i];
1838    }
1839    (dot, norm_b)
1840}
1841
1842#[cfg(target_arch = "aarch64")]
1843#[target_feature(enable = "neon")]
1844#[allow(unsafe_op_in_unsafe_fn)]
1845unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1846    use std::arch::aarch64::*;
1847
1848    let chunks16 = count / 16;
1849    let remainder = count % 16;
1850
1851    let mut d0 = vdupq_n_f32(0.0);
1852    let mut d1 = vdupq_n_f32(0.0);
1853    let mut d2 = vdupq_n_f32(0.0);
1854    let mut d3 = vdupq_n_f32(0.0);
1855    let mut n0 = vdupq_n_f32(0.0);
1856    let mut n1 = vdupq_n_f32(0.0);
1857    let mut n2 = vdupq_n_f32(0.0);
1858    let mut n3 = vdupq_n_f32(0.0);
1859
1860    for c in 0..chunks16 {
1861        let base = c * 16;
1862        let va0 = vld1q_f32(a.as_ptr().add(base));
1863        let vb0 = vld1q_f32(b.as_ptr().add(base));
1864        d0 = vfmaq_f32(d0, va0, vb0);
1865        n0 = vfmaq_f32(n0, vb0, vb0);
1866        let va1 = vld1q_f32(a.as_ptr().add(base + 4));
1867        let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
1868        d1 = vfmaq_f32(d1, va1, vb1);
1869        n1 = vfmaq_f32(n1, vb1, vb1);
1870        let va2 = vld1q_f32(a.as_ptr().add(base + 8));
1871        let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
1872        d2 = vfmaq_f32(d2, va2, vb2);
1873        n2 = vfmaq_f32(n2, vb2, vb2);
1874        let va3 = vld1q_f32(a.as_ptr().add(base + 12));
1875        let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
1876        d3 = vfmaq_f32(d3, va3, vb3);
1877        n3 = vfmaq_f32(n3, vb3, vb3);
1878    }
1879
1880    let acc_dot = vaddq_f32(vaddq_f32(d0, d1), vaddq_f32(d2, d3));
1881    let acc_norm = vaddq_f32(vaddq_f32(n0, n1), vaddq_f32(n2, n3));
1882    let mut dot = vaddvq_f32(acc_dot);
1883    let mut norm = vaddvq_f32(acc_norm);
1884
1885    let base = chunks16 * 16;
1886    for i in 0..remainder {
1887        dot += a[base + i] * b[base + i];
1888        norm += b[base + i] * b[base + i];
1889    }
1890
1891    (dot, norm)
1892}
1893
1894#[cfg(target_arch = "x86_64")]
1895#[target_feature(enable = "avx2", enable = "fma")]
1896#[allow(unsafe_op_in_unsafe_fn)]
1897unsafe fn fused_dot_norm_avx2(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1898    use std::arch::x86_64::*;
1899
1900    let chunks32 = count / 32;
1901    let remainder = count % 32;
1902
1903    let mut d0 = _mm256_setzero_ps();
1904    let mut d1 = _mm256_setzero_ps();
1905    let mut d2 = _mm256_setzero_ps();
1906    let mut d3 = _mm256_setzero_ps();
1907    let mut n0 = _mm256_setzero_ps();
1908    let mut n1 = _mm256_setzero_ps();
1909    let mut n2 = _mm256_setzero_ps();
1910    let mut n3 = _mm256_setzero_ps();
1911
1912    for c in 0..chunks32 {
1913        let base = c * 32;
1914        let vb0 = _mm256_loadu_ps(b.as_ptr().add(base));
1915        d0 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base)), vb0, d0);
1916        n0 = _mm256_fmadd_ps(vb0, vb0, n0);
1917        let vb1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
1918        d1 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 8)), vb1, d1);
1919        n1 = _mm256_fmadd_ps(vb1, vb1, n1);
1920        let vb2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
1921        d2 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 16)), vb2, d2);
1922        n2 = _mm256_fmadd_ps(vb2, vb2, n2);
1923        let vb3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
1924        d3 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 24)), vb3, d3);
1925        n3 = _mm256_fmadd_ps(vb3, vb3, n3);
1926    }
1927
1928    let acc_dot = _mm256_add_ps(_mm256_add_ps(d0, d1), _mm256_add_ps(d2, d3));
1929    let acc_norm = _mm256_add_ps(_mm256_add_ps(n0, n1), _mm256_add_ps(n2, n3));
1930
1931    // Horizontal sums: 256→128→scalar
1932    let hi_d = _mm256_extractf128_ps(acc_dot, 1);
1933    let lo_d = _mm256_castps256_ps128(acc_dot);
1934    let sum_d = _mm_add_ps(lo_d, hi_d);
1935    let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
1936    let sums_d = _mm_add_ps(sum_d, shuf_d);
1937    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1938    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
1939
1940    let hi_n = _mm256_extractf128_ps(acc_norm, 1);
1941    let lo_n = _mm256_castps256_ps128(acc_norm);
1942    let sum_n = _mm_add_ps(lo_n, hi_n);
1943    let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
1944    let sums_n = _mm_add_ps(sum_n, shuf_n);
1945    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1946    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
1947
1948    let base = chunks32 * 32;
1949    for i in 0..remainder {
1950        dot += a[base + i] * b[base + i];
1951        norm += b[base + i] * b[base + i];
1952    }
1953
1954    (dot, norm)
1955}
1956
1957#[cfg(target_arch = "x86_64")]
1958#[target_feature(enable = "sse")]
1959#[allow(unsafe_op_in_unsafe_fn)]
1960unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
1961    use std::arch::x86_64::*;
1962
1963    let chunks = count / 4;
1964    let remainder = count % 4;
1965
1966    let mut acc_dot = _mm_setzero_ps();
1967    let mut acc_norm = _mm_setzero_ps();
1968
1969    for chunk in 0..chunks {
1970        let base = chunk * 4;
1971        let va = _mm_loadu_ps(a.as_ptr().add(base));
1972        let vb = _mm_loadu_ps(b.as_ptr().add(base));
1973        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
1974        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
1975    }
1976
1977    // Horizontal sums
1978    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
1979    let sums_d = _mm_add_ps(acc_dot, shuf_d);
1980    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
1981    let final_d = _mm_add_ss(sums_d, shuf2_d);
1982    let mut dot = _mm_cvtss_f32(final_d);
1983
1984    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
1985    let sums_n = _mm_add_ps(acc_norm, shuf_n);
1986    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
1987    let final_n = _mm_add_ss(sums_n, shuf2_n);
1988    let mut norm = _mm_cvtss_f32(final_n);
1989
1990    let base = chunks * 4;
1991    for i in 0..remainder {
1992        dot += a[base + i] * b[base + i];
1993        norm += b[base + i] * b[base + i];
1994    }
1995
1996    (dot, norm)
1997}
1998
1999/// Fast approximate reciprocal square root: 1/sqrt(x).
2000///
2001/// Uses the IEEE 754 bit trick (Quake III) + one Newton-Raphson iteration
2002/// for ~23-bit precision — sufficient for cosine similarity scoring.
2003/// ~3-5x faster than `1.0 / x.sqrt()` on most architectures.
2004#[inline]
2005pub fn fast_inv_sqrt(x: f32) -> f32 {
2006    let half = 0.5 * x;
2007    let i = 0x5F37_5A86_u32.wrapping_sub(x.to_bits() >> 1);
2008    let y = f32::from_bits(i);
2009    let y = y * (1.5 - half * y * y); // first Newton-Raphson step
2010    y * (1.5 - half * y * y) // second step: ~23-bit precision
2011}
2012
2013/// Batch cosine similarity: query vs N contiguous vectors.
2014///
2015/// `vectors` is a contiguous buffer of `n * dim` floats (row-major).
2016/// `scores` must have length >= n.
2017///
2018/// Optimizations over calling `cosine_similarity` N times:
2019/// 1. Query norm computed once (not N times)
2020/// 2. Fused dot+norm kernel — each vector loaded once (halves bandwidth)
2021/// 3. No per-call overhead (branch prediction, function calls)
2022/// 4. Fast reciprocal square root (~3-5x faster than 1/sqrt)
2023#[inline]
2024pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2025    let n = scores.len();
2026    debug_assert!(vectors.len() >= n * dim);
2027    debug_assert_eq!(query.len(), dim);
2028
2029    if dim == 0 || n == 0 {
2030        return;
2031    }
2032
2033    // Pre-compute query inverse norm once
2034    let norm_q_sq = dot_product_f32(query, query, dim);
2035    if norm_q_sq < f32::EPSILON {
2036        for s in scores.iter_mut() {
2037            *s = 0.0;
2038        }
2039        return;
2040    }
2041    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2042
2043    for i in 0..n {
2044        let vec = &vectors[i * dim..(i + 1) * dim];
2045        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2046        if norm_v_sq < f32::EPSILON {
2047            scores[i] = 0.0;
2048        } else {
2049            scores[i] = dot * inv_norm_q * fast_inv_sqrt(norm_v_sq);
2050        }
2051    }
2052}
2053
2054// ============================================================================
2055// f16 (IEEE 754 half-precision) conversion
2056// ============================================================================
2057
2058/// Convert f32 to f16 (IEEE 754 half-precision), stored as u16
2059#[inline]
2060pub fn f32_to_f16(value: f32) -> u16 {
2061    let bits = value.to_bits();
2062    let sign = (bits >> 16) & 0x8000;
2063    let exp = ((bits >> 23) & 0xFF) as i32;
2064    let mantissa = bits & 0x7F_FFFF;
2065
2066    if exp == 255 {
2067        // Inf/NaN
2068        return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
2069    }
2070
2071    let exp16 = exp - 127 + 15;
2072
2073    if exp16 >= 31 {
2074        return (sign | 0x7C00) as u16; // overflow → infinity
2075    }
2076
2077    if exp16 <= 0 {
2078        if exp16 < -10 {
2079            return sign as u16; // too small → zero
2080        }
2081        let m = (mantissa | 0x80_0000) >> (1 - exp16);
2082        return (sign | (m >> 13)) as u16;
2083    }
2084
2085    (sign | ((exp16 as u32) << 10) | (mantissa >> 13)) as u16
2086}
2087
2088/// Convert f16 (stored as u16) to f32
2089#[inline]
2090pub fn f16_to_f32(half: u16) -> f32 {
2091    let sign = ((half & 0x8000) as u32) << 16;
2092    let exp = ((half >> 10) & 0x1F) as u32;
2093    let mantissa = (half & 0x3FF) as u32;
2094
2095    if exp == 0 {
2096        if mantissa == 0 {
2097            return f32::from_bits(sign);
2098        }
2099        // Subnormal: normalize
2100        let mut e = 0u32;
2101        let mut m = mantissa;
2102        while (m & 0x400) == 0 {
2103            m <<= 1;
2104            e += 1;
2105        }
2106        return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
2107    }
2108
2109    if exp == 31 {
2110        return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
2111    }
2112
2113    f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
2114}
2115
2116// ============================================================================
2117// uint8 scalar quantization for [-1, 1] range
2118// ============================================================================
2119
2120const U8_SCALE: f32 = 127.5;
2121const U8_INV_SCALE: f32 = 1.0 / 127.5;
2122
2123/// Quantize f32 in [-1, 1] to u8 [0, 255]
2124#[inline]
2125pub fn f32_to_u8_saturating(value: f32) -> u8 {
2126    ((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
2127}
2128
2129/// Dequantize u8 [0, 255] to f32 in [-1, 1]
2130#[inline]
2131pub fn u8_to_f32(byte: u8) -> f32 {
2132    byte as f32 * U8_INV_SCALE - 1.0
2133}
2134
2135// ============================================================================
2136// Batch conversion (used during builder write)
2137// ============================================================================
2138
2139/// Batch convert f32 slice to f16 (stored as u16)
2140pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
2141    debug_assert_eq!(src.len(), dst.len());
2142    for (s, d) in src.iter().zip(dst.iter_mut()) {
2143        *d = f32_to_f16(*s);
2144    }
2145}
2146
2147/// Batch convert f32 slice to u8 with [-1,1] → [0,255] mapping
2148pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
2149    debug_assert_eq!(src.len(), dst.len());
2150    for (s, d) in src.iter().zip(dst.iter_mut()) {
2151        *d = f32_to_u8_saturating(*s);
2152    }
2153}
2154
2155// ============================================================================
2156// NEON-accelerated fused dot+norm for quantized vectors
2157// ============================================================================
2158
2159#[cfg(target_arch = "aarch64")]
2160#[allow(unsafe_op_in_unsafe_fn)]
2161mod neon_quant {
2162    use std::arch::aarch64::*;
2163
2164    /// Fused dot(query_f16, vec_f16) + norm(vec_f16) for f16 vectors on NEON.
2165    ///
2166    /// Both query and vectors are f16 (stored as u16). Uses hardware `vcvt_f32_f16`
2167    /// for SIMD f16→f32 conversion (replaces scalar bit manipulation), processes
2168    /// 8 elements per iteration with f32 accumulation for precision.
2169    #[target_feature(enable = "neon")]
2170    pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2171        let chunks8 = dim / 8;
2172        let remainder = dim % 8;
2173
2174        let mut acc_dot = vdupq_n_f32(0.0);
2175        let mut acc_norm = vdupq_n_f32(0.0);
2176
2177        for c in 0..chunks8 {
2178            let base = c * 8;
2179
2180            // Load 8 f16 vector values, hardware-convert to 2×4 f32
2181            let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2182            let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2183            let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2184
2185            // Load 8 f16 query values, hardware-convert to 2×4 f32
2186            let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2187            let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2188            let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2189
2190            acc_dot = vfmaq_f32(acc_dot, q_lo, v_lo);
2191            acc_dot = vfmaq_f32(acc_dot, q_hi, v_hi);
2192            acc_norm = vfmaq_f32(acc_norm, v_lo, v_lo);
2193            acc_norm = vfmaq_f32(acc_norm, v_hi, v_hi);
2194        }
2195
2196        let mut dot = vaddvq_f32(acc_dot);
2197        let mut norm = vaddvq_f32(acc_norm);
2198
2199        let base = chunks8 * 8;
2200        for i in 0..remainder {
2201            let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2202            let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2203            dot += q * v;
2204            norm += v * v;
2205        }
2206
2207        (dot, norm)
2208    }
2209
2210    /// Fused dot(query, vec) + norm(vec) for u8 vectors on NEON.
2211    /// Processes 16 u8 values per iteration using NEON widening chain.
2212    #[target_feature(enable = "neon")]
2213    pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2214        let scale = vdupq_n_f32(super::U8_INV_SCALE);
2215        let offset = vdupq_n_f32(-1.0);
2216
2217        let chunks16 = dim / 16;
2218        let remainder = dim % 16;
2219
2220        let mut acc_dot = vdupq_n_f32(0.0);
2221        let mut acc_norm = vdupq_n_f32(0.0);
2222
2223        for c in 0..chunks16 {
2224            let base = c * 16;
2225
2226            // Load 16 u8 values
2227            let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2228
2229            // Widen: 16×u8 → 2×8×u16 → 4×4×u32 → 4×4×f32
2230            let lo8 = vget_low_u8(bytes);
2231            let hi8 = vget_high_u8(bytes);
2232            let lo16 = vmovl_u8(lo8);
2233            let hi16 = vmovl_u8(hi8);
2234
2235            let f0 = vaddq_f32(
2236                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2237                offset,
2238            );
2239            let f1 = vaddq_f32(
2240                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2241                offset,
2242            );
2243            let f2 = vaddq_f32(
2244                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2245                offset,
2246            );
2247            let f3 = vaddq_f32(
2248                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2249                offset,
2250            );
2251
2252            let q0 = vld1q_f32(query.as_ptr().add(base));
2253            let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2254            let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2255            let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2256
2257            acc_dot = vfmaq_f32(acc_dot, q0, f0);
2258            acc_dot = vfmaq_f32(acc_dot, q1, f1);
2259            acc_dot = vfmaq_f32(acc_dot, q2, f2);
2260            acc_dot = vfmaq_f32(acc_dot, q3, f3);
2261
2262            acc_norm = vfmaq_f32(acc_norm, f0, f0);
2263            acc_norm = vfmaq_f32(acc_norm, f1, f1);
2264            acc_norm = vfmaq_f32(acc_norm, f2, f2);
2265            acc_norm = vfmaq_f32(acc_norm, f3, f3);
2266        }
2267
2268        let mut dot = vaddvq_f32(acc_dot);
2269        let mut norm = vaddvq_f32(acc_norm);
2270
2271        let base = chunks16 * 16;
2272        for i in 0..remainder {
2273            let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2274            dot += *query.get_unchecked(base + i) * v;
2275            norm += v * v;
2276        }
2277
2278        (dot, norm)
2279    }
2280
2281    /// Dot product only for f16 vectors on NEON (no norm — for unit_norm vectors).
2282    #[target_feature(enable = "neon")]
2283    pub unsafe fn dot_product_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2284        let chunks8 = dim / 8;
2285        let remainder = dim % 8;
2286
2287        let mut acc = vdupq_n_f32(0.0);
2288
2289        for c in 0..chunks8 {
2290            let base = c * 8;
2291            let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
2292            let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
2293            let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
2294            let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
2295            let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
2296            let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
2297            acc = vfmaq_f32(acc, q_lo, v_lo);
2298            acc = vfmaq_f32(acc, q_hi, v_hi);
2299        }
2300
2301        let mut dot = vaddvq_f32(acc);
2302        let base = chunks8 * 8;
2303        for i in 0..remainder {
2304            let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
2305            let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
2306            dot += q * v;
2307        }
2308        dot
2309    }
2310
2311    /// Dot product only for u8 vectors on NEON (no norm — for unit_norm vectors).
2312    #[target_feature(enable = "neon")]
2313    pub unsafe fn dot_product_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2314        let scale = vdupq_n_f32(super::U8_INV_SCALE);
2315        let offset = vdupq_n_f32(-1.0);
2316        let chunks16 = dim / 16;
2317        let remainder = dim % 16;
2318
2319        let mut acc = vdupq_n_f32(0.0);
2320
2321        for c in 0..chunks16 {
2322            let base = c * 16;
2323            let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
2324            let lo8 = vget_low_u8(bytes);
2325            let hi8 = vget_high_u8(bytes);
2326            let lo16 = vmovl_u8(lo8);
2327            let hi16 = vmovl_u8(hi8);
2328            let f0 = vaddq_f32(
2329                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
2330                offset,
2331            );
2332            let f1 = vaddq_f32(
2333                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
2334                offset,
2335            );
2336            let f2 = vaddq_f32(
2337                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
2338                offset,
2339            );
2340            let f3 = vaddq_f32(
2341                vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
2342                offset,
2343            );
2344            let q0 = vld1q_f32(query.as_ptr().add(base));
2345            let q1 = vld1q_f32(query.as_ptr().add(base + 4));
2346            let q2 = vld1q_f32(query.as_ptr().add(base + 8));
2347            let q3 = vld1q_f32(query.as_ptr().add(base + 12));
2348            acc = vfmaq_f32(acc, q0, f0);
2349            acc = vfmaq_f32(acc, q1, f1);
2350            acc = vfmaq_f32(acc, q2, f2);
2351            acc = vfmaq_f32(acc, q3, f3);
2352        }
2353
2354        let mut dot = vaddvq_f32(acc);
2355        let base = chunks16 * 16;
2356        for i in 0..remainder {
2357            let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
2358            dot += *query.get_unchecked(base + i) * v;
2359        }
2360        dot
2361    }
2362}
2363
2364// ============================================================================
2365// Scalar fallback for fused dot+norm on quantized vectors
2366// ============================================================================
2367
2368#[allow(dead_code)]
2369fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2370    let mut dot = 0.0f32;
2371    let mut norm = 0.0f32;
2372    for i in 0..dim {
2373        let v = f16_to_f32(vec_f16[i]);
2374        let q = f16_to_f32(query_f16[i]);
2375        dot += q * v;
2376        norm += v * v;
2377    }
2378    (dot, norm)
2379}
2380
2381#[allow(dead_code)]
2382fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2383    let mut dot = 0.0f32;
2384    let mut norm = 0.0f32;
2385    for i in 0..dim {
2386        let v = u8_to_f32(vec_u8[i]);
2387        dot += query[i] * v;
2388        norm += v * v;
2389    }
2390    (dot, norm)
2391}
2392
2393#[allow(dead_code)]
2394fn dot_product_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2395    let mut dot = 0.0f32;
2396    for i in 0..dim {
2397        dot += f16_to_f32(query_f16[i]) * f16_to_f32(vec_f16[i]);
2398    }
2399    dot
2400}
2401
2402#[allow(dead_code)]
2403fn dot_product_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2404    let mut dot = 0.0f32;
2405    for i in 0..dim {
2406        dot += query[i] * u8_to_f32(vec_u8[i]);
2407    }
2408    dot
2409}
2410
2411// ============================================================================
2412// x86_64 SSE4.1 quantized fused dot+norm
2413// ============================================================================
2414
2415#[cfg(target_arch = "x86_64")]
2416#[target_feature(enable = "sse2", enable = "sse4.1")]
2417#[allow(unsafe_op_in_unsafe_fn)]
2418unsafe fn fused_dot_norm_f16_sse(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2419    use std::arch::x86_64::*;
2420
2421    let chunks = dim / 4;
2422    let remainder = dim % 4;
2423
2424    let mut acc_dot = _mm_setzero_ps();
2425    let mut acc_norm = _mm_setzero_ps();
2426
2427    for chunk in 0..chunks {
2428        let base = chunk * 4;
2429        // Load 4 f16 values and convert to f32 using scalar conversion
2430        let v0 = f16_to_f32(*vec_f16.get_unchecked(base));
2431        let v1 = f16_to_f32(*vec_f16.get_unchecked(base + 1));
2432        let v2 = f16_to_f32(*vec_f16.get_unchecked(base + 2));
2433        let v3 = f16_to_f32(*vec_f16.get_unchecked(base + 3));
2434        let vb = _mm_set_ps(v3, v2, v1, v0);
2435
2436        let q0 = f16_to_f32(*query_f16.get_unchecked(base));
2437        let q1 = f16_to_f32(*query_f16.get_unchecked(base + 1));
2438        let q2 = f16_to_f32(*query_f16.get_unchecked(base + 2));
2439        let q3 = f16_to_f32(*query_f16.get_unchecked(base + 3));
2440        let va = _mm_set_ps(q3, q2, q1, q0);
2441
2442        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2443        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2444    }
2445
2446    // Horizontal sums
2447    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2448    let sums_d = _mm_add_ps(acc_dot, shuf_d);
2449    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2450    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2451
2452    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2453    let sums_n = _mm_add_ps(acc_norm, shuf_n);
2454    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2455    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2456
2457    let base = chunks * 4;
2458    for i in 0..remainder {
2459        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2460        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2461        dot += q * v;
2462        norm += v * v;
2463    }
2464
2465    (dot, norm)
2466}
2467
2468#[cfg(target_arch = "x86_64")]
2469#[target_feature(enable = "sse2", enable = "sse4.1")]
2470#[allow(unsafe_op_in_unsafe_fn)]
2471unsafe fn fused_dot_norm_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2472    use std::arch::x86_64::*;
2473
2474    let scale = _mm_set1_ps(U8_INV_SCALE);
2475    let offset = _mm_set1_ps(-1.0);
2476
2477    let chunks = dim / 4;
2478    let remainder = dim % 4;
2479
2480    let mut acc_dot = _mm_setzero_ps();
2481    let mut acc_norm = _mm_setzero_ps();
2482
2483    for chunk in 0..chunks {
2484        let base = chunk * 4;
2485
2486        // Load 4 bytes, zero-extend to i32, convert to f32, dequantize
2487        let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2488            vec_u8.as_ptr().add(base) as *const i32
2489        ));
2490        let ints = _mm_cvtepu8_epi32(bytes);
2491        let floats = _mm_cvtepi32_ps(ints);
2492        let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2493
2494        let va = _mm_loadu_ps(query.as_ptr().add(base));
2495
2496        acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
2497        acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
2498    }
2499
2500    // Horizontal sums
2501    let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
2502    let sums_d = _mm_add_ps(acc_dot, shuf_d);
2503    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2504    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2505
2506    let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
2507    let sums_n = _mm_add_ps(acc_norm, shuf_n);
2508    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2509    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2510
2511    let base = chunks * 4;
2512    for i in 0..remainder {
2513        let v = u8_to_f32(*vec_u8.get_unchecked(base + i));
2514        dot += *query.get_unchecked(base + i) * v;
2515        norm += v * v;
2516    }
2517
2518    (dot, norm)
2519}
2520
2521// ============================================================================
2522// x86_64 F16C + AVX + FMA accelerated f16 scoring
2523// ============================================================================
2524
2525#[cfg(target_arch = "x86_64")]
2526#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2527#[allow(unsafe_op_in_unsafe_fn)]
2528unsafe fn fused_dot_norm_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2529    use std::arch::x86_64::*;
2530
2531    let chunks = dim / 8;
2532    let remainder = dim % 8;
2533
2534    let mut acc_dot = _mm256_setzero_ps();
2535    let mut acc_norm = _mm256_setzero_ps();
2536
2537    for chunk in 0..chunks {
2538        let base = chunk * 8;
2539        // Hardware f16→f32: 8 values at once via F16C
2540        let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2541        let vb = _mm256_cvtph_ps(v_raw);
2542        let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2543        let qa = _mm256_cvtph_ps(q_raw);
2544        acc_dot = _mm256_fmadd_ps(qa, vb, acc_dot);
2545        acc_norm = _mm256_fmadd_ps(vb, vb, acc_norm);
2546    }
2547
2548    // Horizontal sum 256→128→scalar
2549    let hi_d = _mm256_extractf128_ps(acc_dot, 1);
2550    let lo_d = _mm256_castps256_ps128(acc_dot);
2551    let sum_d = _mm_add_ps(lo_d, hi_d);
2552    let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
2553    let sums_d = _mm_add_ps(sum_d, shuf_d);
2554    let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
2555    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
2556
2557    let hi_n = _mm256_extractf128_ps(acc_norm, 1);
2558    let lo_n = _mm256_castps256_ps128(acc_norm);
2559    let sum_n = _mm_add_ps(lo_n, hi_n);
2560    let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
2561    let sums_n = _mm_add_ps(sum_n, shuf_n);
2562    let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
2563    let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
2564
2565    let base = chunks * 8;
2566    for i in 0..remainder {
2567        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2568        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2569        dot += q * v;
2570        norm += v * v;
2571    }
2572
2573    (dot, norm)
2574}
2575
2576#[cfg(target_arch = "x86_64")]
2577#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
2578#[allow(unsafe_op_in_unsafe_fn)]
2579unsafe fn dot_product_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2580    use std::arch::x86_64::*;
2581
2582    let chunks = dim / 8;
2583    let remainder = dim % 8;
2584    let mut acc = _mm256_setzero_ps();
2585
2586    for chunk in 0..chunks {
2587        let base = chunk * 8;
2588        let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
2589        let vb = _mm256_cvtph_ps(v_raw);
2590        let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
2591        let qa = _mm256_cvtph_ps(q_raw);
2592        acc = _mm256_fmadd_ps(qa, vb, acc);
2593    }
2594
2595    let hi = _mm256_extractf128_ps(acc, 1);
2596    let lo = _mm256_castps256_ps128(acc);
2597    let sum = _mm_add_ps(lo, hi);
2598    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
2599    let sums = _mm_add_ps(sum, shuf);
2600    let shuf2 = _mm_movehl_ps(sums, sums);
2601    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2602
2603    let base = chunks * 8;
2604    for i in 0..remainder {
2605        let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
2606        let q = f16_to_f32(*query_f16.get_unchecked(base + i));
2607        dot += q * v;
2608    }
2609    dot
2610}
2611
2612#[cfg(target_arch = "x86_64")]
2613#[target_feature(enable = "sse2", enable = "sse4.1")]
2614#[allow(unsafe_op_in_unsafe_fn)]
2615unsafe fn dot_product_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2616    use std::arch::x86_64::*;
2617
2618    let scale = _mm_set1_ps(U8_INV_SCALE);
2619    let offset = _mm_set1_ps(-1.0);
2620    let chunks = dim / 4;
2621    let remainder = dim % 4;
2622    let mut acc = _mm_setzero_ps();
2623
2624    for chunk in 0..chunks {
2625        let base = chunk * 4;
2626        let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
2627            vec_u8.as_ptr().add(base) as *const i32
2628        ));
2629        let ints = _mm_cvtepu8_epi32(bytes);
2630        let floats = _mm_cvtepi32_ps(ints);
2631        let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
2632        let va = _mm_loadu_ps(query.as_ptr().add(base));
2633        acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
2634    }
2635
2636    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01);
2637    let sums = _mm_add_ps(acc, shuf);
2638    let shuf2 = _mm_movehl_ps(sums, sums);
2639    let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
2640
2641    let base = chunks * 4;
2642    for i in 0..remainder {
2643        dot += *query.get_unchecked(base + i) * u8_to_f32(*vec_u8.get_unchecked(base + i));
2644    }
2645    dot
2646}
2647
2648// ============================================================================
2649// Platform dispatch
2650// ============================================================================
2651
2652#[inline]
2653fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
2654    #[cfg(target_arch = "aarch64")]
2655    {
2656        return unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) };
2657    }
2658
2659    #[cfg(target_arch = "x86_64")]
2660    {
2661        if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2662            return unsafe { fused_dot_norm_f16_f16c(query_f16, vec_f16, dim) };
2663        }
2664        if sse::is_available() {
2665            return unsafe { fused_dot_norm_f16_sse(query_f16, vec_f16, dim) };
2666        }
2667    }
2668
2669    #[allow(unreachable_code)]
2670    fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
2671}
2672
2673#[inline]
2674fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
2675    #[cfg(target_arch = "aarch64")]
2676    {
2677        return unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) };
2678    }
2679
2680    #[cfg(target_arch = "x86_64")]
2681    {
2682        if sse::is_available() {
2683            return unsafe { fused_dot_norm_u8_sse(query, vec_u8, dim) };
2684        }
2685    }
2686
2687    #[allow(unreachable_code)]
2688    fused_dot_norm_u8_scalar(query, vec_u8, dim)
2689}
2690
2691// ── Dot-product-only dispatch (for unit_norm vectors) ─────────────────────
2692
2693#[inline]
2694fn dot_product_f16_quant(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
2695    #[cfg(target_arch = "aarch64")]
2696    {
2697        return unsafe { neon_quant::dot_product_f16(query_f16, vec_f16, dim) };
2698    }
2699
2700    #[cfg(target_arch = "x86_64")]
2701    {
2702        if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
2703            return unsafe { dot_product_f16_f16c(query_f16, vec_f16, dim) };
2704        }
2705    }
2706
2707    #[allow(unreachable_code)]
2708    dot_product_f16_scalar(query_f16, vec_f16, dim)
2709}
2710
2711#[inline]
2712fn dot_product_u8_quant(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
2713    #[cfg(target_arch = "aarch64")]
2714    {
2715        return unsafe { neon_quant::dot_product_u8(query, vec_u8, dim) };
2716    }
2717
2718    #[cfg(target_arch = "x86_64")]
2719    {
2720        if sse::is_available() {
2721            return unsafe { dot_product_u8_sse(query, vec_u8, dim) };
2722        }
2723    }
2724
2725    #[allow(unreachable_code)]
2726    dot_product_u8_scalar(query, vec_u8, dim)
2727}
2728
2729// ============================================================================
2730// Public batch cosine scoring for quantized vectors
2731// ============================================================================
2732
2733/// Batch cosine similarity: f32 query vs N contiguous f16 vectors.
2734///
2735/// `vectors_raw` is raw bytes: N vectors × dim × 2 bytes (f16 stored as u16).
2736/// Query is quantized to f16 once, then both query and vectors are scored in
2737/// f16 space using hardware SIMD conversion (8 elements/iteration on NEON).
2738/// Memory bandwidth is halved for both query and vector loads.
2739#[inline]
2740pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2741    let n = scores.len();
2742    if dim == 0 || n == 0 {
2743        return;
2744    }
2745
2746    // Compute query inverse norm in f32 (full precision, before quantization)
2747    let norm_q_sq = dot_product_f32(query, query, dim);
2748    if norm_q_sq < f32::EPSILON {
2749        for s in scores.iter_mut() {
2750            *s = 0.0;
2751        }
2752        return;
2753    }
2754    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2755
2756    // Quantize query to f16 once (O(dim)), reused for all N vector scorings
2757    let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2758
2759    let vec_bytes = dim * 2;
2760    debug_assert!(vectors_raw.len() >= n * vec_bytes);
2761
2762    // Vectors file uses data-first layout with 8-byte padding between fields,
2763    // so mmap slices are always 2-byte aligned for u16 access.
2764    debug_assert!(
2765        (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2766        "f16 vector data not 2-byte aligned"
2767    );
2768
2769    for i in 0..n {
2770        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2771        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2772
2773        let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
2774        scores[i] = if norm_v_sq < f32::EPSILON {
2775            0.0
2776        } else {
2777            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2778        };
2779    }
2780}
2781
2782/// Batch cosine similarity: f32 query vs N contiguous u8 vectors.
2783///
2784/// `vectors_raw` is raw bytes: N vectors × dim bytes (u8, mapping [-1,1]→[0,255]).
2785/// Converts u8→f32 using NEON widening chain (16 values/iteration), scores with FMA.
2786/// Memory bandwidth is quartered compared to f32 scoring.
2787#[inline]
2788pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2789    let n = scores.len();
2790    if dim == 0 || n == 0 {
2791        return;
2792    }
2793
2794    let norm_q_sq = dot_product_f32(query, query, dim);
2795    if norm_q_sq < f32::EPSILON {
2796        for s in scores.iter_mut() {
2797            *s = 0.0;
2798        }
2799        return;
2800    }
2801    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2802
2803    debug_assert!(vectors_raw.len() >= n * dim);
2804
2805    for i in 0..n {
2806        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2807
2808        let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
2809        scores[i] = if norm_v_sq < f32::EPSILON {
2810            0.0
2811        } else {
2812            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2813        };
2814    }
2815}
2816
2817// ============================================================================
2818// Batch dot-product scoring for unit-norm vectors
2819// ============================================================================
2820
2821/// Batch dot-product scoring: f32 query vs N contiguous f32 unit-norm vectors.
2822///
2823/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
2824/// Skips per-vector norm computation — ~40% less work than `batch_cosine_scores`.
2825#[inline]
2826pub fn batch_dot_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
2827    let n = scores.len();
2828    debug_assert!(vectors.len() >= n * dim);
2829    debug_assert_eq!(query.len(), dim);
2830
2831    if dim == 0 || n == 0 {
2832        return;
2833    }
2834
2835    let norm_q_sq = dot_product_f32(query, query, dim);
2836    if norm_q_sq < f32::EPSILON {
2837        for s in scores.iter_mut() {
2838            *s = 0.0;
2839        }
2840        return;
2841    }
2842    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2843
2844    for i in 0..n {
2845        let vec = &vectors[i * dim..(i + 1) * dim];
2846        let dot = dot_product_f32(query, vec, dim);
2847        scores[i] = dot * inv_norm_q;
2848    }
2849}
2850
2851/// Batch dot-product scoring: f32 query vs N contiguous f16 unit-norm vectors.
2852///
2853/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
2854/// Uses F16C/NEON hardware conversion + dot-only kernel.
2855#[inline]
2856pub fn batch_dot_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2857    let n = scores.len();
2858    if dim == 0 || n == 0 {
2859        return;
2860    }
2861
2862    let norm_q_sq = dot_product_f32(query, query, dim);
2863    if norm_q_sq < f32::EPSILON {
2864        for s in scores.iter_mut() {
2865            *s = 0.0;
2866        }
2867        return;
2868    }
2869    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2870
2871    let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
2872    let vec_bytes = dim * 2;
2873    debug_assert!(vectors_raw.len() >= n * vec_bytes);
2874    debug_assert!(
2875        (vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
2876        "f16 vector data not 2-byte aligned"
2877    );
2878
2879    for i in 0..n {
2880        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2881        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2882        let dot = dot_product_f16_quant(&query_f16, f16_slice, dim);
2883        scores[i] = dot * inv_norm_q;
2884    }
2885}
2886
2887/// Batch dot-product scoring: f32 query vs N contiguous u8 unit-norm vectors.
2888///
2889/// For pre-normalized vectors (||v|| = 1), cosine = dot(q, v) / ||q||.
2890/// Uses NEON/SSE widening chain for u8→f32 conversion + dot-only kernel.
2891#[inline]
2892pub fn batch_dot_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
2893    let n = scores.len();
2894    if dim == 0 || n == 0 {
2895        return;
2896    }
2897
2898    let norm_q_sq = dot_product_f32(query, query, dim);
2899    if norm_q_sq < f32::EPSILON {
2900        for s in scores.iter_mut() {
2901            *s = 0.0;
2902        }
2903        return;
2904    }
2905    let inv_norm_q = fast_inv_sqrt(norm_q_sq);
2906
2907    debug_assert!(vectors_raw.len() >= n * dim);
2908
2909    for i in 0..n {
2910        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2911        let dot = dot_product_u8_quant(query, u8_slice, dim);
2912        scores[i] = dot * inv_norm_q;
2913    }
2914}
2915
2916// ============================================================================
2917// Precomputed-norm batch scoring (avoids redundant query norm + f16 conversion)
2918// ============================================================================
2919
2920/// Batch cosine: f32 query vs N f32 vectors, with precomputed `inv_norm_q`.
2921#[inline]
2922pub fn batch_cosine_scores_precomp(
2923    query: &[f32],
2924    vectors: &[f32],
2925    dim: usize,
2926    scores: &mut [f32],
2927    inv_norm_q: f32,
2928) {
2929    let n = scores.len();
2930    debug_assert!(vectors.len() >= n * dim);
2931    for i in 0..n {
2932        let vec = &vectors[i * dim..(i + 1) * dim];
2933        let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
2934        scores[i] = if norm_v_sq < f32::EPSILON {
2935            0.0
2936        } else {
2937            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2938        };
2939    }
2940}
2941
2942/// Batch cosine: precomputed `inv_norm_q` + `query_f16` vs N f16 vectors.
2943#[inline]
2944pub fn batch_cosine_scores_f16_precomp(
2945    query_f16: &[u16],
2946    vectors_raw: &[u8],
2947    dim: usize,
2948    scores: &mut [f32],
2949    inv_norm_q: f32,
2950) {
2951    let n = scores.len();
2952    let vec_bytes = dim * 2;
2953    debug_assert!(vectors_raw.len() >= n * vec_bytes);
2954    for i in 0..n {
2955        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
2956        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
2957        let (dot, norm_v_sq) = fused_dot_norm_f16(query_f16, f16_slice, dim);
2958        scores[i] = if norm_v_sq < f32::EPSILON {
2959            0.0
2960        } else {
2961            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2962        };
2963    }
2964}
2965
2966/// Batch cosine: precomputed `inv_norm_q` vs N u8 vectors.
2967#[inline]
2968pub fn batch_cosine_scores_u8_precomp(
2969    query: &[f32],
2970    vectors_raw: &[u8],
2971    dim: usize,
2972    scores: &mut [f32],
2973    inv_norm_q: f32,
2974) {
2975    let n = scores.len();
2976    debug_assert!(vectors_raw.len() >= n * dim);
2977    for i in 0..n {
2978        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
2979        let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
2980        scores[i] = if norm_v_sq < f32::EPSILON {
2981            0.0
2982        } else {
2983            dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
2984        };
2985    }
2986}
2987
2988/// Batch dot-product: precomputed `inv_norm_q` vs N f32 unit-norm vectors.
2989#[inline]
2990pub fn batch_dot_scores_precomp(
2991    query: &[f32],
2992    vectors: &[f32],
2993    dim: usize,
2994    scores: &mut [f32],
2995    inv_norm_q: f32,
2996) {
2997    let n = scores.len();
2998    debug_assert!(vectors.len() >= n * dim);
2999    for i in 0..n {
3000        let vec = &vectors[i * dim..(i + 1) * dim];
3001        scores[i] = dot_product_f32(query, vec, dim) * inv_norm_q;
3002    }
3003}
3004
3005/// Batch dot-product: precomputed `inv_norm_q` + `query_f16` vs N f16 unit-norm vectors.
3006#[inline]
3007pub fn batch_dot_scores_f16_precomp(
3008    query_f16: &[u16],
3009    vectors_raw: &[u8],
3010    dim: usize,
3011    scores: &mut [f32],
3012    inv_norm_q: f32,
3013) {
3014    let n = scores.len();
3015    let vec_bytes = dim * 2;
3016    debug_assert!(vectors_raw.len() >= n * vec_bytes);
3017    for i in 0..n {
3018        let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
3019        let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
3020        scores[i] = dot_product_f16_quant(query_f16, f16_slice, dim) * inv_norm_q;
3021    }
3022}
3023
3024/// Batch dot-product: precomputed `inv_norm_q` vs N u8 unit-norm vectors.
3025#[inline]
3026pub fn batch_dot_scores_u8_precomp(
3027    query: &[f32],
3028    vectors_raw: &[u8],
3029    dim: usize,
3030    scores: &mut [f32],
3031    inv_norm_q: f32,
3032) {
3033    let n = scores.len();
3034    debug_assert!(vectors_raw.len() >= n * dim);
3035    for i in 0..n {
3036        let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
3037        scores[i] = dot_product_u8_quant(query, u8_slice, dim) * inv_norm_q;
3038    }
3039}
3040
3041/// Compute cosine similarity between two f32 vectors with SIMD acceleration
3042///
3043/// Returns dot(a,b) / (||a|| * ||b||), range [-1, 1]
3044/// Returns 0.0 if either vector has zero norm.
3045#[inline]
3046pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3047    debug_assert_eq!(a.len(), b.len());
3048    let count = a.len();
3049
3050    if count == 0 {
3051        return 0.0;
3052    }
3053
3054    let dot = dot_product_f32(a, b, count);
3055    let norm_a = dot_product_f32(a, a, count);
3056    let norm_b = dot_product_f32(b, b, count);
3057
3058    let denom = (norm_a * norm_b).sqrt();
3059    if denom < f32::EPSILON {
3060        return 0.0;
3061    }
3062
3063    dot / denom
3064}
3065
3066/// Compute squared Euclidean distance between two f32 vectors with SIMD acceleration
3067///
3068/// Returns sum((a[i] - b[i])^2) for all i
3069#[inline]
3070pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
3071    debug_assert_eq!(a.len(), b.len());
3072    let count = a.len();
3073
3074    if count == 0 {
3075        return 0.0;
3076    }
3077
3078    #[cfg(target_arch = "aarch64")]
3079    {
3080        if neon::is_available() {
3081            return unsafe { squared_euclidean_neon(a, b, count) };
3082        }
3083    }
3084
3085    #[cfg(target_arch = "x86_64")]
3086    {
3087        if avx2::is_available() {
3088            return unsafe { squared_euclidean_avx2(a, b, count) };
3089        }
3090        if sse::is_available() {
3091            return unsafe { squared_euclidean_sse(a, b, count) };
3092        }
3093    }
3094
3095    // Scalar fallback
3096    a.iter()
3097        .zip(b.iter())
3098        .map(|(&x, &y)| {
3099            let d = x - y;
3100            d * d
3101        })
3102        .sum()
3103}
3104
3105#[cfg(target_arch = "aarch64")]
3106#[target_feature(enable = "neon")]
3107#[allow(unsafe_op_in_unsafe_fn)]
3108unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
3109    use std::arch::aarch64::*;
3110
3111    let chunks = count / 4;
3112    let remainder = count % 4;
3113
3114    let mut acc = vdupq_n_f32(0.0);
3115
3116    for chunk in 0..chunks {
3117        let base = chunk * 4;
3118        let va = vld1q_f32(a.as_ptr().add(base));
3119        let vb = vld1q_f32(b.as_ptr().add(base));
3120        let diff = vsubq_f32(va, vb);
3121        acc = vfmaq_f32(acc, diff, diff); // acc += diff * diff (fused multiply-add)
3122    }
3123
3124    // Horizontal sum
3125    let mut sum = vaddvq_f32(acc);
3126
3127    // Handle remainder
3128    let base = chunks * 4;
3129    for i in 0..remainder {
3130        let d = a[base + i] - b[base + i];
3131        sum += d * d;
3132    }
3133
3134    sum
3135}
3136
3137#[cfg(target_arch = "x86_64")]
3138#[target_feature(enable = "sse")]
3139#[allow(unsafe_op_in_unsafe_fn)]
3140unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
3141    use std::arch::x86_64::*;
3142
3143    let chunks = count / 4;
3144    let remainder = count % 4;
3145
3146    let mut acc = _mm_setzero_ps();
3147
3148    for chunk in 0..chunks {
3149        let base = chunk * 4;
3150        let va = _mm_loadu_ps(a.as_ptr().add(base));
3151        let vb = _mm_loadu_ps(b.as_ptr().add(base));
3152        let diff = _mm_sub_ps(va, vb);
3153        acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
3154    }
3155
3156    // Horizontal sum: [a, b, c, d] -> a + b + c + d
3157    let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); // [b, a, d, c]
3158    let sums = _mm_add_ps(acc, shuf); // [a+b, a+b, c+d, c+d]
3159    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?]
3160    let final_sum = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?]
3161
3162    let mut sum = _mm_cvtss_f32(final_sum);
3163
3164    // Handle remainder
3165    let base = chunks * 4;
3166    for i in 0..remainder {
3167        let d = a[base + i] - b[base + i];
3168        sum += d * d;
3169    }
3170
3171    sum
3172}
3173
3174#[cfg(target_arch = "x86_64")]
3175#[target_feature(enable = "avx2")]
3176#[allow(unsafe_op_in_unsafe_fn)]
3177unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
3178    use std::arch::x86_64::*;
3179
3180    let chunks = count / 8;
3181    let remainder = count % 8;
3182
3183    let mut acc = _mm256_setzero_ps();
3184
3185    for chunk in 0..chunks {
3186        let base = chunk * 8;
3187        let va = _mm256_loadu_ps(a.as_ptr().add(base));
3188        let vb = _mm256_loadu_ps(b.as_ptr().add(base));
3189        let diff = _mm256_sub_ps(va, vb);
3190        acc = _mm256_fmadd_ps(diff, diff, acc); // acc += diff * diff (FMA)
3191    }
3192
3193    // Horizontal sum of 8 floats
3194    // First, add high 128 bits to low 128 bits
3195    let high = _mm256_extractf128_ps(acc, 1);
3196    let low = _mm256_castps256_ps128(acc);
3197    let sum128 = _mm_add_ps(low, high);
3198
3199    // Now sum the 4 floats in sum128
3200    let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
3201    let sums = _mm_add_ps(sum128, shuf);
3202    let shuf2 = _mm_movehl_ps(sums, sums);
3203    let final_sum = _mm_add_ss(sums, shuf2);
3204
3205    let mut sum = _mm_cvtss_f32(final_sum);
3206
3207    // Handle remainder
3208    let base = chunks * 8;
3209    for i in 0..remainder {
3210        let d = a[base + i] - b[base + i];
3211        sum += d * d;
3212    }
3213
3214    sum
3215}
3216
3217/// Batch compute squared Euclidean distances from one query to multiple vectors
3218///
3219/// Returns distances[i] = squared_euclidean_distance(query, vectors[i])
3220/// This is more efficient than calling squared_euclidean_distance in a loop
3221/// because we can keep the query in registers.
3222#[inline]
3223pub fn batch_squared_euclidean_distances(
3224    query: &[f32],
3225    vectors: &[Vec<f32>],
3226    distances: &mut [f32],
3227) {
3228    debug_assert_eq!(vectors.len(), distances.len());
3229
3230    #[cfg(target_arch = "x86_64")]
3231    {
3232        if avx2::is_available() {
3233            for (i, vec) in vectors.iter().enumerate() {
3234                distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
3235            }
3236            return;
3237        }
3238    }
3239
3240    // Fallback to individual calls
3241    for (i, vec) in vectors.iter().enumerate() {
3242        distances[i] = squared_euclidean_distance(query, vec);
3243    }
3244}
3245
3246#[cfg(test)]
3247mod tests {
3248    use super::*;
3249
3250    #[test]
3251    fn test_unpack_8bit() {
3252        let input: Vec<u8> = (0..128).collect();
3253        let mut output = vec![0u32; 128];
3254        unpack_8bit(&input, &mut output, 128);
3255
3256        for (i, &v) in output.iter().enumerate() {
3257            assert_eq!(v, i as u32);
3258        }
3259    }
3260
3261    #[test]
3262    fn test_unpack_16bit() {
3263        let mut input = vec![0u8; 256];
3264        for i in 0..128 {
3265            let val = (i * 100) as u16;
3266            input[i * 2] = val as u8;
3267            input[i * 2 + 1] = (val >> 8) as u8;
3268        }
3269
3270        let mut output = vec![0u32; 128];
3271        unpack_16bit(&input, &mut output, 128);
3272
3273        for (i, &v) in output.iter().enumerate() {
3274            assert_eq!(v, (i * 100) as u32);
3275        }
3276    }
3277
3278    #[test]
3279    fn test_unpack_32bit() {
3280        let mut input = vec![0u8; 512];
3281        for i in 0..128 {
3282            let val = (i * 1000) as u32;
3283            let bytes = val.to_le_bytes();
3284            input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
3285        }
3286
3287        let mut output = vec![0u32; 128];
3288        unpack_32bit(&input, &mut output, 128);
3289
3290        for (i, &v) in output.iter().enumerate() {
3291            assert_eq!(v, (i * 1000) as u32);
3292        }
3293    }
3294
3295    #[test]
3296    fn test_delta_decode() {
3297        // doc_ids: [10, 15, 20, 30, 50]
3298        // gaps: [5, 5, 10, 20]
3299        // deltas (gap-1): [4, 4, 9, 19]
3300        let deltas = vec![4u32, 4, 9, 19];
3301        let mut output = vec![0u32; 5];
3302
3303        delta_decode(&mut output, &deltas, 10, 5);
3304
3305        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3306    }
3307
3308    #[test]
3309    fn test_add_one() {
3310        let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
3311        add_one(&mut values, 8);
3312
3313        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
3314    }
3315
3316    #[test]
3317    fn test_bits_needed() {
3318        assert_eq!(bits_needed(0), 0);
3319        assert_eq!(bits_needed(1), 1);
3320        assert_eq!(bits_needed(2), 2);
3321        assert_eq!(bits_needed(3), 2);
3322        assert_eq!(bits_needed(4), 3);
3323        assert_eq!(bits_needed(255), 8);
3324        assert_eq!(bits_needed(256), 9);
3325        assert_eq!(bits_needed(u32::MAX), 32);
3326    }
3327
3328    #[test]
3329    fn test_unpack_8bit_delta_decode() {
3330        // doc_ids: [10, 15, 20, 30, 50]
3331        // gaps: [5, 5, 10, 20]
3332        // deltas (gap-1): [4, 4, 9, 19] stored as u8
3333        let input: Vec<u8> = vec![4, 4, 9, 19];
3334        let mut output = vec![0u32; 5];
3335
3336        unpack_8bit_delta_decode(&input, &mut output, 10, 5);
3337
3338        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3339    }
3340
3341    #[test]
3342    fn test_unpack_16bit_delta_decode() {
3343        // doc_ids: [100, 600, 1100, 2100, 4100]
3344        // gaps: [500, 500, 1000, 2000]
3345        // deltas (gap-1): [499, 499, 999, 1999] stored as u16
3346        let mut input = vec![0u8; 8];
3347        for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
3348            input[i * 2] = delta as u8;
3349            input[i * 2 + 1] = (delta >> 8) as u8;
3350        }
3351        let mut output = vec![0u32; 5];
3352
3353        unpack_16bit_delta_decode(&input, &mut output, 100, 5);
3354
3355        assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
3356    }
3357
3358    #[test]
3359    fn test_fused_vs_separate_8bit() {
3360        // Test that fused and separate operations produce the same result
3361        let input: Vec<u8> = (0..127).collect();
3362        let first_value = 1000u32;
3363        let count = 128;
3364
3365        // Separate: unpack then delta_decode
3366        let mut unpacked = vec![0u32; 128];
3367        unpack_8bit(&input, &mut unpacked, 127);
3368        let mut separate_output = vec![0u32; 128];
3369        delta_decode(&mut separate_output, &unpacked, first_value, count);
3370
3371        // Fused
3372        let mut fused_output = vec![0u32; 128];
3373        unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
3374
3375        assert_eq!(separate_output, fused_output);
3376    }
3377
3378    #[test]
3379    fn test_round_bit_width() {
3380        assert_eq!(round_bit_width(0), 0);
3381        assert_eq!(round_bit_width(1), 8);
3382        assert_eq!(round_bit_width(5), 8);
3383        assert_eq!(round_bit_width(8), 8);
3384        assert_eq!(round_bit_width(9), 16);
3385        assert_eq!(round_bit_width(12), 16);
3386        assert_eq!(round_bit_width(16), 16);
3387        assert_eq!(round_bit_width(17), 32);
3388        assert_eq!(round_bit_width(24), 32);
3389        assert_eq!(round_bit_width(32), 32);
3390    }
3391
3392    #[test]
3393    fn test_rounded_bitwidth_from_exact() {
3394        assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
3395        assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
3396        assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
3397        assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
3398        assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
3399        assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
3400        assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
3401    }
3402
3403    #[test]
3404    fn test_pack_unpack_rounded_8bit() {
3405        let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
3406        let mut packed = vec![0u8; 128];
3407
3408        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
3409        assert_eq!(bytes_written, 128);
3410
3411        let mut unpacked = vec![0u32; 128];
3412        unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
3413
3414        assert_eq!(values, unpacked);
3415    }
3416
3417    #[test]
3418    fn test_pack_unpack_rounded_16bit() {
3419        let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
3420        let mut packed = vec![0u8; 256];
3421
3422        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
3423        assert_eq!(bytes_written, 256);
3424
3425        let mut unpacked = vec![0u32; 128];
3426        unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
3427
3428        assert_eq!(values, unpacked);
3429    }
3430
3431    #[test]
3432    fn test_pack_unpack_rounded_32bit() {
3433        let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
3434        let mut packed = vec![0u8; 512];
3435
3436        let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
3437        assert_eq!(bytes_written, 512);
3438
3439        let mut unpacked = vec![0u32; 128];
3440        unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
3441
3442        assert_eq!(values, unpacked);
3443    }
3444
3445    #[test]
3446    fn test_unpack_rounded_delta_decode() {
3447        // Test 8-bit rounded delta decode
3448        // doc_ids: [10, 15, 20, 30, 50]
3449        // gaps: [5, 5, 10, 20]
3450        // deltas (gap-1): [4, 4, 9, 19] stored as u8
3451        let input: Vec<u8> = vec![4, 4, 9, 19];
3452        let mut output = vec![0u32; 5];
3453
3454        unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
3455
3456        assert_eq!(output, vec![10, 15, 20, 30, 50]);
3457    }
3458
3459    #[test]
3460    fn test_unpack_rounded_delta_decode_zero() {
3461        // All zeros means gaps of 1 (consecutive doc IDs)
3462        let input: Vec<u8> = vec![];
3463        let mut output = vec![0u32; 5];
3464
3465        unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
3466
3467        assert_eq!(output, vec![100, 101, 102, 103, 104]);
3468    }
3469
3470    // ========================================================================
3471    // Sparse Vector SIMD Tests
3472    // ========================================================================
3473
3474    #[test]
3475    fn test_dequantize_uint8() {
3476        let input: Vec<u8> = vec![0, 128, 255, 64, 192];
3477        let mut output = vec![0.0f32; 5];
3478        let scale = 0.1;
3479        let min_val = 1.0;
3480
3481        dequantize_uint8(&input, &mut output, scale, min_val, 5);
3482
3483        // Expected: input[i] * scale + min_val
3484        assert!((output[0] - 1.0).abs() < 1e-6); // 0 * 0.1 + 1.0 = 1.0
3485        assert!((output[1] - 13.8).abs() < 1e-6); // 128 * 0.1 + 1.0 = 13.8
3486        assert!((output[2] - 26.5).abs() < 1e-6); // 255 * 0.1 + 1.0 = 26.5
3487        assert!((output[3] - 7.4).abs() < 1e-6); // 64 * 0.1 + 1.0 = 7.4
3488        assert!((output[4] - 20.2).abs() < 1e-6); // 192 * 0.1 + 1.0 = 20.2
3489    }
3490
3491    #[test]
3492    fn test_dequantize_uint8_large() {
3493        // Test with 128 values (full SIMD block)
3494        let input: Vec<u8> = (0..128).collect();
3495        let mut output = vec![0.0f32; 128];
3496        let scale = 2.0;
3497        let min_val = -10.0;
3498
3499        dequantize_uint8(&input, &mut output, scale, min_val, 128);
3500
3501        for (i, &out) in output.iter().enumerate().take(128) {
3502            let expected = i as f32 * scale + min_val;
3503            assert!(
3504                (out - expected).abs() < 1e-5,
3505                "Mismatch at {}: expected {}, got {}",
3506                i,
3507                expected,
3508                out
3509            );
3510        }
3511    }
3512
3513    #[test]
3514    fn test_dot_product_f32() {
3515        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
3516        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
3517
3518        let result = dot_product_f32(&a, &b, 5);
3519
3520        // Expected: 1*2 + 2*3 + 3*4 + 4*5 + 5*6 = 2 + 6 + 12 + 20 + 30 = 70
3521        assert!((result - 70.0).abs() < 1e-5);
3522    }
3523
3524    #[test]
3525    fn test_dot_product_f32_large() {
3526        // Test with 128 values
3527        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
3528        let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
3529
3530        let result = dot_product_f32(&a, &b, 128);
3531
3532        // Compute expected
3533        let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
3534        assert!(
3535            (result - expected).abs() < 1e-3,
3536            "Expected {}, got {}",
3537            expected,
3538            result
3539        );
3540    }
3541
3542    #[test]
3543    fn test_max_f32() {
3544        let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
3545        let result = max_f32(&values, 6);
3546        assert!((result - 9.0).abs() < 1e-6);
3547    }
3548
3549    #[test]
3550    fn test_max_f32_large() {
3551        // Test with 128 values, max at position 77
3552        let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
3553        values[77] = 1000.0;
3554
3555        let result = max_f32(&values, 128);
3556        assert!((result - 1000.0).abs() < 1e-5);
3557    }
3558
3559    #[test]
3560    fn test_max_f32_negative() {
3561        let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
3562        let result = max_f32(&values, 5);
3563        assert!((result - (-1.0)).abs() < 1e-6);
3564    }
3565
3566    #[test]
3567    fn test_max_f32_empty() {
3568        let values: Vec<f32> = vec![];
3569        let result = max_f32(&values, 0);
3570        assert_eq!(result, f32::NEG_INFINITY);
3571    }
3572
3573    #[test]
3574    fn test_fused_dot_norm() {
3575        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3576        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3577        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3578
3579        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3580        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3581        assert!(
3582            (dot - expected_dot).abs() < 1e-5,
3583            "dot: expected {}, got {}",
3584            expected_dot,
3585            dot
3586        );
3587        assert!(
3588            (norm_b - expected_norm).abs() < 1e-5,
3589            "norm: expected {}, got {}",
3590            expected_norm,
3591            norm_b
3592        );
3593    }
3594
3595    #[test]
3596    fn test_fused_dot_norm_large() {
3597        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
3598        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
3599        let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
3600
3601        let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3602        let expected_norm: f32 = b.iter().map(|x| x * x).sum();
3603        assert!(
3604            (dot - expected_dot).abs() < 1.0,
3605            "dot: expected {}, got {}",
3606            expected_dot,
3607            dot
3608        );
3609        assert!(
3610            (norm_b - expected_norm).abs() < 1.0,
3611            "norm: expected {}, got {}",
3612            expected_norm,
3613            norm_b
3614        );
3615    }
3616
3617    #[test]
3618    fn test_batch_cosine_scores() {
3619        // 4 vectors of dim 3
3620        let query = vec![1.0f32, 0.0, 0.0];
3621        let vectors = vec![
3622            1.0, 0.0, 0.0, // identical to query
3623            0.0, 1.0, 0.0, // orthogonal
3624            -1.0, 0.0, 0.0, // opposite
3625            0.5, 0.5, 0.0, // 45 degrees
3626        ];
3627        let mut scores = vec![0f32; 4];
3628        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3629
3630        assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
3631        assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
3632        assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
3633        let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
3634        assert!(
3635            (scores[3] - expected_45).abs() < 1e-5,
3636            "45deg: expected {}, got {}",
3637            expected_45,
3638            scores[3]
3639        );
3640    }
3641
3642    #[test]
3643    fn test_batch_cosine_scores_matches_individual() {
3644        let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
3645        let n = 50;
3646        let dim = 128;
3647        let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
3648
3649        let mut batch_scores = vec![0f32; n];
3650        batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
3651
3652        for i in 0..n {
3653            let vec_i = &vectors[i * dim..(i + 1) * dim];
3654            let individual = cosine_similarity(&query, vec_i);
3655            assert!(
3656                (batch_scores[i] - individual).abs() < 1e-5,
3657                "vec {}: batch={}, individual={}",
3658                i,
3659                batch_scores[i],
3660                individual
3661            );
3662        }
3663    }
3664
3665    #[test]
3666    fn test_batch_cosine_scores_empty() {
3667        let query = vec![1.0f32, 2.0, 3.0];
3668        let vectors: Vec<f32> = vec![];
3669        let mut scores: Vec<f32> = vec![];
3670        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3671        assert!(scores.is_empty());
3672    }
3673
3674    #[test]
3675    fn test_batch_cosine_scores_zero_query() {
3676        let query = vec![0.0f32, 0.0, 0.0];
3677        let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
3678        let mut scores = vec![0f32; 2];
3679        batch_cosine_scores(&query, &vectors, 3, &mut scores);
3680        assert_eq!(scores[0], 0.0);
3681        assert_eq!(scores[1], 0.0);
3682    }
3683
3684    #[test]
3685    fn test_squared_euclidean_distance() {
3686        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
3687        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
3688        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3689        let result = squared_euclidean_distance(&a, &b);
3690        assert!(
3691            (result - expected).abs() < 1e-5,
3692            "expected {}, got {}",
3693            expected,
3694            result
3695        );
3696    }
3697
3698    #[test]
3699    fn test_squared_euclidean_distance_large() {
3700        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
3701        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
3702        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
3703        let result = squared_euclidean_distance(&a, &b);
3704        assert!(
3705            (result - expected).abs() < 1e-3,
3706            "expected {}, got {}",
3707            expected,
3708            result
3709        );
3710    }
3711
3712    // ================================================================
3713    // f16 conversion tests
3714    // ================================================================
3715
3716    #[test]
3717    fn test_f16_roundtrip_normal() {
3718        for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
3719            let h = f32_to_f16(v);
3720            let back = f16_to_f32(h);
3721            let err = (back - v).abs() / v.abs().max(1e-6);
3722            assert!(
3723                err < 0.002,
3724                "f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
3725            );
3726        }
3727    }
3728
3729    #[test]
3730    fn test_f16_special() {
3731        // Zero
3732        assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
3733        // Negative zero
3734        assert_eq!(f32_to_f16(-0.0), 0x8000);
3735        // Infinity
3736        assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
3737        // NaN
3738        assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
3739    }
3740
3741    #[test]
3742    fn test_f16_embedding_range() {
3743        // Typical embedding values in [-1, 1]
3744        let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
3745        for &v in &values {
3746            let back = f16_to_f32(f32_to_f16(v));
3747            assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
3748        }
3749    }
3750
3751    // ================================================================
3752    // u8 conversion tests
3753    // ================================================================
3754
3755    #[test]
3756    fn test_u8_roundtrip() {
3757        // Boundary values
3758        assert_eq!(f32_to_u8_saturating(-1.0), 0);
3759        assert_eq!(f32_to_u8_saturating(1.0), 255);
3760        assert_eq!(f32_to_u8_saturating(0.0), 127); // ~127.5 truncated
3761
3762        // Saturation
3763        assert_eq!(f32_to_u8_saturating(-2.0), 0);
3764        assert_eq!(f32_to_u8_saturating(2.0), 255);
3765    }
3766
3767    #[test]
3768    fn test_u8_dequantize() {
3769        assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
3770        assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
3771        assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
3772    }
3773
3774    // ================================================================
3775    // Batch scoring tests for quantized vectors
3776    // ================================================================
3777
3778    #[test]
3779    fn test_batch_cosine_scores_f16() {
3780        let query = vec![0.6f32, 0.8, 0.0, 0.0];
3781        let dim = 4;
3782        let vecs_f32 = vec![
3783            0.6f32, 0.8, 0.0, 0.0, // identical to query
3784            0.0, 0.0, 0.6, 0.8, // orthogonal
3785        ];
3786
3787        // Quantize to f16
3788        let mut f16_buf = vec![0u16; 8];
3789        batch_f32_to_f16(&vecs_f32, &mut f16_buf);
3790        let raw: &[u8] =
3791            unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
3792
3793        let mut scores = vec![0f32; 2];
3794        batch_cosine_scores_f16(&query, raw, dim, &mut scores);
3795
3796        assert!(
3797            (scores[0] - 1.0).abs() < 0.01,
3798            "identical vectors: {}",
3799            scores[0]
3800        );
3801        assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
3802    }
3803
3804    #[test]
3805    fn test_batch_cosine_scores_u8() {
3806        let query = vec![0.6f32, 0.8, 0.0, 0.0];
3807        let dim = 4;
3808        let vecs_f32 = vec![
3809            0.6f32, 0.8, 0.0, 0.0, // ~identical to query
3810            -0.6, -0.8, 0.0, 0.0, // opposite
3811        ];
3812
3813        // Quantize to u8
3814        let mut u8_buf = vec![0u8; 8];
3815        batch_f32_to_u8(&vecs_f32, &mut u8_buf);
3816
3817        let mut scores = vec![0f32; 2];
3818        batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
3819
3820        assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
3821        assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
3822    }
3823
3824    #[test]
3825    fn test_batch_cosine_scores_f16_large_dim() {
3826        // Test with typical embedding dimension
3827        let dim = 768;
3828        let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
3829        let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
3830
3831        let mut all_vecs = query.clone();
3832        all_vecs.extend_from_slice(&vec2);
3833
3834        let mut f16_buf = vec![0u16; all_vecs.len()];
3835        batch_f32_to_f16(&all_vecs, &mut f16_buf);
3836        let raw: &[u8] =
3837            unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
3838
3839        let mut scores = vec![0f32; 2];
3840        batch_cosine_scores_f16(&query, raw, dim, &mut scores);
3841
3842        // Self-similarity should be ~1.0
3843        assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
3844        // High similarity with scaled version
3845        assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
3846    }
3847}
3848
3849// ============================================================================
3850// SIMD-accelerated linear scan for sorted u32 slices (within-block seek)
3851// ============================================================================
3852
3853/// Find index of first element >= `target` in a sorted `u32` slice.
3854///
3855/// Equivalent to `slice.partition_point(|&d| d < target)` but uses SIMD to
3856/// scan 4 elements per cycle. Faster than binary search for slices ≤ 256
3857/// elements because it avoids the data-dependency chain inherent in binary
3858/// search (~8-10 cycles/iteration vs ~1-2 cycles/iteration for SIMD scan).
3859///
3860/// Returns `slice.len()` if no element >= `target`.
3861#[inline]
3862pub fn find_first_ge_u32(slice: &[u32], target: u32) -> usize {
3863    #[cfg(target_arch = "aarch64")]
3864    {
3865        if neon::is_available() {
3866            return unsafe { find_first_ge_u32_neon(slice, target) };
3867        }
3868    }
3869
3870    #[cfg(target_arch = "x86_64")]
3871    {
3872        if sse::is_available() {
3873            return unsafe { find_first_ge_u32_sse(slice, target) };
3874        }
3875    }
3876
3877    // Scalar fallback (WASM, other architectures)
3878    slice.partition_point(|&d| d < target)
3879}
3880
3881#[cfg(target_arch = "aarch64")]
3882#[target_feature(enable = "neon")]
3883#[allow(unsafe_op_in_unsafe_fn)]
3884unsafe fn find_first_ge_u32_neon(slice: &[u32], target: u32) -> usize {
3885    use std::arch::aarch64::*;
3886
3887    let n = slice.len();
3888    let ptr = slice.as_ptr();
3889    let target_vec = vdupq_n_u32(target);
3890    // Bit positions for each lane: [1, 2, 4, 8]
3891    let bit_mask: uint32x4_t = core::mem::transmute([1u32, 2u32, 4u32, 8u32]);
3892
3893    let chunks = n / 16;
3894    let mut base = 0usize;
3895
3896    // Process 16 elements per iteration (4 × 4-wide NEON compares)
3897    for _ in 0..chunks {
3898        let v0 = vld1q_u32(ptr.add(base));
3899        let v1 = vld1q_u32(ptr.add(base + 4));
3900        let v2 = vld1q_u32(ptr.add(base + 8));
3901        let v3 = vld1q_u32(ptr.add(base + 12));
3902
3903        let c0 = vcgeq_u32(v0, target_vec);
3904        let c1 = vcgeq_u32(v1, target_vec);
3905        let c2 = vcgeq_u32(v2, target_vec);
3906        let c3 = vcgeq_u32(v3, target_vec);
3907
3908        let m0 = vaddvq_u32(vandq_u32(c0, bit_mask));
3909        if m0 != 0 {
3910            return base + m0.trailing_zeros() as usize;
3911        }
3912        let m1 = vaddvq_u32(vandq_u32(c1, bit_mask));
3913        if m1 != 0 {
3914            return base + 4 + m1.trailing_zeros() as usize;
3915        }
3916        let m2 = vaddvq_u32(vandq_u32(c2, bit_mask));
3917        if m2 != 0 {
3918            return base + 8 + m2.trailing_zeros() as usize;
3919        }
3920        let m3 = vaddvq_u32(vandq_u32(c3, bit_mask));
3921        if m3 != 0 {
3922            return base + 12 + m3.trailing_zeros() as usize;
3923        }
3924        base += 16;
3925    }
3926
3927    // Process remaining 4 elements at a time
3928    while base + 4 <= n {
3929        let vals = vld1q_u32(ptr.add(base));
3930        let cmp = vcgeq_u32(vals, target_vec);
3931        let mask = vaddvq_u32(vandq_u32(cmp, bit_mask));
3932        if mask != 0 {
3933            return base + mask.trailing_zeros() as usize;
3934        }
3935        base += 4;
3936    }
3937
3938    // Scalar remainder (0-3 elements)
3939    while base < n {
3940        if *slice.get_unchecked(base) >= target {
3941            return base;
3942        }
3943        base += 1;
3944    }
3945    n
3946}
3947
3948#[cfg(target_arch = "x86_64")]
3949#[target_feature(enable = "sse2", enable = "sse4.1")]
3950#[allow(unsafe_op_in_unsafe_fn)]
3951unsafe fn find_first_ge_u32_sse(slice: &[u32], target: u32) -> usize {
3952    use std::arch::x86_64::*;
3953
3954    let n = slice.len();
3955    let ptr = slice.as_ptr();
3956
3957    // For unsigned >= comparison: XOR with 0x80000000 converts to signed domain
3958    let sign_flip = _mm_set1_epi32(i32::MIN);
3959    let target_xor = _mm_xor_si128(_mm_set1_epi32(target as i32), sign_flip);
3960
3961    let chunks = n / 16;
3962    let mut base = 0usize;
3963
3964    // Process 16 elements per iteration (4 × 4-wide SSE compares)
3965    for _ in 0..chunks {
3966        let v0 = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
3967        let v1 = _mm_xor_si128(
3968            _mm_loadu_si128(ptr.add(base + 4) as *const __m128i),
3969            sign_flip,
3970        );
3971        let v2 = _mm_xor_si128(
3972            _mm_loadu_si128(ptr.add(base + 8) as *const __m128i),
3973            sign_flip,
3974        );
3975        let v3 = _mm_xor_si128(
3976            _mm_loadu_si128(ptr.add(base + 12) as *const __m128i),
3977            sign_flip,
3978        );
3979
3980        // ge = eq | gt (in signed domain after XOR)
3981        let ge0 = _mm_or_si128(
3982            _mm_cmpeq_epi32(v0, target_xor),
3983            _mm_cmpgt_epi32(v0, target_xor),
3984        );
3985        let m0 = _mm_movemask_ps(_mm_castsi128_ps(ge0)) as u32;
3986        if m0 != 0 {
3987            return base + m0.trailing_zeros() as usize;
3988        }
3989
3990        let ge1 = _mm_or_si128(
3991            _mm_cmpeq_epi32(v1, target_xor),
3992            _mm_cmpgt_epi32(v1, target_xor),
3993        );
3994        let m1 = _mm_movemask_ps(_mm_castsi128_ps(ge1)) as u32;
3995        if m1 != 0 {
3996            return base + 4 + m1.trailing_zeros() as usize;
3997        }
3998
3999        let ge2 = _mm_or_si128(
4000            _mm_cmpeq_epi32(v2, target_xor),
4001            _mm_cmpgt_epi32(v2, target_xor),
4002        );
4003        let m2 = _mm_movemask_ps(_mm_castsi128_ps(ge2)) as u32;
4004        if m2 != 0 {
4005            return base + 8 + m2.trailing_zeros() as usize;
4006        }
4007
4008        let ge3 = _mm_or_si128(
4009            _mm_cmpeq_epi32(v3, target_xor),
4010            _mm_cmpgt_epi32(v3, target_xor),
4011        );
4012        let m3 = _mm_movemask_ps(_mm_castsi128_ps(ge3)) as u32;
4013        if m3 != 0 {
4014            return base + 12 + m3.trailing_zeros() as usize;
4015        }
4016        base += 16;
4017    }
4018
4019    // Process remaining 4 elements at a time
4020    while base + 4 <= n {
4021        let vals = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
4022        let ge = _mm_or_si128(
4023            _mm_cmpeq_epi32(vals, target_xor),
4024            _mm_cmpgt_epi32(vals, target_xor),
4025        );
4026        let mask = _mm_movemask_ps(_mm_castsi128_ps(ge)) as u32;
4027        if mask != 0 {
4028            return base + mask.trailing_zeros() as usize;
4029        }
4030        base += 4;
4031    }
4032
4033    // Scalar remainder (0-3 elements)
4034    while base < n {
4035        if *slice.get_unchecked(base) >= target {
4036            return base;
4037        }
4038        base += 1;
4039    }
4040    n
4041}
4042
4043#[cfg(test)]
4044mod find_first_ge_tests {
4045    use super::find_first_ge_u32;
4046
4047    #[test]
4048    fn test_find_first_ge_basic() {
4049        let data: Vec<u32> = (0..128).map(|i| i * 3).collect(); // [0, 3, 6, ..., 381]
4050        assert_eq!(find_first_ge_u32(&data, 0), 0);
4051        assert_eq!(find_first_ge_u32(&data, 1), 1); // first >= 1 is 3 at idx 1
4052        assert_eq!(find_first_ge_u32(&data, 3), 1);
4053        assert_eq!(find_first_ge_u32(&data, 4), 2); // first >= 4 is 6 at idx 2
4054        assert_eq!(find_first_ge_u32(&data, 381), 127);
4055        assert_eq!(find_first_ge_u32(&data, 382), 128); // past end
4056    }
4057
4058    #[test]
4059    fn test_find_first_ge_matches_partition_point() {
4060        let data: Vec<u32> = vec![1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75];
4061        for target in 0..80 {
4062            let expected = data.partition_point(|&d| d < target);
4063            let actual = find_first_ge_u32(&data, target);
4064            assert_eq!(actual, expected, "target={}", target);
4065        }
4066    }
4067
4068    #[test]
4069    fn test_find_first_ge_small_slices() {
4070        // Empty
4071        assert_eq!(find_first_ge_u32(&[], 5), 0);
4072        // Single element
4073        assert_eq!(find_first_ge_u32(&[10], 5), 0);
4074        assert_eq!(find_first_ge_u32(&[10], 10), 0);
4075        assert_eq!(find_first_ge_u32(&[10], 11), 1);
4076        // Three elements (< SIMD width)
4077        assert_eq!(find_first_ge_u32(&[2, 4, 6], 5), 2);
4078    }
4079
4080    #[test]
4081    fn test_find_first_ge_full_block() {
4082        // Simulate a full 128-entry block
4083        let data: Vec<u32> = (100..228).collect();
4084        assert_eq!(find_first_ge_u32(&data, 100), 0);
4085        assert_eq!(find_first_ge_u32(&data, 150), 50);
4086        assert_eq!(find_first_ge_u32(&data, 227), 127);
4087        assert_eq!(find_first_ge_u32(&data, 228), 128);
4088        assert_eq!(find_first_ge_u32(&data, 99), 0);
4089    }
4090
4091    #[test]
4092    fn test_find_first_ge_u32_max() {
4093        // Test with large u32 values (unsigned correctness)
4094        let data = vec![u32::MAX - 10, u32::MAX - 5, u32::MAX - 1, u32::MAX];
4095        assert_eq!(find_first_ge_u32(&data, u32::MAX - 10), 0);
4096        assert_eq!(find_first_ge_u32(&data, u32::MAX - 7), 1);
4097        assert_eq!(find_first_ge_u32(&data, u32::MAX), 3);
4098    }
4099}