Skip to main content

oxirs_vec/
flat_ivf_index.rs

1//! Flat IVF (Inverted File Index) for approximate nearest-neighbour search.
2//!
3//! Uses k-means clustering (Lloyd's algorithm) to partition the vector space
4//! into `num_cells` Voronoi cells.  Search probes the `n_probe` nearest cells.
5
6// ── Types ─────────────────────────────────────────────────────────────────────
7
8/// A cluster centroid.
9#[derive(Debug, Clone)]
10pub struct Centroid {
11    pub id: usize,
12    pub vector: Vec<f32>,
13}
14
15/// One inverted-file cell: the centroid id, plus all vectors assigned to it.
16#[derive(Debug, Clone, Default)]
17pub struct IvfCell {
18    pub centroid_id: usize,
19    pub vector_ids: Vec<u64>,
20    pub vectors: Vec<Vec<f32>>,
21}
22
23/// A single search result.
24#[derive(Debug, Clone, PartialEq)]
25pub struct SearchResult {
26    pub id: u64,
27    pub distance: f32,
28}
29
30/// Flat IVF index.
31pub struct FlatIvfIndex {
32    pub dim: usize,
33    pub num_cells: usize,
34    pub cells: Vec<IvfCell>,
35    pub centroids: Vec<Centroid>,
36}
37
38impl FlatIvfIndex {
39    /// Create an untrained index with `num_cells` cells and vectors of dimension `dim`.
40    pub fn new(dim: usize, num_cells: usize) -> Self {
41        let cells: Vec<IvfCell> = (0..num_cells)
42            .map(|id| IvfCell {
43                centroid_id: id,
44                vector_ids: Vec::new(),
45                vectors: Vec::new(),
46            })
47            .collect();
48        FlatIvfIndex {
49            dim,
50            num_cells,
51            cells,
52            centroids: Vec::new(),
53        }
54    }
55
56    // ── Training ──────────────────────────────────────────────────────────
57
58    /// Train the index using Lloyd's k-means (up to 20 iterations) on `vectors`.
59    ///
60    /// Centroids are initialised by selecting evenly-spaced samples from the input.
61    pub fn train(&mut self, vectors: &[Vec<f32>]) {
62        if vectors.is_empty() || self.num_cells == 0 {
63            return;
64        }
65        let k = self.num_cells.min(vectors.len());
66
67        // --- Initialise centroids from evenly-spaced samples ---
68        let mut centroids: Vec<Vec<f32>> = (0..k)
69            .map(|i| {
70                let idx = (i * vectors.len()) / k;
71                vectors[idx].clone()
72            })
73            .collect();
74
75        // --- Lloyd's iterations ---
76        for _ in 0..20 {
77            // Assign each vector to its nearest centroid.
78            let assignments: Vec<usize> = vectors
79                .iter()
80                .map(|v| Self::nearest_centroid_from_list(&centroids, v))
81                .collect();
82
83            // Recompute centroids as cluster means.
84            let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0f32; self.dim]; k];
85            let mut counts: Vec<usize> = vec![0; k];
86
87            for (v, &c) in vectors.iter().zip(assignments.iter()) {
88                for (d, x) in new_centroids[c].iter_mut().zip(v.iter()) {
89                    *d += x;
90                }
91                counts[c] += 1;
92            }
93
94            let mut converged = true;
95            for c in 0..k {
96                if counts[c] == 0 {
97                    // Keep old centroid if empty cluster.
98                    new_centroids[c] = centroids[c].clone();
99                } else {
100                    for d in new_centroids[c].iter_mut() {
101                        *d /= counts[c] as f32;
102                    }
103                }
104                let change = Self::l2_distance(&centroids[c], &new_centroids[c]);
105                if change > 1e-6 {
106                    converged = false;
107                }
108            }
109            centroids = new_centroids;
110            if converged {
111                break;
112            }
113        }
114
115        // Store trained centroids.
116        self.centroids = centroids
117            .into_iter()
118            .enumerate()
119            .map(|(id, vector)| Centroid { id, vector })
120            .collect();
121
122        // Reset cells with updated centroids.
123        self.cells = (0..k)
124            .map(|id| IvfCell {
125                centroid_id: id,
126                vector_ids: Vec::new(),
127                vectors: Vec::new(),
128            })
129            .collect();
130        self.num_cells = k;
131    }
132
133    // ── Insertion ─────────────────────────────────────────────────────────
134
135    /// Insert a vector with the given id into the nearest cell.
136    pub fn insert(&mut self, id: u64, vector: Vec<f32>) {
137        let cell_idx = self.nearest_centroid(&vector);
138        let cell = &mut self.cells[cell_idx];
139        cell.vector_ids.push(id);
140        cell.vectors.push(vector);
141    }
142
143    // ── Search ────────────────────────────────────────────────────────────
144
145    /// Search for the `k` nearest neighbours of `query`, probing `n_probe` cells.
146    pub fn search(&self, query: &[f32], k: usize, n_probe: usize) -> Vec<SearchResult> {
147        if self.centroids.is_empty() || k == 0 {
148            return Vec::new();
149        }
150
151        let n_probe = n_probe.min(self.num_cells);
152
153        // Find the `n_probe` nearest centroids.
154        let mut centroid_dists: Vec<(usize, f32)> = self
155            .centroids
156            .iter()
157            .map(|c| (c.id, Self::l2_distance(query, &c.vector)))
158            .collect();
159        centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
160
161        // Collect candidates from the top `n_probe` cells.
162        let mut candidates: Vec<SearchResult> = Vec::new();
163        for (cell_id, _) in centroid_dists.iter().take(n_probe) {
164            let cell = &self.cells[*cell_id];
165            for (vec_id, vec) in cell.vector_ids.iter().zip(cell.vectors.iter()) {
166                let dist = Self::l2_distance(query, vec);
167                candidates.push(SearchResult {
168                    id: *vec_id,
169                    distance: dist,
170                });
171            }
172        }
173
174        // Sort by distance and return top-k.
175        candidates.sort_by(|a, b| {
176            a.distance
177                .partial_cmp(&b.distance)
178                .unwrap_or(std::cmp::Ordering::Equal)
179        });
180        candidates.truncate(k);
181        candidates
182    }
183
184    // ── Removal ───────────────────────────────────────────────────────────
185
186    /// Remove the vector with `id` from the index. Returns `true` if found.
187    pub fn remove(&mut self, id: u64) -> bool {
188        for cell in &mut self.cells {
189            if let Some(pos) = cell.vector_ids.iter().position(|&x| x == id) {
190                cell.vector_ids.remove(pos);
191                cell.vectors.remove(pos);
192                return true;
193            }
194        }
195        false
196    }
197
198    // ── Metadata ──────────────────────────────────────────────────────────
199
200    /// Total number of vectors in the index.
201    pub fn len(&self) -> usize {
202        self.cells.iter().map(|c| c.vector_ids.len()).sum()
203    }
204
205    /// Returns `true` if no vectors are stored.
206    pub fn is_empty(&self) -> bool {
207        self.len() == 0
208    }
209
210    // ── Internal helpers ──────────────────────────────────────────────────
211
212    /// Return the index of the nearest cell for `vec`.
213    pub fn nearest_centroid(&self, vec: &[f32]) -> usize {
214        if self.centroids.is_empty() {
215            // Fall back to modulo assignment before training.
216            return 0;
217        }
218        Self::nearest_centroid_from_list(
219            &self
220                .centroids
221                .iter()
222                .map(|c| c.vector.clone())
223                .collect::<Vec<_>>(),
224            vec,
225        )
226    }
227
228    fn nearest_centroid_from_list(centroids: &[Vec<f32>], vec: &[f32]) -> usize {
229        centroids
230            .iter()
231            .enumerate()
232            .map(|(i, c)| (i, Self::l2_distance(vec, c)))
233            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
234            .map(|(i, _)| i)
235            .unwrap_or(0)
236    }
237
238    /// Compute the squared L2 distance between two equal-length slices.
239    pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
240        a.iter()
241            .zip(b.iter())
242            .map(|(x, y)| (x - y) * (x - y))
243            .sum::<f32>()
244            .sqrt()
245    }
246}
247
248// ── Tests ─────────────────────────────────────────────────────────────────────
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    fn unit_vec(dim: usize, val: f32) -> Vec<f32> {
255        vec![val; dim]
256    }
257
258    // ── Construction ──────────────────────────────────────────────────────
259
260    #[test]
261    fn test_new_index() {
262        let idx = FlatIvfIndex::new(4, 3);
263        assert_eq!(idx.dim, 4);
264        assert_eq!(idx.num_cells, 3);
265        assert!(idx.is_empty());
266        assert_eq!(idx.len(), 0);
267    }
268
269    // ── Training ──────────────────────────────────────────────────────────
270
271    #[test]
272    fn test_train_basic() {
273        let mut idx = FlatIvfIndex::new(2, 2);
274        let vecs: Vec<Vec<f32>> = vec![
275            vec![0.0, 0.0],
276            vec![0.1, 0.1],
277            vec![10.0, 10.0],
278            vec![10.1, 10.1],
279        ];
280        idx.train(&vecs);
281        assert_eq!(idx.centroids.len(), 2);
282    }
283
284    #[test]
285    fn test_train_empty() {
286        let mut idx = FlatIvfIndex::new(2, 3);
287        idx.train(&[]);
288        assert!(idx.centroids.is_empty());
289    }
290
291    #[test]
292    fn test_train_fewer_vectors_than_cells() {
293        let mut idx = FlatIvfIndex::new(2, 10);
294        let vecs = vec![vec![1.0f32, 2.0], vec![3.0, 4.0]];
295        idx.train(&vecs);
296        assert!(idx.centroids.len() <= 2);
297    }
298
299    // ── Insert / len / is_empty ───────────────────────────────────────────
300
301    #[test]
302    fn test_insert_and_len() {
303        let mut idx = FlatIvfIndex::new(2, 2);
304        let vecs = vec![vec![0.0f32, 0.0], vec![10.0, 10.0]];
305        idx.train(&vecs);
306        idx.insert(1, vec![0.0, 0.0]);
307        idx.insert(2, vec![10.0, 10.0]);
308        assert_eq!(idx.len(), 2);
309        assert!(!idx.is_empty());
310    }
311
312    #[test]
313    fn test_insert_many() {
314        let mut idx = FlatIvfIndex::new(1, 3);
315        let vecs: Vec<Vec<f32>> = (0..30).map(|i| vec![i as f32]).collect();
316        idx.train(&vecs);
317        for i in 0u64..30 {
318            idx.insert(i, vec![i as f32]);
319        }
320        assert_eq!(idx.len(), 30);
321    }
322
323    // ── Remove ────────────────────────────────────────────────────────────
324
325    #[test]
326    fn test_remove_existing() {
327        let mut idx = FlatIvfIndex::new(2, 2);
328        idx.train(&[vec![0.0f32, 0.0], vec![5.0, 5.0]]);
329        idx.insert(42, vec![0.0, 0.0]);
330        assert!(idx.remove(42));
331        assert_eq!(idx.len(), 0);
332    }
333
334    #[test]
335    fn test_remove_nonexistent() {
336        let mut idx = FlatIvfIndex::new(2, 2);
337        idx.train(&[vec![0.0f32, 0.0], vec![5.0, 5.0]]);
338        assert!(!idx.remove(999));
339    }
340
341    #[test]
342    fn test_remove_and_search() {
343        let mut idx = FlatIvfIndex::new(1, 2);
344        idx.train(&[vec![0.0f32], vec![10.0]]);
345        idx.insert(1, vec![0.0]);
346        idx.insert(2, vec![10.0]);
347        idx.remove(1);
348        let results = idx.search(&[0.0], 10, 2);
349        assert!(!results.iter().any(|r| r.id == 1));
350    }
351
352    // ── Search ────────────────────────────────────────────────────────────
353
354    #[test]
355    fn test_search_nearest() {
356        let mut idx = FlatIvfIndex::new(1, 2);
357        let train_vecs = vec![vec![0.0f32], vec![100.0]];
358        idx.train(&train_vecs);
359        idx.insert(0, vec![0.0]);
360        idx.insert(1, vec![1.0]);
361        idx.insert(2, vec![100.0]);
362        let results = idx.search(&[0.5], 1, 1);
363        assert_eq!(results.len(), 1);
364        // Either 0 or 1 is nearest; both are close to 0.5.
365        assert!(results[0].id == 0 || results[0].id == 1);
366    }
367
368    #[test]
369    fn test_search_k_results() {
370        let mut idx = FlatIvfIndex::new(1, 2);
371        let vecs: Vec<Vec<f32>> = vec![vec![0.0], vec![100.0]];
372        idx.train(&vecs);
373        for i in 0u64..5 {
374            idx.insert(i, vec![i as f32]);
375        }
376        let results = idx.search(&[0.0], 3, 2);
377        assert!(results.len() <= 3);
378    }
379
380    #[test]
381    fn test_search_k_0_returns_empty() {
382        let mut idx = FlatIvfIndex::new(1, 2);
383        idx.train(&[vec![0.0f32], vec![1.0]]);
384        idx.insert(0, vec![0.0]);
385        let results = idx.search(&[0.0], 0, 1);
386        assert!(results.is_empty());
387    }
388
389    #[test]
390    fn test_search_empty_index() {
391        let idx = FlatIvfIndex::new(2, 3);
392        let results = idx.search(&[0.0, 0.0], 5, 2);
393        assert!(results.is_empty());
394    }
395
396    #[test]
397    fn test_search_n_probe_all_cells() {
398        let mut idx = FlatIvfIndex::new(1, 3);
399        let train_vecs: Vec<Vec<f32>> = vec![vec![0.0], vec![5.0], vec![10.0]];
400        idx.train(&train_vecs);
401        idx.insert(0, vec![0.0]);
402        idx.insert(1, vec![5.0]);
403        idx.insert(2, vec![10.0]);
404        let results = idx.search(&[5.0], 3, 3);
405        assert_eq!(results.len(), 3);
406    }
407
408    #[test]
409    fn test_search_sorted_by_distance() {
410        let mut idx = FlatIvfIndex::new(1, 2);
411        idx.train(&[vec![0.0f32], vec![10.0]]);
412        idx.insert(0, vec![0.0]);
413        idx.insert(1, vec![3.0]);
414        idx.insert(2, vec![10.0]);
415        let results = idx.search(&[0.0], 3, 2);
416        for i in 1..results.len() {
417            assert!(results[i - 1].distance <= results[i].distance);
418        }
419    }
420
421    // ── l2_distance ───────────────────────────────────────────────────────
422
423    #[test]
424    fn test_l2_distance_zero() {
425        let a = vec![1.0f32, 2.0, 3.0];
426        assert!((FlatIvfIndex::l2_distance(&a, &a)).abs() < 1e-6);
427    }
428
429    #[test]
430    fn test_l2_distance_unit_vector() {
431        let a = vec![1.0f32, 0.0];
432        let b = vec![0.0f32, 0.0];
433        assert!((FlatIvfIndex::l2_distance(&a, &b) - 1.0).abs() < 1e-6);
434    }
435
436    #[test]
437    fn test_l2_distance_symmetric() {
438        let a = vec![1.0f32, 2.0, 3.0];
439        let b = vec![4.0f32, 5.0, 6.0];
440        let d1 = FlatIvfIndex::l2_distance(&a, &b);
441        let d2 = FlatIvfIndex::l2_distance(&b, &a);
442        assert!((d1 - d2).abs() < 1e-6);
443    }
444
445    // ── nearest_centroid ──────────────────────────────────────────────────
446
447    #[test]
448    fn test_nearest_centroid_basic() {
449        let mut idx = FlatIvfIndex::new(1, 2);
450        idx.train(&[vec![0.0f32], vec![100.0]]);
451        let near_zero = idx.nearest_centroid(&[1.0]);
452        let near_hundred = idx.nearest_centroid(&[99.0]);
453        assert_ne!(near_zero, near_hundred);
454    }
455
456    // ── n_probe variation ─────────────────────────────────────────────────
457
458    #[test]
459    fn test_n_probe_1_vs_all() {
460        let mut idx = FlatIvfIndex::new(1, 4);
461        let tv: Vec<Vec<f32>> = vec![vec![0.0], vec![10.0], vec![20.0], vec![30.0]];
462        idx.train(&tv);
463        for i in 0..8u64 {
464            idx.insert(i, vec![(i as f32) * 5.0]);
465        }
466        let r1 = idx.search(&[15.0], 8, 1);
467        let r_all = idx.search(&[15.0], 8, 4);
468        // Probing all cells should find at least as many results.
469        assert!(r_all.len() >= r1.len());
470    }
471
472    // ── 2D vectors ────────────────────────────────────────────────────────
473
474    #[test]
475    fn test_2d_cluster_separation() {
476        let mut idx = FlatIvfIndex::new(2, 2);
477        let tv = vec![
478            vec![0.0f32, 0.0],
479            vec![0.5, 0.5],
480            vec![100.0, 100.0],
481            vec![100.5, 100.5],
482        ];
483        idx.train(&tv);
484        idx.insert(10, vec![0.2, 0.2]);
485        idx.insert(11, vec![100.2, 100.2]);
486
487        let results = idx.search(&[0.1, 0.1], 1, 1);
488        if !results.is_empty() {
489            assert_eq!(results[0].id, 10);
490        }
491    }
492
493    #[test]
494    fn test_exact_match() {
495        let mut idx = FlatIvfIndex::new(3, 2);
496        idx.train(&[vec![1.0f32, 2.0, 3.0], vec![10.0, 20.0, 30.0]]);
497        idx.insert(99, vec![5.0, 5.0, 5.0]);
498        let query = vec![5.0f32, 5.0, 5.0];
499        let results = idx.search(&query, 1, 2);
500        assert!(!results.is_empty());
501        assert!((results[0].distance).abs() < 1e-5);
502        assert_eq!(results[0].id, 99);
503    }
504}