1use crate::{
9 pq::{PQConfig, PQIndex},
10 Vector, VectorIndex,
11};
12use anyhow::{anyhow, Result};
13use std::sync::{Arc, RwLock};
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum QuantizationStrategy {
18 None,
20 ProductQuantization(PQConfig),
22 ResidualQuantization {
24 levels: usize,
25 pq_configs: Vec<PQConfig>,
26 },
27 MultiCodebook {
29 num_codebooks: usize,
30 pq_configs: Vec<PQConfig>,
31 },
32}
33
34#[derive(Debug, Clone)]
36pub struct IvfConfig {
37 pub n_clusters: usize,
39 pub n_probes: usize,
41 pub max_iterations: usize,
43 pub convergence_threshold: f32,
45 pub seed: Option<u64>,
47 pub quantization: QuantizationStrategy,
49 pub enable_residual_quantization: bool,
51 pub pq_config: Option<PQConfig>,
53}
54
55impl Default for IvfConfig {
56 fn default() -> Self {
57 Self {
58 n_clusters: 256,
59 n_probes: 8,
60 max_iterations: 100,
61 convergence_threshold: 1e-4,
62 seed: None,
63 quantization: QuantizationStrategy::None,
64 enable_residual_quantization: false,
65 pq_config: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72enum VectorStorage {
73 Full(Vector),
75 Quantized(Vec<u8>),
77 MultiLevelQuantized {
79 levels: Vec<Vec<u8>>, final_residual: Option<Vector>, },
82 MultiCodebook {
84 codebooks: Vec<Vec<u8>>, weights: Vec<f32>, },
87}
88
89#[derive(Debug, Clone)]
91struct InvertedList {
92 vectors: Vec<(String, VectorStorage)>,
94 quantization: QuantizationStrategy,
96 pq_index: Option<PQIndex>,
98 multi_level_pq: Vec<PQIndex>,
100 multi_codebook_pq: Vec<PQIndex>,
102 codebook_weights: Vec<f32>,
104}
105
106impl InvertedList {
107 fn new() -> Self {
108 Self {
109 vectors: Vec::new(),
110 quantization: QuantizationStrategy::None,
111 pq_index: None,
112 multi_level_pq: Vec::new(),
113 multi_codebook_pq: Vec::new(),
114 codebook_weights: Vec::new(),
115 }
116 }
117
118 fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
119 let mut list = Self {
120 vectors: Vec::new(),
121 quantization: quantization.clone(),
122 pq_index: None,
123 multi_level_pq: Vec::new(),
124 multi_codebook_pq: Vec::new(),
125 codebook_weights: Vec::new(),
126 };
127
128 match quantization {
129 QuantizationStrategy::None => {}
130 QuantizationStrategy::ProductQuantization(pq_config) => {
131 list.pq_index = Some(PQIndex::new(pq_config));
132 }
133 QuantizationStrategy::ResidualQuantization {
134 levels: _,
135 ref pq_configs,
136 } => {
137 for pq_config in pq_configs {
138 list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
139 }
140 }
141 QuantizationStrategy::MultiCodebook {
142 num_codebooks,
143 ref pq_configs,
144 } => {
145 for pq_config in pq_configs {
146 list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
147 }
148 list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
150 }
151 }
152
153 Ok(list)
154 }
155
156 fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
158 Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
159 }
160
161 fn add_full(&mut self, uri: String, vector: Vector) {
162 self.vectors.push((uri, VectorStorage::Full(vector)));
163 }
164
165 fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
166 match &self.quantization {
167 QuantizationStrategy::ProductQuantization(_) => {
168 if let Some(ref mut pq_index) = self.pq_index {
169 if !pq_index.is_trained() {
171 let training_residuals = vec![residual.clone()];
172 pq_index.train(&training_residuals)?;
173 }
174
175 let codes = pq_index.encode(&residual)?;
176 self.vectors.push((uri, VectorStorage::Quantized(codes)));
177 } else {
178 return Err(anyhow!(
179 "PQ index not initialized for residual quantization"
180 ));
181 }
182 }
183 QuantizationStrategy::ResidualQuantization { levels, .. } => {
184 self.add_multi_level_residual(uri, residual, *levels)?;
185 }
186 QuantizationStrategy::MultiCodebook { .. } => {
187 self.add_multi_codebook(uri, residual)?;
188 }
189 QuantizationStrategy::None => {
190 self.add_full(uri, residual);
191 }
192 }
193 Ok(())
194 }
195
196 fn add_multi_level_residual(
198 &mut self,
199 uri: String,
200 mut residual: Vector,
201 levels: usize,
202 ) -> Result<()> {
203 let mut level_codes = Vec::new();
204
205 for level in 0..levels.min(self.multi_level_pq.len()) {
206 if !self.multi_level_pq[level].is_trained() {
208 let training_residuals = vec![residual.clone()];
209 self.multi_level_pq[level].train(&training_residuals)?;
210 }
211
212 let codes = self.multi_level_pq[level].encode(&residual)?;
214 level_codes.push(codes);
215
216 let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
218 residual = residual.subtract(&approximation)?;
219 }
220
221 let final_residual = if level_codes.len() < levels {
223 Some(residual)
224 } else {
225 None
226 };
227
228 self.vectors.push((
229 uri,
230 VectorStorage::MultiLevelQuantized {
231 levels: level_codes,
232 final_residual,
233 },
234 ));
235
236 Ok(())
237 }
238
239 fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
241 let mut codebook_codes = Vec::new();
242
243 for pq_index in self.multi_codebook_pq.iter_mut() {
244 if !pq_index.is_trained() {
246 let training_residuals = vec![residual.clone()];
247 pq_index.train(&training_residuals)?;
248 }
249
250 let codes = pq_index.encode(&residual)?;
252 codebook_codes.push(codes);
253 }
254
255 self.vectors.push((
256 uri,
257 VectorStorage::MultiCodebook {
258 codebooks: codebook_codes,
259 weights: self.codebook_weights.clone(),
260 },
261 ));
262
263 Ok(())
264 }
265
266 fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
267 let mut distances: Vec<(String, f32)> = Vec::new();
268 let query_residual = query.subtract(centroid)?;
269
270 for (uri, storage) in &self.vectors {
271 let distance = match storage {
272 VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
273 VectorStorage::Quantized(codes) => {
274 if let Some(ref pq_index) = self.pq_index {
275 pq_index.compute_distance(&query_residual, codes)?
276 } else {
277 f32::INFINITY
278 }
279 }
280 VectorStorage::MultiLevelQuantized {
281 levels,
282 final_residual,
283 } => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
284 VectorStorage::MultiCodebook { codebooks, weights } => {
285 self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
286 }
287 };
288 distances.push((uri.clone(), distance));
289 }
290
291 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
292 distances.truncate(k);
293
294 Ok(distances
296 .into_iter()
297 .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
298 .collect())
299 }
300
301 fn compute_multi_level_distance(
303 &self,
304 query_residual: &Vector,
305 level_codes: &[Vec<u8>],
306 final_residual: &Option<Vector>,
307 ) -> Result<f32> {
308 let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
309
310 for (level, codes) in level_codes.iter().enumerate() {
312 if level < self.multi_level_pq.len() {
313 let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
314 reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
315 }
316 }
317
318 if let Some(final_res) = final_residual {
320 reconstructed_residual = reconstructed_residual.add(final_res)?;
321 }
322
323 query_residual.euclidean_distance(&reconstructed_residual)
325 }
326
327 fn compute_multi_codebook_distance(
329 &self,
330 query_residual: &Vector,
331 codebook_codes: &[Vec<u8>],
332 weights: &[f32],
333 ) -> Result<f32> {
334 let mut weighted_distance = 0.0;
335 let mut total_weight = 0.0;
336
337 for (i, codes) in codebook_codes.iter().enumerate() {
339 if i < self.multi_codebook_pq.len() && i < weights.len() {
340 let codebook_distance =
341 self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
342 weighted_distance += weights[i] * codebook_distance;
343 total_weight += weights[i];
344 }
345 }
346
347 if total_weight > 0.0 {
349 Ok(weighted_distance / total_weight)
350 } else {
351 Ok(f32::INFINITY)
352 }
353 }
354
355 fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
357 match &self.quantization {
358 QuantizationStrategy::ProductQuantization(_) => {
359 if let Some(ref mut pq_index) = self.pq_index {
360 pq_index.train(residuals)?;
361 }
362 }
363 QuantizationStrategy::ResidualQuantization { levels, .. } => {
364 self.train_multi_level_pq(residuals, *levels)?;
365 }
366 QuantizationStrategy::MultiCodebook { .. } => {
367 self.train_multi_codebook_pq(residuals)?;
368 }
369 QuantizationStrategy::None => {}
370 }
371 Ok(())
372 }
373
374 fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
376 let mut current_residuals = residuals.to_vec();
377
378 for level in 0..levels.min(self.multi_level_pq.len()) {
379 self.multi_level_pq[level].train(¤t_residuals)?;
381
382 let mut next_residuals = Vec::new();
384 for residual in ¤t_residuals {
385 let codes = self.multi_level_pq[level].encode(residual)?;
386 let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
387 let next_residual = residual.subtract(&approximation)?;
388 next_residuals.push(next_residual);
389 }
390 current_residuals = next_residuals;
391 }
392
393 Ok(())
394 }
395
396 fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
398 for pq_index in &mut self.multi_codebook_pq {
400 pq_index.train(residuals)?;
401 }
402
403 self.optimize_codebook_weights(residuals)?;
405
406 Ok(())
407 }
408
409 fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
411 if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
412 return Ok(());
413 }
414
415 let num_codebooks = self.multi_codebook_pq.len();
416 let mut reconstruction_errors = vec![0.0; num_codebooks];
417
418 for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
420 let mut total_error = 0.0;
421 for residual in residuals {
422 let codes = pq_index.encode(residual)?;
423 let reconstruction = pq_index.decode_vector(&codes)?;
424 let error = residual
425 .euclidean_distance(&reconstruction)
426 .unwrap_or(f32::INFINITY);
427 total_error += error;
428 }
429 reconstruction_errors[i] = total_error / residuals.len() as f32;
430 }
431
432 let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
434 if max_error > 0.0 {
435 let mut total_weight = 0.0;
436 for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
437 self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
439 total_weight += self.codebook_weights[i];
440 }
441
442 if total_weight > 0.0 {
444 for weight in &mut self.codebook_weights {
445 *weight /= total_weight;
446 }
447 }
448 }
449
450 Ok(())
451 }
452
453 fn stats(&self) -> InvertedListStats {
455 let mut full_vectors = 0;
456 let mut quantized_vectors = 0;
457 let mut multi_level_vectors = 0;
458 let mut multi_codebook_vectors = 0;
459
460 for (_, storage) in &self.vectors {
461 match storage {
462 VectorStorage::Full(_) => full_vectors += 1,
463 VectorStorage::Quantized(_) => quantized_vectors += 1,
464 VectorStorage::MultiLevelQuantized { .. } => {
465 quantized_vectors += 1;
466 multi_level_vectors += 1;
467 }
468 VectorStorage::MultiCodebook { .. } => {
469 quantized_vectors += 1;
470 multi_codebook_vectors += 1;
471 }
472 }
473 }
474
475 let total_vectors = self.vectors.len();
476 let compression_ratio = if total_vectors > 0 {
477 quantized_vectors as f32 / total_vectors as f32
478 } else {
479 0.0
480 };
481
482 InvertedListStats {
483 total_vectors,
484 full_vectors,
485 quantized_vectors,
486 compression_ratio,
487 multi_level_vectors,
488 multi_codebook_vectors,
489 quantization_strategy: self.quantization.clone(),
490 }
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct InvertedListStats {
497 pub total_vectors: usize,
498 pub full_vectors: usize,
499 pub quantized_vectors: usize,
500 pub compression_ratio: f32,
501 pub multi_level_vectors: usize,
502 pub multi_codebook_vectors: usize,
503 pub quantization_strategy: QuantizationStrategy,
504}
505
506pub struct IvfIndex {
508 config: IvfConfig,
509 centroids: Vec<Vector>,
511 inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
513 dimensions: Option<usize>,
515 n_vectors: usize,
517 is_trained: bool,
519}
520
521impl IvfIndex {
522 pub fn new(config: IvfConfig) -> Result<Self> {
524 let mut inverted_lists = Vec::with_capacity(config.n_clusters);
525
526 let quantization = if config.enable_residual_quantization {
528 if let Some(ref pq_config) = config.pq_config {
529 QuantizationStrategy::ProductQuantization(pq_config.clone())
530 } else {
531 return Err(anyhow!(
532 "PQ config required when residual quantization is enabled"
533 ));
534 }
535 } else {
536 config.quantization.clone()
537 };
538
539 for _ in 0..config.n_clusters {
540 let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
541 quantization.clone(),
542 )?));
543 inverted_lists.push(inverted_list);
544 }
545
546 Ok(Self {
547 config,
548 centroids: Vec::new(),
549 inverted_lists,
550 dimensions: None,
551 n_vectors: 0,
552 is_trained: false,
553 })
554 }
555
556 pub fn new_with_product_quantization(
558 n_clusters: usize,
559 n_probes: usize,
560 pq_config: PQConfig,
561 ) -> Result<Self> {
562 let config = IvfConfig {
563 n_clusters,
564 n_probes,
565 quantization: QuantizationStrategy::ProductQuantization(pq_config),
566 ..Default::default()
567 };
568 Self::new(config)
569 }
570
571 pub fn new_with_multi_level_quantization(
573 n_clusters: usize,
574 n_probes: usize,
575 levels: usize,
576 pq_configs: Vec<PQConfig>,
577 ) -> Result<Self> {
578 if pq_configs.len() < levels {
579 return Err(anyhow!(
580 "Number of PQ configs must be at least equal to levels"
581 ));
582 }
583
584 let config = IvfConfig {
585 n_clusters,
586 n_probes,
587 quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
588 ..Default::default()
589 };
590 Self::new(config)
591 }
592
593 pub fn new_with_multi_codebook_quantization(
595 n_clusters: usize,
596 n_probes: usize,
597 num_codebooks: usize,
598 pq_configs: Vec<PQConfig>,
599 ) -> Result<Self> {
600 if pq_configs.len() != num_codebooks {
601 return Err(anyhow!(
602 "Number of PQ configs must equal number of codebooks"
603 ));
604 }
605
606 let config = IvfConfig {
607 n_clusters,
608 n_probes,
609 quantization: QuantizationStrategy::MultiCodebook {
610 num_codebooks,
611 pq_configs,
612 },
613 ..Default::default()
614 };
615 Self::new(config)
616 }
617
618 pub fn new_with_residual_quantization(
620 n_clusters: usize,
621 n_probes: usize,
622 pq_config: PQConfig,
623 ) -> Result<Self> {
624 Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
625 }
626
627 pub fn config(&self) -> &IvfConfig {
629 &self.config
630 }
631
632 pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
634 if training_vectors.is_empty() {
635 return Err(anyhow!("Cannot train IVF index with empty training set"));
636 }
637
638 let dims = training_vectors[0].dimensions;
640 if !training_vectors.iter().all(|v| v.dimensions == dims) {
641 return Err(anyhow!(
642 "All training vectors must have the same dimensions"
643 ));
644 }
645
646 self.dimensions = Some(dims);
647
648 self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
650
651 let mut iteration = 0;
653 let mut prev_error = f32::INFINITY;
654
655 while iteration < self.config.max_iterations {
656 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
658
659 for vector in training_vectors {
660 let nearest_idx = self.find_nearest_centroid(vector)?;
661 clusters[nearest_idx].push(vector);
662 }
663
664 let mut total_error = 0.0;
666 for (i, cluster) in clusters.iter().enumerate() {
667 if !cluster.is_empty() {
668 let new_centroid = self.compute_centroid(cluster);
669 total_error += self.centroids[i]
670 .euclidean_distance(&new_centroid)
671 .unwrap_or(0.0);
672 self.centroids[i] = new_centroid;
673 }
674 }
675
676 if (prev_error - total_error).abs() < self.config.convergence_threshold {
678 break;
679 }
680
681 prev_error = total_error;
682 iteration += 1;
683 }
684
685 self.is_trained = true;
686
687 if !matches!(self.config.quantization, QuantizationStrategy::None)
689 || self.config.enable_residual_quantization
690 {
691 self.train_residual_quantization(training_vectors)?;
692 }
693
694 Ok(())
695 }
696
697 fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
699 let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
701
702 for vector in training_vectors {
703 let cluster_idx = self.find_nearest_centroid(vector)?;
704 let centroid = &self.centroids[cluster_idx];
705 let residual = vector.subtract(centroid)?;
706 cluster_residuals[cluster_idx].push(residual);
707 }
708
709 for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
711 if residuals.len() > 10 {
712 let mut list = self.inverted_lists[cluster_idx].write().unwrap();
714 list.train_pq(residuals)?;
715 }
716 }
717
718 Ok(())
719 }
720
721 fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
723 use std::collections::hash_map::DefaultHasher;
724 use std::hash::{Hash, Hasher};
725
726 let mut hasher = DefaultHasher::new();
727 self.config.seed.unwrap_or(42).hash(&mut hasher);
728 let mut rng_state = hasher.finish();
729
730 let mut centroids = Vec::with_capacity(self.config.n_clusters);
731
732 let first_idx = (rng_state as usize) % vectors.len();
734 centroids.push(vectors[first_idx].clone());
735
736 while centroids.len() < self.config.n_clusters {
738 let mut distances = Vec::with_capacity(vectors.len());
739 let mut sum_distances = 0.0;
740
741 for vector in vectors {
743 let min_dist = centroids
744 .iter()
745 .map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
746 .fold(f32::INFINITY, |a, b| a.min(b));
747
748 distances.push(min_dist * min_dist); sum_distances += min_dist * min_dist;
750 }
751
752 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
754 let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
755
756 let mut cumulative = 0.0;
757 for (i, &dist) in distances.iter().enumerate() {
758 cumulative += dist;
759 if cumulative >= threshold {
760 centroids.push(vectors[i].clone());
761 break;
762 }
763 }
764 }
765
766 Ok(centroids)
767 }
768
769 fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
771 if cluster.is_empty() {
772 return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
773 }
774
775 let dims = cluster[0].dimensions;
776 let mut sum = vec![0.0; dims];
777
778 for vector in cluster {
779 let values = vector.as_f32();
780 for (i, &val) in values.iter().enumerate() {
781 sum[i] += val;
782 }
783 }
784
785 let count = cluster.len() as f32;
786 for val in &mut sum {
787 *val /= count;
788 }
789
790 Vector::new(sum)
791 }
792
793 fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
795 if self.centroids.is_empty() {
796 return Err(anyhow!("No centroids available"));
797 }
798
799 let mut min_distance = f32::INFINITY;
800 let mut nearest_idx = 0;
801
802 for (i, centroid) in self.centroids.iter().enumerate() {
803 let distance = vector.euclidean_distance(centroid)?;
804 if distance < min_distance {
805 min_distance = distance;
806 nearest_idx = i;
807 }
808 }
809
810 Ok(nearest_idx)
811 }
812
813 fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
815 let mut distances: Vec<(usize, f32)> = self
816 .centroids
817 .iter()
818 .enumerate()
819 .map(|(i, centroid)| {
820 let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
821 (i, dist)
822 })
823 .collect();
824
825 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
826
827 Ok(distances
828 .into_iter()
829 .take(n_probes.min(self.centroids.len()))
830 .map(|(i, _)| i)
831 .collect())
832 }
833
834 pub fn stats(&self) -> IvfStats {
836 let mut total_list_stats = InvertedListStats {
837 total_vectors: 0,
838 full_vectors: 0,
839 quantized_vectors: 0,
840 compression_ratio: 0.0,
841 multi_level_vectors: 0,
842 multi_codebook_vectors: 0,
843 quantization_strategy: QuantizationStrategy::None,
844 };
845
846 let mut cluster_stats = Vec::new();
847 let mut vectors_per_cluster = Vec::new();
848 let mut non_empty_clusters = 0;
849
850 for list in &self.inverted_lists {
851 let list_guard = list
852 .read()
853 .expect("inverted list lock should not be poisoned");
854 let stats = list_guard.stats();
855
856 total_list_stats.total_vectors += stats.total_vectors;
857 total_list_stats.full_vectors += stats.full_vectors;
858 total_list_stats.quantized_vectors += stats.quantized_vectors;
859 total_list_stats.multi_level_vectors += stats.multi_level_vectors;
860 total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
861
862 vectors_per_cluster.push(stats.total_vectors);
863 if stats.total_vectors > 0 {
864 non_empty_clusters += 1;
865 }
866
867 cluster_stats.push(stats);
868 }
869
870 if total_list_stats.total_vectors > 0 {
872 total_list_stats.compression_ratio =
873 total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
874 }
875
876 let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
877 self.n_vectors as f32 / self.config.n_clusters as f32
878 } else {
879 0.0
880 };
881
882 IvfStats {
883 n_clusters: self.config.n_clusters,
884 n_probes: self.config.n_probes,
885 n_vectors: self.n_vectors,
886 is_trained: self.is_trained,
887 dimensions: self.dimensions,
888 vectors_per_cluster,
889 avg_vectors_per_cluster,
890 non_empty_clusters,
891 enable_residual_quantization: self.config.enable_residual_quantization,
892 quantization_strategy: self.config.quantization.clone(),
893 compression_stats: Some(total_list_stats),
894 cluster_stats,
895 }
896 }
897}
898
899impl VectorIndex for IvfIndex {
900 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
901 if !self.is_trained {
902 return Err(anyhow!(
903 "IVF index must be trained before inserting vectors"
904 ));
905 }
906
907 if let Some(dims) = self.dimensions {
909 if vector.dimensions != dims {
910 return Err(anyhow!(
911 "Vector dimensions {} don't match index dimensions {}",
912 vector.dimensions,
913 dims
914 ));
915 }
916 }
917
918 let cluster_idx = self.find_nearest_centroid(&vector)?;
920 let centroid = &self.centroids[cluster_idx];
921
922 let mut list = self.inverted_lists[cluster_idx].write().unwrap();
923
924 match &self.config.quantization {
926 QuantizationStrategy::None => {
927 if self.config.enable_residual_quantization {
928 let residual = vector.subtract(centroid)?;
930 list.add_residual(uri, residual, centroid)?;
931 } else {
932 list.add_full(uri, vector);
933 }
934 }
935 _ => {
936 let residual = vector.subtract(centroid)?;
938 list.add_residual(uri, residual, centroid)?;
939 }
940 }
941
942 self.n_vectors += 1;
943 Ok(())
944 }
945
946 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
947 if !self.is_trained {
948 return Err(anyhow!("IVF index must be trained before searching"));
949 }
950
951 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
953
954 let mut all_results = Vec::new();
956 for idx in probe_indices {
957 let list = self.inverted_lists[idx].read().unwrap();
958 let centroid = &self.centroids[idx];
959 let mut results = list.search(query, centroid, k)?;
960 all_results.append(&mut results);
961 }
962
963 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
965 all_results.truncate(k);
966
967 Ok(all_results)
968 }
969
970 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
971 if !self.is_trained {
972 return Err(anyhow!("IVF index must be trained before searching"));
973 }
974
975 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
977
978 let mut all_results = Vec::new();
980 for idx in probe_indices {
981 let list = self.inverted_lists[idx].read().unwrap();
982 let centroid = &self.centroids[idx];
983 let results = list.search(query, centroid, self.n_vectors)?;
984
985 for (uri, similarity) in results {
987 if similarity >= threshold {
988 all_results.push((uri, similarity));
989 }
990 }
991 }
992
993 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
995
996 Ok(all_results)
997 }
998
999 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1000 None
1004 }
1005}
1006
1007#[derive(Debug, Clone)]
1009pub struct IvfStats {
1010 pub n_vectors: usize,
1011 pub n_clusters: usize,
1012 pub n_probes: usize,
1013 pub is_trained: bool,
1014 pub dimensions: Option<usize>,
1015 pub vectors_per_cluster: Vec<usize>,
1016 pub avg_vectors_per_cluster: f32,
1017 pub non_empty_clusters: usize,
1018 pub enable_residual_quantization: bool,
1019 pub quantization_strategy: QuantizationStrategy,
1020 pub compression_stats: Option<InvertedListStats>,
1021 pub cluster_stats: Vec<InvertedListStats>,
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026 use super::*;
1027
1028 #[test]
1029 fn test_ivf_basic() {
1030 let config = IvfConfig {
1031 n_clusters: 4,
1032 n_probes: 2,
1033 ..Default::default()
1034 };
1035
1036 let mut index = IvfIndex::new(config).unwrap();
1037
1038 let training_vectors = vec![
1040 Vector::new(vec![1.0, 0.0]),
1041 Vector::new(vec![0.0, 1.0]),
1042 Vector::new(vec![-1.0, 0.0]),
1043 Vector::new(vec![0.0, -1.0]),
1044 Vector::new(vec![0.5, 0.5]),
1045 Vector::new(vec![-0.5, 0.5]),
1046 Vector::new(vec![-0.5, -0.5]),
1047 Vector::new(vec![0.5, -0.5]),
1048 ];
1049
1050 index.train(&training_vectors).unwrap();
1052 assert!(index.is_trained);
1053
1054 for (i, vec) in training_vectors.iter().enumerate() {
1056 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1057 }
1058
1059 let query = Vector::new(vec![0.9, 0.1]);
1061 let results = index.search_knn(&query, 3).unwrap();
1062
1063 assert!(!results.is_empty());
1064 assert!(results.len() <= 3);
1065
1066 assert_eq!(results[0].0, "vec0");
1068 }
1069
1070 #[test]
1071 fn test_ivf_threshold_search() {
1072 let config = IvfConfig {
1073 n_clusters: 2,
1074 n_probes: 2,
1075 ..Default::default()
1076 };
1077
1078 let mut index = IvfIndex::new(config).unwrap();
1079
1080 let training_vectors = vec![
1082 Vector::new(vec![1.0, 0.0, 0.0]),
1083 Vector::new(vec![0.0, 1.0, 0.0]),
1084 Vector::new(vec![0.0, 0.0, 1.0]),
1085 Vector::new(vec![0.5, 0.5, 0.0]),
1086 ];
1087
1088 index.train(&training_vectors).unwrap();
1089
1090 index
1092 .insert("v1".to_string(), training_vectors[0].clone())
1093 .unwrap();
1094 index
1095 .insert("v2".to_string(), training_vectors[1].clone())
1096 .unwrap();
1097 index
1098 .insert("v3".to_string(), training_vectors[2].clone())
1099 .unwrap();
1100 index
1101 .insert("v4".to_string(), training_vectors[3].clone())
1102 .unwrap();
1103
1104 let query = Vector::new(vec![0.9, 0.1, 0.0]);
1106 let results = index.search_threshold(&query, 0.5).unwrap();
1107
1108 assert!(!results.is_empty());
1109 for (_, similarity) in &results {
1111 assert!(*similarity >= 0.5);
1112 }
1113 }
1114
1115 #[test]
1116 fn test_ivf_stats() {
1117 let config = IvfConfig {
1118 n_clusters: 3,
1119 n_probes: 1,
1120 ..Default::default()
1121 };
1122
1123 let mut index = IvfIndex::new(config).unwrap();
1124
1125 let training_vectors = vec![
1127 Vector::new(vec![1.0, 0.0]),
1128 Vector::new(vec![0.0, 1.0]),
1129 Vector::new(vec![-1.0, -1.0]),
1130 ];
1131
1132 index.train(&training_vectors).unwrap();
1133
1134 index
1136 .insert("a".to_string(), Vector::new(vec![1.1, 0.1]))
1137 .unwrap();
1138 index
1139 .insert("b".to_string(), Vector::new(vec![0.1, 1.1]))
1140 .unwrap();
1141
1142 let stats = index.stats();
1143 assert_eq!(stats.n_vectors, 2);
1144 assert_eq!(stats.n_clusters, 3);
1145 assert!(stats.is_trained);
1146 assert_eq!(stats.dimensions, Some(2));
1147 }
1148
1149 #[test]
1150 fn test_ivf_multi_level_quantization() {
1151 use crate::pq::PQConfig;
1152
1153 let pq_config_1 = PQConfig {
1155 n_subquantizers: 2,
1156 n_bits: 8,
1157 ..Default::default()
1158 };
1159 let pq_config_2 = PQConfig {
1160 n_subquantizers: 2,
1161 n_bits: 4,
1162 ..Default::default()
1163 };
1164
1165 let mut index =
1166 IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1167 .unwrap();
1168
1169 let training_vectors = vec![
1171 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1172 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1173 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1174 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1175 Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
1176 Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
1177 ];
1178
1179 index.train(&training_vectors).unwrap();
1181 assert!(index.is_trained);
1182
1183 for (i, vec) in training_vectors.iter().enumerate() {
1185 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1186 }
1187
1188 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1190 let results = index.search_knn(&query, 3).unwrap();
1191
1192 assert!(!results.is_empty());
1193 assert!(results.len() <= 3);
1194
1195 let stats = index.stats();
1197 assert!(matches!(
1198 stats.quantization_strategy,
1199 QuantizationStrategy::ResidualQuantization { .. }
1200 ));
1201 if let Some(compression_stats) = &stats.compression_stats {
1202 assert!(compression_stats.multi_level_vectors > 0);
1203 }
1204 }
1205
1206 #[test]
1207 fn test_ivf_multi_codebook_quantization() {
1208 use crate::pq::PQConfig;
1209
1210 let pq_config_1 = PQConfig {
1212 n_subquantizers: 2,
1213 n_bits: 8,
1214 ..Default::default()
1215 };
1216 let pq_config_2 = PQConfig {
1217 n_subquantizers: 2,
1218 n_bits: 8,
1219 ..Default::default()
1220 };
1221
1222 let mut index =
1223 IvfIndex::new_with_multi_codebook_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1224 .unwrap();
1225
1226 let training_vectors = vec![
1228 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1229 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1230 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1231 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1232 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1233 ];
1234
1235 index.train(&training_vectors).unwrap();
1237 assert!(index.is_trained);
1238
1239 for (i, vec) in training_vectors.iter().enumerate() {
1241 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1242 }
1243
1244 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1246 let results = index.search_knn(&query, 2).unwrap();
1247
1248 assert!(!results.is_empty());
1249 assert!(results.len() <= 2);
1250
1251 let stats = index.stats();
1253 assert!(matches!(
1254 stats.quantization_strategy,
1255 QuantizationStrategy::MultiCodebook { .. }
1256 ));
1257 if let Some(compression_stats) = &stats.compression_stats {
1258 assert!(compression_stats.multi_codebook_vectors > 0);
1259 }
1260 }
1261
1262 #[test]
1263 fn test_quantization_strategies() {
1264 use crate::pq::PQConfig;
1265
1266 let pq_config = PQConfig::default();
1267
1268 let strategies = vec![
1270 QuantizationStrategy::None,
1271 QuantizationStrategy::ProductQuantization(pq_config.clone()),
1272 QuantizationStrategy::ResidualQuantization {
1273 levels: 2,
1274 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1275 },
1276 QuantizationStrategy::MultiCodebook {
1277 num_codebooks: 2,
1278 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1279 },
1280 ];
1281
1282 for strategy in strategies {
1283 let config = IvfConfig {
1284 n_clusters: 2,
1285 n_probes: 1,
1286 quantization: strategy.clone(),
1287 ..Default::default()
1288 };
1289
1290 let index = IvfIndex::new(config);
1291 assert!(
1292 index.is_ok(),
1293 "Failed to create index with strategy: {strategy:?}"
1294 );
1295 }
1296 }
1297}