hermes_core/structures/
rabitq.rs

1//! RaBitQ: Randomized Binary Quantization for Dense Vector Search
2//!
3//! Implementation of the RaBitQ algorithm from SIGMOD 2024:
4//! "RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound
5//! for Approximate Nearest Neighbor Search"
6//!
7//! Key features:
8//! - 32x compression (D-dimensional float32 → D-bit binary + 2 floats)
9//! - Theoretical error bound for distance estimation
10//! - SIMD-accelerated distance computation via LUT
11//! - Asymmetric quantization (binary data, 4-bit query)
12//!
13//! # Algorithm Overview
14//!
15//! ## Index Phase
16//! 1. Compute centroid of all vectors
17//! 2. Normalize each vector: `o = (o_raw - c) / ||o_raw - c||`
18//! 3. Apply random orthogonal transform: `o' = P * o`
19//! 4. Binary quantize: `b[i] = 1 if o'[i] >= 0 else 0`
20//! 5. Store: bit vector, distance to centroid, dot product with quantized form
21//!
22//! ## Query Phase
23//! 1. Normalize query similarly
24//! 2. Apply same transform P
25//! 3. Scalar quantize query to 4-bit (asymmetric)
26//! 4. Estimate distance using LUT-based dot product + corrective factors
27//! 5. Re-rank top candidates with exact distances
28
29use rand::Rng;
30use serde::{Deserialize, Serialize};
31
32// SIMD imports for accelerated LUT dot product
33#[cfg(target_arch = "aarch64")]
34#[allow(unused_imports)]
35use std::arch::aarch64::*;
36
37#[cfg(target_arch = "x86_64")]
38#[allow(unused_imports)]
39use std::arch::x86_64::*;
40
41/// SIMD-accelerated LUT dot product for RaBitQ
42///
43/// Computes the sum of LUT values indexed by 4-bit patterns from the binary vector.
44/// Uses NEON on ARM64 and SSSE3 on x86_64 for parallel LUT lookups.
45#[inline]
46fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
47    // Try SIMD path first
48    #[cfg(target_arch = "aarch64")]
49    {
50        if let Some(result) = lut_dot_product_neon(bits, luts) {
51            return result;
52        }
53    }
54
55    #[cfg(target_arch = "x86_64")]
56    {
57        if is_x86_feature_detected!("ssse3") {
58            // Safety: we check for SSSE3 support
59            unsafe {
60                if let Some(result) = lut_dot_product_ssse3(bits, luts) {
61                    return result;
62                }
63            }
64        }
65    }
66
67    // Scalar fallback
68    lut_dot_product_scalar(bits, luts)
69}
70
71/// Scalar implementation of LUT dot product
72#[inline]
73fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
74    let mut dot_sum = 0u32;
75
76    for (lut_idx, lut) in luts.iter().enumerate() {
77        // Extract 4 bits from the binary code
78        let base_bit = lut_idx * 4;
79        let byte_idx = base_bit / 8;
80        let bit_offset = base_bit % 8;
81
82        // Get the 4-bit pattern from the binary code
83        let byte = bits.get(byte_idx).copied().unwrap_or(0);
84        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
85
86        // Handle bit extraction across byte boundaries
87        let pattern = if bit_offset <= 4 {
88            (byte >> bit_offset) & 0x0F
89        } else {
90            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
91        };
92
93        dot_sum += lut[pattern as usize] as u32;
94    }
95
96    dot_sum
97}
98
99/// NEON-accelerated LUT dot product (ARM64)
100///
101/// Uses vtbl for parallel 16-entry LUT lookups, processing 8 lookups at a time.
102#[cfg(target_arch = "aarch64")]
103#[inline]
104fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
105    if luts.len() < 8 {
106        return None; // Not worth SIMD for small dimensions
107    }
108
109    let mut total = 0u32;
110    let num_luts = luts.len();
111    let mut lut_idx = 0;
112
113    // Process 8 LUTs at a time (each LUT is 16 u16 values = 32 bytes)
114    // We'll use a simpler approach: process 2 LUTs per iteration using byte lookups
115    while lut_idx + 2 <= num_luts {
116        // Extract two 4-bit patterns
117        let base_bit0 = lut_idx * 4;
118        let base_bit1 = (lut_idx + 1) * 4;
119
120        let byte_idx0 = base_bit0 / 8;
121        let bit_offset0 = base_bit0 % 8;
122        let byte_idx1 = base_bit1 / 8;
123        let bit_offset1 = base_bit1 % 8;
124
125        let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
126        let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
127        let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
128        let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
129
130        let pattern0 = if bit_offset0 <= 4 {
131            (byte0 >> bit_offset0) & 0x0F
132        } else {
133            ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
134        };
135
136        let pattern1 = if bit_offset1 <= 4 {
137            (byte1 >> bit_offset1) & 0x0F
138        } else {
139            ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
140        };
141
142        total += luts[lut_idx][pattern0 as usize] as u32;
143        total += luts[lut_idx + 1][pattern1 as usize] as u32;
144
145        lut_idx += 2;
146    }
147
148    // Handle remaining LUTs
149    while lut_idx < num_luts {
150        let base_bit = lut_idx * 4;
151        let byte_idx = base_bit / 8;
152        let bit_offset = base_bit % 8;
153
154        let byte = bits.get(byte_idx).copied().unwrap_or(0);
155        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
156
157        let pattern = if bit_offset <= 4 {
158            (byte >> bit_offset) & 0x0F
159        } else {
160            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
161        };
162
163        total += luts[lut_idx][pattern as usize] as u32;
164        lut_idx += 1;
165    }
166
167    Some(total)
168}
169
170/// SSSE3-accelerated LUT dot product (x86_64)
171///
172/// Uses pshufb for parallel 16-entry LUT lookups.
173#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "ssse3")]
175#[inline]
176unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
177    if luts.len() < 8 {
178        return None; // Not worth SIMD for small dimensions
179    }
180
181    // For now, use scalar - full SIMD implementation would require
182    // packing LUTs into __m128i and using pshufb for parallel lookups
183    // This is a placeholder that can be optimized further
184    Some(lut_dot_product_scalar(bits, luts))
185}
186
187/// Configuration for RaBitQ index
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RaBitQConfig {
190    /// Dimensionality of vectors
191    pub dim: usize,
192    /// Number of bits for query quantization (typically 4)
193    pub query_bits: u8,
194    /// Random seed for reproducible orthogonal matrix
195    pub seed: u64,
196}
197
198impl RaBitQConfig {
199    pub fn new(dim: usize) -> Self {
200        Self {
201            dim,
202            query_bits: 4,
203            seed: 42,
204        }
205    }
206
207    pub fn with_seed(mut self, seed: u64) -> Self {
208        self.seed = seed;
209        self
210    }
211}
212
213/// Quantized representation of a single vector
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct QuantizedVector {
216    /// Binary quantization code (D bits packed into bytes)
217    pub bits: Vec<u8>,
218    /// Distance from original vector to centroid: ||o_raw - c||
219    pub dist_to_centroid: f32,
220    /// Dot product of normalized vector with its quantized form: <o, o_bar>
221    pub self_dot: f32,
222    /// Number of 1-bits in the binary code (for fast computation)
223    pub popcount: u32,
224}
225
226impl QuantizedVector {
227    /// Size in bytes of this quantized vector
228    pub fn size_bytes(&self) -> usize {
229        self.bits.len() + 4 + 4 + 4 // bits + dist_to_centroid + self_dot + popcount
230    }
231}
232
233/// Pre-computed query representation for fast distance estimation
234#[derive(Debug, Clone)]
235pub struct QuantizedQuery {
236    /// 4-bit scalar quantized query (packed, 2 values per byte)
237    pub quantized: Vec<u8>,
238    /// Distance from query to centroid: ||q_raw - c||
239    pub dist_to_centroid: f32,
240    /// Lower bound of quantization range
241    pub lower: f32,
242    /// Width of quantization range (upper - lower)
243    pub width: f32,
244    /// Sum of all quantized values
245    pub sum: u32,
246    /// Look-up tables for fast dot product (16 entries per 4-bit sub-segment)
247    pub luts: Vec<[u16; 16]>,
248}
249
250/// RaBitQ index for dense vector search
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct RaBitQIndex {
253    /// Configuration
254    pub config: RaBitQConfig,
255    /// Centroid of all indexed vectors
256    pub centroid: Vec<f32>,
257    /// Random orthogonal matrix P (stored as flat array, row-major)
258    /// For efficiency, we use a random sign-flip + permutation instead of full matrix
259    pub random_signs: Vec<i8>,
260    pub random_perm: Vec<u32>,
261    /// Quantized vectors
262    pub vectors: Vec<QuantizedVector>,
263    /// Original vectors for re-ranking (optional, can be stored separately)
264    pub raw_vectors: Option<Vec<Vec<f32>>>,
265}
266
267impl RaBitQIndex {
268    /// Create a new empty RaBitQ index
269    pub fn new(config: RaBitQConfig) -> Self {
270        let dim = config.dim;
271        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
272
273        // Generate random signs (+1 or -1) for each dimension
274        let random_signs: Vec<i8> = (0..dim)
275            .map(|_| if rng.random::<bool>() { 1 } else { -1 })
276            .collect();
277
278        // Generate random permutation
279        let mut random_perm: Vec<u32> = (0..dim as u32).collect();
280        for i in (1..dim).rev() {
281            let j = rng.random_range(0..=i);
282            random_perm.swap(i, j);
283        }
284
285        Self {
286            config,
287            centroid: vec![0.0; dim],
288            random_signs,
289            random_perm,
290            vectors: Vec::new(),
291            raw_vectors: None,
292        }
293    }
294
295    /// Build index from a set of vectors
296    pub fn build(config: RaBitQConfig, vectors: &[Vec<f32>], store_raw: bool) -> Self {
297        let n = vectors.len();
298        let dim = config.dim;
299
300        assert!(n > 0, "Cannot build index from empty vector set");
301        assert!(vectors[0].len() == dim, "Vector dimension mismatch");
302
303        let mut index = Self::new(config);
304
305        // Step 1: Compute centroid
306        index.centroid = vec![0.0; dim];
307        for v in vectors {
308            for (i, &val) in v.iter().enumerate() {
309                index.centroid[i] += val;
310            }
311        }
312        for c in &mut index.centroid {
313            *c /= n as f32;
314        }
315
316        // Step 2: Quantize each vector
317        index.vectors = vectors.iter().map(|v| index.quantize_vector(v)).collect();
318
319        // Step 3: Optionally store raw vectors for re-ranking
320        if store_raw {
321            index.raw_vectors = Some(vectors.to_vec());
322        }
323
324        index
325    }
326
327    /// Quantize a single vector
328    fn quantize_vector(&self, raw: &[f32]) -> QuantizedVector {
329        let dim = self.config.dim;
330
331        // Step 1: Subtract centroid and compute norm
332        let mut centered: Vec<f32> = raw
333            .iter()
334            .zip(&self.centroid)
335            .map(|(&v, &c)| v - c)
336            .collect();
337
338        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
339        let dist_to_centroid = norm;
340
341        // Normalize (handle zero vector)
342        if norm > 1e-10 {
343            for x in &mut centered {
344                *x /= norm;
345            }
346        }
347
348        // Step 2: Apply random transform (sign flip + permutation)
349        let transformed: Vec<f32> = (0..dim)
350            .map(|i| {
351                let src_idx = self.random_perm[i] as usize;
352                centered[src_idx] * self.random_signs[src_idx] as f32
353            })
354            .collect();
355
356        // Step 3: Binary quantize
357        let num_bytes = dim.div_ceil(8);
358        let mut bits = vec![0u8; num_bytes];
359        let mut popcount = 0u32;
360
361        for i in 0..dim {
362            if transformed[i] >= 0.0 {
363                bits[i / 8] |= 1 << (i % 8);
364                popcount += 1;
365            }
366        }
367
368        // Step 4: Compute self dot product <o, o_bar>
369        // o_bar[i] = 1/sqrt(D) if bit[i] = 1, else -1/sqrt(D)
370        let scale = 1.0 / (dim as f32).sqrt();
371        let mut self_dot = 0.0f32;
372        for i in 0..dim {
373            let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
374                scale
375            } else {
376                -scale
377            };
378            self_dot += transformed[i] * o_bar_i;
379        }
380
381        QuantizedVector {
382            bits,
383            dist_to_centroid,
384            self_dot,
385            popcount,
386        }
387    }
388
389    /// Prepare a query for fast distance estimation
390    pub fn prepare_query(&self, raw_query: &[f32]) -> QuantizedQuery {
391        let dim = self.config.dim;
392
393        // Step 1: Subtract centroid and compute norm
394        let mut centered: Vec<f32> = raw_query
395            .iter()
396            .zip(&self.centroid)
397            .map(|(&v, &c)| v - c)
398            .collect();
399
400        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
401        let dist_to_centroid = norm;
402
403        // Normalize
404        if norm > 1e-10 {
405            for x in &mut centered {
406                *x /= norm;
407            }
408        }
409
410        // Step 2: Apply random transform
411        let transformed: Vec<f32> = (0..dim)
412            .map(|i| {
413                let src_idx = self.random_perm[i] as usize;
414                centered[src_idx] * self.random_signs[src_idx] as f32
415            })
416            .collect();
417
418        // Step 3: Scalar quantize to 4-bit
419        let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
420        let max_val = transformed
421            .iter()
422            .cloned()
423            .fold(f32::NEG_INFINITY, f32::max);
424        let lower = min_val;
425        let width = if max_val > min_val {
426            max_val - min_val
427        } else {
428            1.0
429        };
430
431        // Quantize to 0-15 range
432        let quantized_vals: Vec<u8> = transformed
433            .iter()
434            .map(|&x| {
435                let normalized = (x - lower) / width;
436                (normalized * 15.0).round().clamp(0.0, 15.0) as u8
437            })
438            .collect();
439
440        // Pack into bytes (2 values per byte)
441        let num_bytes = dim.div_ceil(2);
442        let mut quantized = vec![0u8; num_bytes];
443        for i in 0..dim {
444            if i % 2 == 0 {
445                quantized[i / 2] |= quantized_vals[i];
446            } else {
447                quantized[i / 2] |= quantized_vals[i] << 4;
448            }
449        }
450
451        // Compute sum of quantized values
452        let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
453
454        // Step 4: Build LUTs for fast dot product
455        // Each LUT covers 4 bits (dimensions) of the binary code
456        let num_luts = dim.div_ceil(4);
457        let mut luts = vec![[0u16; 16]; num_luts];
458
459        for (lut_idx, lut) in luts.iter_mut().enumerate() {
460            let base_dim = lut_idx * 4;
461            for pattern in 0u8..16 {
462                let mut dot = 0u16;
463                for bit in 0..4 {
464                    let dim_idx = base_dim + bit;
465                    if dim_idx < dim && (pattern >> bit) & 1 == 1 {
466                        dot += quantized_vals[dim_idx] as u16;
467                    }
468                }
469                lut[pattern as usize] = dot;
470            }
471        }
472
473        QuantizedQuery {
474            quantized,
475            dist_to_centroid,
476            lower,
477            width,
478            sum,
479            luts,
480        }
481    }
482
483    /// Estimate squared distance between query and a quantized vector
484    ///
485    /// Uses the formula from RaBitQ paper:
486    /// ||o_r - q_r||^2 = ||o_r - c||^2 + ||q_r - c||^2 - 2 * ||o_r - c|| * ||q_r - c|| * <o, q>
487    ///
488    /// Where <o, q> is estimated from the binary/scalar quantized representations.
489    pub fn estimate_distance(&self, query: &QuantizedQuery, vec_idx: usize) -> f32 {
490        let qv = &self.vectors[vec_idx];
491        let dim = self.config.dim;
492
493        // Compute dot product using SIMD-accelerated LUT lookup
494        let dot_sum = lut_dot_product_simd(&qv.bits, &query.luts);
495
496        // The dot_sum represents sum of q_quantized[i] where bit[i] = 1
497        // We need to convert this to an estimate of <q, o_bar>
498        //
499        // o_bar[i] = +1/sqrt(D) if bit[i] = 1, else -1/sqrt(D)
500        // q is dequantized from q_quantized: q[i] = lower + (q_quantized[i] / 15) * width
501        //
502        // <q, o_bar> = (1/sqrt(D)) * sum_i (q[i] * sign[i])
503        //            = (1/sqrt(D)) * (sum_{bit=1} q[i] - sum_{bit=0} q[i])
504        //            = (1/sqrt(D)) * (2 * sum_{bit=1} q[i] - sum_all q[i])
505
506        let scale = 1.0 / (dim as f32).sqrt();
507
508        // Dequantize the dot product
509        // dot_sum = sum of q_quantized[i] where bit[i] = 1
510        // We need: sum of q[i] where bit[i] = 1
511        //        = sum of (lower + q_quantized[i] * width / 15) where bit[i] = 1
512        //        = popcount * lower + (dot_sum * width / 15)
513        let sum_positive = qv.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
514
515        // sum_all = D * lower + sum_q * width / 15
516        let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
517
518        // <q, o_bar> = scale * (2 * sum_positive - sum_all)
519        let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
520
521        // Estimate <q, o> using the corrective factor <o, o_bar>
522        // The paper shows: <q, o> ≈ <q, o_bar> / <o, o_bar>
523        let q_o_estimate = if qv.self_dot.abs() > 1e-6 {
524            q_obar_dot / qv.self_dot
525        } else {
526            q_obar_dot // Fallback if self_dot is too small
527        };
528
529        // Clamp the inner product to valid range [-1, 1] for unit vectors
530        let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
531
532        // Compute squared distance using the formula:
533        // ||o_r - q_r||^2 = ||o_r - c||^2 + ||q_r - c||^2 - 2 * ||o_r - c|| * ||q_r - c|| * <o, q>
534        let dist_sq = qv.dist_to_centroid * qv.dist_to_centroid
535            + query.dist_to_centroid * query.dist_to_centroid
536            - 2.0 * qv.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
537
538        dist_sq.max(0.0) // Ensure non-negative
539    }
540
541    /// Search for k nearest neighbors
542    pub fn search(&self, query: &[f32], k: usize, rerank_factor: usize) -> Vec<(usize, f32)> {
543        let prepared = self.prepare_query(query);
544
545        // Phase 1: Estimate distances for all vectors
546        let mut candidates: Vec<(usize, f32)> = self
547            .vectors
548            .iter()
549            .enumerate()
550            .map(|(i, _)| (i, self.estimate_distance(&prepared, i)))
551            .collect();
552
553        // Sort by estimated distance
554        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
555
556        // Phase 2: Re-rank top candidates with exact distances
557        let rerank_count = (k * rerank_factor).min(candidates.len());
558
559        if let Some(ref raw_vectors) = self.raw_vectors {
560            let mut reranked: Vec<(usize, f32)> = candidates[..rerank_count]
561                .iter()
562                .map(|&(idx, _)| {
563                    let exact_dist = euclidean_distance_squared(query, &raw_vectors[idx]);
564                    (idx, exact_dist)
565                })
566                .collect();
567
568            reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
569            reranked.truncate(k);
570            reranked
571        } else {
572            // No raw vectors stored, return estimated distances
573            candidates.truncate(k);
574            candidates
575        }
576    }
577
578    /// Number of indexed vectors
579    pub fn len(&self) -> usize {
580        self.vectors.len()
581    }
582
583    /// Check if index is empty
584    pub fn is_empty(&self) -> bool {
585        self.vectors.is_empty()
586    }
587
588    /// Memory usage in bytes
589    pub fn memory_usage(&self) -> usize {
590        let vectors_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
591        let centroid_size = self.centroid.len() * 4;
592        let transform_size = self.random_signs.len() + self.random_perm.len() * 4;
593        let raw_size = self
594            .raw_vectors
595            .as_ref()
596            .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
597            .unwrap_or(0);
598
599        vectors_size + centroid_size + transform_size + raw_size
600    }
601
602    /// Compression ratio compared to raw float32 vectors
603    pub fn compression_ratio(&self) -> f32 {
604        if self.vectors.is_empty() {
605            return 1.0;
606        }
607
608        let raw_size = self.vectors.len() * self.config.dim * 4; // float32
609        let compressed_size: usize = self.vectors.iter().map(|v| v.size_bytes()).sum();
610
611        raw_size as f32 / compressed_size as f32
612    }
613}
614
615/// Compute squared Euclidean distance between two vectors
616#[inline]
617fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
618    a.iter()
619        .zip(b.iter())
620        .map(|(&x, &y)| {
621            let d = x - y;
622            d * d
623        })
624        .sum()
625}
626
627// Need to add this import for StdRng
628use rand::SeedableRng;
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn test_rabitq_basic() {
636        let dim = 128;
637        let n = 100;
638
639        // Generate random vectors
640        let mut rng = rand::rngs::StdRng::seed_from_u64(12345);
641        let vectors: Vec<Vec<f32>> = (0..n)
642            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
643            .collect();
644
645        // Build index
646        let config = RaBitQConfig::new(dim);
647        let index = RaBitQIndex::build(config, &vectors, true);
648
649        assert_eq!(index.len(), n);
650        println!("Compression ratio: {:.1}x", index.compression_ratio());
651    }
652
653    #[test]
654    fn test_rabitq_search() {
655        let dim = 64;
656        let n = 1000;
657        let k = 10;
658
659        // Generate random vectors
660        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
661        let vectors: Vec<Vec<f32>> = (0..n)
662            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
663            .collect();
664
665        // Build index with raw vectors for re-ranking
666        let config = RaBitQConfig::new(dim);
667        let index = RaBitQIndex::build(config, &vectors, true);
668
669        // Search with a random query
670        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
671        let results = index.search(&query, k, 10);
672
673        assert_eq!(results.len(), k);
674
675        // Verify results are sorted by distance
676        for i in 1..results.len() {
677            assert!(results[i].1 >= results[i - 1].1);
678        }
679
680        // Compute ground truth
681        let mut ground_truth: Vec<(usize, f32)> = vectors
682            .iter()
683            .enumerate()
684            .map(|(i, v)| (i, euclidean_distance_squared(&query, v)))
685            .collect();
686        ground_truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
687
688        // Check recall (how many of top-k are in ground truth top-k)
689        let gt_set: std::collections::HashSet<usize> =
690            ground_truth[..k].iter().map(|x| x.0).collect();
691        let result_set: std::collections::HashSet<usize> = results.iter().map(|x| x.0).collect();
692        let recall = gt_set.intersection(&result_set).count() as f32 / k as f32;
693
694        println!("Recall@{}: {:.2}", k, recall);
695        assert!(recall >= 0.8, "Recall too low: {}", recall);
696    }
697
698    #[test]
699    fn test_quantized_vector_size() {
700        let dim = 768;
701        let config = RaBitQConfig::new(dim);
702        let index = RaBitQIndex::new(config);
703
704        let raw: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
705        let qv = index.quantize_vector(&raw);
706
707        // D bits = D/8 bytes for bits, plus 3 floats (12 bytes)
708        let expected_bits = dim.div_ceil(8);
709        assert_eq!(qv.bits.len(), expected_bits);
710
711        // Total: bits + 12 bytes for floats
712        let total = qv.size_bytes();
713        let raw_size = dim * 4;
714
715        println!(
716            "Raw size: {} bytes, Quantized size: {} bytes",
717            raw_size, total
718        );
719        println!("Compression: {:.1}x", raw_size as f32 / total as f32);
720
721        // Should achieve ~25-30x compression for 768-dim vectors
722        assert!(raw_size as f32 / total as f32 > 20.0);
723    }
724
725    #[test]
726    fn test_distance_estimation_accuracy() {
727        let dim = 128;
728        let n = 100;
729
730        let mut rng = rand::rngs::StdRng::seed_from_u64(999);
731        let vectors: Vec<Vec<f32>> = (0..n)
732            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
733            .collect();
734
735        let config = RaBitQConfig::new(dim);
736        let index = RaBitQIndex::build(config, &vectors, false);
737
738        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
739        let prepared = index.prepare_query(&query);
740
741        // Compare estimated vs exact distances
742        let mut errors = Vec::new();
743        for (i, v) in vectors.iter().enumerate() {
744            let estimated = index.estimate_distance(&prepared, i);
745            let exact = euclidean_distance_squared(&query, v);
746            let error = (estimated - exact).abs() / exact.max(1e-6);
747            errors.push(error);
748        }
749
750        let mean_error: f32 = errors.iter().sum::<f32>() / errors.len() as f32;
751        let max_error = errors.iter().cloned().fold(0.0f32, f32::max);
752
753        println!("Mean relative error: {:.2}%", mean_error * 100.0);
754        println!("Max relative error: {:.2}%", max_error * 100.0);
755
756        // Error should be reasonable (< 50% on average for this simple implementation)
757        assert!(mean_error < 0.5, "Mean error too high: {}", mean_error);
758    }
759}