1use std::path::Path;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use parking_lot::{Mutex, RwLock};
7use petgraph::graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::Directed;
10use rusqlite::{params, Connection};
11use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
12use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
13
14use crate::errors::{MCSError, Result};
15use crate::ivf::{IvfFlatIndex, Metric as IvfMetric};
16use crate::kg::push_json_str;
17
18pub type EntityId = i64;
19
20enum AnnIndex {
24 Hnsw(Arc<Index>),
25 Ivf(Box<IvfFlatIndex>),
26}
27
28impl AnnIndex {
29 fn capacity(&self) -> usize {
31 match self {
32 AnnIndex::Hnsw(i) => i.capacity(),
33 AnnIndex::Ivf(i) => i.len(),
34 }
35 }
36
37 fn reserve(&self, target: usize) -> Result<()> {
39 if let AnnIndex::Hnsw(i) = self {
40 i.reserve_capacity_and_threads(target, 1)
41 .map_err(|e| MCSError::MemoryError(format!("usearch reserve: {e}")))?;
42 }
43 Ok(())
44 }
45
46 fn add(&self, id: u64, vector: &[f32]) -> Result<()> {
48 match self {
49 AnnIndex::Hnsw(i) => i
50 .add(id, vector)
51 .map_err(|e| MCSError::MemoryError(format!("usearch add: {e}"))),
52 AnnIndex::Ivf(i) => i
53 .upsert(id, vector)
54 .map(|_| ())
55 .map_err(MCSError::MemoryError),
56 }
57 }
58
59 fn remove(&self, id: u64) -> Result<bool> {
61 match self {
62 AnnIndex::Hnsw(i) => i
63 .remove(id)
64 .map(|n| n > 0)
65 .map_err(|e| MCSError::MemoryError(format!("usearch remove: {e}"))),
66 AnnIndex::Ivf(i) => Ok(i.remove(id)),
67 }
68 }
69
70 fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Result<Vec<(u64, f32)>> {
72 match self {
73 AnnIndex::Hnsw(i) => {
74 let m = i
75 .search(query, top_k)
76 .map_err(|e| MCSError::MemoryError(format!("usearch search: {e}")))?;
77 let cap = m.keys.len().min(m.distances.len());
78 Ok((0..cap).map(|j| (m.keys[j], m.distances[j])).collect())
79 }
80 AnnIndex::Ivf(i) => i.search(query, top_k, nprobe).map_err(MCSError::MemoryError),
81 }
82 }
83
84 fn train(&self) -> Result<()> {
86 if let AnnIndex::Ivf(i) = self {
87 i.train().map_err(MCSError::MemoryError)?;
88 }
89 Ok(())
90 }
91
92 const fn kind(&self) -> IndexKind {
93 match self {
94 AnnIndex::Hnsw(_) => IndexKind::Hnsw,
95 AnnIndex::Ivf(_) => IndexKind::Ivf,
96 }
97 }
98
99 fn memory_bytes(&self) -> usize {
100 match self {
101 AnnIndex::Hnsw(i) => i.memory_usage(),
102 AnnIndex::Ivf(i) => i.memory_bytes(),
103 }
104 }
105
106 fn memory_breakdown(&self) -> (usize, usize) {
108 match self {
109 AnnIndex::Hnsw(i) => {
110 let s = i.memory_stats();
111 (
112 s.graph_allocated + s.graph_reserved,
113 s.vectors_allocated + s.vectors_reserved,
114 )
115 }
116 AnnIndex::Ivf(i) => (0, i.memory_bytes()),
117 }
118 }
119}
120
121#[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
122#[repr(C)]
123struct BlobHeader {
124 dims: u32,
125}
126
127#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
129pub enum IndexKind {
130 #[default]
132 Hnsw,
133 Ivf,
137}
138
139#[derive(Clone, Copy, Debug)]
143pub struct VectorConfig {
144 pub dims: u32,
146 pub index_kind: IndexKind,
148 pub metric: MetricKind,
150 pub quantization: ScalarKind,
152 pub connectivity: usize,
154 pub expansion_add: usize,
156 pub expansion_search: usize,
158 pub ivf_nlist: usize,
160 pub ivf_nprobe: usize,
162}
163
164impl VectorConfig {
165 pub const fn new(dims: u32) -> Self {
167 Self {
168 dims,
169 index_kind: IndexKind::Hnsw,
170 metric: MetricKind::Cos,
171 quantization: ScalarKind::F32,
172 connectivity: 16,
173 expansion_add: 200,
174 expansion_search: 50,
175 ivf_nlist: 256,
176 ivf_nprobe: 8,
177 }
178 }
179}
180
181pub struct VectorStore {
182 pub name_to_id: Arc<DashMap<String, EntityId>>,
183 pub id_to_name: Arc<DashMap<EntityId, String>>,
184
185 pub(crate) graph: Arc<RwLock<StableGraph<EntityId, (), Directed, u32>>>,
186 pub(crate) node_map: Arc<DashMap<EntityId, NodeIndex<u32>>>,
187
188 index: AnnIndex,
189 pub(crate) db: Mutex<Connection>,
190
191 pub dims: u32,
192 pub count: AtomicUsize,
193 ivf_nprobe: usize,
195
196 pub db_path: std::path::PathBuf,
197}
198
199fn sqlite_err(e: rusqlite::Error) -> MCSError {
200 MCSError::IoError(std::io::Error::other(e))
201}
202
203thread_local! {
204 static SCRATCH: std::cell::RefCell<Vec<f32>> = const {
205 std::cell::RefCell::new(Vec::new())
206 };
207}
208
209pub fn with_scratch<R>(f: impl FnOnce(&mut Vec<f32>) -> R) -> R {
210 SCRATCH.with(|cell| {
211 let mut buf = cell.borrow_mut();
212 buf.clear();
213 f(&mut buf)
214 })
215}
216
217fn serialize_embedding(emb: &[f32]) -> Vec<u8> {
218 let header = BlobHeader {
219 dims: emb.len() as u32,
220 };
221 let f32_bytes: &[u8] = unsafe {
222 std::slice::from_raw_parts(emb.as_ptr() as *const u8, emb.len() * 4)
223 };
224 let mut bytes = Vec::with_capacity(4 + f32_bytes.len());
225 bytes.extend_from_slice(header.as_bytes());
226 bytes.extend_from_slice(f32_bytes);
227 bytes
228}
229
230fn parse_embedding_blob(blob: &[u8]) -> Result<&[f32]> {
231 let (header, rest) = BlobHeader::ref_from_prefix(blob)
232 .map_err(|_| MCSError::MemoryError("Invalid blob header".into()))?;
233 let count = header.dims as usize;
234 let bytes = rest
235 .get(..count * 4)
236 .ok_or_else(|| MCSError::MemoryError("Blob data too short".into()))?;
237 let emb = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, count) };
238 Ok(emb)
239}
240
241impl VectorStore {
242 pub fn new(db_path: &Path, dims: u32) -> Result<Self> {
244 Self::with_config(db_path, &VectorConfig::new(dims))
245 }
246
247 pub fn with_config(db_path: &Path, cfg: &VectorConfig) -> Result<Self> {
249 let dims = cfg.dims;
250 let conn = Connection::open(db_path).map_err(sqlite_err)?;
251 conn.busy_timeout(std::time::Duration::from_secs(5))
252 .map_err(sqlite_err)?;
253 conn.execute_batch(
254 "PRAGMA journal_mode = WAL;
255 PRAGMA synchronous = NORMAL;
256 PRAGMA temp_store = MEMORY;
257 CREATE TABLE IF NOT EXISTS vector_embedding (
258 entity_id INTEGER PRIMARY KEY,
259 dims INTEGER NOT NULL,
260 blob BLOB NOT NULL,
261 model TEXT NOT NULL DEFAULT '',
262 created_us INTEGER NOT NULL
263 );",
264 )
265 .map_err(sqlite_err)?;
266
267 let index = match cfg.index_kind {
268 IndexKind::Hnsw => {
269 let index_opts = IndexOptions {
270 dimensions: dims as usize,
271 metric: cfg.metric,
272 quantization: cfg.quantization,
273 connectivity: cfg.connectivity,
274 expansion_add: cfg.expansion_add,
275 expansion_search: cfg.expansion_search,
276 multi: false,
277 };
278 let index = Index::new(&index_opts)
279 .map_err(|e| MCSError::MemoryError(format!("usearch init: {e}")))?;
280 AnnIndex::Hnsw(Arc::new(index))
281 }
282 IndexKind::Ivf => AnnIndex::Ivf(Box::new(IvfFlatIndex::new(
283 dims as usize,
284 IvfMetric::from_usearch(cfg.metric),
285 cfg.ivf_nlist,
286 cfg.ivf_nprobe,
287 ))),
288 };
289
290 let name_to_id = Arc::new(DashMap::new());
291 let id_to_name = Arc::new(DashMap::new());
292 let graph = Arc::new(RwLock::new(StableGraph::<EntityId, (), Directed, u32>::new()));
293 let node_map = Arc::new(DashMap::new());
294 let db = Mutex::new(conn);
295
296 let store = Self {
297 name_to_id,
298 id_to_name,
299 graph,
300 node_map,
301 index,
302 db,
303 dims,
304 count: AtomicUsize::new(0),
305 ivf_nprobe: cfg.ivf_nprobe,
306 db_path: db_path.to_path_buf(),
307 };
308 store.load_existing()?;
309
310 Ok(store)
311 }
312
313 fn load_existing(&self) -> Result<()> {
314 let conn = self.db.lock();
315 let count: usize = conn
316 .query_row("SELECT COUNT(*) FROM vector_embedding", [], |r| {
317 r.get::<_, i64>(0)
318 })
319 .map_err(sqlite_err)?
320 as usize;
321
322 if count == 0 {
323 return Ok(());
324 }
325
326 self.index.reserve(count)?;
327
328 let mut stmt = conn
329 .prepare("SELECT entity_id, dims, blob, model FROM vector_embedding")
330 .map_err(sqlite_err)?;
331
332 let rows = stmt
333 .query_map([], |row| {
334 let id: i64 = row.get(0)?;
335 let dims: i64 = row.get(1)?;
336 let blob: Vec<u8> = row.get(2)?;
337 let model: String = row.get(3)?;
338 Ok((id, dims, blob, model))
339 })
340 .map_err(sqlite_err)?;
341
342 for row in rows {
343 let (id, _row_dims, blob, _model) = row.map_err(sqlite_err)?;
344 let emb = parse_embedding_blob(&blob)?;
345 self.index.add(id as u64, emb)?;
346 self.count.fetch_add(1, Ordering::Relaxed);
347 }
348
349 self.index.train()?;
352
353 self.load_names_from_entity_table(&conn)?;
354 Ok(())
355 }
356
357 fn load_names_from_entity_table(&self, conn: &Connection) -> Result<()> {
358 let mut stmt = conn
359 .prepare("SELECT id, name FROM entity WHERE flags = 0")
360 .map_err(sqlite_err)?;
361 let rows = stmt
362 .query_map([], |row| {
363 let id: i64 = row.get(0)?;
364 let name: String = row.get(1)?;
365 Ok((id, name))
366 })
367 .map_err(sqlite_err)?;
368
369 self.name_to_id.clear();
370 self.id_to_name.clear();
371
372 for row in rows {
373 let (id, name) = row.map_err(sqlite_err)?;
374 self.name_to_id.insert(name.clone(), id);
375 self.id_to_name.insert(id, name);
376 }
377 Ok(())
378 }
379
380 fn get_entity_id_and_name(&self, conn: &Connection, entity_name: &str) -> Result<Option<(EntityId, String)>> {
381 if let Some(entry) = self.name_to_id.get(entity_name) {
382 let id = *entry;
383 let name = entity_name.to_string();
384 return Ok(Some((id, name)));
385 }
386 let h = crate::kg::name_hash(entity_name);
387 let mut stmt = conn
388 .prepare_cached(
389 "SELECT id, name FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
390 )
391 .map_err(sqlite_err)?;
392 match stmt.query_row(params![h, entity_name], |row| {
393 let id: i64 = row.get(0)?;
394 let name: String = row.get(1)?;
395 Ok((id, name))
396 }) {
397 Ok(tup) => {
398 self.name_to_id.insert(tup.1.clone(), tup.0);
399 self.id_to_name.insert(tup.0, tup.1.clone());
400 Ok(Some(tup))
401 }
402 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
403 Err(e) => Err(sqlite_err(e)),
404 }
405 }
406
407 pub fn upsert_embedding(&self, entity_name: &str, embedding: &[f32], model: &str) -> Result<()> {
408 if embedding.len() != self.dims as usize {
409 return Err(MCSError::InvalidParams(format!(
410 "Embedding dimension mismatch: got {}, expected {}",
411 embedding.len(),
412 self.dims
413 )));
414 }
415
416 let conn = self.db.lock();
417 let entity = self
418 .get_entity_id_and_name(&conn, entity_name)?
419 .ok_or_else(|| {
420 MCSError::InvalidParams(format!("Entity '{entity_name}' not found in KG"))
421 })?;
422 let entity_id = entity.0;
423
424 let needed = self.count.load(Ordering::Relaxed).saturating_add(1);
427 if needed > self.index.capacity() {
428 const CHUNK: usize = 1024;
429 let target = needed.div_ceil(CHUNK).saturating_mul(CHUNK);
430 self.index.reserve(target)?;
431 }
432 let existed = self.index.remove(entity_id as u64).unwrap_or(false);
433 self.index.add(entity_id as u64, embedding)?;
434
435 self.name_to_id
436 .insert(entity_name.to_string(), entity_id);
437 self.id_to_name.insert(entity_id, entity_name.to_string());
438
439 let blob = serialize_embedding(embedding);
440 let now = std::time::SystemTime::now()
441 .duration_since(std::time::UNIX_EPOCH)
442 .unwrap_or_default()
443 .as_micros() as i64;
444
445 conn.execute(
446 "INSERT OR REPLACE INTO vector_embedding (entity_id, dims, blob, model, created_us) VALUES (?1, ?2, ?3, ?4, ?5)",
447 params![entity_id, self.dims, blob, model, now],
448 )
449 .map_err(sqlite_err)?;
450
451 if !existed {
452 self.count.fetch_add(1, Ordering::Relaxed);
453 }
454 Ok(())
455 }
456
457 pub fn delete_embedding(&self, entity_name: &str) -> Result<bool> {
458 let conn = self.db.lock();
459 let entity_id = match self.name_to_id.get(entity_name) {
460 Some(entry) => *entry,
461 None => {
462 return Ok(false);
463 }
464 };
465
466 self.index.remove(entity_id as u64)?;
467
468 self.name_to_id.remove(entity_name);
469 self.id_to_name.remove(&entity_id);
470
471 conn.execute(
472 "DELETE FROM vector_embedding WHERE entity_id = ?1",
473 params![entity_id],
474 )
475 .map_err(sqlite_err)?;
476
477 {
478 let mut g = self.graph.write();
479 if let Some(nx) = self.node_map.get(&entity_id) {
480 g.remove_node(*nx);
481 self.node_map.remove(&entity_id);
482 }
483 }
484
485 self.count.fetch_sub(1, Ordering::Relaxed);
486 Ok(true)
487 }
488
489 pub fn search_embeddings(
490 &self,
491 query: &[f32],
492 top_k: usize,
493 ) -> Result<Vec<(EntityId, f32)>> {
494 if self.count.load(Ordering::Relaxed) == 0 {
495 return Ok(Vec::new());
496 }
497 let top_k = top_k.clamp(1, 100);
498 let matches = self.index.search(query, top_k, Some(self.ivf_nprobe))?;
499 Ok(matches
500 .into_iter()
501 .map(|(id, dist)| (id as EntityId, dist))
502 .collect())
503 }
504
505 pub fn search_entities_json(
506 &self,
507 query: &[f32],
508 top_k: usize,
509 entity_type_filter: Option<&str>,
510 ) -> Result<String> {
511 let results = self.search_embeddings(query, top_k)?;
512 if results.is_empty() {
513 return Ok(r#"{"results":[],"count":0}"#.to_string());
514 }
515
516 let conn = self.db.lock();
517 let mut out = String::with_capacity(128 + results.len() * 64);
518 out.push_str(r#"{"results":["#);
519 let mut first = true;
520 let mut actual_count = 0usize;
521
522 for &(id, dist) in &results {
523 let name = self
524 .id_to_name
525 .get(&id)
526 .map(|r| r.value().clone())
527 .or_else(|| {
528 conn.query_row(
529 "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
530 params![id],
531 |row| row.get::<_, String>(0),
532 )
533 .ok()
534 });
535
536 let name = match name {
537 Some(n) => n,
538 None => continue,
539 };
540
541 let etype: String = conn
542 .query_row(
543 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
544 params![id],
545 |row| row.get(0),
546 )
547 .unwrap_or_default();
548
549 if let Some(filter_type) = entity_type_filter
550 && etype != filter_type
551 {
552 continue;
553 }
554
555 if !first {
556 out.push(',');
557 }
558 first = false;
559
560 out.push_str(r#"{"name":"#);
561 push_json_str(&mut out, &name);
562 out.push_str(r#","entityType":"#);
563 push_json_str(&mut out, &etype);
564 write_f32(&mut out, dist);
565 out.push('}');
566 actual_count += 1;
567 }
568
569 out.push_str(r#"],"count":"#);
570 out.push_str(&actual_count.to_string());
571 out.push('}');
572 Ok(out)
573 }
574
575 pub fn build_search_response_json(&self, results: &[(EntityId, f32)]) -> String {
576 let mut out = String::with_capacity(128 + results.len() * 64);
577 out.push_str(r#"{"results":["#);
578 for (i, &(id, dist)) in results.iter().enumerate() {
579 if i > 0 {
580 out.push(',');
581 }
582 out.push_str(r#"{"entityId":"#);
583 out.push_str(&id.to_string());
584 out.push_str(r#","distance":"#);
585 write_f32(&mut out, dist);
586 out.push('}');
587 }
588 out.push_str(r#"],"count":"#);
589 out.push_str(&results.len().to_string());
590 out.push('}');
591 out
592 }
593
594 pub fn rebuild_graph_cache(&self) -> Result<()> {
595 let conn = self.db.lock();
596
597 let mut ent_stmt = conn
598 .prepare("SELECT entity_id FROM vector_embedding")
599 .map_err(sqlite_err)?;
600 let ids: Vec<EntityId> = ent_stmt
601 .query_map([], |r| r.get::<_, i64>(0))
602 .map_err(sqlite_err)?
603 .filter_map(|r| r.ok())
604 .collect();
605
606 let mut g = StableGraph::<EntityId, (), Directed, u32>::with_capacity(ids.len(), 0);
607 let nm = DashMap::new();
608
609 for &id in &ids {
610 let nx = g.add_node(id);
611 nm.insert(id, nx);
612 }
613
614 if !ids.is_empty() {
615 const BATCH_SIZE: usize = 5000;
616 for chunk in ids.chunks(BATCH_SIZE) {
617 let placeholders: Vec<String> = chunk.iter().map(|_| "?".to_string()).collect();
618 let sql = format!(
619 "SELECT from_id, to_id FROM relation WHERE from_id IN ({}) AND to_id IN ({})",
620 placeholders.join(","),
621 placeholders.join(",")
622 );
623 let mut rel_stmt = conn.prepare(&sql).map_err(sqlite_err)?;
624
625 let mut param_values: Vec<&dyn rusqlite::types::ToSql> = Vec::with_capacity(chunk.len() * 2);
626 for id in chunk {
627 param_values.push(id as &dyn rusqlite::types::ToSql);
628 }
629 for id in chunk {
630 param_values.push(id as &dyn rusqlite::types::ToSql);
631 }
632
633 let rel_rows = rel_stmt
634 .query_map(param_values.as_slice(), |row| {
635 let from: i64 = row.get(0)?;
636 let to: i64 = row.get(1)?;
637 Ok((from, to))
638 })
639 .map_err(sqlite_err)?;
640
641 for rel in rel_rows {
642 let (from, to) = rel.map_err(sqlite_err)?;
643 if let (Some(f_nx), Some(t_nx)) = (nm.get(&from), nm.get(&to))
644 && g.find_edge(*f_nx, *t_nx).is_none()
645 {
646 g.add_edge(*f_nx, *t_nx, ());
647 }
648 }
649 }
650 }
651
652 *self.graph.write() = g;
653 self.node_map.clear();
654 for entry in nm.iter() {
655 self.node_map.insert(*entry.key(), *entry.value());
656 }
657
658 Ok(())
659 }
660
661 pub fn graph_node_count(&self) -> usize {
662 self.node_map.len()
663 }
664
665 pub fn graph_edge_count(&self) -> usize {
666 self.graph.read().edge_count()
667 }
668
669 pub fn get_entity_type(&self, entity_id: EntityId) -> Result<Option<String>> {
670 let conn = self.db.lock();
671 let etype = conn
672 .query_row(
673 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
674 params![entity_id],
675 |row| row.get(0),
676 )
677 .ok();
678 Ok(etype)
679 }
680
681 pub fn count(&self) -> usize {
682 self.count.load(Ordering::Relaxed)
683 }
684
685 pub const fn dims(&self) -> u32 {
686 self.dims
687 }
688
689 pub fn index_memory_bytes(&self) -> usize {
691 self.index.memory_bytes()
692 }
693
694 pub fn index_memory_breakdown(&self) -> (usize, usize) {
697 self.index.memory_breakdown()
698 }
699
700 pub fn index_capacity(&self) -> usize {
703 self.index.capacity()
704 }
705
706 pub const fn index_kind(&self) -> IndexKind {
708 self.index.kind()
709 }
710
711 pub fn reindex(&self) -> Result<()> {
715 self.index.train()
716 }
717
718 pub fn entity_id_of(&self, name: &str) -> Result<Option<EntityId>> {
720 let conn = self.db.lock();
721 Ok(self.get_entity_id_and_name(&conn, name)?.map(|(id, _)| id))
722 }
723
724 pub fn get_embedding_by_id(&self, id: EntityId) -> Result<Option<Vec<f32>>> {
726 let conn = self.db.lock();
727 let blob: Option<Vec<u8>> = conn
728 .query_row(
729 "SELECT blob FROM vector_embedding WHERE entity_id = ?1",
730 params![id],
731 |r| r.get(0),
732 )
733 .ok();
734 match blob {
735 Some(b) => Ok(Some(parse_embedding_blob(&b)?.to_vec())),
736 None => Ok(None),
737 }
738 }
739
740 pub fn get_embedding_by_name(
742 &self,
743 name: &str,
744 ) -> Result<Option<(EntityId, Vec<f32>, String)>> {
745 let id = match self.entity_id_of(name)? {
746 Some(id) => id,
747 None => return Ok(None),
748 };
749 let conn = self.db.lock();
750 let row: Option<(Vec<u8>, String)> = conn
751 .query_row(
752 "SELECT blob, model FROM vector_embedding WHERE entity_id = ?1",
753 params![id],
754 |r| Ok((r.get(0)?, r.get(1)?)),
755 )
756 .ok();
757 match row {
758 Some((blob, model)) => Ok(Some((id, parse_embedding_blob(&blob)?.to_vec(), model))),
759 None => Ok(None),
760 }
761 }
762
763 pub fn resolve_name_type(&self, id: EntityId) -> (String, String) {
766 let conn = self.db.lock();
767 let name = self
768 .id_to_name
769 .get(&id)
770 .map(|r| r.value().clone())
771 .or_else(|| {
772 conn.query_row(
773 "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
774 params![id],
775 |row| row.get::<_, String>(0),
776 )
777 .ok()
778 })
779 .unwrap_or_default();
780 let etype: String = conn
781 .query_row(
782 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
783 params![id],
784 |row| row.get(0),
785 )
786 .unwrap_or_default();
787 (name, etype)
788 }
789
790 pub fn search_resolved(
794 &self,
795 query: &[f32],
796 top_k: usize,
797 entity_type: Option<&str>,
798 exclude: &std::collections::HashSet<EntityId>,
799 ) -> Result<Vec<(EntityId, String, String, f32)>> {
800 let fetch = (top_k.saturating_mul(3) + exclude.len()).clamp(top_k, 100);
801 let raw = self.search_embeddings(query, fetch)?;
802 let mut out = Vec::with_capacity(top_k);
803 for (id, dist) in raw {
804 if exclude.contains(&id) {
805 continue;
806 }
807 let (name, etype) = self.resolve_name_type(id);
808 if name.is_empty() {
809 continue;
810 }
811 if let Some(ft) = entity_type
812 && etype != ft
813 {
814 continue;
815 }
816 out.push((id, name, etype, dist));
817 if out.len() >= top_k {
818 break;
819 }
820 }
821 Ok(out)
822 }
823
824 pub fn invalidate_entity_cache(&self, names: &[String]) {
825 for name in names {
826 if let Some((_, id)) = self.name_to_id.remove(name.as_str()) {
827 self.id_to_name.remove(&id);
828 }
829 }
830 }
831
832 pub fn name_to_id(&self) -> &DashMap<String, EntityId> {
833 &self.name_to_id
834 }
835
836 pub fn id_to_name(&self) -> &DashMap<EntityId, String> {
837 &self.id_to_name
838 }
839}
840
841fn write_f32(buf: &mut String, val: f32) {
842 use std::fmt::Write;
843 write!(buf, r#","score":{:.6}"#, val).unwrap();
844}
845
846#[cfg(test)]
847mod tests {
848 use super::*;
849 use crate::kg::GraphHandle;
850 use crate::config::{Durability, SqliteTuning};
851 use crate::types::Entity;
852 use std::num::NonZeroUsize;
853
854 struct TestEnv {
855 kg: GraphHandle,
856 vs: VectorStore,
857 _dir: tempfile::TempDir,
858 }
859
860 fn setup(dims: u32) -> TestEnv {
861 let dir = tempfile::TempDir::new().unwrap();
862 let db_path = dir.path().join("test.db");
863 let lru = NonZeroUsize::new(10000).unwrap();
864 let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
865 let vs = VectorStore::new(&db_path, dims).unwrap();
866 TestEnv {
867 kg,
868 vs,
869 _dir: dir,
870 }
871 }
872
873 fn setup_ivf(dims: u32) -> TestEnv {
874 let dir = tempfile::TempDir::new().unwrap();
875 let db_path = dir.path().join("test.db");
876 let lru = NonZeroUsize::new(10000).unwrap();
877 let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
878 let mut cfg = VectorConfig::new(dims);
879 cfg.index_kind = IndexKind::Ivf;
880 cfg.ivf_nlist = 4;
881 cfg.ivf_nprobe = 4;
882 let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
883 TestEnv {
884 kg,
885 vs,
886 _dir: dir,
887 }
888 }
889
890 fn create_test_entity(kg: &GraphHandle, name: &str, etype: &str) {
891 kg.create_entities(&[Entity {
892 name: name.into(),
893 entity_type: etype.into(),
894 observations: vec!["test observation".into()],
895 }])
896 .unwrap();
897 }
898
899 fn make_embedding(dims: u32, value: f32) -> Vec<f32> {
900 vec![value; dims as usize]
901 }
902
903 #[test]
904 fn test_vector_upsert_and_search() {
905 let env = setup(4);
906 create_test_entity(&env.kg, "alice", "person");
907 create_test_entity(&env.kg, "bob", "person");
908
909 let emb_a = make_embedding(4, 1.0);
910 let emb_b = make_embedding(4, 0.1);
911 env.vs.upsert_embedding("alice", &emb_a, "test-model").unwrap();
912 env.vs.upsert_embedding("bob", &emb_b, "test-model").unwrap();
913
914 let query = make_embedding(4, 1.0);
915 let results = env.vs.search_embeddings(&query, 10).unwrap();
916 assert_eq!(results.len(), 2);
917 assert!(results[0].1 < results[1].1);
918 }
919
920 #[test]
921 fn test_vector_delete_embedding() {
922 let env = setup(4);
923 create_test_entity(&env.kg, "alice", "person");
924 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
925 assert_eq!(env.vs.count(), 1);
926
927 let deleted = env.vs.delete_embedding("alice").unwrap();
928 assert!(deleted);
929 assert_eq!(env.vs.count(), 0);
930
931 let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
932 assert!(results.is_empty());
933 }
934
935 #[test]
936 fn test_vector_upsert_nonexistent_entity() {
937 let env = setup(4);
938 let err = env.vs.upsert_embedding("nonexistent", &make_embedding(4, 1.0), "");
939 assert!(err.is_err());
940 }
941
942 #[test]
943 fn test_vector_dimension_mismatch() {
944 let env = setup(4);
945 create_test_entity(&env.kg, "alice", "person");
946 let err = env.vs.upsert_embedding("alice", &make_embedding(8, 1.0), "");
947 assert!(err.is_err());
948 }
949
950 #[test]
951 fn test_vector_search_top_k() {
952 let env = setup(4);
953 for i in 0..5 {
954 create_test_entity(&env.kg, &format!("e{i}"), "test");
955 env.vs.upsert_embedding(&format!("e{i}"), &make_embedding(4, i as f32 * 0.2), "")
956 .unwrap();
957 }
958 let results = env.vs.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
959 assert_eq!(results.len(), 3);
960 }
961
962 #[test]
963 fn test_vector_search_type_filter() {
964 let env = setup(4);
965 create_test_entity(&env.kg, "alice", "person");
966 create_test_entity(&env.kg, "acme", "organization");
967 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
968 env.vs.upsert_embedding("acme", &make_embedding(4, 0.95), "").unwrap();
969
970 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, Some("person")).unwrap();
971 assert!(json.contains("alice"));
972 assert!(!json.contains("acme"));
973 }
974
975 #[test]
976 fn test_vector_blob_roundtrip() {
977 let emb: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
978 let blob = serialize_embedding(&emb);
979 let parsed = parse_embedding_blob(&blob).unwrap();
980 assert_eq!(parsed.len(), emb.len());
981 for (a, b) in parsed.iter().zip(emb.iter()) {
982 assert!((a - b).abs() < 1e-6);
983 }
984 }
985
986 #[test]
987 fn test_vector_scratch_buffer() {
988 with_scratch(|buf| {
989 buf.push(1.0);
990 buf.push(2.0);
991 assert_eq!(buf.len(), 2);
992 });
993 with_scratch(|buf| {
994 assert!(buf.is_empty());
995 buf.extend_from_slice(&[3.0, 4.0, 5.0]);
996 assert_eq!(buf.len(), 3);
997 });
998 }
999
1000 #[test]
1001 fn test_vector_rebuild_graph_cache() {
1002 let env = setup(4);
1003 create_test_entity(&env.kg, "alice", "person");
1004 create_test_entity(&env.kg, "bob", "person");
1005 create_test_entity(&env.kg, "charlie", "person");
1006
1007 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1008 env.vs.upsert_embedding("bob", &make_embedding(4, 0.5), "").unwrap();
1009 env.vs.upsert_embedding("charlie", &make_embedding(4, 0.0), "").unwrap();
1010
1011 env.kg
1012 .create_relations(&[crate::types::Relation {
1013 from: "alice".into(),
1014 to: "bob".into(),
1015 relation_type: "knows".into(),
1016 }])
1017 .unwrap();
1018
1019 env.vs.rebuild_graph_cache().unwrap();
1020 assert_eq!(env.vs.graph_node_count(), 3);
1021 assert_eq!(env.vs.graph_edge_count(), 1);
1022 }
1023
1024 #[test]
1025 fn test_vector_upsert_replace() {
1026 let env = setup(4);
1027 create_test_entity(&env.kg, "alice", "person");
1028 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1029 env.vs.upsert_embedding("alice", &make_embedding(4, 0.5), "").unwrap();
1030 assert_eq!(env.vs.count(), 1);
1031
1032 let results = env.vs.search_embeddings(&make_embedding(4, 0.5), 10).unwrap();
1033 assert_eq!(results.len(), 1);
1034 let name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1035 assert_eq!(name.as_deref(), Some("alice"));
1036 }
1037
1038 #[test]
1039 fn test_vector_index_capacity_grows_in_chunks() {
1040 let env = setup(4);
1041 assert_eq!(env.vs.count(), 0);
1043
1044 create_test_entity(&env.kg, "e0", "t");
1047 env.vs.upsert_embedding("e0", &make_embedding(4, 0.0), "").unwrap();
1048 let cap_after_first = env.vs.index_capacity();
1049 assert!(cap_after_first >= 1024, "capacity {cap_after_first} < 1024");
1050
1051 for i in 1..50 {
1053 let name = format!("e{i}");
1054 create_test_entity(&env.kg, &name, "t");
1055 env.vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.01), "").unwrap();
1056 }
1057 assert_eq!(env.vs.count(), 50);
1058 assert_eq!(env.vs.index_capacity(), cap_after_first, "capacity changed mid-chunk");
1059
1060 env.vs.upsert_embedding("e0", &make_embedding(4, 0.5), "").unwrap();
1062 assert_eq!(env.vs.count(), 50);
1063 assert_eq!(env.vs.index_capacity(), cap_after_first);
1064
1065 assert!(env.vs.index_memory_bytes() > 0);
1067 let (graph_bytes, vec_bytes) = env.vs.index_memory_breakdown();
1068 assert!(graph_bytes + vec_bytes > 0);
1069 }
1070
1071 #[test]
1072 fn test_vector_empty_store_search() {
1073 let env = setup(4);
1074 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1075 assert_eq!(json, r#"{"results":[],"count":0}"#);
1076 }
1077
1078 #[test]
1079 fn test_vector_persistence_across_reopen() {
1080 let dir = tempfile::TempDir::new().unwrap();
1081 let db_path = dir.path().join("persist.db");
1082 let lru = NonZeroUsize::new(10000).unwrap();
1083
1084 let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1085 kg.create_entities(&[Entity {
1086 name: "alice".into(),
1087 entity_type: "person".into(),
1088 observations: vec![],
1089 }])
1090 .unwrap();
1091
1092 let vs1 = VectorStore::new(&db_path, 4).unwrap();
1093 vs1.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1094 assert_eq!(vs1.count(), 1);
1095 drop(vs1);
1096 drop(kg);
1097
1098 let kg2 = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1099 let vs2 = VectorStore::new(&db_path, 4).unwrap();
1100 assert_eq!(vs2.count(), 1);
1101
1102 let results = vs2.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1103 assert_eq!(results.len(), 1);
1104 drop(vs2);
1105 drop(kg2);
1106 }
1107
1108 #[test]
1109 fn test_vector_search_json_format() {
1110 let env = setup(4);
1111 create_test_entity(&env.kg, "alice", "person");
1112 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "").unwrap();
1113
1114 let json = env.vs.search_entities_json(&make_embedding(4, 1.0), 10, None).unwrap();
1115 assert!(json.contains("alice"));
1116 assert!(json.contains("person"));
1117 assert!(json.contains("score"));
1118 assert!(json.contains("count"));
1119 }
1120
1121 #[test]
1122 fn test_vector_concurrent_upsert() {
1123 let env = setup(8);
1124 let vs = Arc::new(env.vs);
1125
1126 let mut threads = Vec::new();
1127 for i in 0..4 {
1128 let vs = Arc::clone(&vs);
1129 threads.push(std::thread::spawn(move || {
1130 let name = format!("thread_{i}");
1131 vs.upsert_embedding(&name, &make_embedding(8, i as f32 * 0.25), "")
1133 .ok();
1134 }));
1135 }
1136
1137 create_test_entity(&env.kg, "thread_0", "t");
1138 create_test_entity(&env.kg, "thread_1", "t");
1139 create_test_entity(&env.kg, "thread_2", "t");
1140 create_test_entity(&env.kg, "thread_3", "t");
1141
1142 for t in threads {
1143 t.join().unwrap();
1144 }
1145 }
1146
1147 #[test]
1150 fn test_ivf_store_upsert_search_delete() {
1151 let env = setup_ivf(4);
1152 assert_eq!(env.vs.index_kind(), IndexKind::Ivf);
1153 create_test_entity(&env.kg, "alice", "person");
1154 create_test_entity(&env.kg, "bob", "person");
1155 env.vs.upsert_embedding("alice", &make_embedding(4, 1.0), "m").unwrap();
1156 env.vs.upsert_embedding("bob", &make_embedding(4, 0.1), "m").unwrap();
1157 assert_eq!(env.vs.count(), 2);
1158
1159 let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1160 assert_eq!(results.len(), 2);
1161 let top_name = env.vs.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1163 assert_eq!(top_name.as_deref(), Some("alice"));
1164
1165 assert!(env.vs.delete_embedding("alice").unwrap());
1166 assert_eq!(env.vs.count(), 1);
1167 let results = env.vs.search_embeddings(&make_embedding(4, 1.0), 10).unwrap();
1168 assert_eq!(results.len(), 1);
1169 }
1170
1171 #[test]
1172 fn test_ivf_persistence_and_reindex() {
1173 let dir = tempfile::TempDir::new().unwrap();
1174 let db_path = dir.path().join("ivf.db");
1175 let lru = NonZeroUsize::new(10000).unwrap();
1176 let kg = GraphHandle::new(&db_path, Durability::Async, SqliteTuning::default(), lru, 4).unwrap();
1177 let mut cfg = VectorConfig::new(4);
1178 cfg.index_kind = IndexKind::Ivf;
1179 cfg.ivf_nlist = 3;
1180 cfg.ivf_nprobe = 3;
1181
1182 {
1183 let vs = VectorStore::with_config(&db_path, &cfg).unwrap();
1184 for i in 0..12 {
1185 let name = format!("e{i}");
1186 create_test_entity(&kg, &name, "t");
1187 vs.upsert_embedding(&name, &make_embedding(4, i as f32 * 0.1), "").unwrap();
1188 }
1189 vs.reindex().unwrap();
1190 assert_eq!(vs.count(), 12);
1191 }
1192
1193 let vs2 = VectorStore::with_config(&db_path, &cfg).unwrap();
1195 assert_eq!(vs2.count(), 12);
1196 let results = vs2.search_embeddings(&make_embedding(4, 0.0), 3).unwrap();
1197 assert!(!results.is_empty());
1198 let top = vs2.id_to_name.get(&results[0].0).map(|r| r.value().clone());
1200 assert_eq!(top.as_deref(), Some("e0"));
1201 }
1202
1203 #[test]
1206 fn test_get_embedding_helpers() {
1207 let env = setup(4);
1208 create_test_entity(&env.kg, "alice", "person");
1209 let emb = vec![0.1, 0.2, 0.3, 0.4];
1210 env.vs.upsert_embedding("alice", &emb, "model-x").unwrap();
1211
1212 let id = env.vs.entity_id_of("alice").unwrap().unwrap();
1213 let by_id = env.vs.get_embedding_by_id(id).unwrap().unwrap();
1214 assert_eq!(by_id, emb);
1215
1216 let (got_id, got_emb, model) = env.vs.get_embedding_by_name("alice").unwrap().unwrap();
1217 assert_eq!(got_id, id);
1218 assert_eq!(got_emb, emb);
1219 assert_eq!(model, "model-x");
1220
1221 assert!(env.vs.get_embedding_by_name("nobody").unwrap().is_none());
1222 }
1223
1224 #[test]
1225 fn test_search_resolved_excludes_and_filters() {
1226 let env = setup(4);
1227 create_test_entity(&env.kg, "a", "doc");
1228 create_test_entity(&env.kg, "b", "doc");
1229 create_test_entity(&env.kg, "c", "note");
1230 env.vs.upsert_embedding("a", &make_embedding(4, 1.0), "").unwrap();
1231 env.vs.upsert_embedding("b", &make_embedding(4, 0.9), "").unwrap();
1232 env.vs.upsert_embedding("c", &make_embedding(4, 0.95), "").unwrap();
1233
1234 let id_a = env.vs.entity_id_of("a").unwrap().unwrap();
1235 let mut exclude = std::collections::HashSet::new();
1236 exclude.insert(id_a);
1237
1238 let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, None, &exclude).unwrap();
1240 let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1241 assert!(!names.contains(&"a"));
1242 assert!(names.contains(&"b") && names.contains(&"c"));
1243
1244 let rows = env.vs.search_resolved(&make_embedding(4, 1.0), 10, Some("doc"), &exclude).unwrap();
1246 let names: Vec<&str> = rows.iter().map(|(_, n, _, _)| n.as_str()).collect();
1247 assert_eq!(names, vec!["b"]);
1248 }
1249}