1#[derive(Debug, Clone)]
10pub struct Centroid {
11 pub id: usize,
12 pub vector: Vec<f32>,
13}
14
15#[derive(Debug, Clone, Default)]
17pub struct IvfCell {
18 pub centroid_id: usize,
19 pub vector_ids: Vec<u64>,
20 pub vectors: Vec<Vec<f32>>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub struct SearchResult {
26 pub id: u64,
27 pub distance: f32,
28}
29
30pub struct FlatIvfIndex {
32 pub dim: usize,
33 pub num_cells: usize,
34 pub cells: Vec<IvfCell>,
35 pub centroids: Vec<Centroid>,
36}
37
38impl FlatIvfIndex {
39 pub fn new(dim: usize, num_cells: usize) -> Self {
41 let cells: Vec<IvfCell> = (0..num_cells)
42 .map(|id| IvfCell {
43 centroid_id: id,
44 vector_ids: Vec::new(),
45 vectors: Vec::new(),
46 })
47 .collect();
48 FlatIvfIndex {
49 dim,
50 num_cells,
51 cells,
52 centroids: Vec::new(),
53 }
54 }
55
56 pub fn train(&mut self, vectors: &[Vec<f32>]) {
62 if vectors.is_empty() || self.num_cells == 0 {
63 return;
64 }
65 let k = self.num_cells.min(vectors.len());
66
67 let mut centroids: Vec<Vec<f32>> = (0..k)
69 .map(|i| {
70 let idx = (i * vectors.len()) / k;
71 vectors[idx].clone()
72 })
73 .collect();
74
75 for _ in 0..20 {
77 let assignments: Vec<usize> = vectors
79 .iter()
80 .map(|v| Self::nearest_centroid_from_list(¢roids, v))
81 .collect();
82
83 let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0f32; self.dim]; k];
85 let mut counts: Vec<usize> = vec![0; k];
86
87 for (v, &c) in vectors.iter().zip(assignments.iter()) {
88 for (d, x) in new_centroids[c].iter_mut().zip(v.iter()) {
89 *d += x;
90 }
91 counts[c] += 1;
92 }
93
94 let mut converged = true;
95 for c in 0..k {
96 if counts[c] == 0 {
97 new_centroids[c] = centroids[c].clone();
99 } else {
100 for d in new_centroids[c].iter_mut() {
101 *d /= counts[c] as f32;
102 }
103 }
104 let change = Self::l2_distance(¢roids[c], &new_centroids[c]);
105 if change > 1e-6 {
106 converged = false;
107 }
108 }
109 centroids = new_centroids;
110 if converged {
111 break;
112 }
113 }
114
115 self.centroids = centroids
117 .into_iter()
118 .enumerate()
119 .map(|(id, vector)| Centroid { id, vector })
120 .collect();
121
122 self.cells = (0..k)
124 .map(|id| IvfCell {
125 centroid_id: id,
126 vector_ids: Vec::new(),
127 vectors: Vec::new(),
128 })
129 .collect();
130 self.num_cells = k;
131 }
132
133 pub fn insert(&mut self, id: u64, vector: Vec<f32>) {
137 let cell_idx = self.nearest_centroid(&vector);
138 let cell = &mut self.cells[cell_idx];
139 cell.vector_ids.push(id);
140 cell.vectors.push(vector);
141 }
142
143 pub fn search(&self, query: &[f32], k: usize, n_probe: usize) -> Vec<SearchResult> {
147 if self.centroids.is_empty() || k == 0 {
148 return Vec::new();
149 }
150
151 let n_probe = n_probe.min(self.num_cells);
152
153 let mut centroid_dists: Vec<(usize, f32)> = self
155 .centroids
156 .iter()
157 .map(|c| (c.id, Self::l2_distance(query, &c.vector)))
158 .collect();
159 centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
160
161 let mut candidates: Vec<SearchResult> = Vec::new();
163 for (cell_id, _) in centroid_dists.iter().take(n_probe) {
164 let cell = &self.cells[*cell_id];
165 for (vec_id, vec) in cell.vector_ids.iter().zip(cell.vectors.iter()) {
166 let dist = Self::l2_distance(query, vec);
167 candidates.push(SearchResult {
168 id: *vec_id,
169 distance: dist,
170 });
171 }
172 }
173
174 candidates.sort_by(|a, b| {
176 a.distance
177 .partial_cmp(&b.distance)
178 .unwrap_or(std::cmp::Ordering::Equal)
179 });
180 candidates.truncate(k);
181 candidates
182 }
183
184 pub fn remove(&mut self, id: u64) -> bool {
188 for cell in &mut self.cells {
189 if let Some(pos) = cell.vector_ids.iter().position(|&x| x == id) {
190 cell.vector_ids.remove(pos);
191 cell.vectors.remove(pos);
192 return true;
193 }
194 }
195 false
196 }
197
198 pub fn len(&self) -> usize {
202 self.cells.iter().map(|c| c.vector_ids.len()).sum()
203 }
204
205 pub fn is_empty(&self) -> bool {
207 self.len() == 0
208 }
209
210 pub fn nearest_centroid(&self, vec: &[f32]) -> usize {
214 if self.centroids.is_empty() {
215 return 0;
217 }
218 Self::nearest_centroid_from_list(
219 &self
220 .centroids
221 .iter()
222 .map(|c| c.vector.clone())
223 .collect::<Vec<_>>(),
224 vec,
225 )
226 }
227
228 fn nearest_centroid_from_list(centroids: &[Vec<f32>], vec: &[f32]) -> usize {
229 centroids
230 .iter()
231 .enumerate()
232 .map(|(i, c)| (i, Self::l2_distance(vec, c)))
233 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
234 .map(|(i, _)| i)
235 .unwrap_or(0)
236 }
237
238 pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
240 a.iter()
241 .zip(b.iter())
242 .map(|(x, y)| (x - y) * (x - y))
243 .sum::<f32>()
244 .sqrt()
245 }
246}
247
248#[cfg(test)]
251mod tests {
252 use super::*;
253
254 fn unit_vec(dim: usize, val: f32) -> Vec<f32> {
255 vec![val; dim]
256 }
257
258 #[test]
261 fn test_new_index() {
262 let idx = FlatIvfIndex::new(4, 3);
263 assert_eq!(idx.dim, 4);
264 assert_eq!(idx.num_cells, 3);
265 assert!(idx.is_empty());
266 assert_eq!(idx.len(), 0);
267 }
268
269 #[test]
272 fn test_train_basic() {
273 let mut idx = FlatIvfIndex::new(2, 2);
274 let vecs: Vec<Vec<f32>> = vec![
275 vec![0.0, 0.0],
276 vec![0.1, 0.1],
277 vec![10.0, 10.0],
278 vec![10.1, 10.1],
279 ];
280 idx.train(&vecs);
281 assert_eq!(idx.centroids.len(), 2);
282 }
283
284 #[test]
285 fn test_train_empty() {
286 let mut idx = FlatIvfIndex::new(2, 3);
287 idx.train(&[]);
288 assert!(idx.centroids.is_empty());
289 }
290
291 #[test]
292 fn test_train_fewer_vectors_than_cells() {
293 let mut idx = FlatIvfIndex::new(2, 10);
294 let vecs = vec![vec![1.0f32, 2.0], vec![3.0, 4.0]];
295 idx.train(&vecs);
296 assert!(idx.centroids.len() <= 2);
297 }
298
299 #[test]
302 fn test_insert_and_len() {
303 let mut idx = FlatIvfIndex::new(2, 2);
304 let vecs = vec![vec![0.0f32, 0.0], vec![10.0, 10.0]];
305 idx.train(&vecs);
306 idx.insert(1, vec![0.0, 0.0]);
307 idx.insert(2, vec![10.0, 10.0]);
308 assert_eq!(idx.len(), 2);
309 assert!(!idx.is_empty());
310 }
311
312 #[test]
313 fn test_insert_many() {
314 let mut idx = FlatIvfIndex::new(1, 3);
315 let vecs: Vec<Vec<f32>> = (0..30).map(|i| vec![i as f32]).collect();
316 idx.train(&vecs);
317 for i in 0u64..30 {
318 idx.insert(i, vec![i as f32]);
319 }
320 assert_eq!(idx.len(), 30);
321 }
322
323 #[test]
326 fn test_remove_existing() {
327 let mut idx = FlatIvfIndex::new(2, 2);
328 idx.train(&[vec![0.0f32, 0.0], vec![5.0, 5.0]]);
329 idx.insert(42, vec![0.0, 0.0]);
330 assert!(idx.remove(42));
331 assert_eq!(idx.len(), 0);
332 }
333
334 #[test]
335 fn test_remove_nonexistent() {
336 let mut idx = FlatIvfIndex::new(2, 2);
337 idx.train(&[vec![0.0f32, 0.0], vec![5.0, 5.0]]);
338 assert!(!idx.remove(999));
339 }
340
341 #[test]
342 fn test_remove_and_search() {
343 let mut idx = FlatIvfIndex::new(1, 2);
344 idx.train(&[vec![0.0f32], vec![10.0]]);
345 idx.insert(1, vec![0.0]);
346 idx.insert(2, vec![10.0]);
347 idx.remove(1);
348 let results = idx.search(&[0.0], 10, 2);
349 assert!(!results.iter().any(|r| r.id == 1));
350 }
351
352 #[test]
355 fn test_search_nearest() {
356 let mut idx = FlatIvfIndex::new(1, 2);
357 let train_vecs = vec![vec![0.0f32], vec![100.0]];
358 idx.train(&train_vecs);
359 idx.insert(0, vec![0.0]);
360 idx.insert(1, vec![1.0]);
361 idx.insert(2, vec![100.0]);
362 let results = idx.search(&[0.5], 1, 1);
363 assert_eq!(results.len(), 1);
364 assert!(results[0].id == 0 || results[0].id == 1);
366 }
367
368 #[test]
369 fn test_search_k_results() {
370 let mut idx = FlatIvfIndex::new(1, 2);
371 let vecs: Vec<Vec<f32>> = vec![vec![0.0], vec![100.0]];
372 idx.train(&vecs);
373 for i in 0u64..5 {
374 idx.insert(i, vec![i as f32]);
375 }
376 let results = idx.search(&[0.0], 3, 2);
377 assert!(results.len() <= 3);
378 }
379
380 #[test]
381 fn test_search_k_0_returns_empty() {
382 let mut idx = FlatIvfIndex::new(1, 2);
383 idx.train(&[vec![0.0f32], vec![1.0]]);
384 idx.insert(0, vec![0.0]);
385 let results = idx.search(&[0.0], 0, 1);
386 assert!(results.is_empty());
387 }
388
389 #[test]
390 fn test_search_empty_index() {
391 let idx = FlatIvfIndex::new(2, 3);
392 let results = idx.search(&[0.0, 0.0], 5, 2);
393 assert!(results.is_empty());
394 }
395
396 #[test]
397 fn test_search_n_probe_all_cells() {
398 let mut idx = FlatIvfIndex::new(1, 3);
399 let train_vecs: Vec<Vec<f32>> = vec![vec![0.0], vec![5.0], vec![10.0]];
400 idx.train(&train_vecs);
401 idx.insert(0, vec![0.0]);
402 idx.insert(1, vec![5.0]);
403 idx.insert(2, vec![10.0]);
404 let results = idx.search(&[5.0], 3, 3);
405 assert_eq!(results.len(), 3);
406 }
407
408 #[test]
409 fn test_search_sorted_by_distance() {
410 let mut idx = FlatIvfIndex::new(1, 2);
411 idx.train(&[vec![0.0f32], vec![10.0]]);
412 idx.insert(0, vec![0.0]);
413 idx.insert(1, vec![3.0]);
414 idx.insert(2, vec![10.0]);
415 let results = idx.search(&[0.0], 3, 2);
416 for i in 1..results.len() {
417 assert!(results[i - 1].distance <= results[i].distance);
418 }
419 }
420
421 #[test]
424 fn test_l2_distance_zero() {
425 let a = vec![1.0f32, 2.0, 3.0];
426 assert!((FlatIvfIndex::l2_distance(&a, &a)).abs() < 1e-6);
427 }
428
429 #[test]
430 fn test_l2_distance_unit_vector() {
431 let a = vec![1.0f32, 0.0];
432 let b = vec![0.0f32, 0.0];
433 assert!((FlatIvfIndex::l2_distance(&a, &b) - 1.0).abs() < 1e-6);
434 }
435
436 #[test]
437 fn test_l2_distance_symmetric() {
438 let a = vec![1.0f32, 2.0, 3.0];
439 let b = vec![4.0f32, 5.0, 6.0];
440 let d1 = FlatIvfIndex::l2_distance(&a, &b);
441 let d2 = FlatIvfIndex::l2_distance(&b, &a);
442 assert!((d1 - d2).abs() < 1e-6);
443 }
444
445 #[test]
448 fn test_nearest_centroid_basic() {
449 let mut idx = FlatIvfIndex::new(1, 2);
450 idx.train(&[vec![0.0f32], vec![100.0]]);
451 let near_zero = idx.nearest_centroid(&[1.0]);
452 let near_hundred = idx.nearest_centroid(&[99.0]);
453 assert_ne!(near_zero, near_hundred);
454 }
455
456 #[test]
459 fn test_n_probe_1_vs_all() {
460 let mut idx = FlatIvfIndex::new(1, 4);
461 let tv: Vec<Vec<f32>> = vec![vec![0.0], vec![10.0], vec![20.0], vec![30.0]];
462 idx.train(&tv);
463 for i in 0..8u64 {
464 idx.insert(i, vec![(i as f32) * 5.0]);
465 }
466 let r1 = idx.search(&[15.0], 8, 1);
467 let r_all = idx.search(&[15.0], 8, 4);
468 assert!(r_all.len() >= r1.len());
470 }
471
472 #[test]
475 fn test_2d_cluster_separation() {
476 let mut idx = FlatIvfIndex::new(2, 2);
477 let tv = vec![
478 vec![0.0f32, 0.0],
479 vec![0.5, 0.5],
480 vec![100.0, 100.0],
481 vec![100.5, 100.5],
482 ];
483 idx.train(&tv);
484 idx.insert(10, vec![0.2, 0.2]);
485 idx.insert(11, vec![100.2, 100.2]);
486
487 let results = idx.search(&[0.1, 0.1], 1, 1);
488 if !results.is_empty() {
489 assert_eq!(results[0].id, 10);
490 }
491 }
492
493 #[test]
494 fn test_exact_match() {
495 let mut idx = FlatIvfIndex::new(3, 2);
496 idx.train(&[vec![1.0f32, 2.0, 3.0], vec![10.0, 20.0, 30.0]]);
497 idx.insert(99, vec![5.0, 5.0, 5.0]);
498 let query = vec![5.0f32, 5.0, 5.0];
499 let results = idx.search(&query, 1, 2);
500 assert!(!results.is_empty());
501 assert!((results[0].distance).abs() < 1e-5);
502 assert_eq!(results[0].id, 99);
503 }
504}