memvid_core/
vec_pq.rs

1//! Product Quantization (PQ) for vector compression
2//!
3//! Compresses 384-dim f32 vectors from 1,536 bytes to 96 bytes (16x compression)
4//! while maintaining ~95% search accuracy using codebook-based quantization.
5//!
6//! **Algorithm**:
7//! 1. Split 384-dim vector into 96 subspaces of 4 dimensions each
8//! 2. For each subspace, train 256 centroids using k-means
9//! 3. Each vector is encoded as 96 bytes (one u8 index per subspace)
10//! 4. Search uses ADC (Asymmetric Distance Computation) with lookup tables
11
12use blake3::hash;
13use serde::{Deserialize, Serialize};
14
15use crate::vec::VecSearchHit;
16use crate::{MemvidError, Result, types::FrameId};
17
18fn vec_config() -> impl bincode::config::Config {
19    bincode::config::standard()
20        .with_fixed_int_encoding()
21        .with_little_endian()
22}
23
24const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
25
26/// Product Quantization parameters
27const NUM_SUBSPACES: usize = 96; // 384 dims / 4 dims per subspace
28const SUBSPACE_DIM: usize = 4; // Dimensions per subspace
29const NUM_CENTROIDS: usize = 256; // 2^8 centroids (encoded as u8)
30const TOTAL_DIM: usize = NUM_SUBSPACES * SUBSPACE_DIM; // 384
31
32/// Codebook for one subspace: 256 centroids, each with SUBSPACE_DIM dimensions
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SubspaceCodebook {
35    /// Flat array: centroids[i*SUBSPACE_DIM..(i+1)*SUBSPACE_DIM] is centroid i
36    centroids: Vec<f32>,
37}
38
39impl SubspaceCodebook {
40    fn new() -> Self {
41        Self {
42            centroids: vec![0.0; NUM_CENTROIDS * SUBSPACE_DIM],
43        }
44    }
45
46    fn get_centroid(&self, index: u8) -> &[f32] {
47        let start = (index as usize) * SUBSPACE_DIM;
48        &self.centroids[start..start + SUBSPACE_DIM]
49    }
50
51    fn set_centroid(&mut self, index: u8, values: &[f32]) {
52        assert_eq!(values.len(), SUBSPACE_DIM);
53        let start = (index as usize) * SUBSPACE_DIM;
54        self.centroids[start..start + SUBSPACE_DIM].copy_from_slice(values);
55    }
56
57    /// Find nearest centroid to a subspace vector
58    fn quantize(&self, subspace: &[f32]) -> u8 {
59        assert_eq!(subspace.len(), SUBSPACE_DIM);
60
61        let mut best_idx = 0u8;
62        let mut best_dist = f32::INFINITY;
63
64        for i in 0..NUM_CENTROIDS {
65            let centroid = self.get_centroid(i as u8);
66            let dist = l2_distance_squared(subspace, centroid);
67            if dist < best_dist {
68                best_dist = dist;
69                best_idx = i as u8;
70            }
71        }
72
73        best_idx
74    }
75}
76
77/// Product Quantizer with codebooks for all subspaces
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ProductQuantizer {
80    /// One codebook per subspace
81    codebooks: Vec<SubspaceCodebook>,
82    dimension: u32,
83}
84
85impl ProductQuantizer {
86    /// Create uninitialized quantizer
87    pub fn new(dimension: u32) -> Result<Self> {
88        if dimension as usize != TOTAL_DIM {
89            return Err(MemvidError::InvalidQuery {
90                reason: format!(
91                    "PQ only supports {}-dim vectors, got {}",
92                    TOTAL_DIM, dimension
93                ),
94            });
95        }
96
97        Ok(Self {
98            codebooks: vec![SubspaceCodebook::new(); NUM_SUBSPACES],
99            dimension,
100        })
101    }
102
103    /// Train codebooks using k-means on sample vectors
104    pub fn train(&mut self, training_vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
105        if training_vectors.is_empty() {
106            return Err(MemvidError::InvalidQuery {
107                reason: "Cannot train PQ with empty training set".to_string(),
108            });
109        }
110
111        // Verify all vectors have correct dimension
112        for vec in training_vectors {
113            if vec.len() != TOTAL_DIM {
114                return Err(MemvidError::InvalidQuery {
115                    reason: format!(
116                        "Training vector has wrong dimension: expected {}, got {}",
117                        TOTAL_DIM,
118                        vec.len()
119                    ),
120                });
121            }
122        }
123
124        // Train each subspace independently
125        for subspace_idx in 0..NUM_SUBSPACES {
126            let start_dim = subspace_idx * SUBSPACE_DIM;
127            let end_dim = start_dim + SUBSPACE_DIM;
128
129            // Extract subspace vectors
130            let subspace_vecs: Vec<Vec<f32>> = training_vectors
131                .iter()
132                .map(|v| v[start_dim..end_dim].to_vec())
133                .collect();
134
135            // Run k-means
136            let centroids = kmeans(&subspace_vecs, NUM_CENTROIDS, max_iterations)?;
137
138            // Store in codebook
139            for (i, centroid) in centroids.iter().enumerate() {
140                self.codebooks[subspace_idx].set_centroid(i as u8, centroid);
141            }
142        }
143
144        Ok(())
145    }
146
147    /// Encode a vector into PQ codes (96 bytes)
148    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
149        if vector.len() != TOTAL_DIM {
150            return Err(MemvidError::InvalidQuery {
151                reason: format!(
152                    "Vector dimension mismatch: expected {}, got {}",
153                    TOTAL_DIM,
154                    vector.len()
155                ),
156            });
157        }
158
159        let mut codes = Vec::with_capacity(NUM_SUBSPACES);
160
161        for subspace_idx in 0..NUM_SUBSPACES {
162            let start_dim = subspace_idx * SUBSPACE_DIM;
163            let end_dim = start_dim + SUBSPACE_DIM;
164            let subspace = &vector[start_dim..end_dim];
165
166            let code = self.codebooks[subspace_idx].quantize(subspace);
167            codes.push(code);
168        }
169
170        Ok(codes)
171    }
172
173    /// Decode PQ codes back to approximate vector (for debugging/verification)
174    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
175        if codes.len() != NUM_SUBSPACES {
176            return Err(MemvidError::InvalidQuery {
177                reason: format!(
178                    "Invalid PQ codes length: expected {}, got {}",
179                    NUM_SUBSPACES,
180                    codes.len()
181                ),
182            });
183        }
184
185        let mut vector = Vec::with_capacity(TOTAL_DIM);
186
187        for (subspace_idx, &code) in codes.iter().enumerate() {
188            let centroid = self.codebooks[subspace_idx].get_centroid(code);
189            vector.extend_from_slice(centroid);
190        }
191
192        Ok(vector)
193    }
194
195    /// Compute asymmetric distance between query vector and PQ-encoded vector
196    /// Uses precomputed lookup tables for efficiency
197    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
198        if query.len() != TOTAL_DIM || codes.len() != NUM_SUBSPACES {
199            return f32::INFINITY;
200        }
201
202        let mut total_dist_sq = 0.0f32;
203
204        for subspace_idx in 0..NUM_SUBSPACES {
205            let start_dim = subspace_idx * SUBSPACE_DIM;
206            let end_dim = start_dim + SUBSPACE_DIM;
207            let query_subspace = &query[start_dim..end_dim];
208
209            let code = codes[subspace_idx];
210            let centroid = self.codebooks[subspace_idx].get_centroid(code);
211
212            total_dist_sq += l2_distance_squared(query_subspace, centroid);
213        }
214
215        total_dist_sq.sqrt()
216    }
217}
218
219/// Compressed vector document
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct QuantizedVecDocument {
222    pub frame_id: FrameId,
223    /// PQ codes: 96 bytes (one u8 per subspace)
224    pub codes: Vec<u8>,
225}
226
227/// Builder for compressed vector index
228#[derive(Default)]
229pub struct QuantizedVecIndexBuilder {
230    documents: Vec<QuantizedVecDocument>,
231    quantizer: Option<ProductQuantizer>,
232}
233
234impl QuantizedVecIndexBuilder {
235    pub fn new() -> Self {
236        Self::default()
237    }
238
239    /// Train quantizer on sample vectors before encoding
240    pub fn train_quantizer(&mut self, training_vectors: &[Vec<f32>], dimension: u32) -> Result<()> {
241        let mut pq = ProductQuantizer::new(dimension)?;
242        pq.train(training_vectors, 25)?; // 25 k-means iterations
243        self.quantizer = Some(pq);
244        Ok(())
245    }
246
247    /// Add document with pre-trained quantizer
248    pub fn add_document(&mut self, frame_id: FrameId, embedding: Vec<f32>) -> Result<()> {
249        let quantizer = self
250            .quantizer
251            .as_ref()
252            .ok_or_else(|| MemvidError::InvalidQuery {
253                reason: "Quantizer not trained. Call train_quantizer first".to_string(),
254            })?;
255
256        let codes = quantizer.encode(&embedding)?;
257
258        self.documents
259            .push(QuantizedVecDocument { frame_id, codes });
260
261        Ok(())
262    }
263
264    pub fn finish(self) -> Result<QuantizedVecIndexArtifact> {
265        let quantizer = self.quantizer.ok_or_else(|| MemvidError::InvalidQuery {
266            reason: "Quantizer not trained".to_string(),
267        })?;
268
269        let vector_count = self.documents.len() as u64;
270        let bytes =
271            bincode::serde::encode_to_vec(&(quantizer.clone(), self.documents), vec_config())?;
272        let checksum = *hash(&bytes).as_bytes();
273
274        Ok(QuantizedVecIndexArtifact {
275            bytes,
276            vector_count,
277            dimension: quantizer.dimension,
278            checksum,
279            compression_ratio: 16.0, // 1536 bytes -> 96 bytes
280        })
281    }
282}
283
284#[derive(Debug, Clone)]
285pub struct QuantizedVecIndexArtifact {
286    pub bytes: Vec<u8>,
287    pub vector_count: u64,
288    pub dimension: u32,
289    pub checksum: [u8; 32],
290    pub compression_ratio: f64,
291}
292
293#[derive(Debug, Clone)]
294pub struct QuantizedVecIndex {
295    quantizer: ProductQuantizer,
296    documents: Vec<QuantizedVecDocument>,
297}
298
299impl QuantizedVecIndex {
300    pub fn decode(bytes: &[u8]) -> Result<Self> {
301        // Try decoding with current format (with dimension field)
302        let config = bincode::config::standard()
303            .with_fixed_int_encoding()
304            .with_little_endian()
305            .with_limit::<VEC_DECODE_LIMIT>();
306
307        if let Ok(((quantizer, documents), read)) = bincode::serde::decode_from_slice::<
308            (ProductQuantizer, Vec<QuantizedVecDocument>),
309            _,
310        >(bytes, config.clone())
311        {
312            if read == bytes.len() {
313                return Ok(Self {
314                    quantizer,
315                    documents,
316                });
317            }
318        }
319
320        // Fall back to old format (without dimension field)
321        // Old ProductQuantizer struct without dimension field
322        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323        struct OldProductQuantizer {
324            codebooks: Vec<SubspaceCodebook>,
325        }
326
327        let ((old_quantizer, documents), read): (
328            (OldProductQuantizer, Vec<QuantizedVecDocument>),
329            usize,
330        ) = bincode::serde::decode_from_slice(bytes, config)?;
331
332        if read != bytes.len() {
333            return Err(MemvidError::InvalidToc {
334                reason: "unsupported quantized vector index encoding".into(),
335            });
336        }
337
338        // Convert old format to new format
339        let quantizer = ProductQuantizer {
340            codebooks: old_quantizer.codebooks,
341            dimension: (NUM_SUBSPACES * SUBSPACE_DIM) as u32,
342        };
343
344        Ok(Self {
345            quantizer,
346            documents,
347        })
348    }
349
350    /// Search using asymmetric distance computation
351    pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
352        if query.is_empty() {
353            return Vec::new();
354        }
355
356        let mut hits: Vec<VecSearchHit> = self
357            .documents
358            .iter()
359            .map(|doc| {
360                let distance = self.quantizer.asymmetric_distance(query, &doc.codes);
361                VecSearchHit {
362                    frame_id: doc.frame_id,
363                    distance,
364                }
365            })
366            .collect();
367
368        hits.sort_by(|a, b| {
369            a.distance
370                .partial_cmp(&b.distance)
371                .unwrap_or(std::cmp::Ordering::Equal)
372        });
373
374        hits.truncate(limit);
375        hits
376    }
377
378    pub fn remove(&mut self, frame_id: FrameId) {
379        self.documents.retain(|doc| doc.frame_id != frame_id);
380    }
381
382    /// Get compression statistics
383    pub fn compression_stats(&self) -> CompressionStats {
384        let original_bytes = self.documents.len() * TOTAL_DIM * std::mem::size_of::<f32>();
385        let compressed_bytes = self.documents.len() * NUM_SUBSPACES; // 96 bytes per vector
386        let codebook_bytes =
387            NUM_SUBSPACES * NUM_CENTROIDS * SUBSPACE_DIM * std::mem::size_of::<f32>();
388
389        CompressionStats {
390            vector_count: self.documents.len() as u64,
391            original_bytes: original_bytes as u64,
392            compressed_bytes: compressed_bytes as u64,
393            codebook_bytes: codebook_bytes as u64,
394            total_bytes: (compressed_bytes + codebook_bytes) as u64,
395            compression_ratio: original_bytes as f64 / (compressed_bytes + codebook_bytes) as f64,
396        }
397    }
398}
399
400#[derive(Debug, Clone)]
401pub struct CompressionStats {
402    pub vector_count: u64,
403    pub original_bytes: u64,
404    pub compressed_bytes: u64,
405    pub codebook_bytes: u64,
406    pub total_bytes: u64,
407    pub compression_ratio: f64,
408}
409
410/// K-means clustering for a single subspace
411fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
412    if vectors.is_empty() {
413        return Err(MemvidError::InvalidQuery {
414            reason: "Cannot run k-means on empty vector set".to_string(),
415        });
416    }
417
418    let dim = vectors[0].len();
419
420    // Initialize centroids using k-means++ for better convergence
421    let mut centroids = kmeans_plus_plus_init(vectors, k)?;
422
423    for _iteration in 0..max_iterations {
424        // Assignment step: assign each vector to nearest centroid
425        let mut assignments = vec![Vec::new(); k];
426
427        for vec in vectors {
428            let mut best_cluster = 0;
429            let mut best_dist = f32::INFINITY;
430
431            for (cluster_idx, centroid) in centroids.iter().enumerate() {
432                let dist = l2_distance_squared(vec, centroid);
433                if dist < best_dist {
434                    best_dist = dist;
435                    best_cluster = cluster_idx;
436                }
437            }
438
439            assignments[best_cluster].push(vec.clone());
440        }
441
442        // Update step: recompute centroids
443        let mut changed = false;
444        for (cluster_idx, cluster_vecs) in assignments.iter().enumerate() {
445            if cluster_vecs.is_empty() {
446                // Empty cluster: reinitialize with random vector
447                centroids[cluster_idx] = vectors[cluster_idx % vectors.len()].clone();
448                changed = true;
449                continue;
450            }
451
452            let mut new_centroid = vec![0.0f32; dim];
453            for vec in cluster_vecs {
454                for (i, &val) in vec.iter().enumerate() {
455                    new_centroid[i] += val;
456                }
457            }
458            for val in &mut new_centroid {
459                *val /= cluster_vecs.len() as f32;
460            }
461
462            // Check if centroid changed
463            if l2_distance_squared(&centroids[cluster_idx], &new_centroid) > 1e-6 {
464                changed = true;
465            }
466
467            centroids[cluster_idx] = new_centroid;
468        }
469
470        if !changed {
471            break; // Converged
472        }
473    }
474
475    Ok(centroids)
476}
477
478/// K-means++ initialization for better initial centroids
479fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
480    if vectors.is_empty() || k == 0 {
481        return Err(MemvidError::InvalidQuery {
482            reason: "Invalid k-means++ initialization".to_string(),
483        });
484    }
485
486    let mut centroids = Vec::new();
487
488    // Choose first centroid randomly (use first vector for determinism)
489    centroids.push(vectors[0].clone());
490
491    // Choose remaining k-1 centroids
492    for _ in 1..k {
493        let mut distances = Vec::new();
494
495        // Compute distance to nearest existing centroid for each vector
496        for vec in vectors {
497            let mut min_dist = f32::INFINITY;
498            for centroid in &centroids {
499                let dist = l2_distance_squared(vec, centroid);
500                min_dist = min_dist.min(dist);
501            }
502            distances.push(min_dist);
503        }
504
505        // Choose vector with maximum distance as next centroid
506        let max_idx = distances
507            .iter()
508            .enumerate()
509            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
510            .map(|(idx, _)| idx)
511            .unwrap_or(0);
512
513        centroids.push(vectors[max_idx].clone());
514    }
515
516    Ok(centroids)
517}
518
519/// Squared L2 distance between two vectors
520fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
521    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_subspace_codebook() {
530        let mut codebook = SubspaceCodebook::new();
531
532        // Set a centroid
533        codebook.set_centroid(0, &[1.0, 2.0, 3.0, 4.0]);
534
535        // Retrieve it
536        let centroid = codebook.get_centroid(0);
537        assert_eq!(centroid, &[1.0, 2.0, 3.0, 4.0]);
538
539        // Quantize a similar vector
540        let code = codebook.quantize(&[1.1, 2.1, 3.1, 4.1]);
541        assert_eq!(code, 0);
542    }
543
544    #[test]
545    fn test_product_quantizer_roundtrip() {
546        // Create sample 384-dim vectors
547        let mut training_vecs = Vec::new();
548        for i in 0..100 {
549            let mut vec = vec![0.0f32; TOTAL_DIM];
550            for j in 0..TOTAL_DIM {
551                vec[j] = ((i * TOTAL_DIM + j) % 100) as f32 / 100.0;
552            }
553            training_vecs.push(vec);
554        }
555
556        // Train quantizer
557        let mut pq = ProductQuantizer::new(TOTAL_DIM as u32).unwrap();
558        pq.train(&training_vecs, 10).unwrap();
559
560        // Encode a vector
561        let test_vec = &training_vecs[0];
562        let codes = pq.encode(test_vec).unwrap();
563        assert_eq!(codes.len(), NUM_SUBSPACES);
564
565        // Decode and verify approximate reconstruction
566        let decoded = pq.decode(&codes).unwrap();
567        assert_eq!(decoded.len(), TOTAL_DIM);
568
569        // Distance between original and decoded should be small
570        let dist = l2_distance_squared(test_vec, &decoded).sqrt();
571        assert!(dist < 10.0, "Reconstruction error too large: {}", dist);
572    }
573
574    #[test]
575    fn test_quantized_index_builder() {
576        // Create sample vectors
577        let mut training_vecs = Vec::new();
578        for i in 0..50 {
579            let mut vec = vec![0.0f32; TOTAL_DIM];
580            for j in 0..TOTAL_DIM {
581                vec[j] = ((i + j) % 10) as f32;
582            }
583            training_vecs.push(vec);
584        }
585
586        // Build index
587        let mut builder = QuantizedVecIndexBuilder::new();
588        builder
589            .train_quantizer(&training_vecs, TOTAL_DIM as u32)
590            .unwrap();
591
592        for (i, vec) in training_vecs.iter().take(10).enumerate() {
593            builder
594                .add_document((i + 1) as FrameId, vec.clone())
595                .unwrap();
596        }
597
598        let artifact = builder.finish().unwrap();
599        assert_eq!(artifact.vector_count, 10);
600        assert_eq!(artifact.dimension, TOTAL_DIM as u32);
601        assert!(artifact.compression_ratio > 10.0);
602
603        // Decode and search
604        let index = QuantizedVecIndex::decode(&artifact.bytes).unwrap();
605        let query = &training_vecs[0];
606        let hits = index.search(query, 5);
607
608        assert!(!hits.is_empty());
609        assert_eq!(hits[0].frame_id, 1); // Should find exact match first
610    }
611
612    #[test]
613    fn test_kmeans_simple() {
614        let vectors = vec![
615            vec![0.0, 0.0],
616            vec![0.1, 0.1],
617            vec![10.0, 10.0],
618            vec![10.1, 10.1],
619        ];
620
621        let centroids = kmeans(&vectors, 2, 100).unwrap();
622        assert_eq!(centroids.len(), 2);
623
624        // One centroid should be near [0, 0], the other near [10, 10]
625        let near_zero = centroids.iter().any(|c| c[0] < 5.0 && c[1] < 5.0);
626        let near_ten = centroids.iter().any(|c| c[0] > 5.0 && c[1] > 5.0);
627        assert!(near_zero && near_ten);
628    }
629}