1use common::DistanceMetric;
7use parking_lot::RwLock;
8use rand::Rng;
9use std::collections::HashMap;
10
11use crate::distance::calculate_distance;
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct IvfConfig {
16 pub n_clusters: usize,
18 pub n_probe: usize,
20 pub max_iterations: usize,
22 pub convergence_threshold: f32,
24 pub metric: DistanceMetric,
26}
27
28impl Default for IvfConfig {
29 fn default() -> Self {
30 Self {
31 n_clusters: 256,
32 n_probe: 10,
33 max_iterations: 100,
34 convergence_threshold: 1e-4,
35 metric: DistanceMetric::Cosine,
36 }
37 }
38}
39
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
42pub struct IndexedVector {
43 pub id: String,
44 pub values: Vec<f32>,
45}
46
47pub struct IvfIndex {
49 config: IvfConfig,
50 dimension: Option<usize>,
51 centroids: RwLock<Vec<Vec<f32>>>,
53 inverted_lists: RwLock<HashMap<usize, Vec<IndexedVector>>>,
55 vector_count: RwLock<usize>,
57 is_trained: RwLock<bool>,
59}
60
61impl IvfIndex {
62 pub fn new(config: IvfConfig) -> Self {
64 Self {
65 config,
66 dimension: None,
67 centroids: RwLock::new(Vec::new()),
68 inverted_lists: RwLock::new(HashMap::new()),
69 vector_count: RwLock::new(0),
70 is_trained: RwLock::new(false),
71 }
72 }
73
74 pub fn with_defaults() -> Self {
76 Self::new(IvfConfig::default())
77 }
78
79 pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
81 if vectors.is_empty() {
82 return Err("Cannot train on empty vector set".to_string());
83 }
84
85 let dim = vectors[0].len();
86 if dim == 0 {
87 return Err("Vector dimension cannot be zero".to_string());
88 }
89
90 for v in vectors {
92 if v.len() != dim {
93 return Err(format!(
94 "Dimension mismatch: expected {}, got {}",
95 dim,
96 v.len()
97 ));
98 }
99 }
100
101 self.dimension = Some(dim);
102
103 let n_clusters = self.config.n_clusters.min(vectors.len());
105
106 let centroids = self.kmeans(vectors, n_clusters)?;
108
109 *self.centroids.write() = centroids;
110 *self.is_trained.write() = true;
111
112 let mut lists = self.inverted_lists.write();
114 lists.clear();
115 for i in 0..n_clusters {
116 lists.insert(i, Vec::new());
117 }
118
119 tracing::info!(
120 n_clusters = n_clusters,
121 dimension = dim,
122 training_vectors = vectors.len(),
123 "IVF index trained"
124 );
125
126 Ok(())
127 }
128
129 fn kmeans(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
131 let dim = vectors[0].len();
132 let mut rng = rand::thread_rng();
133
134 let mut centroids = self.kmeans_plus_plus_init(vectors, k, &mut rng);
136
137 for iteration in 0..self.config.max_iterations {
138 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
140
141 for (idx, vector) in vectors.iter().enumerate() {
142 let nearest = self.find_nearest_centroid(vector, ¢roids);
143 assignments[nearest].push(idx);
144 }
145
146 let mut new_centroids = Vec::with_capacity(k);
148 let mut max_shift = 0.0f32;
149
150 for (cluster_idx, indices) in assignments.iter().enumerate() {
151 if indices.is_empty() {
152 new_centroids.push(centroids[cluster_idx].clone());
154 continue;
155 }
156
157 let mut new_centroid = vec![0.0f32; dim];
159 for &idx in indices {
160 for (j, val) in vectors[idx].iter().enumerate() {
161 new_centroid[j] += val;
162 }
163 }
164 for val in &mut new_centroid {
165 *val /= indices.len() as f32;
166 }
167
168 let shift = euclidean_distance(¢roids[cluster_idx], &new_centroid);
170 max_shift = max_shift.max(shift);
171
172 new_centroids.push(new_centroid);
173 }
174
175 centroids = new_centroids;
176
177 if max_shift < self.config.convergence_threshold {
179 tracing::debug!(
180 iteration = iteration,
181 max_shift = max_shift,
182 "K-means converged"
183 );
184 break;
185 }
186 }
187
188 Ok(centroids)
189 }
190
191 fn kmeans_plus_plus_init<R: Rng>(
193 &self,
194 vectors: &[Vec<f32>],
195 k: usize,
196 rng: &mut R,
197 ) -> Vec<Vec<f32>> {
198 let mut centroids = Vec::with_capacity(k);
199
200 let first_idx = rng.gen_range(0..vectors.len());
202 centroids.push(vectors[first_idx].clone());
203
204 for _ in 1..k {
206 let mut distances: Vec<f32> = vectors
207 .iter()
208 .map(|v| {
209 centroids
210 .iter()
211 .map(|c| euclidean_distance(v, c))
212 .fold(f32::MAX, f32::min)
213 .powi(2)
214 })
215 .collect();
216
217 let total: f32 = distances.iter().sum();
218 if total == 0.0 {
219 break;
221 }
222
223 for d in &mut distances {
225 *d /= total;
226 }
227
228 let sample: f32 = rng.gen();
230 let mut cumsum = 0.0;
231 let mut selected = 0;
232 for (i, &d) in distances.iter().enumerate() {
233 cumsum += d;
234 if cumsum >= sample {
235 selected = i;
236 break;
237 }
238 }
239
240 centroids.push(vectors[selected].clone());
241 }
242
243 centroids
244 }
245
246 fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
248 let mut best_idx = 0;
249 let mut best_score = f32::NEG_INFINITY;
250
251 for (idx, centroid) in centroids.iter().enumerate() {
252 let score = calculate_distance(vector, centroid, self.config.metric);
253 if score > best_score {
254 best_score = score;
255 best_idx = idx;
256 }
257 }
258
259 best_idx
260 }
261
262 fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<usize> {
264 let centroids = self.centroids.read();
265 let mut scores: Vec<(usize, f32)> = centroids
266 .iter()
267 .enumerate()
268 .map(|(idx, c)| (idx, calculate_distance(vector, c, self.config.metric)))
269 .collect();
270
271 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
272 scores.into_iter().take(n).map(|(idx, _)| idx).collect()
273 }
274
275 pub fn add(&self, id: String, vector: Vec<f32>) -> Result<(), String> {
277 if !*self.is_trained.read() {
278 return Err("Index must be trained before adding vectors".to_string());
279 }
280
281 if let Some(dim) = self.dimension {
282 if vector.len() != dim {
283 return Err(format!(
284 "Dimension mismatch: expected {}, got {}",
285 dim,
286 vector.len()
287 ));
288 }
289 }
290
291 let centroids = self.centroids.read();
292 let cluster_idx = self.find_nearest_centroid(&vector, ¢roids);
293 drop(centroids);
294
295 let indexed = IndexedVector { id, values: vector };
296
297 let mut lists = self.inverted_lists.write();
298 lists.entry(cluster_idx).or_default().push(indexed);
299 drop(lists);
300
301 *self.vector_count.write() += 1;
302
303 Ok(())
304 }
305
306 pub fn add_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<usize, String> {
308 let mut count = 0;
309 for (id, vector) in vectors {
310 self.add(id, vector)?;
311 count += 1;
312 }
313 Ok(count)
314 }
315
316 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, String> {
318 if !*self.is_trained.read() {
319 return Err("Index must be trained before searching".to_string());
320 }
321
322 if let Some(dim) = self.dimension {
323 if query.len() != dim {
324 return Err(format!(
325 "Dimension mismatch: expected {}, got {}",
326 dim,
327 query.len()
328 ));
329 }
330 }
331
332 let n_probe = self.config.n_probe.min(self.centroids.read().len());
334 let probe_clusters = self.find_nearest_centroids(query, n_probe);
335
336 let mut candidates: Vec<SearchResult> = Vec::new();
338 let lists = self.inverted_lists.read();
339
340 for cluster_idx in probe_clusters {
341 if let Some(vectors) = lists.get(&cluster_idx) {
342 for indexed in vectors {
343 let score = calculate_distance(query, &indexed.values, self.config.metric);
344 candidates.push(SearchResult {
345 id: indexed.id.clone(),
346 score,
347 });
348 }
349 }
350 }
351
352 candidates.sort_by(|a, b| {
354 b.score
355 .partial_cmp(&a.score)
356 .unwrap_or(std::cmp::Ordering::Equal)
357 });
358 candidates.truncate(k);
359
360 Ok(candidates)
361 }
362
363 pub fn remove(&self, id: &str) -> bool {
365 let mut lists = self.inverted_lists.write();
366 let mut removed = false;
367
368 for vectors in lists.values_mut() {
369 if let Some(pos) = vectors.iter().position(|v| v.id == id) {
370 vectors.remove(pos);
371 removed = true;
372 break;
373 }
374 }
375
376 if removed {
377 *self.vector_count.write() -= 1;
378 }
379
380 removed
381 }
382
383 pub fn len(&self) -> usize {
385 *self.vector_count.read()
386 }
387
388 pub fn is_empty(&self) -> bool {
390 self.len() == 0
391 }
392
393 pub fn is_trained(&self) -> bool {
395 *self.is_trained.read()
396 }
397
398 pub fn n_clusters(&self) -> usize {
400 self.centroids.read().len()
401 }
402
403 pub fn config(&self) -> &IvfConfig {
405 &self.config
406 }
407
408 pub fn dimension(&self) -> Option<usize> {
410 self.dimension
411 }
412
413 pub(crate) fn centroids_read(&self) -> Vec<Vec<f32>> {
415 self.centroids.read().clone()
416 }
417
418 pub(crate) fn inverted_lists_read(&self) -> HashMap<usize, Vec<IndexedVector>> {
420 self.inverted_lists.read().clone()
421 }
422
423 pub fn from_snapshot(snapshot: crate::persistence::IvfFullSnapshot) -> Result<Self, String> {
425 let mut inverted_lists = HashMap::new();
426 for (cluster_id, vectors) in snapshot.inverted_lists {
427 inverted_lists.insert(cluster_id, vectors);
428 }
429
430 Ok(Self {
431 config: snapshot.config,
432 dimension: snapshot.dimension,
433 centroids: RwLock::new(snapshot.centroids),
434 inverted_lists: RwLock::new(inverted_lists),
435 vector_count: RwLock::new(snapshot.vector_count),
436 is_trained: RwLock::new(snapshot.is_trained),
437 })
438 }
439}
440
441#[derive(Debug, Clone)]
443pub struct SearchResult {
444 pub id: String,
445 pub score: f32,
446}
447
448fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
450 a.iter()
451 .zip(b.iter())
452 .map(|(x, y)| (x - y).powi(2))
453 .sum::<f32>()
454 .sqrt()
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
462 let mut rng = rand::thread_rng();
463 (0..n)
464 .map(|_| (0..dim).map(|_| rng.gen::<f32>()).collect())
465 .collect()
466 }
467
468 #[test]
469 fn test_ivf_train() {
470 let vectors = generate_random_vectors(100, 32);
471 let mut index = IvfIndex::new(IvfConfig {
472 n_clusters: 10,
473 ..Default::default()
474 });
475
476 index.train(&vectors).unwrap();
477 assert!(index.is_trained());
478 assert_eq!(index.n_clusters(), 10);
479 }
480
481 #[test]
482 fn test_ivf_add_and_search() {
483 let training_vectors = generate_random_vectors(100, 32);
484 let mut index = IvfIndex::new(IvfConfig {
485 n_clusters: 10,
486 n_probe: 3,
487 ..Default::default()
488 });
489
490 index.train(&training_vectors).unwrap();
491
492 for (i, v) in training_vectors.iter().enumerate() {
494 index.add(format!("vec_{}", i), v.clone()).unwrap();
495 }
496
497 assert_eq!(index.len(), 100);
498
499 let query = &training_vectors[0];
501 let results = index.search(query, 5).unwrap();
502
503 assert!(!results.is_empty());
504 assert!(results.len() <= 5);
505 assert_eq!(results[0].id, "vec_0");
507 }
508
509 #[test]
510 fn test_ivf_remove() {
511 let vectors = generate_random_vectors(50, 16);
512 let mut index = IvfIndex::new(IvfConfig {
513 n_clusters: 5,
514 ..Default::default()
515 });
516
517 index.train(&vectors).unwrap();
518
519 for (i, v) in vectors.iter().enumerate() {
520 index.add(format!("vec_{}", i), v.clone()).unwrap();
521 }
522
523 assert_eq!(index.len(), 50);
524
525 let removed = index.remove("vec_10");
526 assert!(removed);
527 assert_eq!(index.len(), 49);
528
529 let not_removed = index.remove("nonexistent");
530 assert!(!not_removed);
531 }
532
533 #[test]
534 fn test_ivf_dimension_mismatch() {
535 let vectors = generate_random_vectors(50, 16);
536 let mut index = IvfIndex::new(IvfConfig {
537 n_clusters: 5,
538 ..Default::default()
539 });
540
541 index.train(&vectors).unwrap();
542 index.add("test".to_string(), vectors[0].clone()).unwrap();
543
544 let wrong_dim = vec![0.0; 32];
546 let result = index.add("wrong".to_string(), wrong_dim);
547 assert!(result.is_err());
548 }
549
550 #[test]
551 fn test_ivf_untrained_error() {
552 let index = IvfIndex::with_defaults();
553
554 let result = index.add("test".to_string(), vec![0.0; 32]);
555 assert!(result.is_err());
556
557 let result = index.search(&[0.0; 32], 5);
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_kmeans_convergence() {
563 let mut vectors = Vec::new();
565 let mut rng = rand::thread_rng();
566
567 for _ in 0..30 {
569 vectors.push(vec![1.0 + rng.gen::<f32>() * 0.1, rng.gen::<f32>() * 0.1]);
570 }
571
572 for _ in 0..30 {
574 vectors.push(vec![rng.gen::<f32>() * 0.1, 1.0 + rng.gen::<f32>() * 0.1]);
575 }
576
577 let mut index = IvfIndex::new(IvfConfig {
578 n_clusters: 2,
579 max_iterations: 50,
580 convergence_threshold: 1e-4,
581 metric: DistanceMetric::Euclidean,
582 ..Default::default()
583 });
584
585 index.train(&vectors).unwrap();
586
587 let centroids = index.centroids.read();
589 assert_eq!(centroids.len(), 2);
590
591 let c1 = ¢roids[0];
593 let c2 = ¢roids[1];
594 let dist = euclidean_distance(c1, c2);
595 assert!(
596 dist > 0.5,
597 "Centroids should be well separated, got dist={}",
598 dist
599 );
600 }
601
602 fn brute_force_knn(
608 query: &[f32],
609 vectors: &[(String, Vec<f32>)],
610 k: usize,
611 metric: DistanceMetric,
612 ) -> Vec<String> {
613 let mut distances: Vec<(String, f32)> = vectors
614 .iter()
615 .map(|(id, v)| (id.clone(), calculate_distance(query, v, metric)))
616 .collect();
617
618 distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
620 distances.into_iter().take(k).map(|(id, _)| id).collect()
621 }
622
623 fn calculate_recall(predicted: &[String], actual: &[String]) -> f32 {
625 let predicted_set: std::collections::HashSet<_> = predicted.iter().collect();
626 let found = actual
627 .iter()
628 .filter(|id| predicted_set.contains(id))
629 .count();
630 found as f32 / actual.len() as f32
631 }
632
633 #[test]
634 fn test_ivf_recall_at_k() {
635 let n_vectors = 500;
637 let dim = 64;
638 let n_clusters = 20;
639 let k = 10;
640
641 let vectors = generate_random_vectors(n_vectors, dim);
642 let mut index = IvfIndex::new(IvfConfig {
643 n_clusters,
644 n_probe: 5, metric: DistanceMetric::Euclidean,
646 ..Default::default()
647 });
648
649 index.train(&vectors).unwrap();
650
651 let indexed: Vec<(String, Vec<f32>)> = vectors
653 .iter()
654 .enumerate()
655 .map(|(i, v)| (format!("vec_{}", i), v.clone()))
656 .collect();
657
658 for (id, v) in &indexed {
659 index.add(id.clone(), v.clone()).unwrap();
660 }
661
662 let n_queries = 20;
664 let mut total_recall = 0.0;
665
666 for q_idx in 0..n_queries {
667 let query = &vectors[q_idx * (n_vectors / n_queries)];
668
669 let ivf_results = index.search(query, k).unwrap();
671 let ivf_ids: Vec<String> = ivf_results.iter().map(|r| r.id.clone()).collect();
672
673 let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
675
676 let recall = calculate_recall(&ivf_ids, &exact_ids);
677 total_recall += recall;
678 }
679
680 let avg_recall = total_recall / n_queries as f32;
681
682 assert!(
684 avg_recall > 0.5,
685 "Average recall@{} should be > 0.5, got {}",
686 k,
687 avg_recall
688 );
689 }
690
691 #[test]
692 fn test_ivf_nprobe_effect_on_recall() {
693 let n_vectors = 300;
695 let dim = 32;
696 let n_clusters = 15;
697 let k = 5;
698
699 let vectors = generate_random_vectors(n_vectors, dim);
700
701 let mut index_low = IvfIndex::new(IvfConfig {
703 n_clusters,
704 n_probe: 2, metric: DistanceMetric::Euclidean,
706 ..Default::default()
707 });
708
709 index_low.train(&vectors).unwrap();
710
711 let indexed: Vec<(String, Vec<f32>)> = vectors
712 .iter()
713 .enumerate()
714 .map(|(i, v)| (format!("vec_{}", i), v.clone()))
715 .collect();
716
717 for (id, v) in &indexed {
718 index_low.add(id.clone(), v.clone()).unwrap();
719 }
720
721 let mut index_high = IvfIndex::new(IvfConfig {
723 n_clusters,
724 n_probe: 10, metric: DistanceMetric::Euclidean,
726 ..Default::default()
727 });
728
729 index_high.train(&vectors).unwrap();
730
731 for (id, v) in &indexed {
732 index_high.add(id.clone(), v.clone()).unwrap();
733 }
734
735 let n_queries = 10;
737 let mut recall_low = 0.0;
738 let mut recall_high = 0.0;
739
740 for q_idx in 0..n_queries {
741 let query = &vectors[q_idx * (n_vectors / n_queries)];
742
743 let low_results = index_low.search(query, k).unwrap();
744 let low_ids: Vec<String> = low_results.iter().map(|r| r.id.clone()).collect();
745
746 let high_results = index_high.search(query, k).unwrap();
747 let high_ids: Vec<String> = high_results.iter().map(|r| r.id.clone()).collect();
748
749 let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
750
751 recall_low += calculate_recall(&low_ids, &exact_ids);
752 recall_high += calculate_recall(&high_ids, &exact_ids);
753 }
754
755 let avg_recall_low = recall_low / n_queries as f32;
756 let avg_recall_high = recall_high / n_queries as f32;
757
758 assert!(
760 avg_recall_high >= avg_recall_low,
761 "Higher nprobe should give equal or better recall: low={}, high={}",
762 avg_recall_low,
763 avg_recall_high
764 );
765 }
766
767 #[test]
768 fn test_ivf_cluster_distribution() {
769 let n_vectors = 200;
771 let dim = 16;
772 let n_clusters = 10;
773
774 let vectors = generate_random_vectors(n_vectors, dim);
775 let mut index = IvfIndex::new(IvfConfig {
776 n_clusters,
777 n_probe: 3,
778 metric: DistanceMetric::Euclidean,
779 ..Default::default()
780 });
781
782 index.train(&vectors).unwrap();
783
784 for (i, v) in vectors.iter().enumerate() {
785 index.add(format!("vec_{}", i), v.clone()).unwrap();
786 }
787
788 let lists = index.inverted_lists.read();
790 let cluster_sizes: Vec<usize> = lists.values().map(|v| v.len()).collect();
791
792 let non_empty_clusters = cluster_sizes.iter().filter(|&&s| s > 0).count();
794 assert!(
795 non_empty_clusters >= n_clusters / 2,
796 "At least half of clusters should be used: {} out of {}",
797 non_empty_clusters,
798 n_clusters
799 );
800
801 let max_cluster_size = cluster_sizes.iter().max().copied().unwrap_or(0);
803 assert!(
804 max_cluster_size < n_vectors * 3 / 4,
805 "No cluster should have more than 75% of vectors: {} out of {}",
806 max_cluster_size,
807 n_vectors
808 );
809 }
810
811 #[test]
812 fn test_ivf_high_dimensional_accuracy() {
813 let n_vectors = 200;
815 let dim = 128;
816 let n_clusters = 16;
817 let k = 5;
818
819 let vectors = generate_random_vectors(n_vectors, dim);
820 let mut index = IvfIndex::new(IvfConfig {
821 n_clusters,
822 n_probe: 4,
823 metric: DistanceMetric::Cosine, ..Default::default()
825 });
826
827 index.train(&vectors).unwrap();
828
829 let indexed: Vec<(String, Vec<f32>)> = vectors
830 .iter()
831 .enumerate()
832 .map(|(i, v)| (format!("vec_{}", i), v.clone()))
833 .collect();
834
835 for (id, v) in &indexed {
836 index.add(id.clone(), v.clone()).unwrap();
837 }
838
839 let query = &vectors[0];
841 let results = index.search(query, k).unwrap();
842
843 assert!(!results.is_empty());
844 assert!(results.len() <= k);
845
846 assert_eq!(results[0].id, "vec_0");
848
849 for result in &results {
851 assert!(
852 result.score.is_finite(),
853 "Score should be finite, got {}",
854 result.score
855 );
856 }
857 }
858
859 #[test]
860 fn test_ivf_cosine_vs_euclidean() {
861 let vectors = vec![
863 vec![1.0, 0.0, 0.0],
864 vec![0.9, 0.1, 0.0],
865 vec![0.0, 1.0, 0.0],
866 vec![0.0, 0.0, 1.0],
867 vec![0.5, 0.5, 0.0],
868 ];
869
870 let mut index_cosine = IvfIndex::new(IvfConfig {
872 n_clusters: 2,
873 n_probe: 2,
874 metric: DistanceMetric::Cosine,
875 ..Default::default()
876 });
877 index_cosine.train(&vectors).unwrap();
878
879 for (i, v) in vectors.iter().enumerate() {
880 index_cosine.add(format!("vec_{}", i), v.clone()).unwrap();
881 }
882
883 let query = vec![0.95, 0.05, 0.0];
885 let results_cosine = index_cosine.search(&query, 3).unwrap();
886
887 assert_eq!(results_cosine.len(), 3);
889
890 let mut index_euclidean = IvfIndex::new(IvfConfig {
892 n_clusters: 2,
893 n_probe: 2,
894 metric: DistanceMetric::Euclidean,
895 ..Default::default()
896 });
897 index_euclidean.train(&vectors).unwrap();
898
899 for (i, v) in vectors.iter().enumerate() {
900 index_euclidean
901 .add(format!("vec_{}", i), v.clone())
902 .unwrap();
903 }
904
905 let results_euclidean = index_euclidean.search(&query, 3).unwrap();
906 assert_eq!(results_euclidean.len(), 3);
907
908 let top_cosine = &results_cosine[0].id;
910 let top_euclidean = &results_euclidean[0].id;
911 assert!(
912 top_cosine == "vec_0" || top_cosine == "vec_1",
913 "Cosine top result should be vec_0 or vec_1, got {}",
914 top_cosine
915 );
916 assert!(
917 top_euclidean == "vec_0" || top_euclidean == "vec_1",
918 "Euclidean top result should be vec_0 or vec_1, got {}",
919 top_euclidean
920 );
921 }
922
923 #[test]
924 fn test_ivf_batch_accuracy() {
925 let n_vectors = 100;
927 let dim = 32;
928
929 let vectors = generate_random_vectors(n_vectors, dim);
930 let mut index = IvfIndex::new(IvfConfig {
931 n_clusters: 10,
932 n_probe: 5,
933 metric: DistanceMetric::Euclidean,
934 ..Default::default()
935 });
936
937 index.train(&vectors).unwrap();
938
939 let batch: Vec<(String, Vec<f32>)> = vectors
941 .iter()
942 .enumerate()
943 .map(|(i, v)| (format!("vec_{}", i), v.clone()))
944 .collect();
945
946 let added = index.add_batch(batch.clone()).unwrap();
947 assert_eq!(added, n_vectors);
948 assert_eq!(index.len(), n_vectors);
949
950 let query = &vectors[0];
952 let results = index.search(query, 5).unwrap();
953
954 assert!(!results.is_empty());
955 assert!(
957 results.iter().any(|r| r.id == "vec_0"),
958 "Query vector should be in search results"
959 );
960 }
961
962 #[test]
963 fn test_ivf_empty_cluster_handling() {
964 let vectors = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]];
967
968 let mut index = IvfIndex::new(IvfConfig {
969 n_clusters: 3, n_probe: 3,
971 metric: DistanceMetric::Euclidean,
972 ..Default::default()
973 });
974
975 index.train(&vectors).unwrap();
976
977 for (i, v) in vectors.iter().enumerate() {
978 index.add(format!("vec_{}", i), v.clone()).unwrap();
979 }
980
981 let results = index.search(&vec![0.5, 0.5], 2).unwrap();
983 assert!(!results.is_empty());
984 }
985}