Skip to main content

engine/
ivfpq.rs

1//! IVF-PQ (Inverted File with Product Quantization) Hybrid Index
2//!
3//! Combines IVF clustering for fast candidate retrieval with PQ compression
4//! for memory-efficient storage. This is similar to the approach used by
5//! FAISS and other production vector databases.
6//!
7//! Architecture:
8//! 1. IVF partitions vectors into clusters using k-means centroids
9//! 2. Within each cluster, vectors are stored as PQ codes (compressed)
10//! 3. Search: find nearest clusters → scan PQ codes within those clusters
11
12use common::{DistanceMetric, Vector};
13use parking_lot::RwLock;
14use rand::seq::SliceRandom;
15use std::collections::HashMap;
16
17use crate::pq::{PQConfig, ProductQuantizer};
18
19/// Configuration for IVF-PQ hybrid index
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct IvfPqConfig {
22    /// Number of IVF clusters (coarse quantizer)
23    pub n_clusters: usize,
24    /// Number of clusters to probe during search
25    pub n_probe: usize,
26    /// Number of PQ subquantizers
27    pub pq_subquantizers: usize,
28    /// Number of centroids per subquantizer (typically 256)
29    pub pq_centroids: usize,
30    /// K-means iterations for IVF training
31    pub ivf_iterations: usize,
32    /// K-means iterations for PQ training
33    pub pq_iterations: usize,
34    /// Distance metric
35    pub metric: DistanceMetric,
36}
37
38impl Default for IvfPqConfig {
39    fn default() -> Self {
40        Self {
41            n_clusters: 256,
42            n_probe: 8,
43            pq_subquantizers: 8,
44            pq_centroids: 256,
45            ivf_iterations: 20,
46            pq_iterations: 10,
47            metric: DistanceMetric::Euclidean,
48        }
49    }
50}
51
52/// Search result from IVF-PQ index
53#[derive(Debug, Clone)]
54pub struct IvfPqSearchResult {
55    pub id: String,
56    pub score: f32,
57    pub cluster_id: usize,
58}
59
60/// Entry stored in each IVF bucket (PQ-encoded vector)
61#[derive(Debug, Clone)]
62struct PqEntry {
63    id: String,
64    /// Residual PQ codes (vector - centroid, then PQ encoded)
65    codes: Vec<u8>,
66}
67
68/// IVF-PQ Hybrid Index
69pub struct IvfPqIndex {
70    config: IvfPqConfig,
71    dimension: Option<usize>,
72    /// IVF centroids (coarse quantizer)
73    centroids: Vec<Vec<f32>>,
74    /// PQ quantizer for residuals
75    pq: Option<ProductQuantizer>,
76    /// Inverted lists: cluster_id -> list of PQ entries
77    inverted_lists: Vec<RwLock<Vec<PqEntry>>>,
78    /// Whether the index is trained
79    trained: bool,
80}
81
82impl IvfPqIndex {
83    /// Create a new IVF-PQ index
84    pub fn new(config: IvfPqConfig) -> Self {
85        Self {
86            config,
87            dimension: None,
88            centroids: Vec::new(),
89            pq: None,
90            inverted_lists: Vec::new(),
91            trained: false,
92        }
93    }
94
95    /// Check if the index is trained
96    pub fn is_trained(&self) -> bool {
97        self.trained
98    }
99
100    /// Get the dimension
101    pub fn dimension(&self) -> Option<usize> {
102        self.dimension
103    }
104
105    /// Get statistics about the index
106    pub fn stats(&self) -> IvfPqStats {
107        let mut list_sizes = Vec::with_capacity(self.inverted_lists.len());
108        let mut total_vectors = 0usize;
109
110        for list in &self.inverted_lists {
111            let size = list.read().len();
112            list_sizes.push(size);
113            total_vectors += size;
114        }
115
116        let avg_list_size = if list_sizes.is_empty() {
117            0.0
118        } else {
119            total_vectors as f64 / list_sizes.len() as f64
120        };
121
122        let max_list_size = list_sizes.iter().copied().max().unwrap_or(0);
123        let min_list_size = list_sizes.iter().copied().min().unwrap_or(0);
124
125        // Calculate memory usage estimate
126        let centroid_memory = self.centroids.len() * self.dimension.unwrap_or(0) * 4;
127        let pq_memory = self
128            .pq
129            .as_ref()
130            .map(|pq| {
131                pq.config.num_subquantizers
132                    * pq.config.num_centroids
133                    * (self.dimension.unwrap_or(0) / pq.config.num_subquantizers)
134                    * 4
135            })
136            .unwrap_or(0);
137        let codes_memory = total_vectors * self.config.pq_subquantizers;
138
139        IvfPqStats {
140            n_clusters: self.centroids.len(),
141            total_vectors,
142            avg_list_size,
143            max_list_size,
144            min_list_size,
145            trained: self.trained,
146            dimension: self.dimension,
147            memory_bytes: centroid_memory + pq_memory + codes_memory,
148        }
149    }
150
151    /// Train the index on a set of vectors
152    ///
153    /// This performs:
154    /// 1. K-means clustering to learn IVF centroids
155    /// 2. Compute residuals (vectors - their assigned centroid)
156    /// 3. Train PQ on the residuals
157    pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
158        if vectors.is_empty() {
159            return Err("Cannot train on empty vector set".to_string());
160        }
161
162        let dim = vectors[0].values.len();
163        self.dimension = Some(dim);
164
165        // Validate all vectors have same dimension
166        for v in vectors {
167            if v.values.len() != dim {
168                return Err(format!(
169                    "Dimension mismatch: expected {}, got {}",
170                    dim,
171                    v.values.len()
172                ));
173            }
174        }
175
176        // Step 1: Train IVF centroids using k-means
177        let n_clusters = self.config.n_clusters.min(vectors.len());
178        self.centroids = self.kmeans_train(vectors, n_clusters)?;
179
180        // Initialize inverted lists
181        self.inverted_lists = (0..n_clusters).map(|_| RwLock::new(Vec::new())).collect();
182
183        // Step 2: Compute residuals for PQ training
184        let mut residuals = Vec::with_capacity(vectors.len());
185        for v in vectors {
186            let (cluster_id, _) = self.find_nearest_centroid(&v.values);
187            let residual = self.compute_residual(&v.values, cluster_id);
188            residuals.push(Vector {
189                id: v.id.clone(),
190                values: residual,
191                metadata: None,
192                ttl_seconds: None,
193                expires_at: None,
194            });
195        }
196
197        // Step 3: Train PQ on residuals
198        let pq_config = PQConfig {
199            num_subquantizers: self.config.pq_subquantizers,
200            num_centroids: self.config.pq_centroids,
201            kmeans_iterations: self.config.pq_iterations,
202            distance_metric: self.config.metric,
203        };
204
205        let mut pq = ProductQuantizer::new(pq_config, dim)?;
206        pq.train(&residuals)?;
207        self.pq = Some(pq);
208
209        self.trained = true;
210        Ok(())
211    }
212
213    /// Add vectors to the trained index
214    pub fn add(&self, vectors: &[Vector]) -> Result<usize, String> {
215        if !self.trained {
216            return Err("Index must be trained before adding vectors".to_string());
217        }
218
219        let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
220        let dim = self.dimension.ok_or("Dimension not set")?;
221
222        let mut added = 0;
223        for v in vectors {
224            if v.values.len() != dim {
225                continue;
226            }
227
228            // Find nearest centroid
229            let (cluster_id, _) = self.find_nearest_centroid(&v.values);
230
231            // Compute residual
232            let residual = self.compute_residual(&v.values, cluster_id);
233
234            // Encode residual with PQ
235            let codes = pq.encode(&residual)?;
236
237            // Add to inverted list
238            let entry = PqEntry {
239                id: v.id.clone(),
240                codes,
241            };
242            self.inverted_lists[cluster_id].write().push(entry);
243            added += 1;
244        }
245
246        Ok(added)
247    }
248
249    /// Search for nearest neighbors
250    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<IvfPqSearchResult>, String> {
251        if !self.trained {
252            return Err("Index must be trained before searching".to_string());
253        }
254
255        let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
256        let dim = self.dimension.ok_or("Dimension not set")?;
257
258        if query.len() != dim {
259            return Err(format!(
260                "Query dimension {} doesn't match index dimension {}",
261                query.len(),
262                dim
263            ));
264        }
265
266        // Find n_probe nearest clusters
267        let n_probe = self.config.n_probe.min(self.centroids.len());
268        let nearest_clusters = self.find_nearest_centroids(query, n_probe);
269
270        // Collect candidates from all probed clusters
271        let mut candidates: Vec<IvfPqSearchResult> = Vec::new();
272
273        for (cluster_id, _) in nearest_clusters {
274            // Compute residual for this cluster
275            let query_residual = self.compute_residual(query, cluster_id);
276
277            // Precompute distance table for asymmetric distance computation
278            let distance_table = pq.compute_distance_table(&query_residual)?;
279
280            // Scan inverted list
281            let list = self.inverted_lists[cluster_id].read();
282            for entry in list.iter() {
283                // Compute approximate distance using ADC (asymmetric distance computation)
284                // Note: For Euclidean, ADC returns negative distance (higher = more similar)
285                let score = pq.compute_distance_adc(&distance_table, &entry.codes);
286
287                candidates.push(IvfPqSearchResult {
288                    id: entry.id.clone(),
289                    score, // ADC already returns similarity-like score
290                    cluster_id,
291                });
292            }
293        }
294
295        // Sort by score (descending) and take top k
296        candidates.sort_by(|a, b| {
297            b.score
298                .partial_cmp(&a.score)
299                .unwrap_or(std::cmp::Ordering::Equal)
300        });
301        candidates.truncate(k);
302
303        Ok(candidates)
304    }
305
306    /// Find the nearest centroid to a vector
307    fn find_nearest_centroid(&self, vector: &[f32]) -> (usize, f32) {
308        let mut best_idx = 0;
309        let mut best_dist = f32::MAX;
310
311        for (idx, centroid) in self.centroids.iter().enumerate() {
312            let dist = euclidean_distance(vector, centroid);
313            if dist < best_dist {
314                best_dist = dist;
315                best_idx = idx;
316            }
317        }
318
319        (best_idx, best_dist)
320    }
321
322    /// Find the n nearest centroids
323    fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<(usize, f32)> {
324        let mut distances: Vec<(usize, f32)> = self
325            .centroids
326            .iter()
327            .enumerate()
328            .map(|(idx, centroid)| (idx, euclidean_distance(vector, centroid)))
329            .collect();
330
331        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
332        distances.truncate(n);
333        distances
334    }
335
336    /// Compute residual (vector - centroid)
337    fn compute_residual(&self, vector: &[f32], cluster_id: usize) -> Vec<f32> {
338        let centroid = &self.centroids[cluster_id];
339        vector
340            .iter()
341            .zip(centroid.iter())
342            .map(|(v, c)| v - c)
343            .collect()
344    }
345
346    /// Train k-means clustering
347    fn kmeans_train(&self, vectors: &[Vector], k: usize) -> Result<Vec<Vec<f32>>, String> {
348        if vectors.is_empty() || k == 0 {
349            return Err("Invalid input for k-means".to_string());
350        }
351
352        let dim = vectors[0].values.len();
353        let mut rng = rand::thread_rng();
354
355        // Initialize centroids by random selection
356        let mut indices: Vec<usize> = (0..vectors.len()).collect();
357        indices.shuffle(&mut rng);
358        let mut centroids: Vec<Vec<f32>> = indices
359            .iter()
360            .take(k)
361            .map(|&i| vectors[i].values.clone())
362            .collect();
363
364        // Ensure we have k centroids
365        while centroids.len() < k {
366            centroids.push(vec![0.0; dim]);
367        }
368
369        // K-means iterations
370        for _ in 0..self.config.ivf_iterations {
371            // Assignment step
372            let mut assignments: HashMap<usize, Vec<usize>> = HashMap::new();
373            for cluster_id in 0..k {
374                assignments.insert(cluster_id, Vec::new());
375            }
376
377            for (vec_idx, v) in vectors.iter().enumerate() {
378                let mut best_cluster = 0;
379                let mut best_dist = f32::MAX;
380
381                for (cluster_id, centroid) in centroids.iter().enumerate() {
382                    let dist = euclidean_distance(&v.values, centroid);
383                    if dist < best_dist {
384                        best_dist = dist;
385                        best_cluster = cluster_id;
386                    }
387                }
388
389                if let Some(members) = assignments.get_mut(&best_cluster) {
390                    members.push(vec_idx);
391                }
392            }
393
394            // Update step
395            let mut converged = true;
396            for (cluster_id, member_indices) in &assignments {
397                if member_indices.is_empty() {
398                    continue;
399                }
400
401                let mut new_centroid = vec![0.0; dim];
402                for &idx in member_indices {
403                    for (j, val) in vectors[idx].values.iter().enumerate() {
404                        new_centroid[j] += val;
405                    }
406                }
407                for val in &mut new_centroid {
408                    *val /= member_indices.len() as f32;
409                }
410
411                // Check convergence
412                let diff = euclidean_distance(&centroids[*cluster_id], &new_centroid);
413                if diff > 1e-4 {
414                    converged = false;
415                }
416
417                centroids[*cluster_id] = new_centroid;
418            }
419
420            if converged {
421                break;
422            }
423        }
424
425        Ok(centroids)
426    }
427}
428
429/// Statistics about the IVF-PQ index
430#[derive(Debug, Clone)]
431pub struct IvfPqStats {
432    pub n_clusters: usize,
433    pub total_vectors: usize,
434    pub avg_list_size: f64,
435    pub max_list_size: usize,
436    pub min_list_size: usize,
437    pub trained: bool,
438    pub dimension: Option<usize>,
439    pub memory_bytes: usize,
440}
441
442/// Euclidean distance helper
443fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
444    a.iter()
445        .zip(b.iter())
446        .map(|(x, y)| (x - y).powi(2))
447        .sum::<f32>()
448        .sqrt()
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    fn create_test_vectors(n: usize, dim: usize) -> Vec<Vector> {
456        use rand::Rng;
457        use rand::SeedableRng;
458
459        // Use seeded RNG for reproducibility
460        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
461
462        (0..n)
463            .map(|i| {
464                let values: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
465                Vector {
466                    id: format!("v{}", i),
467                    values,
468                    metadata: None,
469                    ttl_seconds: None,
470                    expires_at: None,
471                }
472            })
473            .collect()
474    }
475
476    #[test]
477    fn test_ivfpq_creation() {
478        let config = IvfPqConfig::default();
479        let index = IvfPqIndex::new(config);
480
481        assert!(!index.is_trained());
482        assert_eq!(index.dimension(), None);
483    }
484
485    #[test]
486    fn test_ivfpq_training() {
487        let config = IvfPqConfig {
488            n_clusters: 4,
489            n_probe: 2,
490            pq_subquantizers: 2,
491            pq_centroids: 8,
492            ivf_iterations: 5,
493            pq_iterations: 5,
494            metric: DistanceMetric::Euclidean,
495        };
496
497        let mut index = IvfPqIndex::new(config);
498        let vectors = create_test_vectors(100, 16);
499
500        let result = index.train(&vectors);
501        assert!(result.is_ok(), "Training failed: {:?}", result.err());
502        assert!(index.is_trained());
503        assert_eq!(index.dimension(), Some(16));
504    }
505
506    #[test]
507    fn test_ivfpq_add_and_search() {
508        let config = IvfPqConfig {
509            n_clusters: 4,
510            n_probe: 4, // Probe all clusters to ensure we find the vector
511            pq_subquantizers: 2,
512            pq_centroids: 8,
513            ivf_iterations: 5,
514            pq_iterations: 5,
515            metric: DistanceMetric::Euclidean,
516        };
517
518        let mut index = IvfPqIndex::new(config);
519        let vectors = create_test_vectors(100, 16);
520
521        // Train
522        index.train(&vectors).unwrap();
523
524        // Add vectors
525        let added = index.add(&vectors).unwrap();
526        assert_eq!(added, 100);
527
528        // Search
529        let query = &vectors[0].values;
530        let results = index.search(query, 10).unwrap();
531
532        assert!(!results.is_empty(), "Results should not be empty");
533
534        // The query vector itself should be among top results
535        let found_self = results.iter().any(|r| r.id == "v0");
536        assert!(
537            found_self,
538            "Query vector should be found in results. Got: {:?}",
539            results.iter().map(|r| &r.id).collect::<Vec<_>>()
540        );
541    }
542
543    #[test]
544    fn test_ivfpq_stats() {
545        let config = IvfPqConfig {
546            n_clusters: 4,
547            n_probe: 2,
548            pq_subquantizers: 2,
549            pq_centroids: 8,
550            ivf_iterations: 5,
551            pq_iterations: 5,
552            metric: DistanceMetric::Euclidean,
553        };
554
555        let mut index = IvfPqIndex::new(config);
556        let vectors = create_test_vectors(100, 16);
557
558        index.train(&vectors).unwrap();
559        index.add(&vectors).unwrap();
560
561        let stats = index.stats();
562        assert_eq!(stats.n_clusters, 4);
563        assert_eq!(stats.total_vectors, 100);
564        assert!(stats.trained);
565        assert_eq!(stats.dimension, Some(16));
566        assert!(stats.memory_bytes > 0);
567    }
568
569    #[test]
570    fn test_ivfpq_search_quality() {
571        let config = IvfPqConfig {
572            n_clusters: 8,
573            n_probe: 8, // Probe all clusters for this test
574            pq_subquantizers: 4,
575            pq_centroids: 16,
576            ivf_iterations: 10,
577            pq_iterations: 10,
578            metric: DistanceMetric::Euclidean,
579        };
580
581        let mut index = IvfPqIndex::new(config);
582        let vectors = create_test_vectors(200, 32);
583
584        index.train(&vectors).unwrap();
585        index.add(&vectors).unwrap();
586
587        // Search for multiple queries and check recall
588        let mut total_recall = 0.0;
589        let test_queries = 10;
590
591        for i in 0..test_queries {
592            let query = &vectors[i * 10].values;
593            let results = index.search(query, 20).unwrap();
594
595            // Check if the exact match is in results
596            let expected_id = format!("v{}", i * 10);
597            if results.iter().any(|r| r.id == expected_id) {
598                total_recall += 1.0;
599            }
600        }
601
602        let recall = total_recall / test_queries as f32;
603        assert!(
604            recall >= 0.5,
605            "Recall should be at least 50%, got {}%",
606            recall * 100.0
607        );
608    }
609
610    #[test]
611    fn test_ivfpq_empty_search() {
612        let config = IvfPqConfig {
613            n_clusters: 4,
614            n_probe: 2,
615            pq_subquantizers: 2,
616            pq_centroids: 8,
617            ivf_iterations: 5,
618            pq_iterations: 5,
619            metric: DistanceMetric::Euclidean,
620        };
621
622        let mut index = IvfPqIndex::new(config);
623        let vectors = create_test_vectors(50, 16);
624
625        index.train(&vectors).unwrap();
626        // Don't add any vectors
627
628        let query = &vectors[0].values;
629        let results = index.search(query, 5).unwrap();
630
631        assert!(results.is_empty());
632    }
633
634    #[test]
635    fn test_ivfpq_untrained_error() {
636        let index = IvfPqIndex::new(IvfPqConfig::default());
637
638        let result = index.search(&[0.0; 128], 5);
639        assert!(result.is_err());
640        assert!(result.unwrap_err().contains("trained"));
641    }
642}