Skip to main content

memvid_core/
vec_pq.rs

1//! Product Quantization (PQ) for vector compression
2//!
3//! Compresses 384-dim f32 vectors from 1,536 bytes to 96 bytes (16x compression)
4//! while maintaining ~95% search accuracy using codebook-based quantization.
5//!
6//! **Algorithm**:
7//! 1. Split 384-dim vector into 96 subspaces of 4 dimensions each
8//! 2. For each subspace, train 256 centroids using k-means
9//! 3. Each vector is encoded as 96 bytes (one u8 index per subspace)
10//! 4. Search uses ADC (Asymmetric Distance Computation) with lookup tables
11
12use blake3::hash;
13use serde::{Deserialize, Serialize};
14
15use crate::vec::VecSearchHit;
16use crate::{MemvidError, Result, types::FrameId};
17
18fn vec_config() -> impl bincode::config::Config {
19    bincode::config::standard()
20        .with_fixed_int_encoding()
21        .with_little_endian()
22}
23
24#[allow(clippy::cast_possible_truncation)]
25const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
26
27/// Product Quantization parameters
28const NUM_SUBSPACES: usize = 96; // 384 dims / 4 dims per subspace
29const SUBSPACE_DIM: usize = 4; // Dimensions per subspace
30const NUM_CENTROIDS: usize = 256; // 2^8 centroids (encoded as u8)
31const TOTAL_DIM: usize = NUM_SUBSPACES * SUBSPACE_DIM; // 384
32
33/// Codebook for one subspace: 256 centroids, each with `SUBSPACE_DIM` dimensions
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct SubspaceCodebook {
36    /// Flat array: centroids[i*`SUBSPACE_DIM..(i+1)`*`SUBSPACE_DIM`] is centroid i
37    centroids: Vec<f32>,
38}
39
40impl SubspaceCodebook {
41    fn new() -> Self {
42        Self {
43            centroids: vec![0.0; NUM_CENTROIDS * SUBSPACE_DIM],
44        }
45    }
46
47    fn get_centroid(&self, index: u8) -> &[f32] {
48        let start = (index as usize) * SUBSPACE_DIM;
49        &self.centroids[start..start + SUBSPACE_DIM]
50    }
51
52    fn set_centroid(&mut self, index: u8, values: &[f32]) {
53        assert_eq!(values.len(), SUBSPACE_DIM);
54        let start = (index as usize) * SUBSPACE_DIM;
55        self.centroids[start..start + SUBSPACE_DIM].copy_from_slice(values);
56    }
57
58    /// Find nearest centroid to a subspace vector
59    fn quantize(&self, subspace: &[f32]) -> u8 {
60        assert_eq!(subspace.len(), SUBSPACE_DIM);
61
62        let mut best_idx = 0u8;
63        let mut best_dist = f32::INFINITY;
64
65        for i in 0..NUM_CENTROIDS {
66            #[allow(clippy::cast_possible_truncation)]
67            let centroid = self.get_centroid(i as u8);
68            let dist = l2_distance_squared(subspace, centroid);
69            if dist < best_dist {
70                best_dist = dist;
71                #[allow(clippy::cast_possible_truncation)]
72                {
73                    best_idx = i as u8;
74                }
75            }
76        }
77
78        best_idx
79    }
80}
81
82/// Product Quantizer with codebooks for all subspaces
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProductQuantizer {
85    /// One codebook per subspace
86    codebooks: Vec<SubspaceCodebook>,
87    dimension: u32,
88}
89
90impl ProductQuantizer {
91    /// Create uninitialized quantizer
92    pub fn new(dimension: u32) -> Result<Self> {
93        if dimension as usize != TOTAL_DIM {
94            return Err(MemvidError::InvalidQuery {
95                reason: format!("PQ only supports {TOTAL_DIM}-dim vectors, got {dimension}"),
96            });
97        }
98
99        Ok(Self {
100            codebooks: vec![SubspaceCodebook::new(); NUM_SUBSPACES],
101            dimension,
102        })
103    }
104
105    /// Train codebooks using k-means on sample vectors
106    pub fn train(&mut self, training_vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
107        if training_vectors.is_empty() {
108            return Err(MemvidError::InvalidQuery {
109                reason: "Cannot train PQ with empty training set".to_string(),
110            });
111        }
112
113        // Verify all vectors have correct dimension
114        for vec in training_vectors {
115            if vec.len() != TOTAL_DIM {
116                return Err(MemvidError::InvalidQuery {
117                    reason: format!(
118                        "Training vector has wrong dimension: expected {}, got {}",
119                        TOTAL_DIM,
120                        vec.len()
121                    ),
122                });
123            }
124        }
125
126        // Train each subspace independently
127        for subspace_idx in 0..NUM_SUBSPACES {
128            let start_dim = subspace_idx * SUBSPACE_DIM;
129            let end_dim = start_dim + SUBSPACE_DIM;
130
131            // Extract subspace vectors
132            let subspace_vecs: Vec<Vec<f32>> = training_vectors
133                .iter()
134                .map(|v| v[start_dim..end_dim].to_vec())
135                .collect();
136
137            // Run k-means
138            let centroids = kmeans(&subspace_vecs, NUM_CENTROIDS, max_iterations)?;
139
140            // Store in codebook
141            for (i, centroid) in centroids.iter().enumerate() {
142                #[allow(clippy::cast_possible_truncation)]
143                self.codebooks[subspace_idx].set_centroid(i as u8, centroid);
144            }
145        }
146
147        Ok(())
148    }
149
150    /// Encode a vector into PQ codes (96 bytes)
151    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
152        if vector.len() != TOTAL_DIM {
153            return Err(MemvidError::InvalidQuery {
154                reason: format!(
155                    "Vector dimension mismatch: expected {}, got {}",
156                    TOTAL_DIM,
157                    vector.len()
158                ),
159            });
160        }
161
162        let mut codes = Vec::with_capacity(NUM_SUBSPACES);
163
164        for subspace_idx in 0..NUM_SUBSPACES {
165            let start_dim = subspace_idx * SUBSPACE_DIM;
166            let end_dim = start_dim + SUBSPACE_DIM;
167            let subspace = &vector[start_dim..end_dim];
168
169            let code = self.codebooks[subspace_idx].quantize(subspace);
170            codes.push(code);
171        }
172
173        Ok(codes)
174    }
175
176    /// Decode PQ codes back to approximate vector (for debugging/verification)
177    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
178        if codes.len() != NUM_SUBSPACES {
179            return Err(MemvidError::InvalidQuery {
180                reason: format!(
181                    "Invalid PQ codes length: expected {}, got {}",
182                    NUM_SUBSPACES,
183                    codes.len()
184                ),
185            });
186        }
187
188        let mut vector = Vec::with_capacity(TOTAL_DIM);
189
190        for (subspace_idx, &code) in codes.iter().enumerate() {
191            let centroid = self.codebooks[subspace_idx].get_centroid(code);
192            vector.extend_from_slice(centroid);
193        }
194
195        Ok(vector)
196    }
197
198    /// Compute asymmetric distance between query vector and PQ-encoded vector
199    /// Uses precomputed lookup tables for efficiency
200    #[must_use]
201    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
202        if query.len() != TOTAL_DIM || codes.len() != NUM_SUBSPACES {
203            return f32::INFINITY;
204        }
205
206        let mut total_dist_sq = 0.0f32;
207
208        for subspace_idx in 0..NUM_SUBSPACES {
209            let start_dim = subspace_idx * SUBSPACE_DIM;
210            let end_dim = start_dim + SUBSPACE_DIM;
211            let query_subspace = &query[start_dim..end_dim];
212
213            let code = codes[subspace_idx];
214            let centroid = self.codebooks[subspace_idx].get_centroid(code);
215
216            total_dist_sq += l2_distance_squared(query_subspace, centroid);
217        }
218
219        total_dist_sq.sqrt()
220    }
221}
222
223/// Compressed vector document
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct QuantizedVecDocument {
226    pub frame_id: FrameId,
227    /// PQ codes: 96 bytes (one u8 per subspace)
228    pub codes: Vec<u8>,
229}
230
231/// Builder for compressed vector index
232#[derive(Default)]
233pub struct QuantizedVecIndexBuilder {
234    documents: Vec<QuantizedVecDocument>,
235    quantizer: Option<ProductQuantizer>,
236}
237
238impl QuantizedVecIndexBuilder {
239    #[must_use]
240    pub fn new() -> Self {
241        Self::default()
242    }
243
244    /// Train quantizer on sample vectors before encoding
245    pub fn train_quantizer(&mut self, training_vectors: &[Vec<f32>], dimension: u32) -> Result<()> {
246        let mut pq = ProductQuantizer::new(dimension)?;
247        pq.train(training_vectors, 25)?; // 25 k-means iterations
248        self.quantizer = Some(pq);
249        Ok(())
250    }
251
252    /// Add document with pre-trained quantizer
253    pub fn add_document(&mut self, frame_id: FrameId, embedding: Vec<f32>) -> Result<()> {
254        let quantizer = self
255            .quantizer
256            .as_ref()
257            .ok_or_else(|| MemvidError::InvalidQuery {
258                reason: "Quantizer not trained. Call train_quantizer first".to_string(),
259            })?;
260
261        let codes = quantizer.encode(&embedding)?;
262
263        self.documents
264            .push(QuantizedVecDocument { frame_id, codes });
265
266        Ok(())
267    }
268
269    pub fn finish(self) -> Result<QuantizedVecIndexArtifact> {
270        let quantizer = self.quantizer.ok_or_else(|| MemvidError::InvalidQuery {
271            reason: "Quantizer not trained".to_string(),
272        })?;
273
274        let vector_count = self.documents.len() as u64;
275        let bytes =
276            bincode::serde::encode_to_vec(&(quantizer.clone(), self.documents), vec_config())?;
277        let checksum = *hash(&bytes).as_bytes();
278
279        Ok(QuantizedVecIndexArtifact {
280            bytes,
281            vector_count,
282            dimension: quantizer.dimension,
283            checksum,
284            compression_ratio: 16.0, // 1536 bytes -> 96 bytes
285        })
286    }
287}
288
289#[derive(Debug, Clone)]
290pub struct QuantizedVecIndexArtifact {
291    pub bytes: Vec<u8>,
292    pub vector_count: u64,
293    pub dimension: u32,
294    pub checksum: [u8; 32],
295    pub compression_ratio: f64,
296}
297
298#[derive(Debug, Clone)]
299pub struct QuantizedVecIndex {
300    quantizer: ProductQuantizer,
301    documents: Vec<QuantizedVecDocument>,
302}
303
304impl QuantizedVecIndex {
305    pub fn decode(bytes: &[u8]) -> Result<Self> {
306        // Try decoding with current format (with dimension field)
307        let config = bincode::config::standard()
308            .with_fixed_int_encoding()
309            .with_little_endian()
310            .with_limit::<VEC_DECODE_LIMIT>();
311
312        if let Ok(((quantizer, documents), read)) = bincode::serde::decode_from_slice::<
313            (ProductQuantizer, Vec<QuantizedVecDocument>),
314            _,
315        >(bytes, config)
316        {
317            if read == bytes.len() {
318                return Ok(Self {
319                    quantizer,
320                    documents,
321                });
322            }
323        }
324
325        // Fall back to old format (without dimension field)
326        // Old ProductQuantizer struct without dimension field
327        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
328        struct OldProductQuantizer {
329            codebooks: Vec<SubspaceCodebook>,
330        }
331
332        let ((old_quantizer, documents), read): (
333            (OldProductQuantizer, Vec<QuantizedVecDocument>),
334            usize,
335        ) = bincode::serde::decode_from_slice(bytes, config)?;
336
337        if read != bytes.len() {
338            return Err(MemvidError::InvalidToc {
339                reason: "unsupported quantized vector index encoding".into(),
340            });
341        }
342
343        // Convert old format to new format
344        let quantizer = ProductQuantizer {
345            codebooks: old_quantizer.codebooks,
346            dimension: u32::try_from(NUM_SUBSPACES * SUBSPACE_DIM).unwrap_or(u32::MAX),
347        };
348
349        Ok(Self {
350            quantizer,
351            documents,
352        })
353    }
354
355    /// Search using asymmetric distance computation
356    #[must_use]
357    pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
358        if query.is_empty() {
359            return Vec::new();
360        }
361
362        let mut hits: Vec<VecSearchHit> = self
363            .documents
364            .iter()
365            .map(|doc| {
366                let distance = self.quantizer.asymmetric_distance(query, &doc.codes);
367                VecSearchHit {
368                    frame_id: doc.frame_id,
369                    distance,
370                }
371            })
372            .collect();
373
374        hits.sort_by(|a, b| {
375            a.distance
376                .partial_cmp(&b.distance)
377                .unwrap_or(std::cmp::Ordering::Equal)
378        });
379
380        hits.truncate(limit);
381        hits
382    }
383
384    pub fn remove(&mut self, frame_id: FrameId) {
385        self.documents.retain(|doc| doc.frame_id != frame_id);
386    }
387
388    /// Get compression statistics
389    #[must_use]
390    pub fn compression_stats(&self) -> CompressionStats {
391        let original_bytes = self.documents.len() * TOTAL_DIM * std::mem::size_of::<f32>();
392        let compressed_bytes = self.documents.len() * NUM_SUBSPACES; // 96 bytes per vector
393        let codebook_bytes =
394            NUM_SUBSPACES * NUM_CENTROIDS * SUBSPACE_DIM * std::mem::size_of::<f32>();
395
396        CompressionStats {
397            vector_count: self.documents.len() as u64,
398            original_bytes: original_bytes as u64,
399            compressed_bytes: compressed_bytes as u64,
400            codebook_bytes: codebook_bytes as u64,
401            total_bytes: (compressed_bytes + codebook_bytes) as u64,
402            compression_ratio: original_bytes as f64 / (compressed_bytes + codebook_bytes) as f64,
403        }
404    }
405}
406
407#[derive(Debug, Clone)]
408pub struct CompressionStats {
409    pub vector_count: u64,
410    pub original_bytes: u64,
411    pub compressed_bytes: u64,
412    pub codebook_bytes: u64,
413    pub total_bytes: u64,
414    pub compression_ratio: f64,
415}
416
417/// K-means clustering for a single subspace
418fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
419    if vectors.is_empty() {
420        return Err(MemvidError::InvalidQuery {
421            reason: "Cannot run k-means on empty vector set".to_string(),
422        });
423    }
424
425    let dim = vectors[0].len();
426
427    // Initialize centroids using k-means++ for better convergence
428    let mut centroids = kmeans_plus_plus_init(vectors, k)?;
429
430    for _iteration in 0..max_iterations {
431        // Assignment step: assign each vector to nearest centroid
432        let mut assignments = vec![Vec::new(); k];
433
434        for vec in vectors {
435            let mut best_cluster = 0;
436            let mut best_dist = f32::INFINITY;
437
438            for (cluster_idx, centroid) in centroids.iter().enumerate() {
439                let dist = l2_distance_squared(vec, centroid);
440                if dist < best_dist {
441                    best_dist = dist;
442                    best_cluster = cluster_idx;
443                }
444            }
445
446            assignments[best_cluster].push(vec.clone());
447        }
448
449        // Update step: recompute centroids
450        let mut changed = false;
451        for (cluster_idx, cluster_vecs) in assignments.iter().enumerate() {
452            if cluster_vecs.is_empty() {
453                // Empty cluster: reinitialize with random vector
454                centroids[cluster_idx] = vectors[cluster_idx % vectors.len()].clone();
455                changed = true;
456                continue;
457            }
458
459            let mut new_centroid = vec![0.0f32; dim];
460            for vec in cluster_vecs {
461                for (i, &val) in vec.iter().enumerate() {
462                    new_centroid[i] += val;
463                }
464            }
465            for val in &mut new_centroid {
466                *val /= cluster_vecs.len() as f32;
467            }
468
469            // Check if centroid changed
470            if l2_distance_squared(&centroids[cluster_idx], &new_centroid) > 1e-6 {
471                changed = true;
472            }
473
474            centroids[cluster_idx] = new_centroid;
475        }
476
477        if !changed {
478            break; // Converged
479        }
480    }
481
482    Ok(centroids)
483}
484
485/// K-means++ initialization for better initial centroids
486fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
487    if vectors.is_empty() || k == 0 {
488        return Err(MemvidError::InvalidQuery {
489            reason: "Invalid k-means++ initialization".to_string(),
490        });
491    }
492
493    let mut centroids = Vec::new();
494
495    // Choose first centroid randomly (use first vector for determinism)
496    centroids.push(vectors[0].clone());
497
498    // Choose remaining k-1 centroids
499    for _ in 1..k {
500        let mut distances = Vec::new();
501
502        // Compute distance to nearest existing centroid for each vector
503        for vec in vectors {
504            let mut min_dist = f32::INFINITY;
505            for centroid in &centroids {
506                let dist = l2_distance_squared(vec, centroid);
507                min_dist = min_dist.min(dist);
508            }
509            distances.push(min_dist);
510        }
511
512        // Choose vector with maximum distance as next centroid
513        let max_idx = distances
514            .iter()
515            .enumerate()
516            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
517            .map_or(0, |(idx, _)| idx);
518
519        centroids.push(vectors[max_idx].clone());
520    }
521
522    Ok(centroids)
523}
524
525/// Squared L2 distance between two vectors
526fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
527    crate::simd::l2_distance_squared_simd(a, b)
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    #[test]
535    fn test_subspace_codebook() {
536        let mut codebook = SubspaceCodebook::new();
537
538        // Set a centroid
539        codebook.set_centroid(0, &[1.0, 2.0, 3.0, 4.0]);
540
541        // Retrieve it
542        let centroid = codebook.get_centroid(0);
543        assert_eq!(centroid, &[1.0, 2.0, 3.0, 4.0]);
544
545        // Quantize a similar vector
546        let code = codebook.quantize(&[1.1, 2.1, 3.1, 4.1]);
547        assert_eq!(code, 0);
548    }
549
550    #[test]
551    fn test_product_quantizer_roundtrip() {
552        // Create sample 384-dim vectors
553        let mut training_vecs = Vec::new();
554        for i in 0..100 {
555            let mut vec = vec![0.0f32; TOTAL_DIM];
556            for j in 0..TOTAL_DIM {
557                vec[j] = ((i * TOTAL_DIM + j) % 100) as f32 / 100.0;
558            }
559            training_vecs.push(vec);
560        }
561
562        // Train quantizer
563        let mut pq = ProductQuantizer::new(u32::try_from(TOTAL_DIM).unwrap()).unwrap();
564        pq.train(&training_vecs, 10).unwrap();
565
566        // Encode a vector
567        let test_vec = &training_vecs[0];
568        let codes = pq.encode(test_vec).unwrap();
569        assert_eq!(codes.len(), NUM_SUBSPACES);
570
571        // Decode and verify approximate reconstruction
572        let decoded = pq.decode(&codes).unwrap();
573        assert_eq!(decoded.len(), TOTAL_DIM);
574
575        // Distance between original and decoded should be small
576        let dist = l2_distance_squared(test_vec, &decoded).sqrt();
577        assert!(dist < 10.0, "Reconstruction error too large: {}", dist);
578    }
579
580    #[test]
581    fn test_quantized_index_builder() {
582        // Create sample vectors
583        let mut training_vecs = Vec::new();
584        for i in 0..50 {
585            let mut vec = vec![0.0f32; TOTAL_DIM];
586            for j in 0..TOTAL_DIM {
587                vec[j] = ((i + j) % 10) as f32;
588            }
589            training_vecs.push(vec);
590        }
591
592        // Build index
593        let mut builder = QuantizedVecIndexBuilder::new();
594        builder
595            .train_quantizer(&training_vecs, u32::try_from(TOTAL_DIM).unwrap())
596            .unwrap();
597
598        for (i, vec) in training_vecs.iter().take(10).enumerate() {
599            builder
600                .add_document((i + 1) as FrameId, vec.clone())
601                .unwrap();
602        }
603
604        let artifact = builder.finish().unwrap();
605        assert_eq!(artifact.vector_count, 10);
606        assert_eq!(artifact.dimension, u32::try_from(TOTAL_DIM).unwrap());
607        assert!(artifact.compression_ratio > 10.0);
608
609        // Decode and search
610        let index = QuantizedVecIndex::decode(&artifact.bytes).unwrap();
611        let query = &training_vecs[0];
612        let hits = index.search(query, 5);
613
614        assert!(!hits.is_empty());
615        assert_eq!(hits[0].frame_id, 1); // Should find exact match first
616    }
617
618    #[test]
619    fn test_kmeans_simple() {
620        let vectors = vec![
621            vec![0.0, 0.0],
622            vec![0.1, 0.1],
623            vec![10.0, 10.0],
624            vec![10.1, 10.1],
625        ];
626
627        let centroids = kmeans(&vectors, 2, 100).unwrap();
628        assert_eq!(centroids.len(), 2);
629
630        // One centroid should be near [0, 0], the other near [10, 10]
631        let near_zero = centroids.iter().any(|c| c[0] < 5.0 && c[1] < 5.0);
632        let near_ten = centroids.iter().any(|c| c[0] > 5.0 && c[1] > 5.0);
633        assert!(near_zero && near_ten);
634    }
635}