oxirs_vec/
ivf.rs

1//! Inverted File (IVF) index implementation for approximate nearest neighbor search
2//!
3//! IVF is a clustering-based indexing method that partitions the vector space into
4//! Voronoi cells. Each cell has a centroid, and vectors are assigned to their nearest
5//! centroid. During search, only a subset of cells are examined, greatly reducing
6//! search time at the cost of some accuracy.
7
8use crate::{
9    pq::{PQConfig, PQIndex},
10    Vector, VectorIndex,
11};
12use anyhow::{anyhow, Result};
13use std::sync::{Arc, RwLock};
14
15/// Quantization strategy for residuals
16#[derive(Debug, Clone, PartialEq)]
17pub enum QuantizationStrategy {
18    /// No quantization - store full residuals
19    None,
20    /// Single-level product quantization
21    ProductQuantization(PQConfig),
22    /// Residual quantization with multiple levels
23    ResidualQuantization {
24        levels: usize,
25        pq_configs: Vec<PQConfig>,
26    },
27    /// Multi-codebook quantization for improved accuracy
28    MultiCodebook {
29        num_codebooks: usize,
30        pq_configs: Vec<PQConfig>,
31    },
32}
33
34/// Configuration for IVF index
35#[derive(Debug, Clone)]
36pub struct IvfConfig {
37    /// Number of clusters (Voronoi cells)
38    pub n_clusters: usize,
39    /// Number of probes during search (cells to examine)
40    pub n_probes: usize,
41    /// Maximum iterations for k-means clustering
42    pub max_iterations: usize,
43    /// Convergence threshold for k-means
44    pub convergence_threshold: f32,
45    /// Random seed for reproducibility
46    pub seed: Option<u64>,
47    /// Quantization strategy for residuals
48    pub quantization: QuantizationStrategy,
49    /// Enable residual quantization for compression (deprecated - use quantization field)
50    pub enable_residual_quantization: bool,
51    /// Product quantization configuration for residuals (deprecated - use quantization field)
52    pub pq_config: Option<PQConfig>,
53}
54
55impl Default for IvfConfig {
56    fn default() -> Self {
57        Self {
58            n_clusters: 256,
59            n_probes: 8,
60            max_iterations: 100,
61            convergence_threshold: 1e-4,
62            seed: None,
63            quantization: QuantizationStrategy::None,
64            enable_residual_quantization: false,
65            pq_config: None,
66        }
67    }
68}
69
70/// Storage format for vectors in inverted lists
71#[derive(Debug, Clone)]
72enum VectorStorage {
73    /// Store full vectors
74    Full(Vector),
75    /// Store quantized residuals with PQ codes
76    Quantized(Vec<u8>),
77    /// Store multi-level residual quantization codes
78    MultiLevelQuantized {
79        levels: Vec<Vec<u8>>,           // PQ codes for each quantization level
80        final_residual: Option<Vector>, // Optional final unquantized residual
81    },
82    /// Store multi-codebook quantization codes
83    MultiCodebook {
84        codebooks: Vec<Vec<u8>>, // PQ codes from different codebooks
85        weights: Vec<f32>,       // Weights for combining codebook predictions
86    },
87}
88
89/// Inverted list storing vectors for a single cluster
90#[derive(Debug, Clone)]
91struct InvertedList {
92    /// Vectors in this cluster with their storage format
93    vectors: Vec<(String, VectorStorage)>,
94    /// Quantization strategy used for this list
95    quantization: QuantizationStrategy,
96    /// Product quantizer for single-level quantization
97    pq_index: Option<PQIndex>,
98    /// Multiple PQ indexes for multi-level residual quantization
99    multi_level_pq: Vec<PQIndex>,
100    /// Multiple PQ indexes for multi-codebook quantization
101    multi_codebook_pq: Vec<PQIndex>,
102    /// Codebook weights for multi-codebook quantization
103    codebook_weights: Vec<f32>,
104}
105
106impl InvertedList {
107    fn new() -> Self {
108        Self {
109            vectors: Vec::new(),
110            quantization: QuantizationStrategy::None,
111            pq_index: None,
112            multi_level_pq: Vec::new(),
113            multi_codebook_pq: Vec::new(),
114            codebook_weights: Vec::new(),
115        }
116    }
117
118    fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
119        let mut list = Self {
120            vectors: Vec::new(),
121            quantization: quantization.clone(),
122            pq_index: None,
123            multi_level_pq: Vec::new(),
124            multi_codebook_pq: Vec::new(),
125            codebook_weights: Vec::new(),
126        };
127
128        match quantization {
129            QuantizationStrategy::None => {}
130            QuantizationStrategy::ProductQuantization(pq_config) => {
131                list.pq_index = Some(PQIndex::new(pq_config));
132            }
133            QuantizationStrategy::ResidualQuantization {
134                levels: _,
135                ref pq_configs,
136            } => {
137                for pq_config in pq_configs {
138                    list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
139                }
140            }
141            QuantizationStrategy::MultiCodebook {
142                num_codebooks,
143                ref pq_configs,
144            } => {
145                for pq_config in pq_configs {
146                    list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
147                }
148                // Initialize equal weights for all codebooks
149                list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
150            }
151        }
152
153        Ok(list)
154    }
155
156    // Deprecated - use new_with_quantization instead
157    fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
158        Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
159    }
160
161    fn add_full(&mut self, uri: String, vector: Vector) {
162        self.vectors.push((uri, VectorStorage::Full(vector)));
163    }
164
165    fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
166        match &self.quantization {
167            QuantizationStrategy::ProductQuantization(_) => {
168                if let Some(ref mut pq_index) = self.pq_index {
169                    // Train PQ on residuals if not already trained
170                    if !pq_index.is_trained() {
171                        let training_residuals = vec![residual.clone()];
172                        pq_index.train(&training_residuals)?;
173                    }
174
175                    let codes = pq_index.encode(&residual)?;
176                    self.vectors.push((uri, VectorStorage::Quantized(codes)));
177                } else {
178                    return Err(anyhow!(
179                        "PQ index not initialized for residual quantization"
180                    ));
181                }
182            }
183            QuantizationStrategy::ResidualQuantization { levels, .. } => {
184                self.add_multi_level_residual(uri, residual, *levels)?;
185            }
186            QuantizationStrategy::MultiCodebook { .. } => {
187                self.add_multi_codebook(uri, residual)?;
188            }
189            QuantizationStrategy::None => {
190                self.add_full(uri, residual);
191            }
192        }
193        Ok(())
194    }
195
196    /// Add vector using multi-level residual quantization
197    fn add_multi_level_residual(
198        &mut self,
199        uri: String,
200        mut residual: Vector,
201        levels: usize,
202    ) -> Result<()> {
203        let mut level_codes = Vec::new();
204
205        for level in 0..levels.min(self.multi_level_pq.len()) {
206            // Train this level's PQ if not already trained
207            if !self.multi_level_pq[level].is_trained() {
208                let training_residuals = vec![residual.clone()];
209                self.multi_level_pq[level].train(&training_residuals)?;
210            }
211
212            // Encode residual at this level
213            let codes = self.multi_level_pq[level].encode(&residual)?;
214            level_codes.push(codes);
215
216            // Compute and subtract the quantized approximation to get next level residual
217            let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
218            residual = residual.subtract(&approximation)?;
219        }
220
221        // Store the final residual if we haven't exhausted all levels
222        let final_residual = if level_codes.len() < levels {
223            Some(residual)
224        } else {
225            None
226        };
227
228        self.vectors.push((
229            uri,
230            VectorStorage::MultiLevelQuantized {
231                levels: level_codes,
232                final_residual,
233            },
234        ));
235
236        Ok(())
237    }
238
239    /// Add vector using multi-codebook quantization
240    fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
241        let mut codebook_codes = Vec::new();
242
243        for pq_index in self.multi_codebook_pq.iter_mut() {
244            // Train this codebook's PQ if not already trained
245            if !pq_index.is_trained() {
246                let training_residuals = vec![residual.clone()];
247                pq_index.train(&training_residuals)?;
248            }
249
250            // Encode residual with this codebook
251            let codes = pq_index.encode(&residual)?;
252            codebook_codes.push(codes);
253        }
254
255        self.vectors.push((
256            uri,
257            VectorStorage::MultiCodebook {
258                codebooks: codebook_codes,
259                weights: self.codebook_weights.clone(),
260            },
261        ));
262
263        Ok(())
264    }
265
266    fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
267        let mut distances: Vec<(String, f32)> = Vec::new();
268        let query_residual = query.subtract(centroid)?;
269
270        for (uri, storage) in &self.vectors {
271            let distance = match storage {
272                VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
273                VectorStorage::Quantized(codes) => {
274                    if let Some(ref pq_index) = self.pq_index {
275                        pq_index.compute_distance(&query_residual, codes)?
276                    } else {
277                        f32::INFINITY
278                    }
279                }
280                VectorStorage::MultiLevelQuantized {
281                    levels,
282                    final_residual,
283                } => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
284                VectorStorage::MultiCodebook { codebooks, weights } => {
285                    self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
286                }
287            };
288            distances.push((uri.clone(), distance));
289        }
290
291        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
292        distances.truncate(k);
293
294        // Convert distances to similarities (1 / (1 + distance))
295        Ok(distances
296            .into_iter()
297            .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
298            .collect())
299    }
300
301    /// Compute distance for multi-level residual quantization
302    fn compute_multi_level_distance(
303        &self,
304        query_residual: &Vector,
305        level_codes: &[Vec<u8>],
306        final_residual: &Option<Vector>,
307    ) -> Result<f32> {
308        let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
309
310        // Reconstruct vector from quantized levels
311        for (level, codes) in level_codes.iter().enumerate() {
312            if level < self.multi_level_pq.len() {
313                let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
314                reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
315            }
316        }
317
318        // Add final unquantized residual if present
319        if let Some(final_res) = final_residual {
320            reconstructed_residual = reconstructed_residual.add(final_res)?;
321        }
322
323        // Compute distance between query residual and reconstructed residual
324        query_residual.euclidean_distance(&reconstructed_residual)
325    }
326
327    /// Compute distance for multi-codebook quantization
328    fn compute_multi_codebook_distance(
329        &self,
330        query_residual: &Vector,
331        codebook_codes: &[Vec<u8>],
332        weights: &[f32],
333    ) -> Result<f32> {
334        let mut weighted_distance = 0.0;
335        let mut total_weight = 0.0;
336
337        // Compute weighted combination of distances from all codebooks
338        for (i, codes) in codebook_codes.iter().enumerate() {
339            if i < self.multi_codebook_pq.len() && i < weights.len() {
340                let codebook_distance =
341                    self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
342                weighted_distance += weights[i] * codebook_distance;
343                total_weight += weights[i];
344            }
345        }
346
347        // Normalize by total weight
348        if total_weight > 0.0 {
349            Ok(weighted_distance / total_weight)
350        } else {
351            Ok(f32::INFINITY)
352        }
353    }
354
355    /// Train product quantizer on collected residuals
356    fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
357        match &self.quantization {
358            QuantizationStrategy::ProductQuantization(_) => {
359                if let Some(ref mut pq_index) = self.pq_index {
360                    pq_index.train(residuals)?;
361                }
362            }
363            QuantizationStrategy::ResidualQuantization { levels, .. } => {
364                self.train_multi_level_pq(residuals, *levels)?;
365            }
366            QuantizationStrategy::MultiCodebook { .. } => {
367                self.train_multi_codebook_pq(residuals)?;
368            }
369            QuantizationStrategy::None => {}
370        }
371        Ok(())
372    }
373
374    /// Train multi-level residual quantization
375    fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
376        let mut current_residuals = residuals.to_vec();
377
378        for level in 0..levels.min(self.multi_level_pq.len()) {
379            // Train PQ at this level
380            self.multi_level_pq[level].train(&current_residuals)?;
381
382            // Compute residuals for next level by subtracting quantized approximation
383            let mut next_residuals = Vec::new();
384            for residual in &current_residuals {
385                let codes = self.multi_level_pq[level].encode(residual)?;
386                let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
387                let next_residual = residual.subtract(&approximation)?;
388                next_residuals.push(next_residual);
389            }
390            current_residuals = next_residuals;
391        }
392
393        Ok(())
394    }
395
396    /// Train multi-codebook quantization
397    fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
398        // Train each codebook independently on the same residuals
399        for pq_index in &mut self.multi_codebook_pq {
400            pq_index.train(residuals)?;
401        }
402
403        // Optionally, optimize codebook weights based on reconstruction quality
404        self.optimize_codebook_weights(residuals)?;
405
406        Ok(())
407    }
408
409    /// Optimize weights for multi-codebook quantization
410    fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
411        if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
412            return Ok(());
413        }
414
415        let num_codebooks = self.multi_codebook_pq.len();
416        let mut reconstruction_errors = vec![0.0; num_codebooks];
417
418        // Compute reconstruction error for each codebook
419        for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
420            let mut total_error = 0.0;
421            for residual in residuals {
422                let codes = pq_index.encode(residual)?;
423                let reconstruction = pq_index.decode_vector(&codes)?;
424                let error = residual
425                    .euclidean_distance(&reconstruction)
426                    .unwrap_or(f32::INFINITY);
427                total_error += error;
428            }
429            reconstruction_errors[i] = total_error / residuals.len() as f32;
430        }
431
432        // Compute weights inversely proportional to reconstruction error
433        let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
434        if max_error > 0.0 {
435            let mut total_weight = 0.0;
436            for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
437                // Higher weight for lower error
438                self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
439                total_weight += self.codebook_weights[i];
440            }
441
442            // Normalize weights
443            if total_weight > 0.0 {
444                for weight in &mut self.codebook_weights {
445                    *weight /= total_weight;
446                }
447            }
448        }
449
450        Ok(())
451    }
452
453    /// Get statistics about this inverted list
454    fn stats(&self) -> InvertedListStats {
455        let mut full_vectors = 0;
456        let mut quantized_vectors = 0;
457        let mut multi_level_vectors = 0;
458        let mut multi_codebook_vectors = 0;
459
460        for (_, storage) in &self.vectors {
461            match storage {
462                VectorStorage::Full(_) => full_vectors += 1,
463                VectorStorage::Quantized(_) => quantized_vectors += 1,
464                VectorStorage::MultiLevelQuantized { .. } => {
465                    quantized_vectors += 1;
466                    multi_level_vectors += 1;
467                }
468                VectorStorage::MultiCodebook { .. } => {
469                    quantized_vectors += 1;
470                    multi_codebook_vectors += 1;
471                }
472            }
473        }
474
475        let total_vectors = self.vectors.len();
476        let compression_ratio = if total_vectors > 0 {
477            quantized_vectors as f32 / total_vectors as f32
478        } else {
479            0.0
480        };
481
482        InvertedListStats {
483            total_vectors,
484            full_vectors,
485            quantized_vectors,
486            compression_ratio,
487            multi_level_vectors,
488            multi_codebook_vectors,
489            quantization_strategy: self.quantization.clone(),
490        }
491    }
492}
493
494/// Statistics for an inverted list
495#[derive(Debug, Clone)]
496pub struct InvertedListStats {
497    pub total_vectors: usize,
498    pub full_vectors: usize,
499    pub quantized_vectors: usize,
500    pub compression_ratio: f32,
501    pub multi_level_vectors: usize,
502    pub multi_codebook_vectors: usize,
503    pub quantization_strategy: QuantizationStrategy,
504}
505
506/// IVF index for approximate nearest neighbor search
507pub struct IvfIndex {
508    config: IvfConfig,
509    /// Cluster centroids
510    centroids: Vec<Vector>,
511    /// Inverted lists (one per cluster)
512    inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
513    /// Dimensions of vectors
514    dimensions: Option<usize>,
515    /// Total number of vectors
516    n_vectors: usize,
517    /// Whether the index has been trained
518    is_trained: bool,
519}
520
521impl IvfIndex {
522    /// Create a new IVF index
523    pub fn new(config: IvfConfig) -> Result<Self> {
524        let mut inverted_lists = Vec::with_capacity(config.n_clusters);
525
526        // Determine quantization strategy (backward compatibility support)
527        let quantization = if config.enable_residual_quantization {
528            if let Some(ref pq_config) = config.pq_config {
529                QuantizationStrategy::ProductQuantization(pq_config.clone())
530            } else {
531                return Err(anyhow!(
532                    "PQ config required when residual quantization is enabled"
533                ));
534            }
535        } else {
536            config.quantization.clone()
537        };
538
539        for _ in 0..config.n_clusters {
540            let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
541                quantization.clone(),
542            )?));
543            inverted_lists.push(inverted_list);
544        }
545
546        Ok(Self {
547            config,
548            centroids: Vec::new(),
549            inverted_lists,
550            dimensions: None,
551            n_vectors: 0,
552            is_trained: false,
553        })
554    }
555
556    /// Create a new IVF index with product quantization
557    pub fn new_with_product_quantization(
558        n_clusters: usize,
559        n_probes: usize,
560        pq_config: PQConfig,
561    ) -> Result<Self> {
562        let config = IvfConfig {
563            n_clusters,
564            n_probes,
565            quantization: QuantizationStrategy::ProductQuantization(pq_config),
566            ..Default::default()
567        };
568        Self::new(config)
569    }
570
571    /// Create a new IVF index with multi-level residual quantization
572    pub fn new_with_multi_level_quantization(
573        n_clusters: usize,
574        n_probes: usize,
575        levels: usize,
576        pq_configs: Vec<PQConfig>,
577    ) -> Result<Self> {
578        if pq_configs.len() < levels {
579            return Err(anyhow!(
580                "Number of PQ configs must be at least equal to levels"
581            ));
582        }
583
584        let config = IvfConfig {
585            n_clusters,
586            n_probes,
587            quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
588            ..Default::default()
589        };
590        Self::new(config)
591    }
592
593    /// Create a new IVF index with multi-codebook quantization
594    pub fn new_with_multi_codebook_quantization(
595        n_clusters: usize,
596        n_probes: usize,
597        num_codebooks: usize,
598        pq_configs: Vec<PQConfig>,
599    ) -> Result<Self> {
600        if pq_configs.len() != num_codebooks {
601            return Err(anyhow!(
602                "Number of PQ configs must equal number of codebooks"
603            ));
604        }
605
606        let config = IvfConfig {
607            n_clusters,
608            n_probes,
609            quantization: QuantizationStrategy::MultiCodebook {
610                num_codebooks,
611                pq_configs,
612            },
613            ..Default::default()
614        };
615        Self::new(config)
616    }
617
618    /// Create a new IVF index with residual quantization enabled (deprecated)
619    pub fn new_with_residual_quantization(
620        n_clusters: usize,
621        n_probes: usize,
622        pq_config: PQConfig,
623    ) -> Result<Self> {
624        Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
625    }
626
627    /// Get the configuration of this index
628    pub fn config(&self) -> &IvfConfig {
629        &self.config
630    }
631
632    /// Train the index with a sample of vectors
633    pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
634        if training_vectors.is_empty() {
635            return Err(anyhow!("Cannot train IVF index with empty training set"));
636        }
637
638        // Validate dimensions
639        let dims = training_vectors[0].dimensions;
640        if !training_vectors.iter().all(|v| v.dimensions == dims) {
641            return Err(anyhow!(
642                "All training vectors must have the same dimensions"
643            ));
644        }
645
646        self.dimensions = Some(dims);
647
648        // Initialize centroids using k-means++
649        self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
650
651        // Run k-means clustering
652        let mut iteration = 0;
653        let mut prev_error = f32::INFINITY;
654
655        while iteration < self.config.max_iterations {
656            // Assign vectors to nearest centroids
657            let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
658
659            for vector in training_vectors {
660                let nearest_idx = self.find_nearest_centroid(vector)?;
661                clusters[nearest_idx].push(vector);
662            }
663
664            // Update centroids
665            let mut total_error = 0.0;
666            for (i, cluster) in clusters.iter().enumerate() {
667                if !cluster.is_empty() {
668                    let new_centroid = self.compute_centroid(cluster);
669                    total_error += self.centroids[i]
670                        .euclidean_distance(&new_centroid)
671                        .unwrap_or(0.0);
672                    self.centroids[i] = new_centroid;
673                }
674            }
675
676            // Check convergence
677            if (prev_error - total_error).abs() < self.config.convergence_threshold {
678                break;
679            }
680
681            prev_error = total_error;
682            iteration += 1;
683        }
684
685        self.is_trained = true;
686
687        // Train quantization if enabled
688        if !matches!(self.config.quantization, QuantizationStrategy::None)
689            || self.config.enable_residual_quantization
690        {
691            self.train_residual_quantization(training_vectors)?;
692        }
693
694        Ok(())
695    }
696
697    /// Train residual quantization on all clusters
698    fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
699        // Collect residuals for each cluster
700        let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
701
702        for vector in training_vectors {
703            let cluster_idx = self.find_nearest_centroid(vector)?;
704            let centroid = &self.centroids[cluster_idx];
705            let residual = vector.subtract(centroid)?;
706            cluster_residuals[cluster_idx].push(residual);
707        }
708
709        // Train PQ for each cluster that has enough residuals
710        for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
711            if residuals.len() > 10 {
712                // Minimum threshold for training
713                let mut list = self.inverted_lists[cluster_idx].write().unwrap();
714                list.train_pq(residuals)?;
715            }
716        }
717
718        Ok(())
719    }
720
721    /// Initialize centroids using k-means++ algorithm
722    fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
723        use std::collections::hash_map::DefaultHasher;
724        use std::hash::{Hash, Hasher};
725
726        let mut hasher = DefaultHasher::new();
727        self.config.seed.unwrap_or(42).hash(&mut hasher);
728        let mut rng_state = hasher.finish();
729
730        let mut centroids = Vec::with_capacity(self.config.n_clusters);
731
732        // Choose first centroid randomly
733        let first_idx = (rng_state as usize) % vectors.len();
734        centroids.push(vectors[first_idx].clone());
735
736        // Choose remaining centroids
737        while centroids.len() < self.config.n_clusters {
738            let mut distances = Vec::with_capacity(vectors.len());
739            let mut sum_distances = 0.0;
740
741            // Calculate distance to nearest centroid for each vector
742            for vector in vectors {
743                let min_dist = centroids
744                    .iter()
745                    .map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
746                    .fold(f32::INFINITY, |a, b| a.min(b));
747
748                distances.push(min_dist * min_dist); // Square for k-means++
749                sum_distances += min_dist * min_dist;
750            }
751
752            // Choose next centroid with probability proportional to squared distance
753            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
754            let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
755
756            let mut cumulative = 0.0;
757            for (i, &dist) in distances.iter().enumerate() {
758                cumulative += dist;
759                if cumulative >= threshold {
760                    centroids.push(vectors[i].clone());
761                    break;
762                }
763            }
764        }
765
766        Ok(centroids)
767    }
768
769    /// Compute centroid of a cluster
770    fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
771        if cluster.is_empty() {
772            return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
773        }
774
775        let dims = cluster[0].dimensions;
776        let mut sum = vec![0.0; dims];
777
778        for vector in cluster {
779            let values = vector.as_f32();
780            for (i, &val) in values.iter().enumerate() {
781                sum[i] += val;
782            }
783        }
784
785        let count = cluster.len() as f32;
786        for val in &mut sum {
787            *val /= count;
788        }
789
790        Vector::new(sum)
791    }
792
793    /// Find the nearest centroid for a vector
794    fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
795        if self.centroids.is_empty() {
796            return Err(anyhow!("No centroids available"));
797        }
798
799        let mut min_distance = f32::INFINITY;
800        let mut nearest_idx = 0;
801
802        for (i, centroid) in self.centroids.iter().enumerate() {
803            let distance = vector.euclidean_distance(centroid)?;
804            if distance < min_distance {
805                min_distance = distance;
806                nearest_idx = i;
807            }
808        }
809
810        Ok(nearest_idx)
811    }
812
813    /// Find the n_probes nearest centroids for a query
814    fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
815        let mut distances: Vec<(usize, f32)> = self
816            .centroids
817            .iter()
818            .enumerate()
819            .map(|(i, centroid)| {
820                let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
821                (i, dist)
822            })
823            .collect();
824
825        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
826
827        Ok(distances
828            .into_iter()
829            .take(n_probes.min(self.centroids.len()))
830            .map(|(i, _)| i)
831            .collect())
832    }
833
834    /// Get comprehensive statistics about the IVF index including compression info
835    pub fn stats(&self) -> IvfStats {
836        let mut total_list_stats = InvertedListStats {
837            total_vectors: 0,
838            full_vectors: 0,
839            quantized_vectors: 0,
840            compression_ratio: 0.0,
841            multi_level_vectors: 0,
842            multi_codebook_vectors: 0,
843            quantization_strategy: QuantizationStrategy::None,
844        };
845
846        let mut cluster_stats = Vec::new();
847        let mut vectors_per_cluster = Vec::new();
848        let mut non_empty_clusters = 0;
849
850        for list in &self.inverted_lists {
851            let list_guard = list
852                .read()
853                .expect("inverted list lock should not be poisoned");
854            let stats = list_guard.stats();
855
856            total_list_stats.total_vectors += stats.total_vectors;
857            total_list_stats.full_vectors += stats.full_vectors;
858            total_list_stats.quantized_vectors += stats.quantized_vectors;
859            total_list_stats.multi_level_vectors += stats.multi_level_vectors;
860            total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
861
862            vectors_per_cluster.push(stats.total_vectors);
863            if stats.total_vectors > 0 {
864                non_empty_clusters += 1;
865            }
866
867            cluster_stats.push(stats);
868        }
869
870        // Calculate overall compression ratio
871        if total_list_stats.total_vectors > 0 {
872            total_list_stats.compression_ratio =
873                total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
874        }
875
876        let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
877            self.n_vectors as f32 / self.config.n_clusters as f32
878        } else {
879            0.0
880        };
881
882        IvfStats {
883            n_clusters: self.config.n_clusters,
884            n_probes: self.config.n_probes,
885            n_vectors: self.n_vectors,
886            is_trained: self.is_trained,
887            dimensions: self.dimensions,
888            vectors_per_cluster,
889            avg_vectors_per_cluster,
890            non_empty_clusters,
891            enable_residual_quantization: self.config.enable_residual_quantization,
892            quantization_strategy: self.config.quantization.clone(),
893            compression_stats: Some(total_list_stats),
894            cluster_stats,
895        }
896    }
897}
898
899impl VectorIndex for IvfIndex {
900    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
901        if !self.is_trained {
902            return Err(anyhow!(
903                "IVF index must be trained before inserting vectors"
904            ));
905        }
906
907        // Validate dimensions
908        if let Some(dims) = self.dimensions {
909            if vector.dimensions != dims {
910                return Err(anyhow!(
911                    "Vector dimensions {} don't match index dimensions {}",
912                    vector.dimensions,
913                    dims
914                ));
915            }
916        }
917
918        // Find nearest centroid
919        let cluster_idx = self.find_nearest_centroid(&vector)?;
920        let centroid = &self.centroids[cluster_idx];
921
922        let mut list = self.inverted_lists[cluster_idx].write().unwrap();
923
924        // Handle quantization based on strategy
925        match &self.config.quantization {
926            QuantizationStrategy::None => {
927                if self.config.enable_residual_quantization {
928                    // Backward compatibility: use residual quantization
929                    let residual = vector.subtract(centroid)?;
930                    list.add_residual(uri, residual, centroid)?;
931                } else {
932                    list.add_full(uri, vector);
933                }
934            }
935            _ => {
936                // Use new quantization strategies
937                let residual = vector.subtract(centroid)?;
938                list.add_residual(uri, residual, centroid)?;
939            }
940        }
941
942        self.n_vectors += 1;
943        Ok(())
944    }
945
946    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
947        if !self.is_trained {
948            return Err(anyhow!("IVF index must be trained before searching"));
949        }
950
951        // Find nearest centroids to probe
952        let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
953
954        // Search in selected inverted lists
955        let mut all_results = Vec::new();
956        for idx in probe_indices {
957            let list = self.inverted_lists[idx].read().unwrap();
958            let centroid = &self.centroids[idx];
959            let mut results = list.search(query, centroid, k)?;
960            all_results.append(&mut results);
961        }
962
963        // Sort and truncate to k results
964        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
965        all_results.truncate(k);
966
967        Ok(all_results)
968    }
969
970    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
971        if !self.is_trained {
972            return Err(anyhow!("IVF index must be trained before searching"));
973        }
974
975        // Find nearest centroids to probe
976        let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
977
978        // Search in selected inverted lists
979        let mut all_results = Vec::new();
980        for idx in probe_indices {
981            let list = self.inverted_lists[idx].read().unwrap();
982            let centroid = &self.centroids[idx];
983            let results = list.search(query, centroid, self.n_vectors)?;
984
985            // Filter by threshold
986            for (uri, similarity) in results {
987                if similarity >= threshold {
988                    all_results.push((uri, similarity));
989                }
990            }
991        }
992
993        // Sort by similarity
994        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
995
996        Ok(all_results)
997    }
998
999    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1000        // IVF doesn't maintain a direct URI to vector mapping
1001        // Would need to search through all inverted lists
1002        // For efficiency, we return None
1003        None
1004    }
1005}
1006
1007/// Statistics for IVF index
1008#[derive(Debug, Clone)]
1009pub struct IvfStats {
1010    pub n_vectors: usize,
1011    pub n_clusters: usize,
1012    pub n_probes: usize,
1013    pub is_trained: bool,
1014    pub dimensions: Option<usize>,
1015    pub vectors_per_cluster: Vec<usize>,
1016    pub avg_vectors_per_cluster: f32,
1017    pub non_empty_clusters: usize,
1018    pub enable_residual_quantization: bool,
1019    pub quantization_strategy: QuantizationStrategy,
1020    pub compression_stats: Option<InvertedListStats>,
1021    pub cluster_stats: Vec<InvertedListStats>,
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026    use super::*;
1027
1028    #[test]
1029    fn test_ivf_basic() {
1030        let config = IvfConfig {
1031            n_clusters: 4,
1032            n_probes: 2,
1033            ..Default::default()
1034        };
1035
1036        let mut index = IvfIndex::new(config).unwrap();
1037
1038        // Create training vectors
1039        let training_vectors = vec![
1040            Vector::new(vec![1.0, 0.0]),
1041            Vector::new(vec![0.0, 1.0]),
1042            Vector::new(vec![-1.0, 0.0]),
1043            Vector::new(vec![0.0, -1.0]),
1044            Vector::new(vec![0.5, 0.5]),
1045            Vector::new(vec![-0.5, 0.5]),
1046            Vector::new(vec![-0.5, -0.5]),
1047            Vector::new(vec![0.5, -0.5]),
1048        ];
1049
1050        // Train the index
1051        index.train(&training_vectors).unwrap();
1052        assert!(index.is_trained);
1053
1054        // Insert vectors
1055        for (i, vec) in training_vectors.iter().enumerate() {
1056            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1057        }
1058
1059        // Search for nearest neighbors
1060        let query = Vector::new(vec![0.9, 0.1]);
1061        let results = index.search_knn(&query, 3).unwrap();
1062
1063        assert!(!results.is_empty());
1064        assert!(results.len() <= 3);
1065
1066        // The first result should be vec0 (closest to [1.0, 0.0])
1067        assert_eq!(results[0].0, "vec0");
1068    }
1069
1070    #[test]
1071    fn test_ivf_threshold_search() {
1072        let config = IvfConfig {
1073            n_clusters: 2,
1074            n_probes: 2,
1075            ..Default::default()
1076        };
1077
1078        let mut index = IvfIndex::new(config).unwrap();
1079
1080        // Create and train with vectors
1081        let training_vectors = vec![
1082            Vector::new(vec![1.0, 0.0, 0.0]),
1083            Vector::new(vec![0.0, 1.0, 0.0]),
1084            Vector::new(vec![0.0, 0.0, 1.0]),
1085            Vector::new(vec![0.5, 0.5, 0.0]),
1086        ];
1087
1088        index.train(&training_vectors).unwrap();
1089
1090        // Insert vectors
1091        index
1092            .insert("v1".to_string(), training_vectors[0].clone())
1093            .unwrap();
1094        index
1095            .insert("v2".to_string(), training_vectors[1].clone())
1096            .unwrap();
1097        index
1098            .insert("v3".to_string(), training_vectors[2].clone())
1099            .unwrap();
1100        index
1101            .insert("v4".to_string(), training_vectors[3].clone())
1102            .unwrap();
1103
1104        // Search with threshold
1105        let query = Vector::new(vec![0.9, 0.1, 0.0]);
1106        let results = index.search_threshold(&query, 0.5).unwrap();
1107
1108        assert!(!results.is_empty());
1109        // Should find vectors with similarity >= 0.5
1110        for (_, similarity) in &results {
1111            assert!(*similarity >= 0.5);
1112        }
1113    }
1114
1115    #[test]
1116    fn test_ivf_stats() {
1117        let config = IvfConfig {
1118            n_clusters: 3,
1119            n_probes: 1,
1120            ..Default::default()
1121        };
1122
1123        let mut index = IvfIndex::new(config).unwrap();
1124
1125        // Train with simple vectors
1126        let training_vectors = vec![
1127            Vector::new(vec![1.0, 0.0]),
1128            Vector::new(vec![0.0, 1.0]),
1129            Vector::new(vec![-1.0, -1.0]),
1130        ];
1131
1132        index.train(&training_vectors).unwrap();
1133
1134        // Insert some vectors
1135        index
1136            .insert("a".to_string(), Vector::new(vec![1.1, 0.1]))
1137            .unwrap();
1138        index
1139            .insert("b".to_string(), Vector::new(vec![0.1, 1.1]))
1140            .unwrap();
1141
1142        let stats = index.stats();
1143        assert_eq!(stats.n_vectors, 2);
1144        assert_eq!(stats.n_clusters, 3);
1145        assert!(stats.is_trained);
1146        assert_eq!(stats.dimensions, Some(2));
1147    }
1148
1149    #[test]
1150    fn test_ivf_multi_level_quantization() {
1151        use crate::pq::PQConfig;
1152
1153        // Create PQ configs for different levels
1154        let pq_config_1 = PQConfig {
1155            n_subquantizers: 2,
1156            n_bits: 8,
1157            ..Default::default()
1158        };
1159        let pq_config_2 = PQConfig {
1160            n_subquantizers: 2,
1161            n_bits: 4,
1162            ..Default::default()
1163        };
1164
1165        let mut index =
1166            IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1167                .unwrap();
1168
1169        // Create training vectors
1170        let training_vectors = vec![
1171            Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1172            Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1173            Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1174            Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1175            Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
1176            Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
1177        ];
1178
1179        // Train the index
1180        index.train(&training_vectors).unwrap();
1181        assert!(index.is_trained);
1182
1183        // Insert vectors
1184        for (i, vec) in training_vectors.iter().enumerate() {
1185            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1186        }
1187
1188        // Search for nearest neighbors
1189        let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1190        let results = index.search_knn(&query, 3).unwrap();
1191
1192        assert!(!results.is_empty());
1193        assert!(results.len() <= 3);
1194
1195        // Check stats
1196        let stats = index.stats();
1197        assert!(matches!(
1198            stats.quantization_strategy,
1199            QuantizationStrategy::ResidualQuantization { .. }
1200        ));
1201        if let Some(compression_stats) = &stats.compression_stats {
1202            assert!(compression_stats.multi_level_vectors > 0);
1203        }
1204    }
1205
1206    #[test]
1207    fn test_ivf_multi_codebook_quantization() {
1208        use crate::pq::PQConfig;
1209
1210        // Create PQ configs for different codebooks
1211        let pq_config_1 = PQConfig {
1212            n_subquantizers: 2,
1213            n_bits: 8,
1214            ..Default::default()
1215        };
1216        let pq_config_2 = PQConfig {
1217            n_subquantizers: 2,
1218            n_bits: 8,
1219            ..Default::default()
1220        };
1221
1222        let mut index =
1223            IvfIndex::new_with_multi_codebook_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1224                .unwrap();
1225
1226        // Create training vectors
1227        let training_vectors = vec![
1228            Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1229            Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1230            Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1231            Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1232            Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1233        ];
1234
1235        // Train the index
1236        index.train(&training_vectors).unwrap();
1237        assert!(index.is_trained);
1238
1239        // Insert vectors
1240        for (i, vec) in training_vectors.iter().enumerate() {
1241            index.insert(format!("vec{i}"), vec.clone()).unwrap();
1242        }
1243
1244        // Search for nearest neighbors
1245        let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1246        let results = index.search_knn(&query, 2).unwrap();
1247
1248        assert!(!results.is_empty());
1249        assert!(results.len() <= 2);
1250
1251        // Check stats
1252        let stats = index.stats();
1253        assert!(matches!(
1254            stats.quantization_strategy,
1255            QuantizationStrategy::MultiCodebook { .. }
1256        ));
1257        if let Some(compression_stats) = &stats.compression_stats {
1258            assert!(compression_stats.multi_codebook_vectors > 0);
1259        }
1260    }
1261
1262    #[test]
1263    fn test_quantization_strategies() {
1264        use crate::pq::PQConfig;
1265
1266        let pq_config = PQConfig::default();
1267
1268        // Test different quantization strategies
1269        let strategies = vec![
1270            QuantizationStrategy::None,
1271            QuantizationStrategy::ProductQuantization(pq_config.clone()),
1272            QuantizationStrategy::ResidualQuantization {
1273                levels: 2,
1274                pq_configs: vec![pq_config.clone(), pq_config.clone()],
1275            },
1276            QuantizationStrategy::MultiCodebook {
1277                num_codebooks: 2,
1278                pq_configs: vec![pq_config.clone(), pq_config.clone()],
1279            },
1280        ];
1281
1282        for strategy in strategies {
1283            let config = IvfConfig {
1284                n_clusters: 2,
1285                n_probes: 1,
1286                quantization: strategy.clone(),
1287                ..Default::default()
1288            };
1289
1290            let index = IvfIndex::new(config);
1291            assert!(
1292                index.is_ok(),
1293                "Failed to create index with strategy: {strategy:?}"
1294            );
1295        }
1296    }
1297}