Skip to main content

hermes_core/structures/vector/quantization/
rabitq.rs

1//! RaBitQ: Randomized Binary Quantization for Dense Vector Search
2//!
3//! Implementation of the RaBitQ algorithm from SIGMOD 2024:
4//! "RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound
5//! for Approximate Nearest Neighbor Search"
6//!
7//! Key features:
8//! - 32x compression (D-dimensional float32 → D-bit binary + 2 floats)
9//! - Theoretical error bound for distance estimation
10//! - SIMD-accelerated distance computation via LUT
11//! - Asymmetric quantization (binary data, 4-bit query)
12
13use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(target_arch = "x86_64")]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27/// Configuration for RaBitQ quantization
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RaBitQConfig {
30    /// Dimensionality of vectors
31    pub dim: usize,
32    /// Number of bits for query quantization (typically 4)
33    pub query_bits: u8,
34    /// Random seed for reproducible orthogonal matrix
35    pub seed: u64,
36}
37
38impl RaBitQConfig {
39    pub fn new(dim: usize) -> Self {
40        Self {
41            dim,
42            query_bits: 4,
43            seed: 42,
44        }
45    }
46
47    pub fn with_seed(mut self, seed: u64) -> Self {
48        self.seed = seed;
49        self
50    }
51}
52
53/// Quantized representation of a single vector
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QuantizedVector {
56    /// Binary quantization code (D bits packed into bytes)
57    pub bits: Vec<u8>,
58    /// Distance from original vector to centroid: ||o_raw - c||
59    pub dist_to_centroid: f32,
60    /// Dot product of normalized vector with its quantized form: <o, o_bar>
61    pub self_dot: f32,
62    /// Number of 1-bits in the binary code (for fast computation)
63    pub popcount: u32,
64}
65
66impl QuantizedCode for QuantizedVector {
67    fn size_bytes(&self) -> usize {
68        self.bits.len() + 4 + 4 + 4 // bits + dist_to_centroid + self_dot + popcount
69    }
70}
71
72/// Pre-computed query representation for fast distance estimation
73#[derive(Debug, Clone)]
74pub struct QuantizedQuery {
75    /// 4-bit scalar quantized query (packed, 2 values per byte)
76    pub quantized: Vec<u8>,
77    /// Distance from query to centroid: ||q_raw - c||
78    pub dist_to_centroid: f32,
79    /// Lower bound of quantization range
80    pub lower: f32,
81    /// Width of quantization range (upper - lower)
82    pub width: f32,
83    /// Sum of all quantized values
84    pub sum: u32,
85    /// Look-up tables for fast dot product (16 entries per 4-bit sub-segment)
86    pub luts: Vec<[u16; 16]>,
87}
88
89/// RaBitQ codebook (random transform parameters)
90///
91/// Trained once, shared across all segments for merge compatibility.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RaBitQCodebook {
94    /// Configuration
95    pub config: RaBitQConfig,
96    /// Random signs for transform (+1 or -1)
97    pub random_signs: Vec<i8>,
98    /// Random permutation for transform
99    pub random_perm: Vec<u32>,
100    /// Version for merge compatibility checking
101    pub version: u64,
102}
103
104impl RaBitQCodebook {
105    /// Create a new RaBitQ codebook with random transform
106    pub fn new(config: RaBitQConfig) -> Self {
107        let dim = config.dim;
108        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
109
110        // Generate random signs (+1 or -1) for each dimension
111        let random_signs: Vec<i8> = (0..dim)
112            .map(|_| if rng.random::<bool>() { 1 } else { -1 })
113            .collect();
114
115        // Generate random permutation
116        let mut random_perm: Vec<u32> = (0..dim as u32).collect();
117        for i in (1..dim).rev() {
118            let j = rng.random_range(0..=i);
119            random_perm.swap(i, j);
120        }
121
122        let version = std::time::SystemTime::now()
123            .duration_since(std::time::UNIX_EPOCH)
124            .unwrap_or_default()
125            .as_millis() as u64;
126
127        Self {
128            config,
129            random_signs,
130            random_perm,
131            version,
132        }
133    }
134
135    /// Encode a vector to binary quantized form
136    ///
137    /// If centroid is provided, encodes the residual (vector - centroid).
138    pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> QuantizedVector {
139        let dim = self.config.dim;
140
141        // Step 1: Subtract centroid (if provided) and compute norm
142        let centered: Vec<f32> = if let Some(c) = centroid {
143            vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
144        } else {
145            vector.to_vec()
146        };
147
148        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
149        let dist_to_centroid = norm;
150
151        // Normalize (handle zero vector)
152        let normalized: Vec<f32> = if norm > 1e-10 {
153            centered.iter().map(|x| x / norm).collect()
154        } else {
155            centered
156        };
157
158        // Step 2: Apply random transform (sign flip + permutation)
159        let transformed: Vec<f32> = (0..dim)
160            .map(|i| {
161                let src_idx = self.random_perm[i] as usize;
162                normalized[src_idx] * self.random_signs[src_idx] as f32
163            })
164            .collect();
165
166        // Step 3: Binary quantize
167        let num_bytes = dim.div_ceil(8);
168        let mut bits = vec![0u8; num_bytes];
169        let mut popcount = 0u32;
170
171        for i in 0..dim {
172            if transformed[i] >= 0.0 {
173                bits[i / 8] |= 1 << (i % 8);
174                popcount += 1;
175            }
176        }
177
178        // Step 4: Compute self dot product <o, o_bar>
179        let scale = 1.0 / (dim as f32).sqrt();
180        let mut self_dot = 0.0f32;
181        for i in 0..dim {
182            let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
183                scale
184            } else {
185                -scale
186            };
187            self_dot += transformed[i] * o_bar_i;
188        }
189
190        QuantizedVector {
191            bits,
192            dist_to_centroid,
193            self_dot,
194            popcount,
195        }
196    }
197
198    /// Prepare a query for fast distance estimation
199    pub fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> QuantizedQuery {
200        let dim = self.config.dim;
201
202        // Step 1: Subtract centroid (if provided) and compute norm
203        let centered: Vec<f32> = if let Some(c) = centroid {
204            query.iter().zip(c).map(|(&v, &c)| v - c).collect()
205        } else {
206            query.to_vec()
207        };
208
209        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
210        let dist_to_centroid = norm;
211
212        // Normalize
213        let normalized: Vec<f32> = if norm > 1e-10 {
214            centered.iter().map(|x| x / norm).collect()
215        } else {
216            centered
217        };
218
219        // Step 2: Apply random transform
220        let transformed: Vec<f32> = (0..dim)
221            .map(|i| {
222                let src_idx = self.random_perm[i] as usize;
223                normalized[src_idx] * self.random_signs[src_idx] as f32
224            })
225            .collect();
226
227        // Step 3: Scalar quantize to 4-bit
228        let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
229        let max_val = transformed
230            .iter()
231            .cloned()
232            .fold(f32::NEG_INFINITY, f32::max);
233        let lower = min_val;
234        let width = if max_val > min_val {
235            max_val - min_val
236        } else {
237            1.0
238        };
239
240        // Quantize to 0-15 range
241        let quantized_vals: Vec<u8> = transformed
242            .iter()
243            .map(|&x| {
244                let normalized = (x - lower) / width;
245                (normalized * 15.0).round().clamp(0.0, 15.0) as u8
246            })
247            .collect();
248
249        // Pack into bytes (2 values per byte)
250        let num_bytes = dim.div_ceil(2);
251        let mut quantized = vec![0u8; num_bytes];
252        for i in 0..dim {
253            if i % 2 == 0 {
254                quantized[i / 2] |= quantized_vals[i];
255            } else {
256                quantized[i / 2] |= quantized_vals[i] << 4;
257            }
258        }
259
260        // Compute sum of quantized values
261        let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
262
263        // Step 4: Build LUTs for fast dot product
264        let num_luts = dim.div_ceil(4);
265        let mut luts = vec![[0u16; 16]; num_luts];
266
267        for (lut_idx, lut) in luts.iter_mut().enumerate() {
268            let base_dim = lut_idx * 4;
269            for pattern in 0u8..16 {
270                let mut dot = 0u16;
271                for bit in 0..4 {
272                    let dim_idx = base_dim + bit;
273                    if dim_idx < dim && (pattern >> bit) & 1 == 1 {
274                        dot += quantized_vals[dim_idx] as u16;
275                    }
276                }
277                lut[pattern as usize] = dot;
278            }
279        }
280
281        QuantizedQuery {
282            quantized,
283            dist_to_centroid,
284            lower,
285            width,
286            sum,
287            luts,
288        }
289    }
290
291    /// Estimate squared distance between query and a quantized vector
292    pub fn estimate_distance(&self, query: &QuantizedQuery, code: &QuantizedVector) -> f32 {
293        let dim = self.config.dim;
294
295        // Compute dot product using SIMD-accelerated LUT lookup
296        let dot_sum = lut_dot_product_simd(&code.bits, &query.luts);
297
298        let scale = 1.0 / (dim as f32).sqrt();
299
300        // Dequantize the dot product
301        let sum_positive = code.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
302        let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
303
304        // <q, o_bar> = scale * (2 * sum_positive - sum_all)
305        let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
306
307        // Estimate <q, o> using the corrective factor <o, o_bar>
308        let q_o_estimate = if code.self_dot.abs() > 1e-6 {
309            q_obar_dot / code.self_dot
310        } else {
311            q_obar_dot
312        };
313
314        // Clamp the inner product to valid range [-1, 1]
315        let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
316
317        // Compute squared distance
318        let dist_sq = code.dist_to_centroid * code.dist_to_centroid
319            + query.dist_to_centroid * query.dist_to_centroid
320            - 2.0 * code.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
321
322        dist_sq.max(0.0)
323    }
324
325    /// Memory usage in bytes
326    pub fn size_bytes(&self) -> usize {
327        self.random_signs.len() + self.random_perm.len() * 4 + 64
328    }
329
330    /// Estimated memory usage in bytes (alias for size_bytes)
331    pub fn estimated_memory_bytes(&self) -> usize {
332        self.size_bytes()
333    }
334}
335
336impl Quantizer for RaBitQCodebook {
337    type Code = QuantizedVector;
338    type Config = RaBitQConfig;
339    type QueryData = QuantizedQuery;
340
341    fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
342        self.encode(vector, centroid)
343    }
344
345    fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
346        self.prepare_query(query, centroid)
347    }
348
349    fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
350        self.estimate_distance(query_data, code)
351    }
352
353    fn size_bytes(&self) -> usize {
354        self.size_bytes()
355    }
356}
357
358// ============================================================================
359// SIMD-accelerated LUT dot product
360// ============================================================================
361
362/// SIMD-accelerated LUT dot product for RaBitQ
363#[inline]
364fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
365    #[cfg(target_arch = "aarch64")]
366    {
367        if let Some(result) = lut_dot_product_neon(bits, luts) {
368            return result;
369        }
370    }
371
372    #[cfg(target_arch = "x86_64")]
373    {
374        if is_x86_feature_detected!("ssse3") {
375            unsafe {
376                if let Some(result) = lut_dot_product_ssse3(bits, luts) {
377                    return result;
378                }
379            }
380        }
381    }
382
383    lut_dot_product_scalar(bits, luts)
384}
385
386/// Scalar implementation of LUT dot product
387#[inline]
388fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
389    let mut dot_sum = 0u32;
390
391    for (lut_idx, lut) in luts.iter().enumerate() {
392        let base_bit = lut_idx * 4;
393        let byte_idx = base_bit / 8;
394        let bit_offset = base_bit % 8;
395
396        let byte = bits.get(byte_idx).copied().unwrap_or(0);
397        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
398
399        let pattern = if bit_offset <= 4 {
400            (byte >> bit_offset) & 0x0F
401        } else {
402            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
403        };
404
405        dot_sum += lut[pattern as usize] as u32;
406    }
407
408    dot_sum
409}
410
411/// NEON-accelerated LUT dot product (ARM64)
412#[cfg(target_arch = "aarch64")]
413#[inline]
414fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
415    if luts.len() < 8 {
416        return None;
417    }
418
419    let mut total = 0u32;
420    let num_luts = luts.len();
421    let mut lut_idx = 0;
422
423    while lut_idx + 2 <= num_luts {
424        let base_bit0 = lut_idx * 4;
425        let base_bit1 = (lut_idx + 1) * 4;
426
427        let byte_idx0 = base_bit0 / 8;
428        let bit_offset0 = base_bit0 % 8;
429        let byte_idx1 = base_bit1 / 8;
430        let bit_offset1 = base_bit1 % 8;
431
432        let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
433        let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
434        let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
435        let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
436
437        let pattern0 = if bit_offset0 <= 4 {
438            (byte0 >> bit_offset0) & 0x0F
439        } else {
440            ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
441        };
442
443        let pattern1 = if bit_offset1 <= 4 {
444            (byte1 >> bit_offset1) & 0x0F
445        } else {
446            ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
447        };
448
449        total += luts[lut_idx][pattern0 as usize] as u32;
450        total += luts[lut_idx + 1][pattern1 as usize] as u32;
451
452        lut_idx += 2;
453    }
454
455    while lut_idx < num_luts {
456        let base_bit = lut_idx * 4;
457        let byte_idx = base_bit / 8;
458        let bit_offset = base_bit % 8;
459
460        let byte = bits.get(byte_idx).copied().unwrap_or(0);
461        let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
462
463        let pattern = if bit_offset <= 4 {
464            (byte >> bit_offset) & 0x0F
465        } else {
466            ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
467        };
468
469        total += luts[lut_idx][pattern as usize] as u32;
470        lut_idx += 1;
471    }
472
473    Some(total)
474}
475
476/// SSSE3-accelerated LUT dot product (x86_64)
477#[cfg(target_arch = "x86_64")]
478#[target_feature(enable = "ssse3")]
479#[inline]
480unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
481    if luts.len() < 8 {
482        return None;
483    }
484    Some(lut_dot_product_scalar(bits, luts))
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_rabitq_codebook_basic() {
493        let config = RaBitQConfig::new(128);
494        let codebook = RaBitQCodebook::new(config);
495
496        assert_eq!(codebook.random_signs.len(), 128);
497        assert_eq!(codebook.random_perm.len(), 128);
498    }
499
500    #[test]
501    fn test_encode_decode() {
502        let config = RaBitQConfig::new(64);
503        let codebook = RaBitQCodebook::new(config);
504
505        let vector: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
506        let code = codebook.encode(&vector, None);
507
508        assert_eq!(code.bits.len(), 8); // 64 bits = 8 bytes
509        assert!(code.dist_to_centroid > 0.0);
510    }
511
512    #[test]
513    fn test_distance_estimation() {
514        let config = RaBitQConfig::new(64);
515        let codebook = RaBitQCodebook::new(config);
516
517        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
518        let v1: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
519        let v2: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
520
521        let code = codebook.encode(&v1, None);
522        let query = codebook.prepare_query(&v2, None);
523
524        let estimated = codebook.estimate_distance(&query, &code);
525        assert!(estimated >= 0.0);
526    }
527
528    #[test]
529    fn test_quantizer_trait() {
530        let config = RaBitQConfig::new(32);
531        let codebook = RaBitQCodebook::new(config);
532
533        let vector: Vec<f32> = (0..32).map(|i| i as f32 / 32.0).collect();
534        let query: Vec<f32> = (0..32).map(|i| (31 - i) as f32 / 32.0).collect();
535
536        // Use trait methods
537        let code = Quantizer::encode(&codebook, &vector, None);
538        let query_data = Quantizer::prepare_query(&codebook, &query, None);
539        let dist = Quantizer::compute_distance(&codebook, &query_data, &code);
540
541        assert!(dist >= 0.0);
542    }
543}