Skip to main content

hermes_core/structures/vector/quantization/
rabitq.rs

1//! RaBitQ: Randomized Binary Quantization for Dense Vector Search
2//!
3//! Implementation of the RaBitQ algorithm from SIGMOD 2024:
4//! "RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound
5//! for Approximate Nearest Neighbor Search"
6//!
7//! Key features:
8//! - 32x compression (D-dimensional float32 → D-bit binary + 2 floats)
9//! - Theoretical error bound for distance estimation
10//! - SIMD-accelerated distance computation via LUT
11//! - Asymmetric quantization (binary data, 4-bit query)
12
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(target_arch = "x86_64")]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27/// Configuration for RaBitQ quantization
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RaBitQConfig {
30    /// Dimensionality of vectors
31    pub dim: usize,
32    /// Number of bits for query quantization (typically 4)
33    pub query_bits: u8,
34    /// Random seed for reproducible orthogonal matrix
35    pub seed: u64,
36}
37
38impl RaBitQConfig {
39    pub fn new(dim: usize) -> Self {
40        Self {
41            dim,
42            query_bits: 4,
43            seed: 42,
44        }
45    }
46
47    pub fn with_seed(mut self, seed: u64) -> Self {
48        self.seed = seed;
49        self
50    }
51}
52
53/// Quantized representation of a single vector
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QuantizedVector {
56    /// Binary quantization code (D bits packed into bytes)
57    pub bits: Vec<u8>,
58    /// Distance from original vector to centroid: ||o_raw - c||
59    pub dist_to_centroid: f32,
60    /// Dot product of normalized vector with its quantized form: <o, o_bar>
61    pub self_dot: f32,
62    /// Number of 1-bits in the binary code (for fast computation)
63    pub popcount: u32,
64}
65
66impl QuantizedCode for QuantizedVector {
67    fn size_bytes(&self) -> usize {
68        self.bits.len() + 4 + 4 + 4 // bits + dist_to_centroid + self_dot + popcount
69    }
70}
71
72/// Pre-computed query representation for fast distance estimation
73#[derive(Debug, Clone)]
74pub struct QuantizedQuery {
75    /// 4-bit scalar quantized query (packed, 2 values per byte)
76    pub quantized: Vec<u8>,
77    /// Distance from query to centroid: ||q_raw - c||
78    pub dist_to_centroid: f32,
79    /// Lower bound of quantization range
80    pub lower: f32,
81    /// Width of quantization range (upper - lower)
82    pub width: f32,
83    /// Sum of all quantized values
84    pub sum: u32,
85    /// Look-up tables for fast dot product (16 entries per 4-bit sub-segment)
86    pub luts: Vec<[u16; 16]>,
87}
88
89/// RaBitQ codebook (random transform parameters)
90///
91/// Trained once, shared across all segments for merge compatibility.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RaBitQCodebook {
94    /// Configuration
95    pub config: RaBitQConfig,
96    /// Random signs for transform (+1 or -1)
97    pub random_signs: Vec<i8>,
98    /// Random permutation for transform
99    pub random_perm: Vec<u32>,
100    /// Version for merge compatibility checking
101    pub version: u64,
102}
103
104impl RaBitQCodebook {
105    /// Create a new RaBitQ codebook with random transform
106    pub fn new(config: RaBitQConfig) -> Self {
107        let dim = config.dim;
108        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
109
110        // Generate random signs (+1 or -1) for each dimension
111        let random_signs: Vec<i8> = (0..dim)
112            .map(|_| if rng.random::<bool>() { 1 } else { -1 })
113            .collect();
114
115        // Generate random permutation
116        let mut random_perm: Vec<u32> = (0..dim as u32).collect();
117        for i in (1..dim).rev() {
118            let j = rng.random_range(0..=i);
119            random_perm.swap(i, j);
120        }
121
122        let version = std::time::SystemTime::now()
123            .duration_since(std::time::UNIX_EPOCH)
124            .unwrap_or_default()
125            .as_millis() as u64;
126
127        Self {
128            config,
129            random_signs,
130            random_perm,
131            version,
132        }
133    }
134
135    /// Encode a vector to binary quantized form
136    ///
137    /// If centroid is provided, encodes the residual (vector - centroid).
138    pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> QuantizedVector {
139        let dim = self.config.dim;
140
141        // Step 1: Subtract centroid (if provided) and compute norm
142        let centered: Vec<f32> = if let Some(c) = centroid {
143            vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
144        } else {
145            vector.to_vec()
146        };
147
148        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
149        let dist_to_centroid = norm;
150
151        // Normalize (handle zero vector)
152        let normalized: Vec<f32> = if norm > 1e-10 {
153            centered.iter().map(|x| x / norm).collect()
154        } else {
155            centered
156        };
157
158        // Step 2: Apply random transform (sign flip + permutation)
159        let transformed: Vec<f32> = (0..dim)
160            .map(|i| {
161                let src_idx = self.random_perm[i] as usize;
162                normalized[src_idx] * self.random_signs[src_idx] as f32
163            })
164            .collect();
165
166        // Step 3: Binary quantize
167        let num_bytes = dim.div_ceil(8);
168        let mut bits = vec![0u8; num_bytes];
169        let mut popcount = 0u32;
170
171        for i in 0..dim {
172            if transformed[i] >= 0.0 {
173                bits[i / 8] |= 1 << (i % 8);
174                popcount += 1;
175            }
176        }
177
178        // Step 4: Compute self dot product <o, o_bar>
179        let scale = 1.0 / (dim as f32).sqrt();
180        let mut self_dot = 0.0f32;
181        for i in 0..dim {
182            let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
183                scale
184            } else {
185                -scale
186            };
187            self_dot += transformed[i] * o_bar_i;
188        }
189
190        QuantizedVector {
191            bits,
192            dist_to_centroid,
193            self_dot,
194            popcount,
195        }
196    }
197
198    /// Prepare a query for fast distance estimation
199    pub fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> QuantizedQuery {
200        let dim = self.config.dim;
201
202        // Step 1: Subtract centroid (if provided) and compute norm
203        let centered: Vec<f32> = if let Some(c) = centroid {
204            query.iter().zip(c).map(|(&v, &c)| v - c).collect()
205        } else {
206            query.to_vec()
207        };
208
209        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
210        let dist_to_centroid = norm;
211
212        // Normalize
213        let normalized: Vec<f32> = if norm > 1e-10 {
214            centered.iter().map(|x| x / norm).collect()
215        } else {
216            centered
217        };
218
219        // Step 2: Apply random transform
220        let transformed: Vec<f32> = (0..dim)
221            .map(|i| {
222                let src_idx = self.random_perm[i] as usize;
223                normalized[src_idx] * self.random_signs[src_idx] as f32
224            })
225            .collect();
226
227        // Step 3: Scalar quantize to 4-bit
228        let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
229        let max_val = transformed
230            .iter()
231            .cloned()
232            .fold(f32::NEG_INFINITY, f32::max);
233        let lower = min_val;
234        let width = if max_val > min_val {
235            max_val - min_val
236        } else {
237            1.0
238        };
239
240        // Quantize to 0-15 range
241        let quantized_vals: Vec<u8> = transformed
242            .iter()
243            .map(|&x| {
244                let normalized = (x - lower) / width;
245                (normalized * 15.0).round().clamp(0.0, 15.0) as u8
246            })
247            .collect();
248
249        // Pack into bytes (2 values per byte)
250        let num_bytes = dim.div_ceil(2);
251        let mut quantized = vec![0u8; num_bytes];
252        for i in 0..dim {
253            if i % 2 == 0 {
254                quantized[i / 2] |= quantized_vals[i];
255            } else {
256                quantized[i / 2] |= quantized_vals[i] << 4;
257            }
258        }
259
260        // Compute sum of quantized values
261        let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
262
263        // Step 4: Build LUTs for fast dot product
264        let num_luts = dim.div_ceil(4);
265        let mut luts = vec![[0u16; 16]; num_luts];
266
267        for (lut_idx, lut) in luts.iter_mut().enumerate() {
268            let base_dim = lut_idx * 4;
269            for pattern in 0u8..16 {
270                let mut dot = 0u16;
271                for bit in 0..4 {
272                    let dim_idx = base_dim + bit;
273                    if dim_idx < dim && (pattern >> bit) & 1 == 1 {
274                        dot += quantized_vals[dim_idx] as u16;
275                    }
276                }
277                lut[pattern as usize] = dot;
278            }
279        }
280
281        QuantizedQuery {
282            quantized,
283            dist_to_centroid,
284            lower,
285            width,
286            sum,
287            luts,
288        }
289    }
290
291    /// Estimate squared distance between query and a quantized vector
292    pub fn estimate_distance(&self, query: &QuantizedQuery, code: &QuantizedVector) -> f32 {
293        let dim = self.config.dim;
294
295        // Compute dot product using SIMD-accelerated LUT lookup
296        let dot_sum = lut_dot_product_simd(&code.bits, &query.luts);
297
298        let scale = 1.0 / (dim as f32).sqrt();
299
300        // Dequantize the dot product
301        let sum_positive = code.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
302        let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
303
304        // <q, o_bar> = scale * (2 * sum_positive - sum_all)
305        let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
306
307        // Estimate <q, o> using the corrective factor <o, o_bar>
308        let q_o_estimate = if code.self_dot.abs() > 1e-6 {
309            q_obar_dot / code.self_dot
310        } else {
311            q_obar_dot
312        };
313
314        // Clamp the inner product to valid range [-1, 1]
315        let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
316
317        // Compute squared distance
318        let dist_sq = code.dist_to_centroid * code.dist_to_centroid
319            + query.dist_to_centroid * query.dist_to_centroid
320            - 2.0 * code.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
321
322        dist_sq.max(0.0)
323    }
324
325    /// Memory usage in bytes
326    pub fn size_bytes(&self) -> usize {
327        self.random_signs.len() + self.random_perm.len() * 4 + 64
328    }
329}
330
331impl Quantizer for RaBitQCodebook {
332    type Code = QuantizedVector;
333    type Config = RaBitQConfig;
334    type QueryData = QuantizedQuery;
335
336    fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
337        self.encode(vector, centroid)
338    }
339
340    fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
341        self.prepare_query(query, centroid)
342    }
343
344    fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
345        self.estimate_distance(query_data, code)
346    }
347
348    fn size_bytes(&self) -> usize {
349        self.size_bytes()
350    }
351}
352
353// ============================================================================
354// SIMD-accelerated LUT dot product
355// ============================================================================
356
357/// SIMD-accelerated LUT dot product for RaBitQ
358#[inline]
359fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
360    #[cfg(target_arch = "aarch64")]
361    {
362        if let Some(result) = lut_dot_product_neon(bits, luts) {
363            return result;
364        }
365    }
366
367    #[cfg(target_arch = "x86_64")]
368    {
369        if is_x86_feature_detected!("ssse3") {
370            unsafe {
371                if let Some(result) = lut_dot_product_ssse3(bits, luts) {
372                    return result;
373                }
374            }
375        }
376    }
377
378    lut_dot_product_scalar(bits, luts)
379}
380
381/// Scalar implementation of LUT dot product
382#[inline]
383fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
384    let mut dot_sum = 0u32;
385
386    for (lut_idx, lut) in luts.iter().enumerate() {
387        let base_bit = lut_idx * 4;
388        let byte_idx = base_bit / 8;
389        let bit_offset = base_bit % 8;
390
391        let byte = bits.get(byte_idx).copied().unwrap_or(0);
392        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
393
394        let pattern = if bit_offset <= 4 {
395            (byte >> bit_offset) & 0x0F
396        } else {
397            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
398        };
399
400        dot_sum += lut[pattern as usize] as u32;
401    }
402
403    dot_sum
404}
405
406/// NEON-accelerated LUT dot product (ARM64)
407#[cfg(target_arch = "aarch64")]
408#[inline]
409fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
410    if luts.len() < 8 {
411        return None;
412    }
413
414    let mut total = 0u32;
415    let num_luts = luts.len();
416    let mut lut_idx = 0;
417
418    while lut_idx + 2 <= num_luts {
419        let base_bit0 = lut_idx * 4;
420        let base_bit1 = (lut_idx + 1) * 4;
421
422        let byte_idx0 = base_bit0 / 8;
423        let bit_offset0 = base_bit0 % 8;
424        let byte_idx1 = base_bit1 / 8;
425        let bit_offset1 = base_bit1 % 8;
426
427        let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
428        let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
429        let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
430        let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
431
432        let pattern0 = if bit_offset0 <= 4 {
433            (byte0 >> bit_offset0) & 0x0F
434        } else {
435            ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
436        };
437
438        let pattern1 = if bit_offset1 <= 4 {
439            (byte1 >> bit_offset1) & 0x0F
440        } else {
441            ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
442        };
443
444        total += luts[lut_idx][pattern0 as usize] as u32;
445        total += luts[lut_idx + 1][pattern1 as usize] as u32;
446
447        lut_idx += 2;
448    }
449
450    while lut_idx < num_luts {
451        let base_bit = lut_idx * 4;
452        let byte_idx = base_bit / 8;
453        let bit_offset = base_bit % 8;
454
455        let byte = bits.get(byte_idx).copied().unwrap_or(0);
456        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
457
458        let pattern = if bit_offset <= 4 {
459            (byte >> bit_offset) & 0x0F
460        } else {
461            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
462        };
463
464        total += luts[lut_idx][pattern as usize] as u32;
465        lut_idx += 1;
466    }
467
468    Some(total)
469}
470
471/// SSSE3-accelerated LUT dot product (x86_64)
472#[cfg(target_arch = "x86_64")]
473#[target_feature(enable = "ssse3")]
474#[inline]
475unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
476    if luts.len() < 8 {
477        return None;
478    }
479    Some(lut_dot_product_scalar(bits, luts))
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_rabitq_codebook_basic() {
488        let config = RaBitQConfig::new(128);
489        let codebook = RaBitQCodebook::new(config);
490
491        assert_eq!(codebook.random_signs.len(), 128);
492        assert_eq!(codebook.random_perm.len(), 128);
493    }
494
495    #[test]
496    fn test_encode_decode() {
497        let config = RaBitQConfig::new(64);
498        let codebook = RaBitQCodebook::new(config);
499
500        let vector: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
501        let code = codebook.encode(&vector, None);
502
503        assert_eq!(code.bits.len(), 8); // 64 bits = 8 bytes
504        assert!(code.dist_to_centroid > 0.0);
505    }
506
507    #[test]
508    fn test_distance_estimation() {
509        let config = RaBitQConfig::new(64);
510        let codebook = RaBitQCodebook::new(config);
511
512        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
513        let v1: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
514        let v2: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
515
516        let code = codebook.encode(&v1, None);
517        let query = codebook.prepare_query(&v2, None);
518
519        let estimated = codebook.estimate_distance(&query, &code);
520        assert!(estimated >= 0.0);
521    }
522
523    #[test]
524    fn test_quantizer_trait() {
525        let config = RaBitQConfig::new(32);
526        let codebook = RaBitQCodebook::new(config);
527
528        let vector: Vec<f32> = (0..32).map(|i| i as f32 / 32.0).collect();
529        let query: Vec<f32> = (0..32).map(|i| (31 - i) as f32 / 32.0).collect();
530
531        // Use trait methods
532        let code = Quantizer::encode(&codebook, &vector, None);
533        let query_data = Quantizer::prepare_query(&codebook, &query, None);
534        let dist = Quantizer::compute_distance(&codebook, &query_data, &code);
535
536        assert!(dist >= 0.0);
537    }
538}