Skip to main content

reddb_server/storage/engine/
binary_quantize.rs

1//! Binary Quantization for Vector Embeddings
2//!
3//! Compresses fp32 vectors to binary (1 bit per dimension) for ultra-fast
4//! approximate nearest neighbor search using Hamming distance.
5//!
6//! # Compression Ratio
7//!
8//! - fp32: 4 bytes per dimension
9//! - binary: 1 bit per dimension = 32x compression
10//!
11//! Example: 1024-dim vector
12//! - fp32: 4096 bytes
13//! - binary: 128 bytes
14//!
15//! # Algorithm
16//!
17//! Simple sign-based quantization:
18//! - positive values → 1
19//! - negative/zero values → 0
20//!
21//! For normalized embeddings (e.g., from sentence transformers),
22//! this preserves ~95-97% of retrieval quality.
23//!
24//! # Usage
25//!
26//! ```ignore
27//! // Quantize a vector
28//! let binary = BinaryVector::from_f32(&embedding);
29//!
30//! // Compute Hamming distance (number of differing bits)
31//! let distance = binary.hamming_distance(&other);
32//!
33//! // For retrieval: lower Hamming distance = more similar
34//! ```
35//!
36//! # References
37//!
38//! - "Embedding Quantization" - HuggingFace Blog
39//! - Binary embedding with Matryoshka representation learning
40
41use std::cmp::Ordering;
42
43/// Binary quantized vector stored as packed u64 words
44#[derive(Clone, Debug)]
45pub struct BinaryVector {
46    /// Packed binary data (each u64 holds 64 dimensions)
47    data: Vec<u64>,
48    /// Original dimensionality
49    dim: usize,
50}
51
52impl BinaryVector {
53    /// Create a binary vector from fp32 values using sign-based quantization
54    ///
55    /// Positive values become 1, negative/zero become 0
56    pub fn from_f32(values: &[f32]) -> Self {
57        let dim = values.len();
58        let n_words = dim.div_ceil(64); // Ceiling division
59        let mut data = vec![0u64; n_words];
60
61        for (i, &v) in values.iter().enumerate() {
62            if v > 0.0 {
63                let word_idx = i / 64;
64                let bit_idx = i % 64;
65                data[word_idx] |= 1u64 << bit_idx;
66            }
67        }
68
69        Self { data, dim }
70    }
71
72    /// Create a binary vector from threshold-based quantization
73    ///
74    /// Values above threshold become 1, below become 0
75    pub fn from_f32_threshold(values: &[f32], threshold: f32) -> Self {
76        let dim = values.len();
77        let n_words = dim.div_ceil(64);
78        let mut data = vec![0u64; n_words];
79
80        for (i, &v) in values.iter().enumerate() {
81            if v > threshold {
82                let word_idx = i / 64;
83                let bit_idx = i % 64;
84                data[word_idx] |= 1u64 << bit_idx;
85            }
86        }
87
88        Self { data, dim }
89    }
90
91    /// Create a binary vector from median-based quantization
92    ///
93    /// Values above median become 1, below become 0.
94    /// Better for non-normalized vectors.
95    pub fn from_f32_median(values: &[f32]) -> Self {
96        let mut sorted = values.to_vec();
97        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
98        let median = if sorted.len().is_multiple_of(2) {
99            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
100        } else {
101            sorted[sorted.len() / 2]
102        };
103
104        Self::from_f32_threshold(values, median)
105    }
106
107    /// Create from raw packed data
108    pub fn from_raw(data: Vec<u64>, dim: usize) -> Self {
109        Self { data, dim }
110    }
111
112    /// Get the dimensionality of the original vector
113    #[inline]
114    pub fn dim(&self) -> usize {
115        self.dim
116    }
117
118    /// Get the packed binary data
119    #[inline]
120    pub fn data(&self) -> &[u64] {
121        &self.data
122    }
123
124    /// Get size in bytes
125    #[inline]
126    pub fn size_bytes(&self) -> usize {
127        self.data.len() * 8
128    }
129
130    /// Compute Hamming distance to another binary vector
131    ///
132    /// Hamming distance = number of positions where bits differ.
133    /// Uses popcount which is a single CPU instruction on modern x86.
134    #[inline]
135    pub fn hamming_distance(&self, other: &Self) -> u32 {
136        debug_assert_eq!(self.dim, other.dim, "Dimensions must match");
137
138        hamming_distance_simd(&self.data, &other.data)
139    }
140
141    /// Compute normalized Hamming distance (0.0 to 1.0)
142    ///
143    /// 0.0 = identical, 1.0 = completely different
144    #[inline]
145    pub fn hamming_distance_normalized(&self, other: &Self) -> f32 {
146        let dist = self.hamming_distance(other) as f32;
147        dist / self.dim as f32
148    }
149
150    /// Convert Hamming distance to approximate cosine similarity
151    ///
152    /// For normalized embeddings, there's a relationship between
153    /// Hamming distance and cosine similarity:
154    /// cos_sim ≈ 1 - 2 * (hamming_dist / dim)
155    #[inline]
156    pub fn approx_cosine_similarity(&self, other: &Self) -> f32 {
157        let normalized_dist = self.hamming_distance_normalized(other);
158        1.0 - 2.0 * normalized_dist
159    }
160}
161
162// ============================================================================
163// Hamming Distance with SIMD
164// ============================================================================
165
166/// Compute Hamming distance between two packed binary vectors using SIMD
167#[inline]
168pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
169    debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
170
171    #[cfg(target_arch = "x86_64")]
172    {
173        if is_x86_feature_detected!("popcnt") {
174            return unsafe { hamming_distance_popcnt(a, b) };
175        }
176    }
177
178    hamming_distance_scalar(a, b)
179}
180
181/// Scalar fallback for Hamming distance
182#[inline]
183fn hamming_distance_scalar(a: &[u64], b: &[u64]) -> u32 {
184    let mut count = 0u32;
185    for (x, y) in a.iter().zip(b.iter()) {
186        count += (x ^ y).count_ones();
187    }
188    count
189}
190
191/// SIMD-accelerated Hamming distance using popcount instruction
192#[cfg(target_arch = "x86_64")]
193#[target_feature(enable = "popcnt")]
194#[inline]
195unsafe fn hamming_distance_popcnt(a: &[u64], b: &[u64]) -> u32 {
196    use std::arch::x86_64::_popcnt64;
197
198    let mut count = 0i32;
199
200    // Process 4 u64s at a time for better instruction-level parallelism
201    let chunks = a.len() / 4;
202    for i in 0..chunks {
203        let idx = i * 4;
204        let xor0 = a[idx] ^ b[idx];
205        let xor1 = a[idx + 1] ^ b[idx + 1];
206        let xor2 = a[idx + 2] ^ b[idx + 2];
207        let xor3 = a[idx + 3] ^ b[idx + 3];
208
209        count += _popcnt64(xor0 as i64);
210        count += _popcnt64(xor1 as i64);
211        count += _popcnt64(xor2 as i64);
212        count += _popcnt64(xor3 as i64);
213    }
214
215    // Handle remaining elements
216    for i in (chunks * 4)..a.len() {
217        count += _popcnt64((a[i] ^ b[i]) as i64);
218    }
219
220    count as u32
221}
222
223// ============================================================================
224// Batch Operations
225// ============================================================================
226
227/// Index of binary vectors for fast batch search
228#[derive(Clone)]
229pub struct BinaryIndex {
230    /// All binary vectors (flattened: n_vectors * n_words)
231    vectors: Vec<u64>,
232    /// Number of u64 words per vector
233    words_per_vector: usize,
234    /// Number of vectors
235    n_vectors: usize,
236    /// Original dimension
237    dim: usize,
238}
239
240impl BinaryIndex {
241    /// Create a new binary index
242    pub fn new(dim: usize) -> Self {
243        let words_per_vector = dim.div_ceil(64);
244        Self {
245            vectors: Vec::new(),
246            words_per_vector,
247            n_vectors: 0,
248            dim,
249        }
250    }
251
252    /// Create with pre-allocated capacity
253    pub fn with_capacity(dim: usize, capacity: usize) -> Self {
254        let words_per_vector = dim.div_ceil(64);
255        Self {
256            vectors: Vec::with_capacity(capacity * words_per_vector),
257            words_per_vector,
258            n_vectors: 0,
259            dim,
260        }
261    }
262
263    /// Add a vector to the index
264    pub fn add(&mut self, vector: &BinaryVector) {
265        debug_assert_eq!(vector.dim, self.dim, "Dimension mismatch");
266        self.vectors.extend_from_slice(&vector.data);
267        self.n_vectors += 1;
268    }
269
270    /// Add a fp32 vector (will be quantized)
271    pub fn add_f32(&mut self, vector: &[f32]) {
272        let binary = BinaryVector::from_f32(vector);
273        self.add(&binary);
274    }
275
276    /// Get number of vectors in the index
277    #[inline]
278    pub fn len(&self) -> usize {
279        self.n_vectors
280    }
281
282    /// Check if index is empty
283    #[inline]
284    pub fn is_empty(&self) -> bool {
285        self.n_vectors == 0
286    }
287
288    /// Get memory usage in bytes
289    pub fn memory_bytes(&self) -> usize {
290        self.vectors.len() * 8
291    }
292
293    /// Get a vector by index
294    pub fn get(&self, idx: usize) -> Option<BinaryVector> {
295        if idx >= self.n_vectors {
296            return None;
297        }
298        let start = idx * self.words_per_vector;
299        let end = start + self.words_per_vector;
300        Some(BinaryVector::from_raw(
301            self.vectors[start..end].to_vec(),
302            self.dim,
303        ))
304    }
305
306    /// Search for k nearest neighbors using Hamming distance
307    ///
308    /// Returns (index, hamming_distance) pairs sorted by distance.
309    pub fn search(&self, query: &BinaryVector, k: usize) -> Vec<(usize, u32)> {
310        if self.n_vectors == 0 {
311            return Vec::new();
312        }
313
314        let k = k.min(self.n_vectors);
315        let mut results: Vec<(usize, u32)> = Vec::with_capacity(self.n_vectors);
316
317        // Compute distances to all vectors
318        for i in 0..self.n_vectors {
319            let start = i * self.words_per_vector;
320            let end = start + self.words_per_vector;
321            let dist = hamming_distance_simd(&query.data, &self.vectors[start..end]);
322            results.push((i, dist));
323        }
324
325        // Partial sort to get top-k
326        if k < self.n_vectors {
327            results.select_nth_unstable_by_key(k - 1, |&(_, d)| d);
328            results.truncate(k);
329        }
330        results.sort_by_key(|&(_, d)| d);
331
332        results
333    }
334
335    /// Search from fp32 query (will be quantized)
336    pub fn search_f32(&self, query: &[f32], k: usize) -> Vec<(usize, u32)> {
337        let binary_query = BinaryVector::from_f32(query);
338        self.search(&binary_query, k)
339    }
340
341    /// Batch search for multiple queries
342    ///
343    /// More efficient than individual searches due to cache locality.
344    pub fn batch_search(&self, queries: &[BinaryVector], k: usize) -> Vec<Vec<(usize, u32)>> {
345        queries.iter().map(|q| self.search(q, k)).collect()
346    }
347}
348
349// ============================================================================
350// Distance Result for Integration
351// ============================================================================
352
353/// Result from binary search with rescoring capability
354#[derive(Debug, Clone)]
355pub struct BinarySearchResult {
356    /// Vector index
357    pub id: usize,
358    /// Hamming distance (lower = more similar)
359    pub hamming_distance: u32,
360    /// Optional rescored distance (set during int8/fp32 rescoring)
361    pub rescored_distance: Option<f32>,
362}
363
364impl BinarySearchResult {
365    pub fn new(id: usize, hamming_distance: u32) -> Self {
366        Self {
367            id,
368            hamming_distance,
369            rescored_distance: None,
370        }
371    }
372
373    /// Get the final distance (rescored if available, otherwise Hamming)
374    pub fn final_distance(&self) -> f32 {
375        self.rescored_distance
376            .unwrap_or(self.hamming_distance as f32)
377    }
378}
379
380// ============================================================================
381// Tests
382// ============================================================================
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_binary_quantization_positive() {
390        let values = vec![1.0, -1.0, 0.5, -0.5, 0.0, 2.0, -2.0, 0.1];
391        let binary = BinaryVector::from_f32(&values);
392
393        // positive: indices 0, 2, 5, 7
394        // Expected bits: 0b10100101 = 165
395        assert_eq!(binary.data[0] & 0xFF, 0b10100101);
396    }
397
398    #[test]
399    fn test_hamming_distance_identical() {
400        let v1 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
401        let v2 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
402        assert_eq!(v1.hamming_distance(&v2), 0);
403    }
404
405    #[test]
406    fn test_hamming_distance_opposite() {
407        let v1 = BinaryVector::from_f32(&[1.0, 1.0, 1.0, 1.0]);
408        let v2 = BinaryVector::from_f32(&[-1.0, -1.0, -1.0, -1.0]);
409        assert_eq!(v1.hamming_distance(&v2), 4);
410    }
411
412    #[test]
413    fn test_hamming_distance_partial() {
414        let v1 = BinaryVector::from_f32(&[1.0, 1.0, -1.0, -1.0]);
415        let v2 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
416        assert_eq!(v1.hamming_distance(&v2), 2);
417    }
418
419    #[test]
420    fn test_large_vector() {
421        // Test 1024-dim vector (common embedding size)
422        let v1: Vec<f32> = (0..1024)
423            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
424            .collect();
425        let v2: Vec<f32> = (0..1024)
426            .map(|i| if i % 3 == 0 { 1.0 } else { -1.0 })
427            .collect();
428
429        let b1 = BinaryVector::from_f32(&v1);
430        let b2 = BinaryVector::from_f32(&v2);
431
432        // Verify size: 1024 bits = 128 bytes = 16 u64s
433        assert_eq!(b1.size_bytes(), 128);
434        assert_eq!(b1.data.len(), 16);
435
436        let dist = b1.hamming_distance(&b2);
437        assert!(dist > 0 && dist < 1024);
438    }
439
440    #[test]
441    fn test_binary_index_search() {
442        let mut index = BinaryIndex::new(64);
443
444        // Add some vectors
445        let v1 = vec![1.0f32; 64];
446        let v2 = vec![-1.0f32; 64];
447        let v3: Vec<f32> = (0..64).map(|i| if i < 32 { 1.0 } else { -1.0 }).collect();
448
449        index.add_f32(&v1);
450        index.add_f32(&v2);
451        index.add_f32(&v3);
452
453        // Search for v1-like vector
454        let query: Vec<f32> = (0..64).map(|i| if i < 60 { 1.0 } else { -1.0 }).collect();
455        let results = index.search_f32(&query, 3);
456
457        // v1 should be closest (only 4 bits different)
458        assert_eq!(results[0].0, 0);
459        assert_eq!(results[0].1, 4);
460    }
461
462    #[test]
463    fn test_approx_cosine() {
464        let v1 = BinaryVector::from_f32(&[1.0; 128]);
465        let v2 = BinaryVector::from_f32(&[1.0; 128]);
466        let sim = v1.approx_cosine_similarity(&v2);
467        assert!((sim - 1.0).abs() < 0.001); // Identical = 1.0
468
469        let v3 = BinaryVector::from_f32(&[-1.0; 128]);
470        let sim2 = v1.approx_cosine_similarity(&v3);
471        assert!((sim2 - (-1.0)).abs() < 0.001); // Opposite = -1.0
472    }
473
474    #[test]
475    fn test_compression_ratio() {
476        // fp32: 1024 * 4 = 4096 bytes
477        // binary: 1024 / 8 = 128 bytes
478        // ratio: 32x
479
480        let fp32_size = 1024 * 4;
481        let binary = BinaryVector::from_f32(&vec![1.0; 1024]);
482        let binary_size = binary.size_bytes();
483
484        assert_eq!(binary_size, 128);
485        assert_eq!(fp32_size / binary_size, 32);
486    }
487}