Skip to main content

reddb_server/storage/engine/turboquant/
scoring.rs

1//! TurboQuant per-block scoring kernels.
2//!
3//! This slice (S1 of PRD #688 / ADR 0024) ships the reference scalar
4//! kernel only. SIMD kernels (NEON, AVX2, AVX-512BW) are added in
5//! later slices and join [`select_scorer`] without changing the
6//! dispatch surface or paying a per-query branch cost.
7//!
8//! MIT notice: LUT construction shape and the PERM0-aware decode loop
9//! are derived from RyanCodrai/turbovec (commit
10//! `4a4f2cd2db233f24405911b1ceaf1823fa23b4ac`, MIT); the RedDB
11//! `PerBlockScorer` trait, dispatch, and scalar oracle are
12//! clean-room.
13
14use super::storage::{BLOCK_LANES, PERM0};
15
16const MAX_LUT: f32 = 127.0;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ScoreKernel {
20    Scalar,
21    Avx2,
22    Avx512Bw,
23    Neon,
24}
25
26/// Per-query lookup table. Built once per `score_many` call and shared
27/// across every block of a collection.
28#[derive(Debug, Clone)]
29pub struct QueryLut {
30    pub bytes: Vec<u8>,
31    pub n_byte_groups: usize,
32    pub scale: f32,
33    pub bias: f32,
34}
35
36impl QueryLut {
37    pub fn build(query_terms: &[f32], centroids: &[f64]) -> Self {
38        let n_byte_groups = query_terms.len().div_ceil(2);
39        let mut float_luts = vec![0.0f32; n_byte_groups * 32];
40        let mut max_span = 0.0f32;
41        let mut bias = 0.0f32;
42
43        for group in 0..n_byte_groups {
44            let lo_term = query_terms[group * 2];
45            let hi_term = query_terms.get(group * 2 + 1).copied().unwrap_or(0.0);
46            let group_base = group * 32;
47
48            for code in 0..16 {
49                float_luts[group_base + code] = hi_term * centroids[code] as f32;
50                float_luts[group_base + 16 + code] = lo_term * centroids[code] as f32;
51            }
52
53            let hi = &float_luts[group_base..group_base + 16];
54            let lo = &float_luts[group_base + 16..group_base + 32];
55            let hi_min = hi.iter().copied().fold(f32::INFINITY, f32::min);
56            let hi_max = hi.iter().copied().fold(f32::NEG_INFINITY, f32::max);
57            let lo_min = lo.iter().copied().fold(f32::INFINITY, f32::min);
58            let lo_max = lo.iter().copied().fold(f32::NEG_INFINITY, f32::max);
59            bias += hi_min + lo_min;
60            max_span = max_span.max((hi_max - hi_min).max(lo_max - lo_min));
61        }
62
63        let scale = if max_span > 1e-10 {
64            max_span / MAX_LUT
65        } else {
66            1.0
67        };
68        let inv_scale = 1.0 / scale;
69        let mut bytes = vec![0u8; n_byte_groups * 32];
70
71        for group in 0..n_byte_groups {
72            let group_base = group * 32;
73            let hi_min = float_luts[group_base..group_base + 16]
74                .iter()
75                .copied()
76                .fold(f32::INFINITY, f32::min);
77            let lo_min = float_luts[group_base + 16..group_base + 32]
78                .iter()
79                .copied()
80                .fold(f32::INFINITY, f32::min);
81            for code in 0..16 {
82                bytes[group_base + code] = ((float_luts[group_base + code] - hi_min) * inv_scale)
83                    .round()
84                    .clamp(0.0, MAX_LUT) as u8;
85                bytes[group_base + 16 + code] = ((float_luts[group_base + 16 + code] - lo_min)
86                    * inv_scale)
87                    .round()
88                    .clamp(0.0, MAX_LUT) as u8;
89            }
90        }
91
92        Self {
93            bytes,
94            n_byte_groups,
95            scale,
96            bias,
97        }
98    }
99}
100
101/// Single block scoring trait. Each implementation consumes one
102/// `block_codes` slice (PERM0-interleaved, 64-byte aligned by
103/// [`super::storage::BlockedCodeStorage`]) and writes one score per
104/// lane into `out`.
105///
106/// SIMD slices add impls (NEON, AVX2, AVX-512BW) and join
107/// [`select_scorer`] without changing this trait.
108pub trait PerBlockScorer: Sync + Send {
109    fn kernel(&self) -> ScoreKernel;
110
111    /// Compute `lut.scale * sum_g(lut[g][hi] + lut[g][lo]) + lut.bias`
112    /// for each lane in `0..n_vectors`. Lanes `>= n_vectors` are filled
113    /// with `0.0`. The output is the unit-rotated-query dot product —
114    /// outer code applies metric-specific transforms and per-vector
115    /// scales.
116    fn score_block(
117        &self,
118        lut: &QueryLut,
119        block_codes: &[u8],
120        n_byte_groups: usize,
121        n_vectors: usize,
122        out: &mut [f32; BLOCK_LANES],
123    );
124}
125
126/// Reference scalar implementation. Acts as the oracle the
127/// equivalence-test harness keys off — every SIMD slice must match
128/// this kernel bit-exactly.
129pub struct ScalarScorer;
130
131impl PerBlockScorer for ScalarScorer {
132    fn kernel(&self) -> ScoreKernel {
133        ScoreKernel::Scalar
134    }
135
136    fn score_block(
137        &self,
138        lut: &QueryLut,
139        block_codes: &[u8],
140        n_byte_groups: usize,
141        n_vectors: usize,
142        out: &mut [f32; BLOCK_LANES],
143    ) {
144        debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
145        debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
146        debug_assert!(n_vectors <= BLOCK_LANES);
147
148        for (lane, slot) in out.iter_mut().enumerate() {
149            if lane >= n_vectors {
150                *slot = 0.0;
151                continue;
152            }
153            let mut acc = 0u32;
154            for g in 0..n_byte_groups {
155                let (hi, lo) = decode_perm0_byte(block_codes, g, lane);
156                acc = acc.wrapping_add(lut.bytes[g * 32 + hi as usize] as u32);
157                acc = acc.wrapping_add(lut.bytes[g * 32 + 16 + lo as usize] as u32);
158            }
159            *slot = lut.scale.mul_add(acc as f32, lut.bias);
160        }
161    }
162}
163
164static SCALAR_SCORER: ScalarScorer = ScalarScorer;
165
166#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
167static AVX2_SCORER: Avx2Scorer = Avx2Scorer;
168
169#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
170static AVX512BW_SCORER: Avx512BwScorer = Avx512BwScorer;
171
172#[cfg(target_arch = "aarch64")]
173static NEON_SCORER: NeonScorer = NeonScorer;
174
175/// Pick the best available scoring kernel for this host. SIMD slices
176/// register themselves here without touching the call site; the choice
177/// is made once per query (no per-block branch cost).
178pub fn select_scorer() -> &'static dyn PerBlockScorer {
179    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
180    {
181        // AVX-512BW is the widest x86 kernel and takes precedence over
182        // AVX2 when available. FMA is required throughout to match
183        // `ScalarScorer`'s `f32::mul_add` bit-exactly; the SIMD paths
184        // use `vfmadd` rather than separate mul+add to stay byte-
185        // identical. AVX-512F implies AVX2, and FMA3 ships on every
186        // relevant Intel/AMD AVX2+ core; rare hosts that drop FMA fall
187        // back to scalar.
188        if std::is_x86_feature_detected!("avx512bw")
189            && std::is_x86_feature_detected!("avx512f")
190            && std::is_x86_feature_detected!("fma")
191        {
192            return &AVX512BW_SCORER;
193        }
194        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
195            return &AVX2_SCORER;
196        }
197    }
198    #[cfg(target_arch = "aarch64")]
199    {
200        // NEON (Advanced SIMD) is mandatory in the AArch64 base ISA, so
201        // no runtime detection is needed — every aarch64 host supports
202        // the intrinsics this kernel uses. `vfmaq_f32` is the
203        // single-rounding fused multiply-add that matches
204        // `ScalarScorer`'s `f32::mul_add` bit-exactly.
205        return &NEON_SCORER;
206    }
207    #[allow(unreachable_code)]
208    &SCALAR_SCORER
209}
210
211/// AVX2 block scorer. Reads aligned 256-bit lanes straight from
212/// [`super::storage::BlockedCodeStorage::block_codes`] with no
213/// per-query repack and table-looks up nibble scores via `vpshufb`.
214///
215/// MIT notice: the SIMD body is adapted from RyanCodrai/turbovec
216/// (commit `4a4f2cd2db233f24405911b1ceaf1823fa23b4ac`, MIT). The
217/// per-vector scale and tail handling are clean-room — the trait
218/// returns the unit-rotated dot product only; outer code applies
219/// metric and per-vector scale, matching [`ScalarScorer`].
220#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221pub struct Avx2Scorer;
222
223#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
224impl PerBlockScorer for Avx2Scorer {
225    fn kernel(&self) -> ScoreKernel {
226        ScoreKernel::Avx2
227    }
228
229    fn score_block(
230        &self,
231        lut: &QueryLut,
232        block_codes: &[u8],
233        n_byte_groups: usize,
234        n_vectors: usize,
235        out: &mut [f32; BLOCK_LANES],
236    ) {
237        debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
238        debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
239        debug_assert!(n_vectors <= BLOCK_LANES);
240        debug_assert!(std::is_x86_feature_detected!("avx2"));
241        debug_assert!(std::is_x86_feature_detected!("fma"));
242        // SAFETY: AVX2 + FMA availability is enforced by `select_scorer`
243        // and re-asserted by the debug checks above.
244        unsafe {
245            score_block_avx2_inner(lut, block_codes, n_byte_groups, n_vectors, out);
246        }
247    }
248}
249
250#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251#[target_feature(enable = "avx2,fma")]
252unsafe fn score_block_avx2_inner(
253    lut: &QueryLut,
254    block_codes: &[u8],
255    n_byte_groups: usize,
256    n_vectors: usize,
257    out: &mut [f32; BLOCK_LANES],
258) {
259    #[cfg(target_arch = "x86")]
260    use std::arch::x86::*;
261    #[cfg(target_arch = "x86_64")]
262    use std::arch::x86_64::*;
263
264    let mut accum = [_mm256_setzero_si256(); 4];
265    let nibble_mask = _mm256_set1_epi8(0x0f);
266
267    for g in 0..n_byte_groups {
268        let codes = _mm256_loadu_si256(block_codes.as_ptr().add(g * BLOCK_LANES) as *const __m256i);
269        let clo = _mm256_and_si256(codes, nibble_mask);
270        let chi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), nibble_mask);
271        let table = _mm256_loadu_si256(lut.bytes.as_ptr().add(g * 32) as *const __m256i);
272        let lo_scores = _mm256_shuffle_epi8(table, clo);
273        let hi_scores = _mm256_shuffle_epi8(table, chi);
274
275        accum[0] = _mm256_add_epi16(accum[0], lo_scores);
276        accum[1] = _mm256_add_epi16(accum[1], _mm256_srli_epi16(lo_scores, 8));
277        accum[2] = _mm256_add_epi16(accum[2], hi_scores);
278        accum[3] = _mm256_add_epi16(accum[3], _mm256_srli_epi16(hi_scores, 8));
279    }
280
281    accum[0] = _mm256_sub_epi16(accum[0], _mm256_slli_epi16(accum[1], 8));
282    accum[2] = _mm256_sub_epi16(accum[2], _mm256_slli_epi16(accum[3], 8));
283
284    let dis0 = _mm256_add_epi16(
285        _mm256_permute2x128_si256(accum[0], accum[1], 0x21),
286        _mm256_blend_epi32(accum[0], accum[1], 0xf0),
287    );
288    let dis1 = _mm256_add_epi16(
289        _mm256_permute2x128_si256(accum[2], accum[3], 0x21),
290        _mm256_blend_epi32(accum[2], accum[3], 0xf0),
291    );
292
293    let sums = [
294        _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(dis0))),
295        _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(dis0, 1))),
296        _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(dis1))),
297        _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(dis1, 1))),
298    ];
299    let v_scale = _mm256_set1_ps(lut.scale);
300    let v_bias = _mm256_set1_ps(lut.bias);
301
302    for (chunk, sum) in sums.iter().enumerate() {
303        let lane_start = chunk * 8;
304        // FMA matches `f32::mul_add` used by `ScalarScorer` bit-exactly.
305        let score = _mm256_fmadd_ps(v_scale, *sum, v_bias);
306        _mm256_storeu_ps(out.as_mut_ptr().add(lane_start), score);
307    }
308
309    // Tail lanes match the scalar oracle: unused slots are 0.0.
310    for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
311        *score = 0.0;
312    }
313}
314
315/// AVX-512BW block scorer. Processes pairs of byte groups in a single
316/// 512-bit register so that `vpshufb` resolves four 128-bit lane
317/// lookups per iteration (two groups × hi/lo tables).
318///
319/// MIT notice: the paired-group `vpshufb` + u16-accumulator structure
320/// is adapted from RyanCodrai/turbovec (commit
321/// `4a4f2cd2db233f24405911b1ceaf1823fa23b4ac`, MIT). The single-block
322/// trait surface, the fold from 512→256-bit lanes, and the tail
323/// handling are clean-room.
324#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
325pub struct Avx512BwScorer;
326
327#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
328impl PerBlockScorer for Avx512BwScorer {
329    fn kernel(&self) -> ScoreKernel {
330        ScoreKernel::Avx512Bw
331    }
332
333    fn score_block(
334        &self,
335        lut: &QueryLut,
336        block_codes: &[u8],
337        n_byte_groups: usize,
338        n_vectors: usize,
339        out: &mut [f32; BLOCK_LANES],
340    ) {
341        debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
342        debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
343        debug_assert!(n_vectors <= BLOCK_LANES);
344        debug_assert!(std::is_x86_feature_detected!("avx512bw"));
345        debug_assert!(std::is_x86_feature_detected!("avx512f"));
346        debug_assert!(std::is_x86_feature_detected!("fma"));
347        // SAFETY: AVX-512BW/F + FMA availability is enforced by
348        // `select_scorer` and re-asserted by the debug checks above.
349        unsafe {
350            score_block_avx512bw_inner(lut, block_codes, n_byte_groups, n_vectors, out);
351        }
352    }
353}
354
355#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
356#[target_feature(enable = "avx512bw,avx512f,avx2,avx,fma")]
357unsafe fn score_block_avx512bw_inner(
358    lut: &QueryLut,
359    block_codes: &[u8],
360    n_byte_groups: usize,
361    n_vectors: usize,
362    out: &mut [f32; BLOCK_LANES],
363) {
364    #[cfg(target_arch = "x86")]
365    use std::arch::x86::*;
366    #[cfg(target_arch = "x86_64")]
367    use std::arch::x86_64::*;
368
369    // 32 u16 lanes per accumulator — one position per byte across two
370    // adjacent groups. The lower 256 bits hold group g, the upper 256
371    // hold group g+1; they are folded together after the loop. The u16
372    // headroom matches the AVX2 kernel: max sum per lane = 2 *
373    // n_byte_groups * MAX_LUT, safe for n_byte_groups <= 258.
374    let mut accum = [_mm512_setzero_si512(); 4];
375    let nibble_mask = _mm512_set1_epi8(0x0f);
376
377    let mut g = 0;
378    while g + 2 <= n_byte_groups {
379        let codes = _mm512_loadu_si512(block_codes.as_ptr().add(g * BLOCK_LANES) as *const _);
380        let clo = _mm512_and_si512(codes, nibble_mask);
381        let chi = _mm512_and_si512(_mm512_srli_epi16(codes, 4), nibble_mask);
382        let table = _mm512_loadu_si512(lut.bytes.as_ptr().add(g * 32) as *const _);
383        let lo_scores = _mm512_shuffle_epi8(table, clo);
384        let hi_scores = _mm512_shuffle_epi8(table, chi);
385
386        accum[0] = _mm512_add_epi16(accum[0], lo_scores);
387        accum[1] = _mm512_add_epi16(accum[1], _mm512_srli_epi16(lo_scores, 8));
388        accum[2] = _mm512_add_epi16(accum[2], hi_scores);
389        accum[3] = _mm512_add_epi16(accum[3], _mm512_srli_epi16(hi_scores, 8));
390
391        g += 2;
392    }
393
394    // Fold the 512-bit accumulators back to 256-bit by summing the two
395    // halves. After this, each `accum256[i]` matches what the AVX2
396    // kernel computes for the same `accum[i]`, modulo the unpaired
397    // tail group which is added below.
398    let mut accum256 = [
399        _mm256_add_epi16(
400            _mm512_castsi512_si256(accum[0]),
401            _mm512_extracti64x4_epi64(accum[0], 1),
402        ),
403        _mm256_add_epi16(
404            _mm512_castsi512_si256(accum[1]),
405            _mm512_extracti64x4_epi64(accum[1], 1),
406        ),
407        _mm256_add_epi16(
408            _mm512_castsi512_si256(accum[2]),
409            _mm512_extracti64x4_epi64(accum[2], 1),
410        ),
411        _mm256_add_epi16(
412            _mm512_castsi512_si256(accum[3]),
413            _mm512_extracti64x4_epi64(accum[3], 1),
414        ),
415    ];
416
417    // Tail: an odd number of byte groups leaves one group unpaired.
418    // Handle it with the AVX2-shaped 256-bit kernel body so the result
419    // remains bit-identical to `ScalarScorer` and `Avx2Scorer`.
420    if g < n_byte_groups {
421        let nibble_mask_256 = _mm256_set1_epi8(0x0f);
422        let codes = _mm256_loadu_si256(block_codes.as_ptr().add(g * BLOCK_LANES) as *const __m256i);
423        let clo = _mm256_and_si256(codes, nibble_mask_256);
424        let chi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), nibble_mask_256);
425        let table = _mm256_loadu_si256(lut.bytes.as_ptr().add(g * 32) as *const __m256i);
426        let lo_scores = _mm256_shuffle_epi8(table, clo);
427        let hi_scores = _mm256_shuffle_epi8(table, chi);
428
429        accum256[0] = _mm256_add_epi16(accum256[0], lo_scores);
430        accum256[1] = _mm256_add_epi16(accum256[1], _mm256_srli_epi16(lo_scores, 8));
431        accum256[2] = _mm256_add_epi16(accum256[2], hi_scores);
432        accum256[3] = _mm256_add_epi16(accum256[3], _mm256_srli_epi16(hi_scores, 8));
433    }
434
435    // Split the (high, low) byte-position sums per the AVX2 trick:
436    // accum256[0] currently holds, per 16-bit lane i,
437    //   lo_scores[2i] + lo_scores[2i+1] * 256 summed across groups.
438    // accum256[1] holds sum of lo_scores[2i+1] (zero-extended).
439    // Subtracting (accum256[1] << 8) leaves sum of lo_scores[2i].
440    accum256[0] = _mm256_sub_epi16(accum256[0], _mm256_slli_epi16(accum256[1], 8));
441    accum256[2] = _mm256_sub_epi16(accum256[2], _mm256_slli_epi16(accum256[3], 8));
442
443    // Interleave even/odd byte-position sums back into lane order.
444    let dis0 = _mm256_add_epi16(
445        _mm256_permute2x128_si256(accum256[0], accum256[1], 0x21),
446        _mm256_blend_epi32(accum256[0], accum256[1], 0xf0),
447    );
448    let dis1 = _mm256_add_epi16(
449        _mm256_permute2x128_si256(accum256[2], accum256[3], 0x21),
450        _mm256_blend_epi32(accum256[2], accum256[3], 0xf0),
451    );
452
453    let v_scale = _mm512_set1_ps(lut.scale);
454    let v_bias = _mm512_set1_ps(lut.bias);
455
456    // 16 u16 → 16 u32 → 16 f32 per store. FMA matches `f32::mul_add`
457    // bit-exactly, the contract the scalar oracle locks down.
458    let sum0 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(dis0));
459    let scores0 = _mm512_fmadd_ps(v_scale, sum0, v_bias);
460    _mm512_storeu_ps(out.as_mut_ptr(), scores0);
461
462    let sum1 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(dis1));
463    let scores1 = _mm512_fmadd_ps(v_scale, sum1, v_bias);
464    _mm512_storeu_ps(out.as_mut_ptr().add(16), scores1);
465
466    // Tail lanes match the scalar oracle: unused slots are 0.0.
467    for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
468        *score = 0.0;
469    }
470}
471
472/// NEON block scorer. Reads aligned 128-bit lanes straight from
473/// [`super::storage::BlockedCodeStorage::block_codes`] with no
474/// per-query repack and table-looks up nibble scores via `vqtbl1q_u8`.
475///
476/// MIT notice: the SIMD body is adapted from RyanCodrai/turbovec
477/// (commit `4a4f2cd2db233f24405911b1ceaf1823fa23b4ac`, MIT). The
478/// single-block trait surface, the PERM0-aware lane scatter, and the
479/// tail handling are clean-room — the trait returns the unit-rotated
480/// dot product only; outer code applies metric and per-vector scale,
481/// matching [`ScalarScorer`].
482#[cfg(target_arch = "aarch64")]
483pub struct NeonScorer;
484
485#[cfg(target_arch = "aarch64")]
486impl PerBlockScorer for NeonScorer {
487    fn kernel(&self) -> ScoreKernel {
488        ScoreKernel::Neon
489    }
490
491    fn score_block(
492        &self,
493        lut: &QueryLut,
494        block_codes: &[u8],
495        n_byte_groups: usize,
496        n_vectors: usize,
497        out: &mut [f32; BLOCK_LANES],
498    ) {
499        debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
500        debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
501        debug_assert!(n_vectors <= BLOCK_LANES);
502        // SAFETY: NEON is part of the mandatory aarch64 base ISA, so
503        // every `target_arch = "aarch64"` host supports these intrinsics.
504        unsafe {
505            score_block_neon_inner(lut, block_codes, n_byte_groups, n_vectors, out);
506        }
507    }
508}
509
510#[cfg(target_arch = "aarch64")]
511#[target_feature(enable = "neon")]
512unsafe fn score_block_neon_inner(
513    lut: &QueryLut,
514    block_codes: &[u8],
515    n_byte_groups: usize,
516    n_vectors: usize,
517    out: &mut [f32; BLOCK_LANES],
518) {
519    use std::arch::aarch64::*;
520
521    // Four u16x8 accumulators cover the 32-lane block split as two
522    // 16-lane halves (half-0 = lanes `PERM0[i]`, half-1 = lanes
523    // `PERM0[i] + 16` for `i in 0..16`). Each half splits again into a
524    // low 8-position u16x8 and a high 8-position u16x8. u16 headroom
525    // matches the AVX2/AVX-512BW kernels: max sum per position = 2 *
526    // n_byte_groups * MAX_LUT, safe for n_byte_groups <= 258.
527    let mut acc_h0_lo = vdupq_n_u16(0);
528    let mut acc_h0_hi = vdupq_n_u16(0);
529    let mut acc_h1_lo = vdupq_n_u16(0);
530    let mut acc_h1_hi = vdupq_n_u16(0);
531    let nibble_mask = vdupq_n_u8(0x0f);
532
533    for g in 0..n_byte_groups {
534        let base = g * BLOCK_LANES;
535        // 16 bytes of "hi-pairs" (perm-positions 0..16) and 16 bytes of
536        // "lo-pairs" (perm-positions 16..32). Each byte packs two
537        // 4-bit codes: low nibble belongs to lane `PERM0[perm_pos]`
538        // (half 0), high nibble to lane `PERM0[perm_pos] + 16` (half 1).
539        let hi_pair = vld1q_u8(block_codes.as_ptr().add(base));
540        let lo_pair = vld1q_u8(block_codes.as_ptr().add(base + 16));
541        let hi_lut = vld1q_u8(lut.bytes.as_ptr().add(g * 32));
542        let lo_lut = vld1q_u8(lut.bytes.as_ptr().add(g * 32 + 16));
543
544        // Half-0 nibble indices (low nibbles).
545        let idx_lo_h0 = vandq_u8(lo_pair, nibble_mask);
546        let idx_hi_h0 = vandq_u8(hi_pair, nibble_mask);
547        // Half-1 nibble indices (high nibbles).
548        let idx_lo_h1 = vshrq_n_u8(lo_pair, 4);
549        let idx_hi_h1 = vshrq_n_u8(hi_pair, 4);
550
551        // `vqtbl1q_u8` does a 16-entry byte table lookup; all indices
552        // are masked to 0..16 so no out-of-range result is consumed.
553        let s_lo_h0 = vqtbl1q_u8(lo_lut, idx_lo_h0);
554        let s_hi_h0 = vqtbl1q_u8(hi_lut, idx_hi_h0);
555        let s_lo_h1 = vqtbl1q_u8(lo_lut, idx_lo_h1);
556        let s_hi_h1 = vqtbl1q_u8(hi_lut, idx_hi_h1);
557
558        // Per-position score = lo-table contribution + hi-table
559        // contribution. `vaddl_u8` widens u8x8 + u8x8 → u16x8 in one
560        // op, accumulating without overflow into the u16 lanes.
561        acc_h0_lo = vaddq_u16(
562            acc_h0_lo,
563            vaddl_u8(vget_low_u8(s_lo_h0), vget_low_u8(s_hi_h0)),
564        );
565        acc_h0_hi = vaddq_u16(
566            acc_h0_hi,
567            vaddl_u8(vget_high_u8(s_lo_h0), vget_high_u8(s_hi_h0)),
568        );
569        acc_h1_lo = vaddq_u16(
570            acc_h1_lo,
571            vaddl_u8(vget_low_u8(s_lo_h1), vget_low_u8(s_hi_h1)),
572        );
573        acc_h1_hi = vaddq_u16(
574            acc_h1_hi,
575            vaddl_u8(vget_high_u8(s_lo_h1), vget_high_u8(s_hi_h1)),
576        );
577    }
578
579    let scale = vdupq_n_f32(lut.scale);
580    let bias = vdupq_n_f32(lut.bias);
581
582    // Widen u16x8 → u32x4 × 2 → f32x4 × 2 then FMA with scale/bias.
583    // `vfmaq_f32(a, b, c) = a + b * c` is a single-rounding fused
584    // multiply-add and matches `f32::mul_add` bit-exactly — the
585    // contract the scalar oracle locks down.
586    let conv = |acc_u16: uint16x8_t| -> [f32; 8] {
587        let lo_u32 = vmovl_u16(vget_low_u16(acc_u16));
588        let hi_u32 = vmovl_u16(vget_high_u16(acc_u16));
589        let lo_f = vcvtq_f32_u32(lo_u32);
590        let hi_f = vcvtq_f32_u32(hi_u32);
591        let lo_score = vfmaq_f32(bias, scale, lo_f);
592        let hi_score = vfmaq_f32(bias, scale, hi_f);
593        let mut tmp = [0.0f32; 8];
594        vst1q_f32(tmp.as_mut_ptr(), lo_score);
595        vst1q_f32(tmp.as_mut_ptr().add(4), hi_score);
596        tmp
597    };
598
599    let h0_lo = conv(acc_h0_lo);
600    let h0_hi = conv(acc_h0_hi);
601    let h1_lo = conv(acc_h1_lo);
602    let h1_hi = conv(acc_h1_hi);
603
604    // Scatter perm-positions back to lane order. `scores_hN[perm_pos]`
605    // is the score for the lane PERM0 maps `perm_pos` to (offset by 16
606    // for half 1). Done in scalar — this is 32 stores at the tail of
607    // the hot loop, dwarfed by the per-group SIMD work above.
608    let mut scores_h0 = [0.0f32; 16];
609    let mut scores_h1 = [0.0f32; 16];
610    scores_h0[..8].copy_from_slice(&h0_lo);
611    scores_h0[8..].copy_from_slice(&h0_hi);
612    scores_h1[..8].copy_from_slice(&h1_lo);
613    scores_h1[8..].copy_from_slice(&h1_hi);
614
615    for perm_pos in 0..16 {
616        let lane = PERM0[perm_pos];
617        out[lane] = scores_h0[perm_pos];
618        out[lane + 16] = scores_h1[perm_pos];
619    }
620
621    // Tail lanes match the scalar oracle: unused slots are 0.0.
622    for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
623        *score = 0.0;
624    }
625}
626
627fn decode_perm0_byte(block_codes: &[u8], group: usize, lane: usize) -> (u8, u8) {
628    debug_assert!(lane < BLOCK_LANES);
629    let half = lane / 16;
630    let within_half = lane % 16;
631    let perm_pos = PERM0
632        .iter()
633        .position(|&v| v == within_half)
634        .expect("lane in perm0");
635    let group_base = group * BLOCK_LANES;
636    let hi_pair = block_codes[group_base + perm_pos];
637    let lo_pair = block_codes[group_base + 16 + perm_pos];
638    if half == 0 {
639        (hi_pair & 0x0f, lo_pair & 0x0f)
640    } else {
641        (hi_pair >> 4, lo_pair >> 4)
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use crate::storage::engine::turboquant::storage::BlockedCodeStorage;
649
650    fn centroids_for(bits: u8) -> Vec<f64> {
651        let levels = 1usize << bits;
652        let step = 2.0 / levels as f64;
653        (0..levels)
654            .map(|i| -1.0 + (i as f64 + 0.5) * step)
655            .collect()
656    }
657
658    #[test]
659    fn scalar_block_score_is_zero_when_query_is_zero() {
660        let lut = QueryLut::build(&[0.0f32; 4], &centroids_for(4));
661        let mut storage = BlockedCodeStorage::new(2);
662        storage.append(&[0x12, 0x34], 1.0);
663        let mut out = [0.0f32; BLOCK_LANES];
664        ScalarScorer.score_block(&lut, storage.block_codes(0), 2, 1, &mut out);
665        // Bias = 0 for zero query, scale arbitrary, acc clamps to LUT min
666        // (all zero entries) → 0 + 0 == 0.
667        assert_eq!(out[0], 0.0);
668    }
669
670    #[test]
671    fn scalar_block_score_matches_per_vector_scalar_lut() {
672        // Cross-check: reconstruct the per-vector scalar sum
673        // (sum of LUT[hi] + LUT[16+lo] over groups) by direct
674        // arithmetic on the per-vector packed bytes and confirm
675        // ScalarScorer agrees on the equivalent lane.
676        let centroids = centroids_for(4);
677        let query = vec![0.2f32, -0.3, 0.4, -0.5];
678        let lut = QueryLut::build(&query, &centroids);
679
680        let n_byte_groups = 2;
681        let mut storage = BlockedCodeStorage::new(n_byte_groups);
682        let packed_a = vec![0xa3u8, 0x5c];
683        let packed_b = vec![0x71u8, 0xfe];
684        storage.append(&packed_a, 1.0);
685        storage.append(&packed_b, 1.0);
686
687        let mut out = [0.0f32; BLOCK_LANES];
688        ScalarScorer.score_block(&lut, storage.block_codes(0), n_byte_groups, 2, &mut out);
689
690        for (lane, packed) in [&packed_a, &packed_b].iter().enumerate() {
691            let mut expected = 0u32;
692            for (g, byte) in packed.iter().enumerate() {
693                let lo = (byte & 0x0f) as usize;
694                let hi = (byte >> 4) as usize;
695                expected += lut.bytes[g * 32 + hi] as u32;
696                expected += lut.bytes[g * 32 + 16 + lo] as u32;
697            }
698            let expected_f = lut.scale.mul_add(expected as f32, lut.bias);
699            assert_eq!(
700                out[lane], expected_f,
701                "lane {lane} matches per-vector LUT scoring",
702            );
703        }
704
705        for lane in 2..BLOCK_LANES {
706            assert_eq!(out[lane], 0.0, "unused lane {lane} stays 0");
707        }
708    }
709
710    #[test]
711    fn select_scorer_matches_host_capability() {
712        let kernel = select_scorer().kernel();
713        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
714        {
715            if std::is_x86_feature_detected!("avx512bw")
716                && std::is_x86_feature_detected!("avx512f")
717                && std::is_x86_feature_detected!("fma")
718            {
719                assert_eq!(kernel, ScoreKernel::Avx512Bw);
720                return;
721            }
722            if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
723                assert_eq!(kernel, ScoreKernel::Avx2);
724                return;
725            }
726        }
727        #[cfg(target_arch = "aarch64")]
728        {
729            // NEON is mandatory on aarch64 — no runtime gate; the
730            // dispatch must always pick it on this arch.
731            assert_eq!(kernel, ScoreKernel::Neon);
732            return;
733        }
734        #[allow(unreachable_code)]
735        {
736            assert_eq!(kernel, ScoreKernel::Scalar);
737        }
738    }
739
740    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
741    #[test]
742    fn avx2_scorer_matches_scalar_oracle_across_dataset_sizes() {
743        if !std::is_x86_feature_detected!("avx2") || !std::is_x86_feature_detected!("fma") {
744            return;
745        }
746        let centroids = centroids_for(4);
747        // Queries chosen to exercise different LUT shapes (zero, sign-mixed,
748        // single-axis) and to keep n_byte_groups small enough that the
749        // AVX2 kernel's u16 accumulators cannot overflow vs the scalar
750        // u32 oracle (max sum per lane = 2*n_byte_groups*127).
751        let queries: [Vec<f32>; 4] = [
752            vec![0.0; 8],
753            vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
754            vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
755            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
756        ];
757
758        for query in &queries {
759            let lut = QueryLut::build(query, &centroids);
760            let n_byte_groups = lut.n_byte_groups;
761
762            for n in [1usize, 31, 32, 33, 95, 96, 97] {
763                let mut storage = BlockedCodeStorage::new(n_byte_groups);
764                for i in 0..n {
765                    let packed: Vec<u8> = (0..n_byte_groups)
766                        .map(|g| {
767                            let lo = ((i + g * 3) & 0x0f) as u8;
768                            let hi = ((i * 5 + g * 7) & 0x0f) as u8;
769                            lo | (hi << 4)
770                        })
771                        .collect();
772                    storage.append(&packed, 1.0);
773                }
774
775                for b in 0..storage.n_blocks() {
776                    let filled = storage.block_lanes_filled(b);
777                    let mut scalar_out = [0.0f32; BLOCK_LANES];
778                    let mut avx2_out = [f32::NAN; BLOCK_LANES];
779                    ScalarScorer.score_block(
780                        &lut,
781                        storage.block_codes(b),
782                        n_byte_groups,
783                        filled,
784                        &mut scalar_out,
785                    );
786                    AVX2_SCORER.score_block(
787                        &lut,
788                        storage.block_codes(b),
789                        n_byte_groups,
790                        filled,
791                        &mut avx2_out,
792                    );
793                    for lane in 0..BLOCK_LANES {
794                        assert_eq!(
795                            avx2_out[lane].to_bits(),
796                            scalar_out[lane].to_bits(),
797                            "AVX2 diverges from scalar at N={n}, block {b}, lane {lane}",
798                        );
799                    }
800                }
801            }
802        }
803    }
804
805    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
806    #[test]
807    fn avx512bw_scorer_matches_scalar_oracle_across_dataset_sizes() {
808        if !std::is_x86_feature_detected!("avx512bw")
809            || !std::is_x86_feature_detected!("avx512f")
810            || !std::is_x86_feature_detected!("fma")
811        {
812            return;
813        }
814        let centroids = centroids_for(4);
815        // Queries chosen to exercise different LUT shapes (zero,
816        // sign-mixed, single-axis) plus a query whose n_byte_groups is
817        // odd (5) so the kernel exercises both the paired-group main
818        // loop and the unpaired-tail branch. `n_byte_groups` stays
819        // small enough that the u16 accumulators cannot overflow vs
820        // the scalar u32 oracle (max sum per lane = 2 *
821        // n_byte_groups * 127).
822        let queries: [Vec<f32>; 5] = [
823            vec![0.0; 8],
824            vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
825            vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
826            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
827            // 10-dim → 5 byte groups: odd count exercises the tail.
828            vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8, 0.9, -1.0],
829        ];
830
831        for query in &queries {
832            let lut = QueryLut::build(query, &centroids);
833            let n_byte_groups = lut.n_byte_groups;
834
835            for n in [1usize, 31, 32, 33, 95, 96, 97] {
836                let mut storage = BlockedCodeStorage::new(n_byte_groups);
837                for i in 0..n {
838                    let packed: Vec<u8> = (0..n_byte_groups)
839                        .map(|g| {
840                            let lo = ((i + g * 3) & 0x0f) as u8;
841                            let hi = ((i * 5 + g * 7) & 0x0f) as u8;
842                            lo | (hi << 4)
843                        })
844                        .collect();
845                    storage.append(&packed, 1.0);
846                }
847
848                for b in 0..storage.n_blocks() {
849                    let filled = storage.block_lanes_filled(b);
850                    let mut scalar_out = [0.0f32; BLOCK_LANES];
851                    let mut avx512_out = [f32::NAN; BLOCK_LANES];
852                    ScalarScorer.score_block(
853                        &lut,
854                        storage.block_codes(b),
855                        n_byte_groups,
856                        filled,
857                        &mut scalar_out,
858                    );
859                    AVX512BW_SCORER.score_block(
860                        &lut,
861                        storage.block_codes(b),
862                        n_byte_groups,
863                        filled,
864                        &mut avx512_out,
865                    );
866                    for lane in 0..BLOCK_LANES {
867                        assert_eq!(
868                            avx512_out[lane].to_bits(),
869                            scalar_out[lane].to_bits(),
870                            "AVX-512BW diverges from scalar at N={n}, block {b}, lane {lane}",
871                        );
872                    }
873                }
874            }
875        }
876    }
877
878    #[cfg(target_arch = "aarch64")]
879    #[test]
880    fn neon_scorer_matches_scalar_oracle_across_dataset_sizes() {
881        // NEON is mandatory on aarch64; no runtime gate needed. Query
882        // shapes mirror the AVX-512BW test so the paired-byte structure
883        // and the odd-`n_byte_groups` tail-only case are both covered.
884        // `n_byte_groups` stays small enough that u16 accumulators
885        // cannot overflow vs the scalar u32 oracle (max sum per
886        // position = 2 * n_byte_groups * 127).
887        let centroids = centroids_for(4);
888        let queries: [Vec<f32>; 5] = [
889            vec![0.0; 8],
890            vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
891            vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
892            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
893            // 10-dim → 5 byte groups: odd count exercises the tail of
894            // any paired-group kernel (NEON walks groups one at a time
895            // so there is no paired/unpaired distinction, but the case
896            // is kept for parity with the AVX-512BW test).
897            vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8, 0.9, -1.0],
898        ];
899
900        for query in &queries {
901            let lut = QueryLut::build(query, &centroids);
902            let n_byte_groups = lut.n_byte_groups;
903
904            // Coverage includes the tail-only case (n < BLOCK_LANES)
905            // and exact-block-boundary sizes (32, 96) per acceptance.
906            for n in [1usize, 31, 32, 33, 95, 96, 97] {
907                let mut storage = BlockedCodeStorage::new(n_byte_groups);
908                for i in 0..n {
909                    let packed: Vec<u8> = (0..n_byte_groups)
910                        .map(|g| {
911                            let lo = ((i + g * 3) & 0x0f) as u8;
912                            let hi = ((i * 5 + g * 7) & 0x0f) as u8;
913                            lo | (hi << 4)
914                        })
915                        .collect();
916                    storage.append(&packed, 1.0);
917                }
918
919                for b in 0..storage.n_blocks() {
920                    let filled = storage.block_lanes_filled(b);
921                    let mut scalar_out = [0.0f32; BLOCK_LANES];
922                    let mut neon_out = [f32::NAN; BLOCK_LANES];
923                    ScalarScorer.score_block(
924                        &lut,
925                        storage.block_codes(b),
926                        n_byte_groups,
927                        filled,
928                        &mut scalar_out,
929                    );
930                    NEON_SCORER.score_block(
931                        &lut,
932                        storage.block_codes(b),
933                        n_byte_groups,
934                        filled,
935                        &mut neon_out,
936                    );
937                    for lane in 0..BLOCK_LANES {
938                        assert_eq!(
939                            neon_out[lane].to_bits(),
940                            scalar_out[lane].to_bits(),
941                            "NEON diverges from scalar at N={n}, block {b}, lane {lane}",
942                        );
943                    }
944                }
945            }
946        }
947    }
948}