1use crate::{
9 pq::{PQConfig, PQIndex},
10 Vector, VectorIndex,
11};
12use anyhow::{anyhow, Result};
13use std::sync::{Arc, RwLock};
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum QuantizationStrategy {
18 None,
20 ProductQuantization(PQConfig),
22 ResidualQuantization {
24 levels: usize,
25 pq_configs: Vec<PQConfig>,
26 },
27 MultiCodebook {
29 num_codebooks: usize,
30 pq_configs: Vec<PQConfig>,
31 },
32}
33
34#[derive(Debug, Clone)]
36pub struct IvfConfig {
37 pub n_clusters: usize,
39 pub n_probes: usize,
41 pub max_iterations: usize,
43 pub convergence_threshold: f32,
45 pub seed: Option<u64>,
47 pub quantization: QuantizationStrategy,
49 pub enable_residual_quantization: bool,
51 pub pq_config: Option<PQConfig>,
53}
54
55impl Default for IvfConfig {
56 fn default() -> Self {
57 Self {
58 n_clusters: 256,
59 n_probes: 8,
60 max_iterations: 100,
61 convergence_threshold: 1e-4,
62 seed: None,
63 quantization: QuantizationStrategy::None,
64 enable_residual_quantization: false,
65 pq_config: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72enum VectorStorage {
73 Full(Vector),
75 Quantized(Vec<u8>),
77 MultiLevelQuantized {
79 levels: Vec<Vec<u8>>, final_residual: Option<Vector>, },
82 MultiCodebook {
84 codebooks: Vec<Vec<u8>>, weights: Vec<f32>, },
87}
88
89#[derive(Debug, Clone)]
91struct InvertedList {
92 vectors: Vec<(String, VectorStorage)>,
94 quantization: QuantizationStrategy,
96 pq_index: Option<PQIndex>,
98 multi_level_pq: Vec<PQIndex>,
100 multi_codebook_pq: Vec<PQIndex>,
102 codebook_weights: Vec<f32>,
104}
105
106impl InvertedList {
107 fn new() -> Self {
108 Self {
109 vectors: Vec::new(),
110 quantization: QuantizationStrategy::None,
111 pq_index: None,
112 multi_level_pq: Vec::new(),
113 multi_codebook_pq: Vec::new(),
114 codebook_weights: Vec::new(),
115 }
116 }
117
118 fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
119 let mut list = Self {
120 vectors: Vec::new(),
121 quantization: quantization.clone(),
122 pq_index: None,
123 multi_level_pq: Vec::new(),
124 multi_codebook_pq: Vec::new(),
125 codebook_weights: Vec::new(),
126 };
127
128 match quantization {
129 QuantizationStrategy::None => {}
130 QuantizationStrategy::ProductQuantization(pq_config) => {
131 list.pq_index = Some(PQIndex::new(pq_config));
132 }
133 QuantizationStrategy::ResidualQuantization {
134 levels: _,
135 ref pq_configs,
136 } => {
137 for pq_config in pq_configs {
138 list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
139 }
140 }
141 QuantizationStrategy::MultiCodebook {
142 num_codebooks,
143 ref pq_configs,
144 } => {
145 for pq_config in pq_configs {
146 list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
147 }
148 list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
150 }
151 }
152
153 Ok(list)
154 }
155
156 fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
158 Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
159 }
160
161 fn add_full(&mut self, uri: String, vector: Vector) {
162 self.vectors.push((uri, VectorStorage::Full(vector)));
163 }
164
165 fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
166 match &self.quantization {
167 QuantizationStrategy::ProductQuantization(_) => {
168 if let Some(ref mut pq_index) = self.pq_index {
169 if !pq_index.is_trained() {
171 let training_residuals = vec![residual.clone()];
172 pq_index.train(&training_residuals)?;
173 }
174
175 let codes = pq_index.encode(&residual)?;
176 self.vectors.push((uri, VectorStorage::Quantized(codes)));
177 } else {
178 return Err(anyhow!(
179 "PQ index not initialized for residual quantization"
180 ));
181 }
182 }
183 QuantizationStrategy::ResidualQuantization { levels, .. } => {
184 self.add_multi_level_residual(uri, residual, *levels)?;
185 }
186 QuantizationStrategy::MultiCodebook { .. } => {
187 self.add_multi_codebook(uri, residual)?;
188 }
189 QuantizationStrategy::None => {
190 self.add_full(uri, residual);
191 }
192 }
193 Ok(())
194 }
195
196 fn add_multi_level_residual(
198 &mut self,
199 uri: String,
200 mut residual: Vector,
201 levels: usize,
202 ) -> Result<()> {
203 let mut level_codes = Vec::new();
204
205 for level in 0..levels.min(self.multi_level_pq.len()) {
206 if !self.multi_level_pq[level].is_trained() {
208 let training_residuals = vec![residual.clone()];
209 self.multi_level_pq[level].train(&training_residuals)?;
210 }
211
212 let codes = self.multi_level_pq[level].encode(&residual)?;
214 level_codes.push(codes);
215
216 let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
218 residual = residual.subtract(&approximation)?;
219 }
220
221 let final_residual = if level_codes.len() < levels {
223 Some(residual)
224 } else {
225 None
226 };
227
228 self.vectors.push((
229 uri,
230 VectorStorage::MultiLevelQuantized {
231 levels: level_codes,
232 final_residual,
233 },
234 ));
235
236 Ok(())
237 }
238
239 fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
241 let mut codebook_codes = Vec::new();
242
243 for pq_index in self.multi_codebook_pq.iter_mut() {
244 if !pq_index.is_trained() {
246 let training_residuals = vec![residual.clone()];
247 pq_index.train(&training_residuals)?;
248 }
249
250 let codes = pq_index.encode(&residual)?;
252 codebook_codes.push(codes);
253 }
254
255 self.vectors.push((
256 uri,
257 VectorStorage::MultiCodebook {
258 codebooks: codebook_codes,
259 weights: self.codebook_weights.clone(),
260 },
261 ));
262
263 Ok(())
264 }
265
266 fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
267 let mut distances: Vec<(String, f32)> = Vec::new();
268 let query_residual = query.subtract(centroid)?;
269
270 for (uri, storage) in &self.vectors {
271 let distance = match storage {
272 VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
273 VectorStorage::Quantized(codes) => {
274 if let Some(ref pq_index) = self.pq_index {
275 pq_index.compute_distance(&query_residual, codes)?
276 } else {
277 f32::INFINITY
278 }
279 }
280 VectorStorage::MultiLevelQuantized {
281 levels,
282 final_residual,
283 } => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
284 VectorStorage::MultiCodebook { codebooks, weights } => {
285 self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
286 }
287 };
288 distances.push((uri.clone(), distance));
289 }
290
291 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
292 distances.truncate(k);
293
294 Ok(distances
296 .into_iter()
297 .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
298 .collect())
299 }
300
301 fn compute_multi_level_distance(
303 &self,
304 query_residual: &Vector,
305 level_codes: &[Vec<u8>],
306 final_residual: &Option<Vector>,
307 ) -> Result<f32> {
308 let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
309
310 for (level, codes) in level_codes.iter().enumerate() {
312 if level < self.multi_level_pq.len() {
313 let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
314 reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
315 }
316 }
317
318 if let Some(final_res) = final_residual {
320 reconstructed_residual = reconstructed_residual.add(final_res)?;
321 }
322
323 query_residual.euclidean_distance(&reconstructed_residual)
325 }
326
327 fn compute_multi_codebook_distance(
329 &self,
330 query_residual: &Vector,
331 codebook_codes: &[Vec<u8>],
332 weights: &[f32],
333 ) -> Result<f32> {
334 let mut weighted_distance = 0.0;
335 let mut total_weight = 0.0;
336
337 for (i, codes) in codebook_codes.iter().enumerate() {
339 if i < self.multi_codebook_pq.len() && i < weights.len() {
340 let codebook_distance =
341 self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
342 weighted_distance += weights[i] * codebook_distance;
343 total_weight += weights[i];
344 }
345 }
346
347 if total_weight > 0.0 {
349 Ok(weighted_distance / total_weight)
350 } else {
351 Ok(f32::INFINITY)
352 }
353 }
354
355 fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
357 match &self.quantization {
358 QuantizationStrategy::ProductQuantization(_) => {
359 if let Some(ref mut pq_index) = self.pq_index {
360 pq_index.train(residuals)?;
361 }
362 }
363 QuantizationStrategy::ResidualQuantization { levels, .. } => {
364 self.train_multi_level_pq(residuals, *levels)?;
365 }
366 QuantizationStrategy::MultiCodebook { .. } => {
367 self.train_multi_codebook_pq(residuals)?;
368 }
369 QuantizationStrategy::None => {}
370 }
371 Ok(())
372 }
373
374 fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
376 let mut current_residuals = residuals.to_vec();
377
378 for level in 0..levels.min(self.multi_level_pq.len()) {
379 self.multi_level_pq[level].train(¤t_residuals)?;
381
382 let mut next_residuals = Vec::new();
384 for residual in ¤t_residuals {
385 let codes = self.multi_level_pq[level].encode(residual)?;
386 let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
387 let next_residual = residual.subtract(&approximation)?;
388 next_residuals.push(next_residual);
389 }
390 current_residuals = next_residuals;
391 }
392
393 Ok(())
394 }
395
396 fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
398 for pq_index in &mut self.multi_codebook_pq {
400 pq_index.train(residuals)?;
401 }
402
403 self.optimize_codebook_weights(residuals)?;
405
406 Ok(())
407 }
408
409 fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
411 if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
412 return Ok(());
413 }
414
415 let num_codebooks = self.multi_codebook_pq.len();
416 let mut reconstruction_errors = vec![0.0; num_codebooks];
417
418 for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
420 let mut total_error = 0.0;
421 for residual in residuals {
422 let codes = pq_index.encode(residual)?;
423 let reconstruction = pq_index.decode_vector(&codes)?;
424 let error = residual
425 .euclidean_distance(&reconstruction)
426 .unwrap_or(f32::INFINITY);
427 total_error += error;
428 }
429 reconstruction_errors[i] = total_error / residuals.len() as f32;
430 }
431
432 let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
434 if max_error > 0.0 {
435 let mut total_weight = 0.0;
436 for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
437 self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
439 total_weight += self.codebook_weights[i];
440 }
441
442 if total_weight > 0.0 {
444 for weight in &mut self.codebook_weights {
445 *weight /= total_weight;
446 }
447 }
448 }
449
450 Ok(())
451 }
452
453 fn stats(&self) -> InvertedListStats {
455 let mut full_vectors = 0;
456 let mut quantized_vectors = 0;
457 let mut multi_level_vectors = 0;
458 let mut multi_codebook_vectors = 0;
459
460 for (_, storage) in &self.vectors {
461 match storage {
462 VectorStorage::Full(_) => full_vectors += 1,
463 VectorStorage::Quantized(_) => quantized_vectors += 1,
464 VectorStorage::MultiLevelQuantized { .. } => {
465 quantized_vectors += 1;
466 multi_level_vectors += 1;
467 }
468 VectorStorage::MultiCodebook { .. } => {
469 quantized_vectors += 1;
470 multi_codebook_vectors += 1;
471 }
472 }
473 }
474
475 let total_vectors = self.vectors.len();
476 let compression_ratio = if total_vectors > 0 {
477 quantized_vectors as f32 / total_vectors as f32
478 } else {
479 0.0
480 };
481
482 InvertedListStats {
483 total_vectors,
484 full_vectors,
485 quantized_vectors,
486 compression_ratio,
487 multi_level_vectors,
488 multi_codebook_vectors,
489 quantization_strategy: self.quantization.clone(),
490 }
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct InvertedListStats {
497 pub total_vectors: usize,
498 pub full_vectors: usize,
499 pub quantized_vectors: usize,
500 pub compression_ratio: f32,
501 pub multi_level_vectors: usize,
502 pub multi_codebook_vectors: usize,
503 pub quantization_strategy: QuantizationStrategy,
504}
505
506pub struct IvfIndex {
508 config: IvfConfig,
509 centroids: Vec<Vector>,
511 inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
513 dimensions: Option<usize>,
515 n_vectors: usize,
517 is_trained: bool,
519}
520
521impl IvfIndex {
522 pub fn new(config: IvfConfig) -> Result<Self> {
524 let mut inverted_lists = Vec::with_capacity(config.n_clusters);
525
526 let quantization = if config.enable_residual_quantization {
528 if let Some(ref pq_config) = config.pq_config {
529 QuantizationStrategy::ProductQuantization(pq_config.clone())
530 } else {
531 return Err(anyhow!(
532 "PQ config required when residual quantization is enabled"
533 ));
534 }
535 } else {
536 config.quantization.clone()
537 };
538
539 for _ in 0..config.n_clusters {
540 let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
541 quantization.clone(),
542 )?));
543 inverted_lists.push(inverted_list);
544 }
545
546 Ok(Self {
547 config,
548 centroids: Vec::new(),
549 inverted_lists,
550 dimensions: None,
551 n_vectors: 0,
552 is_trained: false,
553 })
554 }
555
556 pub fn new_with_product_quantization(
558 n_clusters: usize,
559 n_probes: usize,
560 pq_config: PQConfig,
561 ) -> Result<Self> {
562 let config = IvfConfig {
563 n_clusters,
564 n_probes,
565 quantization: QuantizationStrategy::ProductQuantization(pq_config),
566 ..Default::default()
567 };
568 Self::new(config)
569 }
570
571 pub fn new_with_multi_level_quantization(
573 n_clusters: usize,
574 n_probes: usize,
575 levels: usize,
576 pq_configs: Vec<PQConfig>,
577 ) -> Result<Self> {
578 if pq_configs.len() < levels {
579 return Err(anyhow!(
580 "Number of PQ configs must be at least equal to levels"
581 ));
582 }
583
584 let config = IvfConfig {
585 n_clusters,
586 n_probes,
587 quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
588 ..Default::default()
589 };
590 Self::new(config)
591 }
592
593 pub fn new_with_multi_codebook_quantization(
595 n_clusters: usize,
596 n_probes: usize,
597 num_codebooks: usize,
598 pq_configs: Vec<PQConfig>,
599 ) -> Result<Self> {
600 if pq_configs.len() != num_codebooks {
601 return Err(anyhow!(
602 "Number of PQ configs must equal number of codebooks"
603 ));
604 }
605
606 let config = IvfConfig {
607 n_clusters,
608 n_probes,
609 quantization: QuantizationStrategy::MultiCodebook {
610 num_codebooks,
611 pq_configs,
612 },
613 ..Default::default()
614 };
615 Self::new(config)
616 }
617
618 pub fn new_with_residual_quantization(
620 n_clusters: usize,
621 n_probes: usize,
622 pq_config: PQConfig,
623 ) -> Result<Self> {
624 Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
625 }
626
627 pub fn config(&self) -> &IvfConfig {
629 &self.config
630 }
631
632 pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
634 if training_vectors.is_empty() {
635 return Err(anyhow!("Cannot train IVF index with empty training set"));
636 }
637
638 let dims = training_vectors[0].dimensions;
640 if !training_vectors.iter().all(|v| v.dimensions == dims) {
641 return Err(anyhow!(
642 "All training vectors must have the same dimensions"
643 ));
644 }
645
646 self.dimensions = Some(dims);
647
648 self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
650
651 let mut iteration = 0;
653 let mut prev_error = f32::INFINITY;
654
655 while iteration < self.config.max_iterations {
656 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
658
659 for vector in training_vectors {
660 let nearest_idx = self.find_nearest_centroid(vector)?;
661 clusters[nearest_idx].push(vector);
662 }
663
664 let mut total_error = 0.0;
666 for (i, cluster) in clusters.iter().enumerate() {
667 if !cluster.is_empty() {
668 let new_centroid = self.compute_centroid(cluster);
669 total_error += self.centroids[i]
670 .euclidean_distance(&new_centroid)
671 .unwrap_or(0.0);
672 self.centroids[i] = new_centroid;
673 }
674 }
675
676 if (prev_error - total_error).abs() < self.config.convergence_threshold {
678 break;
679 }
680
681 prev_error = total_error;
682 iteration += 1;
683 }
684
685 self.is_trained = true;
686
687 if !matches!(self.config.quantization, QuantizationStrategy::None)
689 || self.config.enable_residual_quantization
690 {
691 self.train_residual_quantization(training_vectors)?;
692 }
693
694 Ok(())
695 }
696
697 fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
699 let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
701
702 for vector in training_vectors {
703 let cluster_idx = self.find_nearest_centroid(vector)?;
704 let centroid = &self.centroids[cluster_idx];
705 let residual = vector.subtract(centroid)?;
706 cluster_residuals[cluster_idx].push(residual);
707 }
708
709 for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
711 if residuals.len() > 10 {
712 let mut list = self.inverted_lists[cluster_idx]
714 .write()
715 .expect("inverted_lists lock should not be poisoned");
716 list.train_pq(residuals)?;
717 }
718 }
719
720 Ok(())
721 }
722
723 fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
725 use std::collections::hash_map::DefaultHasher;
726 use std::hash::{Hash, Hasher};
727
728 let mut hasher = DefaultHasher::new();
729 self.config.seed.unwrap_or(42).hash(&mut hasher);
730 let mut rng_state = hasher.finish();
731
732 let mut centroids = Vec::with_capacity(self.config.n_clusters);
733
734 let first_idx = (rng_state as usize) % vectors.len();
736 centroids.push(vectors[first_idx].clone());
737
738 while centroids.len() < self.config.n_clusters {
740 let mut distances = Vec::with_capacity(vectors.len());
741 let mut sum_distances = 0.0;
742
743 for vector in vectors {
745 let min_dist = centroids
746 .iter()
747 .map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
748 .fold(f32::INFINITY, |a, b| a.min(b));
749
750 distances.push(min_dist * min_dist); sum_distances += min_dist * min_dist;
752 }
753
754 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
756 let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
757
758 let mut cumulative = 0.0;
759 for (i, &dist) in distances.iter().enumerate() {
760 cumulative += dist;
761 if cumulative >= threshold {
762 centroids.push(vectors[i].clone());
763 break;
764 }
765 }
766 }
767
768 Ok(centroids)
769 }
770
771 fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
773 if cluster.is_empty() {
774 return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
775 }
776
777 let dims = cluster[0].dimensions;
778 let mut sum = vec![0.0; dims];
779
780 for vector in cluster {
781 let values = vector.as_f32();
782 for (i, &val) in values.iter().enumerate() {
783 sum[i] += val;
784 }
785 }
786
787 let count = cluster.len() as f32;
788 for val in &mut sum {
789 *val /= count;
790 }
791
792 Vector::new(sum)
793 }
794
795 fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
797 if self.centroids.is_empty() {
798 return Err(anyhow!("No centroids available"));
799 }
800
801 let mut min_distance = f32::INFINITY;
802 let mut nearest_idx = 0;
803
804 for (i, centroid) in self.centroids.iter().enumerate() {
805 let distance = vector.euclidean_distance(centroid)?;
806 if distance < min_distance {
807 min_distance = distance;
808 nearest_idx = i;
809 }
810 }
811
812 Ok(nearest_idx)
813 }
814
815 fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
817 let mut distances: Vec<(usize, f32)> = self
818 .centroids
819 .iter()
820 .enumerate()
821 .map(|(i, centroid)| {
822 let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
823 (i, dist)
824 })
825 .collect();
826
827 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
828
829 Ok(distances
830 .into_iter()
831 .take(n_probes.min(self.centroids.len()))
832 .map(|(i, _)| i)
833 .collect())
834 }
835
836 pub fn stats(&self) -> IvfStats {
838 let mut total_list_stats = InvertedListStats {
839 total_vectors: 0,
840 full_vectors: 0,
841 quantized_vectors: 0,
842 compression_ratio: 0.0,
843 multi_level_vectors: 0,
844 multi_codebook_vectors: 0,
845 quantization_strategy: QuantizationStrategy::None,
846 };
847
848 let mut cluster_stats = Vec::new();
849 let mut vectors_per_cluster = Vec::new();
850 let mut non_empty_clusters = 0;
851
852 for list in &self.inverted_lists {
853 let list_guard = list
854 .read()
855 .expect("inverted list lock should not be poisoned");
856 let stats = list_guard.stats();
857
858 total_list_stats.total_vectors += stats.total_vectors;
859 total_list_stats.full_vectors += stats.full_vectors;
860 total_list_stats.quantized_vectors += stats.quantized_vectors;
861 total_list_stats.multi_level_vectors += stats.multi_level_vectors;
862 total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
863
864 vectors_per_cluster.push(stats.total_vectors);
865 if stats.total_vectors > 0 {
866 non_empty_clusters += 1;
867 }
868
869 cluster_stats.push(stats);
870 }
871
872 if total_list_stats.total_vectors > 0 {
874 total_list_stats.compression_ratio =
875 total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
876 }
877
878 let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
879 self.n_vectors as f32 / self.config.n_clusters as f32
880 } else {
881 0.0
882 };
883
884 IvfStats {
885 n_clusters: self.config.n_clusters,
886 n_probes: self.config.n_probes,
887 n_vectors: self.n_vectors,
888 is_trained: self.is_trained,
889 dimensions: self.dimensions,
890 vectors_per_cluster,
891 avg_vectors_per_cluster,
892 non_empty_clusters,
893 enable_residual_quantization: self.config.enable_residual_quantization,
894 quantization_strategy: self.config.quantization.clone(),
895 compression_stats: Some(total_list_stats),
896 cluster_stats,
897 }
898 }
899}
900
901impl VectorIndex for IvfIndex {
902 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
903 if !self.is_trained {
904 return Err(anyhow!(
905 "IVF index must be trained before inserting vectors"
906 ));
907 }
908
909 if let Some(dims) = self.dimensions {
911 if vector.dimensions != dims {
912 return Err(anyhow!(
913 "Vector dimensions {} don't match index dimensions {}",
914 vector.dimensions,
915 dims
916 ));
917 }
918 }
919
920 let cluster_idx = self.find_nearest_centroid(&vector)?;
922 let centroid = &self.centroids[cluster_idx];
923
924 let mut list = self.inverted_lists[cluster_idx]
925 .write()
926 .expect("inverted_lists lock should not be poisoned");
927
928 match &self.config.quantization {
930 QuantizationStrategy::None => {
931 if self.config.enable_residual_quantization {
932 let residual = vector.subtract(centroid)?;
934 list.add_residual(uri, residual, centroid)?;
935 } else {
936 list.add_full(uri, vector);
937 }
938 }
939 _ => {
940 let residual = vector.subtract(centroid)?;
942 list.add_residual(uri, residual, centroid)?;
943 }
944 }
945
946 self.n_vectors += 1;
947 Ok(())
948 }
949
950 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
951 if !self.is_trained {
952 return Err(anyhow!("IVF index must be trained before searching"));
953 }
954
955 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
957
958 let mut all_results = Vec::new();
960 for idx in probe_indices {
961 let list = self.inverted_lists[idx]
962 .read()
963 .expect("inverted_lists lock should not be poisoned");
964 let centroid = &self.centroids[idx];
965 let mut results = list.search(query, centroid, k)?;
966 all_results.append(&mut results);
967 }
968
969 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
971 all_results.truncate(k);
972
973 Ok(all_results)
974 }
975
976 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
977 if !self.is_trained {
978 return Err(anyhow!("IVF index must be trained before searching"));
979 }
980
981 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
983
984 let mut all_results = Vec::new();
986 for idx in probe_indices {
987 let list = self.inverted_lists[idx]
988 .read()
989 .expect("inverted_lists lock should not be poisoned");
990 let centroid = &self.centroids[idx];
991 let results = list.search(query, centroid, self.n_vectors)?;
992
993 for (uri, similarity) in results {
995 if similarity >= threshold {
996 all_results.push((uri, similarity));
997 }
998 }
999 }
1000
1001 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1003
1004 Ok(all_results)
1005 }
1006
1007 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1008 None
1012 }
1013}
1014
1015#[derive(Debug, Clone)]
1017pub struct IvfStats {
1018 pub n_vectors: usize,
1019 pub n_clusters: usize,
1020 pub n_probes: usize,
1021 pub is_trained: bool,
1022 pub dimensions: Option<usize>,
1023 pub vectors_per_cluster: Vec<usize>,
1024 pub avg_vectors_per_cluster: f32,
1025 pub non_empty_clusters: usize,
1026 pub enable_residual_quantization: bool,
1027 pub quantization_strategy: QuantizationStrategy,
1028 pub compression_stats: Option<InvertedListStats>,
1029 pub cluster_stats: Vec<InvertedListStats>,
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 #[test]
1037 fn test_ivf_basic() {
1038 let config = IvfConfig {
1039 n_clusters: 4,
1040 n_probes: 2,
1041 ..Default::default()
1042 };
1043
1044 let mut index = IvfIndex::new(config).unwrap();
1045
1046 let training_vectors = vec![
1048 Vector::new(vec![1.0, 0.0]),
1049 Vector::new(vec![0.0, 1.0]),
1050 Vector::new(vec![-1.0, 0.0]),
1051 Vector::new(vec![0.0, -1.0]),
1052 Vector::new(vec![0.5, 0.5]),
1053 Vector::new(vec![-0.5, 0.5]),
1054 Vector::new(vec![-0.5, -0.5]),
1055 Vector::new(vec![0.5, -0.5]),
1056 ];
1057
1058 index.train(&training_vectors).unwrap();
1060 assert!(index.is_trained);
1061
1062 for (i, vec) in training_vectors.iter().enumerate() {
1064 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1065 }
1066
1067 let query = Vector::new(vec![0.9, 0.1]);
1069 let results = index.search_knn(&query, 3).unwrap();
1070
1071 assert!(!results.is_empty());
1072 assert!(results.len() <= 3);
1073
1074 assert_eq!(results[0].0, "vec0");
1076 }
1077
1078 #[test]
1079 fn test_ivf_threshold_search() {
1080 let config = IvfConfig {
1081 n_clusters: 2,
1082 n_probes: 2,
1083 ..Default::default()
1084 };
1085
1086 let mut index = IvfIndex::new(config).unwrap();
1087
1088 let training_vectors = vec![
1090 Vector::new(vec![1.0, 0.0, 0.0]),
1091 Vector::new(vec![0.0, 1.0, 0.0]),
1092 Vector::new(vec![0.0, 0.0, 1.0]),
1093 Vector::new(vec![0.5, 0.5, 0.0]),
1094 ];
1095
1096 index.train(&training_vectors).unwrap();
1097
1098 index
1100 .insert("v1".to_string(), training_vectors[0].clone())
1101 .unwrap();
1102 index
1103 .insert("v2".to_string(), training_vectors[1].clone())
1104 .unwrap();
1105 index
1106 .insert("v3".to_string(), training_vectors[2].clone())
1107 .unwrap();
1108 index
1109 .insert("v4".to_string(), training_vectors[3].clone())
1110 .unwrap();
1111
1112 let query = Vector::new(vec![0.9, 0.1, 0.0]);
1114 let results = index.search_threshold(&query, 0.5).unwrap();
1115
1116 assert!(!results.is_empty());
1117 for (_, similarity) in &results {
1119 assert!(*similarity >= 0.5);
1120 }
1121 }
1122
1123 #[test]
1124 fn test_ivf_stats() {
1125 let config = IvfConfig {
1126 n_clusters: 3,
1127 n_probes: 1,
1128 ..Default::default()
1129 };
1130
1131 let mut index = IvfIndex::new(config).unwrap();
1132
1133 let training_vectors = vec![
1135 Vector::new(vec![1.0, 0.0]),
1136 Vector::new(vec![0.0, 1.0]),
1137 Vector::new(vec![-1.0, -1.0]),
1138 ];
1139
1140 index.train(&training_vectors).unwrap();
1141
1142 index
1144 .insert("a".to_string(), Vector::new(vec![1.1, 0.1]))
1145 .unwrap();
1146 index
1147 .insert("b".to_string(), Vector::new(vec![0.1, 1.1]))
1148 .unwrap();
1149
1150 let stats = index.stats();
1151 assert_eq!(stats.n_vectors, 2);
1152 assert_eq!(stats.n_clusters, 3);
1153 assert!(stats.is_trained);
1154 assert_eq!(stats.dimensions, Some(2));
1155 }
1156
1157 #[test]
1158 fn test_ivf_multi_level_quantization() {
1159 use crate::pq::PQConfig;
1160
1161 let pq_config_1 = PQConfig {
1163 n_subquantizers: 2,
1164 n_bits: 8,
1165 ..Default::default()
1166 };
1167 let pq_config_2 = PQConfig {
1168 n_subquantizers: 2,
1169 n_bits: 4,
1170 ..Default::default()
1171 };
1172
1173 let mut index =
1174 IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1175 .unwrap();
1176
1177 let training_vectors = vec![
1179 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1180 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1181 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1182 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1183 Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
1184 Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
1185 ];
1186
1187 index.train(&training_vectors).unwrap();
1189 assert!(index.is_trained);
1190
1191 for (i, vec) in training_vectors.iter().enumerate() {
1193 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1194 }
1195
1196 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1198 let results = index.search_knn(&query, 3).unwrap();
1199
1200 assert!(!results.is_empty());
1201 assert!(results.len() <= 3);
1202
1203 let stats = index.stats();
1205 assert!(matches!(
1206 stats.quantization_strategy,
1207 QuantizationStrategy::ResidualQuantization { .. }
1208 ));
1209 if let Some(compression_stats) = &stats.compression_stats {
1210 assert!(compression_stats.multi_level_vectors > 0);
1211 }
1212 }
1213
1214 #[test]
1215 fn test_ivf_multi_codebook_quantization() {
1216 use crate::pq::PQConfig;
1217
1218 let pq_config_1 = PQConfig {
1220 n_subquantizers: 2,
1221 n_bits: 8,
1222 ..Default::default()
1223 };
1224 let pq_config_2 = PQConfig {
1225 n_subquantizers: 2,
1226 n_bits: 8,
1227 ..Default::default()
1228 };
1229
1230 let mut index =
1231 IvfIndex::new_with_multi_codebook_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])
1232 .unwrap();
1233
1234 let training_vectors = vec![
1236 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1237 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1238 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1239 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1240 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1241 ];
1242
1243 index.train(&training_vectors).unwrap();
1245 assert!(index.is_trained);
1246
1247 for (i, vec) in training_vectors.iter().enumerate() {
1249 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1250 }
1251
1252 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1254 let results = index.search_knn(&query, 2).unwrap();
1255
1256 assert!(!results.is_empty());
1257 assert!(results.len() <= 2);
1258
1259 let stats = index.stats();
1261 assert!(matches!(
1262 stats.quantization_strategy,
1263 QuantizationStrategy::MultiCodebook { .. }
1264 ));
1265 if let Some(compression_stats) = &stats.compression_stats {
1266 assert!(compression_stats.multi_codebook_vectors > 0);
1267 }
1268 }
1269
1270 #[test]
1271 fn test_quantization_strategies() {
1272 use crate::pq::PQConfig;
1273
1274 let pq_config = PQConfig::default();
1275
1276 let strategies = vec![
1278 QuantizationStrategy::None,
1279 QuantizationStrategy::ProductQuantization(pq_config.clone()),
1280 QuantizationStrategy::ResidualQuantization {
1281 levels: 2,
1282 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1283 },
1284 QuantizationStrategy::MultiCodebook {
1285 num_codebooks: 2,
1286 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1287 },
1288 ];
1289
1290 for strategy in strategies {
1291 let config = IvfConfig {
1292 n_clusters: 2,
1293 n_probes: 1,
1294 quantization: strategy.clone(),
1295 ..Default::default()
1296 };
1297
1298 let index = IvfIndex::new(config);
1299 assert!(
1300 index.is_ok(),
1301 "Failed to create index with strategy: {strategy:?}"
1302 );
1303 }
1304 }
1305}