Skip to main content

reddb_server/storage/engine/
pq.rs

1//! Product Quantization (PQ) for Vector Compression
2//!
3//! Compresses high-dimensional vectors by splitting into sub-vectors
4//! and quantizing each independently to a codebook.
5//!
6//! # Design
7//!
8//! - Split D-dimensional vector into M sub-vectors of D/M dimensions
9//! - Each sub-vector is quantized to one of K centroids (codebook)
10//! - Storage: M bytes per vector (with K=256)
11//! - Distance: Asymmetric distance computation using lookup tables
12//!
13//! # Example
14//!
15//! ```ignore
16//! let mut pq = ProductQuantizer::new(PQConfig {
17//!     dimension: 384,
18//!     n_subvectors: 48,      // 48 sub-vectors of 8 dimensions each
19//!     n_centroids: 256,      // 8-bit codes
20//! });
21//!
22//! // Train on sample vectors
23//! pq.train(&training_vectors);
24//!
25//! // Encode vectors
26//! let codes = pq.encode(&vector);
27//!
28//! // Compute distance using lookup table
29//! let distances = pq.compute_distances(&query, &code_database);
30//! ```
31
32use std::collections::HashMap;
33
34use super::distance::{cmp_f32, l2_squared_simd};
35use super::hnsw::NodeId;
36
37/// PQ configuration
38#[derive(Clone, Debug)]
39pub struct PQConfig {
40    /// Full vector dimension
41    pub dimension: usize,
42    /// Number of sub-vectors (M)
43    pub n_subvectors: usize,
44    /// Number of centroids per sub-vector (K, typically 256)
45    pub n_centroids: usize,
46    /// Maximum k-means iterations during training
47    pub max_iterations: usize,
48}
49
50impl Default for PQConfig {
51    fn default() -> Self {
52        Self {
53            dimension: 128,
54            n_subvectors: 8,
55            n_centroids: 256,
56            max_iterations: 25,
57        }
58    }
59}
60
61impl PQConfig {
62    pub fn new(dimension: usize, n_subvectors: usize) -> Self {
63        assert!(
64            dimension.is_multiple_of(n_subvectors),
65            "dimension must be divisible by n_subvectors"
66        );
67        Self {
68            dimension,
69            n_subvectors,
70            n_centroids: 256,
71            max_iterations: 25,
72        }
73    }
74
75    /// Get sub-vector dimension
76    pub fn subvector_dim(&self) -> usize {
77        self.dimension / self.n_subvectors
78    }
79}
80
81/// Codebook for a single sub-vector
82#[derive(Clone)]
83struct Codebook {
84    /// Centroids for this sub-vector [n_centroids x subvector_dim]
85    centroids: Vec<Vec<f32>>,
86    /// Sub-vector dimension
87    dim: usize,
88}
89
90impl Codebook {
91    fn new(dim: usize, n_centroids: usize) -> Self {
92        Self {
93            centroids: vec![vec![0.0; dim]; n_centroids],
94            dim,
95        }
96    }
97
98    /// Train codebook using k-means on sub-vectors
99    fn train(&mut self, subvectors: &[Vec<f32>], max_iterations: usize) {
100        if subvectors.is_empty() {
101            return;
102        }
103
104        let k = self.centroids.len();
105
106        // Initialize centroids using sampling
107        let step = subvectors.len().max(1) / k.max(1);
108        for (i, centroid) in self.centroids.iter_mut().enumerate() {
109            let idx = (i * step).min(subvectors.len() - 1);
110            *centroid = subvectors[idx].clone();
111        }
112
113        // Run k-means
114        for _ in 0..max_iterations {
115            // Assign to nearest centroid
116            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
117            for (i, sv) in subvectors.iter().enumerate() {
118                let nearest = self.find_nearest(sv);
119                assignments[nearest].push(i);
120            }
121
122            // Update centroids
123            let mut converged = true;
124            for (ci, indices) in assignments.iter().enumerate() {
125                if indices.is_empty() {
126                    continue;
127                }
128
129                let mut new_centroid = vec![0.0f32; self.dim];
130                for &idx in indices {
131                    for (j, &val) in subvectors[idx].iter().enumerate() {
132                        new_centroid[j] += val;
133                    }
134                }
135                for val in &mut new_centroid {
136                    *val /= indices.len() as f32;
137                }
138
139                // Check convergence
140                let shift = l2_squared_simd(&new_centroid, &self.centroids[ci]).sqrt();
141                if shift > 1e-4 {
142                    converged = false;
143                }
144
145                self.centroids[ci] = new_centroid;
146            }
147
148            if converged {
149                break;
150            }
151        }
152    }
153
154    /// Find nearest centroid index
155    fn find_nearest(&self, subvector: &[f32]) -> usize {
156        self.centroids
157            .iter()
158            .enumerate()
159            .map(|(i, c)| (i, l2_squared_simd(subvector, c)))
160            .min_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)))
161            .map(|(i, _)| i)
162            .unwrap_or(0)
163    }
164
165    /// Compute distance lookup table for a query sub-vector
166    fn compute_distance_table(&self, query_subvector: &[f32]) -> Vec<f32> {
167        self.centroids
168            .iter()
169            .map(|c| l2_squared_simd(query_subvector, c))
170            .collect()
171    }
172}
173
174/// PQ code for a single vector (M bytes)
175pub type PQCode = Vec<u8>;
176
177/// Product Quantizer
178pub struct ProductQuantizer {
179    config: PQConfig,
180    /// Codebooks for each sub-vector [M codebooks]
181    codebooks: Vec<Codebook>,
182    /// Whether the quantizer has been trained
183    trained: bool,
184}
185
186impl ProductQuantizer {
187    /// Create a new product quantizer
188    pub fn new(config: PQConfig) -> Self {
189        let subdim = config.subvector_dim();
190        let codebooks = (0..config.n_subvectors)
191            .map(|_| Codebook::new(subdim, config.n_centroids))
192            .collect();
193
194        Self {
195            config,
196            codebooks,
197            trained: false,
198        }
199    }
200
201    /// Create with default config for dimension
202    pub fn with_dimension(dimension: usize) -> Self {
203        // Choose sensible defaults
204        let n_subvectors = if dimension >= 64 { 8 } else { 4 };
205        Self::new(PQConfig::new(dimension, n_subvectors))
206    }
207
208    /// Train the product quantizer
209    pub fn train(&mut self, vectors: &[Vec<f32>]) {
210        if vectors.is_empty() {
211            return;
212        }
213
214        let subdim = self.config.subvector_dim();
215
216        // Train each codebook independently
217        for (m, codebook) in self.codebooks.iter_mut().enumerate() {
218            // Extract sub-vectors for this codebook
219            let subvectors: Vec<Vec<f32>> = vectors
220                .iter()
221                .map(|v| v[m * subdim..(m + 1) * subdim].to_vec())
222                .collect();
223
224            codebook.train(&subvectors, self.config.max_iterations);
225        }
226
227        self.trained = true;
228    }
229
230    /// Encode a single vector to PQ codes
231    pub fn encode(&self, vector: &[f32]) -> PQCode {
232        let subdim = self.config.subvector_dim();
233
234        self.codebooks
235            .iter()
236            .enumerate()
237            .map(|(m, codebook)| {
238                let subvector = &vector[m * subdim..(m + 1) * subdim];
239                codebook.find_nearest(subvector) as u8
240            })
241            .collect()
242    }
243
244    /// Encode multiple vectors
245    pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<PQCode> {
246        vectors.iter().map(|v| self.encode(v)).collect()
247    }
248
249    /// Decode PQ codes back to approximate vector
250    pub fn decode(&self, code: &PQCode) -> Vec<f32> {
251        let subdim = self.config.subvector_dim();
252        let mut vector = Vec::with_capacity(self.config.dimension);
253
254        for (m, &c) in code.iter().enumerate() {
255            let centroid = &self.codebooks[m].centroids[c as usize];
256            vector.extend_from_slice(centroid);
257        }
258
259        vector
260    }
261
262    /// Compute asymmetric distances from query to all codes
263    pub fn compute_distances(&self, query: &[f32], codes: &[PQCode]) -> Vec<f32> {
264        // Build distance lookup tables for each sub-vector
265        let subdim = self.config.subvector_dim();
266        let tables: Vec<Vec<f32>> = self
267            .codebooks
268            .iter()
269            .enumerate()
270            .map(|(m, codebook)| {
271                let subquery = &query[m * subdim..(m + 1) * subdim];
272                codebook.compute_distance_table(subquery)
273            })
274            .collect();
275
276        // Compute distance for each code using table lookups
277        codes
278            .iter()
279            .map(|code| {
280                code.iter()
281                    .enumerate()
282                    .map(|(m, &c)| tables[m][c as usize])
283                    .sum::<f32>()
284                    .sqrt()
285            })
286            .collect()
287    }
288
289    /// Get compression ratio
290    pub fn compression_ratio(&self) -> f32 {
291        let original_bytes = self.config.dimension * 4; // f32 = 4 bytes
292        let compressed_bytes = self.config.n_subvectors; // 1 byte per subvector
293        original_bytes as f32 / compressed_bytes as f32
294    }
295
296    /// Get configuration
297    pub fn config(&self) -> &PQConfig {
298        &self.config
299    }
300
301    /// Check if trained
302    pub fn is_trained(&self) -> bool {
303        self.trained
304    }
305}
306
307/// PQ-based index for compressed vector search
308pub struct PQIndex {
309    /// Product quantizer
310    pq: ProductQuantizer,
311    /// Stored codes
312    codes: Vec<PQCode>,
313    /// ID mapping
314    ids: Vec<NodeId>,
315    /// Reverse mapping
316    id_to_idx: HashMap<NodeId, usize>,
317    /// Original vectors (optional, for reranking)
318    originals: Option<Vec<Vec<f32>>>,
319    /// Next auto ID
320    next_id: NodeId,
321}
322
323impl PQIndex {
324    /// Create a new PQ index
325    pub fn new(config: PQConfig) -> Self {
326        Self {
327            pq: ProductQuantizer::new(config),
328            codes: Vec::new(),
329            ids: Vec::new(),
330            id_to_idx: HashMap::new(),
331            originals: None,
332            next_id: 0,
333        }
334    }
335
336    /// Create and enable original vector storage for reranking
337    pub fn with_originals(mut self) -> Self {
338        self.originals = Some(Vec::new());
339        self
340    }
341
342    /// Train the index
343    pub fn train(&mut self, vectors: &[Vec<f32>]) {
344        self.pq.train(vectors);
345    }
346
347    /// Add a vector
348    pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
349        let id = self.next_id;
350        self.next_id += 1;
351        self.add_with_id(id, vector);
352        id
353    }
354
355    /// Add a vector with ID
356    pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
357        let code = self.pq.encode(&vector);
358        let idx = self.codes.len();
359
360        self.codes.push(code);
361        self.ids.push(id);
362        self.id_to_idx.insert(id, idx);
363
364        if let Some(ref mut originals) = self.originals {
365            originals.push(vector);
366        }
367    }
368
369    /// Add multiple vectors
370    pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
371        vectors.into_iter().map(|v| self.add(v)).collect()
372    }
373
374    /// Search for k nearest neighbors
375    pub fn search(&self, query: &[f32], k: usize) -> Vec<(NodeId, f32)> {
376        if self.codes.is_empty() {
377            return Vec::new();
378        }
379
380        let distances = self.pq.compute_distances(query, &self.codes);
381
382        let mut results: Vec<(usize, f32)> = distances.into_iter().enumerate().collect();
383
384        results.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
385        results.truncate(k);
386
387        results
388            .into_iter()
389            .map(|(idx, dist)| (self.ids[idx], dist))
390            .collect()
391    }
392
393    /// Search with reranking using original vectors
394    pub fn search_rerank(&self, query: &[f32], k: usize, rerank_k: usize) -> Vec<(NodeId, f32)> {
395        let originals = match &self.originals {
396            Some(o) => o,
397            None => return self.search(query, k),
398        };
399
400        // Get more candidates for reranking
401        let candidates = self.search(query, rerank_k);
402
403        // Rerank using original vectors
404        let mut reranked: Vec<(NodeId, f32)> = candidates
405            .into_iter()
406            .map(|(id, _)| {
407                let idx = self.id_to_idx[&id];
408                let dist = l2_squared_simd(query, &originals[idx]).sqrt();
409                (id, dist)
410            })
411            .collect();
412
413        reranked.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
414        reranked.truncate(k);
415        reranked
416    }
417
418    /// Get number of vectors
419    pub fn len(&self) -> usize {
420        self.codes.len()
421    }
422
423    /// Check if empty
424    pub fn is_empty(&self) -> bool {
425        self.codes.is_empty()
426    }
427
428    /// Get compression ratio
429    pub fn compression_ratio(&self) -> f32 {
430        self.pq.compression_ratio()
431    }
432
433    /// Memory usage in bytes
434    pub fn memory_usage(&self) -> usize {
435        let code_bytes = self.codes.len() * self.pq.config.n_subvectors;
436        let original_bytes = self
437            .originals
438            .as_ref()
439            .map(|o| o.len() * self.pq.config.dimension * 4)
440            .unwrap_or(0);
441        code_bytes + original_bytes
442    }
443}
444
445// ============================================================================
446// Tests
447// ============================================================================
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
454        (0..dim)
455            .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
456            .collect()
457    }
458
459    #[test]
460    fn test_pq_encode_decode() {
461        let config = PQConfig::new(16, 4);
462        let mut pq = ProductQuantizer::new(config);
463
464        // Generate training data
465        let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(16, i)).collect();
466
467        pq.train(&training);
468        assert!(pq.is_trained());
469
470        // Encode and decode
471        let original = random_vector(16, 999);
472        let code = pq.encode(&original);
473        let decoded = pq.decode(&code);
474
475        assert_eq!(code.len(), 4); // M sub-vectors
476        assert_eq!(decoded.len(), 16);
477
478        // Decoded should be an approximation (not exact)
479        let reconstruction_error: f32 = original
480            .iter()
481            .zip(decoded.iter())
482            .map(|(a, b)| (a - b).powi(2))
483            .sum();
484        assert!(reconstruction_error < 1.0); // Should be close
485    }
486
487    #[test]
488    fn test_pq_compression_ratio() {
489        let pq = ProductQuantizer::new(PQConfig::new(128, 8));
490        // 128 floats * 4 bytes = 512 bytes original
491        // 8 sub-vectors * 1 byte = 8 bytes compressed
492        // Ratio = 512 / 8 = 64x
493        assert_eq!(pq.compression_ratio(), 64.0);
494    }
495
496    #[test]
497    fn test_pq_index_search() {
498        let mut index = PQIndex::new(PQConfig::new(8, 4));
499
500        // Training data
501        let training: Vec<Vec<f32>> = (0..50).map(|i| random_vector(8, i)).collect();
502
503        index.train(&training);
504
505        // Add vectors
506        for (i, v) in training.iter().enumerate() {
507            index.add_with_id(i as u64, v.clone());
508        }
509
510        // Search
511        let query = random_vector(8, 0);
512        let results = index.search(&query, 5);
513
514        assert_eq!(results.len(), 5);
515        // First result should be the query itself (ID 0)
516        assert_eq!(results[0].0, 0);
517    }
518
519    #[test]
520    fn test_pq_distance_tables() {
521        let config = PQConfig::new(8, 2);
522        let mut pq = ProductQuantizer::new(config);
523
524        let training: Vec<Vec<f32>> = vec![
525            vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
526            vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
527        ];
528
529        pq.train(&training);
530
531        let query = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
532        let codes = pq.encode_batch(&training);
533        let distances = pq.compute_distances(&query, &codes);
534
535        assert_eq!(distances.len(), 2);
536        // Distances should be approximately equal (equidistant from query)
537        assert!((distances[0] - distances[1]).abs() < 0.1);
538    }
539}