Skip to main content

oxirs_vec/
ivf.rs

1//! Inverted File (IVF) index implementation for approximate nearest neighbor search
2//!
3//! IVF is a clustering-based indexing method that partitions the vector space into
4//! Voronoi cells. Each cell has a centroid, and vectors are assigned to their nearest
5//! centroid. During search, only a subset of cells are examined, greatly reducing
6//! search time at the cost of some accuracy.
7
8use crate::{
9    pq::{PQConfig, PQIndex},
10    Vector, VectorIndex,
11};
12use anyhow::{anyhow, Result};
13use std::sync::{Arc, RwLock};
14
15/// Quantization strategy for residuals
16#[derive(Debug, Clone, PartialEq)]
17pub enum QuantizationStrategy {
18    /// No quantization - store full residuals
19    None,
20    /// Single-level product quantization
21    ProductQuantization(PQConfig),
22    /// Residual quantization with multiple levels
23    ResidualQuantization {
24        levels: usize,
25        pq_configs: Vec<PQConfig>,
26    },
27    /// Multi-codebook quantization for improved accuracy
28    MultiCodebook {
29        num_codebooks: usize,
30        pq_configs: Vec<PQConfig>,
31    },
32}
33
34/// Configuration for IVF index
35#[derive(Debug, Clone)]
36pub struct IvfConfig {
37    /// Number of clusters (Voronoi cells)
38    pub n_clusters: usize,
39    /// Number of probes during search (cells to examine)
40    pub n_probes: usize,
41    /// Maximum iterations for k-means clustering
42    pub max_iterations: usize,
43    /// Convergence threshold for k-means
44    pub convergence_threshold: f32,
45    /// Random seed for reproducibility
46    pub seed: Option<u64>,
47    /// Quantization strategy for residuals
48    pub quantization: QuantizationStrategy,
49    /// Enable residual quantization for compression (deprecated - use quantization field)
50    pub enable_residual_quantization: bool,
51    /// Product quantization configuration for residuals (deprecated - use quantization field)
52    pub pq_config: Option<PQConfig>,
53}
54
55impl Default for IvfConfig {
56    fn default() -> Self {
57        Self {
58            n_clusters: 256,
59            n_probes: 8,
60            max_iterations: 100,
61            convergence_threshold: 1e-4,
62            seed: None,
63            quantization: QuantizationStrategy::None,
64            enable_residual_quantization: false,
65            pq_config: None,
66        }
67    }
68}
69
70/// Storage format for vectors in inverted lists
71#[derive(Debug, Clone)]
72enum VectorStorage {
73    /// Store full vectors
74    Full(Vector),
75    /// Store quantized residuals with PQ codes
76    Quantized(Vec<u8>),
77    /// Store multi-level residual quantization codes
78    MultiLevelQuantized {
79        levels: Vec<Vec<u8>>,           // PQ codes for each quantization level
80        final_residual: Option<Vector>, // Optional final unquantized residual
81    },
82    /// Store multi-codebook quantization codes
83    MultiCodebook {
84        codebooks: Vec<Vec<u8>>, // PQ codes from different codebooks
85        weights: Vec<f32>,       // Weights for combining codebook predictions
86    },
87}
88
89/// Inverted list storing vectors for a single cluster
90#[derive(Debug, Clone)]
91struct InvertedList {
92    /// Vectors in this cluster with their storage format
93    vectors: Vec<(String, VectorStorage)>,
94    /// Quantization strategy used for this list
95    quantization: QuantizationStrategy,
96    /// Product quantizer for single-level quantization
97    pq_index: Option<PQIndex>,
98    /// Multiple PQ indexes for multi-level residual quantization
99    multi_level_pq: Vec<PQIndex>,
100    /// Multiple PQ indexes for multi-codebook quantization
101    multi_codebook_pq: Vec<PQIndex>,
102    /// Codebook weights for multi-codebook quantization
103    codebook_weights: Vec<f32>,
104}
105
106impl InvertedList {
107    fn new() -> Self {
108        Self {
109            vectors: Vec::new(),
110            quantization: QuantizationStrategy::None,
111            pq_index: None,
112            multi_level_pq: Vec::new(),
113            multi_codebook_pq: Vec::new(),
114            codebook_weights: Vec::new(),
115        }
116    }
117
118    fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
119        let mut list = Self {
120            vectors: Vec::new(),
121            quantization: quantization.clone(),
122            pq_index: None,
123            multi_level_pq: Vec::new(),
124            multi_codebook_pq: Vec::new(),
125            codebook_weights: Vec::new(),
126        };
127
128        match quantization {
129            QuantizationStrategy::None => {}
130            QuantizationStrategy::ProductQuantization(pq_config) => {
131                list.pq_index = Some(PQIndex::new(pq_config));
132            }
133            QuantizationStrategy::ResidualQuantization {
134                levels: _,
135                ref pq_configs,
136            } => {
137                for pq_config in pq_configs {
138                    list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
139                }
140            }
141            QuantizationStrategy::MultiCodebook {
142                num_codebooks,
143                ref pq_configs,
144            } => {
145                for pq_config in pq_configs {
146                    list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
147                }
148                // Initialize equal weights for all codebooks
149                list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
150            }
151        }
152
153        Ok(list)
154    }
155
156    // Deprecated - use new_with_quantization instead
157    fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
158        Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
159    }
160
161    fn add_full(&mut self, uri: String, vector: Vector) {
162        self.vectors.push((uri, VectorStorage::Full(vector)));
163    }
164
165    fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
166        match &self.quantization {
167            QuantizationStrategy::ProductQuantization(_) => {
168                if let Some(ref mut pq_index) = self.pq_index {
169                    // Train PQ on residuals if not already trained
170                    if !pq_index.is_trained() {
171                        let training_residuals = vec![residual.clone()];
172                        pq_index.train(&training_residuals)?;
173                    }
174
175                    let codes = pq_index.encode(&residual)?;
176                    self.vectors.push((uri, VectorStorage::Quantized(codes)));
177                } else {
178                    return Err(anyhow!(
179                        "PQ index not initialized for residual quantization"
180                    ));
181                }
182            }
183            QuantizationStrategy::ResidualQuantization { levels, .. } => {
184                self.add_multi_level_residual(uri, residual, *levels)?;
185            }
186            QuantizationStrategy::MultiCodebook { .. } => {
187                self.add_multi_codebook(uri, residual)?;
188            }
189            QuantizationStrategy::None => {
190                self.add_full(uri, residual);
191            }
192        }
193        Ok(())
194    }
195
196    /// Add vector using multi-level residual quantization
197    fn add_multi_level_residual(
198        &mut self,
199        uri: String,
200        mut residual: Vector,
201        levels: usize,
202    ) -> Result<()> {
203        let mut level_codes = Vec::new();
204
205        for level in 0..levels.min(self.multi_level_pq.len()) {
206            // Train this level's PQ if not already trained
207            if !self.multi_level_pq[level].is_trained() {
208                let training_residuals = vec![residual.clone()];
209                self.multi_level_pq[level].train(&training_residuals)?;
210            }
211
212            // Encode residual at this level
213            let codes = self.multi_level_pq[level].encode(&residual)?;
214            level_codes.push(codes);
215
216            // Compute and subtract the quantized approximation to get next level residual
217            let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
218            residual = residual.subtract(&approximation)?;
219        }
220
221        // Store the final residual if we haven't exhausted all levels
222        let final_residual = if level_codes.len() < levels {
223            Some(residual)
224        } else {
225            None
226        };
227
228        self.vectors.push((
229            uri,
230            VectorStorage::MultiLevelQuantized {
231                levels: level_codes,
232                final_residual,
233            },
234        ));
235
236        Ok(())
237    }
238
239    /// Add vector using multi-codebook quantization
240    fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
241        let mut codebook_codes = Vec::new();
242
243        for pq_index in self.multi_codebook_pq.iter_mut() {
244            // Train this codebook's PQ if not already trained
245            if !pq_index.is_trained() {
246                let training_residuals = vec![residual.clone()];
247                pq_index.train(&training_residuals)?;
248            }
249
250            // Encode residual with this codebook
251            let codes = pq_index.encode(&residual)?;
252            codebook_codes.push(codes);
253        }
254
255        self.vectors.push((
256            uri,
257            VectorStorage::MultiCodebook {
258                codebooks: codebook_codes,
259                weights: self.codebook_weights.clone(),
260            },
261        ));
262
263        Ok(())
264    }
265
266    fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
267        let mut distances: Vec<(String, f32)> = Vec::new();
268        let query_residual = query.subtract(centroid)?;
269
270        for (uri, storage) in &self.vectors {
271            let distance = match storage {
272                VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
273                VectorStorage::Quantized(codes) => {
274                    if let Some(ref pq_index) = self.pq_index {
275                        pq_index.compute_distance(&query_residual, codes)?
276                    } else {
277                        f32::INFINITY
278                    }
279                }
280                VectorStorage::MultiLevelQuantized {
281                    levels,
282                    final_residual,
283                } => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
284                VectorStorage::MultiCodebook { codebooks, weights } => {
285                    self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
286                }
287            };
288            distances.push((uri.clone(), distance));
289        }
290
291        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
292        distances.truncate(k);
293
294        // Convert distances to similarities (1 / (1 + distance))
295        Ok(distances
296            .into_iter()
297            .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
298            .collect())
299    }
300
301    /// Compute distance for multi-level residual quantization
302    fn compute_multi_level_distance(
303        &self,
304        query_residual: &Vector,
305        level_codes: &[Vec<u8>],
306        final_residual: &Option<Vector>,
307    ) -> Result<f32> {
308        let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
309
310        // Reconstruct vector from quantized levels
311        for (level, codes) in level_codes.iter().enumerate() {
312            if level < self.multi_level_pq.len() {
313                let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
314                reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
315            }
316        }
317
318        // Add final unquantized residual if present
319        if let Some(final_res) = final_residual {
320            reconstructed_residual = reconstructed_residual.add(final_res)?;
321        }
322
323        // Compute distance between query residual and reconstructed residual
324        query_residual.euclidean_distance(&reconstructed_residual)
325    }
326
327    /// Compute distance for multi-codebook quantization
328    fn compute_multi_codebook_distance(
329        &self,
330        query_residual: &Vector,
331        codebook_codes: &[Vec<u8>],
332        weights: &[f32],
333    ) -> Result<f32> {
334        let mut weighted_distance = 0.0;
335        let mut total_weight = 0.0;
336
337        // Compute weighted combination of distances from all codebooks
338        for (i, codes) in codebook_codes.iter().enumerate() {
339            if i < self.multi_codebook_pq.len() && i < weights.len() {
340                let codebook_distance =
341                    self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
342                weighted_distance += weights[i] * codebook_distance;
343                total_weight += weights[i];
344            }
345        }
346
347        // Normalize by total weight
348        if total_weight > 0.0 {
349            Ok(weighted_distance / total_weight)
350        } else {
351            Ok(f32::INFINITY)
352        }
353    }
354
355    /// Train product quantizer on collected residuals
356    fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
357        match &self.quantization {
358            QuantizationStrategy::ProductQuantization(_) => {
359                if let Some(ref mut pq_index) = self.pq_index {
360                    pq_index.train(residuals)?;
361                }
362            }
363            QuantizationStrategy::ResidualQuantization { levels, .. } => {
364                self.train_multi_level_pq(residuals, *levels)?;
365            }
366            QuantizationStrategy::MultiCodebook { .. } => {
367                self.train_multi_codebook_pq(residuals)?;
368            }
369            QuantizationStrategy::None => {}
370        }
371        Ok(())
372    }
373
374    /// Train multi-level residual quantization
375    fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
376        let mut current_residuals = residuals.to_vec();
377
378        for level in 0..levels.min(self.multi_level_pq.len()) {
379            // Train PQ at this level
380            self.multi_level_pq[level].train(&current_residuals)?;
381
382            // Compute residuals for next level by subtracting quantized approximation
383            let mut next_residuals = Vec::new();
384            for residual in &current_residuals {
385                let codes = self.multi_level_pq[level].encode(residual)?;
386                let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
387                let next_residual = residual.subtract(&approximation)?;
388                next_residuals.push(next_residual);
389            }
390            current_residuals = next_residuals;
391        }
392
393        Ok(())
394    }
395
396    /// Train multi-codebook quantization
397    fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
398        // Train each codebook independently on the same residuals
399        for pq_index in &mut self.multi_codebook_pq {
400            pq_index.train(residuals)?;
401        }
402
403        // Optionally, optimize codebook weights based on reconstruction quality
404        self.optimize_codebook_weights(residuals)?;
405
406        Ok(())
407    }
408
409    /// Optimize weights for multi-codebook quantization
410    fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
411        if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
412            return Ok(());
413        }
414
415        let num_codebooks = self.multi_codebook_pq.len();
416        let mut reconstruction_errors = vec![0.0; num_codebooks];
417
418        // Compute reconstruction error for each codebook
419        for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
420            let mut total_error = 0.0;
421            for residual in residuals {
422                let codes = pq_index.encode(residual)?;
423                let reconstruction = pq_index.decode_vector(&codes)?;
424                let error = residual
425                    .euclidean_distance(&reconstruction)
426                    .unwrap_or(f32::INFINITY);
427                total_error += error;
428            }
429            reconstruction_errors[i] = total_error / residuals.len() as f32;
430        }
431
432        // Compute weights inversely proportional to reconstruction error
433        let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
434        if max_error > 0.0 {
435            let mut total_weight = 0.0;
436            for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
437                // Higher weight for lower error
438                self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
439                total_weight += self.codebook_weights[i];
440            }
441
442            // Normalize weights
443            if total_weight > 0.0 {
444                for weight in &mut self.codebook_weights {
445                    *weight /= total_weight;
446                }
447            }
448        }
449
450        Ok(())
451    }
452
453    /// Get statistics about this inverted list
454    fn stats(&self) -> InvertedListStats {
455        let mut full_vectors = 0;
456        let mut quantized_vectors = 0;
457        let mut multi_level_vectors = 0;
458        let mut multi_codebook_vectors = 0;
459
460        for (_, storage) in &self.vectors {
461            match storage {
462                VectorStorage::Full(_) => full_vectors += 1,
463                VectorStorage::Quantized(_) => quantized_vectors += 1,
464                VectorStorage::MultiLevelQuantized { .. } => {
465                    quantized_vectors += 1;
466                    multi_level_vectors += 1;
467                }
468                VectorStorage::MultiCodebook { .. } => {
469                    quantized_vectors += 1;
470                    multi_codebook_vectors += 1;
471                }
472            }
473        }
474
475        let total_vectors = self.vectors.len();
476        let compression_ratio = if total_vectors > 0 {
477            quantized_vectors as f32 / total_vectors as f32
478        } else {
479            0.0
480        };
481
482        InvertedListStats {
483            total_vectors,
484            full_vectors,
485            quantized_vectors,
486            compression_ratio,
487            multi_level_vectors,
488            multi_codebook_vectors,
489            quantization_strategy: self.quantization.clone(),
490        }
491    }
492}
493
494/// Statistics for an inverted list
495#[derive(Debug, Clone)]
496pub struct InvertedListStats {
497    pub total_vectors: usize,
498    pub full_vectors: usize,
499    pub quantized_vectors: usize,
500    pub compression_ratio: f32,
501    pub multi_level_vectors: usize,
502    pub multi_codebook_vectors: usize,
503    pub quantization_strategy: QuantizationStrategy,
504}
505
506/// IVF index for approximate nearest neighbor search
507pub struct IvfIndex {
508    config: IvfConfig,
509    /// Cluster centroids
510    centroids: Vec<Vector>,
511    /// Inverted lists (one per cluster)
512    inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
513    /// Dimensions of vectors
514    dimensions: Option<usize>,
515    /// Total number of vectors
516    n_vectors: usize,
517    /// Whether the index has been trained
518    is_trained: bool,
519}
520
521impl IvfIndex {
522    /// Create a new IVF index
523    pub fn new(config: IvfConfig) -> Result<Self> {
524        let mut inverted_lists = Vec::with_capacity(config.n_clusters);
525
526        // Determine quantization strategy (backward compatibility support)
527        let quantization = if config.enable_residual_quantization {
528            if let Some(ref pq_config) = config.pq_config {
529                QuantizationStrategy::ProductQuantization(pq_config.clone())
530            } else {
531                return Err(anyhow!(
532                    "PQ config required when residual quantization is enabled"
533                ));
534            }
535        } else {
536            config.quantization.clone()
537        };
538
539        for _ in 0..config.n_clusters {
540            let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
541                quantization.clone(),
542            )?));
543            inverted_lists.push(inverted_list);
544        }
545
546        Ok(Self {
547            config,
548            centroids: Vec::new(),
549            inverted_lists,
550            dimensions: None,
551            n_vectors: 0,
552            is_trained: false,
553        })
554    }
555
556    /// Create a new IVF index with product quantization
557    pub fn new_with_product_quantization(
558        n_clusters: usize,
559        n_probes: usize,
560        pq_config: PQConfig,
561    ) -> Result<Self> {
562        let config = IvfConfig {
563            n_clusters,
564            n_probes,
565            quantization: QuantizationStrategy::ProductQuantization(pq_config),
566            ..Default::default()
567        };
568        Self::new(config)
569    }
570
571    /// Create a new IVF index with multi-level residual quantization
572    pub fn new_with_multi_level_quantization(
573        n_clusters: usize,
574        n_probes: usize,
575        levels: usize,
576        pq_configs: Vec<PQConfig>,
577    ) -> Result<Self> {
578        if pq_configs.len() < levels {
579            return Err(anyhow!(
580                "Number of PQ configs must be at least equal to levels"
581            ));
582        }
583
584        let config = IvfConfig {
585            n_clusters,
586            n_probes,
587            quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
588            ..Default::default()
589        };
590        Self::new(config)
591    }
592
593    /// Create a new IVF index with multi-codebook quantization
594    pub fn new_with_multi_codebook_quantization(
595        n_clusters: usize,
596        n_probes: usize,
597        num_codebooks: usize,
598        pq_configs: Vec<PQConfig>,
599    ) -> Result<Self> {
600        if pq_configs.len() != num_codebooks {
601            return Err(anyhow!(
602                "Number of PQ configs must equal number of codebooks"
603            ));
604        }
605
606        let config = IvfConfig {
607            n_clusters,
608            n_probes,
609            quantization: QuantizationStrategy::MultiCodebook {
610                num_codebooks,
611                pq_configs,
612            },
613            ..Default::default()
614        };
615        Self::new(config)
616    }
617
618    /// Create a new IVF index with residual quantization enabled (deprecated)
619    pub fn new_with_residual_quantization(
620        n_clusters: usize,
621        n_probes: usize,
622        pq_config: PQConfig,
623    ) -> Result<Self> {
624        Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
625    }
626
627    /// Get the configuration of this index
628    pub fn config(&self) -> &IvfConfig {
629        &self.config
630    }
631
632    /// Train the index with a sample of vectors
633    pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
634        if training_vectors.is_empty() {
635            return Err(anyhow!("Cannot train IVF index with empty training set"));
636        }
637
638        // Validate dimensions
639        let dims = training_vectors[0].dimensions;
640        if !training_vectors.iter().all(|v| v.dimensions == dims) {
641            return Err(anyhow!(
642                "All training vectors must have the same dimensions"
643            ));
644        }
645
646        self.dimensions = Some(dims);
647
648        // Initialize centroids using k-means++
649        self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
650
651        // Run k-means clustering
652        let mut iteration = 0;
653        let mut prev_error = f32::INFINITY;
654
655        while iteration < self.config.max_iterations {
656            // Assign vectors to nearest centroids
657            let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
658
659            for vector in training_vectors {
660                let nearest_idx = self.find_nearest_centroid(vector)?;
661                clusters[nearest_idx].push(vector);
662            }
663
664            // Update centroids
665            let mut total_error = 0.0;
666            for (i, cluster) in clusters.iter().enumerate() {
667                if !cluster.is_empty() {
668                    let new_centroid = self.compute_centroid(cluster);
669                    total_error += self.centroids[i]
670                        .euclidean_distance(&new_centroid)
671                        .unwrap_or(0.0);
672                    self.centroids[i] = new_centroid;
673                }
674            }
675
676            // Check convergence
677            if (prev_error - total_error).abs() < self.config.convergence_threshold {
678                break;
679            }
680
681            prev_error = total_error;
682            iteration += 1;
683        }
684
685        self.is_trained = true;
686
687        // Train quantization if enabled
688        if !matches!(self.config.quantization, QuantizationStrategy::None)
689            || self.config.enable_residual_quantization
690        {
691            self.train_residual_quantization(training_vectors)?;
692        }
693
694        Ok(())
695    }
696
697    /// Train residual quantization on all clusters
698    fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
699        // Collect residuals for each cluster
700        let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
701
702        for vector in training_vectors {
703            let cluster_idx = self.find_nearest_centroid(vector)?;
704            let centroid = &self.centroids[cluster_idx];
705            let residual = vector.subtract(centroid)?;
706            cluster_residuals[cluster_idx].push(residual);
707        }
708
709        // Train PQ for each cluster that has enough residuals
710        for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
711            if residuals.len() > 10 {
712                // Minimum threshold for training
713                let mut list = self.inverted_lists[cluster_idx]
714                    .write()
715                    .expect("inverted_lists lock should not be poisoned");
716                list.train_pq(residuals)?;
717            }
718        }
719
720        Ok(())
721    }
722
723    /// Initialize centroids using k-means++ algorithm
724    fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
725        use std::collections::hash_map::DefaultHasher;
726        use std::hash::{Hash, Hasher};
727
728        let mut hasher = DefaultHasher::new();
729        self.config.seed.unwrap_or(42).hash(&mut hasher);
730        let mut rng_state = hasher.finish();
731
732        let mut centroids = Vec::with_capacity(self.config.n_clusters);
733
734        // Choose first centroid randomly
735        let first_idx = (rng_state as usize) % vectors.len();
736        centroids.push(vectors[first_idx].clone());
737
738        // Choose remaining centroids
739        while centroids.len() < self.config.n_clusters {
740            let mut distances = Vec::with_capacity(vectors.len());
741            let mut sum_distances = 0.0;
742
743            // Calculate distance to nearest centroid for each vector
744            for vector in vectors {
745                let min_dist = centroids
746                    .iter()
747                    .map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
748                    .fold(f32::INFINITY, |a, b| a.min(b));
749
750                distances.push(min_dist * min_dist); // Square for k-means++
751                sum_distances += min_dist * min_dist;
752            }
753
754            // Choose next centroid with probability proportional to squared distance
755            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
756            let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
757
758            let mut cumulative = 0.0;
759            for (i, &dist) in distances.iter().enumerate() {
760                cumulative += dist;
761                if cumulative >= threshold {
762                    centroids.push(vectors[i].clone());
763                    break;
764                }
765            }
766        }
767
768        Ok(centroids)
769    }
770
771    /// Compute centroid of a cluster
772    fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
773        if cluster.is_empty() {
774            return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
775        }
776
777        let dims = cluster[0].dimensions;
778        let mut sum = vec![0.0; dims];
779
780        for vector in cluster {
781            let values = vector.as_f32();
782            for (i, &val) in values.iter().enumerate() {
783                sum[i] += val;
784            }
785        }
786
787        let count = cluster.len() as f32;
788        for val in &mut sum {
789            *val /= count;
790        }
791
792        Vector::new(sum)
793    }
794
795    /// Find the nearest centroid for a vector
796    fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
797        if self.centroids.is_empty() {
798            return Err(anyhow!("No centroids available"));
799        }
800
801        let mut min_distance = f32::INFINITY;
802        let mut nearest_idx = 0;
803
804        for (i, centroid) in self.centroids.iter().enumerate() {
805            let distance = vector.euclidean_distance(centroid)?;
806            if distance < min_distance {
807                min_distance = distance;
808                nearest_idx = i;
809            }
810        }
811
812        Ok(nearest_idx)
813    }
814
815    /// Find the n_probes nearest centroids for a query
816    fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
817        let mut distances: Vec<(usize, f32)> = self
818            .centroids
819            .iter()
820            .enumerate()
821            .map(|(i, centroid)| {
822                let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
823                (i, dist)
824            })
825            .collect();
826
827        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
828
829        Ok(distances
830            .into_iter()
831            .take(n_probes.min(self.centroids.len()))
832            .map(|(i, _)| i)
833            .collect())
834    }
835
836    /// Get comprehensive statistics about the IVF index including compression info
837    pub fn stats(&self) -> IvfStats {
838        let mut total_list_stats = InvertedListStats {
839            total_vectors: 0,
840            full_vectors: 0,
841            quantized_vectors: 0,
842            compression_ratio: 0.0,
843            multi_level_vectors: 0,
844            multi_codebook_vectors: 0,
845            quantization_strategy: QuantizationStrategy::None,
846        };
847
848        let mut cluster_stats = Vec::new();
849        let mut vectors_per_cluster = Vec::new();
850        let mut non_empty_clusters = 0;
851
852        for list in &self.inverted_lists {
853            let list_guard = list
854                .read()
855                .expect("inverted list lock should not be poisoned");
856            let stats = list_guard.stats();
857
858            total_list_stats.total_vectors += stats.total_vectors;
859            total_list_stats.full_vectors += stats.full_vectors;
860            total_list_stats.quantized_vectors += stats.quantized_vectors;
861            total_list_stats.multi_level_vectors += stats.multi_level_vectors;
862            total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
863
864            vectors_per_cluster.push(stats.total_vectors);
865            if stats.total_vectors > 0 {
866                non_empty_clusters += 1;
867            }
868
869            cluster_stats.push(stats);
870        }
871
872        // Calculate overall compression ratio
873        if total_list_stats.total_vectors > 0 {
874            total_list_stats.compression_ratio =
875                total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
876        }
877
878        let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
879            self.n_vectors as f32 / self.config.n_clusters as f32
880        } else {
881            0.0
882        };
883
884        IvfStats {
885            n_clusters: self.config.n_clusters,
886            n_probes: self.config.n_probes,
887            n_vectors: self.n_vectors,
888            is_trained: self.is_trained,
889            dimensions: self.dimensions,
890            vectors_per_cluster,
891            avg_vectors_per_cluster,
892            non_empty_clusters,
893            enable_residual_quantization: self.config.enable_residual_quantization,
894            quantization_strategy: self.config.quantization.clone(),
895            compression_stats: Some(total_list_stats),
896            cluster_stats,
897        }
898    }
899}
900
901impl VectorIndex for IvfIndex {
902    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
903        if !self.is_trained {
904            return Err(anyhow!(
905                "IVF index must be trained before inserting vectors"
906            ));
907        }
908
909        // Validate dimensions
910        if let Some(dims) = self.dimensions {
911            if vector.dimensions != dims {
912                return Err(anyhow!(
913                    "Vector dimensions {} don't match index dimensions {}",
914                    vector.dimensions,
915                    dims
916                ));
917            }
918        }
919
920        // Find nearest centroid
921        let cluster_idx = self.find_nearest_centroid(&vector)?;
922        let centroid = &self.centroids[cluster_idx];
923
924        let mut list = self.inverted_lists[cluster_idx]
925            .write()
926            .expect("inverted_lists lock should not be poisoned");
927
928        // Handle quantization based on strategy
929        match &self.config.quantization {
930            QuantizationStrategy::None => {
931                if self.config.enable_residual_quantization {
932                    // Backward compatibility: use residual quantization
933                    let residual = vector.subtract(centroid)?;
934                    list.add_residual(uri, residual, centroid)?;
935                } else {
936                    list.add_full(uri, vector);
937                }
938            }
939            _ => {
940                // Use new quantization strategies
941                let residual = vector.subtract(centroid)?;
942                list.add_residual(uri, residual, centroid)?;
943            }
944        }
945
946        self.n_vectors += 1;
947        Ok(())
948    }
949
950    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
951        if !self.is_trained {
952            return Err(anyhow!("IVF index must be trained before searching"));
953        }
954
955        // Find nearest centroids to probe
956        let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
957
958        // Search in selected inverted lists
959        let mut all_results = Vec::new();
960        for idx in probe_indices {
961            let list = self.inverted_lists[idx]
962                .read()
963                .expect("inverted_lists lock should not be poisoned");
964            let centroid = &self.centroids[idx];
965            let mut results = list.search(query, centroid, k)?;
966            all_results.append(&mut results);
967        }
968
969        // Sort and truncate to k results
970        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
971        all_results.truncate(k);
972
973        Ok(all_results)
974    }
975
976    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
977        if !self.is_trained {
978            return Err(anyhow!("IVF index must be trained before searching"));
979        }
980
981        // Find nearest centroids to probe
982        let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
983
984        // Search in selected inverted lists
985        let mut all_results = Vec::new();
986        for idx in probe_indices {
987            let list = self.inverted_lists[idx]
988                .read()
989                .expect("inverted_lists lock should not be poisoned");
990            let centroid = &self.centroids[idx];
991            let results = list.search(query, centroid, self.n_vectors)?;
992
993            // Filter by threshold
994            for (uri, similarity) in results {
995                if similarity >= threshold {
996                    all_results.push((uri, similarity));
997                }
998            }
999        }
1000
1001        // Sort by similarity
1002        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1003
1004        Ok(all_results)
1005    }
1006
1007    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1008        // IVF doesn't maintain a direct URI to vector mapping
1009        // Would need to search through all inverted lists
1010        // For efficiency, we return None
1011        None
1012    }
1013}
1014
1015/// Statistics for IVF index
1016#[derive(Debug, Clone)]
1017pub struct IvfStats {
1018    pub n_vectors: usize,
1019    pub n_clusters: usize,
1020    pub n_probes: usize,
1021    pub is_trained: bool,
1022    pub dimensions: Option<usize>,
1023    pub vectors_per_cluster: Vec<usize>,
1024    pub avg_vectors_per_cluster: f32,
1025    pub non_empty_clusters: usize,
1026    pub enable_residual_quantization: bool,
1027    pub quantization_strategy: QuantizationStrategy,
1028    pub compression_stats: Option<InvertedListStats>,
1029    pub cluster_stats: Vec<InvertedListStats>,
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035
1036    #[test]
1037    fn test_ivf_basic() {
1038        let config = IvfConfig {
1039            n_clusters: 4,
1040            n_probes: 2,
1041            ..Default::default()
1042        };
1043
1044        let mut index = IvfIndex::new(config).unwrap();
1045
1046        // Create training vectors
1047        let training_vectors = vec![
1048            Vector::new(vec![1.0, 0.0]),
1049            Vector::new(vec![0.0, 1.0]),
1050            Vector::new(vec![-1.0, 0.0]),
1051            Vector::new(vec![0.0, -1.0]),
1052            Vector::new(vec![0.5, 0.5]),
1053            Vector::new(vec![-0.5, 0.5]),
1054            Vector::new(vec![-0.5, -0.5]),
1055            Vector::new(vec![0.5, -0.5]),
1056        ];
1057
1058        // Train the index
1059        index.train(&training_vectors).unwrap();
1060        assert!(index.is_trained);
1061
1062        // Insert vectors
1063        for (i, vec) in training_vectors.iter().enumerate() {
1064            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1065        }
1066
1067        // Search for nearest neighbors
1068        let query = Vector::new(vec![0.9, 0.1]);
1069        let results = index.search_knn(&query, 3).unwrap();
1070
1071        assert!(!results.is_empty());
1072        assert!(results.len() <= 3);
1073
1074        // The first result should be vec0 (closest to [1.0, 0.0])
1075        assert_eq!(results[0].0, "vec0");
1076    }
1077
1078    #[test]
1079    fn test_ivf_threshold_search() {
1080        let config = IvfConfig {
1081            n_clusters: 2,
1082            n_probes: 2,
1083            ..Default::default()
1084        };
1085
1086        let mut index = IvfIndex::new(config).unwrap();
1087
1088        // Create and train with vectors
1089        let training_vectors = vec![
1090            Vector::new(vec![1.0, 0.0, 0.0]),
1091            Vector::new(vec![0.0, 1.0, 0.0]),
1092            Vector::new(vec![0.0, 0.0, 1.0]),
1093            Vector::new(vec![0.5, 0.5, 0.0]),
1094        ];
1095
1096        index.train(&training_vectors).unwrap();
1097
1098        // Insert vectors
1099        index
1100            .insert("v1".to_string(), training_vectors[0].clone())
1101            .unwrap();
1102        index
1103            .insert("v2".to_string(), training_vectors[1].clone())
1104            .unwrap();
1105        index
1106            .insert("v3".to_string(), training_vectors[2].clone())
1107            .unwrap();
1108        index
1109            .insert("v4".to_string(), training_vectors[3].clone())
1110            .unwrap();
1111
1112        // Search with threshold
1113        let query = Vector::new(vec![0.9, 0.1, 0.0]);
1114        let results = index.search_threshold(&query, 0.5).unwrap();
1115
1116        assert!(!results.is_empty());
1117        // Should find vectors with similarity >= 0.5
1118        for (_, similarity) in &results {
1119            assert!(*similarity >= 0.5);
1120        }
1121    }
1122
1123    #[test]
1124    fn test_ivf_stats() {
1125        let config = IvfConfig {
1126            n_clusters: 3,
1127            n_probes: 1,
1128            ..Default::default()
1129        };
1130
1131        let mut index = IvfIndex::new(config).unwrap();
1132
1133        // Train with simple vectors
1134        let training_vectors = vec![
1135            Vector::new(vec![1.0, 0.0]),
1136            Vector::new(vec![0.0, 1.0]),
1137            Vector::new(vec![-1.0, -1.0]),
1138        ];
1139
1140        index.train(&training_vectors).unwrap();
1141
1142        // Insert some vectors
1143        index
1144            .insert("a".to_string(), Vector::new(vec![1.1, 0.1]))
1145            .unwrap();
1146        index
1147            .insert("b".to_string(), Vector::new(vec![0.1, 1.1]))
1148            .unwrap();
1149
1150        let stats = index.stats();
1151        assert_eq!(stats.n_vectors, 2);
1152        assert_eq!(stats.n_clusters, 3);
1153        assert!(stats.is_trained);
1154        assert_eq!(stats.dimensions, Some(2));
1155    }
1156
1157    #[test]
1158    fn test_ivf_multi_level_quantization() {
1159        use crate::pq::PQConfig;
1160
1161        // Create PQ configs for different levels
1162        let pq_config_1 = PQConfig {
1163            n_subquantizers: 2,
1164            n_bits: 8,
1165            ..Default::default()
1166        };
1167        let pq_config_2 = PQConfig {
1168            n_subquantizers: 2,
1169            n_bits: 4,
1170            ..Default::default()
1171        };
1172
1173        let mut index =
1174            IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1175                .unwrap();
1176
1177        // Create training vectors
1178        let training_vectors = vec![
1179            Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1180            Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1181            Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1182            Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1183            Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
1184            Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
1185        ];
1186
1187        // Train the index
1188        index.train(&training_vectors).unwrap();
1189        assert!(index.is_trained);
1190
1191        // Insert vectors
1192        for (i, vec) in training_vectors.iter().enumerate() {
1193            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1194        }
1195
1196        // Search for nearest neighbors
1197        let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1198        let results = index.search_knn(&query, 3).unwrap();
1199
1200        assert!(!results.is_empty());
1201        assert!(results.len() <= 3);
1202
1203        // Check stats
1204        let stats = index.stats();
1205        assert!(matches!(
1206            stats.quantization_strategy,
1207            QuantizationStrategy::ResidualQuantization { .. }
1208        ));
1209        if let Some(compression_stats) = &stats.compression_stats {
1210            assert!(compression_stats.multi_level_vectors > 0);
1211        }
1212    }
1213
1214    #[test]
1215    fn test_ivf_multi_codebook_quantization() {
1216        use crate::pq::PQConfig;
1217
1218        // Create PQ configs for different codebooks
1219        let pq_config_1 = PQConfig {
1220            n_subquantizers: 2,
1221            n_bits: 8,
1222            ..Default::default()
1223        };
1224        let pq_config_2 = PQConfig {
1225            n_subquantizers: 2,
1226            n_bits: 8,
1227            ..Default::default()
1228        };
1229
1230        let mut index =
1231            IvfIndex::new_with_multi_codebook_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1232                .unwrap();
1233
1234        // Create training vectors
1235        let training_vectors = vec![
1236            Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1237            Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1238            Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1239            Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1240            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1241        ];
1242
1243        // Train the index
1244        index.train(&training_vectors).unwrap();
1245        assert!(index.is_trained);
1246
1247        // Insert vectors
1248        for (i, vec) in training_vectors.iter().enumerate() {
1249            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1250        }
1251
1252        // Search for nearest neighbors
1253        let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1254        let results = index.search_knn(&query, 2).unwrap();
1255
1256        assert!(!results.is_empty());
1257        assert!(results.len() <= 2);
1258
1259        // Check stats
1260        let stats = index.stats();
1261        assert!(matches!(
1262            stats.quantization_strategy,
1263            QuantizationStrategy::MultiCodebook { .. }
1264        ));
1265        if let Some(compression_stats) = &stats.compression_stats {
1266            assert!(compression_stats.multi_codebook_vectors > 0);
1267        }
1268    }
1269
1270    #[test]
1271    fn test_quantization_strategies() {
1272        use crate::pq::PQConfig;
1273
1274        let pq_config = PQConfig::default();
1275
1276        // Test different quantization strategies
1277        let strategies = vec![
1278            QuantizationStrategy::None,
1279            QuantizationStrategy::ProductQuantization(pq_config.clone()),
1280            QuantizationStrategy::ResidualQuantization {
1281                levels: 2,
1282                pq_configs: vec![pq_config.clone(), pq_config.clone()],
1283            },
1284            QuantizationStrategy::MultiCodebook {
1285                num_codebooks: 2,
1286                pq_configs: vec![pq_config.clone(), pq_config.clone()],
1287            },
1288        ];
1289
1290        for strategy in strategies {
1291            let config = IvfConfig {
1292                n_clusters: 2,
1293                n_probes: 1,
1294                quantization: strategy.clone(),
1295                ..Default::default()
1296            };
1297
1298            let index = IvfIndex::new(config);
1299            assert!(
1300                index.is_ok(),
1301                "Failed to create index with strategy: {strategy:?}"
1302            );
1303        }
1304    }
1305}