diskann_rs/
pq.rs

1//! # Product Quantization (PQ) for Vector Compression
2//!
3//! Product Quantization compresses high-dimensional vectors by:
4//! 1. Dividing each vector into M subspaces (segments)
5//! 2. Training K centroids per subspace (codebook)
6//! 3. Representing each segment by its nearest centroid ID (1 byte for K=256)
7//!
8//! ## Compression Ratio
9//!
10//! For 128-dim float vectors (512 bytes) with M=8 subspaces:
11//! - Original: 512 bytes
12//! - Compressed: 8 bytes (one centroid ID per subspace)
13//! - Compression: 64x
14//!
15//! ## Usage
16//!
17//! ```ignore
18//! use diskann_rs::pq::{ProductQuantizer, PQConfig};
19//!
20//! // Train a quantizer on sample vectors
21//! let vectors: Vec<Vec<f32>> = load_your_training_data();
22//! let config = PQConfig::default(); // 8 subspaces, 256 centroids each
23//! let pq = ProductQuantizer::train(&vectors, config).unwrap();
24//!
25//! // Encode vectors (each becomes M bytes)
26//! let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
27//!
28//! // Compute asymmetric distance (query vs quantized database vector)
29//! let query = vec![0.0f32; 128];
30//! let dist = pq.asymmetric_distance(&query, &codes[0]);
31//! ```
32//!
33//! ## Asymmetric Distance Computation (ADC)
34//!
35//! For search, we compute exact query-to-centroid distances once,
36//! then use a lookup table for fast distance approximation.
37
38use crate::DiskAnnError;
39use rand::prelude::*;
40use rayon::prelude::*;
41use serde::{Deserialize, Serialize};
42use std::fs::File;
43use std::io::{BufReader, BufWriter};
44
45/// Configuration for Product Quantization
46#[derive(Clone, Copy, Debug)]
47pub struct PQConfig {
48    /// Number of subspaces (M). Vector is divided into M segments.
49    /// Typical values: 4, 8, 16, 32
50    pub num_subspaces: usize,
51    /// Number of centroids per subspace (K). Typically 256 (fits in u8).
52    pub num_centroids: usize,
53    /// Number of k-means iterations for training
54    pub kmeans_iterations: usize,
55    /// Sample size for training (if 0, use all vectors)
56    pub training_sample_size: usize,
57}
58
59impl Default for PQConfig {
60    fn default() -> Self {
61        Self {
62            num_subspaces: 8,
63            num_centroids: 256,
64            kmeans_iterations: 20,
65            training_sample_size: 50_000,
66        }
67    }
68}
69
70/// Trained Product Quantizer
71#[derive(Serialize, Deserialize, Clone)]
72pub struct ProductQuantizer {
73    /// Dimension of original vectors
74    dim: usize,
75    /// Number of subspaces
76    num_subspaces: usize,
77    /// Number of centroids per subspace
78    num_centroids: usize,
79    /// Dimension of each subspace
80    subspace_dim: usize,
81    /// Codebooks: [num_subspaces][num_centroids][subspace_dim]
82    /// Flattened for cache efficiency
83    codebooks: Vec<f32>,
84}
85
86impl ProductQuantizer {
87    /// Train a product quantizer on a set of vectors
88    pub fn train(vectors: &[Vec<f32>], config: PQConfig) -> Result<Self, DiskAnnError> {
89        if vectors.is_empty() {
90            return Err(DiskAnnError::IndexError("No vectors to train on".into()));
91        }
92
93        let dim = vectors[0].len();
94        if dim % config.num_subspaces != 0 {
95            return Err(DiskAnnError::IndexError(format!(
96                "Dimension {} not divisible by num_subspaces {}",
97                dim, config.num_subspaces
98            )));
99        }
100
101        let subspace_dim = dim / config.num_subspaces;
102
103        // Sample training vectors if needed
104        let training_vectors: Vec<&Vec<f32>> = if config.training_sample_size > 0
105            && vectors.len() > config.training_sample_size
106        {
107            let mut rng = thread_rng();
108            vectors
109                .choose_multiple(&mut rng, config.training_sample_size)
110                .collect()
111        } else {
112            vectors.iter().collect()
113        };
114
115        // Train codebook for each subspace (parallel)
116        let codebooks_per_subspace: Vec<Vec<f32>> = (0..config.num_subspaces)
117            .into_par_iter()
118            .map(|m| {
119                // Extract subspace vectors
120                let start = m * subspace_dim;
121                let end = start + subspace_dim;
122
123                let subspace_vectors: Vec<Vec<f32>> = training_vectors
124                    .iter()
125                    .map(|v| v[start..end].to_vec())
126                    .collect();
127
128                // Run k-means
129                kmeans(
130                    &subspace_vectors,
131                    config.num_centroids,
132                    config.kmeans_iterations,
133                )
134            })
135            .collect();
136
137        // Flatten codebooks for cache efficiency
138        let mut codebooks =
139            Vec::with_capacity(config.num_subspaces * config.num_centroids * subspace_dim);
140        for cb in &codebooks_per_subspace {
141            codebooks.extend_from_slice(cb);
142        }
143
144        Ok(Self {
145            dim,
146            num_subspaces: config.num_subspaces,
147            num_centroids: config.num_centroids,
148            subspace_dim,
149            codebooks,
150        })
151    }
152
153    /// Encode a vector into PQ codes (M bytes)
154    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
155        assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
156
157        let mut codes = Vec::with_capacity(self.num_subspaces);
158
159        for m in 0..self.num_subspaces {
160            let start = m * self.subspace_dim;
161            let end = start + self.subspace_dim;
162            let subvec = &vector[start..end];
163
164            // Find nearest centroid
165            let mut best_centroid = 0u8;
166            let mut best_dist = f32::MAX;
167
168            for k in 0..self.num_centroids {
169                let centroid = self.get_centroid(m, k);
170                let dist = l2_distance(subvec, centroid);
171                if dist < best_dist {
172                    best_dist = dist;
173                    best_centroid = k as u8;
174                }
175            }
176
177            codes.push(best_centroid);
178        }
179
180        codes
181    }
182
183    /// Batch encode vectors (parallel)
184    pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
185        vectors.par_iter().map(|v| self.encode(v)).collect()
186    }
187
188    /// Compute asymmetric distance between query and quantized vector
189    /// This uses precomputed distance tables for efficiency
190    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
191        assert_eq!(query.len(), self.dim, "Query dimension mismatch");
192        assert_eq!(codes.len(), self.num_subspaces, "Code length mismatch");
193
194        let mut total_dist = 0.0f32;
195
196        for m in 0..self.num_subspaces {
197            let start = m * self.subspace_dim;
198            let end = start + self.subspace_dim;
199            let query_sub = &query[start..end];
200
201            let centroid_id = codes[m] as usize;
202            let centroid = self.get_centroid(m, centroid_id);
203
204            total_dist += l2_distance(query_sub, centroid);
205        }
206
207        total_dist
208    }
209
210    /// Create a distance lookup table for a query (for fast batch queries)
211    ///
212    /// Returns: `[num_subspaces][num_centroids]` distance table
213    pub fn create_distance_table(&self, query: &[f32]) -> Vec<f32> {
214        assert_eq!(query.len(), self.dim);
215
216        let mut table = Vec::with_capacity(self.num_subspaces * self.num_centroids);
217
218        for m in 0..self.num_subspaces {
219            let start = m * self.subspace_dim;
220            let end = start + self.subspace_dim;
221            let query_sub = &query[start..end];
222
223            for k in 0..self.num_centroids {
224                let centroid = self.get_centroid(m, k);
225                table.push(l2_distance(query_sub, centroid));
226            }
227        }
228
229        table
230    }
231
232    /// Compute distance using precomputed table (very fast)
233    #[inline]
234    pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
235        let mut dist = 0.0f32;
236        for (m, &code) in codes.iter().enumerate() {
237            let idx = m * self.num_centroids + code as usize;
238            dist += table[idx];
239        }
240        dist
241    }
242
243    /// Decode PQ codes back to approximate vector
244    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
245        assert_eq!(codes.len(), self.num_subspaces);
246
247        let mut vector = Vec::with_capacity(self.dim);
248
249        for (m, &code) in codes.iter().enumerate() {
250            let centroid = self.get_centroid(m, code as usize);
251            vector.extend_from_slice(centroid);
252        }
253
254        vector
255    }
256
257    /// Get centroid for subspace m, centroid k
258    #[inline]
259    fn get_centroid(&self, m: usize, k: usize) -> &[f32] {
260        let offset = (m * self.num_centroids + k) * self.subspace_dim;
261        &self.codebooks[offset..offset + self.subspace_dim]
262    }
263
264    /// Save quantizer to file
265    pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
266        let file = File::create(path)?;
267        let writer = BufWriter::new(file);
268        bincode::serialize_into(writer, self)?;
269        Ok(())
270    }
271
272    /// Load quantizer from file
273    pub fn load(path: &str) -> Result<Self, DiskAnnError> {
274        let file = File::open(path)?;
275        let reader = BufReader::new(file);
276        let pq: Self = bincode::deserialize_from(reader)?;
277        Ok(pq)
278    }
279
280    /// Get stats about the quantizer
281    pub fn stats(&self) -> PQStats {
282        PQStats {
283            dim: self.dim,
284            num_subspaces: self.num_subspaces,
285            num_centroids: self.num_centroids,
286            subspace_dim: self.subspace_dim,
287            codebook_size_bytes: self.codebooks.len() * 4,
288            code_size_bytes: self.num_subspaces,
289            compression_ratio: (self.dim * 4) as f32 / self.num_subspaces as f32,
290        }
291    }
292}
293
294/// Statistics about a ProductQuantizer
295#[derive(Debug, Clone)]
296pub struct PQStats {
297    pub dim: usize,
298    pub num_subspaces: usize,
299    pub num_centroids: usize,
300    pub subspace_dim: usize,
301    pub codebook_size_bytes: usize,
302    pub code_size_bytes: usize,
303    pub compression_ratio: f32,
304}
305
306impl std::fmt::Display for PQStats {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        writeln!(f, "Product Quantizer Stats:")?;
309        writeln!(f, "  Original dimension: {}", self.dim)?;
310        writeln!(f, "  Subspaces (M): {}", self.num_subspaces)?;
311        writeln!(f, "  Centroids per subspace (K): {}", self.num_centroids)?;
312        writeln!(f, "  Subspace dimension: {}", self.subspace_dim)?;
313        writeln!(f, "  Codebook size: {} bytes", self.codebook_size_bytes)?;
314        writeln!(f, "  Compressed code size: {} bytes", self.code_size_bytes)?;
315        writeln!(f, "  Compression ratio: {:.1}x", self.compression_ratio)
316    }
317}
318
319/// Simple k-means clustering
320/// Returns flattened centroids of shape [k * dim]
321/// Note: Always returns exactly `k` centroids (replicates if n < k)
322fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<f32> {
323    if vectors.is_empty() {
324        return vec![0.0; k * 1]; // shouldn't happen, but safe fallback
325    }
326
327    let dim = vectors[0].len();
328    let n = vectors.len();
329    let effective_k = k.min(n); // For clustering, use available vectors
330
331    // Initialize centroids with k-means++ style
332    let mut centroids = Vec::with_capacity(k * dim);
333    let mut rng = thread_rng();
334
335    // First centroid: random vector
336    let first = rng.gen_range(0..n);
337    centroids.extend_from_slice(&vectors[first]);
338
339    // Remaining centroids: weighted by distance to nearest existing centroid (k-means++)
340    // Only do this for effective_k-1 more centroids (based on actual unique vectors)
341    for _ in 1..effective_k {
342        let num_current = centroids.len() / dim;
343        let distances: Vec<f32> = vectors
344            .iter()
345            .map(|v| {
346                let mut min_dist = f32::MAX;
347                for c in 0..num_current {
348                    let centroid = &centroids[c * dim..(c + 1) * dim];
349                    let d = l2_distance(v, centroid);
350                    min_dist = min_dist.min(d);
351                }
352                min_dist
353            })
354            .collect();
355
356        // Sample weighted by distance
357        let total: f32 = distances.iter().sum();
358        if total == 0.0 {
359            // All points are at centroids, pick random
360            let idx = rng.gen_range(0..n);
361            centroids.extend_from_slice(&vectors[idx]);
362        } else {
363            let threshold = rng.r#gen::<f32>() * total;
364            let mut cumsum = 0.0;
365            let mut picked = false;
366            for (i, &d) in distances.iter().enumerate() {
367                cumsum += d;
368                if cumsum >= threshold {
369                    centroids.extend_from_slice(&vectors[i]);
370                    picked = true;
371                    break;
372                }
373            }
374            // Fallback if we didn't pick one (can happen with float precision)
375            if !picked {
376                centroids.extend_from_slice(&vectors[n - 1]);
377            }
378        }
379    }
380
381    // If k > n, replicate existing centroids to reach k
382    while centroids.len() < k * dim {
383        // Cycle through existing centroids
384        let idx = (centroids.len() / dim) % effective_k;
385        let centroid = centroids[idx * dim..(idx + 1) * dim].to_vec();
386        centroids.extend_from_slice(&centroid);
387    }
388    centroids.truncate(k * dim);
389
390    // Lloyd's algorithm iterations
391    let mut assignments: Vec<usize>;
392
393    for _ in 0..iterations {
394        // Assignment step (parallel)
395        assignments = vectors
396            .par_iter()
397            .map(|v| {
398                let mut best_c = 0;
399                let mut best_dist = f32::MAX;
400                for c in 0..k {
401                    let centroid = &centroids[c * dim..(c + 1) * dim];
402                    let d = l2_distance(v, centroid);
403                    if d < best_dist {
404                        best_dist = d;
405                        best_c = c;
406                    }
407                }
408                best_c
409            })
410            .collect();
411
412        // Update step
413        let mut new_centroids = vec![0.0f32; k * dim];
414        let mut counts = vec![0usize; k];
415
416        for (i, &c) in assignments.iter().enumerate() {
417            counts[c] += 1;
418            for (j, &val) in vectors[i].iter().enumerate() {
419                new_centroids[c * dim + j] += val;
420            }
421        }
422
423        // Average
424        for c in 0..k {
425            if counts[c] > 0 {
426                for j in 0..dim {
427                    new_centroids[c * dim + j] /= counts[c] as f32;
428                }
429            } else {
430                // Empty cluster: reinitialize from random point
431                let idx = rng.gen_range(0..n);
432                for j in 0..dim {
433                    new_centroids[c * dim + j] = vectors[idx][j];
434                }
435            }
436        }
437
438        centroids = new_centroids;
439    }
440
441    centroids
442}
443
444/// L2 squared distance
445#[inline]
446fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
447    a.iter()
448        .zip(b.iter())
449        .map(|(x, y)| {
450            let d = x - y;
451            d * d
452        })
453        .sum()
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
461        use rand::SeedableRng;
462        let mut rng = StdRng::seed_from_u64(seed);
463        (0..n)
464            .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
465            .collect()
466    }
467
468    #[test]
469    fn test_pq_encode_decode() {
470        let vectors = random_vectors(1000, 64, 42);
471        let config = PQConfig {
472            num_subspaces: 8,
473            num_centroids: 256,
474            kmeans_iterations: 10,
475            training_sample_size: 0,
476        };
477
478        let pq = ProductQuantizer::train(&vectors, config).unwrap();
479
480        // Encode and decode
481        let original = &vectors[0];
482        let codes = pq.encode(original);
483        let decoded = pq.decode(&codes);
484
485        // Should have same dimension
486        assert_eq!(decoded.len(), original.len());
487
488        // Decoded should be somewhat close to original (lossy compression)
489        let dist = l2_distance(original, &decoded);
490        assert!(
491            dist < original.len() as f32 * 0.1,
492            "Reconstruction error too high: {dist}"
493        );
494    }
495
496    #[test]
497    fn test_pq_asymmetric_distance() {
498        let vectors = random_vectors(500, 32, 123);
499        let config = PQConfig {
500            num_subspaces: 4,
501            num_centroids: 64,
502            kmeans_iterations: 10,
503            training_sample_size: 0,
504        };
505
506        let pq = ProductQuantizer::train(&vectors, config).unwrap();
507
508        let query = &vectors[0];
509        let target = &vectors[100];
510
511        let codes = pq.encode(target);
512
513        // Asymmetric distance should be similar to distance to decoded
514        let asym_dist = pq.asymmetric_distance(query, &codes);
515        let decoded = pq.decode(&codes);
516        let exact_dist = l2_distance(query, &decoded);
517
518        // Should be very close (asymmetric uses same centroids)
519        assert!(
520            (asym_dist - exact_dist).abs() < 1e-5,
521            "asym={asym_dist}, exact={exact_dist}"
522        );
523    }
524
525    #[test]
526    fn test_pq_distance_table() {
527        let vectors = random_vectors(500, 32, 456);
528        let config = PQConfig {
529            num_subspaces: 4,
530            num_centroids: 64,
531            kmeans_iterations: 10,
532            training_sample_size: 0,
533        };
534
535        let pq = ProductQuantizer::train(&vectors, config).unwrap();
536
537        let query = &vectors[0];
538        let table = pq.create_distance_table(query);
539
540        // Compare table-based vs direct asymmetric distance
541        for target in vectors.iter().take(10) {
542            let codes = pq.encode(target);
543            let direct = pq.asymmetric_distance(query, &codes);
544            let table_dist = pq.distance_with_table(&table, &codes);
545
546            assert!(
547                (direct - table_dist).abs() < 1e-5,
548                "direct={direct}, table={table_dist}"
549            );
550        }
551    }
552
553    #[test]
554    fn test_pq_batch_encode() {
555        let vectors = random_vectors(100, 64, 789);
556        let config = PQConfig::default();
557
558        let pq = ProductQuantizer::train(&vectors, config).unwrap();
559        let codes = pq.encode_batch(&vectors);
560
561        assert_eq!(codes.len(), vectors.len());
562        for code in &codes {
563            assert_eq!(code.len(), config.num_subspaces);
564        }
565    }
566
567    #[test]
568    fn test_pq_save_load() {
569        let vectors = random_vectors(200, 64, 111);
570        let config = PQConfig {
571            num_subspaces: 8,
572            num_centroids: 128,
573            kmeans_iterations: 5,
574            training_sample_size: 0,
575        };
576
577        let pq = ProductQuantizer::train(&vectors, config).unwrap();
578        let codes_before = pq.encode(&vectors[0]);
579
580        let path = "test_pq.bin";
581        pq.save(path).unwrap();
582
583        let pq_loaded = ProductQuantizer::load(path).unwrap();
584        let codes_after = pq_loaded.encode(&vectors[0]);
585
586        assert_eq!(codes_before, codes_after);
587
588        std::fs::remove_file(path).ok();
589    }
590
591    #[test]
592    fn test_pq_stats() {
593        let vectors = random_vectors(100, 128, 222);
594        let config = PQConfig {
595            num_subspaces: 8,
596            num_centroids: 256,
597            kmeans_iterations: 5,
598            training_sample_size: 0,
599        };
600
601        let pq = ProductQuantizer::train(&vectors, config).unwrap();
602        let stats = pq.stats();
603
604        assert_eq!(stats.dim, 128);
605        assert_eq!(stats.num_subspaces, 8);
606        assert_eq!(stats.num_centroids, 256);
607        assert_eq!(stats.subspace_dim, 16);
608        assert_eq!(stats.code_size_bytes, 8);
609        assert!(stats.compression_ratio > 50.0); // 128*4 / 8 = 64x
610
611        println!("{}", stats);
612    }
613
614    #[test]
615    fn test_pq_preserves_ordering() {
616        let vectors = random_vectors(500, 64, 333);
617        let config = PQConfig {
618            num_subspaces: 8,
619            num_centroids: 256,
620            kmeans_iterations: 15,
621            training_sample_size: 0,
622        };
623
624        let pq = ProductQuantizer::train(&vectors, config).unwrap();
625
626        let query = &vectors[0];
627
628        // Compute true distances
629        let mut true_dists: Vec<(usize, f32)> = vectors
630            .iter()
631            .enumerate()
632            .skip(1)
633            .map(|(i, v)| (i, l2_distance(query, v)))
634            .collect();
635        true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
636
637        // Compute PQ distances
638        let table = pq.create_distance_table(query);
639        let codes: Vec<Vec<u8>> = vectors.iter().map(|v| pq.encode(v)).collect();
640
641        let mut pq_dists: Vec<(usize, f32)> = codes
642            .iter()
643            .enumerate()
644            .skip(1)
645            .map(|(i, c)| (i, pq.distance_with_table(&table, c)))
646            .collect();
647        pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
648
649        // Check recall@10: how many of true top-10 appear in PQ top-10
650        let true_top10: std::collections::HashSet<_> =
651            true_dists.iter().take(10).map(|(i, _)| *i).collect();
652        let pq_top10: std::collections::HashSet<_> =
653            pq_dists.iter().take(10).map(|(i, _)| *i).collect();
654
655        let recall: f32 = true_top10.intersection(&pq_top10).count() as f32 / 10.0;
656        assert!(
657            recall >= 0.5,
658            "PQ recall@10 too low: {recall}. Expected >= 0.5"
659        );
660    }
661}