Skip to main content

engine/
ivf.rs

1//! IVF (Inverted File) Index with K-means clustering
2//!
3//! This index partitions vectors into clusters using k-means,
4//! enabling sublinear search time by only searching relevant clusters.
5
6use common::DistanceMetric;
7use parking_lot::RwLock;
8use rand::Rng;
9use std::collections::HashMap;
10
11use crate::distance::calculate_distance;
12
13/// Configuration for IVF index
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct IvfConfig {
16    /// Number of clusters (centroids)
17    pub n_clusters: usize,
18    /// Number of clusters to probe during search
19    pub n_probe: usize,
20    /// Maximum k-means iterations
21    pub max_iterations: usize,
22    /// Convergence threshold for k-means
23    pub convergence_threshold: f32,
24    /// Distance metric to use
25    pub metric: DistanceMetric,
26}
27
28impl Default for IvfConfig {
29    fn default() -> Self {
30        Self {
31            n_clusters: 256,
32            n_probe: 10,
33            max_iterations: 100,
34            convergence_threshold: 1e-4,
35            metric: DistanceMetric::Cosine,
36        }
37    }
38}
39
40/// A vector stored in the index with its ID
41#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
42pub struct IndexedVector {
43    pub id: String,
44    pub values: Vec<f32>,
45}
46
47/// IVF Index for approximate nearest neighbor search
48pub struct IvfIndex {
49    config: IvfConfig,
50    dimension: Option<usize>,
51    /// Cluster centroids
52    centroids: RwLock<Vec<Vec<f32>>>,
53    /// Inverted lists: cluster_id -> vectors in that cluster
54    inverted_lists: RwLock<HashMap<usize, Vec<IndexedVector>>>,
55    /// Total vector count
56    vector_count: RwLock<usize>,
57    /// Whether the index has been trained
58    is_trained: RwLock<bool>,
59}
60
61impl IvfIndex {
62    /// Create a new IVF index with the given configuration
63    pub fn new(config: IvfConfig) -> Self {
64        Self {
65            config,
66            dimension: None,
67            centroids: RwLock::new(Vec::new()),
68            inverted_lists: RwLock::new(HashMap::new()),
69            vector_count: RwLock::new(0),
70            is_trained: RwLock::new(false),
71        }
72    }
73
74    /// Create with default configuration
75    pub fn with_defaults() -> Self {
76        Self::new(IvfConfig::default())
77    }
78
79    /// Train the index using k-means clustering
80    pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
81        if vectors.is_empty() {
82            return Err("Cannot train on empty vector set".to_string());
83        }
84
85        let dim = vectors[0].len();
86        if dim == 0 {
87            return Err("Vector dimension cannot be zero".to_string());
88        }
89
90        // Validate all vectors have same dimension
91        for v in vectors {
92            if v.len() != dim {
93                return Err(format!(
94                    "Dimension mismatch: expected {}, got {}",
95                    dim,
96                    v.len()
97                ));
98            }
99        }
100
101        self.dimension = Some(dim);
102
103        // Adjust n_clusters if we have fewer vectors
104        let n_clusters = self.config.n_clusters.min(vectors.len());
105
106        // Run k-means clustering
107        let centroids = self.kmeans(vectors, n_clusters)?;
108
109        *self.centroids.write() = centroids;
110        *self.is_trained.write() = true;
111
112        // Initialize empty inverted lists
113        let mut lists = self.inverted_lists.write();
114        lists.clear();
115        for i in 0..n_clusters {
116            lists.insert(i, Vec::new());
117        }
118
119        tracing::info!(
120            n_clusters = n_clusters,
121            dimension = dim,
122            training_vectors = vectors.len(),
123            "IVF index trained"
124        );
125
126        Ok(())
127    }
128
129    /// K-means clustering algorithm
130    fn kmeans(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
131        let dim = vectors[0].len();
132        let mut rng = rand::thread_rng();
133
134        // Initialize centroids using k-means++ initialization
135        let mut centroids = self.kmeans_plus_plus_init(vectors, k, &mut rng);
136
137        for iteration in 0..self.config.max_iterations {
138            // Assign vectors to nearest centroid
139            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
140
141            for (idx, vector) in vectors.iter().enumerate() {
142                let nearest = self.find_nearest_centroid(vector, &centroids);
143                assignments[nearest].push(idx);
144            }
145
146            // Compute new centroids
147            let mut new_centroids = Vec::with_capacity(k);
148            let mut max_shift = 0.0f32;
149
150            for (cluster_idx, indices) in assignments.iter().enumerate() {
151                if indices.is_empty() {
152                    // Keep old centroid if cluster is empty
153                    new_centroids.push(centroids[cluster_idx].clone());
154                    continue;
155                }
156
157                // Compute mean of assigned vectors
158                let mut new_centroid = vec![0.0f32; dim];
159                for &idx in indices {
160                    for (j, val) in vectors[idx].iter().enumerate() {
161                        new_centroid[j] += val;
162                    }
163                }
164                for val in &mut new_centroid {
165                    *val /= indices.len() as f32;
166                }
167
168                // Compute shift from old centroid
169                let shift = euclidean_distance(&centroids[cluster_idx], &new_centroid);
170                max_shift = max_shift.max(shift);
171
172                new_centroids.push(new_centroid);
173            }
174
175            centroids = new_centroids;
176
177            // Check convergence
178            if max_shift < self.config.convergence_threshold {
179                tracing::debug!(
180                    iteration = iteration,
181                    max_shift = max_shift,
182                    "K-means converged"
183                );
184                break;
185            }
186        }
187
188        Ok(centroids)
189    }
190
191    /// K-means++ initialization for better initial centroids
192    fn kmeans_plus_plus_init<R: Rng>(
193        &self,
194        vectors: &[Vec<f32>],
195        k: usize,
196        rng: &mut R,
197    ) -> Vec<Vec<f32>> {
198        let mut centroids = Vec::with_capacity(k);
199
200        // Choose first centroid randomly
201        let first_idx = rng.gen_range(0..vectors.len());
202        centroids.push(vectors[first_idx].clone());
203
204        // Choose remaining centroids with probability proportional to squared distance
205        for _ in 1..k {
206            let mut distances: Vec<f32> = vectors
207                .iter()
208                .map(|v| {
209                    centroids
210                        .iter()
211                        .map(|c| euclidean_distance(v, c))
212                        .fold(f32::MAX, f32::min)
213                        .powi(2)
214                })
215                .collect();
216
217            let total: f32 = distances.iter().sum();
218            if total == 0.0 {
219                // All remaining vectors are at centroid positions
220                break;
221            }
222
223            // Normalize to probabilities
224            for d in &mut distances {
225                *d /= total;
226            }
227
228            // Sample according to distribution
229            let sample: f32 = rng.gen();
230            let mut cumsum = 0.0;
231            let mut selected = 0;
232            for (i, &d) in distances.iter().enumerate() {
233                cumsum += d;
234                if cumsum >= sample {
235                    selected = i;
236                    break;
237                }
238            }
239
240            centroids.push(vectors[selected].clone());
241        }
242
243        centroids
244    }
245
246    /// Find the index of the nearest centroid to a vector
247    fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
248        let mut best_idx = 0;
249        let mut best_score = f32::NEG_INFINITY;
250
251        for (idx, centroid) in centroids.iter().enumerate() {
252            let score = calculate_distance(vector, centroid, self.config.metric);
253            if score > best_score {
254                best_score = score;
255                best_idx = idx;
256            }
257        }
258
259        best_idx
260    }
261
262    /// Find the top-n nearest centroids to a vector
263    fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<usize> {
264        let centroids = self.centroids.read();
265        let mut scores: Vec<(usize, f32)> = centroids
266            .iter()
267            .enumerate()
268            .map(|(idx, c)| (idx, calculate_distance(vector, c, self.config.metric)))
269            .collect();
270
271        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
272        scores.into_iter().take(n).map(|(idx, _)| idx).collect()
273    }
274
275    /// Add a vector to the index (must be trained first)
276    pub fn add(&self, id: String, vector: Vec<f32>) -> Result<(), String> {
277        if !*self.is_trained.read() {
278            return Err("Index must be trained before adding vectors".to_string());
279        }
280
281        if let Some(dim) = self.dimension {
282            if vector.len() != dim {
283                return Err(format!(
284                    "Dimension mismatch: expected {}, got {}",
285                    dim,
286                    vector.len()
287                ));
288            }
289        }
290
291        let centroids = self.centroids.read();
292        let cluster_idx = self.find_nearest_centroid(&vector, &centroids);
293        drop(centroids);
294
295        let indexed = IndexedVector { id, values: vector };
296
297        let mut lists = self.inverted_lists.write();
298        lists.entry(cluster_idx).or_default().push(indexed);
299        drop(lists);
300
301        *self.vector_count.write() += 1;
302
303        Ok(())
304    }
305
306    /// Add multiple vectors to the index
307    pub fn add_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<usize, String> {
308        let mut count = 0;
309        for (id, vector) in vectors {
310            self.add(id, vector)?;
311            count += 1;
312        }
313        Ok(count)
314    }
315
316    /// Search for the k nearest neighbors
317    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, String> {
318        if !*self.is_trained.read() {
319            return Err("Index must be trained before searching".to_string());
320        }
321
322        if let Some(dim) = self.dimension {
323            if query.len() != dim {
324                return Err(format!(
325                    "Dimension mismatch: expected {}, got {}",
326                    dim,
327                    query.len()
328                ));
329            }
330        }
331
332        // Find nearest centroids to probe
333        let n_probe = self.config.n_probe.min(self.centroids.read().len());
334        let probe_clusters = self.find_nearest_centroids(query, n_probe);
335
336        // Search in selected clusters
337        let mut candidates: Vec<SearchResult> = Vec::new();
338        let lists = self.inverted_lists.read();
339
340        for cluster_idx in probe_clusters {
341            if let Some(vectors) = lists.get(&cluster_idx) {
342                for indexed in vectors {
343                    let score = calculate_distance(query, &indexed.values, self.config.metric);
344                    candidates.push(SearchResult {
345                        id: indexed.id.clone(),
346                        score,
347                    });
348                }
349            }
350        }
351
352        // Sort by score (descending) and take top-k
353        candidates.sort_by(|a, b| {
354            b.score
355                .partial_cmp(&a.score)
356                .unwrap_or(std::cmp::Ordering::Equal)
357        });
358        candidates.truncate(k);
359
360        Ok(candidates)
361    }
362
363    /// Remove a vector by ID
364    pub fn remove(&self, id: &str) -> bool {
365        let mut lists = self.inverted_lists.write();
366        let mut removed = false;
367
368        for vectors in lists.values_mut() {
369            if let Some(pos) = vectors.iter().position(|v| v.id == id) {
370                vectors.remove(pos);
371                removed = true;
372                break;
373            }
374        }
375
376        if removed {
377            *self.vector_count.write() -= 1;
378        }
379
380        removed
381    }
382
383    /// Get total number of indexed vectors
384    pub fn len(&self) -> usize {
385        *self.vector_count.read()
386    }
387
388    /// Check if index is empty
389    pub fn is_empty(&self) -> bool {
390        self.len() == 0
391    }
392
393    /// Check if index is trained
394    pub fn is_trained(&self) -> bool {
395        *self.is_trained.read()
396    }
397
398    /// Get number of clusters
399    pub fn n_clusters(&self) -> usize {
400        self.centroids.read().len()
401    }
402
403    /// Get configuration
404    pub fn config(&self) -> &IvfConfig {
405        &self.config
406    }
407
408    /// Get dimension
409    pub fn dimension(&self) -> Option<usize> {
410        self.dimension
411    }
412
413    /// Get read access to centroids for persistence
414    pub(crate) fn centroids_read(&self) -> Vec<Vec<f32>> {
415        self.centroids.read().clone()
416    }
417
418    /// Get read access to inverted lists for persistence
419    pub(crate) fn inverted_lists_read(&self) -> HashMap<usize, Vec<IndexedVector>> {
420        self.inverted_lists.read().clone()
421    }
422
423    /// Restore IVF index from a full snapshot
424    pub fn from_snapshot(snapshot: crate::persistence::IvfFullSnapshot) -> Result<Self, String> {
425        let mut inverted_lists = HashMap::new();
426        for (cluster_id, vectors) in snapshot.inverted_lists {
427            inverted_lists.insert(cluster_id, vectors);
428        }
429
430        Ok(Self {
431            config: snapshot.config,
432            dimension: snapshot.dimension,
433            centroids: RwLock::new(snapshot.centroids),
434            inverted_lists: RwLock::new(inverted_lists),
435            vector_count: RwLock::new(snapshot.vector_count),
436            is_trained: RwLock::new(snapshot.is_trained),
437        })
438    }
439}
440
441/// Search result from IVF index
442#[derive(Debug, Clone)]
443pub struct SearchResult {
444    pub id: String,
445    pub score: f32,
446}
447
448/// Euclidean distance (L2)
449fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
450    a.iter()
451        .zip(b.iter())
452        .map(|(x, y)| (x - y).powi(2))
453        .sum::<f32>()
454        .sqrt()
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
462        let mut rng = rand::thread_rng();
463        (0..n)
464            .map(|_| (0..dim).map(|_| rng.gen::<f32>()).collect())
465            .collect()
466    }
467
468    #[test]
469    fn test_ivf_train() {
470        let vectors = generate_random_vectors(100, 32);
471        let mut index = IvfIndex::new(IvfConfig {
472            n_clusters: 10,
473            ..Default::default()
474        });
475
476        index.train(&vectors).unwrap();
477        assert!(index.is_trained());
478        assert_eq!(index.n_clusters(), 10);
479    }
480
481    #[test]
482    fn test_ivf_add_and_search() {
483        let training_vectors = generate_random_vectors(100, 32);
484        let mut index = IvfIndex::new(IvfConfig {
485            n_clusters: 10,
486            n_probe: 3,
487            ..Default::default()
488        });
489
490        index.train(&training_vectors).unwrap();
491
492        // Add vectors
493        for (i, v) in training_vectors.iter().enumerate() {
494            index.add(format!("vec_{}", i), v.clone()).unwrap();
495        }
496
497        assert_eq!(index.len(), 100);
498
499        // Search
500        let query = &training_vectors[0];
501        let results = index.search(query, 5).unwrap();
502
503        assert!(!results.is_empty());
504        assert!(results.len() <= 5);
505        // First result should be the query itself (exact match)
506        assert_eq!(results[0].id, "vec_0");
507    }
508
509    #[test]
510    fn test_ivf_remove() {
511        let vectors = generate_random_vectors(50, 16);
512        let mut index = IvfIndex::new(IvfConfig {
513            n_clusters: 5,
514            ..Default::default()
515        });
516
517        index.train(&vectors).unwrap();
518
519        for (i, v) in vectors.iter().enumerate() {
520            index.add(format!("vec_{}", i), v.clone()).unwrap();
521        }
522
523        assert_eq!(index.len(), 50);
524
525        let removed = index.remove("vec_10");
526        assert!(removed);
527        assert_eq!(index.len(), 49);
528
529        let not_removed = index.remove("nonexistent");
530        assert!(!not_removed);
531    }
532
533    #[test]
534    fn test_ivf_dimension_mismatch() {
535        let vectors = generate_random_vectors(50, 16);
536        let mut index = IvfIndex::new(IvfConfig {
537            n_clusters: 5,
538            ..Default::default()
539        });
540
541        index.train(&vectors).unwrap();
542        index.add("test".to_string(), vectors[0].clone()).unwrap();
543
544        // Try to add vector with wrong dimension
545        let wrong_dim = vec![0.0; 32];
546        let result = index.add("wrong".to_string(), wrong_dim);
547        assert!(result.is_err());
548    }
549
550    #[test]
551    fn test_ivf_untrained_error() {
552        let index = IvfIndex::with_defaults();
553
554        let result = index.add("test".to_string(), vec![0.0; 32]);
555        assert!(result.is_err());
556
557        let result = index.search(&[0.0; 32], 5);
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_kmeans_convergence() {
563        // Create clustered data with well-separated clusters
564        let mut vectors = Vec::new();
565        let mut rng = rand::thread_rng();
566
567        // Cluster 1 around [1, 0] (pointing right)
568        for _ in 0..30 {
569            vectors.push(vec![1.0 + rng.gen::<f32>() * 0.1, rng.gen::<f32>() * 0.1]);
570        }
571
572        // Cluster 2 around [0, 1] (pointing up)
573        for _ in 0..30 {
574            vectors.push(vec![rng.gen::<f32>() * 0.1, 1.0 + rng.gen::<f32>() * 0.1]);
575        }
576
577        let mut index = IvfIndex::new(IvfConfig {
578            n_clusters: 2,
579            max_iterations: 50,
580            convergence_threshold: 1e-4,
581            metric: DistanceMetric::Euclidean,
582            ..Default::default()
583        });
584
585        index.train(&vectors).unwrap();
586
587        // Centroids should be near [1, 0] and [0, 1]
588        let centroids = index.centroids.read();
589        assert_eq!(centroids.len(), 2);
590
591        // Check that centroids are distinct
592        let c1 = &centroids[0];
593        let c2 = &centroids[1];
594        let dist = euclidean_distance(c1, c2);
595        assert!(
596            dist > 0.5,
597            "Centroids should be well separated, got dist={}",
598            dist
599        );
600    }
601
602    // =====================================================
603    // IVF Index Accuracy Tests
604    // =====================================================
605
606    /// Compute exact k nearest neighbors using brute force
607    fn brute_force_knn(
608        query: &[f32],
609        vectors: &[(String, Vec<f32>)],
610        k: usize,
611        metric: DistanceMetric,
612    ) -> Vec<String> {
613        let mut distances: Vec<(String, f32)> = vectors
614            .iter()
615            .map(|(id, v)| (id.clone(), calculate_distance(query, v, metric)))
616            .collect();
617
618        // Sort by score descending (higher = more similar)
619        distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
620        distances.into_iter().take(k).map(|(id, _)| id).collect()
621    }
622
623    /// Calculate recall@k: fraction of true top-k neighbors found
624    fn calculate_recall(predicted: &[String], actual: &[String]) -> f32 {
625        let predicted_set: std::collections::HashSet<_> = predicted.iter().collect();
626        let found = actual
627            .iter()
628            .filter(|id| predicted_set.contains(id))
629            .count();
630        found as f32 / actual.len() as f32
631    }
632
633    #[test]
634    fn test_ivf_recall_at_k() {
635        // Test recall@k with controlled dataset
636        let n_vectors = 500;
637        let dim = 64;
638        let n_clusters = 20;
639        let k = 10;
640
641        let vectors = generate_random_vectors(n_vectors, dim);
642        let mut index = IvfIndex::new(IvfConfig {
643            n_clusters,
644            n_probe: 5, // Probe 25% of clusters
645            metric: DistanceMetric::Euclidean,
646            ..Default::default()
647        });
648
649        index.train(&vectors).unwrap();
650
651        // Index all vectors
652        let indexed: Vec<(String, Vec<f32>)> = vectors
653            .iter()
654            .enumerate()
655            .map(|(i, v)| (format!("vec_{}", i), v.clone()))
656            .collect();
657
658        for (id, v) in &indexed {
659            index.add(id.clone(), v.clone()).unwrap();
660        }
661
662        // Test recall on multiple queries
663        let n_queries = 20;
664        let mut total_recall = 0.0;
665
666        for q_idx in 0..n_queries {
667            let query = &vectors[q_idx * (n_vectors / n_queries)];
668
669            // Get IVF results
670            let ivf_results = index.search(query, k).unwrap();
671            let ivf_ids: Vec<String> = ivf_results.iter().map(|r| r.id.clone()).collect();
672
673            // Get exact results
674            let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
675
676            let recall = calculate_recall(&ivf_ids, &exact_ids);
677            total_recall += recall;
678        }
679
680        let avg_recall = total_recall / n_queries as f32;
681
682        // With n_probe=5 out of 20 clusters (25%), expect recall > 0.5
683        assert!(
684            avg_recall > 0.5,
685            "Average recall@{} should be > 0.5, got {}",
686            k,
687            avg_recall
688        );
689    }
690
691    #[test]
692    fn test_ivf_nprobe_effect_on_recall() {
693        // Verify that increasing nprobe improves recall
694        let n_vectors = 300;
695        let dim = 32;
696        let n_clusters = 15;
697        let k = 5;
698
699        let vectors = generate_random_vectors(n_vectors, dim);
700
701        // Index vectors with small nprobe first
702        let mut index_low = IvfIndex::new(IvfConfig {
703            n_clusters,
704            n_probe: 2, // Low nprobe
705            metric: DistanceMetric::Euclidean,
706            ..Default::default()
707        });
708
709        index_low.train(&vectors).unwrap();
710
711        let indexed: Vec<(String, Vec<f32>)> = vectors
712            .iter()
713            .enumerate()
714            .map(|(i, v)| (format!("vec_{}", i), v.clone()))
715            .collect();
716
717        for (id, v) in &indexed {
718            index_low.add(id.clone(), v.clone()).unwrap();
719        }
720
721        // Index with high nprobe (same centroids)
722        let mut index_high = IvfIndex::new(IvfConfig {
723            n_clusters,
724            n_probe: 10, // High nprobe
725            metric: DistanceMetric::Euclidean,
726            ..Default::default()
727        });
728
729        index_high.train(&vectors).unwrap();
730
731        for (id, v) in &indexed {
732            index_high.add(id.clone(), v.clone()).unwrap();
733        }
734
735        // Test recall on multiple queries
736        let n_queries = 10;
737        let mut recall_low = 0.0;
738        let mut recall_high = 0.0;
739
740        for q_idx in 0..n_queries {
741            let query = &vectors[q_idx * (n_vectors / n_queries)];
742
743            let low_results = index_low.search(query, k).unwrap();
744            let low_ids: Vec<String> = low_results.iter().map(|r| r.id.clone()).collect();
745
746            let high_results = index_high.search(query, k).unwrap();
747            let high_ids: Vec<String> = high_results.iter().map(|r| r.id.clone()).collect();
748
749            let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
750
751            recall_low += calculate_recall(&low_ids, &exact_ids);
752            recall_high += calculate_recall(&high_ids, &exact_ids);
753        }
754
755        let avg_recall_low = recall_low / n_queries as f32;
756        let avg_recall_high = recall_high / n_queries as f32;
757
758        // High nprobe should generally have better recall
759        assert!(
760            avg_recall_high >= avg_recall_low,
761            "Higher nprobe should give equal or better recall: low={}, high={}",
762            avg_recall_low,
763            avg_recall_high
764        );
765    }
766
767    #[test]
768    fn test_ivf_cluster_distribution() {
769        // Verify vectors are distributed across clusters
770        let n_vectors = 200;
771        let dim = 16;
772        let n_clusters = 10;
773
774        let vectors = generate_random_vectors(n_vectors, dim);
775        let mut index = IvfIndex::new(IvfConfig {
776            n_clusters,
777            n_probe: 3,
778            metric: DistanceMetric::Euclidean,
779            ..Default::default()
780        });
781
782        index.train(&vectors).unwrap();
783
784        for (i, v) in vectors.iter().enumerate() {
785            index.add(format!("vec_{}", i), v.clone()).unwrap();
786        }
787
788        // Check distribution across clusters
789        let lists = index.inverted_lists.read();
790        let cluster_sizes: Vec<usize> = lists.values().map(|v| v.len()).collect();
791
792        // Verify multiple clusters are used
793        let non_empty_clusters = cluster_sizes.iter().filter(|&&s| s > 0).count();
794        assert!(
795            non_empty_clusters >= n_clusters / 2,
796            "At least half of clusters should be used: {} out of {}",
797            non_empty_clusters,
798            n_clusters
799        );
800
801        // Verify no single cluster has all vectors (reasonable distribution)
802        let max_cluster_size = cluster_sizes.iter().max().copied().unwrap_or(0);
803        assert!(
804            max_cluster_size < n_vectors * 3 / 4,
805            "No cluster should have more than 75% of vectors: {} out of {}",
806            max_cluster_size,
807            n_vectors
808        );
809    }
810
811    #[test]
812    fn test_ivf_high_dimensional_accuracy() {
813        // Test with higher dimensions (128D is common for embeddings)
814        let n_vectors = 200;
815        let dim = 128;
816        let n_clusters = 16;
817        let k = 5;
818
819        let vectors = generate_random_vectors(n_vectors, dim);
820        let mut index = IvfIndex::new(IvfConfig {
821            n_clusters,
822            n_probe: 4,
823            metric: DistanceMetric::Cosine, // Cosine is common for embeddings
824            ..Default::default()
825        });
826
827        index.train(&vectors).unwrap();
828
829        let indexed: Vec<(String, Vec<f32>)> = vectors
830            .iter()
831            .enumerate()
832            .map(|(i, v)| (format!("vec_{}", i), v.clone()))
833            .collect();
834
835        for (id, v) in &indexed {
836            index.add(id.clone(), v.clone()).unwrap();
837        }
838
839        // Verify search works and returns valid results
840        let query = &vectors[0];
841        let results = index.search(query, k).unwrap();
842
843        assert!(!results.is_empty());
844        assert!(results.len() <= k);
845
846        // First result should be the query itself (exact match)
847        assert_eq!(results[0].id, "vec_0");
848
849        // All scores should be valid (not NaN or Inf)
850        for result in &results {
851            assert!(
852                result.score.is_finite(),
853                "Score should be finite, got {}",
854                result.score
855            );
856        }
857    }
858
859    #[test]
860    fn test_ivf_cosine_vs_euclidean() {
861        // Compare behavior with different metrics
862        let vectors = vec![
863            vec![1.0, 0.0, 0.0],
864            vec![0.9, 0.1, 0.0],
865            vec![0.0, 1.0, 0.0],
866            vec![0.0, 0.0, 1.0],
867            vec![0.5, 0.5, 0.0],
868        ];
869
870        // Test with Cosine metric
871        let mut index_cosine = IvfIndex::new(IvfConfig {
872            n_clusters: 2,
873            n_probe: 2,
874            metric: DistanceMetric::Cosine,
875            ..Default::default()
876        });
877        index_cosine.train(&vectors).unwrap();
878
879        for (i, v) in vectors.iter().enumerate() {
880            index_cosine.add(format!("vec_{}", i), v.clone()).unwrap();
881        }
882
883        // Query similar to [1, 0, 0]
884        let query = vec![0.95, 0.05, 0.0];
885        let results_cosine = index_cosine.search(&query, 3).unwrap();
886
887        // With cosine, vec_0 [1,0,0] and vec_1 [0.9,0.1,0] should be most similar
888        assert_eq!(results_cosine.len(), 3);
889
890        // Test with Euclidean metric
891        let mut index_euclidean = IvfIndex::new(IvfConfig {
892            n_clusters: 2,
893            n_probe: 2,
894            metric: DistanceMetric::Euclidean,
895            ..Default::default()
896        });
897        index_euclidean.train(&vectors).unwrap();
898
899        for (i, v) in vectors.iter().enumerate() {
900            index_euclidean
901                .add(format!("vec_{}", i), v.clone())
902                .unwrap();
903        }
904
905        let results_euclidean = index_euclidean.search(&query, 3).unwrap();
906        assert_eq!(results_euclidean.len(), 3);
907
908        // Both should return vec_0 or vec_1 as top result for this query
909        let top_cosine = &results_cosine[0].id;
910        let top_euclidean = &results_euclidean[0].id;
911        assert!(
912            top_cosine == "vec_0" || top_cosine == "vec_1",
913            "Cosine top result should be vec_0 or vec_1, got {}",
914            top_cosine
915        );
916        assert!(
917            top_euclidean == "vec_0" || top_euclidean == "vec_1",
918            "Euclidean top result should be vec_0 or vec_1, got {}",
919            top_euclidean
920        );
921    }
922
923    #[test]
924    fn test_ivf_batch_accuracy() {
925        // Test add_batch doesn't affect accuracy
926        let n_vectors = 100;
927        let dim = 32;
928
929        let vectors = generate_random_vectors(n_vectors, dim);
930        let mut index = IvfIndex::new(IvfConfig {
931            n_clusters: 10,
932            n_probe: 5,
933            metric: DistanceMetric::Euclidean,
934            ..Default::default()
935        });
936
937        index.train(&vectors).unwrap();
938
939        // Add all vectors in batch
940        let batch: Vec<(String, Vec<f32>)> = vectors
941            .iter()
942            .enumerate()
943            .map(|(i, v)| (format!("vec_{}", i), v.clone()))
944            .collect();
945
946        let added = index.add_batch(batch.clone()).unwrap();
947        assert_eq!(added, n_vectors);
948        assert_eq!(index.len(), n_vectors);
949
950        // Verify search still works correctly
951        let query = &vectors[0];
952        let results = index.search(query, 5).unwrap();
953
954        assert!(!results.is_empty());
955        // Query itself should be in results
956        assert!(
957            results.iter().any(|r| r.id == "vec_0"),
958            "Query vector should be in search results"
959        );
960    }
961
962    #[test]
963    fn test_ivf_empty_cluster_handling() {
964        // Test behavior when some clusters are empty
965        // Use very few vectors with many clusters
966        let vectors = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]];
967
968        let mut index = IvfIndex::new(IvfConfig {
969            n_clusters: 3, // Same as vector count
970            n_probe: 3,
971            metric: DistanceMetric::Euclidean,
972            ..Default::default()
973        });
974
975        index.train(&vectors).unwrap();
976
977        for (i, v) in vectors.iter().enumerate() {
978            index.add(format!("vec_{}", i), v.clone()).unwrap();
979        }
980
981        // Search should still work even if probing empty clusters
982        let results = index.search(&vec![0.5, 0.5], 2).unwrap();
983        assert!(!results.is_empty());
984    }
985}