1use parking_lot::RwLock;
19use rustc_hash::FxHashMap;
20use usearch::MetricKind;
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26pub enum Metric {
27 Cos,
29 Ip,
31 L2sq,
33}
34
35impl Metric {
36 pub const fn from_usearch(m: MetricKind) -> Self {
39 match m {
40 MetricKind::IP => Metric::Ip,
41 MetricKind::L2sq => Metric::L2sq,
42 _ => Metric::Cos,
43 }
44 }
45}
46
47#[inline]
48fn dot(a: &[f32], b: &[f32]) -> f32 {
49 a.iter().zip(b).map(|(x, y)| x * y).sum()
50}
51
52#[inline]
53fn l2sq(a: &[f32], b: &[f32]) -> f32 {
54 a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
55}
56
57#[inline]
58fn norm(a: &[f32]) -> f32 {
59 dot(a, a).sqrt()
60}
61
62struct Inner {
63 dims: usize,
64 metric: Metric,
65 ids: Vec<u64>,
67 vecs: Vec<f32>,
69 norms: Vec<f32>,
71 assign: Vec<i32>,
73 id_pos: FxHashMap<u64, usize>,
75 centroids: Vec<f32>,
77 lists: Vec<Vec<usize>>,
79}
80
81impl Inner {
82 #[inline]
83 fn row(&self, pos: usize) -> &[f32] {
84 &self.vecs[pos * self.dims..(pos + 1) * self.dims]
85 }
86
87 #[inline]
88 fn centroid(&self, c: usize) -> &[f32] {
89 &self.centroids[c * self.dims..(c + 1) * self.dims]
90 }
91
92 #[inline]
94 fn dist_to_row(&self, q: &[f32], q_norm: f32, pos: usize) -> f32 {
95 let v = self.row(pos);
96 match self.metric {
97 Metric::Cos => {
98 let denom = q_norm * self.norms[pos];
99 if denom == 0.0 {
100 1.0
101 } else {
102 1.0 - dot(q, v) / denom
103 }
104 }
105 Metric::Ip => 1.0 - dot(q, v),
106 Metric::L2sq => l2sq(q, v),
107 }
108 }
109
110 #[inline]
113 fn dist_to_centroid(&self, q: &[f32], q_norm: f32, c: usize) -> f32 {
114 let v = self.centroid(c);
115 match self.metric {
116 Metric::Cos => {
117 let denom = q_norm * norm(v);
118 if denom == 0.0 {
119 1.0
120 } else {
121 1.0 - dot(q, v) / denom
122 }
123 }
124 Metric::Ip => 1.0 - dot(q, v),
125 Metric::L2sq => l2sq(q, v),
126 }
127 }
128
129 fn nearest_centroid(&self, v: &[f32], v_norm: f32) -> i32 {
130 let mut best = -1i32;
131 let mut best_d = f32::INFINITY;
132 for c in 0..self.lists.len() {
133 let d = self.dist_to_centroid(v, v_norm, c);
134 if d < best_d {
135 best_d = d;
136 best = c as i32;
137 }
138 }
139 best
140 }
141}
142
143pub struct IvfFlatIndex {
147 dims: usize,
148 metric: Metric,
149 nlist: usize,
152 nprobe: usize,
154 inner: RwLock<Inner>,
155}
156
157impl IvfFlatIndex {
158 pub fn new(dims: usize, metric: Metric, nlist: usize, nprobe: usize) -> Self {
159 let nlist = nlist.max(1);
160 let nprobe = nprobe.clamp(1, nlist);
161 Self {
162 dims,
163 metric,
164 nlist,
165 nprobe,
166 inner: RwLock::new(Inner {
167 dims,
168 metric,
169 ids: Vec::new(),
170 vecs: Vec::new(),
171 norms: Vec::new(),
172 assign: Vec::new(),
173 id_pos: FxHashMap::default(),
174 centroids: Vec::new(),
175 lists: Vec::new(),
176 }),
177 }
178 }
179
180 pub fn len(&self) -> usize {
181 self.inner.read().ids.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
185 self.len() == 0
186 }
187
188 pub fn is_trained(&self) -> bool {
189 !self.inner.read().centroids.is_empty()
190 }
191
192 pub const fn metric(&self) -> Metric {
193 self.metric
194 }
195
196 pub const fn nlist(&self) -> usize {
197 self.nlist
198 }
199
200 pub const fn nprobe(&self) -> usize {
201 self.nprobe
202 }
203
204 pub fn centroid_count(&self) -> usize {
206 self.inner.read().lists.len()
207 }
208
209 pub fn memory_bytes(&self) -> usize {
211 let g = self.inner.read();
212 g.vecs.len() * 4
213 + g.norms.len() * 4
214 + g.centroids.len() * 4
215 + g.assign.len() * 4
216 + g.ids.len() * 8
217 + g.id_pos.len() * 16
218 + g.lists.iter().map(|l| l.len() * 8).sum::<usize>()
219 }
220
221 pub fn upsert(&self, id: u64, v: &[f32]) -> Result<bool, String> {
224 if v.len() != self.dims {
225 return Err(format!(
226 "dimension mismatch: got {}, expected {}",
227 v.len(),
228 self.dims
229 ));
230 }
231 let mut g = self.inner.write();
232 let existed = g.id_pos.contains_key(&id);
233 if existed {
234 remove_locked(&mut g, id);
235 }
236
237 let pos = g.ids.len();
238 g.ids.push(id);
239 g.vecs.extend_from_slice(v);
240 g.norms.push(norm(v));
241 g.id_pos.insert(id, pos);
242
243 if !g.centroids.is_empty() {
247 let n = g.norms[pos];
248 let c = g.nearest_centroid(v, n);
249 g.assign.push(c);
250 if c >= 0 {
251 g.lists[c as usize].push(pos);
252 }
253 } else {
254 g.assign.push(-1);
255 }
256 Ok(existed)
257 }
258
259 pub fn remove(&self, id: u64) -> bool {
261 let mut g = self.inner.write();
262 if !g.id_pos.contains_key(&id) {
263 return false;
264 }
265 remove_locked(&mut g, id);
266 true
267 }
268
269 pub fn search(
273 &self,
274 query: &[f32],
275 top_k: usize,
276 nprobe_override: Option<usize>,
277 ) -> Result<Vec<(u64, f32)>, String> {
278 if query.len() != self.dims {
279 return Err(format!(
280 "dimension mismatch: got {}, expected {}",
281 query.len(),
282 self.dims
283 ));
284 }
285 if top_k == 0 {
286 return Ok(Vec::new());
287 }
288 let g = self.inner.read();
289 if g.ids.is_empty() {
290 return Ok(Vec::new());
291 }
292 let q_norm = norm(query);
293
294 let candidates: Vec<usize> = if g.centroids.is_empty() {
297 (0..g.ids.len()).collect()
298 } else {
299 let nprobe = nprobe_override.unwrap_or(self.nprobe).clamp(1, g.lists.len());
300 let mut cd: Vec<(usize, f32)> = (0..g.lists.len())
302 .map(|c| (c, g.dist_to_centroid(query, q_norm, c)))
303 .collect();
304 cd.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
305 let mut cand = Vec::new();
306 for &(c, _) in cd.iter().take(nprobe) {
307 cand.extend_from_slice(&g.lists[c]);
308 }
309 cand
310 };
311
312 let mut scored: Vec<(u64, f32)> = candidates
313 .into_iter()
314 .map(|pos| (g.ids[pos], g.dist_to_row(query, q_norm, pos)))
315 .collect();
316 scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
317 scored.truncate(top_k);
318 Ok(scored)
319 }
320
321 pub fn train(&self) -> Result<(), String> {
325 let mut g = self.inner.write();
326 let n = g.ids.len();
327 if n == 0 {
328 g.centroids.clear();
329 g.lists.clear();
330 return Ok(());
331 }
332 let k = self.nlist.min(n);
333 let dims = self.dims;
334
335 let mut centroids: Vec<f32> = Vec::with_capacity(k * dims);
338 centroids.extend_from_slice(g.row(0));
339 let mut min_d: Vec<f32> = (0..n)
340 .map(|p| dist_rows(&g, g.row(p), ¢roids[0..dims]))
341 .collect();
342 while centroids.len() / dims < k {
343 let mut far = 0usize;
345 let mut far_d = -1.0f32;
346 for (p, &d) in min_d.iter().enumerate() {
347 if d > far_d {
348 far_d = d;
349 far = p;
350 }
351 }
352 let start = centroids.len();
353 centroids.extend_from_slice(g.row(far));
354 let new_c = ¢roids[start..start + dims];
355 for (p, slot) in min_d.iter_mut().enumerate() {
356 let d = dist_rows(&g, g.row(p), new_c);
357 if d < *slot {
358 *slot = d;
359 }
360 }
361 }
362
363 let mut assign = vec![0i32; n];
365 for _ in 0..IVF_KMEANS_ITERS {
366 let mut changed = false;
368 for (p, a) in assign.iter_mut().enumerate() {
369 let row = g.row(p);
370 let mut best = 0usize;
371 let mut best_d = f32::INFINITY;
372 for c in 0..k {
373 let d = dist_rows(&g, row, ¢roids[c * dims..(c + 1) * dims]);
374 if d < best_d {
375 best_d = d;
376 best = c;
377 }
378 }
379 if *a != best as i32 {
380 *a = best as i32;
381 changed = true;
382 }
383 }
384 let mut sums = vec![0f32; k * dims];
386 let mut counts = vec![0usize; k];
387 for (p, &c_raw) in assign.iter().enumerate() {
388 let c = c_raw as usize;
389 counts[c] += 1;
390 let row = g.row(p);
391 let base = c * dims;
392 for (j, &x) in row.iter().enumerate() {
393 sums[base + j] += x;
394 }
395 }
396 for (c, &cnt) in counts.iter().enumerate() {
397 if cnt == 0 {
398 continue;
399 }
400 let inv = 1.0 / cnt as f32;
401 let base = c * dims;
402 for (j, slot) in centroids[base..base + dims].iter_mut().enumerate() {
403 *slot = sums[base + j] * inv;
404 }
405 }
406 if !changed {
407 break;
408 }
409 }
410
411 let mut lists: Vec<Vec<usize>> = vec![Vec::new(); k];
413 for (p, &c) in assign.iter().enumerate() {
414 lists[c as usize].push(p);
415 }
416 g.centroids = centroids;
417 g.lists = lists;
418 g.assign = assign;
419 Ok(())
420 }
421
422 pub fn bulk_load(&self, items: impl IntoIterator<Item = (u64, Vec<f32>)>) -> Result<(), String> {
425 let mut g = self.inner.write();
426 for (id, v) in items {
427 if v.len() != self.dims {
428 return Err(format!(
429 "dimension mismatch: got {}, expected {}",
430 v.len(),
431 self.dims
432 ));
433 }
434 let pos = g.ids.len();
435 g.ids.push(id);
436 g.vecs.extend_from_slice(&v);
437 g.norms.push(norm(&v));
438 g.assign.push(-1);
439 g.id_pos.insert(id, pos);
440 }
441 Ok(())
442 }
443}
444
445#[inline]
447fn dist_rows(inner: &Inner, a: &[f32], b: &[f32]) -> f32 {
448 match inner.metric {
449 Metric::Cos => {
450 let denom = norm(a) * norm(b);
451 if denom == 0.0 {
452 1.0
453 } else {
454 1.0 - dot(a, b) / denom
455 }
456 }
457 Metric::Ip => 1.0 - dot(a, b),
458 Metric::L2sq => l2sq(a, b),
459 }
460}
461
462fn remove_locked(g: &mut Inner, id: u64) {
465 let dims = g.dims;
466 let pos = g.id_pos[&id];
467 let last = g.ids.len() - 1;
468
469 let c_pos = g.assign[pos];
471 if c_pos >= 0 {
472 let list = &mut g.lists[c_pos as usize];
473 if let Some(i) = list.iter().position(|&p| p == pos) {
474 list.swap_remove(i);
475 }
476 }
477
478 if pos != last {
479 let moved_id = g.ids[last];
481 let moved_c = g.assign[last];
482 g.ids.swap_remove(pos);
483 g.assign.swap_remove(pos);
484 g.norms.swap_remove(pos);
485 let (head, tail) = g.vecs.split_at_mut(last * dims);
487 head[pos * dims..(pos + 1) * dims].copy_from_slice(&tail[..dims]);
488 g.vecs.truncate(last * dims);
489
490 g.id_pos.insert(moved_id, pos);
491 if moved_c >= 0 {
493 let list = &mut g.lists[moved_c as usize];
494 if let Some(i) = list.iter().position(|&p| p == last) {
495 list[i] = pos;
496 }
497 }
498 } else {
499 g.ids.pop();
500 g.assign.pop();
501 g.norms.pop();
502 g.vecs.truncate(last * dims);
503 }
504 g.id_pos.remove(&id);
505}
506
507const IVF_KMEANS_ITERS: usize = 15;
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 fn v(xs: &[f32]) -> Vec<f32> {
516 xs.to_vec()
517 }
518
519 #[test]
520 fn empty_search_returns_nothing() {
521 let idx = IvfFlatIndex::new(3, Metric::L2sq, 4, 2);
522 assert!(idx.search(&[1.0, 0.0, 0.0], 5, None).unwrap().is_empty());
523 assert_eq!(idx.len(), 0);
524 assert!(!idx.is_trained());
525 }
526
527 #[test]
528 fn brute_force_before_training_is_exact() {
529 let idx = IvfFlatIndex::new(2, Metric::L2sq, 8, 2);
530 idx.upsert(1, &v(&[0.0, 0.0])).unwrap();
531 idx.upsert(2, &v(&[10.0, 10.0])).unwrap();
532 idx.upsert(3, &v(&[1.0, 1.0])).unwrap();
533 let r = idx.search(&[0.0, 0.0], 2, None).unwrap();
535 assert_eq!(r[0].0, 1);
536 assert_eq!(r[1].0, 3);
537 }
538
539 #[test]
540 fn trained_search_finds_cluster_members() {
541 let idx = IvfFlatIndex::new(2, Metric::L2sq, 2, 2);
542 for i in 0..10 {
544 idx.upsert(i, &v(&[i as f32 * 0.01, 0.0])).unwrap();
545 }
546 for i in 10..20 {
547 idx.upsert(i, &v(&[100.0 + i as f32 * 0.01, 100.0])).unwrap();
548 }
549 idx.train().unwrap();
550 assert!(idx.is_trained());
551 assert_eq!(idx.centroid_count(), 2);
552 let r = idx.search(&[0.0, 0.0], 3, None).unwrap();
553 for (id, _) in &r {
555 assert!(*id < 10, "unexpected id {id} from far cluster");
556 }
557 }
558
559 #[test]
560 fn upsert_replaces_and_counts() {
561 let idx = IvfFlatIndex::new(2, Metric::L2sq, 4, 2);
562 assert!(!idx.upsert(1, &v(&[0.0, 0.0])).unwrap());
563 assert!(idx.upsert(1, &v(&[5.0, 5.0])).unwrap()); assert_eq!(idx.len(), 1);
565 let r = idx.search(&[5.0, 5.0], 1, None).unwrap();
566 assert_eq!(r[0].0, 1);
567 assert!(r[0].1 < 0.001, "distance to exact match should be ~0");
568 }
569
570 #[test]
571 fn remove_keeps_index_consistent() {
572 let idx = IvfFlatIndex::new(2, Metric::L2sq, 3, 3);
573 for i in 0..6 {
574 idx.upsert(i, &v(&[i as f32, 0.0])).unwrap();
575 }
576 idx.train().unwrap();
577 assert!(idx.remove(2));
578 assert!(!idx.remove(2)); assert_eq!(idx.len(), 5);
580 let r = idx.search(&[5.0, 0.0], 6, None).unwrap();
582 let ids: Vec<u64> = r.iter().map(|(id, _)| *id).collect();
583 assert!(!ids.contains(&2));
584 assert!(ids.contains(&5));
585 assert_eq!(ids.len(), 5);
586 }
587
588 #[test]
589 fn add_after_training_is_findable() {
590 let idx = IvfFlatIndex::new(2, Metric::L2sq, 2, 2);
591 for i in 0..8 {
592 idx.upsert(i, &v(&[i as f32, 0.0])).unwrap();
593 }
594 idx.train().unwrap();
595 idx.upsert(99, &v(&[3.5, 0.0])).unwrap();
596 let r = idx.search(&[3.5, 0.0], 1, None).unwrap();
597 assert_eq!(r[0].0, 99);
598 }
599
600 #[test]
601 fn cosine_metric_ranks_by_direction() {
602 let idx = IvfFlatIndex::new(2, Metric::Cos, 4, 4);
603 idx.upsert(1, &v(&[1.0, 0.0])).unwrap();
604 idx.upsert(2, &v(&[0.0, 1.0])).unwrap();
605 idx.upsert(3, &v(&[10.0, 0.0])).unwrap(); let r = idx.search(&[2.0, 0.0], 3, None).unwrap();
607 assert!(r[0].0 == 1 || r[0].0 == 3);
609 assert!(r[1].0 == 1 || r[1].0 == 3);
610 assert_eq!(r[2].0, 2);
611 }
612
613 #[test]
614 fn dimension_mismatch_errors() {
615 let idx = IvfFlatIndex::new(3, Metric::L2sq, 2, 2);
616 assert!(idx.upsert(1, &v(&[1.0, 2.0])).is_err());
617 assert!(idx.search(&[1.0, 2.0], 1, None).is_err());
618 }
619
620 #[test]
621 fn retrain_after_many_inserts() {
622 let idx = IvfFlatIndex::new(4, Metric::L2sq, 4, 4);
623 for i in 0..50 {
624 idx.upsert(i, &v(&[i as f32, 0.0, 0.0, 0.0])).unwrap();
625 }
626 idx.train().unwrap();
627 let c1 = idx.centroid_count();
628 for i in 50..100 {
629 idx.upsert(i, &v(&[i as f32, 0.0, 0.0, 0.0])).unwrap();
630 }
631 idx.train().unwrap(); assert_eq!(idx.len(), 100);
633 assert_eq!(c1, 4);
634 let r = idx.search(&[75.0, 0.0, 0.0, 0.0], 1, None).unwrap();
636 assert_eq!(r[0].0, 75);
637 }
638}