Skip to main content

mcp_memory/
vector_store.rs

1use std::path::Path;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use parking_lot::{Mutex, RwLock};
7use petgraph::graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::Directed;
10use rusqlite::{params, Connection};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
13
14use crate::errors::{MCSError, Result};
15use crate::ivf::{IvfFlatIndex, Metric as IvfMetric};
16use crate::kg::push_json_str;
17
18pub type EntityId = i64;
19
20/// Unifies the two ANN backends behind one small surface so [`VectorStore`] does
21/// not branch on the index kind at every call site. Distances follow usearch's
22/// convention (smaller = closer) for both backends.
23enum AnnIndex {
24    Hnsw(Arc<Index>),
25    Ivf(Box<IvfFlatIndex>),
26}
27
28impl AnnIndex {
29    /// Current allocated capacity (HNSW) or live count (IVF; it grows on demand).
30    fn capacity(&self) -> usize {
31        match self {
32            AnnIndex::Hnsw(i) => i.capacity(),
33            AnnIndex::Ivf(i) => i.len(),
34        }
35    }
36
37    /// Ensure room for `target` vectors. No-op for IVF (a growable `Vec`).
38    fn reserve(&self, target: usize) -> Result<()> {
39        if let AnnIndex::Hnsw(i) = self {
40            i.reserve_capacity_and_threads(target, 1)
41                .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
42        }
43        Ok(())
44    }
45
46    /// Add `id`/`vector` to the index (caller has already removed any prior entry).
47    fn add(&self, id: u64, vector: &[f32]) -> Result<()> {
48        match self {
49            AnnIndex::Hnsw(i) => i
50                .add(id, vector)
51                .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}"))),
52            AnnIndex::Ivf(i) => i
53                .upsert(id, vector)
54                .map(|_| ())
55                .map_err(MCSError::MemoryError),
56        }
57    }
58
59    /// Remove `id`; returns whether it existed.
60    fn remove(&self, id: u64) -> Result<bool> {
61        match self {
62            AnnIndex::Hnsw(i) => i
63                .remove(id)
64                .map(|n| n > 0)
65                .map_err(|e| MCSError::MemoryError(format!("usearch remove: {e}"))),
66            AnnIndex::Ivf(i) => Ok(i.remove(id)),
67        }
68    }
69
70    /// Nearest `top_k` ids with distances (ascending). `nprobe` applies to IVF only.
71    fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Result<Vec<(u64, f32)>> {
72        match self {
73            AnnIndex::Hnsw(i) => {
74                let m = i
75                    .search(query, top_k)
76                    .map_err(|e| MCSError::MemoryError(format!("usearch search: {e}")))?;
77                let cap = m.keys.len().min(m.distances.len());
78                Ok((0..cap).map(|j| (m.keys[j], m.distances[j])).collect())
79            }
80            AnnIndex::Ivf(i) => i.search(query, top_k, nprobe).map_err(MCSError::MemoryError),
81        }
82    }
83
84    /// (Re)train the IVF centroids. No-op for HNSW.
85    fn train(&self) -> Result<()> {
86        if let AnnIndex::Ivf(i) = self {
87            i.train().map_err(MCSError::MemoryError)?;
88        }
89        Ok(())
90    }
91
92    const fn kind(&self) -> IndexKind {
93        match self {
94            AnnIndex::Hnsw(_) => IndexKind::Hnsw,
95            AnnIndex::Ivf(_) => IndexKind::Ivf,
96        }
97    }
98
99    fn memory_bytes(&self) -> usize {
100        match self {
101            AnnIndex::Hnsw(i) => i.memory_usage(),
102            AnnIndex::Ivf(i) => i.memory_bytes(),
103        }
104    }
105
106    /// (graph_bytes, vectors_bytes). IVF has no graph, so its graph component is 0.
107    fn memory_breakdown(&self) -> (usize, usize) {
108        match self {
109            AnnIndex::Hnsw(i) => {
110                let s = i.memory_stats();
111                (
112                    s.graph_allocated + s.graph_reserved,
113                    s.vectors_allocated + s.vectors_reserved,
114                )
115            }
116            AnnIndex::Ivf(i) => (0, i.memory_bytes()),
117        }
118    }
119}
120
121#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
122#[repr(C)]
123struct BlobHeader {
124    dims: u32,
125}
126
127/// The ANN index backend a [`VectorStore`] uses.
128#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
129pub enum IndexKind {
130    /// usearch HNSW: best recall/latency, higher memory and build cost.
131    #[default]
132    Hnsw,
133    /// IVF-Flat: k-means partitioned exact-within-cell search. Cheaper to build
134    /// and lighter in memory; suits large, batch-ingested, periodically-rebuilt
135    /// corpora. Exact (brute-force) until trained.
136    Ivf,
137}
138
139/// Tunable parameters for the vector index. Built from CLI flags;
140/// [`VectorConfig::new`] supplies the defaults used by tests and any caller that
141/// only cares about the embedding dimension.
142#[derive(Clone, Copy, Debug)]
143pub struct VectorConfig {
144    /// Embedding dimension. All upserted/queried vectors must match this.
145    pub dims: u32,
146    /// Which ANN backend to use.
147    pub index_kind: IndexKind,
148    /// Distance metric used by the index.
149    pub metric: MetricKind,
150    /// On-disk/in-index scalar representation (enables quantization). HNSW only.
151    pub quantization: ScalarKind,
152    /// HNSW graph degree (`M`). Higher = better recall, more memory.
153    pub connectivity: usize,
154    /// HNSW `efConstruction`. Higher = better index quality, slower inserts.
155    pub expansion_add: usize,
156    /// HNSW `efSearch`. Higher = better recall, slower queries.
157    pub expansion_search: usize,
158    /// IVF number of Voronoi cells (centroids). IVF only.
159    pub ivf_nlist: usize,
160    /// IVF default cells probed per query. IVF only.
161    pub ivf_nprobe: usize,
162}
163
164impl VectorConfig {
165    /// Default HNSW configuration for the given embedding dimension.
166    pub const fn new(dims: u32) -> Self {
167        Self {
168            dims,
169            index_kind: IndexKind::Hnsw,
170            metric: MetricKind::Cos,
171            quantization: ScalarKind::F32,
172            connectivity: 16,
173            expansion_add: 200,
174            expansion_search: 50,
175            ivf_nlist: 256,
176            ivf_nprobe: 8,
177        }
178    }
179}
180
181pub struct VectorStore {
182    pub name_to_id: Arc<DashMap<String, EntityId>>,
183    pub id_to_name: Arc<DashMap<EntityId, String>>,
184
185    pub(crate) graph: Arc<RwLock<StableGraph<EntityId, (), Directed, u32>>>,
186    pub(crate) node_map: Arc<DashMap<EntityId, NodeIndex<u32>>>,
187
188    index: AnnIndex,
189    pub(crate) db: Mutex<Connection>,
190
191    pub dims: u32,
192    pub count: AtomicUsize,
193    /// Default cells probed per IVF query (ignored by HNSW).
194    ivf_nprobe: usize,
195
196    pub db_path: std::path::PathBuf,
197}
198
199fn sqlite_err(e: rusqlite::Error) -> MCSError {
200    MCSError::IoError(std::io::Error::other(e))
201}
202
203thread_local! {
204    static SCRATCH: std::cell::RefCell<Vec<f32>> = const {
205        std::cell::RefCell::new(Vec::new())
206    };
207}
208
209pub fn with_scratch<R>(f: impl FnOnce(&mut Vec<f32>) -> R) -> R {
210    SCRATCH.with(|cell| {
211        let mut buf = cell.borrow_mut();
212        buf.clear();
213        f(&mut buf)
214    })
215}
216
217fn serialize_embedding(emb: &[f32]) -> Vec<u8> {
218    let header = BlobHeader {
219        dims: emb.len() as u32,
220    };
221    let f32_bytes: &[u8] = unsafe {
222        std::slice::from_raw_parts(emb.as_ptr() as *const u8, emb.len() * 4)
223    };
224    let mut bytes = Vec::with_capacity(4 + f32_bytes.len());
225    bytes.extend_from_slice(header.as_bytes());
226    bytes.extend_from_slice(f32_bytes);
227    bytes
228}
229
230fn parse_embedding_blob(blob: &[u8]) -> Result<&[f32]> {
231    let (header, rest) = BlobHeader::ref_from_prefix(blob)
232        .map_err(|_| MCSError::MemoryError("Invalid blob header".into()))?;
233    let count = header.dims as usize;
234    let bytes = rest
235        .get(..count * 4)
236        .ok_or_else(|| MCSError::MemoryError("Blob data too short".into()))?;
237    let emb = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, count) };
238    Ok(emb)
239}
240
241impl VectorStore {
242    /// Open a store with the default HNSW configuration for `dims`.
243    pub fn new(db_path: &Path, dims: u32) -> Result<Self> {
244        Self::with_config(db_path, &VectorConfig::new(dims))
245    }
246
247    /// Open a store with an explicit HNSW configuration.
248    pub fn with_config(db_path: &Path, cfg: &VectorConfig) -> Result<Self> {
249        let dims = cfg.dims;
250        let conn = Connection::open(db_path).map_err(sqlite_err)?;
251        conn.busy_timeout(std::time::Duration::from_secs(5))
252            .map_err(sqlite_err)?;
253        conn.execute_batch(
254            "PRAGMA journal_mode = WAL;
255             PRAGMA synchronous = NORMAL;
256             PRAGMA temp_store = MEMORY;
257             CREATE TABLE IF NOT EXISTS vector_embedding (
258                 entity_id INTEGER PRIMARY KEY,
259                 dims      INTEGER NOT NULL,
260                 blob      BLOB    NOT NULL,
261                 model     TEXT    NOT NULL DEFAULT '',
262                 created_us INTEGER NOT NULL
263             );",
264        )
265        .map_err(sqlite_err)?;
266
267        let index = match cfg.index_kind {
268            IndexKind::Hnsw => {
269                let index_opts = IndexOptions {
270                    dimensions: dims as usize,
271                    metric: cfg.metric,
272                    quantization: cfg.quantization,
273                    connectivity: cfg.connectivity,
274                    expansion_add: cfg.expansion_add,
275                    expansion_search: cfg.expansion_search,
276                    multi: false,
277                };
278                let index = Index::new(&index_opts)
279                    .map_err(|e| MCSError::MemoryError(format!("usearch init: {e}")))?;
280                AnnIndex::Hnsw(Arc::new(index))
281            }
282            IndexKind::Ivf => AnnIndex::Ivf(Box::new(IvfFlatIndex::new(
283                dims as usize,
284                IvfMetric::from_usearch(cfg.metric),
285                cfg.ivf_nlist,
286                cfg.ivf_nprobe,
287            ))),
288        };
289
290        let name_to_id = Arc::new(DashMap::new());
291        let id_to_name = Arc::new(DashMap::new());
292        let graph = Arc::new(RwLock::new(StableGraph::<EntityId, (), Directed, u32>::new()));
293        let node_map = Arc::new(DashMap::new());
294        let db = Mutex::new(conn);
295
296        let store = Self {
297            name_to_id,
298            id_to_name,
299            graph,
300            node_map,
301            index,
302            db,
303            dims,
304            count: AtomicUsize::new(0),
305            ivf_nprobe: cfg.ivf_nprobe,
306            db_path: db_path.to_path_buf(),
307        };
308        store.load_existing()?;
309
310        Ok(store)
311    }
312
313    fn load_existing(&self) -> Result<()> {
314        let conn = self.db.lock();
315        let count: usize = conn
316            .query_row("SELECT COUNT(*) FROM vector_embedding", [], |r| {
317                r.get::<_, i64>(0)
318            })
319            .map_err(sqlite_err)?
320            as usize;
321
322        if count == 0 {
323            return Ok(());
324        }
325
326        self.index.reserve(count)?;
327
328        let mut stmt = conn
329            .prepare("SELECT entity_id, dims, blob, model FROM vector_embedding")
330            .map_err(sqlite_err)?;
331
332        let rows = stmt
333            .query_map([], |row| {
334                let id: i64 = row.get(0)?;
335                let dims: i64 = row.get(1)?;
336                let blob: Vec<u8> = row.get(2)?;
337                let model: String = row.get(3)?;
338                Ok((id, dims, blob, model))
339            })
340            .map_err(sqlite_err)?;
341
342        for row in rows {
343            let (id, _row_dims, blob, _model) = row.map_err(sqlite_err)?;
344            let emb = parse_embedding_blob(&blob)?;
345            self.index.add(id as u64, emb)?;
346            self.count.fetch_add(1, Ordering::Relaxed);
347        }
348
349        // Train the IVF backend over the freshly-loaded set so a reopened,
350        // populated database gets sub-linear search immediately (HNSW: no-op).
351        self.index.train()?;
352
353        self.load_names_from_entity_table(&conn)?;
354        Ok(())
355    }
356
357    fn load_names_from_entity_table(&self, conn: &Connection) -> Result<()> {
358        let mut stmt = conn
359            .prepare("SELECT id, name FROM entity WHERE flags = 0")
360            .map_err(sqlite_err)?;
361        let rows = stmt
362            .query_map([], |row| {
363                let id: i64 = row.get(0)?;
364                let name: String = row.get(1)?;
365                Ok((id, name))
366            })
367            .map_err(sqlite_err)?;
368
369        self.name_to_id.clear();
370        self.id_to_name.clear();
371
372        for row in rows {
373            let (id, name) = row.map_err(sqlite_err)?;
374            self.name_to_id.insert(name.clone(), id);
375            self.id_to_name.insert(id, name);
376        }
377        Ok(())
378    }
379
380    fn get_entity_id_and_name(&self, conn: &Connection, entity_name: &str) -> Result<Option<(EntityId, String)>> {
381        if let Some(entry) = self.name_to_id.get(entity_name) {
382            let id = *entry;
383            let name = entity_name.to_string();
384            return Ok(Some((id, name)));
385        }
386        let h = crate::kg::name_hash(entity_name);
387        let mut stmt = conn
388            .prepare_cached(
389                "SELECT id, name FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
390            )
391            .map_err(sqlite_err)?;
392        match stmt.query_row(params![h, entity_name], |row| {
393            let id: i64 = row.get(0)?;
394            let name: String = row.get(1)?;
395            Ok((id, name))
396        }) {
397            Ok(tup) => {
398                self.name_to_id.insert(tup.1.clone(), tup.0);
399                self.id_to_name.insert(tup.0, tup.1.clone());
400                Ok(Some(tup))
401            }
402            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
403            Err(e) => Err(sqlite_err(e)),
404        }
405    }
406
407    pub fn upsert_embedding(&self, entity_name: &str, embedding: &[f32], model: &str) -> Result<()> {
408        if embedding.len() != self.dims as usize {
409            return Err(MCSError::InvalidParams(format!(
410                "Embedding dimension mismatch: got {}, expected {}",
411                embedding.len(),
412                self.dims
413            )));
414        }
415
416        let conn = self.db.lock();
417        let entity = self
418            .get_entity_id_and_name(&conn, entity_name)?
419            .ok_or_else(|| {
420                MCSError::InvalidParams(format!("Entity '{entity_name}' not found in KG"))
421            })?;
422        let entity_id = entity.0;
423
424        // Grow the index capacity in chunks rather than one slot per upsert, so
425        // a bulk load doesn't trigger a reallocation on every insert.
426        let needed = self.count.load(Ordering::Relaxed).saturating_add(1);
427        if needed > self.index.capacity() {
428            const CHUNK: usize = 1024;
429            let target = needed.div_ceil(CHUNK).saturating_mul(CHUNK);
430            self.index.reserve(target)?;
431        }
432        let existed = self.index.remove(entity_id as u64).unwrap_or(false);
433        self.index.add(entity_id as u64, embedding)?;
434
435        self.name_to_id
436            .insert(entity_name.to_string(), entity_id);
437        self.id_to_name.insert(entity_id, entity_name.to_string());
438
439        let blob = serialize_embedding(embedding);
440        let now = std::time::SystemTime::now()
441            .duration_since(std::time::UNIX_EPOCH)
442            .unwrap_or_default()
443            .as_micros() as i64;
444
445        conn.execute(
446            "INSERT OR REPLACE INTO vector_embedding (entity_id, dims, blob, model, created_us) VALUES (?1, ?2, ?3, ?4, ?5)",
447            params![entity_id, self.dims, blob, model, now],
448        )
449        .map_err(sqlite_err)?;
450
451        if !existed {
452            self.count.fetch_add(1, Ordering::Relaxed);
453        }
454        Ok(())
455    }
456
457    pub fn delete_embedding(&self, entity_name: &str) -> Result<bool> {
458        let conn = self.db.lock();
459        let entity_id = match self.name_to_id.get(entity_name) {
460            Some(entry) => *entry,
461            None => {
462                return Ok(false);
463            }
464        };
465
466        self.index.remove(entity_id as u64)?;
467
468        self.name_to_id.remove(entity_name);
469        self.id_to_name.remove(&entity_id);
470
471        conn.execute(
472            "DELETE FROM vector_embedding WHERE entity_id = ?1",
473            params![entity_id],
474        )
475        .map_err(sqlite_err)?;
476
477        {
478            let mut g = self.graph.write();
479            if let Some(nx) = self.node_map.get(&entity_id) {
480                g.remove_node(*nx);
481                self.node_map.remove(&entity_id);
482            }
483        }
484
485        self.count.fetch_sub(1, Ordering::Relaxed);
486        Ok(true)
487    }
488
489    pub fn search_embeddings(
490        &self,
491        query: &[f32],
492        top_k: usize,
493    ) -> Result<Vec<(EntityId, f32)>> {
494        if self.count.load(Ordering::Relaxed) == 0 {
495            return Ok(Vec::new());
496        }
497        let top_k = top_k.clamp(1, 100);
498        let matches = self.index.search(query, top_k, Some(self.ivf_nprobe))?;
499        Ok(matches
500            .into_iter()
501            .map(|(id, dist)| (id as EntityId, dist))
502            .collect())
503    }
504
505    pub fn search_entities_json(
506        &self,
507        query: &[f32],
508        top_k: usize,
509        entity_type_filter: Option<&str>,
510    ) -> Result<String> {
511        let results = self.search_embeddings(query, top_k)?;
512        if results.is_empty() {
513            return Ok(r#"{"results":[],"count":0}"#.to_string());
514        }
515
516        let conn = self.db.lock();
517        let mut out = String::with_capacity(128 + results.len() * 64);
518        out.push_str(r#"{"results":["#);
519        let mut first = true;
520        let mut actual_count = 0usize;
521
522        for &(id, dist) in &results {
523            let name = self
524                .id_to_name
525                .get(&id)
526                .map(|r| r.value().clone())
527                .or_else(|| {
528                    conn.query_row(
529                        "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
530                        params![id],
531                        |row| row.get::<_, String>(0),
532                    )
533                    .ok()
534                });
535
536            let name = match name {
537                Some(n) => n,
538                None => continue,
539            };
540
541            let etype: String = conn
542                .query_row(
543                    "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
544                    params![id],
545                    |row| row.get(0),
546                )
547                .unwrap_or_default();
548
549            if let Some(filter_type) = entity_type_filter {
550                if etype != filter_type {
551                    continue;
552                }
553            }
554
555            if !first {
556                out.push(',');
557            }
558            first = false;
559
560            out.push_str(r#"{"name":"#);
561            push_json_str(&mut out, &name);
562            out.push_str(r#","entityType":"#);
563            push_json_str(&mut out, &etype);
564            write_f32(&mut out, dist);
565            out.push('}');
566            actual_count += 1;
567        }
568
569        out.push_str(r#"],"count":"#);
570        out.push_str(&actual_count.to_string());
571        out.push('}');
572        Ok(out)
573    }
574
575    pub fn build_search_response_json(&self, results: &[(EntityId, f32)]) -> String {
576        let mut out = String::with_capacity(128 + results.len() * 64);
577        out.push_str(r#"{"results":["#);
578        for (i, &(id, dist)) in results.iter().enumerate() {
579            if i > 0 {
580                out.push(',');
581            }
582            out.push_str(r#"{"entityId":"#);
583            out.push_str(&id.to_string());
584            out.push_str(r#","distance":"#);
585            write_f32(&mut out, dist);
586            out.push('}');
587        }
588        out.push_str(r#"],"count":"#);
589        out.push_str(&results.len().to_string());
590        out.push('}');
591        out
592    }
593
594    pub fn rebuild_graph_cache(&self) -> Result<()> {
595        let conn = self.db.lock();
596
597        let mut ent_stmt = conn
598            .prepare("SELECT entity_id FROM vector_embedding")
599            .map_err(sqlite_err)?;
600        let ids: Vec<EntityId> = ent_stmt
601            .query_map([], |r| r.get::<_, i64>(0))
602            .map_err(sqlite_err)?
603            .filter_map(|r| r.ok())
604            .collect();
605
606        let mut g = StableGraph::<EntityId, (), Directed, u32>::with_capacity(ids.len(), 0);
607        let nm = DashMap::new();
608
609        for &id in &ids {
610            let nx = g.add_node(id);
611            nm.insert(id, nx);
612        }
613
614        if !ids.is_empty() {
615            const BATCH_SIZE: usize = 5000;
616            for chunk in ids.chunks(BATCH_SIZE) {
617                let placeholders: Vec<String> = chunk.iter().map(|_| "?".to_string()).collect();
618                let sql = format!(
619                    "SELECT from_id, to_id FROM relation WHERE from_id IN ({}) AND to_id IN ({})",
620                    placeholders.join(","),
621                    placeholders.join(",")
622                );
623                let mut rel_stmt = conn.prepare(&sql).map_err(sqlite_err)?;
624
625                let mut param_values: Vec<&dyn rusqlite::types::ToSql> = Vec::with_capacity(chunk.len() * 2);
626                for id in chunk {
627                    param_values.push(id as &dyn rusqlite::types::ToSql);
628                }
629                for id in chunk {
630                    param_values.push(id as &dyn rusqlite::types::ToSql);
631                }
632
633                let rel_rows = rel_stmt
634                    .query_map(param_values.as_slice(), |row| {
635                        let from: i64 = row.get(0)?;
636                        let to: i64 = row.get(1)?;
637                        Ok((from, to))
638                    })
639                    .map_err(sqlite_err)?;
640
641                for rel in rel_rows {
642                    let (from, to) = rel.map_err(sqlite_err)?;
643                    if let (Some(f_nx), Some(t_nx)) = (nm.get(&from), nm.get(&to))
644                        && g.find_edge(*f_nx, *t_nx).is_none()
645                    {
646                        g.add_edge(*f_nx, *t_nx, ());
647                    }
648                }
649            }
650        }
651
652        *self.graph.write() = g;
653        self.node_map.clear();
654        for entry in nm.iter() {
655            self.node_map.insert(*entry.key(), *entry.value());
656        }
657
658        Ok(())
659    }
660
661    pub fn graph_node_count(&self) -> usize {
662        self.node_map.len()
663    }
664
665    pub fn graph_edge_count(&self) -> usize {
666        self.graph.read().edge_count()
667    }
668
669    pub fn get_entity_type(&self, entity_id: EntityId) -> Result<Option<String>> {
670        let conn = self.db.lock();
671        let etype = conn
672            .query_row(
673                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
674                params![entity_id],
675                |row| row.get(0),
676            )
677            .ok();
678        Ok(etype)
679    }
680
681    pub fn count(&self) -> usize {
682        self.count.load(Ordering::Relaxed)
683    }
684
685    pub const fn dims(&self) -> u32 {
686        self.dims
687    }
688
689    /// Approximate resident RAM used by the ANN index, in bytes.
690    pub fn index_memory_bytes(&self) -> usize {
691        self.index.memory_bytes()
692    }
693
694    /// Breakdown of index RAM into (graph_bytes, vectors_bytes). IVF has no graph
695    /// component, so its `graph_bytes` is 0.
696    pub fn index_memory_breakdown(&self) -> (usize, usize) {
697        self.index.memory_breakdown()
698    }
699
700    /// Current allocated capacity of the index (number of vectors it can hold
701    /// before the next reservation).
702    pub fn index_capacity(&self) -> usize {
703        self.index.capacity()
704    }
705
706    /// The active ANN backend.
707    pub const fn index_kind(&self) -> IndexKind {
708        self.index.kind()
709    }
710
711    /// Rebuild the ANN index structure: retrains the IVF centroids over the
712    /// current vectors (no-op for HNSW). Call after large batch ingestion to keep
713    /// IVF recall high.
714    pub fn reindex(&self) -> Result<()> {
715        self.index.train()
716    }
717
718    /// Resolve a live entity id by name (cache first, then the KG table).
719    pub fn entity_id_of(&self, name: &str) -> Result<Option<EntityId>> {
720        let conn = self.db.lock();
721        Ok(self.get_entity_id_and_name(&conn, name)?.map(|(id, _)| id))
722    }
723
724    /// Fetch the stored embedding for an entity id, if any.
725    pub fn get_embedding_by_id(&self, id: EntityId) -> Result<Option<Vec<f32>>> {
726        let conn = self.db.lock();
727        let blob: Option<Vec<u8>> = conn
728            .query_row(
729                "SELECT blob FROM vector_embedding WHERE entity_id = ?1",
730                params![id],
731                |r| r.get(0),
732            )
733            .ok();
734        match blob {
735            Some(b) => Ok(Some(parse_embedding_blob(&b)?.to_vec())),
736            None => Ok(None),
737        }
738    }
739
740    /// Fetch `(entity_id, embedding, model)` for an entity by name.
741    pub fn get_embedding_by_name(
742        &self,
743        name: &str,
744    ) -> Result<Option<(EntityId, Vec<f32>, String)>> {
745        let id = match self.entity_id_of(name)? {
746            Some(id) => id,
747            None => return Ok(None),
748        };
749        let conn = self.db.lock();
750        let row: Option<(Vec<u8>, String)> = conn
751            .query_row(
752                "SELECT blob, model FROM vector_embedding WHERE entity_id = ?1",
753                params![id],
754                |r| Ok((r.get(0)?, r.get(1)?)),
755            )
756            .ok();
757        match row {
758            Some((blob, model)) => Ok(Some((id, parse_embedding_blob(&blob)?.to_vec(), model))),
759            None => Ok(None),
760        }
761    }
762
763    /// Resolve an entity id to `(name, entityType)`, preferring the in-memory name
764    /// cache and reading the type from the KG.
765    pub fn resolve_name_type(&self, id: EntityId) -> (String, String) {
766        let conn = self.db.lock();
767        let name = self
768            .id_to_name
769            .get(&id)
770            .map(|r| r.value().clone())
771            .or_else(|| {
772                conn.query_row(
773                    "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
774                    params![id],
775                    |row| row.get::<_, String>(0),
776                )
777                .ok()
778            })
779            .unwrap_or_default();
780        let etype: String = conn
781            .query_row(
782                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
783                params![id],
784                |row| row.get(0),
785            )
786            .unwrap_or_default();
787        (name, etype)
788    }
789
790    /// k-NN that returns resolved `(id, name, entityType, distance)`, optionally
791    /// filtered by `entity_type` and excluding `exclude` ids. Over-fetches to
792    /// compensate for filtered-out rows.
793    pub fn search_resolved(
794        &self,
795        query: &[f32],
796        top_k: usize,
797        entity_type: Option<&str>,
798        exclude: &std::collections::HashSet<EntityId>,
799    ) -> Result<Vec<(EntityId, String, String, f32)>> {
800        let fetch = (top_k.saturating_mul(3) + exclude.len()).clamp(top_k, 100);
801        let raw = self.search_embeddings(query, fetch)?;
802        let mut out = Vec::with_capacity(top_k);
803        for (id, dist) in raw {
804            if exclude.contains(&id) {
805                continue;
806            }
807            let (name, etype) = self.resolve_name_type(id);
808            if name.is_empty() {
809                continue;
810            }
811            if let Some(ft) = entity_type
812                && etype != ft
813            {
814                continue;
815            }
816            out.push((id, name, etype, dist));
817            if out.len() >= top_k {
818                break;
819            }
820        }
821        Ok(out)
822    }
823
824    pub fn name_to_id(&self) -> &DashMap<String, EntityId> {
825        &self.name_to_id
826    }
827
828    pub fn id_to_name(&self) -> &DashMap<EntityId, String> {
829        &self.id_to_name
830    }
831}
832
833fn write_f32(buf: &mut String, val: f32) {
834    use std::fmt::Write;
835    write!(buf, r#","score":{:.6}"#, val).unwrap();
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841    use crate::kg::GraphHandle;
842    use crate::config::{Durability, SqliteTuning};
843    use crate::types::Entity;
844    use std::num::NonZeroUsize;
845
846    struct TestEnv {
847        kg: GraphHandle,
848        vs: VectorStore,
849        _dir: tempfile::TempDir,
850    }
851
852    fn setup(dims: u32) -> TestEnv {
853        let dir = tempfile::TempDir::new().unwrap();
854        let db_path = dir.path().join("test.db");
855        let lru = NonZeroUsize::new(10000).unwrap();
856        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
857        let vs = VectorStore::new(&db_path, dims).unwrap();
858        TestEnv {
859            kg,
860            vs,
861            _dir: dir,
862        }
863    }
864
865    fn setup_ivf(dims: u32) -> TestEnv {
866        let dir = tempfile::TempDir::new().unwrap();
867        let db_path = dir.path().join("test.db");
868        let lru = NonZeroUsize::new(10000).unwrap();
869        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
870        let mut cfg = VectorConfig::new(dims);
871        cfg.index_kind = IndexKind::Ivf;
872        cfg.ivf_nlist = 4;
873        cfg.ivf_nprobe = 4;
874        let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
875        TestEnv {
876            kg,
877            vs,
878            _dir: dir,
879        }
880    }
881
882    fn create_test_entity(kg: &GraphHandle, name: &str, etype: &str) {
883        kg.create_entities(&[Entity {
884            name: name.into(),
885            entity_type: etype.into(),
886            observations: vec!["test observation".into()],
887        }])
888        .unwrap();
889    }
890
891    fn make_embedding(dims: u32, value: f32) -> Vec<f32> {
892        vec![value; dims as usize]
893    }
894
895    #[test]
896    fn test_vector_upsert_and_search() {
897        let env = setup(4);
898        create_test_entity(&env.kg, "alice", "person");
899        create_test_entity(&env.kg, "bob", "person");
900
901        let emb_a = make_embedding(4, 1.0);
902        let emb_b = make_embedding(4, 0.1);
903        env.vs.upsert_embedding("alice", &emb_a, "test-model").unwrap();
904        env.vs.upsert_embedding("bob", &emb_b, "test-model").unwrap();
905
906        let query = make_embedding(4, 1.0);
907        let results = env.vs.search_embeddings(&query, 10).unwrap();
908        assert_eq!(results.len(), 2);
909        assert!(results[0].1 < results[1].1);
910    }
911
912    #[test]
913    fn test_vector_delete_embedding() {
914        let env = setup(4);
915        create_test_entity(&env.kg, "alice", "person");
916        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
917        assert_eq!(env.vs.count(), 1);
918
919        let deleted = env.vs.delete_embedding("alice").unwrap();
920        assert!(deleted);
921        assert_eq!(env.vs.count(), 0);
922
923        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
924        assert!(results.is_empty());
925    }
926
927    #[test]
928    fn test_vector_upsert_nonexistent_entity() {
929        let env = setup(4);
930        let err = env.vs.upsert_embedding("nonexistent", &make_embedding(4, 1.0), "");
931        assert!(err.is_err());
932    }
933
934    #[test]
935    fn test_vector_dimension_mismatch() {
936        let env = setup(4);
937        create_test_entity(&env.kg, "alice", "person");
938        let err = env.vs.upsert_embedding("alice", &make_embedding(8, 1.0), "");
939        assert!(err.is_err());
940    }
941
942    #[test]
943    fn test_vector_search_top_k() {
944        let env = setup(4);
945        for i in 0..5 {
946            create_test_entity(&env.kg, &format!("e{i}"), "test");
947            env.vs.upsert_embedding(&format!("e{i}"), &make_embedding(4, i as f32 * 0.2), "")
948                .unwrap();
949        }
950        let results = env.vs.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
951        assert_eq!(results.len(), 3);
952    }
953
954    #[test]
955    fn test_vector_search_type_filter() {
956        let env = setup(4);
957        create_test_entity(&env.kg, "alice", "person");
958        create_test_entity(&env.kg, "acme", "organization");
959        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
960        env.vs.upsert_embedding("acme", &make_embedding(4, 0.95), "").unwrap();
961
962        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, Some("person")).unwrap();
963        assert!(json.contains("alice"));
964        assert!(!json.contains("acme"));
965    }
966
967    #[test]
968    fn test_vector_blob_roundtrip() {
969        let emb: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
970        let blob = serialize_embedding(&emb);
971        let parsed = parse_embedding_blob(&blob).unwrap();
972        assert_eq!(parsed.len(), emb.len());
973        for (a, b) in parsed.iter().zip(emb.iter()) {
974            assert!((a - b).abs() < 1e-6);
975        }
976    }
977
978    #[test]
979    fn test_vector_scratch_buffer() {
980        with_scratch(|buf| {
981            buf.push(1.0);
982            buf.push(2.0);
983            assert_eq!(buf.len(), 2);
984        });
985        with_scratch(|buf| {
986            assert!(buf.is_empty());
987            buf.extend_from_slice(&[3.0, 4.0, 5.0]);
988            assert_eq!(buf.len(), 3);
989        });
990    }
991
992    #[test]
993    fn test_vector_rebuild_graph_cache() {
994        let env = setup(4);
995        create_test_entity(&env.kg, "alice", "person");
996        create_test_entity(&env.kg, "bob", "person");
997        create_test_entity(&env.kg, "charlie", "person");
998
999        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1000        env.vs.upsert_embedding("bob", &make_embedding(4, 0.5), "").unwrap();
1001        env.vs.upsert_embedding("charlie", &make_embedding(4, 0.0), "").unwrap();
1002
1003        env.kg
1004            .create_relations(&[crate::types::Relation {
1005                from: "alice".into(),
1006                to: "bob".into(),
1007                relation_type: "knows".into(),
1008            }])
1009            .unwrap();
1010
1011        env.vs.rebuild_graph_cache().unwrap();
1012        assert_eq!(env.vs.graph_node_count(), 3);
1013        assert_eq!(env.vs.graph_edge_count(), 1);
1014    }
1015
1016    #[test]
1017    fn test_vector_upsert_replace() {
1018        let env = setup(4);
1019        create_test_entity(&env.kg, "alice", "person");
1020        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1021        env.vs.upsert_embedding("alice", &make_embedding(4, 0.5), "").unwrap();
1022        assert_eq!(env.vs.count(), 1);
1023
1024        let results = env.vs.search_embeddings(&make_embedding(4, 0.5), 10).unwrap();
1025        assert_eq!(results.len(), 1);
1026        let name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1027        assert_eq!(name.as_deref(), Some("alice"));
1028    }
1029
1030    #[test]
1031    fn test_vector_index_capacity_grows_in_chunks() {
1032        let env = setup(4);
1033        // Empty store reserves nothing.
1034        assert_eq!(env.vs.count(), 0);
1035
1036        // First insert reserves at least a full 1024 chunk up front (usearch may
1037        // over-allocate beyond that, but it must not reserve just one slot).
1038        create_test_entity(&env.kg, "e0", "t");
1039        env.vs.upsert_embedding("e0", &make_embedding(4, 0.0), "").unwrap();
1040        let cap_after_first = env.vs.index_capacity();
1041        assert!(cap_after_first >= 1024, "capacity {cap_after_first} < 1024");
1042
1043        // Inserts within the same chunk do not reallocate.
1044        for i in 1..50 {
1045            let name = format!("e{i}");
1046            create_test_entity(&env.kg, &name, "t");
1047            env.vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.01), "").unwrap();
1048        }
1049        assert_eq!(env.vs.count(), 50);
1050        assert_eq!(env.vs.index_capacity(), cap_after_first, "capacity changed mid-chunk");
1051
1052        // Overwriting an existing entity never grows capacity.
1053        env.vs.upsert_embedding("e0", &make_embedding(4, 0.5), "").unwrap();
1054        assert_eq!(env.vs.count(), 50);
1055        assert_eq!(env.vs.index_capacity(), cap_after_first);
1056
1057        // Memory accounting is exposed and non-zero once vectors are present.
1058        assert!(env.vs.index_memory_bytes() > 0);
1059        let (graph_bytes, vec_bytes) = env.vs.index_memory_breakdown();
1060        assert!(graph_bytes + vec_bytes > 0);
1061    }
1062
1063    #[test]
1064    fn test_vector_empty_store_search() {
1065        let env = setup(4);
1066        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1067        assert_eq!(json, r#"{"results":[],"count":0}"#);
1068    }
1069
1070    #[test]
1071    fn test_vector_persistence_across_reopen() {
1072        let dir = tempfile::TempDir::new().unwrap();
1073        let db_path = dir.path().join("persist.db");
1074        let lru = NonZeroUsize::new(10000).unwrap();
1075
1076        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1077        kg.create_entities(&[Entity {
1078            name: "alice".into(),
1079            entity_type: "person".into(),
1080            observations: vec![],
1081        }])
1082        .unwrap();
1083
1084        let vs1 = VectorStore::new(&db_path, 4).unwrap();
1085        vs1.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1086        assert_eq!(vs1.count(), 1);
1087        drop(vs1);
1088        drop(kg);
1089
1090        let kg2 = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1091        let vs2 = VectorStore::new(&db_path, 4).unwrap();
1092        assert_eq!(vs2.count(), 1);
1093
1094        let results = vs2.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1095        assert_eq!(results.len(), 1);
1096        drop(vs2);
1097        drop(kg2);
1098    }
1099
1100    #[test]
1101    fn test_vector_search_json_format() {
1102        let env = setup(4);
1103        create_test_entity(&env.kg, "alice", "person");
1104        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1105
1106        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1107        assert!(json.contains("alice"));
1108        assert!(json.contains("person"));
1109        assert!(json.contains("score"));
1110        assert!(json.contains("count"));
1111    }
1112
1113    #[test]
1114    fn test_vector_concurrent_upsert() {
1115        let env = setup(8);
1116        let vs = Arc::new(env.vs);
1117
1118        let mut threads = Vec::new();
1119        for i in 0..4 {
1120            let vs = Arc::clone(&vs);
1121            threads.push(std::thread::spawn(move || {
1122                let name = format!("thread_{i}");
1123                // entity creation happens through GraphHandle - shared
1124                vs.upsert_embedding(&name, &make_embedding(8, i as f32 * 0.25), "")
1125                    .ok();
1126            }));
1127        }
1128
1129        create_test_entity(&env.kg, "thread_0", "t");
1130        create_test_entity(&env.kg, "thread_1", "t");
1131        create_test_entity(&env.kg, "thread_2", "t");
1132        create_test_entity(&env.kg, "thread_3", "t");
1133
1134        for t in threads {
1135            t.join().unwrap();
1136        }
1137    }
1138
1139    // ── IVF backend (via VectorStore) ─────────────────────────────────────
1140
1141    #[test]
1142    fn test_ivf_store_upsert_search_delete() {
1143        let env = setup_ivf(4);
1144        assert_eq!(env.vs.index_kind(), IndexKind::Ivf);
1145        create_test_entity(&env.kg, "alice", "person");
1146        create_test_entity(&env.kg, "bob", "person");
1147        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "m").unwrap();
1148        env.vs.upsert_embedding("bob", &make_embedding(4, 0.1), "m").unwrap();
1149        assert_eq!(env.vs.count(), 2);
1150
1151        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1152        assert_eq!(results.len(), 2);
1153        // alice (all 1.0) is the closest match to the all-ones query.
1154        let top_name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1155        assert_eq!(top_name.as_deref(), Some("alice"));
1156
1157        assert!(env.vs.delete_embedding("alice").unwrap());
1158        assert_eq!(env.vs.count(), 1);
1159        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1160        assert_eq!(results.len(), 1);
1161    }
1162
1163    #[test]
1164    fn test_ivf_persistence_and_reindex() {
1165        let dir = tempfile::TempDir::new().unwrap();
1166        let db_path = dir.path().join("ivf.db");
1167        let lru = NonZeroUsize::new(10000).unwrap();
1168        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1169        let mut cfg = VectorConfig::new(4);
1170        cfg.index_kind = IndexKind::Ivf;
1171        cfg.ivf_nlist = 3;
1172        cfg.ivf_nprobe = 3;
1173
1174        {
1175            let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
1176            for i in 0..12 {
1177                let name = format!("e{i}");
1178                create_test_entity(&kg, &name, "t");
1179                vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.1), "").unwrap();
1180            }
1181            vs.reindex().unwrap();
1182            assert_eq!(vs.count(), 12);
1183        }
1184
1185        // Reopen: embeddings reload and the IVF index retrains on load.
1186        let vs2 = VectorStore::with_config(&db_path, &cfg).unwrap();
1187        assert_eq!(vs2.count(), 12);
1188        let results = vs2.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
1189        assert!(!results.is_empty());
1190        // The exact match (e0 == all-zeros) should be the nearest.
1191        let top = vs2.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1192        assert_eq!(top.as_deref(), Some("e0"));
1193    }
1194
1195    // ── New retrieval helpers (backend-agnostic) ──────────────────────────
1196
1197    #[test]
1198    fn test_get_embedding_helpers() {
1199        let env = setup(4);
1200        create_test_entity(&env.kg, "alice", "person");
1201        let emb = vec![0.1, 0.2, 0.3, 0.4];
1202        env.vs.upsert_embedding("alice", &emb, "model-x").unwrap();
1203
1204        let id = env.vs.entity_id_of("alice").unwrap().unwrap();
1205        let by_id = env.vs.get_embedding_by_id(id).unwrap().unwrap();
1206        assert_eq!(by_id, emb);
1207
1208        let (got_id, got_emb, model) = env.vs.get_embedding_by_name("alice").unwrap().unwrap();
1209        assert_eq!(got_id, id);
1210        assert_eq!(got_emb, emb);
1211        assert_eq!(model, "model-x");
1212
1213        assert!(env.vs.get_embedding_by_name("nobody").unwrap().is_none());
1214    }
1215
1216    #[test]
1217    fn test_search_resolved_excludes_and_filters() {
1218        let env = setup(4);
1219        create_test_entity(&env.kg, "a", "doc");
1220        create_test_entity(&env.kg, "b", "doc");
1221        create_test_entity(&env.kg, "c", "note");
1222        env.vs.upsert_embedding("a", &make_embedding(4, 1.0), "").unwrap();
1223        env.vs.upsert_embedding("b", &make_embedding(4, 0.9), "").unwrap();
1224        env.vs.upsert_embedding("c", &make_embedding(4, 0.95), "").unwrap();
1225
1226        let id_a = env.vs.entity_id_of("a").unwrap().unwrap();
1227        let mut exclude = std::collections::HashSet::new();
1228        exclude.insert(id_a);
1229
1230        // Exclude "a"; without a type filter we expect b and c.
1231        let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, None, &exclude).unwrap();
1232        let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1233        assert!(!names.contains(&"a"));
1234        assert!(names.contains(&"b") && names.contains(&"c"));
1235
1236        // Now filter to type "doc": only "b" remains (a excluded, c is a note).
1237        let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, Some("doc"), &exclude).unwrap();
1238        let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1239        assert_eq!(names, vec!["b"]);
1240    }
1241}