Skip to main content

engine/
pq.rs

1//! Product Quantization (PQ) for vector compression
2//!
3//! PQ compresses high-dimensional vectors by:
4//! 1. Splitting vectors into M subvectors
5//! 2. Training a codebook (K centroids) for each subspace
6//! 3. Encoding each subvector as an index into its codebook
7//! 4. Using lookup tables for fast distance computation (ADC)
8
9use std::collections::HashMap;
10
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14
15use common::{DistanceMetric, Vector, VectorId};
16
17/// Configuration for Product Quantization
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PQConfig {
20    /// Number of subquantizers (subspaces)
21    /// Vector dimension must be divisible by this
22    pub num_subquantizers: usize,
23    /// Number of centroids per subquantizer (typically 256 for 8-bit codes)
24    pub num_centroids: usize,
25    /// Number of k-means iterations for training
26    pub kmeans_iterations: usize,
27    /// Distance metric
28    pub distance_metric: DistanceMetric,
29}
30
31impl Default for PQConfig {
32    fn default() -> Self {
33        Self {
34            num_subquantizers: 8,
35            num_centroids: 256,
36            kmeans_iterations: 20,
37            distance_metric: DistanceMetric::Euclidean,
38        }
39    }
40}
41
42/// A trained Product Quantizer
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ProductQuantizer {
45    /// Configuration
46    pub config: PQConfig,
47    /// Codebooks: [num_subquantizers][num_centroids][subvector_dim]
48    pub codebooks: Vec<Vec<Vec<f32>>>,
49    /// Dimension of full vector
50    pub dimension: usize,
51    /// Dimension of each subvector
52    pub subvector_dim: usize,
53}
54
55impl ProductQuantizer {
56    /// Create a new untrained quantizer
57    pub fn new(config: PQConfig, dimension: usize) -> Result<Self, String> {
58        if !dimension.is_multiple_of(config.num_subquantizers) {
59            return Err(format!(
60                "Dimension {} not divisible by num_subquantizers {}",
61                dimension, config.num_subquantizers
62            ));
63        }
64
65        let subvector_dim = dimension / config.num_subquantizers;
66
67        Ok(Self {
68            config,
69            codebooks: Vec::new(),
70            dimension,
71            subvector_dim,
72        })
73    }
74
75    /// Train the quantizer on a set of vectors
76    pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
77        if vectors.is_empty() {
78            return Err("Cannot train on empty vectors".to_string());
79        }
80
81        if vectors[0].values.len() != self.dimension {
82            return Err(format!(
83                "Vector dimension {} doesn't match expected {}",
84                vectors[0].values.len(),
85                self.dimension
86            ));
87        }
88
89        let m = self.config.num_subquantizers;
90        let k = self.config.num_centroids;
91        let d = self.subvector_dim;
92
93        self.codebooks = Vec::with_capacity(m);
94
95        // Train a codebook for each subspace
96        for subspace_idx in 0..m {
97            let start = subspace_idx * d;
98            let end = start + d;
99
100            // Extract subvectors for this subspace
101            let subvectors: Vec<Vec<f32>> = vectors
102                .iter()
103                .map(|v| v.values[start..end].to_vec())
104                .collect();
105
106            // Run k-means on subvectors
107            let codebook = self.train_kmeans(&subvectors, k)?;
108            self.codebooks.push(codebook);
109        }
110
111        Ok(())
112    }
113
114    /// Train k-means for a single subspace
115    fn train_kmeans(&self, subvectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
116        if subvectors.is_empty() {
117            return Err("Cannot train k-means on empty subvectors".to_string());
118        }
119        let actual_k = k.min(subvectors.len());
120        let dim = subvectors[0].len();
121
122        // Initialize centroids with k-means++
123        let mut centroids = self.kmeans_plus_plus(subvectors, actual_k);
124
125        // Run k-means iterations
126        for _ in 0..self.config.kmeans_iterations {
127            // Assign subvectors to nearest centroid
128            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); actual_k];
129            for (i, subvec) in subvectors.iter().enumerate() {
130                let nearest = self.find_nearest_centroid(subvec, &centroids);
131                assignments[nearest].push(i);
132            }
133
134            // Update centroids
135            for (c_idx, assigned) in assignments.iter().enumerate() {
136                if assigned.is_empty() {
137                    continue;
138                }
139
140                let mut new_centroid = vec![0.0f32; dim];
141                for &vec_idx in assigned {
142                    for (j, &val) in subvectors[vec_idx].iter().enumerate() {
143                        new_centroid[j] += val;
144                    }
145                }
146
147                let count = assigned.len() as f32;
148                for val in &mut new_centroid {
149                    *val /= count;
150                }
151
152                centroids[c_idx] = new_centroid;
153            }
154        }
155
156        Ok(centroids)
157    }
158
159    /// K-means++ initialization
160    fn kmeans_plus_plus(&self, subvectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
161        let mut rng = rand::thread_rng();
162        let mut centroids = Vec::with_capacity(k);
163
164        // First centroid: random
165        if let Some(first) = subvectors.choose(&mut rng) {
166            centroids.push(first.clone());
167        } else {
168            return centroids;
169        }
170
171        // Remaining centroids: weighted by distance squared
172        for _ in 1..k {
173            let distances: Vec<f32> = subvectors
174                .iter()
175                .map(|v| {
176                    centroids
177                        .iter()
178                        .map(|c| self.squared_distance(v, c))
179                        .fold(f32::MAX, f32::min)
180                })
181                .collect();
182
183            let total: f32 = distances.iter().sum();
184            if total == 0.0 {
185                break;
186            }
187
188            let threshold: f32 = rand::random::<f32>() * total;
189            let mut cumsum = 0.0;
190
191            for (i, &d) in distances.iter().enumerate() {
192                cumsum += d;
193                if cumsum >= threshold {
194                    centroids.push(subvectors[i].clone());
195                    break;
196                }
197            }
198        }
199
200        centroids
201    }
202
203    /// Find nearest centroid index
204    fn find_nearest_centroid(&self, subvec: &[f32], centroids: &[Vec<f32>]) -> usize {
205        let mut best_idx = 0;
206        let mut best_dist = f32::MAX;
207
208        for (i, centroid) in centroids.iter().enumerate() {
209            let dist = self.squared_distance(subvec, centroid);
210            if dist < best_dist {
211                best_dist = dist;
212                best_idx = i;
213            }
214        }
215
216        best_idx
217    }
218
219    /// Squared Euclidean distance
220    #[inline]
221    fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
222        a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
223    }
224
225    /// Check if the quantizer is trained
226    pub fn is_trained(&self) -> bool {
227        !self.codebooks.is_empty()
228    }
229
230    /// Encode a vector into PQ codes
231    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>, String> {
232        if !self.is_trained() {
233            return Err("Quantizer not trained".to_string());
234        }
235
236        if vector.len() != self.dimension {
237            return Err(format!(
238                "Vector dimension {} doesn't match expected {}",
239                vector.len(),
240                self.dimension
241            ));
242        }
243
244        let m = self.config.num_subquantizers;
245        let d = self.subvector_dim;
246        let mut codes = Vec::with_capacity(m);
247
248        for subspace_idx in 0..m {
249            let start = subspace_idx * d;
250            let end = start + d;
251            let subvec = &vector[start..end];
252
253            let nearest = self.find_nearest_centroid(subvec, &self.codebooks[subspace_idx]);
254            codes.push(nearest as u8);
255        }
256
257        Ok(codes)
258    }
259
260    /// Decode PQ codes back to approximate vector
261    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>, String> {
262        if !self.is_trained() {
263            return Err("Quantizer not trained".to_string());
264        }
265
266        if codes.len() != self.config.num_subquantizers {
267            return Err(format!(
268                "Code length {} doesn't match num_subquantizers {}",
269                codes.len(),
270                self.config.num_subquantizers
271            ));
272        }
273
274        let mut vector = Vec::with_capacity(self.dimension);
275
276        for (subspace_idx, &code) in codes.iter().enumerate() {
277            let centroid = &self.codebooks[subspace_idx][code as usize];
278            vector.extend_from_slice(centroid);
279        }
280
281        Ok(vector)
282    }
283
284    /// Compute distance lookup table for a query vector
285    /// Returns [num_subquantizers][num_centroids] distances
286    pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, String> {
287        if !self.is_trained() {
288            return Err("Quantizer not trained".to_string());
289        }
290
291        if query.len() != self.dimension {
292            return Err(format!(
293                "Query dimension {} doesn't match expected {}",
294                query.len(),
295                self.dimension
296            ));
297        }
298
299        let m = self.config.num_subquantizers;
300        let k = self.config.num_centroids;
301        let d = self.subvector_dim;
302
303        let mut table = Vec::with_capacity(m);
304
305        for subspace_idx in 0..m {
306            let start = subspace_idx * d;
307            let end = start + d;
308            let query_subvec = &query[start..end];
309
310            let mut distances = Vec::with_capacity(k);
311            for centroid in &self.codebooks[subspace_idx] {
312                let dist = match self.config.distance_metric {
313                    DistanceMetric::Euclidean => {
314                        -self.squared_distance(query_subvec, centroid).sqrt()
315                    }
316                    DistanceMetric::Cosine => self.cosine_sim(query_subvec, centroid),
317                    DistanceMetric::DotProduct => self.dot_product(query_subvec, centroid),
318                };
319                distances.push(dist);
320            }
321
322            table.push(distances);
323        }
324
325        Ok(table)
326    }
327
328    /// Compute distance using precomputed table (ADC - Asymmetric Distance Computation)
329    #[inline]
330    pub fn compute_distance_adc(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
331        let mut total = 0.0f32;
332        for (subspace_idx, &code) in codes.iter().enumerate() {
333            total += table[subspace_idx][code as usize];
334        }
335        total
336    }
337
338    #[inline]
339    fn cosine_sim(&self, a: &[f32], b: &[f32]) -> f32 {
340        let mut dot = 0.0f32;
341        let mut norm_a = 0.0f32;
342        let mut norm_b = 0.0f32;
343
344        for (x, y) in a.iter().zip(b.iter()) {
345            dot += x * y;
346            norm_a += x * x;
347            norm_b += y * y;
348        }
349
350        let norm_a = norm_a.sqrt();
351        let norm_b = norm_b.sqrt();
352
353        if norm_a == 0.0 || norm_b == 0.0 {
354            0.0
355        } else {
356            dot / (norm_a * norm_b)
357        }
358    }
359
360    #[inline]
361    fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
362        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
363    }
364}
365
366/// PQ-based index for compressed vector search
367pub struct PQIndex {
368    /// Product quantizer
369    quantizer: RwLock<ProductQuantizer>,
370    /// Encoded vectors: vector_id -> codes
371    encoded_vectors: RwLock<HashMap<VectorId, Vec<u8>>>,
372    /// Original vectors (optional, for reconstruction)
373    original_vectors: RwLock<HashMap<VectorId, Vector>>,
374    /// Store original vectors for reconstruction
375    store_originals: bool,
376}
377
378/// Search result from PQ index
379#[derive(Debug, Clone)]
380pub struct PQSearchResult {
381    pub id: VectorId,
382    pub score: f32,
383    pub vector: Option<Vector>,
384}
385
386impl PQIndex {
387    /// Create a new PQ index
388    pub fn new(config: PQConfig, dimension: usize, store_originals: bool) -> Result<Self, String> {
389        let quantizer = ProductQuantizer::new(config, dimension)?;
390
391        Ok(Self {
392            quantizer: RwLock::new(quantizer),
393            encoded_vectors: RwLock::new(HashMap::new()),
394            original_vectors: RwLock::new(HashMap::new()),
395            store_originals,
396        })
397    }
398
399    /// Train the index on vectors
400    pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
401        let mut quantizer = self.quantizer.write();
402        quantizer.train(vectors)
403    }
404
405    /// Check if trained
406    pub fn is_trained(&self) -> bool {
407        self.quantizer.read().is_trained()
408    }
409
410    /// Add vectors to the index
411    pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
412        let quantizer = self.quantizer.read();
413        if !quantizer.is_trained() {
414            return Err("Index not trained".to_string());
415        }
416
417        let mut encoded = self.encoded_vectors.write();
418        let mut originals = self.original_vectors.write();
419        let mut count = 0;
420
421        for vector in vectors {
422            let codes = quantizer.encode(&vector.values)?;
423            encoded.insert(vector.id.clone(), codes);
424
425            if self.store_originals {
426                originals.insert(vector.id.clone(), vector);
427            }
428
429            count += 1;
430        }
431
432        Ok(count)
433    }
434
435    /// Remove vectors from the index
436    pub fn remove(&self, ids: &[VectorId]) -> usize {
437        let mut encoded = self.encoded_vectors.write();
438        let mut originals = self.original_vectors.write();
439        let mut count = 0;
440
441        for id in ids {
442            if encoded.remove(id).is_some() {
443                count += 1;
444            }
445            originals.remove(id);
446        }
447
448        count
449    }
450
451    /// Search for nearest neighbors
452    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<PQSearchResult>, String> {
453        let quantizer = self.quantizer.read();
454        if !quantizer.is_trained() {
455            return Err("Index not trained".to_string());
456        }
457
458        // Compute distance lookup table
459        let table = quantizer.compute_distance_table(query)?;
460
461        let encoded = self.encoded_vectors.read();
462        let originals = self.original_vectors.read();
463
464        // Compute distances using ADC
465        let mut results: Vec<PQSearchResult> = encoded
466            .iter()
467            .map(|(id, codes)| {
468                let score = quantizer.compute_distance_adc(&table, codes);
469                let vector = originals.get(id).cloned();
470
471                PQSearchResult {
472                    id: id.clone(),
473                    score,
474                    vector,
475                }
476            })
477            .collect();
478
479        // Sort by score descending (higher = more similar)
480        results.sort_by(|a, b| {
481            b.score
482                .partial_cmp(&a.score)
483                .unwrap_or(std::cmp::Ordering::Equal)
484        });
485        results.truncate(k);
486
487        Ok(results)
488    }
489
490    /// Get number of indexed vectors
491    pub fn len(&self) -> usize {
492        self.encoded_vectors.read().len()
493    }
494
495    /// Check if empty
496    pub fn is_empty(&self) -> bool {
497        self.encoded_vectors.read().is_empty()
498    }
499
500    /// Get compression ratio
501    pub fn compression_ratio(&self) -> f32 {
502        let quantizer = self.quantizer.read();
503        let original_size = quantizer.dimension * 4; // 4 bytes per f32
504        let compressed_size = quantizer.config.num_subquantizers; // 1 byte per code
505        original_size as f32 / compressed_size as f32
506    }
507
508    /// Decode a vector from its codes
509    pub fn decode(&self, id: &VectorId) -> Result<Vec<f32>, String> {
510        let quantizer = self.quantizer.read();
511        let encoded = self.encoded_vectors.read();
512
513        let codes = encoded
514            .get(id)
515            .ok_or_else(|| format!("Vector {} not found", id))?;
516
517        quantizer.decode(codes)
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
526        (0..n)
527            .map(|i| Vector {
528                id: format!("v{}", i),
529                values: (0..dim).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
530                metadata: None,
531                ttl_seconds: None,
532                expires_at: None,
533            })
534            .collect()
535    }
536
537    #[test]
538    fn test_pq_config_validation() {
539        let config = PQConfig {
540            num_subquantizers: 8,
541            ..Default::default()
542        };
543
544        // Valid dimension
545        assert!(ProductQuantizer::new(config.clone(), 64).is_ok());
546
547        // Invalid dimension (not divisible)
548        assert!(ProductQuantizer::new(config, 65).is_err());
549    }
550
551    #[test]
552    fn test_pq_train() {
553        let config = PQConfig {
554            num_subquantizers: 4,
555            num_centroids: 16,
556            kmeans_iterations: 10,
557            ..Default::default()
558        };
559
560        let mut pq = ProductQuantizer::new(config, 32).unwrap();
561        let vectors = test_vectors(100, 32);
562
563        assert!(!pq.is_trained());
564        pq.train(&vectors).unwrap();
565        assert!(pq.is_trained());
566
567        // Check codebook dimensions
568        assert_eq!(pq.codebooks.len(), 4);
569        assert_eq!(pq.codebooks[0].len(), 16);
570        assert_eq!(pq.codebooks[0][0].len(), 8); // 32 / 4 = 8
571    }
572
573    #[test]
574    fn test_pq_encode_decode() {
575        let config = PQConfig {
576            num_subquantizers: 4,
577            num_centroids: 16,
578            ..Default::default()
579        };
580
581        let mut pq = ProductQuantizer::new(config, 32).unwrap();
582        let vectors = test_vectors(100, 32);
583        pq.train(&vectors).unwrap();
584
585        // Encode a vector
586        let original = &vectors[0].values;
587        let codes = pq.encode(original).unwrap();
588
589        assert_eq!(codes.len(), 4);
590
591        // Decode back
592        let decoded = pq.decode(&codes).unwrap();
593        assert_eq!(decoded.len(), 32);
594
595        // Decoded should be approximate (not exact)
596        // Check that it's somewhat close
597        let error: f32 = original
598            .iter()
599            .zip(decoded.iter())
600            .map(|(a, b)| (a - b).powi(2))
601            .sum::<f32>()
602            .sqrt();
603
604        // Quantization error should be reasonable
605        assert!(error < 5.0, "Quantization error too high: {}", error);
606    }
607
608    #[test]
609    fn test_pq_distance_table() {
610        let config = PQConfig {
611            num_subquantizers: 4,
612            num_centroids: 16,
613            ..Default::default()
614        };
615
616        let mut pq = ProductQuantizer::new(config, 32).unwrap();
617        let vectors = test_vectors(100, 32);
618        pq.train(&vectors).unwrap();
619
620        let query = &vectors[0].values;
621        let table = pq.compute_distance_table(query).unwrap();
622
623        assert_eq!(table.len(), 4);
624        assert_eq!(table[0].len(), 16);
625    }
626
627    #[test]
628    fn test_pq_adc() {
629        let config = PQConfig {
630            num_subquantizers: 4,
631            num_centroids: 16,
632            ..Default::default()
633        };
634
635        let mut pq = ProductQuantizer::new(config, 32).unwrap();
636        let vectors = test_vectors(100, 32);
637        pq.train(&vectors).unwrap();
638
639        let query = &vectors[50].values;
640        let table = pq.compute_distance_table(query).unwrap();
641
642        // Encode query and compute ADC distance to itself
643        let codes = pq.encode(query).unwrap();
644        let dist = pq.compute_distance_adc(&table, &codes);
645
646        // Self-distance should be relatively small for Euclidean (negated)
647        // Due to quantization error, the distance won't be exactly 0
648        // We use a lenient threshold to account for coarse quantization
649        assert!(
650            dist > -3.0,
651            "Self-distance should be relatively small, got {}",
652            dist
653        );
654    }
655
656    #[test]
657    fn test_pq_index_basic() {
658        let config = PQConfig {
659            num_subquantizers: 4,
660            num_centroids: 16,
661            ..Default::default()
662        };
663
664        let index = PQIndex::new(config, 32, true).unwrap();
665        let vectors = test_vectors(100, 32);
666
667        index.train(&vectors).unwrap();
668        assert!(index.is_trained());
669
670        let added = index.add(vectors.clone()).unwrap();
671        assert_eq!(added, 100);
672        assert_eq!(index.len(), 100);
673    }
674
675    #[test]
676    fn test_pq_index_search() {
677        let config = PQConfig {
678            num_subquantizers: 4,
679            num_centroids: 32,
680            kmeans_iterations: 15,
681            distance_metric: DistanceMetric::Euclidean,
682        };
683
684        let index = PQIndex::new(config, 32, true).unwrap();
685        let vectors = test_vectors(200, 32);
686
687        index.train(&vectors).unwrap();
688        index.add(vectors.clone()).unwrap();
689
690        // Search for a vector
691        let query = &vectors[100].values;
692        let results = index.search(query, 10).unwrap();
693
694        assert!(!results.is_empty());
695        assert!(results.len() <= 10);
696
697        // Results should be sorted by score descending
698        for i in 1..results.len() {
699            assert!(results[i - 1].score >= results[i].score);
700        }
701
702        // The query vector should be among top results (approximate)
703        let found = results.iter().any(|r| r.id == "v100");
704        assert!(found, "Query vector not found in top results");
705    }
706
707    #[test]
708    fn test_pq_index_remove() {
709        let config = PQConfig {
710            num_subquantizers: 4,
711            num_centroids: 16,
712            ..Default::default()
713        };
714
715        let index = PQIndex::new(config, 32, false).unwrap();
716        let vectors = test_vectors(50, 32);
717
718        index.train(&vectors).unwrap();
719        index.add(vectors).unwrap();
720
721        assert_eq!(index.len(), 50);
722
723        let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
724        assert_eq!(removed, 2);
725        assert_eq!(index.len(), 48);
726    }
727
728    #[test]
729    fn test_pq_compression_ratio() {
730        let config = PQConfig {
731            num_subquantizers: 8,
732            num_centroids: 256,
733            ..Default::default()
734        };
735
736        let index = PQIndex::new(config, 128, false).unwrap();
737
738        // 128 dimensions * 4 bytes = 512 bytes original
739        // 8 subquantizers * 1 byte = 8 bytes compressed
740        // Ratio = 512 / 8 = 64x
741        let ratio = index.compression_ratio();
742        assert!((ratio - 64.0).abs() < 0.1);
743    }
744
745    #[test]
746    fn test_pq_decode_from_index() {
747        let config = PQConfig {
748            num_subquantizers: 4,
749            num_centroids: 16,
750            ..Default::default()
751        };
752
753        let index = PQIndex::new(config, 32, false).unwrap();
754        let vectors = test_vectors(50, 32);
755
756        index.train(&vectors).unwrap();
757        index.add(vectors).unwrap();
758
759        // Decode a vector
760        let decoded = index.decode(&"v10".to_string()).unwrap();
761        assert_eq!(decoded.len(), 32);
762
763        // Non-existent vector
764        assert!(index.decode(&"nonexistent".to_string()).is_err());
765    }
766}