Skip to main content

mcp_memory/
ivf.rs

1//! A self-contained **IVF-Flat** (inverted-file, flat-storage) approximate
2//! nearest-neighbour index.
3//!
4//! IVF-Flat partitions the vector space into `nlist` Voronoi cells via k-means.
5//! Each vector is stored verbatim (no quantization — the "flat" part) in the
6//! cell of its nearest centroid. A query scans only the `nprobe` cells whose
7//! centroids are closest to it, trading a little recall for a large speed-up
8//! over brute force on big collections. It complements the HNSW (usearch)
9//! backend: IVF trains/builds far faster and uses less memory per vector, which
10//! suits large, batch-ingested, periodically-rebuilt corpora typical of RAG.
11//!
12//! All vectors live in RAM (rebuilt from the SQLite `vector_embedding` table on
13//! open, exactly like the HNSW backend), so this type owns no persistence.
14//! Until the index is trained — or when the collection is smaller than `nlist`
15//! — search transparently falls back to an exact brute-force scan, so results
16//! are always correct, just not always sub-linear.
17
18use parking_lot::RwLock;
19use rustc_hash::FxHashMap;
20use usearch::MetricKind;
21
22/// Distance functions supported by the IVF index. Smaller is always "closer",
23/// matching the convention usearch uses, so the two backends are interchangeable
24/// to the rest of the store.
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26pub enum Metric {
27    /// `1 - cosine_similarity` (range `[0, 2]`).
28    Cos,
29    /// `1 - inner_product` (raw dot product; assumes caller-normalized vectors).
30    Ip,
31    /// Squared Euclidean distance.
32    L2sq,
33}
34
35impl Metric {
36    /// Map a usearch metric onto the IVF metric, falling back to cosine for the
37    /// metrics IVF does not model.
38    pub const fn from_usearch(m: MetricKind) -> Self {
39        match m {
40            MetricKind::IP => Metric::Ip,
41            MetricKind::L2sq => Metric::L2sq,
42            _ => Metric::Cos,
43        }
44    }
45}
46
47#[inline]
48fn dot(a: &[f32], b: &[f32]) -> f32 {
49    a.iter().zip(b).map(|(x, y)| x * y).sum()
50}
51
52#[inline]
53fn l2sq(a: &[f32], b: &[f32]) -> f32 {
54    a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
55}
56
57#[inline]
58fn norm(a: &[f32]) -> f32 {
59    dot(a, a).sqrt()
60}
61
62struct Inner {
63    dims: usize,
64    metric: Metric,
65    /// Entity ids, parallel to `vecs` rows, `norms` and `assign`.
66    ids: Vec<u64>,
67    /// Flat row-major vectors: `ids.len() * dims` floats.
68    vecs: Vec<f32>,
69    /// Cached L2 norm per stored vector (used by the cosine metric).
70    norms: Vec<f32>,
71    /// Centroid index each vector belongs to, or `-1` when not yet assigned.
72    assign: Vec<i32>,
73    /// `id -> row position` for O(1) upsert/remove.
74    id_pos: FxHashMap<u64, usize>,
75    /// Trained centroids, flat row-major: `centroid_count * dims`. Empty until trained.
76    centroids: Vec<f32>,
77    /// Inverted lists: for each centroid, the row positions assigned to it.
78    lists: Vec<Vec<usize>>,
79}
80
81impl Inner {
82    #[inline]
83    fn row(&self, pos: usize) -> &[f32] {
84        &self.vecs[pos * self.dims..(pos + 1) * self.dims]
85    }
86
87    #[inline]
88    fn centroid(&self, c: usize) -> &[f32] {
89        &self.centroids[c * self.dims..(c + 1) * self.dims]
90    }
91
92    /// Distance between a query (with precomputed norm for cosine) and stored row.
93    #[inline]
94    fn dist_to_row(&self, q: &[f32], q_norm: f32, pos: usize) -> f32 {
95        let v = self.row(pos);
96        match self.metric {
97            Metric::Cos => {
98                let denom = q_norm * self.norms[pos];
99                if denom == 0.0 {
100                    1.0
101                } else {
102                    1.0 - dot(q, v) / denom
103                }
104            }
105            Metric::Ip => 1.0 - dot(q, v),
106            Metric::L2sq => l2sq(q, v),
107        }
108    }
109
110    /// Distance between a query and a centroid (centroid norm computed on the fly
111    /// — there are far fewer centroids than vectors, so this stays cheap).
112    #[inline]
113    fn dist_to_centroid(&self, q: &[f32], q_norm: f32, c: usize) -> f32 {
114        let v = self.centroid(c);
115        match self.metric {
116            Metric::Cos => {
117                let denom = q_norm * norm(v);
118                if denom == 0.0 {
119                    1.0
120                } else {
121                    1.0 - dot(q, v) / denom
122                }
123            }
124            Metric::Ip => 1.0 - dot(q, v),
125            Metric::L2sq => l2sq(q, v),
126        }
127    }
128
129    fn nearest_centroid(&self, v: &[f32], v_norm: f32) -> i32 {
130        let mut best = -1i32;
131        let mut best_d = f32::INFINITY;
132        for c in 0..self.lists.len() {
133            let d = self.dist_to_centroid(v, v_norm, c);
134            if d < best_d {
135                best_d = d;
136                best = c as i32;
137            }
138        }
139        best
140    }
141}
142
143/// An IVF-Flat index. All methods take `&self`; internal mutable state is guarded
144/// by a single `RwLock`, so the index is `Send + Sync` and safe to share behind
145/// an `Arc` like the usearch backend.
146pub struct IvfFlatIndex {
147    dims: usize,
148    metric: Metric,
149    /// Target number of Voronoi cells (centroids). Actual count is capped at the
150    /// number of stored vectors when training.
151    nlist: usize,
152    /// Default number of cells probed per query (clamped to the trained count).
153    nprobe: usize,
154    inner: RwLock<Inner>,
155}
156
157impl IvfFlatIndex {
158    pub fn new(dims: usize, metric: Metric, nlist: usize, nprobe: usize) -> Self {
159        let nlist = nlist.max(1);
160        let nprobe = nprobe.clamp(1, nlist);
161        Self {
162            dims,
163            metric,
164            nlist,
165            nprobe,
166            inner: RwLock::new(Inner {
167                dims,
168                metric,
169                ids: Vec::new(),
170                vecs: Vec::new(),
171                norms: Vec::new(),
172                assign: Vec::new(),
173                id_pos: FxHashMap::default(),
174                centroids: Vec::new(),
175                lists: Vec::new(),
176            }),
177        }
178    }
179
180    pub fn len(&self) -> usize {
181        self.inner.read().ids.len()
182    }
183
184    pub fn is_empty(&self) -> bool {
185        self.len() == 0
186    }
187
188    pub fn is_trained(&self) -> bool {
189        !self.inner.read().centroids.is_empty()
190    }
191
192    pub const fn metric(&self) -> Metric {
193        self.metric
194    }
195
196    pub const fn nlist(&self) -> usize {
197        self.nlist
198    }
199
200    pub const fn nprobe(&self) -> usize {
201        self.nprobe
202    }
203
204    /// The number of trained centroids (0 until [`IvfFlatIndex::train`] runs).
205    pub fn centroid_count(&self) -> usize {
206        self.inner.read().lists.len()
207    }
208
209    /// Approximate resident bytes: stored vectors + norms + centroids + bookkeeping.
210    pub fn memory_bytes(&self) -> usize {
211        let g = self.inner.read();
212        g.vecs.len() * 4
213            + g.norms.len() * 4
214            + g.centroids.len() * 4
215            + g.assign.len() * 4
216            + g.ids.len() * 8
217            + g.id_pos.len() * 16
218            + g.lists.iter().map(|l| l.len() * 8).sum::<usize>()
219    }
220
221    /// Insert or replace the vector for `id`. Returns `true` if it replaced an
222    /// existing entry.
223    pub fn upsert(&self, id: u64, v: &[f32]) -> Result<bool, String> {
224        if v.len() != self.dims {
225            return Err(format!(
226                "dimension mismatch: got {}, expected {}",
227                v.len(),
228                self.dims
229            ));
230        }
231        let mut g = self.inner.write();
232        let existed = g.id_pos.contains_key(&id);
233        if existed {
234            remove_locked(&mut g, id);
235        }
236
237        let pos = g.ids.len();
238        g.ids.push(id);
239        g.vecs.extend_from_slice(v);
240        g.norms.push(norm(v));
241        g.id_pos.insert(id, pos);
242
243        // Assign into a cell when the index is already trained, so the new vector
244        // is reachable by probe-limited search; otherwise leave it unassigned
245        // (it is still found by the brute-force fallback).
246        if !g.centroids.is_empty() {
247            let n = g.norms[pos];
248            let c = g.nearest_centroid(v, n);
249            g.assign.push(c);
250            if c >= 0 {
251                g.lists[c as usize].push(pos);
252            }
253        } else {
254            g.assign.push(-1);
255        }
256        Ok(existed)
257    }
258
259    /// Remove the vector for `id`. Returns `true` if it existed.
260    pub fn remove(&self, id: u64) -> bool {
261        let mut g = self.inner.write();
262        if !g.id_pos.contains_key(&id) {
263            return false;
264        }
265        remove_locked(&mut g, id);
266        true
267    }
268
269    /// Return the `top_k` nearest ids with their distances (ascending). Uses
270    /// `nprobe_override` cells when given, else the configured default; falls back
271    /// to an exact scan when untrained.
272    pub fn search(
273        &self,
274        query: &[f32],
275        top_k: usize,
276        nprobe_override: Option<usize>,
277    ) -> Result<Vec<(u64, f32)>, String> {
278        if query.len() != self.dims {
279            return Err(format!(
280                "dimension mismatch: got {}, expected {}",
281                query.len(),
282                self.dims
283            ));
284        }
285        if top_k == 0 {
286            return Ok(Vec::new());
287        }
288        let g = self.inner.read();
289        if g.ids.is_empty() {
290            return Ok(Vec::new());
291        }
292        let q_norm = norm(query);
293
294        // Gather candidate row positions: either from the probed lists, or all
295        // rows when the index has not been trained yet.
296        let candidates: Vec<usize> = if g.centroids.is_empty() {
297            (0..g.ids.len()).collect()
298        } else {
299            let nprobe = nprobe_override.unwrap_or(self.nprobe).clamp(1, g.lists.len());
300            // Rank centroids by distance, take the nearest `nprobe`.
301            let mut cd: Vec<(usize, f32)> = (0..g.lists.len())
302                .map(|c| (c, g.dist_to_centroid(query, q_norm, c)))
303                .collect();
304            cd.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
305            let mut cand = Vec::new();
306            for &(c, _) in cd.iter().take(nprobe) {
307                cand.extend_from_slice(&g.lists[c]);
308            }
309            cand
310        };
311
312        let mut scored: Vec<(u64, f32)> = candidates
313            .into_iter()
314            .map(|pos| (g.ids[pos], g.dist_to_row(query, q_norm, pos)))
315            .collect();
316        scored.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
317        scored.truncate(top_k);
318        Ok(scored)
319    }
320
321    /// (Re)train centroids over the currently stored vectors via k-means, then
322    /// rebuild the inverted lists. Cheap no-op when empty. The centroid count is
323    /// `min(nlist, n)`.
324    pub fn train(&self) -> Result<(), String> {
325        let mut g = self.inner.write();
326        let n = g.ids.len();
327        if n == 0 {
328            g.centroids.clear();
329            g.lists.clear();
330            return Ok(());
331        }
332        let k = self.nlist.min(n);
333        let dims = self.dims;
334
335        // k-means++ style seeding: first centroid random-ish (row 0), each
336        // subsequent centroid the row farthest from its nearest chosen centroid.
337        let mut centroids: Vec<f32> = Vec::with_capacity(k * dims);
338        centroids.extend_from_slice(g.row(0));
339        let mut min_d: Vec<f32> = (0..n)
340            .map(|p| dist_rows(&g, g.row(p), &centroids[0..dims]))
341            .collect();
342        while centroids.len() / dims < k {
343            // Pick the row with the largest distance to its nearest centroid.
344            let mut far = 0usize;
345            let mut far_d = -1.0f32;
346            for (p, &d) in min_d.iter().enumerate() {
347                if d > far_d {
348                    far_d = d;
349                    far = p;
350                }
351            }
352            let start = centroids.len();
353            centroids.extend_from_slice(g.row(far));
354            let new_c = &centroids[start..start + dims];
355            for (p, slot) in min_d.iter_mut().enumerate() {
356                let d = dist_rows(&g, g.row(p), new_c);
357                if d < *slot {
358                    *slot = d;
359                }
360            }
361        }
362
363        // Lloyd iterations.
364        let mut assign = vec![0i32; n];
365        for _ in 0..IVF_KMEANS_ITERS {
366            // Assignment step.
367            let mut changed = false;
368            for (p, a) in assign.iter_mut().enumerate() {
369                let row = g.row(p);
370                let mut best = 0usize;
371                let mut best_d = f32::INFINITY;
372                for c in 0..k {
373                    let d = dist_rows(&g, row, &centroids[c * dims..(c + 1) * dims]);
374                    if d < best_d {
375                        best_d = d;
376                        best = c;
377                    }
378                }
379                if *a != best as i32 {
380                    *a = best as i32;
381                    changed = true;
382                }
383            }
384            // Update step: centroid = mean of members; keep old centroid if empty.
385            let mut sums = vec![0f32; k * dims];
386            let mut counts = vec![0usize; k];
387            for (p, &c_raw) in assign.iter().enumerate() {
388                let c = c_raw as usize;
389                counts[c] += 1;
390                let row = g.row(p);
391                let base = c * dims;
392                for (j, &x) in row.iter().enumerate() {
393                    sums[base + j] += x;
394                }
395            }
396            for (c, &cnt) in counts.iter().enumerate() {
397                if cnt == 0 {
398                    continue;
399                }
400                let inv = 1.0 / cnt as f32;
401                let base = c * dims;
402                for (j, slot) in centroids[base..base + dims].iter_mut().enumerate() {
403                    *slot = sums[base + j] * inv;
404                }
405            }
406            if !changed {
407                break;
408            }
409        }
410
411        // Commit centroids + inverted lists + per-row assignment.
412        let mut lists: Vec<Vec<usize>> = vec![Vec::new(); k];
413        for (p, &c) in assign.iter().enumerate() {
414            lists[c as usize].push(p);
415        }
416        g.centroids = centroids;
417        g.lists = lists;
418        g.assign = assign;
419        Ok(())
420    }
421
422    /// Replace the entire contents in one shot (used for the initial bulk load).
423    /// Does not train; call [`IvfFlatIndex::train`] afterwards.
424    pub fn bulk_load(&self, items: impl IntoIterator<Item = (u64, Vec<f32>)>) -> Result<(), String> {
425        let mut g = self.inner.write();
426        for (id, v) in items {
427            if v.len() != self.dims {
428                return Err(format!(
429                    "dimension mismatch: got {}, expected {}",
430                    v.len(),
431                    self.dims
432                ));
433            }
434            let pos = g.ids.len();
435            g.ids.push(id);
436            g.vecs.extend_from_slice(&v);
437            g.norms.push(norm(&v));
438            g.assign.push(-1);
439            g.id_pos.insert(id, pos);
440        }
441        Ok(())
442    }
443}
444
445/// Distance between two raw rows under the inner metric (used during training).
446#[inline]
447fn dist_rows(inner: &Inner, a: &[f32], b: &[f32]) -> f32 {
448    match inner.metric {
449        Metric::Cos => {
450            let denom = norm(a) * norm(b);
451            if denom == 0.0 {
452                1.0
453            } else {
454                1.0 - dot(a, b) / denom
455            }
456        }
457        Metric::Ip => 1.0 - dot(a, b),
458        Metric::L2sq => l2sq(a, b),
459    }
460}
461
462/// Remove `id` from a locked `Inner` via swap-remove, keeping `id_pos`, the row
463/// arrays and the inverted lists consistent. Caller guarantees `id` is present.
464fn remove_locked(g: &mut Inner, id: u64) {
465    let dims = g.dims;
466    let pos = g.id_pos[&id];
467    let last = g.ids.len() - 1;
468
469    // Detach `pos` from its inverted list (if assigned).
470    let c_pos = g.assign[pos];
471    if c_pos >= 0 {
472        let list = &mut g.lists[c_pos as usize];
473        if let Some(i) = list.iter().position(|&p| p == pos) {
474            list.swap_remove(i);
475        }
476    }
477
478    if pos != last {
479        // Move the last row into `pos`.
480        let moved_id = g.ids[last];
481        let moved_c = g.assign[last];
482        g.ids.swap_remove(pos);
483        g.assign.swap_remove(pos);
484        g.norms.swap_remove(pos);
485        // vecs is flat: copy the last row over `pos`, then truncate.
486        let (head, tail) = g.vecs.split_at_mut(last * dims);
487        head[pos * dims..(pos + 1) * dims].copy_from_slice(&tail[..dims]);
488        g.vecs.truncate(last * dims);
489
490        g.id_pos.insert(moved_id, pos);
491        // Repoint the moved row in its list from `last` to `pos`.
492        if moved_c >= 0 {
493            let list = &mut g.lists[moved_c as usize];
494            if let Some(i) = list.iter().position(|&p| p == last) {
495                list[i] = pos;
496            }
497        }
498    } else {
499        g.ids.pop();
500        g.assign.pop();
501        g.norms.pop();
502        g.vecs.truncate(last * dims);
503    }
504    g.id_pos.remove(&id);
505}
506
507/// Lloyd iterations during training — bounded so a large collection cannot stall
508/// startup/reindex.
509const IVF_KMEANS_ITERS: usize = 15;
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    fn v(xs: &[f32]) -> Vec<f32> {
516        xs.to_vec()
517    }
518
519    #[test]
520    fn empty_search_returns_nothing() {
521        let idx = IvfFlatIndex::new(3, Metric::L2sq, 4, 2);
522        assert!(idx.search(&[1.0, 0.0, 0.0], 5, None).unwrap().is_empty());
523        assert_eq!(idx.len(), 0);
524        assert!(!idx.is_trained());
525    }
526
527    #[test]
528    fn brute_force_before_training_is_exact() {
529        let idx = IvfFlatIndex::new(2, Metric::L2sq, 8, 2);
530        idx.upsert(1, &v(&[0.0, 0.0])).unwrap();
531        idx.upsert(2, &v(&[10.0, 10.0])).unwrap();
532        idx.upsert(3, &v(&[1.0, 1.0])).unwrap();
533        // Untrained: still returns the exact nearest.
534        let r = idx.search(&[0.0, 0.0], 2, None).unwrap();
535        assert_eq!(r[0].0, 1);
536        assert_eq!(r[1].0, 3);
537    }
538
539    #[test]
540    fn trained_search_finds_cluster_members() {
541        let idx = IvfFlatIndex::new(2, Metric::L2sq, 2, 2);
542        // Two well-separated clusters.
543        for i in 0..10 {
544            idx.upsert(i, &v(&[i as f32 * 0.01, 0.0])).unwrap();
545        }
546        for i in 10..20 {
547            idx.upsert(i, &v(&[100.0 + i as f32 * 0.01, 100.0])).unwrap();
548        }
549        idx.train().unwrap();
550        assert!(idx.is_trained());
551        assert_eq!(idx.centroid_count(), 2);
552        let r = idx.search(&[0.0, 0.0], 3, None).unwrap();
553        // All three nearest should come from the first cluster (ids < 10).
554        for (id, _) in &r {
555            assert!(*id < 10, "unexpected id {id} from far cluster");
556        }
557    }
558
559    #[test]
560    fn upsert_replaces_and_counts() {
561        let idx = IvfFlatIndex::new(2, Metric::L2sq, 4, 2);
562        assert!(!idx.upsert(1, &v(&[0.0, 0.0])).unwrap());
563        assert!(idx.upsert(1, &v(&[5.0, 5.0])).unwrap()); // replaced
564        assert_eq!(idx.len(), 1);
565        let r = idx.search(&[5.0, 5.0], 1, None).unwrap();
566        assert_eq!(r[0].0, 1);
567        assert!(r[0].1 < 0.001, "distance to exact match should be ~0");
568    }
569
570    #[test]
571    fn remove_keeps_index_consistent() {
572        let idx = IvfFlatIndex::new(2, Metric::L2sq, 3, 3);
573        for i in 0..6 {
574            idx.upsert(i, &v(&[i as f32, 0.0])).unwrap();
575        }
576        idx.train().unwrap();
577        assert!(idx.remove(2));
578        assert!(!idx.remove(2)); // already gone
579        assert_eq!(idx.len(), 5);
580        // The removed id must not appear; remaining ids must still be searchable.
581        let r = idx.search(&[5.0, 0.0], 6, None).unwrap();
582        let ids: Vec<u64> = r.iter().map(|(id, _)| *id).collect();
583        assert!(!ids.contains(&2));
584        assert!(ids.contains(&5));
585        assert_eq!(ids.len(), 5);
586    }
587
588    #[test]
589    fn add_after_training_is_findable() {
590        let idx = IvfFlatIndex::new(2, Metric::L2sq, 2, 2);
591        for i in 0..8 {
592            idx.upsert(i, &v(&[i as f32, 0.0])).unwrap();
593        }
594        idx.train().unwrap();
595        idx.upsert(99, &v(&[3.5, 0.0])).unwrap();
596        let r = idx.search(&[3.5, 0.0], 1, None).unwrap();
597        assert_eq!(r[0].0, 99);
598    }
599
600    #[test]
601    fn cosine_metric_ranks_by_direction() {
602        let idx = IvfFlatIndex::new(2, Metric::Cos, 4, 4);
603        idx.upsert(1, &v(&[1.0, 0.0])).unwrap();
604        idx.upsert(2, &v(&[0.0, 1.0])).unwrap();
605        idx.upsert(3, &v(&[10.0, 0.0])).unwrap(); // same direction as id 1, bigger magnitude
606        let r = idx.search(&[2.0, 0.0], 3, None).unwrap();
607        // Cosine ignores magnitude: ids 1 and 3 tie at distance ~0, id 2 is far.
608        assert!(r[0].0 == 1 || r[0].0 == 3);
609        assert!(r[1].0 == 1 || r[1].0 == 3);
610        assert_eq!(r[2].0, 2);
611    }
612
613    #[test]
614    fn dimension_mismatch_errors() {
615        let idx = IvfFlatIndex::new(3, Metric::L2sq, 2, 2);
616        assert!(idx.upsert(1, &v(&[1.0, 2.0])).is_err());
617        assert!(idx.search(&[1.0, 2.0], 1, None).is_err());
618    }
619
620    #[test]
621    fn retrain_after_many_inserts() {
622        let idx = IvfFlatIndex::new(4, Metric::L2sq, 4, 4);
623        for i in 0..50 {
624            idx.upsert(i, &v(&[i as f32, 0.0, 0.0, 0.0])).unwrap();
625        }
626        idx.train().unwrap();
627        let c1 = idx.centroid_count();
628        for i in 50..100 {
629            idx.upsert(i, &v(&[i as f32, 0.0, 0.0, 0.0])).unwrap();
630        }
631        idx.train().unwrap(); // retrain over the larger set
632        assert_eq!(idx.len(), 100);
633        assert_eq!(c1, 4);
634        // Exact nearest still correct after retrain.
635        let r = idx.search(&[75.0, 0.0, 0.0, 0.0], 1, None).unwrap();
636        assert_eq!(r[0].0, 75);
637    }
638}