Skip to main content

mcp_memory/
vector_store.rs

1use std::path::Path;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use parking_lot::{Mutex, RwLock};
7use petgraph::graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::Directed;
10use rusqlite::{params, Connection};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
13
14use crate::errors::{MCSError, Result};
15use crate::ivf::{IvfFlatIndex, Metric as IvfMetric};
16use crate::kg::push_json_str;
17
18pub type EntityId = i64;
19
20/// Unifies the two ANN backends behind one small surface so [`VectorStore`] does
21/// not branch on the index kind at every call site. Distances follow usearch's
22/// convention (smaller = closer) for both backends.
23enum AnnIndex {
24    Hnsw(Arc<Index>),
25    Ivf(Box<IvfFlatIndex>),
26}
27
28impl AnnIndex {
29    /// Current allocated capacity (HNSW) or live count (IVF; it grows on demand).
30    fn capacity(&self) -> usize {
31        match self {
32            AnnIndex::Hnsw(i) => i.capacity(),
33            AnnIndex::Ivf(i) => i.len(),
34        }
35    }
36
37    /// Ensure room for `target` vectors. No-op for IVF (a growable `Vec`).
38    fn reserve(&self, target: usize) -> Result<()> {
39        if let AnnIndex::Hnsw(i) = self {
40            i.reserve_capacity_and_threads(target, 1)
41                .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
42        }
43        Ok(())
44    }
45
46    /// Add `id`/`vector` to the index (caller has already removed any prior entry).
47    fn add(&self, id: u64, vector: &[f32]) -> Result<()> {
48        match self {
49            AnnIndex::Hnsw(i) => i
50                .add(id, vector)
51                .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}"))),
52            AnnIndex::Ivf(i) => i
53                .upsert(id, vector)
54                .map(|_| ())
55                .map_err(MCSError::MemoryError),
56        }
57    }
58
59    /// Remove `id`; returns whether it existed.
60    fn remove(&self, id: u64) -> Result<bool> {
61        match self {
62            AnnIndex::Hnsw(i) => i
63                .remove(id)
64                .map(|n| n > 0)
65                .map_err(|e| MCSError::MemoryError(format!("usearch remove: {e}"))),
66            AnnIndex::Ivf(i) => Ok(i.remove(id)),
67        }
68    }
69
70    /// Nearest `top_k` ids with distances (ascending). `nprobe` applies to IVF only.
71    fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Result<Vec<(u64, f32)>> {
72        match self {
73            AnnIndex::Hnsw(i) => {
74                let m = i
75                    .search(query, top_k)
76                    .map_err(|e| MCSError::MemoryError(format!("usearch search: {e}")))?;
77                let cap = m.keys.len().min(m.distances.len());
78                Ok((0..cap).map(|j| (m.keys[j], m.distances[j])).collect())
79            }
80            AnnIndex::Ivf(i) => i.search(query, top_k, nprobe).map_err(MCSError::MemoryError),
81        }
82    }
83
84    /// (Re)train the IVF centroids. No-op for HNSW.
85    fn train(&self) -> Result<()> {
86        if let AnnIndex::Ivf(i) = self {
87            i.train().map_err(MCSError::MemoryError)?;
88        }
89        Ok(())
90    }
91
92    const fn kind(&self) -> IndexKind {
93        match self {
94            AnnIndex::Hnsw(_) => IndexKind::Hnsw,
95            AnnIndex::Ivf(_) => IndexKind::Ivf,
96        }
97    }
98
99    fn memory_bytes(&self) -> usize {
100        match self {
101            AnnIndex::Hnsw(i) => i.memory_usage(),
102            AnnIndex::Ivf(i) => i.memory_bytes(),
103        }
104    }
105
106    /// (graph_bytes, vectors_bytes). IVF has no graph, so its graph component is 0.
107    fn memory_breakdown(&self) -> (usize, usize) {
108        match self {
109            AnnIndex::Hnsw(i) => {
110                let s = i.memory_stats();
111                (
112                    s.graph_allocated + s.graph_reserved,
113                    s.vectors_allocated + s.vectors_reserved,
114                )
115            }
116            AnnIndex::Ivf(i) => (0, i.memory_bytes()),
117        }
118    }
119}
120
121#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
122#[repr(C)]
123struct BlobHeader {
124    dims: u32,
125}
126
127/// The ANN index backend a [`VectorStore`] uses.
128#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
129pub enum IndexKind {
130    /// usearch HNSW: best recall/latency, higher memory and build cost.
131    #[default]
132    Hnsw,
133    /// IVF-Flat: k-means partitioned exact-within-cell search. Cheaper to build
134    /// and lighter in memory; suits large, batch-ingested, periodically-rebuilt
135    /// corpora. Exact (brute-force) until trained.
136    Ivf,
137}
138
139/// Tunable parameters for the vector index. Built from CLI flags;
140/// [`VectorConfig::new`] supplies the defaults used by tests and any caller that
141/// only cares about the embedding dimension.
142#[derive(Clone, Copy, Debug)]
143pub struct VectorConfig {
144    /// Embedding dimension. All upserted/queried vectors must match this.
145    pub dims: u32,
146    /// Which ANN backend to use.
147    pub index_kind: IndexKind,
148    /// Distance metric used by the index.
149    pub metric: MetricKind,
150    /// On-disk/in-index scalar representation (enables quantization). HNSW only.
151    pub quantization: ScalarKind,
152    /// HNSW graph degree (`M`). Higher = better recall, more memory.
153    pub connectivity: usize,
154    /// HNSW `efConstruction`. Higher = better index quality, slower inserts.
155    pub expansion_add: usize,
156    /// HNSW `efSearch`. Higher = better recall, slower queries.
157    pub expansion_search: usize,
158    /// IVF number of Voronoi cells (centroids). IVF only.
159    pub ivf_nlist: usize,
160    /// IVF default cells probed per query. IVF only.
161    pub ivf_nprobe: usize,
162}
163
164impl VectorConfig {
165    /// Default HNSW configuration for the given embedding dimension.
166    pub const fn new(dims: u32) -> Self {
167        Self {
168            dims,
169            index_kind: IndexKind::Hnsw,
170            metric: MetricKind::Cos,
171            quantization: ScalarKind::F32,
172            connectivity: 16,
173            expansion_add: 200,
174            expansion_search: 50,
175            ivf_nlist: 256,
176            ivf_nprobe: 8,
177        }
178    }
179}
180
181pub struct VectorStore {
182    pub name_to_id: Arc<DashMap<String, EntityId>>,
183    pub id_to_name: Arc<DashMap<EntityId, String>>,
184
185    pub(crate) graph: Arc<RwLock<StableGraph<EntityId, (), Directed, u32>>>,
186    pub(crate) node_map: Arc<DashMap<EntityId, NodeIndex<u32>>>,
187
188    index: AnnIndex,
189    pub(crate) db: Mutex<Connection>,
190
191    pub dims: u32,
192    pub count: AtomicUsize,
193    /// Default cells probed per IVF query (ignored by HNSW).
194    ivf_nprobe: usize,
195
196    pub db_path: std::path::PathBuf,
197}
198
199fn sqlite_err(e: rusqlite::Error) -> MCSError {
200    MCSError::IoError(std::io::Error::other(e))
201}
202
203thread_local! {
204    static SCRATCH: std::cell::RefCell<Vec<f32>> = const {
205        std::cell::RefCell::new(Vec::new())
206    };
207}
208
209pub fn with_scratch<R>(f: impl FnOnce(&mut Vec<f32>) -> R) -> R {
210    SCRATCH.with(|cell| {
211        let mut buf = cell.borrow_mut();
212        buf.clear();
213        f(&mut buf)
214    })
215}
216
217fn serialize_embedding(emb: &[f32]) -> Vec<u8> {
218    let header = BlobHeader {
219        dims: emb.len() as u32,
220    };
221    let f32_bytes: &[u8] = unsafe {
222        std::slice::from_raw_parts(emb.as_ptr() as *const u8, emb.len() * 4)
223    };
224    let mut bytes = Vec::with_capacity(4 + f32_bytes.len());
225    bytes.extend_from_slice(header.as_bytes());
226    bytes.extend_from_slice(f32_bytes);
227    bytes
228}
229
230fn parse_embedding_blob(blob: &[u8]) -> Result<&[f32]> {
231    let (header, rest) = BlobHeader::ref_from_prefix(blob)
232        .map_err(|_| MCSError::MemoryError("Invalid blob header".into()))?;
233    let count = header.dims as usize;
234    let bytes = rest
235        .get(..count * 4)
236        .ok_or_else(|| MCSError::MemoryError("Blob data too short".into()))?;
237    let emb = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, count) };
238    Ok(emb)
239}
240
241impl VectorStore {
242    /// Open a store with the default HNSW configuration for `dims`.
243    pub fn new(db_path: &Path, dims: u32) -> Result<Self> {
244        Self::with_config(db_path, &VectorConfig::new(dims))
245    }
246
247    /// Open a store with an explicit HNSW configuration.
248    pub fn with_config(db_path: &Path, cfg: &VectorConfig) -> Result<Self> {
249        let dims = cfg.dims;
250        let conn = Connection::open(db_path).map_err(sqlite_err)?;
251        conn.busy_timeout(std::time::Duration::from_secs(5))
252            .map_err(sqlite_err)?;
253        conn.execute_batch(
254            "PRAGMA journal_mode = WAL;
255             PRAGMA synchronous = NORMAL;
256             PRAGMA temp_store = MEMORY;
257             CREATE TABLE IF NOT EXISTS vector_embedding (
258                 entity_id INTEGER PRIMARY KEY,
259                 dims      INTEGER NOT NULL,
260                 blob      BLOB    NOT NULL,
261                 model     TEXT    NOT NULL DEFAULT '',
262                 created_us INTEGER NOT NULL
263             );",
264        )
265        .map_err(sqlite_err)?;
266
267        let index = match cfg.index_kind {
268            IndexKind::Hnsw => {
269                let index_opts = IndexOptions {
270                    dimensions: dims as usize,
271                    metric: cfg.metric,
272                    quantization: cfg.quantization,
273                    connectivity: cfg.connectivity,
274                    expansion_add: cfg.expansion_add,
275                    expansion_search: cfg.expansion_search,
276                    multi: false,
277                };
278                let index = Index::new(&index_opts)
279                    .map_err(|e| MCSError::MemoryError(format!("usearch init: {e}")))?;
280                AnnIndex::Hnsw(Arc::new(index))
281            }
282            IndexKind::Ivf => AnnIndex::Ivf(Box::new(IvfFlatIndex::new(
283                dims as usize,
284                IvfMetric::from_usearch(cfg.metric),
285                cfg.ivf_nlist,
286                cfg.ivf_nprobe,
287            ))),
288        };
289
290        let name_to_id = Arc::new(DashMap::new());
291        let id_to_name = Arc::new(DashMap::new());
292        let graph = Arc::new(RwLock::new(StableGraph::<EntityId, (), Directed, u32>::new()));
293        let node_map = Arc::new(DashMap::new());
294        let db = Mutex::new(conn);
295
296        let store = Self {
297            name_to_id,
298            id_to_name,
299            graph,
300            node_map,
301            index,
302            db,
303            dims,
304            count: AtomicUsize::new(0),
305            ivf_nprobe: cfg.ivf_nprobe,
306            db_path: db_path.to_path_buf(),
307        };
308        store.load_existing()?;
309
310        Ok(store)
311    }
312
313    fn load_existing(&self) -> Result<()> {
314        let conn = self.db.lock();
315        let count: usize = conn
316            .query_row("SELECT COUNT(*) FROM vector_embedding", [], |r| {
317                r.get::<_, i64>(0)
318            })
319            .map_err(sqlite_err)?
320            as usize;
321
322        if count == 0 {
323            return Ok(());
324        }
325
326        self.index.reserve(count)?;
327
328        let mut stmt = conn
329            .prepare("SELECT entity_id, dims, blob, model FROM vector_embedding")
330            .map_err(sqlite_err)?;
331
332        let rows = stmt
333            .query_map([], |row| {
334                let id: i64 = row.get(0)?;
335                let dims: i64 = row.get(1)?;
336                let blob: Vec<u8> = row.get(2)?;
337                let model: String = row.get(3)?;
338                Ok((id, dims, blob, model))
339            })
340            .map_err(sqlite_err)?;
341
342        for row in rows {
343            let (id, _row_dims, blob, _model) = row.map_err(sqlite_err)?;
344            let emb = parse_embedding_blob(&blob)?;
345            self.index.add(id as u64, emb)?;
346            self.count.fetch_add(1, Ordering::Relaxed);
347        }
348
349        // Train the IVF backend over the freshly-loaded set so a reopened,
350        // populated database gets sub-linear search immediately (HNSW: no-op).
351        self.index.train()?;
352
353        self.load_names_from_entity_table(&conn)?;
354        Ok(())
355    }
356
357    fn load_names_from_entity_table(&self, conn: &Connection) -> Result<()> {
358        let mut stmt = conn
359            .prepare("SELECT id, name FROM entity WHERE flags = 0")
360            .map_err(sqlite_err)?;
361        let rows = stmt
362            .query_map([], |row| {
363                let id: i64 = row.get(0)?;
364                let name: String = row.get(1)?;
365                Ok((id, name))
366            })
367            .map_err(sqlite_err)?;
368
369        self.name_to_id.clear();
370        self.id_to_name.clear();
371
372        for row in rows {
373            let (id, name) = row.map_err(sqlite_err)?;
374            self.name_to_id.insert(name.clone(), id);
375            self.id_to_name.insert(id, name);
376        }
377        Ok(())
378    }
379
380    fn get_entity_id_and_name(&self, conn: &Connection, entity_name: &str) -> Result<Option<(EntityId, String)>> {
381        if let Some(entry) = self.name_to_id.get(entity_name) {
382            let id = *entry;
383            let name = entity_name.to_string();
384            return Ok(Some((id, name)));
385        }
386        let h = crate::kg::name_hash(entity_name);
387        let mut stmt = conn
388            .prepare_cached(
389                "SELECT id, name FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
390            )
391            .map_err(sqlite_err)?;
392        match stmt.query_row(params![h, entity_name], |row| {
393            let id: i64 = row.get(0)?;
394            let name: String = row.get(1)?;
395            Ok((id, name))
396        }) {
397            Ok(tup) => {
398                self.name_to_id.insert(tup.1.clone(), tup.0);
399                self.id_to_name.insert(tup.0, tup.1.clone());
400                Ok(Some(tup))
401            }
402            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
403            Err(e) => Err(sqlite_err(e)),
404        }
405    }
406
407    pub fn upsert_embedding(&self, entity_name: &str, embedding: &[f32], model: &str) -> Result<()> {
408        if embedding.len() != self.dims as usize {
409            return Err(MCSError::InvalidParams(format!(
410                "Embedding dimension mismatch: got {}, expected {}",
411                embedding.len(),
412                self.dims
413            )));
414        }
415
416        let conn = self.db.lock();
417        let entity = self
418            .get_entity_id_and_name(&conn, entity_name)?
419            .ok_or_else(|| {
420                MCSError::InvalidParams(format!("Entity '{entity_name}' not found in KG"))
421            })?;
422        let entity_id = entity.0;
423
424        // Grow the index capacity in chunks rather than one slot per upsert, so
425        // a bulk load doesn't trigger a reallocation on every insert.
426        let needed = self.count.load(Ordering::Relaxed).saturating_add(1);
427        if needed > self.index.capacity() {
428            const CHUNK: usize = 1024;
429            let target = needed.div_ceil(CHUNK).saturating_mul(CHUNK);
430            self.index.reserve(target)?;
431        }
432        let existed = self.index.remove(entity_id as u64).unwrap_or(false);
433        self.index.add(entity_id as u64, embedding)?;
434
435        self.name_to_id
436            .insert(entity_name.to_string(), entity_id);
437        self.id_to_name.insert(entity_id, entity_name.to_string());
438
439        let blob = serialize_embedding(embedding);
440        let now = std::time::SystemTime::now()
441            .duration_since(std::time::UNIX_EPOCH)
442            .unwrap_or_default()
443            .as_micros() as i64;
444
445        conn.execute(
446            "INSERT OR REPLACE INTO vector_embedding (entity_id, dims, blob, model, created_us) VALUES (?1, ?2, ?3, ?4, ?5)",
447            params![entity_id, self.dims, blob, model, now],
448        )
449        .map_err(sqlite_err)?;
450
451        if !existed {
452            self.count.fetch_add(1, Ordering::Relaxed);
453        }
454        Ok(())
455    }
456
457    pub fn delete_embedding(&self, entity_name: &str) -> Result<bool> {
458        let conn = self.db.lock();
459        let entity_id = match self.name_to_id.get(entity_name) {
460            Some(entry) => *entry,
461            None => {
462                return Ok(false);
463            }
464        };
465
466        self.index.remove(entity_id as u64)?;
467
468        self.name_to_id.remove(entity_name);
469        self.id_to_name.remove(&entity_id);
470
471        conn.execute(
472            "DELETE FROM vector_embedding WHERE entity_id = ?1",
473            params![entity_id],
474        )
475        .map_err(sqlite_err)?;
476
477        {
478            let mut g = self.graph.write();
479            if let Some(nx) = self.node_map.get(&entity_id) {
480                g.remove_node(*nx);
481                self.node_map.remove(&entity_id);
482            }
483        }
484
485        self.count.fetch_sub(1, Ordering::Relaxed);
486        Ok(true)
487    }
488
489    pub fn search_embeddings(
490        &self,
491        query: &[f32],
492        top_k: usize,
493    ) -> Result<Vec<(EntityId, f32)>> {
494        if self.count.load(Ordering::Relaxed) == 0 {
495            return Ok(Vec::new());
496        }
497        let top_k = top_k.clamp(1, 100);
498        let matches = self.index.search(query, top_k, Some(self.ivf_nprobe))?;
499        Ok(matches
500            .into_iter()
501            .map(|(id, dist)| (id as EntityId, dist))
502            .collect())
503    }
504
505    pub fn search_entities_json(
506        &self,
507        query: &[f32],
508        top_k: usize,
509        entity_type_filter: Option<&str>,
510    ) -> Result<String> {
511        let results = self.search_embeddings(query, top_k)?;
512        if results.is_empty() {
513            return Ok(r#"{"results":[],"count":0}"#.to_string());
514        }
515
516        let conn = self.db.lock();
517        let mut out = String::with_capacity(128 + results.len() * 64);
518        out.push_str(r#"{"results":["#);
519        let mut first = true;
520        let mut actual_count = 0usize;
521
522        for &(id, dist) in &results {
523            let name = self
524                .id_to_name
525                .get(&id)
526                .map(|r| r.value().clone())
527                .or_else(|| {
528                    conn.query_row(
529                        "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
530                        params![id],
531                        |row| row.get::<_, String>(0),
532                    )
533                    .ok()
534                });
535
536            let name = match name {
537                Some(n) => n,
538                None => continue,
539            };
540
541            let etype: String = conn
542                .query_row(
543                    "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
544                    params![id],
545                    |row| row.get(0),
546                )
547                .unwrap_or_default();
548
549            if let Some(filter_type) = entity_type_filter
550                && etype != filter_type
551            {
552                continue;
553            }
554
555            if !first {
556                out.push(',');
557            }
558            first = false;
559
560            out.push_str(r#"{"name":"#);
561            push_json_str(&mut out, &name);
562            out.push_str(r#","entityType":"#);
563            push_json_str(&mut out, &etype);
564            write_f32(&mut out, dist);
565            out.push('}');
566            actual_count += 1;
567        }
568
569        out.push_str(r#"],"count":"#);
570        out.push_str(&actual_count.to_string());
571        out.push('}');
572        Ok(out)
573    }
574
575    pub fn build_search_response_json(&self, results: &[(EntityId, f32)]) -> String {
576        let mut out = String::with_capacity(128 + results.len() * 64);
577        out.push_str(r#"{"results":["#);
578        for (i, &(id, dist)) in results.iter().enumerate() {
579            if i > 0 {
580                out.push(',');
581            }
582            out.push_str(r#"{"entityId":"#);
583            out.push_str(&id.to_string());
584            out.push_str(r#","distance":"#);
585            write_f32(&mut out, dist);
586            out.push('}');
587        }
588        out.push_str(r#"],"count":"#);
589        out.push_str(&results.len().to_string());
590        out.push('}');
591        out
592    }
593
594    pub fn rebuild_graph_cache(&self) -> Result<()> {
595        let conn = self.db.lock();
596
597        let mut ent_stmt = conn
598            .prepare("SELECT entity_id FROM vector_embedding")
599            .map_err(sqlite_err)?;
600        let ids: Vec<EntityId> = ent_stmt
601            .query_map([], |r| r.get::<_, i64>(0))
602            .map_err(sqlite_err)?
603            .filter_map(|r| r.ok())
604            .collect();
605
606        let mut g = StableGraph::<EntityId, (), Directed, u32>::with_capacity(ids.len(), 0);
607        let nm = DashMap::new();
608
609        for &id in &ids {
610            let nx = g.add_node(id);
611            nm.insert(id, nx);
612        }
613
614        if !ids.is_empty() {
615            const BATCH_SIZE: usize = 5000;
616            for chunk in ids.chunks(BATCH_SIZE) {
617                let placeholders: Vec<String> = chunk.iter().map(|_| "?".to_string()).collect();
618                let sql = format!(
619                    "SELECT from_id, to_id FROM relation WHERE from_id IN ({}) AND to_id IN ({})",
620                    placeholders.join(","),
621                    placeholders.join(",")
622                );
623                let mut rel_stmt = conn.prepare(&sql).map_err(sqlite_err)?;
624
625                let mut param_values: Vec<&dyn rusqlite::types::ToSql> = Vec::with_capacity(chunk.len() * 2);
626                for id in chunk {
627                    param_values.push(id as &dyn rusqlite::types::ToSql);
628                }
629                for id in chunk {
630                    param_values.push(id as &dyn rusqlite::types::ToSql);
631                }
632
633                let rel_rows = rel_stmt
634                    .query_map(param_values.as_slice(), |row| {
635                        let from: i64 = row.get(0)?;
636                        let to: i64 = row.get(1)?;
637                        Ok((from, to))
638                    })
639                    .map_err(sqlite_err)?;
640
641                for rel in rel_rows {
642                    let (from, to) = rel.map_err(sqlite_err)?;
643                    if let (Some(f_nx), Some(t_nx)) = (nm.get(&from), nm.get(&to))
644                        && g.find_edge(*f_nx, *t_nx).is_none()
645                    {
646                        g.add_edge(*f_nx, *t_nx, ());
647                    }
648                }
649            }
650        }
651
652        *self.graph.write() = g;
653        self.node_map.clear();
654        for entry in nm.iter() {
655            self.node_map.insert(*entry.key(), *entry.value());
656        }
657
658        Ok(())
659    }
660
661    pub fn graph_node_count(&self) -> usize {
662        self.node_map.len()
663    }
664
665    pub fn graph_edge_count(&self) -> usize {
666        self.graph.read().edge_count()
667    }
668
669    pub fn get_entity_type(&self, entity_id: EntityId) -> Result<Option<String>> {
670        let conn = self.db.lock();
671        let etype = conn
672            .query_row(
673                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
674                params![entity_id],
675                |row| row.get(0),
676            )
677            .ok();
678        Ok(etype)
679    }
680
681    pub fn count(&self) -> usize {
682        self.count.load(Ordering::Relaxed)
683    }
684
685    pub const fn dims(&self) -> u32 {
686        self.dims
687    }
688
689    /// Approximate resident RAM used by the ANN index, in bytes.
690    pub fn index_memory_bytes(&self) -> usize {
691        self.index.memory_bytes()
692    }
693
694    /// Breakdown of index RAM into (graph_bytes, vectors_bytes). IVF has no graph
695    /// component, so its `graph_bytes` is 0.
696    pub fn index_memory_breakdown(&self) -> (usize, usize) {
697        self.index.memory_breakdown()
698    }
699
700    /// Current allocated capacity of the index (number of vectors it can hold
701    /// before the next reservation).
702    pub fn index_capacity(&self) -> usize {
703        self.index.capacity()
704    }
705
706    /// The active ANN backend.
707    pub const fn index_kind(&self) -> IndexKind {
708        self.index.kind()
709    }
710
711    /// Rebuild the ANN index structure: retrains the IVF centroids over the
712    /// current vectors (no-op for HNSW). Call after large batch ingestion to keep
713    /// IVF recall high.
714    pub fn reindex(&self) -> Result<()> {
715        self.index.train()
716    }
717
718    /// Resolve a live entity id by name (cache first, then the KG table).
719    pub fn entity_id_of(&self, name: &str) -> Result<Option<EntityId>> {
720        let conn = self.db.lock();
721        Ok(self.get_entity_id_and_name(&conn, name)?.map(|(id, _)| id))
722    }
723
724    /// Fetch the stored embedding for an entity id, if any.
725    pub fn get_embedding_by_id(&self, id: EntityId) -> Result<Option<Vec<f32>>> {
726        let conn = self.db.lock();
727        let blob: Option<Vec<u8>> = conn
728            .query_row(
729                "SELECT blob FROM vector_embedding WHERE entity_id = ?1",
730                params![id],
731                |r| r.get(0),
732            )
733            .ok();
734        match blob {
735            Some(b) => Ok(Some(parse_embedding_blob(&b)?.to_vec())),
736            None => Ok(None),
737        }
738    }
739
740    /// Fetch `(entity_id, embedding, model)` for an entity by name.
741    pub fn get_embedding_by_name(
742        &self,
743        name: &str,
744    ) -> Result<Option<(EntityId, Vec<f32>, String)>> {
745        let id = match self.entity_id_of(name)? {
746            Some(id) => id,
747            None => return Ok(None),
748        };
749        let conn = self.db.lock();
750        let row: Option<(Vec<u8>, String)> = conn
751            .query_row(
752                "SELECT blob, model FROM vector_embedding WHERE entity_id = ?1",
753                params![id],
754                |r| Ok((r.get(0)?, r.get(1)?)),
755            )
756            .ok();
757        match row {
758            Some((blob, model)) => Ok(Some((id, parse_embedding_blob(&blob)?.to_vec(), model))),
759            None => Ok(None),
760        }
761    }
762
763    /// Resolve an entity id to `(name, entityType)`, preferring the in-memory name
764    /// cache and reading the type from the KG.
765    pub fn resolve_name_type(&self, id: EntityId) -> (String, String) {
766        let conn = self.db.lock();
767        let name = self
768            .id_to_name
769            .get(&id)
770            .map(|r| r.value().clone())
771            .or_else(|| {
772                conn.query_row(
773                    "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
774                    params![id],
775                    |row| row.get::<_, String>(0),
776                )
777                .ok()
778            })
779            .unwrap_or_default();
780        let etype: String = conn
781            .query_row(
782                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
783                params![id],
784                |row| row.get(0),
785            )
786            .unwrap_or_default();
787        (name, etype)
788    }
789
790    /// k-NN that returns resolved `(id, name, entityType, distance)`, optionally
791    /// filtered by `entity_type` and excluding `exclude` ids. Over-fetches to
792    /// compensate for filtered-out rows.
793    pub fn search_resolved(
794        &self,
795        query: &[f32],
796        top_k: usize,
797        entity_type: Option<&str>,
798        exclude: &std::collections::HashSet<EntityId>,
799    ) -> Result<Vec<(EntityId, String, String, f32)>> {
800        let fetch = (top_k.saturating_mul(3) + exclude.len()).clamp(top_k, 100);
801        let raw = self.search_embeddings(query, fetch)?;
802        let mut out = Vec::with_capacity(top_k);
803        for (id, dist) in raw {
804            if exclude.contains(&id) {
805                continue;
806            }
807            let (name, etype) = self.resolve_name_type(id);
808            if name.is_empty() {
809                continue;
810            }
811            if let Some(ft) = entity_type
812                && etype != ft
813            {
814                continue;
815            }
816            out.push((id, name, etype, dist));
817            if out.len() >= top_k {
818                break;
819            }
820        }
821        Ok(out)
822    }
823
824    pub fn invalidate_entity_cache(&self, names: &[String]) {
825        for name in names {
826            if let Some((_, id)) = self.name_to_id.remove(name.as_str()) {
827                self.id_to_name.remove(&id);
828            }
829        }
830    }
831
832    pub fn name_to_id(&self) -> &DashMap<String, EntityId> {
833        &self.name_to_id
834    }
835
836    pub fn id_to_name(&self) -> &DashMap<EntityId, String> {
837        &self.id_to_name
838    }
839}
840
841fn write_f32(buf: &mut String, val: f32) {
842    use std::fmt::Write;
843    write!(buf, r#","score":{:.6}"#, val).unwrap();
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849    use crate::kg::GraphHandle;
850    use crate::config::{Durability, SqliteTuning};
851    use crate::types::Entity;
852    use std::num::NonZeroUsize;
853
854    struct TestEnv {
855        kg: GraphHandle,
856        vs: VectorStore,
857        _dir: tempfile::TempDir,
858    }
859
860    fn setup(dims: u32) -> TestEnv {
861        let dir = tempfile::TempDir::new().unwrap();
862        let db_path = dir.path().join("test.db");
863        let lru = NonZeroUsize::new(10000).unwrap();
864        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
865        let vs = VectorStore::new(&db_path, dims).unwrap();
866        TestEnv {
867            kg,
868            vs,
869            _dir: dir,
870        }
871    }
872
873    fn setup_ivf(dims: u32) -> TestEnv {
874        let dir = tempfile::TempDir::new().unwrap();
875        let db_path = dir.path().join("test.db");
876        let lru = NonZeroUsize::new(10000).unwrap();
877        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
878        let mut cfg = VectorConfig::new(dims);
879        cfg.index_kind = IndexKind::Ivf;
880        cfg.ivf_nlist = 4;
881        cfg.ivf_nprobe = 4;
882        let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
883        TestEnv {
884            kg,
885            vs,
886            _dir: dir,
887        }
888    }
889
890    fn create_test_entity(kg: &GraphHandle, name: &str, etype: &str) {
891        kg.create_entities(&[Entity {
892            name: name.into(),
893            entity_type: etype.into(),
894            observations: vec!["test observation".into()],
895        }])
896        .unwrap();
897    }
898
899    fn make_embedding(dims: u32, value: f32) -> Vec<f32> {
900        vec![value; dims as usize]
901    }
902
903    #[test]
904    fn test_vector_upsert_and_search() {
905        let env = setup(4);
906        create_test_entity(&env.kg, "alice", "person");
907        create_test_entity(&env.kg, "bob", "person");
908
909        let emb_a = make_embedding(4, 1.0);
910        let emb_b = make_embedding(4, 0.1);
911        env.vs.upsert_embedding("alice", &emb_a, "test-model").unwrap();
912        env.vs.upsert_embedding("bob", &emb_b, "test-model").unwrap();
913
914        let query = make_embedding(4, 1.0);
915        let results = env.vs.search_embeddings(&query, 10).unwrap();
916        assert_eq!(results.len(), 2);
917        assert!(results[0].1 < results[1].1);
918    }
919
920    #[test]
921    fn test_vector_delete_embedding() {
922        let env = setup(4);
923        create_test_entity(&env.kg, "alice", "person");
924        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
925        assert_eq!(env.vs.count(), 1);
926
927        let deleted = env.vs.delete_embedding("alice").unwrap();
928        assert!(deleted);
929        assert_eq!(env.vs.count(), 0);
930
931        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
932        assert!(results.is_empty());
933    }
934
935    #[test]
936    fn test_vector_upsert_nonexistent_entity() {
937        let env = setup(4);
938        let err = env.vs.upsert_embedding("nonexistent", &make_embedding(4, 1.0), "");
939        assert!(err.is_err());
940    }
941
942    #[test]
943    fn test_vector_dimension_mismatch() {
944        let env = setup(4);
945        create_test_entity(&env.kg, "alice", "person");
946        let err = env.vs.upsert_embedding("alice", &make_embedding(8, 1.0), "");
947        assert!(err.is_err());
948    }
949
950    #[test]
951    fn test_vector_search_top_k() {
952        let env = setup(4);
953        for i in 0..5 {
954            create_test_entity(&env.kg, &format!("e{i}"), "test");
955            env.vs.upsert_embedding(&format!("e{i}"), &make_embedding(4, i as f32 * 0.2), "")
956                .unwrap();
957        }
958        let results = env.vs.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
959        assert_eq!(results.len(), 3);
960    }
961
962    #[test]
963    fn test_vector_search_type_filter() {
964        let env = setup(4);
965        create_test_entity(&env.kg, "alice", "person");
966        create_test_entity(&env.kg, "acme", "organization");
967        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
968        env.vs.upsert_embedding("acme", &make_embedding(4, 0.95), "").unwrap();
969
970        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, Some("person")).unwrap();
971        assert!(json.contains("alice"));
972        assert!(!json.contains("acme"));
973    }
974
975    #[test]
976    fn test_vector_blob_roundtrip() {
977        let emb: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
978        let blob = serialize_embedding(&emb);
979        let parsed = parse_embedding_blob(&blob).unwrap();
980        assert_eq!(parsed.len(), emb.len());
981        for (a, b) in parsed.iter().zip(emb.iter()) {
982            assert!((a - b).abs() < 1e-6);
983        }
984    }
985
986    #[test]
987    fn test_vector_scratch_buffer() {
988        with_scratch(|buf| {
989            buf.push(1.0);
990            buf.push(2.0);
991            assert_eq!(buf.len(), 2);
992        });
993        with_scratch(|buf| {
994            assert!(buf.is_empty());
995            buf.extend_from_slice(&[3.0, 4.0, 5.0]);
996            assert_eq!(buf.len(), 3);
997        });
998    }
999
1000    #[test]
1001    fn test_vector_rebuild_graph_cache() {
1002        let env = setup(4);
1003        create_test_entity(&env.kg, "alice", "person");
1004        create_test_entity(&env.kg, "bob", "person");
1005        create_test_entity(&env.kg, "charlie", "person");
1006
1007        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1008        env.vs.upsert_embedding("bob", &make_embedding(4, 0.5), "").unwrap();
1009        env.vs.upsert_embedding("charlie", &make_embedding(4, 0.0), "").unwrap();
1010
1011        env.kg
1012            .create_relations(&[crate::types::Relation {
1013                from: "alice".into(),
1014                to: "bob".into(),
1015                relation_type: "knows".into(),
1016            }])
1017            .unwrap();
1018
1019        env.vs.rebuild_graph_cache().unwrap();
1020        assert_eq!(env.vs.graph_node_count(), 3);
1021        assert_eq!(env.vs.graph_edge_count(), 1);
1022    }
1023
1024    #[test]
1025    fn test_vector_upsert_replace() {
1026        let env = setup(4);
1027        create_test_entity(&env.kg, "alice", "person");
1028        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1029        env.vs.upsert_embedding("alice", &make_embedding(4, 0.5), "").unwrap();
1030        assert_eq!(env.vs.count(), 1);
1031
1032        let results = env.vs.search_embeddings(&make_embedding(4, 0.5), 10).unwrap();
1033        assert_eq!(results.len(), 1);
1034        let name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1035        assert_eq!(name.as_deref(), Some("alice"));
1036    }
1037
1038    #[test]
1039    fn test_vector_index_capacity_grows_in_chunks() {
1040        let env = setup(4);
1041        // Empty store reserves nothing.
1042        assert_eq!(env.vs.count(), 0);
1043
1044        // First insert reserves at least a full 1024 chunk up front (usearch may
1045        // over-allocate beyond that, but it must not reserve just one slot).
1046        create_test_entity(&env.kg, "e0", "t");
1047        env.vs.upsert_embedding("e0", &make_embedding(4, 0.0), "").unwrap();
1048        let cap_after_first = env.vs.index_capacity();
1049        assert!(cap_after_first >= 1024, "capacity {cap_after_first} < 1024");
1050
1051        // Inserts within the same chunk do not reallocate.
1052        for i in 1..50 {
1053            let name = format!("e{i}");
1054            create_test_entity(&env.kg, &name, "t");
1055            env.vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.01), "").unwrap();
1056        }
1057        assert_eq!(env.vs.count(), 50);
1058        assert_eq!(env.vs.index_capacity(), cap_after_first, "capacity changed mid-chunk");
1059
1060        // Overwriting an existing entity never grows capacity.
1061        env.vs.upsert_embedding("e0", &make_embedding(4, 0.5), "").unwrap();
1062        assert_eq!(env.vs.count(), 50);
1063        assert_eq!(env.vs.index_capacity(), cap_after_first);
1064
1065        // Memory accounting is exposed and non-zero once vectors are present.
1066        assert!(env.vs.index_memory_bytes() > 0);
1067        let (graph_bytes, vec_bytes) = env.vs.index_memory_breakdown();
1068        assert!(graph_bytes + vec_bytes > 0);
1069    }
1070
1071    #[test]
1072    fn test_vector_empty_store_search() {
1073        let env = setup(4);
1074        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1075        assert_eq!(json, r#"{"results":[],"count":0}"#);
1076    }
1077
1078    #[test]
1079    fn test_vector_persistence_across_reopen() {
1080        let dir = tempfile::TempDir::new().unwrap();
1081        let db_path = dir.path().join("persist.db");
1082        let lru = NonZeroUsize::new(10000).unwrap();
1083
1084        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1085        kg.create_entities(&[Entity {
1086            name: "alice".into(),
1087            entity_type: "person".into(),
1088            observations: vec![],
1089        }])
1090        .unwrap();
1091
1092        let vs1 = VectorStore::new(&db_path, 4).unwrap();
1093        vs1.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1094        assert_eq!(vs1.count(), 1);
1095        drop(vs1);
1096        drop(kg);
1097
1098        let kg2 = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1099        let vs2 = VectorStore::new(&db_path, 4).unwrap();
1100        assert_eq!(vs2.count(), 1);
1101
1102        let results = vs2.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1103        assert_eq!(results.len(), 1);
1104        drop(vs2);
1105        drop(kg2);
1106    }
1107
1108    #[test]
1109    fn test_vector_search_json_format() {
1110        let env = setup(4);
1111        create_test_entity(&env.kg, "alice", "person");
1112        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1113
1114        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1115        assert!(json.contains("alice"));
1116        assert!(json.contains("person"));
1117        assert!(json.contains("score"));
1118        assert!(json.contains("count"));
1119    }
1120
1121    #[test]
1122    fn test_vector_concurrent_upsert() {
1123        let env = setup(8);
1124        let vs = Arc::new(env.vs);
1125
1126        let mut threads = Vec::new();
1127        for i in 0..4 {
1128            let vs = Arc::clone(&vs);
1129            threads.push(std::thread::spawn(move || {
1130                let name = format!("thread_{i}");
1131                // entity creation happens through GraphHandle - shared
1132                vs.upsert_embedding(&name, &make_embedding(8, i as f32 * 0.25), "")
1133                    .ok();
1134            }));
1135        }
1136
1137        create_test_entity(&env.kg, "thread_0", "t");
1138        create_test_entity(&env.kg, "thread_1", "t");
1139        create_test_entity(&env.kg, "thread_2", "t");
1140        create_test_entity(&env.kg, "thread_3", "t");
1141
1142        for t in threads {
1143            t.join().unwrap();
1144        }
1145    }
1146
1147    // ── IVF backend (via VectorStore) ─────────────────────────────────────
1148
1149    #[test]
1150    fn test_ivf_store_upsert_search_delete() {
1151        let env = setup_ivf(4);
1152        assert_eq!(env.vs.index_kind(), IndexKind::Ivf);
1153        create_test_entity(&env.kg, "alice", "person");
1154        create_test_entity(&env.kg, "bob", "person");
1155        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "m").unwrap();
1156        env.vs.upsert_embedding("bob", &make_embedding(4, 0.1), "m").unwrap();
1157        assert_eq!(env.vs.count(), 2);
1158
1159        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1160        assert_eq!(results.len(), 2);
1161        // alice (all 1.0) is the closest match to the all-ones query.
1162        let top_name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1163        assert_eq!(top_name.as_deref(), Some("alice"));
1164
1165        assert!(env.vs.delete_embedding("alice").unwrap());
1166        assert_eq!(env.vs.count(), 1);
1167        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1168        assert_eq!(results.len(), 1);
1169    }
1170
1171    #[test]
1172    fn test_ivf_persistence_and_reindex() {
1173        let dir = tempfile::TempDir::new().unwrap();
1174        let db_path = dir.path().join("ivf.db");
1175        let lru = NonZeroUsize::new(10000).unwrap();
1176        let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1177        let mut cfg = VectorConfig::new(4);
1178        cfg.index_kind = IndexKind::Ivf;
1179        cfg.ivf_nlist = 3;
1180        cfg.ivf_nprobe = 3;
1181
1182        {
1183            let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
1184            for i in 0..12 {
1185                let name = format!("e{i}");
1186                create_test_entity(&kg, &name, "t");
1187                vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.1), "").unwrap();
1188            }
1189            vs.reindex().unwrap();
1190            assert_eq!(vs.count(), 12);
1191        }
1192
1193        // Reopen: embeddings reload and the IVF index retrains on load.
1194        let vs2 = VectorStore::with_config(&db_path, &cfg).unwrap();
1195        assert_eq!(vs2.count(), 12);
1196        let results = vs2.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
1197        assert!(!results.is_empty());
1198        // The exact match (e0 == all-zeros) should be the nearest.
1199        let top = vs2.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1200        assert_eq!(top.as_deref(), Some("e0"));
1201    }
1202
1203    // ── New retrieval helpers (backend-agnostic) ──────────────────────────
1204
1205    #[test]
1206    fn test_get_embedding_helpers() {
1207        let env = setup(4);
1208        create_test_entity(&env.kg, "alice", "person");
1209        let emb = vec![0.1, 0.2, 0.3, 0.4];
1210        env.vs.upsert_embedding("alice", &emb, "model-x").unwrap();
1211
1212        let id = env.vs.entity_id_of("alice").unwrap().unwrap();
1213        let by_id = env.vs.get_embedding_by_id(id).unwrap().unwrap();
1214        assert_eq!(by_id, emb);
1215
1216        let (got_id, got_emb, model) = env.vs.get_embedding_by_name("alice").unwrap().unwrap();
1217        assert_eq!(got_id, id);
1218        assert_eq!(got_emb, emb);
1219        assert_eq!(model, "model-x");
1220
1221        assert!(env.vs.get_embedding_by_name("nobody").unwrap().is_none());
1222    }
1223
1224    #[test]
1225    fn test_search_resolved_excludes_and_filters() {
1226        let env = setup(4);
1227        create_test_entity(&env.kg, "a", "doc");
1228        create_test_entity(&env.kg, "b", "doc");
1229        create_test_entity(&env.kg, "c", "note");
1230        env.vs.upsert_embedding("a", &make_embedding(4, 1.0), "").unwrap();
1231        env.vs.upsert_embedding("b", &make_embedding(4, 0.9), "").unwrap();
1232        env.vs.upsert_embedding("c", &make_embedding(4, 0.95), "").unwrap();
1233
1234        let id_a = env.vs.entity_id_of("a").unwrap().unwrap();
1235        let mut exclude = std::collections::HashSet::new();
1236        exclude.insert(id_a);
1237
1238        // Exclude "a"; without a type filter we expect b and c.
1239        let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, None, &exclude).unwrap();
1240        let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1241        assert!(!names.contains(&"a"));
1242        assert!(names.contains(&"b") && names.contains(&"c"));
1243
1244        // Now filter to type "doc": only "b" remains (a excluded, c is a note).
1245        let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, Some("doc"), &exclude).unwrap();
1246        let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1247        assert_eq!(names, vec!["b"]);
1248    }
1249}