1use super::storage::{BLOCK_LANES, PERM0};
15
16const MAX_LUT: f32 = 127.0;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ScoreKernel {
20 Scalar,
21 Avx2,
22 Avx512Bw,
23 Neon,
24}
25
26#[derive(Debug, Clone)]
29pub struct QueryLut {
30 pub bytes: Vec<u8>,
31 pub n_byte_groups: usize,
32 pub scale: f32,
33 pub bias: f32,
34}
35
36impl QueryLut {
37 pub fn build(query_terms: &[f32], centroids: &[f64]) -> Self {
38 let n_byte_groups = query_terms.len().div_ceil(2);
39 let mut float_luts = vec![0.0f32; n_byte_groups * 32];
40 let mut max_span = 0.0f32;
41 let mut bias = 0.0f32;
42
43 for group in 0..n_byte_groups {
44 let lo_term = query_terms[group * 2];
45 let hi_term = query_terms.get(group * 2 + 1).copied().unwrap_or(0.0);
46 let group_base = group * 32;
47
48 for code in 0..16 {
49 float_luts[group_base + code] = hi_term * centroids[code] as f32;
50 float_luts[group_base + 16 + code] = lo_term * centroids[code] as f32;
51 }
52
53 let hi = &float_luts[group_base..group_base + 16];
54 let lo = &float_luts[group_base + 16..group_base + 32];
55 let hi_min = hi.iter().copied().fold(f32::INFINITY, f32::min);
56 let hi_max = hi.iter().copied().fold(f32::NEG_INFINITY, f32::max);
57 let lo_min = lo.iter().copied().fold(f32::INFINITY, f32::min);
58 let lo_max = lo.iter().copied().fold(f32::NEG_INFINITY, f32::max);
59 bias += hi_min + lo_min;
60 max_span = max_span.max((hi_max - hi_min).max(lo_max - lo_min));
61 }
62
63 let scale = if max_span > 1e-10 {
64 max_span / MAX_LUT
65 } else {
66 1.0
67 };
68 let inv_scale = 1.0 / scale;
69 let mut bytes = vec![0u8; n_byte_groups * 32];
70
71 for group in 0..n_byte_groups {
72 let group_base = group * 32;
73 let hi_min = float_luts[group_base..group_base + 16]
74 .iter()
75 .copied()
76 .fold(f32::INFINITY, f32::min);
77 let lo_min = float_luts[group_base + 16..group_base + 32]
78 .iter()
79 .copied()
80 .fold(f32::INFINITY, f32::min);
81 for code in 0..16 {
82 bytes[group_base + code] = ((float_luts[group_base + code] - hi_min) * inv_scale)
83 .round()
84 .clamp(0.0, MAX_LUT) as u8;
85 bytes[group_base + 16 + code] = ((float_luts[group_base + 16 + code] - lo_min)
86 * inv_scale)
87 .round()
88 .clamp(0.0, MAX_LUT) as u8;
89 }
90 }
91
92 Self {
93 bytes,
94 n_byte_groups,
95 scale,
96 bias,
97 }
98 }
99}
100
101pub trait PerBlockScorer: Sync + Send {
109 fn kernel(&self) -> ScoreKernel;
110
111 fn score_block(
117 &self,
118 lut: &QueryLut,
119 block_codes: &[u8],
120 n_byte_groups: usize,
121 n_vectors: usize,
122 out: &mut [f32; BLOCK_LANES],
123 );
124}
125
126pub struct ScalarScorer;
130
131impl PerBlockScorer for ScalarScorer {
132 fn kernel(&self) -> ScoreKernel {
133 ScoreKernel::Scalar
134 }
135
136 fn score_block(
137 &self,
138 lut: &QueryLut,
139 block_codes: &[u8],
140 n_byte_groups: usize,
141 n_vectors: usize,
142 out: &mut [f32; BLOCK_LANES],
143 ) {
144 debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
145 debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
146 debug_assert!(n_vectors <= BLOCK_LANES);
147
148 for (lane, slot) in out.iter_mut().enumerate() {
149 if lane >= n_vectors {
150 *slot = 0.0;
151 continue;
152 }
153 let mut acc = 0u32;
154 for g in 0..n_byte_groups {
155 let (hi, lo) = decode_perm0_byte(block_codes, g, lane);
156 acc = acc.wrapping_add(lut.bytes[g * 32 + hi as usize] as u32);
157 acc = acc.wrapping_add(lut.bytes[g * 32 + 16 + lo as usize] as u32);
158 }
159 *slot = lut.scale.mul_add(acc as f32, lut.bias);
160 }
161 }
162}
163
164static SCALAR_SCORER: ScalarScorer = ScalarScorer;
165
166#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
167static AVX2_SCORER: Avx2Scorer = Avx2Scorer;
168
169#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
170static AVX512BW_SCORER: Avx512BwScorer = Avx512BwScorer;
171
172#[cfg(target_arch = "aarch64")]
173static NEON_SCORER: NeonScorer = NeonScorer;
174
175pub fn select_scorer() -> &'static dyn PerBlockScorer {
179 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
180 {
181 if std::is_x86_feature_detected!("avx512bw")
189 && std::is_x86_feature_detected!("avx512f")
190 && std::is_x86_feature_detected!("fma")
191 {
192 return &AVX512BW_SCORER;
193 }
194 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
195 return &AVX2_SCORER;
196 }
197 }
198 #[cfg(target_arch = "aarch64")]
199 {
200 return &NEON_SCORER;
206 }
207 #[allow(unreachable_code)]
208 &SCALAR_SCORER
209}
210
211#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221pub struct Avx2Scorer;
222
223#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
224impl PerBlockScorer for Avx2Scorer {
225 fn kernel(&self) -> ScoreKernel {
226 ScoreKernel::Avx2
227 }
228
229 fn score_block(
230 &self,
231 lut: &QueryLut,
232 block_codes: &[u8],
233 n_byte_groups: usize,
234 n_vectors: usize,
235 out: &mut [f32; BLOCK_LANES],
236 ) {
237 debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
238 debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
239 debug_assert!(n_vectors <= BLOCK_LANES);
240 debug_assert!(std::is_x86_feature_detected!("avx2"));
241 debug_assert!(std::is_x86_feature_detected!("fma"));
242 unsafe {
245 score_block_avx2_inner(lut, block_codes, n_byte_groups, n_vectors, out);
246 }
247 }
248}
249
250#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251#[target_feature(enable = "avx2,fma")]
252unsafe fn score_block_avx2_inner(
253 lut: &QueryLut,
254 block_codes: &[u8],
255 n_byte_groups: usize,
256 n_vectors: usize,
257 out: &mut [f32; BLOCK_LANES],
258) {
259 #[cfg(target_arch = "x86")]
260 use std::arch::x86::*;
261 #[cfg(target_arch = "x86_64")]
262 use std::arch::x86_64::*;
263
264 let mut accum = [_mm256_setzero_si256(); 4];
265 let nibble_mask = _mm256_set1_epi8(0x0f);
266
267 for g in 0..n_byte_groups {
268 let codes = _mm256_loadu_si256(block_codes.as_ptr().add(g * BLOCK_LANES) as *const __m256i);
269 let clo = _mm256_and_si256(codes, nibble_mask);
270 let chi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), nibble_mask);
271 let table = _mm256_loadu_si256(lut.bytes.as_ptr().add(g * 32) as *const __m256i);
272 let lo_scores = _mm256_shuffle_epi8(table, clo);
273 let hi_scores = _mm256_shuffle_epi8(table, chi);
274
275 accum[0] = _mm256_add_epi16(accum[0], lo_scores);
276 accum[1] = _mm256_add_epi16(accum[1], _mm256_srli_epi16(lo_scores, 8));
277 accum[2] = _mm256_add_epi16(accum[2], hi_scores);
278 accum[3] = _mm256_add_epi16(accum[3], _mm256_srli_epi16(hi_scores, 8));
279 }
280
281 accum[0] = _mm256_sub_epi16(accum[0], _mm256_slli_epi16(accum[1], 8));
282 accum[2] = _mm256_sub_epi16(accum[2], _mm256_slli_epi16(accum[3], 8));
283
284 let dis0 = _mm256_add_epi16(
285 _mm256_permute2x128_si256(accum[0], accum[1], 0x21),
286 _mm256_blend_epi32(accum[0], accum[1], 0xf0),
287 );
288 let dis1 = _mm256_add_epi16(
289 _mm256_permute2x128_si256(accum[2], accum[3], 0x21),
290 _mm256_blend_epi32(accum[2], accum[3], 0xf0),
291 );
292
293 let sums = [
294 _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(dis0))),
295 _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(dis0, 1))),
296 _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(dis1))),
297 _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(dis1, 1))),
298 ];
299 let v_scale = _mm256_set1_ps(lut.scale);
300 let v_bias = _mm256_set1_ps(lut.bias);
301
302 for (chunk, sum) in sums.iter().enumerate() {
303 let lane_start = chunk * 8;
304 let score = _mm256_fmadd_ps(v_scale, *sum, v_bias);
306 _mm256_storeu_ps(out.as_mut_ptr().add(lane_start), score);
307 }
308
309 for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
311 *score = 0.0;
312 }
313}
314
315#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
325pub struct Avx512BwScorer;
326
327#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
328impl PerBlockScorer for Avx512BwScorer {
329 fn kernel(&self) -> ScoreKernel {
330 ScoreKernel::Avx512Bw
331 }
332
333 fn score_block(
334 &self,
335 lut: &QueryLut,
336 block_codes: &[u8],
337 n_byte_groups: usize,
338 n_vectors: usize,
339 out: &mut [f32; BLOCK_LANES],
340 ) {
341 debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
342 debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
343 debug_assert!(n_vectors <= BLOCK_LANES);
344 debug_assert!(std::is_x86_feature_detected!("avx512bw"));
345 debug_assert!(std::is_x86_feature_detected!("avx512f"));
346 debug_assert!(std::is_x86_feature_detected!("fma"));
347 unsafe {
350 score_block_avx512bw_inner(lut, block_codes, n_byte_groups, n_vectors, out);
351 }
352 }
353}
354
355#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
356#[target_feature(enable = "avx512bw,avx512f,avx2,avx,fma")]
357unsafe fn score_block_avx512bw_inner(
358 lut: &QueryLut,
359 block_codes: &[u8],
360 n_byte_groups: usize,
361 n_vectors: usize,
362 out: &mut [f32; BLOCK_LANES],
363) {
364 #[cfg(target_arch = "x86")]
365 use std::arch::x86::*;
366 #[cfg(target_arch = "x86_64")]
367 use std::arch::x86_64::*;
368
369 let mut accum = [_mm512_setzero_si512(); 4];
375 let nibble_mask = _mm512_set1_epi8(0x0f);
376
377 let mut g = 0;
378 while g + 2 <= n_byte_groups {
379 let codes = _mm512_loadu_si512(block_codes.as_ptr().add(g * BLOCK_LANES) as *const _);
380 let clo = _mm512_and_si512(codes, nibble_mask);
381 let chi = _mm512_and_si512(_mm512_srli_epi16(codes, 4), nibble_mask);
382 let table = _mm512_loadu_si512(lut.bytes.as_ptr().add(g * 32) as *const _);
383 let lo_scores = _mm512_shuffle_epi8(table, clo);
384 let hi_scores = _mm512_shuffle_epi8(table, chi);
385
386 accum[0] = _mm512_add_epi16(accum[0], lo_scores);
387 accum[1] = _mm512_add_epi16(accum[1], _mm512_srli_epi16(lo_scores, 8));
388 accum[2] = _mm512_add_epi16(accum[2], hi_scores);
389 accum[3] = _mm512_add_epi16(accum[3], _mm512_srli_epi16(hi_scores, 8));
390
391 g += 2;
392 }
393
394 let mut accum256 = [
399 _mm256_add_epi16(
400 _mm512_castsi512_si256(accum[0]),
401 _mm512_extracti64x4_epi64(accum[0], 1),
402 ),
403 _mm256_add_epi16(
404 _mm512_castsi512_si256(accum[1]),
405 _mm512_extracti64x4_epi64(accum[1], 1),
406 ),
407 _mm256_add_epi16(
408 _mm512_castsi512_si256(accum[2]),
409 _mm512_extracti64x4_epi64(accum[2], 1),
410 ),
411 _mm256_add_epi16(
412 _mm512_castsi512_si256(accum[3]),
413 _mm512_extracti64x4_epi64(accum[3], 1),
414 ),
415 ];
416
417 if g < n_byte_groups {
421 let nibble_mask_256 = _mm256_set1_epi8(0x0f);
422 let codes = _mm256_loadu_si256(block_codes.as_ptr().add(g * BLOCK_LANES) as *const __m256i);
423 let clo = _mm256_and_si256(codes, nibble_mask_256);
424 let chi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), nibble_mask_256);
425 let table = _mm256_loadu_si256(lut.bytes.as_ptr().add(g * 32) as *const __m256i);
426 let lo_scores = _mm256_shuffle_epi8(table, clo);
427 let hi_scores = _mm256_shuffle_epi8(table, chi);
428
429 accum256[0] = _mm256_add_epi16(accum256[0], lo_scores);
430 accum256[1] = _mm256_add_epi16(accum256[1], _mm256_srli_epi16(lo_scores, 8));
431 accum256[2] = _mm256_add_epi16(accum256[2], hi_scores);
432 accum256[3] = _mm256_add_epi16(accum256[3], _mm256_srli_epi16(hi_scores, 8));
433 }
434
435 accum256[0] = _mm256_sub_epi16(accum256[0], _mm256_slli_epi16(accum256[1], 8));
441 accum256[2] = _mm256_sub_epi16(accum256[2], _mm256_slli_epi16(accum256[3], 8));
442
443 let dis0 = _mm256_add_epi16(
445 _mm256_permute2x128_si256(accum256[0], accum256[1], 0x21),
446 _mm256_blend_epi32(accum256[0], accum256[1], 0xf0),
447 );
448 let dis1 = _mm256_add_epi16(
449 _mm256_permute2x128_si256(accum256[2], accum256[3], 0x21),
450 _mm256_blend_epi32(accum256[2], accum256[3], 0xf0),
451 );
452
453 let v_scale = _mm512_set1_ps(lut.scale);
454 let v_bias = _mm512_set1_ps(lut.bias);
455
456 let sum0 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(dis0));
459 let scores0 = _mm512_fmadd_ps(v_scale, sum0, v_bias);
460 _mm512_storeu_ps(out.as_mut_ptr(), scores0);
461
462 let sum1 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(dis1));
463 let scores1 = _mm512_fmadd_ps(v_scale, sum1, v_bias);
464 _mm512_storeu_ps(out.as_mut_ptr().add(16), scores1);
465
466 for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
468 *score = 0.0;
469 }
470}
471
472#[cfg(target_arch = "aarch64")]
483pub struct NeonScorer;
484
485#[cfg(target_arch = "aarch64")]
486impl PerBlockScorer for NeonScorer {
487 fn kernel(&self) -> ScoreKernel {
488 ScoreKernel::Neon
489 }
490
491 fn score_block(
492 &self,
493 lut: &QueryLut,
494 block_codes: &[u8],
495 n_byte_groups: usize,
496 n_vectors: usize,
497 out: &mut [f32; BLOCK_LANES],
498 ) {
499 debug_assert_eq!(n_byte_groups, lut.n_byte_groups);
500 debug_assert!(block_codes.len() >= n_byte_groups * BLOCK_LANES);
501 debug_assert!(n_vectors <= BLOCK_LANES);
502 unsafe {
505 score_block_neon_inner(lut, block_codes, n_byte_groups, n_vectors, out);
506 }
507 }
508}
509
510#[cfg(target_arch = "aarch64")]
511#[target_feature(enable = "neon")]
512unsafe fn score_block_neon_inner(
513 lut: &QueryLut,
514 block_codes: &[u8],
515 n_byte_groups: usize,
516 n_vectors: usize,
517 out: &mut [f32; BLOCK_LANES],
518) {
519 use std::arch::aarch64::*;
520
521 let mut acc_h0_lo = vdupq_n_u16(0);
528 let mut acc_h0_hi = vdupq_n_u16(0);
529 let mut acc_h1_lo = vdupq_n_u16(0);
530 let mut acc_h1_hi = vdupq_n_u16(0);
531 let nibble_mask = vdupq_n_u8(0x0f);
532
533 for g in 0..n_byte_groups {
534 let base = g * BLOCK_LANES;
535 let hi_pair = vld1q_u8(block_codes.as_ptr().add(base));
540 let lo_pair = vld1q_u8(block_codes.as_ptr().add(base + 16));
541 let hi_lut = vld1q_u8(lut.bytes.as_ptr().add(g * 32));
542 let lo_lut = vld1q_u8(lut.bytes.as_ptr().add(g * 32 + 16));
543
544 let idx_lo_h0 = vandq_u8(lo_pair, nibble_mask);
546 let idx_hi_h0 = vandq_u8(hi_pair, nibble_mask);
547 let idx_lo_h1 = vshrq_n_u8(lo_pair, 4);
549 let idx_hi_h1 = vshrq_n_u8(hi_pair, 4);
550
551 let s_lo_h0 = vqtbl1q_u8(lo_lut, idx_lo_h0);
554 let s_hi_h0 = vqtbl1q_u8(hi_lut, idx_hi_h0);
555 let s_lo_h1 = vqtbl1q_u8(lo_lut, idx_lo_h1);
556 let s_hi_h1 = vqtbl1q_u8(hi_lut, idx_hi_h1);
557
558 acc_h0_lo = vaddq_u16(
562 acc_h0_lo,
563 vaddl_u8(vget_low_u8(s_lo_h0), vget_low_u8(s_hi_h0)),
564 );
565 acc_h0_hi = vaddq_u16(
566 acc_h0_hi,
567 vaddl_u8(vget_high_u8(s_lo_h0), vget_high_u8(s_hi_h0)),
568 );
569 acc_h1_lo = vaddq_u16(
570 acc_h1_lo,
571 vaddl_u8(vget_low_u8(s_lo_h1), vget_low_u8(s_hi_h1)),
572 );
573 acc_h1_hi = vaddq_u16(
574 acc_h1_hi,
575 vaddl_u8(vget_high_u8(s_lo_h1), vget_high_u8(s_hi_h1)),
576 );
577 }
578
579 let scale = vdupq_n_f32(lut.scale);
580 let bias = vdupq_n_f32(lut.bias);
581
582 let conv = |acc_u16: uint16x8_t| -> [f32; 8] {
587 let lo_u32 = vmovl_u16(vget_low_u16(acc_u16));
588 let hi_u32 = vmovl_u16(vget_high_u16(acc_u16));
589 let lo_f = vcvtq_f32_u32(lo_u32);
590 let hi_f = vcvtq_f32_u32(hi_u32);
591 let lo_score = vfmaq_f32(bias, scale, lo_f);
592 let hi_score = vfmaq_f32(bias, scale, hi_f);
593 let mut tmp = [0.0f32; 8];
594 vst1q_f32(tmp.as_mut_ptr(), lo_score);
595 vst1q_f32(tmp.as_mut_ptr().add(4), hi_score);
596 tmp
597 };
598
599 let h0_lo = conv(acc_h0_lo);
600 let h0_hi = conv(acc_h0_hi);
601 let h1_lo = conv(acc_h1_lo);
602 let h1_hi = conv(acc_h1_hi);
603
604 let mut scores_h0 = [0.0f32; 16];
609 let mut scores_h1 = [0.0f32; 16];
610 scores_h0[..8].copy_from_slice(&h0_lo);
611 scores_h0[8..].copy_from_slice(&h0_hi);
612 scores_h1[..8].copy_from_slice(&h1_lo);
613 scores_h1[8..].copy_from_slice(&h1_hi);
614
615 for perm_pos in 0..16 {
616 let lane = PERM0[perm_pos];
617 out[lane] = scores_h0[perm_pos];
618 out[lane + 16] = scores_h1[perm_pos];
619 }
620
621 for score in out.iter_mut().take(BLOCK_LANES).skip(n_vectors) {
623 *score = 0.0;
624 }
625}
626
627fn decode_perm0_byte(block_codes: &[u8], group: usize, lane: usize) -> (u8, u8) {
628 debug_assert!(lane < BLOCK_LANES);
629 let half = lane / 16;
630 let within_half = lane % 16;
631 let perm_pos = PERM0
632 .iter()
633 .position(|&v| v == within_half)
634 .expect("lane in perm0");
635 let group_base = group * BLOCK_LANES;
636 let hi_pair = block_codes[group_base + perm_pos];
637 let lo_pair = block_codes[group_base + 16 + perm_pos];
638 if half == 0 {
639 (hi_pair & 0x0f, lo_pair & 0x0f)
640 } else {
641 (hi_pair >> 4, lo_pair >> 4)
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use crate::storage::engine::turboquant::storage::BlockedCodeStorage;
649
650 fn centroids_for(bits: u8) -> Vec<f64> {
651 let levels = 1usize << bits;
652 let step = 2.0 / levels as f64;
653 (0..levels)
654 .map(|i| -1.0 + (i as f64 + 0.5) * step)
655 .collect()
656 }
657
658 #[test]
659 fn scalar_block_score_is_zero_when_query_is_zero() {
660 let lut = QueryLut::build(&[0.0f32; 4], ¢roids_for(4));
661 let mut storage = BlockedCodeStorage::new(2);
662 storage.append(&[0x12, 0x34], 1.0);
663 let mut out = [0.0f32; BLOCK_LANES];
664 ScalarScorer.score_block(&lut, storage.block_codes(0), 2, 1, &mut out);
665 assert_eq!(out[0], 0.0);
668 }
669
670 #[test]
671 fn scalar_block_score_matches_per_vector_scalar_lut() {
672 let centroids = centroids_for(4);
677 let query = vec![0.2f32, -0.3, 0.4, -0.5];
678 let lut = QueryLut::build(&query, ¢roids);
679
680 let n_byte_groups = 2;
681 let mut storage = BlockedCodeStorage::new(n_byte_groups);
682 let packed_a = vec![0xa3u8, 0x5c];
683 let packed_b = vec![0x71u8, 0xfe];
684 storage.append(&packed_a, 1.0);
685 storage.append(&packed_b, 1.0);
686
687 let mut out = [0.0f32; BLOCK_LANES];
688 ScalarScorer.score_block(&lut, storage.block_codes(0), n_byte_groups, 2, &mut out);
689
690 for (lane, packed) in [&packed_a, &packed_b].iter().enumerate() {
691 let mut expected = 0u32;
692 for (g, byte) in packed.iter().enumerate() {
693 let lo = (byte & 0x0f) as usize;
694 let hi = (byte >> 4) as usize;
695 expected += lut.bytes[g * 32 + hi] as u32;
696 expected += lut.bytes[g * 32 + 16 + lo] as u32;
697 }
698 let expected_f = lut.scale.mul_add(expected as f32, lut.bias);
699 assert_eq!(
700 out[lane], expected_f,
701 "lane {lane} matches per-vector LUT scoring",
702 );
703 }
704
705 for lane in 2..BLOCK_LANES {
706 assert_eq!(out[lane], 0.0, "unused lane {lane} stays 0");
707 }
708 }
709
710 #[test]
711 fn select_scorer_matches_host_capability() {
712 let kernel = select_scorer().kernel();
713 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
714 {
715 if std::is_x86_feature_detected!("avx512bw")
716 && std::is_x86_feature_detected!("avx512f")
717 && std::is_x86_feature_detected!("fma")
718 {
719 assert_eq!(kernel, ScoreKernel::Avx512Bw);
720 return;
721 }
722 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
723 assert_eq!(kernel, ScoreKernel::Avx2);
724 return;
725 }
726 }
727 #[cfg(target_arch = "aarch64")]
728 {
729 assert_eq!(kernel, ScoreKernel::Neon);
732 return;
733 }
734 #[allow(unreachable_code)]
735 {
736 assert_eq!(kernel, ScoreKernel::Scalar);
737 }
738 }
739
740 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
741 #[test]
742 fn avx2_scorer_matches_scalar_oracle_across_dataset_sizes() {
743 if !std::is_x86_feature_detected!("avx2") || !std::is_x86_feature_detected!("fma") {
744 return;
745 }
746 let centroids = centroids_for(4);
747 let queries: [Vec<f32>; 4] = [
752 vec![0.0; 8],
753 vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
754 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
755 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
756 ];
757
758 for query in &queries {
759 let lut = QueryLut::build(query, ¢roids);
760 let n_byte_groups = lut.n_byte_groups;
761
762 for n in [1usize, 31, 32, 33, 95, 96, 97] {
763 let mut storage = BlockedCodeStorage::new(n_byte_groups);
764 for i in 0..n {
765 let packed: Vec<u8> = (0..n_byte_groups)
766 .map(|g| {
767 let lo = ((i + g * 3) & 0x0f) as u8;
768 let hi = ((i * 5 + g * 7) & 0x0f) as u8;
769 lo | (hi << 4)
770 })
771 .collect();
772 storage.append(&packed, 1.0);
773 }
774
775 for b in 0..storage.n_blocks() {
776 let filled = storage.block_lanes_filled(b);
777 let mut scalar_out = [0.0f32; BLOCK_LANES];
778 let mut avx2_out = [f32::NAN; BLOCK_LANES];
779 ScalarScorer.score_block(
780 &lut,
781 storage.block_codes(b),
782 n_byte_groups,
783 filled,
784 &mut scalar_out,
785 );
786 AVX2_SCORER.score_block(
787 &lut,
788 storage.block_codes(b),
789 n_byte_groups,
790 filled,
791 &mut avx2_out,
792 );
793 for lane in 0..BLOCK_LANES {
794 assert_eq!(
795 avx2_out[lane].to_bits(),
796 scalar_out[lane].to_bits(),
797 "AVX2 diverges from scalar at N={n}, block {b}, lane {lane}",
798 );
799 }
800 }
801 }
802 }
803 }
804
805 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
806 #[test]
807 fn avx512bw_scorer_matches_scalar_oracle_across_dataset_sizes() {
808 if !std::is_x86_feature_detected!("avx512bw")
809 || !std::is_x86_feature_detected!("avx512f")
810 || !std::is_x86_feature_detected!("fma")
811 {
812 return;
813 }
814 let centroids = centroids_for(4);
815 let queries: [Vec<f32>; 5] = [
823 vec![0.0; 8],
824 vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
825 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
826 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
827 vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8, 0.9, -1.0],
829 ];
830
831 for query in &queries {
832 let lut = QueryLut::build(query, ¢roids);
833 let n_byte_groups = lut.n_byte_groups;
834
835 for n in [1usize, 31, 32, 33, 95, 96, 97] {
836 let mut storage = BlockedCodeStorage::new(n_byte_groups);
837 for i in 0..n {
838 let packed: Vec<u8> = (0..n_byte_groups)
839 .map(|g| {
840 let lo = ((i + g * 3) & 0x0f) as u8;
841 let hi = ((i * 5 + g * 7) & 0x0f) as u8;
842 lo | (hi << 4)
843 })
844 .collect();
845 storage.append(&packed, 1.0);
846 }
847
848 for b in 0..storage.n_blocks() {
849 let filled = storage.block_lanes_filled(b);
850 let mut scalar_out = [0.0f32; BLOCK_LANES];
851 let mut avx512_out = [f32::NAN; BLOCK_LANES];
852 ScalarScorer.score_block(
853 &lut,
854 storage.block_codes(b),
855 n_byte_groups,
856 filled,
857 &mut scalar_out,
858 );
859 AVX512BW_SCORER.score_block(
860 &lut,
861 storage.block_codes(b),
862 n_byte_groups,
863 filled,
864 &mut avx512_out,
865 );
866 for lane in 0..BLOCK_LANES {
867 assert_eq!(
868 avx512_out[lane].to_bits(),
869 scalar_out[lane].to_bits(),
870 "AVX-512BW diverges from scalar at N={n}, block {b}, lane {lane}",
871 );
872 }
873 }
874 }
875 }
876 }
877
878 #[cfg(target_arch = "aarch64")]
879 #[test]
880 fn neon_scorer_matches_scalar_oracle_across_dataset_sizes() {
881 let centroids = centroids_for(4);
888 let queries: [Vec<f32>; 5] = [
889 vec![0.0; 8],
890 vec![0.7, -0.3, 0.4, -0.1, 0.2, -0.5, 0.6, -0.9],
891 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
892 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
893 vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8, 0.9, -1.0],
898 ];
899
900 for query in &queries {
901 let lut = QueryLut::build(query, ¢roids);
902 let n_byte_groups = lut.n_byte_groups;
903
904 for n in [1usize, 31, 32, 33, 95, 96, 97] {
907 let mut storage = BlockedCodeStorage::new(n_byte_groups);
908 for i in 0..n {
909 let packed: Vec<u8> = (0..n_byte_groups)
910 .map(|g| {
911 let lo = ((i + g * 3) & 0x0f) as u8;
912 let hi = ((i * 5 + g * 7) & 0x0f) as u8;
913 lo | (hi << 4)
914 })
915 .collect();
916 storage.append(&packed, 1.0);
917 }
918
919 for b in 0..storage.n_blocks() {
920 let filled = storage.block_lanes_filled(b);
921 let mut scalar_out = [0.0f32; BLOCK_LANES];
922 let mut neon_out = [f32::NAN; BLOCK_LANES];
923 ScalarScorer.score_block(
924 &lut,
925 storage.block_codes(b),
926 n_byte_groups,
927 filled,
928 &mut scalar_out,
929 );
930 NEON_SCORER.score_block(
931 &lut,
932 storage.block_codes(b),
933 n_byte_groups,
934 filled,
935 &mut neon_out,
936 );
937 for lane in 0..BLOCK_LANES {
938 assert_eq!(
939 neon_out[lane].to_bits(),
940 scalar_out[lane].to_bits(),
941 "NEON diverges from scalar at N={n}, block {b}, lane {lane}",
942 );
943 }
944 }
945 }
946 }
947 }
948}