Skip to main content

mcp_memory/
vector_store.rs

1use std::path::Path;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use parking_lot::{Mutex, RwLock};
7use petgraph::graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::Directed;
10use rusqlite::{params, Connection};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
13
14use crate::errors::{MCSError, Result};
15use crate::kg::push_json_str;
16
17pub type EntityId = i64;
18
19#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
20#[repr(C)]
21struct BlobHeader {
22    dims: u32,
23}
24
25/// Tunable parameters for the usearch HNSW index. Built from CLI flags in the
26/// `mcp-memory-vec` binary; [`VectorConfig::new`] supplies the defaults used by
27/// tests and any caller that only cares about the embedding dimension.
28#[derive(Clone, Copy, Debug)]
29pub struct VectorConfig {
30    /// Embedding dimension. All upserted/queried vectors must match this.
31    pub dims: u32,
32    /// Distance metric used by the index.
33    pub metric: MetricKind,
34    /// On-disk/in-index scalar representation (enables quantization).
35    pub quantization: ScalarKind,
36    /// HNSW graph degree (`M`). Higher = better recall, more memory.
37    pub connectivity: usize,
38    /// HNSW `efConstruction`. Higher = better index quality, slower inserts.
39    pub expansion_add: usize,
40    /// HNSW `efSearch`. Higher = better recall, slower queries.
41    pub expansion_search: usize,
42}
43
44impl VectorConfig {
45    /// Default HNSW configuration for the given embedding dimension.
46    pub const fn new(dims: u32) -> Self {
47        Self {
48            dims,
49            metric: MetricKind::Cos,
50            quantization: ScalarKind::F32,
51            connectivity: 16,
52            expansion_add: 200,
53            expansion_search: 50,
54        }
55    }
56}
57
58pub struct VectorStore {
59    pub name_to_id: Arc<DashMap<String, EntityId>>,
60    pub id_to_name: Arc<DashMap<EntityId, String>>,
61
62    pub(crate) graph: Arc<RwLock<StableGraph<EntityId, (), Directed, u32>>>,
63    pub(crate) node_map: Arc<DashMap<EntityId, NodeIndex<u32>>>,
64
65    pub index: Arc<Index>,
66    pub(crate) db: Mutex<Connection>,
67
68    pub dims: u32,
69    pub count: AtomicUsize,
70
71    pub db_path: std::path::PathBuf,
72}
73
74fn sqlite_err(e: rusqlite::Error) -> MCSError {
75    MCSError::IoError(std::io::Error::other(e))
76}
77
78thread_local! {
79    static SCRATCH: std::cell::RefCell<Vec<f32>> = const {
80        std::cell::RefCell::new(Vec::new())
81    };
82}
83
84pub fn with_scratch<R>(f: impl FnOnce(&mut Vec<f32>) -> R) -> R {
85    SCRATCH.with(|cell| {
86        let mut buf = cell.borrow_mut();
87        buf.clear();
88        f(&mut buf)
89    })
90}
91
92fn serialize_embedding(emb: &[f32]) -> Vec<u8> {
93    let header = BlobHeader {
94        dims: emb.len() as u32,
95    };
96    let f32_bytes: &[u8] = unsafe {
97        std::slice::from_raw_parts(emb.as_ptr() as *const u8, emb.len() * 4)
98    };
99    let mut bytes = Vec::with_capacity(4 + f32_bytes.len());
100    bytes.extend_from_slice(header.as_bytes());
101    bytes.extend_from_slice(f32_bytes);
102    bytes
103}
104
105fn parse_embedding_blob(blob: &[u8]) -> Result<&[f32]> {
106    let (header, rest) = BlobHeader::ref_from_prefix(blob)
107        .map_err(|_| MCSError::MemoryError("Invalid blob header".into()))?;
108    let count = header.dims as usize;
109    let bytes = rest
110        .get(..count * 4)
111        .ok_or_else(|| MCSError::MemoryError("Blob data too short".into()))?;
112    let emb = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, count) };
113    Ok(emb)
114}
115
116impl VectorStore {
117    /// Open a store with the default HNSW configuration for `dims`.
118    pub fn new(db_path: &Path, dims: u32) -> Result<Self> {
119        Self::with_config(db_path, &VectorConfig::new(dims))
120    }
121
122    /// Open a store with an explicit HNSW configuration.
123    pub fn with_config(db_path: &Path, cfg: &VectorConfig) -> Result<Self> {
124        let dims = cfg.dims;
125        let conn = Connection::open(db_path).map_err(sqlite_err)?;
126        conn.busy_timeout(std::time::Duration::from_secs(5))
127            .map_err(sqlite_err)?;
128        conn.execute_batch(
129            "PRAGMA journal_mode = WAL;
130             PRAGMA synchronous = NORMAL;
131             PRAGMA temp_store = MEMORY;
132             CREATE TABLE IF NOT EXISTS vector_embedding (
133                 entity_id INTEGER PRIMARY KEY,
134                 dims      INTEGER NOT NULL,
135                 blob      BLOB    NOT NULL,
136                 model     TEXT    NOT NULL DEFAULT '',
137                 created_us INTEGER NOT NULL
138             );",
139        )
140        .map_err(sqlite_err)?;
141
142        let index_opts = IndexOptions {
143            dimensions: dims as usize,
144            metric: cfg.metric,
145            quantization: cfg.quantization,
146            connectivity: cfg.connectivity,
147            expansion_add: cfg.expansion_add,
148            expansion_search: cfg.expansion_search,
149            multi: false,
150        };
151        let index = Index::new(&index_opts)
152            .map_err(|e| MCSError::MemoryError(format!("usearch init: {e}")))?;
153        let index = Arc::new(index);
154
155        let name_to_id = Arc::new(DashMap::new());
156        let id_to_name = Arc::new(DashMap::new());
157        let graph = Arc::new(RwLock::new(StableGraph::<EntityId, (), Directed, u32>::new()));
158        let node_map = Arc::new(DashMap::new());
159        let db = Mutex::new(conn);
160
161        let store = Self {
162            name_to_id,
163            id_to_name,
164            graph,
165            node_map,
166            index,
167            db,
168            dims,
169            count: AtomicUsize::new(0),
170            db_path: db_path.to_path_buf(),
171        };
172        store.load_existing()?;
173
174        Ok(store)
175    }
176
177    fn load_existing(&self) -> Result<()> {
178        let conn = self.db.lock();
179        let count: usize = conn
180            .query_row("SELECT COUNT(*) FROM vector_embedding", [], |r| {
181                r.get::<_, i64>(0)
182            })
183            .map_err(sqlite_err)?
184            as usize;
185
186        if count == 0 {
187            return Ok(());
188        }
189
190        self.index
191            .reserve_capacity_and_threads(count, 1)
192            .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
193
194        let mut stmt = conn
195            .prepare("SELECT entity_id, dims, blob, model FROM vector_embedding")
196            .map_err(sqlite_err)?;
197
198        let rows = stmt
199            .query_map([], |row| {
200                let id: i64 = row.get(0)?;
201                let dims: i64 = row.get(1)?;
202                let blob: Vec<u8> = row.get(2)?;
203                let model: String = row.get(3)?;
204                Ok((id, dims, blob, model))
205            })
206            .map_err(sqlite_err)?;
207
208        for row in rows {
209            let (id, _row_dims, blob, _model) = row.map_err(sqlite_err)?;
210            let emb = parse_embedding_blob(&blob)?;
211            self.index
212                .add(id as u64, emb)
213                .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}")))?;
214            self.count.fetch_add(1, Ordering::Relaxed);
215        }
216
217        if count > 0 {
218            self.load_names_from_entity_table(&conn)?;
219        }
220        Ok(())
221    }
222
223    fn load_names_from_entity_table(&self, conn: &Connection) -> Result<()> {
224        let mut stmt = conn
225            .prepare("SELECT id, name FROM entity WHERE flags = 0")
226            .map_err(sqlite_err)?;
227        let rows = stmt
228            .query_map([], |row| {
229                let id: i64 = row.get(0)?;
230                let name: String = row.get(1)?;
231                Ok((id, name))
232            })
233            .map_err(sqlite_err)?;
234
235        self.name_to_id.clear();
236        self.id_to_name.clear();
237
238        for row in rows {
239            let (id, name) = row.map_err(sqlite_err)?;
240            self.name_to_id.insert(name.clone(), id);
241            self.id_to_name.insert(id, name);
242        }
243        Ok(())
244    }
245
246    fn get_entity_id_and_name(&self, conn: &Connection, entity_name: &str) -> Result<Option<(EntityId, String)>> {
247        if let Some(entry) = self.name_to_id.get(entity_name) {
248            let id = *entry;
249            let name = entity_name.to_string();
250            return Ok(Some((id, name)));
251        }
252        let h = crate::kg::name_hash(entity_name);
253        let mut stmt = conn
254            .prepare_cached(
255                "SELECT id, name FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
256            )
257            .map_err(sqlite_err)?;
258        match stmt.query_row(params![h, entity_name], |row| {
259            let id: i64 = row.get(0)?;
260            let name: String = row.get(1)?;
261            Ok((id, name))
262        }) {
263            Ok(tup) => {
264                self.name_to_id.insert(tup.1.clone(), tup.0);
265                self.id_to_name.insert(tup.0, tup.1.clone());
266                Ok(Some(tup))
267            }
268            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
269            Err(e) => Err(sqlite_err(e)),
270        }
271    }
272
273    pub fn upsert_embedding(&self, entity_name: &str, embedding: &[f32], model: &str) -> Result<()> {
274        if embedding.len() != self.dims as usize {
275            return Err(MCSError::InvalidParams(format!(
276                "Embedding dimension mismatch: got {}, expected {}",
277                embedding.len(),
278                self.dims
279            )));
280        }
281
282        let conn = self.db.lock();
283        let entity = self
284            .get_entity_id_and_name(&conn, entity_name)?
285            .ok_or_else(|| {
286                MCSError::InvalidParams(format!("Entity '{entity_name}' not found in KG"))
287            })?;
288        let entity_id = entity.0;
289
290        let total = self.count.load(Ordering::Relaxed);
291        self.index
292            .reserve_capacity_and_threads(total.saturating_add(1), 1)
293            .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
294        let existed = self
295            .index
296            .remove(entity_id as u64)
297            .unwrap_or(0) > 0;
298        self.index
299            .add(entity_id as u64, embedding)
300            .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}")))?;
301
302        self.name_to_id
303            .insert(entity_name.to_string(), entity_id);
304        self.id_to_name.insert(entity_id, entity_name.to_string());
305
306        let blob = serialize_embedding(embedding);
307        let now = std::time::SystemTime::now()
308            .duration_since(std::time::UNIX_EPOCH)
309            .unwrap_or_default()
310            .as_micros() as i64;
311
312        conn.execute(
313            "INSERT OR REPLACE INTO vector_embedding (entity_id, dims, blob, model, created_us) VALUES (?1, ?2, ?3, ?4, ?5)",
314            params![entity_id, self.dims, blob, model, now],
315        )
316        .map_err(sqlite_err)?;
317
318        if !existed {
319            self.count.fetch_add(1, Ordering::Relaxed);
320        }
321        Ok(())
322    }
323
324    pub fn delete_embedding(&self, entity_name: &str) -> Result<bool> {
325        let conn = self.db.lock();
326        let entity_id = match self.name_to_id.get(entity_name) {
327            Some(entry) => *entry,
328            None => {
329                return Ok(false);
330            }
331        };
332
333        self.index
334            .remove(entity_id as u64)
335            .map_err(|e| MCSError::MemoryError(format!("usearch remove: {e}")))?;
336
337        self.name_to_id.remove(entity_name);
338        self.id_to_name.remove(&entity_id);
339
340        conn.execute(
341            "DELETE FROM vector_embedding WHERE entity_id = ?1",
342            params![entity_id],
343        )
344        .map_err(sqlite_err)?;
345
346        {
347            let mut g = self.graph.write();
348            if let Some(nx) = self.node_map.get(&entity_id) {
349                g.remove_node(*nx);
350                self.node_map.remove(&entity_id);
351            }
352        }
353
354        self.count.fetch_sub(1, Ordering::Relaxed);
355        Ok(true)
356    }
357
358    pub fn search_embeddings(
359        &self,
360        query: &[f32],
361        top_k: usize,
362    ) -> Result<Vec<(EntityId, f32)>> {
363        if self.count.load(Ordering::Relaxed) == 0 {
364            return Ok(Vec::new());
365        }
366        let top_k = top_k.clamp(1, 100);
367        let matches = self
368            .index
369            .search(query, top_k)
370            .map_err(|e| MCSError::MemoryError(format!("usearch search: {e}")))?;
371
372        let cap = matches.keys.len().min(matches.distances.len());
373        let mut results = Vec::with_capacity(cap);
374        for i in 0..cap {
375            let id = matches.keys[i] as EntityId;
376            let dist = matches.distances[i];
377            results.push((id, dist));
378        }
379        Ok(results)
380    }
381
382    pub fn search_entities_json(
383        &self,
384        query: &[f32],
385        top_k: usize,
386        entity_type_filter: Option<&str>,
387    ) -> Result<String> {
388        let results = self.search_embeddings(query, top_k)?;
389        if results.is_empty() {
390            return Ok(r#"{"results":[],"count":0}"#.to_string());
391        }
392
393        let conn = self.db.lock();
394        let mut out = String::with_capacity(128 + results.len() * 64);
395        out.push_str(r#"{"results":["#);
396        let mut first = true;
397        let mut actual_count = 0usize;
398
399        for &(id, dist) in &results {
400            let name = self
401                .id_to_name
402                .get(&id)
403                .map(|r| r.value().clone())
404                .or_else(|| {
405                    conn.query_row(
406                        "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
407                        params![id],
408                        |row| row.get::<_, String>(0),
409                    )
410                    .ok()
411                });
412
413            let name = match name {
414                Some(n) => n,
415                None => continue,
416            };
417
418            if let Some(filter_type) = entity_type_filter {
419                let actual_type: Option<String> = conn
420                    .query_row(
421                        "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
422                        params![id],
423                        |row| row.get(0),
424                    )
425                    .ok();
426                match actual_type {
427                    Some(t) if t == filter_type => {}
428                    _ => continue,
429                }
430            }
431
432            if !first {
433                out.push(',');
434            }
435            first = false;
436
437            let etype: String = conn
438                .query_row(
439                    "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
440                    params![id],
441                    |row| row.get(0),
442                )
443                .unwrap_or_default();
444
445            out.push_str(r#"{"name":"#);
446            push_json_str(&mut out, &name);
447            out.push_str(r#","entityType":"#);
448            push_json_str(&mut out, &etype);
449            write_f32(&mut out, dist);
450            out.push('}');
451            actual_count += 1;
452        }
453
454        out.push_str(r#"],"count":"#);
455        out.push_str(&actual_count.to_string());
456        out.push('}');
457        Ok(out)
458    }
459
460    pub fn build_search_response_json(&self, results: &[(EntityId, f32)]) -> String {
461        let mut out = String::with_capacity(128 + results.len() * 64);
462        out.push_str(r#"{"results":["#);
463        for (i, &(id, dist)) in results.iter().enumerate() {
464            if i > 0 {
465                out.push(',');
466            }
467            out.push_str(r#"{"entityId":"#);
468            out.push_str(&id.to_string());
469            out.push_str(r#","distance":"#);
470            write_f32(&mut out, dist);
471            out.push('}');
472        }
473        out.push_str(r#"],"count":"#);
474        out.push_str(&results.len().to_string());
475        out.push('}');
476        out
477    }
478
479    pub fn rebuild_graph_cache(&self) -> Result<()> {
480        let conn = self.db.lock();
481
482        let mut ent_stmt = conn
483            .prepare("SELECT entity_id FROM vector_embedding")
484            .map_err(sqlite_err)?;
485        let ids: Vec<EntityId> = ent_stmt
486            .query_map([], |r| r.get::<_, i64>(0))
487            .map_err(sqlite_err)?
488            .filter_map(|r| r.ok())
489            .collect();
490
491        let mut g = StableGraph::<EntityId, (), Directed, u32>::with_capacity(ids.len(), 0);
492        let nm = DashMap::new();
493
494        for &id in &ids {
495            let nx = g.add_node(id);
496            nm.insert(id, nx);
497        }
498
499        if !ids.is_empty() {
500            let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
501            let sql = format!(
502                "SELECT from_id, to_id FROM relation WHERE from_id IN ({}) AND to_id IN ({})",
503                placeholders.join(","),
504                placeholders.join(",")
505            );
506            let mut rel_stmt = conn.prepare(&sql).map_err(sqlite_err)?;
507
508            let mut param_values: Vec<&dyn rusqlite::types::ToSql> = Vec::with_capacity(ids.len() * 2);
509            for id in &ids {
510                param_values.push(id as &dyn rusqlite::types::ToSql);
511            }
512            for id in &ids {
513                param_values.push(id as &dyn rusqlite::types::ToSql);
514            }
515
516            let rel_rows = rel_stmt
517                .query_map(param_values.as_slice(), |row| {
518                    let from: i64 = row.get(0)?;
519                    let to: i64 = row.get(1)?;
520                    Ok((from, to))
521                })
522                .map_err(sqlite_err)?;
523
524            for rel in rel_rows {
525                let (from, to) = rel.map_err(sqlite_err)?;
526                if let (Some(f_nx), Some(t_nx)) = (nm.get(&from), nm.get(&to))
527                    && g.find_edge(*f_nx, *t_nx).is_none()
528                {
529                    g.add_edge(*f_nx, *t_nx, ());
530                }
531            }
532        }
533
534        *self.graph.write() = g;
535        self.node_map.clear();
536        for entry in nm.iter() {
537            self.node_map.insert(*entry.key(), *entry.value());
538        }
539
540        Ok(())
541    }
542
543    pub fn graph_node_count(&self) -> usize {
544        self.node_map.len()
545    }
546
547    pub fn graph_edge_count(&self) -> usize {
548        self.graph.read().edge_count()
549    }
550
551    pub fn get_entity_type(&self, entity_id: EntityId) -> Result<Option<String>> {
552        let conn = self.db.lock();
553        let etype = conn
554            .query_row(
555                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
556                params![entity_id],
557                |row| row.get(0),
558            )
559            .ok();
560        Ok(etype)
561    }
562
563    pub fn count(&self) -> usize {
564        self.count.load(Ordering::Relaxed)
565    }
566
567    pub const fn dims(&self) -> u32 {
568        self.dims
569    }
570
571    pub fn name_to_id(&self) -> &DashMap<String, EntityId> {
572        &self.name_to_id
573    }
574
575    pub fn id_to_name(&self) -> &DashMap<EntityId, String> {
576        &self.id_to_name
577    }
578}
579
580fn write_f32(buf: &mut String, val: f32) {
581    use std::fmt::Write;
582    write!(buf, r#","score":{:.6}"#, val).unwrap();
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::kg::GraphHandle;
589    use crate::config::Durability;
590    use crate::types::Entity;
591    use std::num::NonZeroUsize;
592
593    struct TestEnv {
594        kg: GraphHandle,
595        vs: VectorStore,
596        _dir: tempfile::TempDir,
597    }
598
599    fn setup(dims: u32) -> TestEnv {
600        let dir = tempfile::TempDir::new().unwrap();
601        let db_path = dir.path().join("test.db");
602        let lru = NonZeroUsize::new(10000).unwrap();
603        let kg = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
604        let vs = VectorStore::new(&db_path, dims).unwrap();
605        TestEnv {
606            kg,
607            vs,
608            _dir: dir,
609        }
610    }
611
612    fn create_test_entity(kg: &GraphHandle, name: &str, etype: &str) {
613        kg.create_entities(&[Entity {
614            name: name.into(),
615            entity_type: etype.into(),
616            observations: vec!["test observation".into()],
617        }])
618        .unwrap();
619    }
620
621    fn make_embedding(dims: u32, value: f32) -> Vec<f32> {
622        vec![value; dims as usize]
623    }
624
625    #[test]
626    fn test_vector_upsert_and_search() {
627        let env = setup(4);
628        create_test_entity(&env.kg, "alice", "person");
629        create_test_entity(&env.kg, "bob", "person");
630
631        let emb_a = make_embedding(4, 1.0);
632        let emb_b = make_embedding(4, 0.1);
633        env.vs.upsert_embedding("alice", &emb_a, "test-model").unwrap();
634        env.vs.upsert_embedding("bob", &emb_b, "test-model").unwrap();
635
636        let query = make_embedding(4, 1.0);
637        let results = env.vs.search_embeddings(&query, 10).unwrap();
638        assert_eq!(results.len(), 2);
639        assert!(results[0].1 < results[1].1);
640    }
641
642    #[test]
643    fn test_vector_delete_embedding() {
644        let env = setup(4);
645        create_test_entity(&env.kg, "alice", "person");
646        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
647        assert_eq!(env.vs.count(), 1);
648
649        let deleted = env.vs.delete_embedding("alice").unwrap();
650        assert!(deleted);
651        assert_eq!(env.vs.count(), 0);
652
653        let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
654        assert!(results.is_empty());
655    }
656
657    #[test]
658    fn test_vector_upsert_nonexistent_entity() {
659        let env = setup(4);
660        let err = env.vs.upsert_embedding("nonexistent", &make_embedding(4, 1.0), "");
661        assert!(err.is_err());
662    }
663
664    #[test]
665    fn test_vector_dimension_mismatch() {
666        let env = setup(4);
667        create_test_entity(&env.kg, "alice", "person");
668        let err = env.vs.upsert_embedding("alice", &make_embedding(8, 1.0), "");
669        assert!(err.is_err());
670    }
671
672    #[test]
673    fn test_vector_search_top_k() {
674        let env = setup(4);
675        for i in 0..5 {
676            create_test_entity(&env.kg, &format!("e{i}"), "test");
677            env.vs.upsert_embedding(&format!("e{i}"), &make_embedding(4, i as f32 * 0.2), "")
678                .unwrap();
679        }
680        let results = env.vs.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
681        assert_eq!(results.len(), 3);
682    }
683
684    #[test]
685    fn test_vector_search_type_filter() {
686        let env = setup(4);
687        create_test_entity(&env.kg, "alice", "person");
688        create_test_entity(&env.kg, "acme", "organization");
689        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
690        env.vs.upsert_embedding("acme", &make_embedding(4, 0.95), "").unwrap();
691
692        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, Some("person")).unwrap();
693        assert!(json.contains("alice"));
694        assert!(!json.contains("acme"));
695    }
696
697    #[test]
698    fn test_vector_blob_roundtrip() {
699        let emb: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
700        let blob = serialize_embedding(&emb);
701        let parsed = parse_embedding_blob(&blob).unwrap();
702        assert_eq!(parsed.len(), emb.len());
703        for (a, b) in parsed.iter().zip(emb.iter()) {
704            assert!((a - b).abs() < 1e-6);
705        }
706    }
707
708    #[test]
709    fn test_vector_scratch_buffer() {
710        with_scratch(|buf| {
711            buf.push(1.0);
712            buf.push(2.0);
713            assert_eq!(buf.len(), 2);
714        });
715        with_scratch(|buf| {
716            assert!(buf.is_empty());
717            buf.extend_from_slice(&[3.0, 4.0, 5.0]);
718            assert_eq!(buf.len(), 3);
719        });
720    }
721
722    #[test]
723    fn test_vector_rebuild_graph_cache() {
724        let env = setup(4);
725        create_test_entity(&env.kg, "alice", "person");
726        create_test_entity(&env.kg, "bob", "person");
727        create_test_entity(&env.kg, "charlie", "person");
728
729        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
730        env.vs.upsert_embedding("bob", &make_embedding(4, 0.5), "").unwrap();
731        env.vs.upsert_embedding("charlie", &make_embedding(4, 0.0), "").unwrap();
732
733        env.kg
734            .create_relations(&[crate::types::Relation {
735                from: "alice".into(),
736                to: "bob".into(),
737                relation_type: "knows".into(),
738            }])
739            .unwrap();
740
741        env.vs.rebuild_graph_cache().unwrap();
742        assert_eq!(env.vs.graph_node_count(), 3);
743        assert_eq!(env.vs.graph_edge_count(), 1);
744    }
745
746    #[test]
747    fn test_vector_upsert_replace() {
748        let env = setup(4);
749        create_test_entity(&env.kg, "alice", "person");
750        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
751        env.vs.upsert_embedding("alice", &make_embedding(4, 0.5), "").unwrap();
752        assert_eq!(env.vs.count(), 1);
753
754        let results = env.vs.search_embeddings(&make_embedding(4, 0.5), 10).unwrap();
755        assert_eq!(results.len(), 1);
756        let name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
757        assert_eq!(name.as_deref(), Some("alice"));
758    }
759
760    #[test]
761    fn test_vector_empty_store_search() {
762        let env = setup(4);
763        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
764        assert_eq!(json, r#"{"results":[],"count":0}"#);
765    }
766
767    #[test]
768    fn test_vector_persistence_across_reopen() {
769        let dir = tempfile::TempDir::new().unwrap();
770        let db_path = dir.path().join("persist.db");
771        let lru = NonZeroUsize::new(10000).unwrap();
772
773        let kg = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
774        kg.create_entities(&[Entity {
775            name: "alice".into(),
776            entity_type: "person".into(),
777            observations: vec![],
778        }])
779        .unwrap();
780
781        let vs1 = VectorStore::new(&db_path, 4).unwrap();
782        vs1.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
783        assert_eq!(vs1.count(), 1);
784        drop(vs1);
785        drop(kg);
786
787        let kg2 = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
788        let vs2 = VectorStore::new(&db_path, 4).unwrap();
789        assert_eq!(vs2.count(), 1);
790
791        let results = vs2.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
792        assert_eq!(results.len(), 1);
793        drop(vs2);
794        drop(kg2);
795    }
796
797    #[test]
798    fn test_vector_search_json_format() {
799        let env = setup(4);
800        create_test_entity(&env.kg, "alice", "person");
801        env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
802
803        let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
804        assert!(json.contains("alice"));
805        assert!(json.contains("person"));
806        assert!(json.contains("score"));
807        assert!(json.contains("count"));
808    }
809
810    #[test]
811    fn test_vector_concurrent_upsert() {
812        let env = setup(8);
813        let vs = Arc::new(env.vs);
814
815        let mut threads = Vec::new();
816        for i in 0..4 {
817            let vs = Arc::clone(&vs);
818            threads.push(std::thread::spawn(move || {
819                let name = format!("thread_{i}");
820                // entity creation happens through GraphHandle - shared
821                vs.upsert_embedding(&name, &make_embedding(8, i as f32 * 0.25), "")
822                    .ok();
823            }));
824        }
825
826        create_test_entity(&env.kg, "thread_0", "t");
827        create_test_entity(&env.kg, "thread_1", "t");
828        create_test_entity(&env.kg, "thread_2", "t");
829        create_test_entity(&env.kg, "thread_3", "t");
830
831        for t in threads {
832            t.join().unwrap();
833        }
834    }
835}