1use crate::{
9 pq::{PQConfig, PQIndex},
10 Vector, VectorIndex,
11};
12use anyhow::{anyhow, Result};
13use std::sync::{Arc, RwLock};
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum QuantizationStrategy {
18 None,
20 ProductQuantization(PQConfig),
22 ResidualQuantization {
24 levels: usize,
25 pq_configs: Vec<PQConfig>,
26 },
27 MultiCodebook {
29 num_codebooks: usize,
30 pq_configs: Vec<PQConfig>,
31 },
32}
33
34#[derive(Debug, Clone)]
36pub struct IvfConfig {
37 pub n_clusters: usize,
39 pub n_probes: usize,
41 pub max_iterations: usize,
43 pub convergence_threshold: f32,
45 pub seed: Option<u64>,
47 pub quantization: QuantizationStrategy,
49 pub enable_residual_quantization: bool,
51 pub pq_config: Option<PQConfig>,
53}
54
55impl Default for IvfConfig {
56 fn default() -> Self {
57 Self {
58 n_clusters: 256,
59 n_probes: 8,
60 max_iterations: 100,
61 convergence_threshold: 1e-4,
62 seed: None,
63 quantization: QuantizationStrategy::None,
64 enable_residual_quantization: false,
65 pq_config: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72enum VectorStorage {
73 Full(Vector),
75 Quantized(Vec<u8>),
77 MultiLevelQuantized {
79 levels: Vec<Vec<u8>>, final_residual: Option<Vector>, },
82 MultiCodebook {
84 codebooks: Vec<Vec<u8>>, weights: Vec<f32>, },
87}
88
89#[derive(Debug, Clone)]
91struct InvertedList {
92 vectors: Vec<(String, VectorStorage)>,
94 quantization: QuantizationStrategy,
96 pq_index: Option<PQIndex>,
98 multi_level_pq: Vec<PQIndex>,
100 multi_codebook_pq: Vec<PQIndex>,
102 codebook_weights: Vec<f32>,
104}
105
106impl InvertedList {
107 fn new() -> Self {
108 Self {
109 vectors: Vec::new(),
110 quantization: QuantizationStrategy::None,
111 pq_index: None,
112 multi_level_pq: Vec::new(),
113 multi_codebook_pq: Vec::new(),
114 codebook_weights: Vec::new(),
115 }
116 }
117
118 fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
119 let mut list = Self {
120 vectors: Vec::new(),
121 quantization: quantization.clone(),
122 pq_index: None,
123 multi_level_pq: Vec::new(),
124 multi_codebook_pq: Vec::new(),
125 codebook_weights: Vec::new(),
126 };
127
128 match quantization {
129 QuantizationStrategy::None => {}
130 QuantizationStrategy::ProductQuantization(pq_config) => {
131 list.pq_index = Some(PQIndex::new(pq_config));
132 }
133 QuantizationStrategy::ResidualQuantization {
134 levels: _,
135 ref pq_configs,
136 } => {
137 for pq_config in pq_configs {
138 list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
139 }
140 }
141 QuantizationStrategy::MultiCodebook {
142 num_codebooks,
143 ref pq_configs,
144 } => {
145 for pq_config in pq_configs {
146 list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
147 }
148 list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
150 }
151 }
152
153 Ok(list)
154 }
155
156 fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
158 Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
159 }
160
161 fn add_full(&mut self, uri: String, vector: Vector) {
162 self.vectors.push((uri, VectorStorage::Full(vector)));
163 }
164
165 fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
166 match &self.quantization {
167 QuantizationStrategy::ProductQuantization(_) => {
168 if let Some(ref mut pq_index) = self.pq_index {
169 if !pq_index.is_trained() {
171 let training_residuals = vec![residual.clone()];
172 pq_index.train(&training_residuals)?;
173 }
174
175 let codes = pq_index.encode(&residual)?;
176 self.vectors.push((uri, VectorStorage::Quantized(codes)));
177 } else {
178 return Err(anyhow!(
179 "PQ index not initialized for residual quantization"
180 ));
181 }
182 }
183 QuantizationStrategy::ResidualQuantization { levels, .. } => {
184 self.add_multi_level_residual(uri, residual, *levels)?;
185 }
186 QuantizationStrategy::MultiCodebook { .. } => {
187 self.add_multi_codebook(uri, residual)?;
188 }
189 QuantizationStrategy::None => {
190 self.add_full(uri, residual);
191 }
192 }
193 Ok(())
194 }
195
196 fn add_multi_level_residual(
198 &mut self,
199 uri: String,
200 mut residual: Vector,
201 levels: usize,
202 ) -> Result<()> {
203 let mut level_codes = Vec::new();
204
205 for level in 0..levels.min(self.multi_level_pq.len()) {
206 if !self.multi_level_pq[level].is_trained() {
208 let training_residuals = vec![residual.clone()];
209 self.multi_level_pq[level].train(&training_residuals)?;
210 }
211
212 let codes = self.multi_level_pq[level].encode(&residual)?;
214 level_codes.push(codes);
215
216 let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
218 residual = residual.subtract(&approximation)?;
219 }
220
221 let final_residual = if level_codes.len() < levels {
223 Some(residual)
224 } else {
225 None
226 };
227
228 self.vectors.push((
229 uri,
230 VectorStorage::MultiLevelQuantized {
231 levels: level_codes,
232 final_residual,
233 },
234 ));
235
236 Ok(())
237 }
238
239 fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
241 let mut codebook_codes = Vec::new();
242
243 for pq_index in self.multi_codebook_pq.iter_mut() {
244 if !pq_index.is_trained() {
246 let training_residuals = vec![residual.clone()];
247 pq_index.train(&training_residuals)?;
248 }
249
250 let codes = pq_index.encode(&residual)?;
252 codebook_codes.push(codes);
253 }
254
255 self.vectors.push((
256 uri,
257 VectorStorage::MultiCodebook {
258 codebooks: codebook_codes,
259 weights: self.codebook_weights.clone(),
260 },
261 ));
262
263 Ok(())
264 }
265
266 fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
267 let mut distances: Vec<(String, f32)> = Vec::new();
268 let query_residual = query.subtract(centroid)?;
269
270 for (uri, storage) in &self.vectors {
271 let distance = match storage {
272 VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
273 VectorStorage::Quantized(codes) => {
274 if let Some(ref pq_index) = self.pq_index {
275 pq_index.compute_distance(&query_residual, codes)?
276 } else {
277 f32::INFINITY
278 }
279 }
280 VectorStorage::MultiLevelQuantized {
281 levels,
282 final_residual,
283 } => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
284 VectorStorage::MultiCodebook { codebooks, weights } => {
285 self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
286 }
287 };
288 distances.push((uri.clone(), distance));
289 }
290
291 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
292 distances.truncate(k);
293
294 Ok(distances
296 .into_iter()
297 .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
298 .collect())
299 }
300
301 fn compute_multi_level_distance(
303 &self,
304 query_residual: &Vector,
305 level_codes: &[Vec<u8>],
306 final_residual: &Option<Vector>,
307 ) -> Result<f32> {
308 let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
309
310 for (level, codes) in level_codes.iter().enumerate() {
312 if level < self.multi_level_pq.len() {
313 let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
314 reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
315 }
316 }
317
318 if let Some(final_res) = final_residual {
320 reconstructed_residual = reconstructed_residual.add(final_res)?;
321 }
322
323 query_residual.euclidean_distance(&reconstructed_residual)
325 }
326
327 fn compute_multi_codebook_distance(
329 &self,
330 query_residual: &Vector,
331 codebook_codes: &[Vec<u8>],
332 weights: &[f32],
333 ) -> Result<f32> {
334 let mut weighted_distance = 0.0;
335 let mut total_weight = 0.0;
336
337 for (i, codes) in codebook_codes.iter().enumerate() {
339 if i < self.multi_codebook_pq.len() && i < weights.len() {
340 let codebook_distance =
341 self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
342 weighted_distance += weights[i] * codebook_distance;
343 total_weight += weights[i];
344 }
345 }
346
347 if total_weight > 0.0 {
349 Ok(weighted_distance / total_weight)
350 } else {
351 Ok(f32::INFINITY)
352 }
353 }
354
355 fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
357 match &self.quantization {
358 QuantizationStrategy::ProductQuantization(_) => {
359 if let Some(ref mut pq_index) = self.pq_index {
360 pq_index.train(residuals)?;
361 }
362 }
363 QuantizationStrategy::ResidualQuantization { levels, .. } => {
364 self.train_multi_level_pq(residuals, *levels)?;
365 }
366 QuantizationStrategy::MultiCodebook { .. } => {
367 self.train_multi_codebook_pq(residuals)?;
368 }
369 QuantizationStrategy::None => {}
370 }
371 Ok(())
372 }
373
374 fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
376 let mut current_residuals = residuals.to_vec();
377
378 for level in 0..levels.min(self.multi_level_pq.len()) {
379 self.multi_level_pq[level].train(¤t_residuals)?;
381
382 let mut next_residuals = Vec::new();
384 for residual in ¤t_residuals {
385 let codes = self.multi_level_pq[level].encode(residual)?;
386 let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
387 let next_residual = residual.subtract(&approximation)?;
388 next_residuals.push(next_residual);
389 }
390 current_residuals = next_residuals;
391 }
392
393 Ok(())
394 }
395
396 fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
398 for pq_index in &mut self.multi_codebook_pq {
400 pq_index.train(residuals)?;
401 }
402
403 self.optimize_codebook_weights(residuals)?;
405
406 Ok(())
407 }
408
409 fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
411 if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
412 return Ok(());
413 }
414
415 let num_codebooks = self.multi_codebook_pq.len();
416 let mut reconstruction_errors = vec![0.0; num_codebooks];
417
418 for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
420 let mut total_error = 0.0;
421 for residual in residuals {
422 let codes = pq_index.encode(residual)?;
423 let reconstruction = pq_index.decode_vector(&codes)?;
424 let error = residual
425 .euclidean_distance(&reconstruction)
426 .unwrap_or(f32::INFINITY);
427 total_error += error;
428 }
429 reconstruction_errors[i] = total_error / residuals.len() as f32;
430 }
431
432 let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
434 if max_error > 0.0 {
435 let mut total_weight = 0.0;
436 for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
437 self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
439 total_weight += self.codebook_weights[i];
440 }
441
442 if total_weight > 0.0 {
444 for weight in &mut self.codebook_weights {
445 *weight /= total_weight;
446 }
447 }
448 }
449
450 Ok(())
451 }
452
453 fn stats(&self) -> InvertedListStats {
455 let mut full_vectors = 0;
456 let mut quantized_vectors = 0;
457 let mut multi_level_vectors = 0;
458 let mut multi_codebook_vectors = 0;
459
460 for (_, storage) in &self.vectors {
461 match storage {
462 VectorStorage::Full(_) => full_vectors += 1,
463 VectorStorage::Quantized(_) => quantized_vectors += 1,
464 VectorStorage::MultiLevelQuantized { .. } => {
465 quantized_vectors += 1;
466 multi_level_vectors += 1;
467 }
468 VectorStorage::MultiCodebook { .. } => {
469 quantized_vectors += 1;
470 multi_codebook_vectors += 1;
471 }
472 }
473 }
474
475 let total_vectors = self.vectors.len();
476 let compression_ratio = if total_vectors > 0 {
477 quantized_vectors as f32 / total_vectors as f32
478 } else {
479 0.0
480 };
481
482 InvertedListStats {
483 total_vectors,
484 full_vectors,
485 quantized_vectors,
486 compression_ratio,
487 multi_level_vectors,
488 multi_codebook_vectors,
489 quantization_strategy: self.quantization.clone(),
490 }
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct InvertedListStats {
497 pub total_vectors: usize,
498 pub full_vectors: usize,
499 pub quantized_vectors: usize,
500 pub compression_ratio: f32,
501 pub multi_level_vectors: usize,
502 pub multi_codebook_vectors: usize,
503 pub quantization_strategy: QuantizationStrategy,
504}
505
506pub struct IvfIndex {
508 config: IvfConfig,
509 centroids: Vec<Vector>,
511 inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
513 dimensions: Option<usize>,
515 n_vectors: usize,
517 is_trained: bool,
519}
520
521impl IvfIndex {
522 pub fn new(config: IvfConfig) -> Result<Self> {
524 let mut inverted_lists = Vec::with_capacity(config.n_clusters);
525
526 let quantization = if config.enable_residual_quantization {
528 if let Some(ref pq_config) = config.pq_config {
529 QuantizationStrategy::ProductQuantization(pq_config.clone())
530 } else {
531 return Err(anyhow!(
532 "PQ config required when residual quantization is enabled"
533 ));
534 }
535 } else {
536 config.quantization.clone()
537 };
538
539 for _ in 0..config.n_clusters {
540 let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
541 quantization.clone(),
542 )?));
543 inverted_lists.push(inverted_list);
544 }
545
546 Ok(Self {
547 config,
548 centroids: Vec::new(),
549 inverted_lists,
550 dimensions: None,
551 n_vectors: 0,
552 is_trained: false,
553 })
554 }
555
556 pub fn new_with_product_quantization(
558 n_clusters: usize,
559 n_probes: usize,
560 pq_config: PQConfig,
561 ) -> Result<Self> {
562 let config = IvfConfig {
563 n_clusters,
564 n_probes,
565 quantization: QuantizationStrategy::ProductQuantization(pq_config),
566 ..Default::default()
567 };
568 Self::new(config)
569 }
570
571 pub fn new_with_multi_level_quantization(
573 n_clusters: usize,
574 n_probes: usize,
575 levels: usize,
576 pq_configs: Vec<PQConfig>,
577 ) -> Result<Self> {
578 if pq_configs.len() < levels {
579 return Err(anyhow!(
580 "Number of PQ configs must be at least equal to levels"
581 ));
582 }
583
584 let config = IvfConfig {
585 n_clusters,
586 n_probes,
587 quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
588 ..Default::default()
589 };
590 Self::new(config)
591 }
592
593 pub fn new_with_multi_codebook_quantization(
595 n_clusters: usize,
596 n_probes: usize,
597 num_codebooks: usize,
598 pq_configs: Vec<PQConfig>,
599 ) -> Result<Self> {
600 if pq_configs.len() != num_codebooks {
601 return Err(anyhow!(
602 "Number of PQ configs must equal number of codebooks"
603 ));
604 }
605
606 let config = IvfConfig {
607 n_clusters,
608 n_probes,
609 quantization: QuantizationStrategy::MultiCodebook {
610 num_codebooks,
611 pq_configs,
612 },
613 ..Default::default()
614 };
615 Self::new(config)
616 }
617
618 pub fn new_with_residual_quantization(
620 n_clusters: usize,
621 n_probes: usize,
622 pq_config: PQConfig,
623 ) -> Result<Self> {
624 Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
625 }
626
627 pub fn config(&self) -> &IvfConfig {
629 &self.config
630 }
631
632 pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
634 if training_vectors.is_empty() {
635 return Err(anyhow!("Cannot train IVF index with empty training set"));
636 }
637
638 let dims = training_vectors[0].dimensions;
640 if !training_vectors.iter().all(|v| v.dimensions == dims) {
641 return Err(anyhow!(
642 "All training vectors must have the same dimensions"
643 ));
644 }
645
646 self.dimensions = Some(dims);
647
648 self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
650
651 let mut iteration = 0;
653 let mut prev_error = f32::INFINITY;
654
655 while iteration < self.config.max_iterations {
656 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
658
659 for vector in training_vectors {
660 let nearest_idx = self.find_nearest_centroid(vector)?;
661 clusters[nearest_idx].push(vector);
662 }
663
664 let mut total_error = 0.0;
666 for (i, cluster) in clusters.iter().enumerate() {
667 if !cluster.is_empty() {
668 let new_centroid = self.compute_centroid(cluster);
669 total_error += self.centroids[i]
670 .euclidean_distance(&new_centroid)
671 .unwrap_or(0.0);
672 self.centroids[i] = new_centroid;
673 }
674 }
675
676 if (prev_error - total_error).abs() < self.config.convergence_threshold {
678 break;
679 }
680
681 prev_error = total_error;
682 iteration += 1;
683 }
684
685 self.is_trained = true;
686
687 if !matches!(self.config.quantization, QuantizationStrategy::None)
689 || self.config.enable_residual_quantization
690 {
691 self.train_residual_quantization(training_vectors)?;
692 }
693
694 Ok(())
695 }
696
697 fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
699 let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
701
702 for vector in training_vectors {
703 let cluster_idx = self.find_nearest_centroid(vector)?;
704 let centroid = &self.centroids[cluster_idx];
705 let residual = vector.subtract(centroid)?;
706 cluster_residuals[cluster_idx].push(residual);
707 }
708
709 for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
711 if residuals.len() > 10 {
712 let mut list = self.inverted_lists[cluster_idx]
714 .write()
715 .expect("inverted_lists lock should not be poisoned");
716 list.train_pq(residuals)?;
717 }
718 }
719
720 Ok(())
721 }
722
723 fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
725 use std::collections::hash_map::DefaultHasher;
726 use std::hash::{Hash, Hasher};
727
728 let mut hasher = DefaultHasher::new();
729 self.config.seed.unwrap_or(42).hash(&mut hasher);
730 let mut rng_state = hasher.finish();
731
732 let mut centroids = Vec::with_capacity(self.config.n_clusters);
733
734 let first_idx = (rng_state as usize) % vectors.len();
736 centroids.push(vectors[first_idx].clone());
737
738 while centroids.len() < self.config.n_clusters {
740 let mut distances = Vec::with_capacity(vectors.len());
741 let mut sum_distances = 0.0;
742
743 for vector in vectors {
745 let min_dist = centroids
746 .iter()
747 .map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
748 .fold(f32::INFINITY, |a, b| a.min(b));
749
750 distances.push(min_dist * min_dist); sum_distances += min_dist * min_dist;
752 }
753
754 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
756 let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
757
758 let mut cumulative = 0.0;
759 for (i, &dist) in distances.iter().enumerate() {
760 cumulative += dist;
761 if cumulative >= threshold {
762 centroids.push(vectors[i].clone());
763 break;
764 }
765 }
766 }
767
768 Ok(centroids)
769 }
770
771 fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
773 if cluster.is_empty() {
774 return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
775 }
776
777 let dims = cluster[0].dimensions;
778 let mut sum = vec![0.0; dims];
779
780 for vector in cluster {
781 let values = vector.as_f32();
782 for (i, &val) in values.iter().enumerate() {
783 sum[i] += val;
784 }
785 }
786
787 let count = cluster.len() as f32;
788 for val in &mut sum {
789 *val /= count;
790 }
791
792 Vector::new(sum)
793 }
794
795 fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
797 if self.centroids.is_empty() {
798 return Err(anyhow!("No centroids available"));
799 }
800
801 let mut min_distance = f32::INFINITY;
802 let mut nearest_idx = 0;
803
804 for (i, centroid) in self.centroids.iter().enumerate() {
805 let distance = vector.euclidean_distance(centroid)?;
806 if distance < min_distance {
807 min_distance = distance;
808 nearest_idx = i;
809 }
810 }
811
812 Ok(nearest_idx)
813 }
814
815 fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
817 let mut distances: Vec<(usize, f32)> = self
818 .centroids
819 .iter()
820 .enumerate()
821 .map(|(i, centroid)| {
822 let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
823 (i, dist)
824 })
825 .collect();
826
827 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
828
829 Ok(distances
830 .into_iter()
831 .take(n_probes.min(self.centroids.len()))
832 .map(|(i, _)| i)
833 .collect())
834 }
835
836 pub fn stats(&self) -> IvfStats {
838 let mut total_list_stats = InvertedListStats {
839 total_vectors: 0,
840 full_vectors: 0,
841 quantized_vectors: 0,
842 compression_ratio: 0.0,
843 multi_level_vectors: 0,
844 multi_codebook_vectors: 0,
845 quantization_strategy: QuantizationStrategy::None,
846 };
847
848 let mut cluster_stats = Vec::new();
849 let mut vectors_per_cluster = Vec::new();
850 let mut non_empty_clusters = 0;
851
852 for list in &self.inverted_lists {
853 let list_guard = list
854 .read()
855 .expect("inverted list lock should not be poisoned");
856 let stats = list_guard.stats();
857
858 total_list_stats.total_vectors += stats.total_vectors;
859 total_list_stats.full_vectors += stats.full_vectors;
860 total_list_stats.quantized_vectors += stats.quantized_vectors;
861 total_list_stats.multi_level_vectors += stats.multi_level_vectors;
862 total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
863
864 vectors_per_cluster.push(stats.total_vectors);
865 if stats.total_vectors > 0 {
866 non_empty_clusters += 1;
867 }
868
869 cluster_stats.push(stats);
870 }
871
872 if total_list_stats.total_vectors > 0 {
874 total_list_stats.compression_ratio =
875 total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
876 }
877
878 let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
879 self.n_vectors as f32 / self.config.n_clusters as f32
880 } else {
881 0.0
882 };
883
884 IvfStats {
885 n_clusters: self.config.n_clusters,
886 n_probes: self.config.n_probes,
887 n_vectors: self.n_vectors,
888 is_trained: self.is_trained,
889 dimensions: self.dimensions,
890 vectors_per_cluster,
891 avg_vectors_per_cluster,
892 non_empty_clusters,
893 enable_residual_quantization: self.config.enable_residual_quantization,
894 quantization_strategy: self.config.quantization.clone(),
895 compression_stats: Some(total_list_stats),
896 cluster_stats,
897 }
898 }
899}
900
901impl VectorIndex for IvfIndex {
902 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
903 if !self.is_trained {
904 return Err(anyhow!(
905 "IVF index must be trained before inserting vectors"
906 ));
907 }
908
909 if let Some(dims) = self.dimensions {
911 if vector.dimensions != dims {
912 return Err(anyhow!(
913 "Vector dimensions {} don't match index dimensions {}",
914 vector.dimensions,
915 dims
916 ));
917 }
918 }
919
920 let cluster_idx = self.find_nearest_centroid(&vector)?;
922 let centroid = &self.centroids[cluster_idx];
923
924 let mut list = self.inverted_lists[cluster_idx]
925 .write()
926 .expect("inverted_lists lock should not be poisoned");
927
928 match &self.config.quantization {
930 QuantizationStrategy::None => {
931 if self.config.enable_residual_quantization {
932 let residual = vector.subtract(centroid)?;
934 list.add_residual(uri, residual, centroid)?;
935 } else {
936 list.add_full(uri, vector);
937 }
938 }
939 _ => {
940 let residual = vector.subtract(centroid)?;
942 list.add_residual(uri, residual, centroid)?;
943 }
944 }
945
946 self.n_vectors += 1;
947 Ok(())
948 }
949
950 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
951 if !self.is_trained {
952 return Err(anyhow!("IVF index must be trained before searching"));
953 }
954
955 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
957
958 let mut all_results = Vec::new();
960 for idx in probe_indices {
961 let list = self.inverted_lists[idx]
962 .read()
963 .expect("inverted_lists lock should not be poisoned");
964 let centroid = &self.centroids[idx];
965 let mut results = list.search(query, centroid, k)?;
966 all_results.append(&mut results);
967 }
968
969 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
971 all_results.truncate(k);
972
973 Ok(all_results)
974 }
975
976 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
977 if !self.is_trained {
978 return Err(anyhow!("IVF index must be trained before searching"));
979 }
980
981 let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
983
984 let mut all_results = Vec::new();
986 for idx in probe_indices {
987 let list = self.inverted_lists[idx]
988 .read()
989 .expect("inverted_lists lock should not be poisoned");
990 let centroid = &self.centroids[idx];
991 let results = list.search(query, centroid, self.n_vectors)?;
992
993 for (uri, similarity) in results {
995 if similarity >= threshold {
996 all_results.push((uri, similarity));
997 }
998 }
999 }
1000
1001 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1003
1004 Ok(all_results)
1005 }
1006
1007 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1008 None
1012 }
1013}
1014
1015#[derive(Debug, Clone)]
1017pub struct IvfStats {
1018 pub n_vectors: usize,
1019 pub n_clusters: usize,
1020 pub n_probes: usize,
1021 pub is_trained: bool,
1022 pub dimensions: Option<usize>,
1023 pub vectors_per_cluster: Vec<usize>,
1024 pub avg_vectors_per_cluster: f32,
1025 pub non_empty_clusters: usize,
1026 pub enable_residual_quantization: bool,
1027 pub quantization_strategy: QuantizationStrategy,
1028 pub compression_stats: Option<InvertedListStats>,
1029 pub cluster_stats: Vec<InvertedListStats>,
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 #[test]
1037 fn test_ivf_basic() -> Result<()> {
1038 let config = IvfConfig {
1039 n_clusters: 4,
1040 n_probes: 2,
1041 ..Default::default()
1042 };
1043
1044 let mut index = IvfIndex::new(config)?;
1045
1046 let training_vectors = vec![
1048 Vector::new(vec![1.0, 0.0]),
1049 Vector::new(vec![0.0, 1.0]),
1050 Vector::new(vec![-1.0, 0.0]),
1051 Vector::new(vec![0.0, -1.0]),
1052 Vector::new(vec![0.5, 0.5]),
1053 Vector::new(vec![-0.5, 0.5]),
1054 Vector::new(vec![-0.5, -0.5]),
1055 Vector::new(vec![0.5, -0.5]),
1056 ];
1057
1058 index.train(&training_vectors)?;
1060 assert!(index.is_trained);
1061
1062 for (i, vec) in training_vectors.iter().enumerate() {
1064 index.insert(format!("vec{i}"), vec.clone())?;
1065 }
1066
1067 let query = Vector::new(vec![0.9, 0.1]);
1069 let results = index.search_knn(&query, 3)?;
1070
1071 assert!(!results.is_empty());
1072 assert!(results.len() <= 3);
1073
1074 assert_eq!(results[0].0, "vec0");
1076 Ok(())
1077 }
1078
1079 #[test]
1080 fn test_ivf_threshold_search() -> Result<()> {
1081 let config = IvfConfig {
1082 n_clusters: 2,
1083 n_probes: 2,
1084 ..Default::default()
1085 };
1086
1087 let mut index = IvfIndex::new(config)?;
1088
1089 let training_vectors = vec![
1091 Vector::new(vec![1.0, 0.0, 0.0]),
1092 Vector::new(vec![0.0, 1.0, 0.0]),
1093 Vector::new(vec![0.0, 0.0, 1.0]),
1094 Vector::new(vec![0.5, 0.5, 0.0]),
1095 ];
1096
1097 index.train(&training_vectors)?;
1098
1099 index.insert("v1".to_string(), training_vectors[0].clone())?;
1101 index.insert("v2".to_string(), training_vectors[1].clone())?;
1102 index.insert("v3".to_string(), training_vectors[2].clone())?;
1103 index.insert("v4".to_string(), training_vectors[3].clone())?;
1104
1105 let query = Vector::new(vec![0.9, 0.1, 0.0]);
1107 let results = index.search_threshold(&query, 0.5)?;
1108
1109 assert!(!results.is_empty());
1110 for (_, similarity) in &results {
1112 assert!(*similarity >= 0.5);
1113 }
1114 Ok(())
1115 }
1116
1117 #[test]
1118 fn test_ivf_stats() -> Result<()> {
1119 let config = IvfConfig {
1120 n_clusters: 3,
1121 n_probes: 1,
1122 ..Default::default()
1123 };
1124
1125 let mut index = IvfIndex::new(config)?;
1126
1127 let training_vectors = vec![
1129 Vector::new(vec![1.0, 0.0]),
1130 Vector::new(vec![0.0, 1.0]),
1131 Vector::new(vec![-1.0, -1.0]),
1132 ];
1133
1134 index.train(&training_vectors)?;
1135
1136 index.insert("a".to_string(), Vector::new(vec![1.1, 0.1]))?;
1138 index.insert("b".to_string(), Vector::new(vec![0.1, 1.1]))?;
1139
1140 let stats = index.stats();
1141 assert_eq!(stats.n_vectors, 2);
1142 assert_eq!(stats.n_clusters, 3);
1143 assert!(stats.is_trained);
1144 assert_eq!(stats.dimensions, Some(2));
1145 Ok(())
1146 }
1147
1148 #[test]
1149 fn test_ivf_multi_level_quantization() -> Result<()> {
1150 use crate::pq::PQConfig;
1151
1152 let pq_config_1 = PQConfig {
1154 n_subquantizers: 2,
1155 n_bits: 8,
1156 ..Default::default()
1157 };
1158 let pq_config_2 = PQConfig {
1159 n_subquantizers: 2,
1160 n_bits: 4,
1161 ..Default::default()
1162 };
1163
1164 let mut index =
1165 IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])?;
1166
1167 let training_vectors = vec![
1169 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1170 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1171 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1172 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1173 Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
1174 Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
1175 ];
1176
1177 index.train(&training_vectors)?;
1179 assert!(index.is_trained);
1180
1181 for (i, vec) in training_vectors.iter().enumerate() {
1183 index.insert(format!("vec{i}"), vec.clone())?;
1184 }
1185
1186 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1188 let results = index.search_knn(&query, 3)?;
1189
1190 assert!(!results.is_empty());
1191 assert!(results.len() <= 3);
1192
1193 let stats = index.stats();
1195 assert!(matches!(
1196 stats.quantization_strategy,
1197 QuantizationStrategy::ResidualQuantization { .. }
1198 ));
1199 if let Some(compression_stats) = &stats.compression_stats {
1200 assert!(compression_stats.multi_level_vectors > 0);
1201 }
1202 Ok(())
1203 }
1204
1205 #[test]
1206 fn test_ivf_multi_codebook_quantization() -> Result<()> {
1207 use crate::pq::PQConfig;
1208
1209 let pq_config_1 = PQConfig {
1211 n_subquantizers: 2,
1212 n_bits: 8,
1213 ..Default::default()
1214 };
1215 let pq_config_2 = PQConfig {
1216 n_subquantizers: 2,
1217 n_bits: 8,
1218 ..Default::default()
1219 };
1220
1221 let mut index = IvfIndex::new_with_multi_codebook_quantization(
1222 4,
1223 2,
1224 2,
1225 vec![pq_config_1, pq_config_2],
1226 )?;
1227
1228 let training_vectors = vec![
1230 Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
1231 Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
1232 Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
1233 Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
1234 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1235 ];
1236
1237 index.train(&training_vectors)?;
1239 assert!(index.is_trained);
1240
1241 for (i, vec) in training_vectors.iter().enumerate() {
1243 index.insert(format!("vec{i}"), vec.clone())?;
1244 }
1245
1246 let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
1248 let results = index.search_knn(&query, 2)?;
1249
1250 assert!(!results.is_empty());
1251 assert!(results.len() <= 2);
1252
1253 let stats = index.stats();
1255 assert!(matches!(
1256 stats.quantization_strategy,
1257 QuantizationStrategy::MultiCodebook { .. }
1258 ));
1259 if let Some(compression_stats) = &stats.compression_stats {
1260 assert!(compression_stats.multi_codebook_vectors > 0);
1261 }
1262 Ok(())
1263 }
1264
1265 #[test]
1266 fn test_quantization_strategies() {
1267 use crate::pq::PQConfig;
1268
1269 let pq_config = PQConfig::default();
1270
1271 let strategies = vec![
1273 QuantizationStrategy::None,
1274 QuantizationStrategy::ProductQuantization(pq_config.clone()),
1275 QuantizationStrategy::ResidualQuantization {
1276 levels: 2,
1277 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1278 },
1279 QuantizationStrategy::MultiCodebook {
1280 num_codebooks: 2,
1281 pq_configs: vec![pq_config.clone(), pq_config.clone()],
1282 },
1283 ];
1284
1285 for strategy in strategies {
1286 let config = IvfConfig {
1287 n_clusters: 2,
1288 n_probes: 1,
1289 quantization: strategy.clone(),
1290 ..Default::default()
1291 };
1292
1293 let index = IvfIndex::new(config);
1294 assert!(
1295 index.is_ok(),
1296 "Failed to create index with strategy: {strategy:?}"
1297 );
1298 }
1299 }
1300}