Skip to main content

diskann_rs/
pq.rs

1//! # Product Quantization (PQ) for Vector Compression
2//!
3//! Product Quantization compresses high-dimensional vectors by:
4//! 1. Dividing each vector into M subspaces (segments)
5//! 2. Training K centroids per subspace (codebook)
6//! 3. Representing each segment by its nearest centroid ID (1 byte for K=256)
7//!
8//! ## Compression Ratio
9//!
10//! For 128-dim float vectors (512 bytes) with M=8 subspaces:
11//! - Original: 512 bytes
12//! - Compressed: 8 bytes (one centroid ID per subspace)
13//! - Compression: 64x
14//!
15//! ## Usage
16//!
17//! ```ignore
18//! use diskann_rs::pq::{ProductQuantizer, PQConfig};
19//!
20//! // Train a quantizer on sample vectors
21//! let vectors: Vec<Vec<f32>> = load_your_training_data();
22//! let config = PQConfig::default(); // 8 subspaces, 256 centroids each
23//! let pq = ProductQuantizer::train(&vectors, config).unwrap();
24//!
25//! // Encode vectors (each becomes M bytes)
26//! let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
27//!
28//! // Compute asymmetric distance (query vs quantized database vector)
29//! let query = vec![0.0f32; 128];
30//! let dist = pq.asymmetric_distance(&query, &codes[0]);
31//! ```
32//!
33//! ## Asymmetric Distance Computation (ADC)
34//!
35//! For search, we compute exact query-to-centroid distances once,
36//! then use a lookup table for fast distance approximation.
37
38use crate::DiskAnnError;
39use rand::prelude::*;
40use rayon::prelude::*;
41use serde::{Deserialize, Serialize};
42use std::fs::File;
43use std::io::{BufReader, BufWriter};
44
45/// Configuration for Product Quantization
46#[derive(Clone, Copy, Debug)]
47pub struct PQConfig {
48    /// Number of subspaces (M). Vector is divided into M segments.
49    /// Typical values: 4, 8, 16, 32
50    pub num_subspaces: usize,
51    /// Number of centroids per subspace (K). Typically 256 (fits in u8).
52    pub num_centroids: usize,
53    /// Number of k-means iterations for training
54    pub kmeans_iterations: usize,
55    /// Sample size for training (if 0, use all vectors)
56    pub training_sample_size: usize,
57}
58
59impl Default for PQConfig {
60    fn default() -> Self {
61        Self {
62            num_subspaces: 8,
63            num_centroids: 256,
64            kmeans_iterations: 20,
65            training_sample_size: 50_000,
66        }
67    }
68}
69
70/// Trained Product Quantizer
71#[derive(Serialize, Deserialize, Clone)]
72pub struct ProductQuantizer {
73    /// Dimension of original vectors
74    dim: usize,
75    /// Number of subspaces
76    num_subspaces: usize,
77    /// Number of centroids per subspace
78    num_centroids: usize,
79    /// Dimension of each subspace
80    subspace_dim: usize,
81    /// Codebooks: [num_subspaces][num_centroids][subspace_dim]
82    /// Flattened for cache efficiency
83    codebooks: Vec<f32>,
84}
85
86impl ProductQuantizer {
87    /// Train a product quantizer on a set of vectors
88    pub fn train(vectors: &[Vec<f32>], config: PQConfig) -> Result<Self, DiskAnnError> {
89        if vectors.is_empty() {
90            return Err(DiskAnnError::IndexError("No vectors to train on".into()));
91        }
92
93        let dim = vectors[0].len();
94        if config.num_centroids > 256 {
95            return Err(DiskAnnError::IndexError(format!(
96                "num_centroids {} exceeds 256 (codes are u8)",
97                config.num_centroids
98            )));
99        }
100        if dim % config.num_subspaces != 0 {
101            return Err(DiskAnnError::IndexError(format!(
102                "Dimension {} not divisible by num_subspaces {}",
103                dim, config.num_subspaces
104            )));
105        }
106
107        let subspace_dim = dim / config.num_subspaces;
108
109        // Sample training vectors if needed
110        let training_vectors: Vec<&Vec<f32>> = if config.training_sample_size > 0
111            && vectors.len() > config.training_sample_size
112        {
113            let mut rng = thread_rng();
114            vectors
115                .choose_multiple(&mut rng, config.training_sample_size)
116                .collect()
117        } else {
118            vectors.iter().collect()
119        };
120
121        // Train codebook for each subspace (parallel)
122        let codebooks_per_subspace: Vec<Vec<f32>> = (0..config.num_subspaces)
123            .into_par_iter()
124            .map(|m| {
125                // Extract subspace vectors
126                let start = m * subspace_dim;
127                let end = start + subspace_dim;
128
129                let subspace_vectors: Vec<Vec<f32>> = training_vectors
130                    .iter()
131                    .map(|v| v[start..end].to_vec())
132                    .collect();
133
134                // Run k-means
135                kmeans(
136                    &subspace_vectors,
137                    config.num_centroids,
138                    config.kmeans_iterations,
139                )
140            })
141            .collect();
142
143        // Flatten codebooks for cache efficiency
144        let mut codebooks =
145            Vec::with_capacity(config.num_subspaces * config.num_centroids * subspace_dim);
146        for cb in &codebooks_per_subspace {
147            codebooks.extend_from_slice(cb);
148        }
149
150        Ok(Self {
151            dim,
152            num_subspaces: config.num_subspaces,
153            num_centroids: config.num_centroids,
154            subspace_dim,
155            codebooks,
156        })
157    }
158
159    /// Encode a vector into PQ codes (M bytes)
160    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
161        assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
162
163        let mut codes = Vec::with_capacity(self.num_subspaces);
164
165        for m in 0..self.num_subspaces {
166            let start = m * self.subspace_dim;
167            let end = start + self.subspace_dim;
168            let subvec = &vector[start..end];
169
170            // Find nearest centroid
171            let mut best_centroid = 0u8;
172            let mut best_dist = f32::MAX;
173
174            for k in 0..self.num_centroids {
175                let centroid = self.get_centroid(m, k);
176                let dist = l2_distance(subvec, centroid);
177                if dist < best_dist {
178                    best_dist = dist;
179                    best_centroid = k as u8;
180                }
181            }
182
183            codes.push(best_centroid);
184        }
185
186        codes
187    }
188
189    /// Batch encode vectors (parallel)
190    pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
191        vectors.par_iter().map(|v| self.encode(v)).collect()
192    }
193
194    /// Compute asymmetric distance between query and quantized vector
195    /// This uses precomputed distance tables for efficiency
196    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
197        assert_eq!(query.len(), self.dim, "Query dimension mismatch");
198        assert_eq!(codes.len(), self.num_subspaces, "Code length mismatch");
199
200        let mut total_dist = 0.0f32;
201
202        for m in 0..self.num_subspaces {
203            let start = m * self.subspace_dim;
204            let end = start + self.subspace_dim;
205            let query_sub = &query[start..end];
206
207            let centroid_id = codes[m] as usize;
208            let centroid = self.get_centroid(m, centroid_id);
209
210            total_dist += l2_distance(query_sub, centroid);
211        }
212
213        total_dist
214    }
215
216    /// Create a distance lookup table for a query (for fast batch queries)
217    ///
218    /// Returns: `[num_subspaces][num_centroids]` distance table
219    pub fn create_distance_table(&self, query: &[f32]) -> Vec<f32> {
220        assert_eq!(query.len(), self.dim);
221
222        let mut table = Vec::with_capacity(self.num_subspaces * self.num_centroids);
223
224        for m in 0..self.num_subspaces {
225            let start = m * self.subspace_dim;
226            let end = start + self.subspace_dim;
227            let query_sub = &query[start..end];
228
229            for k in 0..self.num_centroids {
230                let centroid = self.get_centroid(m, k);
231                table.push(l2_distance(query_sub, centroid));
232            }
233        }
234
235        table
236    }
237
238    /// Compute distance using precomputed table (very fast)
239    #[inline]
240    pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
241        let mut dist = 0.0f32;
242        for (m, &code) in codes.iter().enumerate() {
243            let idx = m * self.num_centroids + code as usize;
244            dist += table[idx];
245        }
246        dist
247    }
248
249    /// Decode PQ codes back to approximate vector
250    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
251        assert_eq!(codes.len(), self.num_subspaces);
252
253        let mut vector = Vec::with_capacity(self.dim);
254
255        for (m, &code) in codes.iter().enumerate() {
256            let centroid = self.get_centroid(m, code as usize);
257            vector.extend_from_slice(centroid);
258        }
259
260        vector
261    }
262
263    /// Get centroid for subspace m, centroid k
264    #[inline]
265    fn get_centroid(&self, m: usize, k: usize) -> &[f32] {
266        let offset = (m * self.num_centroids + k) * self.subspace_dim;
267        &self.codebooks[offset..offset + self.subspace_dim]
268    }
269
270    /// Save quantizer to file
271    pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
272        let file = File::create(path)?;
273        let writer = BufWriter::new(file);
274        bincode::serialize_into(writer, self)?;
275        Ok(())
276    }
277
278    /// Load quantizer from file
279    pub fn load(path: &str) -> Result<Self, DiskAnnError> {
280        let file = File::open(path)?;
281        let reader = BufReader::new(file);
282        let pq: Self = bincode::deserialize_from(reader)?;
283        Ok(pq)
284    }
285
286    /// Get stats about the quantizer
287    pub fn stats(&self) -> PQStats {
288        PQStats {
289            dim: self.dim,
290            num_subspaces: self.num_subspaces,
291            num_centroids: self.num_centroids,
292            subspace_dim: self.subspace_dim,
293            codebook_size_bytes: self.codebooks.len() * 4,
294            code_size_bytes: self.num_subspaces,
295            compression_ratio: (self.dim * 4) as f32 / self.num_subspaces as f32,
296        }
297    }
298}
299
300/// Statistics about a ProductQuantizer
301#[derive(Debug, Clone)]
302pub struct PQStats {
303    pub dim: usize,
304    pub num_subspaces: usize,
305    pub num_centroids: usize,
306    pub subspace_dim: usize,
307    pub codebook_size_bytes: usize,
308    pub code_size_bytes: usize,
309    pub compression_ratio: f32,
310}
311
312impl std::fmt::Display for PQStats {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        writeln!(f, "Product Quantizer Stats:")?;
315        writeln!(f, "  Original dimension: {}", self.dim)?;
316        writeln!(f, "  Subspaces (M): {}", self.num_subspaces)?;
317        writeln!(f, "  Centroids per subspace (K): {}", self.num_centroids)?;
318        writeln!(f, "  Subspace dimension: {}", self.subspace_dim)?;
319        writeln!(f, "  Codebook size: {} bytes", self.codebook_size_bytes)?;
320        writeln!(f, "  Compressed code size: {} bytes", self.code_size_bytes)?;
321        writeln!(f, "  Compression ratio: {:.1}x", self.compression_ratio)
322    }
323}
324
325/// Simple k-means clustering
326/// Returns flattened centroids of shape [k * dim]
327/// Note: Always returns exactly `k` centroids (replicates if n < k)
328fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<f32> {
329    if vectors.is_empty() {
330        return vec![0.0; k * 1]; // shouldn't happen, but safe fallback
331    }
332
333    let dim = vectors[0].len();
334    let n = vectors.len();
335    let effective_k = k.min(n); // For clustering, use available vectors
336
337    // Initialize centroids with k-means++ style
338    let mut centroids = Vec::with_capacity(k * dim);
339    let mut rng = thread_rng();
340
341    // First centroid: random vector
342    let first = rng.gen_range(0..n);
343    centroids.extend_from_slice(&vectors[first]);
344
345    // Remaining centroids: weighted by distance to nearest existing centroid (k-means++)
346    // Only do this for effective_k-1 more centroids (based on actual unique vectors)
347    for _ in 1..effective_k {
348        let num_current = centroids.len() / dim;
349        let distances: Vec<f32> = vectors
350            .iter()
351            .map(|v| {
352                let mut min_dist = f32::MAX;
353                for c in 0..num_current {
354                    let centroid = &centroids[c * dim..(c + 1) * dim];
355                    let d = l2_distance(v, centroid);
356                    min_dist = min_dist.min(d);
357                }
358                min_dist
359            })
360            .collect();
361
362        // Sample weighted by distance
363        let total: f32 = distances.iter().sum();
364        if total == 0.0 {
365            // All points are at centroids, pick random
366            let idx = rng.gen_range(0..n);
367            centroids.extend_from_slice(&vectors[idx]);
368        } else {
369            let threshold = rng.r#gen::<f32>() * total;
370            let mut cumsum = 0.0;
371            let mut picked = false;
372            for (i, &d) in distances.iter().enumerate() {
373                cumsum += d;
374                if cumsum >= threshold {
375                    centroids.extend_from_slice(&vectors[i]);
376                    picked = true;
377                    break;
378                }
379            }
380            // Fallback if we didn't pick one (can happen with float precision)
381            if !picked {
382                centroids.extend_from_slice(&vectors[n - 1]);
383            }
384        }
385    }
386
387    // If k > n, replicate existing centroids to reach k
388    while centroids.len() < k * dim {
389        // Cycle through existing centroids
390        let idx = (centroids.len() / dim) % effective_k;
391        let centroid = centroids[idx * dim..(idx + 1) * dim].to_vec();
392        centroids.extend_from_slice(&centroid);
393    }
394    centroids.truncate(k * dim);
395
396    // Lloyd's algorithm iterations
397    let mut assignments: Vec<usize>;
398
399    for _ in 0..iterations {
400        // Assignment step (parallel)
401        assignments = vectors
402            .par_iter()
403            .map(|v| {
404                let mut best_c = 0;
405                let mut best_dist = f32::MAX;
406                for c in 0..k {
407                    let centroid = &centroids[c * dim..(c + 1) * dim];
408                    let d = l2_distance(v, centroid);
409                    if d < best_dist {
410                        best_dist = d;
411                        best_c = c;
412                    }
413                }
414                best_c
415            })
416            .collect();
417
418        // Update step
419        let mut new_centroids = vec![0.0f32; k * dim];
420        let mut counts = vec![0usize; k];
421
422        for (i, &c) in assignments.iter().enumerate() {
423            counts[c] += 1;
424            for (j, &val) in vectors[i].iter().enumerate() {
425                new_centroids[c * dim + j] += val;
426            }
427        }
428
429        // Average
430        for c in 0..k {
431            if counts[c] > 0 {
432                for j in 0..dim {
433                    new_centroids[c * dim + j] /= counts[c] as f32;
434                }
435            } else {
436                // Empty cluster: reinitialize from random point
437                let idx = rng.gen_range(0..n);
438                for j in 0..dim {
439                    new_centroids[c * dim + j] = vectors[idx][j];
440                }
441            }
442        }
443
444        centroids = new_centroids;
445    }
446
447    centroids
448}
449
450/// L2 squared distance
451#[inline]
452fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
453    a.iter()
454        .zip(b.iter())
455        .map(|(x, y)| {
456            let d = x - y;
457            d * d
458        })
459        .sum()
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
467        use rand::SeedableRng;
468        let mut rng = StdRng::seed_from_u64(seed);
469        (0..n)
470            .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
471            .collect()
472    }
473
474    #[test]
475    fn test_pq_encode_decode() {
476        let vectors = random_vectors(1000, 64, 42);
477        let config = PQConfig {
478            num_subspaces: 8,
479            num_centroids: 256,
480            kmeans_iterations: 10,
481            training_sample_size: 0,
482        };
483
484        let pq = ProductQuantizer::train(&vectors, config).unwrap();
485
486        // Encode and decode
487        let original = &vectors[0];
488        let codes = pq.encode(original);
489        let decoded = pq.decode(&codes);
490
491        // Should have same dimension
492        assert_eq!(decoded.len(), original.len());
493
494        // Decoded should be somewhat close to original (lossy compression)
495        let dist = l2_distance(original, &decoded);
496        assert!(
497            dist < original.len() as f32 * 0.1,
498            "Reconstruction error too high: {dist}"
499        );
500    }
501
502    #[test]
503    fn test_pq_asymmetric_distance() {
504        let vectors = random_vectors(500, 32, 123);
505        let config = PQConfig {
506            num_subspaces: 4,
507            num_centroids: 64,
508            kmeans_iterations: 10,
509            training_sample_size: 0,
510        };
511
512        let pq = ProductQuantizer::train(&vectors, config).unwrap();
513
514        let query = &vectors[0];
515        let target = &vectors[100];
516
517        let codes = pq.encode(target);
518
519        // Asymmetric distance should be similar to distance to decoded
520        let asym_dist = pq.asymmetric_distance(query, &codes);
521        let decoded = pq.decode(&codes);
522        let exact_dist = l2_distance(query, &decoded);
523
524        // Should be very close (asymmetric uses same centroids)
525        assert!(
526            (asym_dist - exact_dist).abs() < 1e-5,
527            "asym={asym_dist}, exact={exact_dist}"
528        );
529    }
530
531    #[test]
532    fn test_pq_distance_table() {
533        let vectors = random_vectors(500, 32, 456);
534        let config = PQConfig {
535            num_subspaces: 4,
536            num_centroids: 64,
537            kmeans_iterations: 10,
538            training_sample_size: 0,
539        };
540
541        let pq = ProductQuantizer::train(&vectors, config).unwrap();
542
543        let query = &vectors[0];
544        let table = pq.create_distance_table(query);
545
546        // Compare table-based vs direct asymmetric distance
547        for target in vectors.iter().take(10) {
548            let codes = pq.encode(target);
549            let direct = pq.asymmetric_distance(query, &codes);
550            let table_dist = pq.distance_with_table(&table, &codes);
551
552            assert!(
553                (direct - table_dist).abs() < 1e-5,
554                "direct={direct}, table={table_dist}"
555            );
556        }
557    }
558
559    #[test]
560    fn test_pq_batch_encode() {
561        let vectors = random_vectors(100, 64, 789);
562        let config = PQConfig::default();
563
564        let pq = ProductQuantizer::train(&vectors, config).unwrap();
565        let codes = pq.encode_batch(&vectors);
566
567        assert_eq!(codes.len(), vectors.len());
568        for code in &codes {
569            assert_eq!(code.len(), config.num_subspaces);
570        }
571    }
572
573    #[test]
574    fn test_pq_save_load() {
575        let vectors = random_vectors(200, 64, 111);
576        let config = PQConfig {
577            num_subspaces: 8,
578            num_centroids: 128,
579            kmeans_iterations: 5,
580            training_sample_size: 0,
581        };
582
583        let pq = ProductQuantizer::train(&vectors, config).unwrap();
584        let codes_before = pq.encode(&vectors[0]);
585
586        let path = "test_pq.bin";
587        pq.save(path).unwrap();
588
589        let pq_loaded = ProductQuantizer::load(path).unwrap();
590        let codes_after = pq_loaded.encode(&vectors[0]);
591
592        assert_eq!(codes_before, codes_after);
593
594        std::fs::remove_file(path).ok();
595    }
596
597    #[test]
598    fn test_pq_stats() {
599        let vectors = random_vectors(100, 128, 222);
600        let config = PQConfig {
601            num_subspaces: 8,
602            num_centroids: 256,
603            kmeans_iterations: 5,
604            training_sample_size: 0,
605        };
606
607        let pq = ProductQuantizer::train(&vectors, config).unwrap();
608        let stats = pq.stats();
609
610        assert_eq!(stats.dim, 128);
611        assert_eq!(stats.num_subspaces, 8);
612        assert_eq!(stats.num_centroids, 256);
613        assert_eq!(stats.subspace_dim, 16);
614        assert_eq!(stats.code_size_bytes, 8);
615        assert!(stats.compression_ratio > 50.0); // 128*4 / 8 = 64x
616
617        println!("{}", stats);
618    }
619
620    #[test]
621    fn test_pq_preserves_ordering() {
622        let vectors = random_vectors(500, 64, 333);
623        let config = PQConfig {
624            num_subspaces: 8,
625            num_centroids: 256,
626            kmeans_iterations: 15,
627            training_sample_size: 0,
628        };
629
630        let pq = ProductQuantizer::train(&vectors, config).unwrap();
631
632        let query = &vectors[0];
633
634        // Compute true distances
635        let mut true_dists: Vec<(usize, f32)> = vectors
636            .iter()
637            .enumerate()
638            .skip(1)
639            .map(|(i, v)| (i, l2_distance(query, v)))
640            .collect();
641        true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
642
643        // Compute PQ distances
644        let table = pq.create_distance_table(query);
645        let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
646
647        let mut pq_dists: Vec<(usize, f32)> = codes
648            .iter()
649            .enumerate()
650            .skip(1)
651            .map(|(i, c)| (i, pq.distance_with_table(&table, c)))
652            .collect();
653        pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
654
655        // Check recall@10: how many of true top-10 appear in PQ top-10
656        let true_top10: std::collections::HashSet<_> =
657            true_dists.iter().take(10).map(|(i, _)| *i).collect();
658        let pq_top10: std::collections::HashSet<_> =
659            pq_dists.iter().take(10).map(|(i, _)| *i).collect();
660
661        let recall: f32 = true_top10.intersection(&pq_top10).count() as f32 / 10.0;
662        assert!(
663            recall >= 0.5,
664            "PQ recall@10 too low: {recall}. Expected >= 0.5"
665        );
666    }
667}