omendb_core/compression/
fastscan.rs

1//! FastScan SIMD-accelerated distance computation for quantized vectors
2//!
3//! FastScan uses SIMD shuffle instructions (pshufb/vqtbl1q) to perform
4//! parallel LUT lookups, computing distances for 32 neighbors at once.
5//!
6//! # Performance
7//!
8//! Benchmark on M3 Max showed 5x speedup vs per-neighbor ADC:
9//! - Per-neighbor ADC: 1.93 µs for 32 neighbors
10//! - FastScan NEON: 390 ns for 32 neighbors
11//!
12//! # Memory Layout
13//!
14//! FastScan requires codes to be interleaved by sub-quantizer position:
15//! ```text
16//! [n0_sq0, n1_sq0, ..., n31_sq0]  // 32 bytes - sub-quantizer 0 for all neighbors
17//! [n0_sq1, n1_sq1, ..., n31_sq1]  // 32 bytes - sub-quantizer 1 for all neighbors
18//! ```
19//!
20//! For 4-bit RaBitQ with 768 dimensions:
21//! - code_size = 768 / 2 = 384 bytes per vector
22//! - 384 sub-quantizers, each holding 2 dimension codes (lo/hi nibbles)
23//!
24//! # LUT Format
25//!
26//! For 4-bit quantization, each sub-quantizer has TWO 16-entry u8 LUTs:
27//! - `luts_lo[sq][code]` for the lo nibble (even dimension)
28//! - `luts_hi[sq][code]` for the hi nibble (odd dimension)
29
30use crate::compression::ADCTable;
31
32/// Batch size for FastScan - AVX2/NEON process 32 bytes at a time
33pub const BATCH_SIZE: usize = 32;
34
35/// Quantized LUT for FastScan (u8 distances for SIMD efficiency)
36///
37/// Contains pre-computed distance contributions for each possible code value.
38/// For 4-bit quantization: 16 entries per sub-quantizer, separate LUTs for lo/hi nibbles.
39#[derive(Debug, Clone)]
40pub struct FastScanLUT {
41    /// Lo nibble LUTs: luts_lo[sq][code] = quantized distance for even dimension
42    luts_lo: Vec<[u8; 16]>,
43
44    /// Hi nibble LUTs: luts_hi[sq][code] = quantized distance for odd dimension
45    luts_hi: Vec<[u8; 16]>,
46
47    /// Scale factor to convert accumulated u16 back to approximate f32 distance
48    scale: f32,
49
50    /// Offset to add after scaling (for accurate reconstruction)
51    offset: f32,
52}
53
54impl FastScanLUT {
55    /// Build FastScan LUT from RaBitQ ADC table
56    ///
57    /// ADC table format: table[dim][code] = partial squared distance
58    /// For 4-bit quantization with D dimensions:
59    /// - D/2 sub-quantizers (each byte packs 2 dimensions)
60    /// - table[sq*2][code] = distance contribution for lo nibble (even dim)
61    /// - table[sq*2+1][code] = distance contribution for hi nibble (odd dim)
62    #[must_use]
63    pub fn from_adc_table(adc: &ADCTable) -> Option<Self> {
64        // Only support 4-bit for now
65        if adc.bits() != 4 {
66            return None;
67        }
68
69        let dimensions = adc.dimensions();
70        if dimensions == 0 || !dimensions.is_multiple_of(2) {
71            return None;
72        }
73
74        let num_sq = dimensions / 2;
75
76        // Find global min/max across all LUT entries for uniform scaling
77        // We sum lo + hi contributions, so find min/max of sums
78        let mut global_min = f32::MAX;
79        let mut global_max = f32::MIN;
80
81        for sq in 0..num_sq {
82            for lo_code in 0..16 {
83                for hi_code in 0..16 {
84                    let dist_lo = adc.get(sq * 2, lo_code);
85                    let dist_hi = adc.get(sq * 2 + 1, hi_code);
86                    let sum = dist_lo + dist_hi;
87                    global_min = global_min.min(sum);
88                    global_max = global_max.max(sum);
89                }
90            }
91        }
92
93        // Calculate safe max per nibble to prevent u16 overflow
94        // Max accumulation = num_sq * 2 * max_per_nibble <= 65535
95        // Formula: max_per_nibble = floor(65535 / (num_sq * 2))
96        // Examples: 128D->127, 512D->127, 768D->85, 1536D->42
97        let safe_max_per_nibble = (65535.0 / (num_sq * 2) as f32).floor().min(127.0);
98
99        // Each sub-quantizer contributes to the sum, so scale per-sq contributions
100        // to fit in u8 such that the total sum fits in u16
101        let range = global_max - global_min;
102        let scale_factor = if range > 1e-7 {
103            safe_max_per_nibble / (range / 2.0) // Divide range by 2 since lo+hi both contribute
104        } else {
105            1.0
106        };
107
108        let offset = global_min;
109
110        // Build separate LUTs for lo and hi nibbles
111        let mut luts_lo = Vec::with_capacity(num_sq);
112        let mut luts_hi = Vec::with_capacity(num_sq);
113
114        for sq in 0..num_sq {
115            let dim_lo = sq * 2;
116            let dim_hi = sq * 2 + 1;
117
118            // Lo nibble LUT (even dimension)
119            let mut lut_lo = [0u8; 16];
120            for (code, entry) in lut_lo.iter_mut().enumerate() {
121                let dist = adc.get(dim_lo, code);
122                // Subtract per-dimension share of offset, then scale
123                *entry = ((dist - offset / 2.0) * scale_factor)
124                    .round()
125                    .clamp(0.0, safe_max_per_nibble) as u8;
126            }
127
128            // Hi nibble LUT (odd dimension)
129            let mut lut_hi = [0u8; 16];
130            for (code, entry) in lut_hi.iter_mut().enumerate() {
131                let dist = adc.get(dim_hi, code);
132                *entry = ((dist - offset / 2.0) * scale_factor)
133                    .round()
134                    .clamp(0.0, safe_max_per_nibble) as u8;
135            }
136
137            luts_lo.push(lut_lo);
138            luts_hi.push(lut_hi);
139        }
140
141        Some(Self {
142            luts_lo,
143            luts_hi,
144            scale: 1.0 / scale_factor,
145            offset,
146        })
147    }
148
149    /// Get number of sub-quantizers
150    #[must_use]
151    pub fn num_sq(&self) -> usize {
152        self.luts_lo.len()
153    }
154
155    /// Get lo nibble LUTs
156    #[must_use]
157    pub fn luts_lo(&self) -> &[[u8; 16]] {
158        &self.luts_lo
159    }
160
161    /// Get hi nibble LUTs
162    #[must_use]
163    pub fn luts_hi(&self) -> &[[u8; 16]] {
164        &self.luts_hi
165    }
166
167    /// Convert accumulated u16 distance back to approximate f32
168    #[must_use]
169    pub fn to_f32(&self, accumulated: u16) -> f32 {
170        accumulated as f32 * self.scale + self.offset
171    }
172}
173
174/// Compute batched distances using FastScan NEON (ARM)
175///
176/// # Arguments
177/// * `luts_lo` - Lo nibble LUTs (one 16-byte LUT per sub-quantizer)
178/// * `luts_hi` - Hi nibble LUTs (one 16-byte LUT per sub-quantizer)
179/// * `interleaved_codes` - Interleaved neighbor codes (num_sq * 32 bytes)
180///
181/// # Returns
182/// Array of 32 accumulated u16 distances
183#[cfg(target_arch = "aarch64")]
184#[must_use]
185pub fn fastscan_batch_neon(
186    luts_lo: &[[u8; 16]],
187    luts_hi: &[[u8; 16]],
188    interleaved_codes: &[u8],
189) -> [u16; BATCH_SIZE] {
190    use std::arch::aarch64::{
191        uint16x8_t, vaddl_u8, vaddq_u16, vandq_u8, vdupq_n_u16, vdupq_n_u8, vget_high_u8,
192        vget_low_u8, vld1q_u8, vqtbl1q_u8, vshrq_n_u8, vst1q_u16,
193    };
194
195    unsafe {
196        let low_mask = vdupq_n_u8(0x0F);
197
198        // Four accumulators for 32 results (NEON processes 8 u16 at a time)
199        let mut accum0: uint16x8_t = vdupq_n_u16(0);
200        let mut accum1: uint16x8_t = vdupq_n_u16(0);
201        let mut accum2: uint16x8_t = vdupq_n_u16(0);
202        let mut accum3: uint16x8_t = vdupq_n_u16(0);
203
204        // Process each sub-quantizer
205        for sq in 0..luts_lo.len() {
206            let base = sq * BATCH_SIZE;
207
208            // Load separate LUTs for lo and hi nibbles
209            let lut_lo_vec = vld1q_u8(luts_lo[sq].as_ptr());
210            let lut_hi_vec = vld1q_u8(luts_hi[sq].as_ptr());
211
212            // Load 32 bytes of codes (32 neighbors' codes for this sub-quantizer)
213            let codes_0_15 = vld1q_u8(interleaved_codes.as_ptr().add(base));
214            let codes_16_31 = vld1q_u8(interleaved_codes.as_ptr().add(base + 16));
215
216            // Extract lo nibbles and lookup in lut_lo
217            let idx_lo_0 = vandq_u8(codes_0_15, low_mask);
218            let idx_lo_1 = vandq_u8(codes_16_31, low_mask);
219            let vals_lo_0 = vqtbl1q_u8(lut_lo_vec, idx_lo_0);
220            let vals_lo_1 = vqtbl1q_u8(lut_lo_vec, idx_lo_1);
221
222            // Extract hi nibbles and lookup in lut_hi
223            let idx_hi_0 = vshrq_n_u8(codes_0_15, 4);
224            let idx_hi_1 = vshrq_n_u8(codes_16_31, 4);
225            let vals_hi_0 = vqtbl1q_u8(lut_hi_vec, idx_hi_0);
226            let vals_hi_1 = vqtbl1q_u8(lut_hi_vec, idx_hi_1);
227
228            // Accumulate as u16 to avoid overflow
229            // Neighbors 0-7
230            accum0 = vaddq_u16(
231                accum0,
232                vaddl_u8(vget_low_u8(vals_lo_0), vget_low_u8(vals_hi_0)),
233            );
234            // Neighbors 8-15
235            accum1 = vaddq_u16(
236                accum1,
237                vaddl_u8(vget_high_u8(vals_lo_0), vget_high_u8(vals_hi_0)),
238            );
239            // Neighbors 16-23
240            accum2 = vaddq_u16(
241                accum2,
242                vaddl_u8(vget_low_u8(vals_lo_1), vget_low_u8(vals_hi_1)),
243            );
244            // Neighbors 24-31
245            accum3 = vaddq_u16(
246                accum3,
247                vaddl_u8(vget_high_u8(vals_lo_1), vget_high_u8(vals_hi_1)),
248            );
249        }
250
251        // Extract results
252        let mut results = [0u16; BATCH_SIZE];
253        vst1q_u16(results.as_mut_ptr(), accum0);
254        vst1q_u16(results.as_mut_ptr().add(8), accum1);
255        vst1q_u16(results.as_mut_ptr().add(16), accum2);
256        vst1q_u16(results.as_mut_ptr().add(24), accum3);
257
258        results
259    }
260}
261
262/// Compute batched distances using FastScan AVX2 (x86_64)
263#[cfg(target_arch = "x86_64")]
264#[allow(clippy::cast_ptr_alignment)] // loadu/storeu intrinsics handle unaligned access
265#[must_use]
266pub fn fastscan_batch_avx2(
267    luts_lo: &[[u8; 16]],
268    luts_hi: &[[u8; 16]],
269    interleaved_codes: &[u8],
270) -> [u16; BATCH_SIZE] {
271    use std::arch::x86_64::{
272        __m128i, __m256i, _mm256_add_epi16, _mm256_and_si256, _mm256_broadcastsi128_si256,
273        _mm256_cvtepu8_epi16, _mm256_loadu_si256, _mm256_set1_epi8, _mm256_setzero_si256,
274        _mm256_shuffle_epi8, _mm256_srli_epi16, _mm256_storeu_si256, _mm_loadu_si128,
275    };
276
277    unsafe {
278        if !std::is_x86_feature_detected!("avx2") {
279            return fastscan_batch_scalar(luts_lo, luts_hi, interleaved_codes);
280        }
281
282        let low_mask = _mm256_set1_epi8(0x0F);
283
284        // Two accumulators for 32 u16 results
285        let mut accum_lo = _mm256_setzero_si256(); // neighbors 0-15
286        let mut accum_hi = _mm256_setzero_si256(); // neighbors 16-31
287
288        for sq in 0..luts_lo.len() {
289            let base = sq * BATCH_SIZE;
290
291            // Broadcast 16-byte LUTs to 256-bit registers
292            let lut_lo_128 = _mm_loadu_si128(luts_lo[sq].as_ptr() as *const __m128i);
293            let lut_hi_128 = _mm_loadu_si128(luts_hi[sq].as_ptr() as *const __m128i);
294            let lut_lo_vec = _mm256_broadcastsi128_si256(lut_lo_128);
295            let lut_hi_vec = _mm256_broadcastsi128_si256(lut_hi_128);
296
297            // Load 32 codes
298            let codes = _mm256_loadu_si256(interleaved_codes.as_ptr().add(base) as *const __m256i);
299
300            // Lo nibble lookups using lut_lo
301            let idx_lo = _mm256_and_si256(codes, low_mask);
302            let vals_lo = _mm256_shuffle_epi8(lut_lo_vec, idx_lo);
303
304            // Hi nibble lookups using lut_hi
305            let idx_hi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), low_mask);
306            let vals_hi = _mm256_shuffle_epi8(lut_hi_vec, idx_hi);
307
308            // Add lo + hi as u8 (safe because max is 127+127=254)
309            // Then widen to u16 and accumulate
310            // Note: We need to handle the 32 u8 -> 32 u16 conversion carefully
311            // AVX2 can only widen 16 u8 -> 16 u16 at a time
312
313            // Extract low 16 bytes, widen to u16, accumulate
314            let vals_lo_128 = _mm256_castsi256_si128(vals_lo);
315            let vals_hi_128 = _mm256_castsi256_si128(vals_hi);
316            let sum_lo_16 = _mm256_cvtepu8_epi16(vals_lo_128);
317            let sum_hi_16 = _mm256_cvtepu8_epi16(vals_hi_128);
318            accum_lo = _mm256_add_epi16(accum_lo, sum_lo_16);
319            accum_lo = _mm256_add_epi16(accum_lo, sum_hi_16);
320
321            // Extract high 16 bytes, widen to u16, accumulate
322            let vals_lo_high = _mm256_extracti128_si256(vals_lo, 1);
323            let vals_hi_high = _mm256_extracti128_si256(vals_hi, 1);
324            let sum_lo_high_16 = _mm256_cvtepu8_epi16(vals_lo_high);
325            let sum_hi_high_16 = _mm256_cvtepu8_epi16(vals_hi_high);
326            accum_hi = _mm256_add_epi16(accum_hi, sum_lo_high_16);
327            accum_hi = _mm256_add_epi16(accum_hi, sum_hi_high_16);
328        }
329
330        let mut results = [0u16; BATCH_SIZE];
331        _mm256_storeu_si256(results.as_mut_ptr() as *mut __m256i, accum_lo);
332        _mm256_storeu_si256(results.as_mut_ptr().add(16) as *mut __m256i, accum_hi);
333        results
334    }
335}
336
337#[cfg(target_arch = "x86_64")]
338use std::arch::x86_64::{_mm256_castsi256_si128, _mm256_extracti128_si256};
339
340/// Scalar fallback for platforms without SIMD
341#[must_use]
342pub fn fastscan_batch_scalar(
343    luts_lo: &[[u8; 16]],
344    luts_hi: &[[u8; 16]],
345    interleaved_codes: &[u8],
346) -> [u16; BATCH_SIZE] {
347    let mut results = [0u16; BATCH_SIZE];
348
349    for (sq, (lut_lo, lut_hi)) in luts_lo.iter().zip(luts_hi.iter()).enumerate() {
350        let base = sq * BATCH_SIZE;
351        for n in 0..BATCH_SIZE {
352            let code = interleaved_codes[base + n];
353            let lo_idx = (code & 0x0F) as usize;
354            let hi_idx = ((code >> 4) & 0x0F) as usize;
355            results[n] += lut_lo[lo_idx] as u16 + lut_hi[hi_idx] as u16;
356        }
357    }
358
359    results
360}
361
362/// Choose the best FastScan implementation for the current platform
363#[inline]
364#[must_use]
365pub fn fastscan_batch(
366    luts_lo: &[[u8; 16]],
367    luts_hi: &[[u8; 16]],
368    interleaved_codes: &[u8],
369) -> [u16; BATCH_SIZE] {
370    #[cfg(target_arch = "aarch64")]
371    {
372        fastscan_batch_neon(luts_lo, luts_hi, interleaved_codes)
373    }
374    #[cfg(target_arch = "x86_64")]
375    {
376        fastscan_batch_avx2(luts_lo, luts_hi, interleaved_codes)
377    }
378    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
379    {
380        fastscan_batch_scalar(luts_lo, luts_hi, interleaved_codes)
381    }
382}
383
384/// Convenience wrapper using FastScanLUT struct
385#[inline]
386#[must_use]
387pub fn fastscan_batch_with_lut(lut: &FastScanLUT, interleaved_codes: &[u8]) -> [u16; BATCH_SIZE] {
388    fastscan_batch(lut.luts_lo(), lut.luts_hi(), interleaved_codes)
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_fastscan_scalar() {
397        // Create simple LUTs (distance = code value)
398        let luts_lo: Vec<[u8; 16]> = (0..4).map(|_| core::array::from_fn(|i| i as u8)).collect();
399        let luts_hi: Vec<[u8; 16]> = (0..4).map(|_| core::array::from_fn(|i| i as u8)).collect();
400
401        // Create interleaved codes: all zeros
402        let codes = vec![0u8; 4 * BATCH_SIZE];
403
404        let results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
405
406        // All distances should be 0 (code 0 maps to distance 0)
407        for &r in &results {
408            assert_eq!(r, 0);
409        }
410    }
411
412    #[test]
413    fn test_fastscan_scalar_nonzero() {
414        // LUT where each code maps to its value
415        let luts_lo: Vec<[u8; 16]> = (0..2).map(|_| core::array::from_fn(|i| i as u8)).collect();
416        let luts_hi: Vec<[u8; 16]> = (0..2).map(|_| core::array::from_fn(|i| i as u8)).collect();
417
418        // Create codes: first neighbor has all 0x11 (lo=1, hi=1)
419        let mut codes = vec![0u8; 2 * BATCH_SIZE];
420        codes[0] = 0x11; // sq0, neighbor 0: lo=1, hi=1
421        codes[BATCH_SIZE] = 0x22; // sq1, neighbor 0: lo=2, hi=2
422
423        let results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
424
425        // Neighbor 0: (1+1) + (2+2) = 6
426        assert_eq!(results[0], 6);
427        // Other neighbors: all 0
428        assert_eq!(results[1], 0);
429    }
430
431    #[test]
432    fn test_fastscan_matches_scalar() {
433        // Create random-ish LUTs
434        let luts_lo: Vec<[u8; 16]> = (0..8)
435            .map(|sq| core::array::from_fn(|i| ((sq * 17 + i * 7) % 100) as u8))
436            .collect();
437        let luts_hi: Vec<[u8; 16]> = (0..8)
438            .map(|sq| core::array::from_fn(|i| ((sq * 13 + i * 11) % 100) as u8))
439            .collect();
440
441        // Create random-ish codes
442        let mut codes = vec![0u8; 8 * BATCH_SIZE];
443        for (i, code) in codes.iter_mut().enumerate() {
444            *code = ((i * 31 + 17) % 256) as u8;
445        }
446
447        let scalar_results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
448        let simd_results = fastscan_batch(&luts_lo, &luts_hi, &codes);
449
450        // SIMD should match scalar
451        for (i, (&scalar, &simd)) in scalar_results.iter().zip(simd_results.iter()).enumerate() {
452            assert_eq!(
453                scalar, simd,
454                "Mismatch at neighbor {i}: scalar={scalar}, simd={simd}"
455            );
456        }
457    }
458}