1use std::path::Path;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use parking_lot::{Mutex, RwLock};
7use petgraph::graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::Directed;
10use rusqlite::{params, Connection};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
13
14use crate::errors::{MCSError, Result};
15use crate::kg::push_json_str;
16
17pub type EntityId = i64;
18
19#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
20#[repr(C)]
21struct BlobHeader {
22 dims: u32,
23}
24
25#[derive(Clone, Copy, Debug)]
29pub struct VectorConfig {
30 pub dims: u32,
32 pub metric: MetricKind,
34 pub quantization: ScalarKind,
36 pub connectivity: usize,
38 pub expansion_add: usize,
40 pub expansion_search: usize,
42}
43
44impl VectorConfig {
45 pub const fn new(dims: u32) -> Self {
47 Self {
48 dims,
49 metric: MetricKind::Cos,
50 quantization: ScalarKind::F32,
51 connectivity: 16,
52 expansion_add: 200,
53 expansion_search: 50,
54 }
55 }
56}
57
58pub struct VectorStore {
59 pub name_to_id: Arc<DashMap<String, EntityId>>,
60 pub id_to_name: Arc<DashMap<EntityId, String>>,
61
62 pub(crate) graph: Arc<RwLock<StableGraph<EntityId, (), Directed, u32>>>,
63 pub(crate) node_map: Arc<DashMap<EntityId, NodeIndex<u32>>>,
64
65 pub index: Arc<Index>,
66 pub(crate) db: Mutex<Connection>,
67
68 pub dims: u32,
69 pub count: AtomicUsize,
70
71 pub db_path: std::path::PathBuf,
72}
73
74fn sqlite_err(e: rusqlite::Error) -> MCSError {
75 MCSError::IoError(std::io::Error::other(e))
76}
77
78thread_local! {
79 static SCRATCH: std::cell::RefCell<Vec<f32>> = const {
80 std::cell::RefCell::new(Vec::new())
81 };
82}
83
84pub fn with_scratch<R>(f: impl FnOnce(&mut Vec<f32>) -> R) -> R {
85 SCRATCH.with(|cell| {
86 let mut buf = cell.borrow_mut();
87 buf.clear();
88 f(&mut buf)
89 })
90}
91
92fn serialize_embedding(emb: &[f32]) -> Vec<u8> {
93 let header = BlobHeader {
94 dims: emb.len() as u32,
95 };
96 let f32_bytes: &[u8] = unsafe {
97 std::slice::from_raw_parts(emb.as_ptr() as *const u8, emb.len() * 4)
98 };
99 let mut bytes = Vec::with_capacity(4 + f32_bytes.len());
100 bytes.extend_from_slice(header.as_bytes());
101 bytes.extend_from_slice(f32_bytes);
102 bytes
103}
104
105fn parse_embedding_blob(blob: &[u8]) -> Result<&[f32]> {
106 let (header, rest) = BlobHeader::ref_from_prefix(blob)
107 .map_err(|_| MCSError::MemoryError("Invalid blob header".into()))?;
108 let count = header.dims as usize;
109 let bytes = rest
110 .get(..count * 4)
111 .ok_or_else(|| MCSError::MemoryError("Blob data too short".into()))?;
112 let emb = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, count) };
113 Ok(emb)
114}
115
116impl VectorStore {
117 pub fn new(db_path: &Path, dims: u32) -> Result<Self> {
119 Self::with_config(db_path, &VectorConfig::new(dims))
120 }
121
122 pub fn with_config(db_path: &Path, cfg: &VectorConfig) -> Result<Self> {
124 let dims = cfg.dims;
125 let conn = Connection::open(db_path).map_err(sqlite_err)?;
126 conn.busy_timeout(std::time::Duration::from_secs(5))
127 .map_err(sqlite_err)?;
128 conn.execute_batch(
129 "PRAGMA journal_mode = WAL;
130 PRAGMA synchronous = NORMAL;
131 PRAGMA temp_store = MEMORY;
132 CREATE TABLE IF NOT EXISTS vector_embedding (
133 entity_id INTEGER PRIMARY KEY,
134 dims INTEGER NOT NULL,
135 blob BLOB NOT NULL,
136 model TEXT NOT NULL DEFAULT '',
137 created_us INTEGER NOT NULL
138 );",
139 )
140 .map_err(sqlite_err)?;
141
142 let index_opts = IndexOptions {
143 dimensions: dims as usize,
144 metric: cfg.metric,
145 quantization: cfg.quantization,
146 connectivity: cfg.connectivity,
147 expansion_add: cfg.expansion_add,
148 expansion_search: cfg.expansion_search,
149 multi: false,
150 };
151 let index = Index::new(&index_opts)
152 .map_err(|e| MCSError::MemoryError(format!("usearch init: {e}")))?;
153 let index = Arc::new(index);
154
155 let name_to_id = Arc::new(DashMap::new());
156 let id_to_name = Arc::new(DashMap::new());
157 let graph = Arc::new(RwLock::new(StableGraph::<EntityId, (), Directed, u32>::new()));
158 let node_map = Arc::new(DashMap::new());
159 let db = Mutex::new(conn);
160
161 let store = Self {
162 name_to_id,
163 id_to_name,
164 graph,
165 node_map,
166 index,
167 db,
168 dims,
169 count: AtomicUsize::new(0),
170 db_path: db_path.to_path_buf(),
171 };
172 store.load_existing()?;
173
174 Ok(store)
175 }
176
177 fn load_existing(&self) -> Result<()> {
178 let conn = self.db.lock();
179 let count: usize = conn
180 .query_row("SELECT COUNT(*) FROM vector_embedding", [], |r| {
181 r.get::<_, i64>(0)
182 })
183 .map_err(sqlite_err)?
184 as usize;
185
186 if count == 0 {
187 return Ok(());
188 }
189
190 self.index
191 .reserve_capacity_and_threads(count, 1)
192 .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
193
194 let mut stmt = conn
195 .prepare("SELECT entity_id, dims, blob, model FROM vector_embedding")
196 .map_err(sqlite_err)?;
197
198 let rows = stmt
199 .query_map([], |row| {
200 let id: i64 = row.get(0)?;
201 let dims: i64 = row.get(1)?;
202 let blob: Vec<u8> = row.get(2)?;
203 let model: String = row.get(3)?;
204 Ok((id, dims, blob, model))
205 })
206 .map_err(sqlite_err)?;
207
208 for row in rows {
209 let (id, _row_dims, blob, _model) = row.map_err(sqlite_err)?;
210 let emb = parse_embedding_blob(&blob)?;
211 self.index
212 .add(id as u64, emb)
213 .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}")))?;
214 self.count.fetch_add(1, Ordering::Relaxed);
215 }
216
217 if count > 0 {
218 self.load_names_from_entity_table(&conn)?;
219 }
220 Ok(())
221 }
222
223 fn load_names_from_entity_table(&self, conn: &Connection) -> Result<()> {
224 let mut stmt = conn
225 .prepare("SELECT id, name FROM entity WHERE flags = 0")
226 .map_err(sqlite_err)?;
227 let rows = stmt
228 .query_map([], |row| {
229 let id: i64 = row.get(0)?;
230 let name: String = row.get(1)?;
231 Ok((id, name))
232 })
233 .map_err(sqlite_err)?;
234
235 self.name_to_id.clear();
236 self.id_to_name.clear();
237
238 for row in rows {
239 let (id, name) = row.map_err(sqlite_err)?;
240 self.name_to_id.insert(name.clone(), id);
241 self.id_to_name.insert(id, name);
242 }
243 Ok(())
244 }
245
246 fn get_entity_id_and_name(&self, conn: &Connection, entity_name: &str) -> Result<Option<(EntityId, String)>> {
247 if let Some(entry) = self.name_to_id.get(entity_name) {
248 let id = *entry;
249 let name = entity_name.to_string();
250 return Ok(Some((id, name)));
251 }
252 let h = crate::kg::name_hash(entity_name);
253 let mut stmt = conn
254 .prepare_cached(
255 "SELECT id, name FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
256 )
257 .map_err(sqlite_err)?;
258 match stmt.query_row(params![h, entity_name], |row| {
259 let id: i64 = row.get(0)?;
260 let name: String = row.get(1)?;
261 Ok((id, name))
262 }) {
263 Ok(tup) => {
264 self.name_to_id.insert(tup.1.clone(), tup.0);
265 self.id_to_name.insert(tup.0, tup.1.clone());
266 Ok(Some(tup))
267 }
268 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
269 Err(e) => Err(sqlite_err(e)),
270 }
271 }
272
273 pub fn upsert_embedding(&self, entity_name: &str, embedding: &[f32], model: &str) -> Result<()> {
274 if embedding.len() != self.dims as usize {
275 return Err(MCSError::InvalidParams(format!(
276 "Embedding dimension mismatch: got {}, expected {}",
277 embedding.len(),
278 self.dims
279 )));
280 }
281
282 let conn = self.db.lock();
283 let entity = self
284 .get_entity_id_and_name(&conn, entity_name)?
285 .ok_or_else(|| {
286 MCSError::InvalidParams(format!("Entity '{entity_name}' not found in KG"))
287 })?;
288 let entity_id = entity.0;
289
290 let total = self.count.load(Ordering::Relaxed);
291 self.index
292 .reserve_capacity_and_threads(total.saturating_add(1), 1)
293 .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
294 let existed = self
295 .index
296 .remove(entity_id as u64)
297 .unwrap_or(0) > 0;
298 self.index
299 .add(entity_id as u64, embedding)
300 .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}")))?;
301
302 self.name_to_id
303 .insert(entity_name.to_string(), entity_id);
304 self.id_to_name.insert(entity_id, entity_name.to_string());
305
306 let blob = serialize_embedding(embedding);
307 let now = std::time::SystemTime::now()
308 .duration_since(std::time::UNIX_EPOCH)
309 .unwrap_or_default()
310 .as_micros() as i64;
311
312 conn.execute(
313 "INSERT OR REPLACE INTO vector_embedding (entity_id, dims, blob, model, created_us) VALUES (?1, ?2, ?3, ?4, ?5)",
314 params![entity_id, self.dims, blob, model, now],
315 )
316 .map_err(sqlite_err)?;
317
318 if !existed {
319 self.count.fetch_add(1, Ordering::Relaxed);
320 }
321 Ok(())
322 }
323
324 pub fn delete_embedding(&self, entity_name: &str) -> Result<bool> {
325 let conn = self.db.lock();
326 let entity_id = match self.name_to_id.get(entity_name) {
327 Some(entry) => *entry,
328 None => {
329 return Ok(false);
330 }
331 };
332
333 self.index
334 .remove(entity_id as u64)
335 .map_err(|e| MCSError::MemoryError(format!("usearch remove: {e}")))?;
336
337 self.name_to_id.remove(entity_name);
338 self.id_to_name.remove(&entity_id);
339
340 conn.execute(
341 "DELETE FROM vector_embedding WHERE entity_id = ?1",
342 params![entity_id],
343 )
344 .map_err(sqlite_err)?;
345
346 {
347 let mut g = self.graph.write();
348 if let Some(nx) = self.node_map.get(&entity_id) {
349 g.remove_node(*nx);
350 self.node_map.remove(&entity_id);
351 }
352 }
353
354 self.count.fetch_sub(1, Ordering::Relaxed);
355 Ok(true)
356 }
357
358 pub fn search_embeddings(
359 &self,
360 query: &[f32],
361 top_k: usize,
362 ) -> Result<Vec<(EntityId, f32)>> {
363 if self.count.load(Ordering::Relaxed) == 0 {
364 return Ok(Vec::new());
365 }
366 let top_k = top_k.clamp(1, 100);
367 let matches = self
368 .index
369 .search(query, top_k)
370 .map_err(|e| MCSError::MemoryError(format!("usearch search: {e}")))?;
371
372 let cap = matches.keys.len().min(matches.distances.len());
373 let mut results = Vec::with_capacity(cap);
374 for i in 0..cap {
375 let id = matches.keys[i] as EntityId;
376 let dist = matches.distances[i];
377 results.push((id, dist));
378 }
379 Ok(results)
380 }
381
382 pub fn search_entities_json(
383 &self,
384 query: &[f32],
385 top_k: usize,
386 entity_type_filter: Option<&str>,
387 ) -> Result<String> {
388 let results = self.search_embeddings(query, top_k)?;
389 if results.is_empty() {
390 return Ok(r#"{"results":[],"count":0}"#.to_string());
391 }
392
393 let conn = self.db.lock();
394 let mut out = String::with_capacity(128 + results.len() * 64);
395 out.push_str(r#"{"results":["#);
396 let mut first = true;
397 let mut actual_count = 0usize;
398
399 for &(id, dist) in &results {
400 let name = self
401 .id_to_name
402 .get(&id)
403 .map(|r| r.value().clone())
404 .or_else(|| {
405 conn.query_row(
406 "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
407 params![id],
408 |row| row.get::<_, String>(0),
409 )
410 .ok()
411 });
412
413 let name = match name {
414 Some(n) => n,
415 None => continue,
416 };
417
418 if let Some(filter_type) = entity_type_filter {
419 let actual_type: Option<String> = conn
420 .query_row(
421 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
422 params![id],
423 |row| row.get(0),
424 )
425 .ok();
426 match actual_type {
427 Some(t) if t == filter_type => {}
428 _ => continue,
429 }
430 }
431
432 if !first {
433 out.push(',');
434 }
435 first = false;
436
437 let etype: String = conn
438 .query_row(
439 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
440 params![id],
441 |row| row.get(0),
442 )
443 .unwrap_or_default();
444
445 out.push_str(r#"{"name":"#);
446 push_json_str(&mut out, &name);
447 out.push_str(r#","entityType":"#);
448 push_json_str(&mut out, &etype);
449 write_f32(&mut out, dist);
450 out.push('}');
451 actual_count += 1;
452 }
453
454 out.push_str(r#"],"count":"#);
455 out.push_str(&actual_count.to_string());
456 out.push('}');
457 Ok(out)
458 }
459
460 pub fn build_search_response_json(&self, results: &[(EntityId, f32)]) -> String {
461 let mut out = String::with_capacity(128 + results.len() * 64);
462 out.push_str(r#"{"results":["#);
463 for (i, &(id, dist)) in results.iter().enumerate() {
464 if i > 0 {
465 out.push(',');
466 }
467 out.push_str(r#"{"entityId":"#);
468 out.push_str(&id.to_string());
469 out.push_str(r#","distance":"#);
470 write_f32(&mut out, dist);
471 out.push('}');
472 }
473 out.push_str(r#"],"count":"#);
474 out.push_str(&results.len().to_string());
475 out.push('}');
476 out
477 }
478
479 pub fn rebuild_graph_cache(&self) -> Result<()> {
480 let conn = self.db.lock();
481
482 let mut ent_stmt = conn
483 .prepare("SELECT entity_id FROM vector_embedding")
484 .map_err(sqlite_err)?;
485 let ids: Vec<EntityId> = ent_stmt
486 .query_map([], |r| r.get::<_, i64>(0))
487 .map_err(sqlite_err)?
488 .filter_map(|r| r.ok())
489 .collect();
490
491 let mut g = StableGraph::<EntityId, (), Directed, u32>::with_capacity(ids.len(), 0);
492 let nm = DashMap::new();
493
494 for &id in &ids {
495 let nx = g.add_node(id);
496 nm.insert(id, nx);
497 }
498
499 if !ids.is_empty() {
500 let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
501 let sql = format!(
502 "SELECT from_id, to_id FROM relation WHERE from_id IN ({}) AND to_id IN ({})",
503 placeholders.join(","),
504 placeholders.join(",")
505 );
506 let mut rel_stmt = conn.prepare(&sql).map_err(sqlite_err)?;
507
508 let mut param_values: Vec<&dyn rusqlite::types::ToSql> = Vec::with_capacity(ids.len() * 2);
509 for id in &ids {
510 param_values.push(id as &dyn rusqlite::types::ToSql);
511 }
512 for id in &ids {
513 param_values.push(id as &dyn rusqlite::types::ToSql);
514 }
515
516 let rel_rows = rel_stmt
517 .query_map(param_values.as_slice(), |row| {
518 let from: i64 = row.get(0)?;
519 let to: i64 = row.get(1)?;
520 Ok((from, to))
521 })
522 .map_err(sqlite_err)?;
523
524 for rel in rel_rows {
525 let (from, to) = rel.map_err(sqlite_err)?;
526 if let (Some(f_nx), Some(t_nx)) = (nm.get(&from), nm.get(&to))
527 && g.find_edge(*f_nx, *t_nx).is_none()
528 {
529 g.add_edge(*f_nx, *t_nx, ());
530 }
531 }
532 }
533
534 *self.graph.write() = g;
535 self.node_map.clear();
536 for entry in nm.iter() {
537 self.node_map.insert(*entry.key(), *entry.value());
538 }
539
540 Ok(())
541 }
542
543 pub fn graph_node_count(&self) -> usize {
544 self.node_map.len()
545 }
546
547 pub fn graph_edge_count(&self) -> usize {
548 self.graph.read().edge_count()
549 }
550
551 pub fn get_entity_type(&self, entity_id: EntityId) -> Result<Option<String>> {
552 let conn = self.db.lock();
553 let etype = conn
554 .query_row(
555 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
556 params![entity_id],
557 |row| row.get(0),
558 )
559 .ok();
560 Ok(etype)
561 }
562
563 pub fn count(&self) -> usize {
564 self.count.load(Ordering::Relaxed)
565 }
566
567 pub const fn dims(&self) -> u32 {
568 self.dims
569 }
570
571 pub fn name_to_id(&self) -> &DashMap<String, EntityId> {
572 &self.name_to_id
573 }
574
575 pub fn id_to_name(&self) -> &DashMap<EntityId, String> {
576 &self.id_to_name
577 }
578}
579
580fn write_f32(buf: &mut String, val: f32) {
581 use std::fmt::Write;
582 write!(buf, r#","score":{:.6}"#, val).unwrap();
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use crate::kg::GraphHandle;
589 use crate::config::Durability;
590 use crate::types::Entity;
591 use std::num::NonZeroUsize;
592
593 struct TestEnv {
594 kg: GraphHandle,
595 vs: VectorStore,
596 _dir: tempfile::TempDir,
597 }
598
599 fn setup(dims: u32) -> TestEnv {
600 let dir = tempfile::TempDir::new().unwrap();
601 let db_path = dir.path().join("test.db");
602 let lru = NonZeroUsize::new(10000).unwrap();
603 let kg = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
604 let vs = VectorStore::new(&db_path, dims).unwrap();
605 TestEnv {
606 kg,
607 vs,
608 _dir: dir,
609 }
610 }
611
612 fn create_test_entity(kg: &GraphHandle, name: &str, etype: &str) {
613 kg.create_entities(&[Entity {
614 name: name.into(),
615 entity_type: etype.into(),
616 observations: vec!["test observation".into()],
617 }])
618 .unwrap();
619 }
620
621 fn make_embedding(dims: u32, value: f32) -> Vec<f32> {
622 vec![value; dims as usize]
623 }
624
625 #[test]
626 fn test_vector_upsert_and_search() {
627 let env = setup(4);
628 create_test_entity(&env.kg, "alice", "person");
629 create_test_entity(&env.kg, "bob", "person");
630
631 let emb_a = make_embedding(4, 1.0);
632 let emb_b = make_embedding(4, 0.1);
633 env.vs.upsert_embedding("alice", &emb_a, "test-model").unwrap();
634 env.vs.upsert_embedding("bob", &emb_b, "test-model").unwrap();
635
636 let query = make_embedding(4, 1.0);
637 let results = env.vs.search_embeddings(&query, 10).unwrap();
638 assert_eq!(results.len(), 2);
639 assert!(results[0].1 < results[1].1);
640 }
641
642 #[test]
643 fn test_vector_delete_embedding() {
644 let env = setup(4);
645 create_test_entity(&env.kg, "alice", "person");
646 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
647 assert_eq!(env.vs.count(), 1);
648
649 let deleted = env.vs.delete_embedding("alice").unwrap();
650 assert!(deleted);
651 assert_eq!(env.vs.count(), 0);
652
653 let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
654 assert!(results.is_empty());
655 }
656
657 #[test]
658 fn test_vector_upsert_nonexistent_entity() {
659 let env = setup(4);
660 let err = env.vs.upsert_embedding("nonexistent", &make_embedding(4, 1.0), "");
661 assert!(err.is_err());
662 }
663
664 #[test]
665 fn test_vector_dimension_mismatch() {
666 let env = setup(4);
667 create_test_entity(&env.kg, "alice", "person");
668 let err = env.vs.upsert_embedding("alice", &make_embedding(8, 1.0), "");
669 assert!(err.is_err());
670 }
671
672 #[test]
673 fn test_vector_search_top_k() {
674 let env = setup(4);
675 for i in 0..5 {
676 create_test_entity(&env.kg, &format!("e{i}"), "test");
677 env.vs.upsert_embedding(&format!("e{i}"), &make_embedding(4, i as f32 * 0.2), "")
678 .unwrap();
679 }
680 let results = env.vs.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
681 assert_eq!(results.len(), 3);
682 }
683
684 #[test]
685 fn test_vector_search_type_filter() {
686 let env = setup(4);
687 create_test_entity(&env.kg, "alice", "person");
688 create_test_entity(&env.kg, "acme", "organization");
689 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
690 env.vs.upsert_embedding("acme", &make_embedding(4, 0.95), "").unwrap();
691
692 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, Some("person")).unwrap();
693 assert!(json.contains("alice"));
694 assert!(!json.contains("acme"));
695 }
696
697 #[test]
698 fn test_vector_blob_roundtrip() {
699 let emb: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
700 let blob = serialize_embedding(&emb);
701 let parsed = parse_embedding_blob(&blob).unwrap();
702 assert_eq!(parsed.len(), emb.len());
703 for (a, b) in parsed.iter().zip(emb.iter()) {
704 assert!((a - b).abs() < 1e-6);
705 }
706 }
707
708 #[test]
709 fn test_vector_scratch_buffer() {
710 with_scratch(|buf| {
711 buf.push(1.0);
712 buf.push(2.0);
713 assert_eq!(buf.len(), 2);
714 });
715 with_scratch(|buf| {
716 assert!(buf.is_empty());
717 buf.extend_from_slice(&[3.0, 4.0, 5.0]);
718 assert_eq!(buf.len(), 3);
719 });
720 }
721
722 #[test]
723 fn test_vector_rebuild_graph_cache() {
724 let env = setup(4);
725 create_test_entity(&env.kg, "alice", "person");
726 create_test_entity(&env.kg, "bob", "person");
727 create_test_entity(&env.kg, "charlie", "person");
728
729 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
730 env.vs.upsert_embedding("bob", &make_embedding(4, 0.5), "").unwrap();
731 env.vs.upsert_embedding("charlie", &make_embedding(4, 0.0), "").unwrap();
732
733 env.kg
734 .create_relations(&[crate::types::Relation {
735 from: "alice".into(),
736 to: "bob".into(),
737 relation_type: "knows".into(),
738 }])
739 .unwrap();
740
741 env.vs.rebuild_graph_cache().unwrap();
742 assert_eq!(env.vs.graph_node_count(), 3);
743 assert_eq!(env.vs.graph_edge_count(), 1);
744 }
745
746 #[test]
747 fn test_vector_upsert_replace() {
748 let env = setup(4);
749 create_test_entity(&env.kg, "alice", "person");
750 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
751 env.vs.upsert_embedding("alice", &make_embedding(4, 0.5), "").unwrap();
752 assert_eq!(env.vs.count(), 1);
753
754 let results = env.vs.search_embeddings(&make_embedding(4, 0.5), 10).unwrap();
755 assert_eq!(results.len(), 1);
756 let name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
757 assert_eq!(name.as_deref(), Some("alice"));
758 }
759
760 #[test]
761 fn test_vector_empty_store_search() {
762 let env = setup(4);
763 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
764 assert_eq!(json, r#"{"results":[],"count":0}"#);
765 }
766
767 #[test]
768 fn test_vector_persistence_across_reopen() {
769 let dir = tempfile::TempDir::new().unwrap();
770 let db_path = dir.path().join("persist.db");
771 let lru = NonZeroUsize::new(10000).unwrap();
772
773 let kg = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
774 kg.create_entities(&[Entity {
775 name: "alice".into(),
776 entity_type: "person".into(),
777 observations: vec![],
778 }])
779 .unwrap();
780
781 let vs1 = VectorStore::new(&db_path, 4).unwrap();
782 vs1.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
783 assert_eq!(vs1.count(), 1);
784 drop(vs1);
785 drop(kg);
786
787 let kg2 = GraphHandle::new(&db_path, Durability::Async, 268435456, lru, 4).unwrap();
788 let vs2 = VectorStore::new(&db_path, 4).unwrap();
789 assert_eq!(vs2.count(), 1);
790
791 let results = vs2.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
792 assert_eq!(results.len(), 1);
793 drop(vs2);
794 drop(kg2);
795 }
796
797 #[test]
798 fn test_vector_search_json_format() {
799 let env = setup(4);
800 create_test_entity(&env.kg, "alice", "person");
801 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
802
803 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
804 assert!(json.contains("alice"));
805 assert!(json.contains("person"));
806 assert!(json.contains("score"));
807 assert!(json.contains("count"));
808 }
809
810 #[test]
811 fn test_vector_concurrent_upsert() {
812 let env = setup(8);
813 let vs = Arc::new(env.vs);
814
815 let mut threads = Vec::new();
816 for i in 0..4 {
817 let vs = Arc::clone(&vs);
818 threads.push(std::thread::spawn(move || {
819 let name = format!("thread_{i}");
820 vs.upsert_embedding(&name, &make_embedding(8, i as f32 * 0.25), "")
822 .ok();
823 }));
824 }
825
826 create_test_entity(&env.kg, "thread_0", "t");
827 create_test_entity(&env.kg, "thread_1", "t");
828 create_test_entity(&env.kg, "thread_2", "t");
829 create_test_entity(&env.kg, "thread_3", "t");
830
831 for t in threads {
832 t.join().unwrap();
833 }
834 }
835}