1use std::collections::HashMap;
28use std::sync::Arc;
29
30use async_trait::async_trait;
31use entelix_core::{Error, ExecutionContext, Result};
32use parking_lot::RwLock;
33use uuid::Uuid;
34
35use crate::namespace::Namespace;
36use crate::traits::{Document, VectorFilter, VectorStore};
37
38pub struct InMemoryVectorStore {
43 dimension: usize,
44 inner: Arc<RwLock<HashMap<String, Vec<Slot>>>>,
45}
46
47#[derive(Clone, Debug)]
48struct Slot {
49 doc_id: String,
50 document: Document,
51 vector: Vec<f32>,
52 norm: f32,
55}
56
57impl InMemoryVectorStore {
58 #[must_use]
62 pub fn new(dimension: usize) -> Self {
63 Self {
64 dimension,
65 inner: Arc::new(RwLock::new(HashMap::new())),
66 }
67 }
68
69 #[must_use]
71 pub fn total_slots(&self) -> usize {
72 let guard = self.inner.read();
73 guard.values().map(Vec::len).sum()
74 }
75}
76
77impl Clone for InMemoryVectorStore {
78 fn clone(&self) -> Self {
79 Self {
80 dimension: self.dimension,
81 inner: Arc::clone(&self.inner),
82 }
83 }
84}
85
86impl std::fmt::Debug for InMemoryVectorStore {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 let guard = self.inner.read();
89 f.debug_struct("InMemoryVectorStore")
90 .field("dimension", &self.dimension)
91 .field("namespaces", &guard.len())
92 .field("total_slots", &guard.values().map(Vec::len).sum::<usize>())
93 .finish()
94 }
95}
96
97fn cosine_similarity(lhs: &[f32], lhs_norm: f32, rhs: &[f32], rhs_norm: f32) -> f32 {
101 if lhs_norm == 0.0 || rhs_norm == 0.0 {
102 return 0.0;
103 }
104 let dot: f32 = lhs.iter().zip(rhs.iter()).map(|(a, b)| a * b).sum();
105 dot / (lhs_norm * rhs_norm)
106}
107
108fn vector_norm(v: &[f32]) -> f32 {
109 v.iter().map(|x| x * x).sum::<f32>().sqrt()
110}
111
112#[async_trait]
113impl VectorStore for InMemoryVectorStore {
114 fn dimension(&self) -> usize {
115 self.dimension
116 }
117
118 async fn add(
119 &self,
120 _ctx: &ExecutionContext,
121 ns: &Namespace,
122 document: Document,
123 vector: Vec<f32>,
124 ) -> Result<()> {
125 if vector.len() != self.dimension {
126 return Err(Error::invalid_request(format!(
127 "InMemoryVectorStore: vector dimension {} does not match index dimension {}",
128 vector.len(),
129 self.dimension
130 )));
131 }
132 let norm = vector_norm(&vector);
133 let doc_id = document
138 .doc_id
139 .clone()
140 .unwrap_or_else(|| Uuid::new_v4().to_string());
141 let stored_doc = Document {
142 doc_id: Some(doc_id.clone()),
143 ..document
144 };
145 let mut guard = self.inner.write();
146 guard.entry(ns.render()).or_default().push(Slot {
147 doc_id,
148 document: stored_doc,
149 vector,
150 norm,
151 });
152 Ok(())
153 }
154
155 async fn search(
156 &self,
157 _ctx: &ExecutionContext,
158 ns: &Namespace,
159 query_vector: &[f32],
160 top_k: usize,
161 ) -> Result<Vec<Document>> {
162 if query_vector.len() != self.dimension {
163 return Err(Error::invalid_request(format!(
164 "InMemoryVectorStore: query dimension {} does not match index dimension {}",
165 query_vector.len(),
166 self.dimension
167 )));
168 }
169 let q_norm = vector_norm(query_vector);
170 let key = ns.render();
171 let scored: Vec<(f32, Document)> = {
172 let guard = self.inner.read();
173 let Some(slots) = guard.get(&key) else {
174 return Ok(Vec::new());
175 };
176 let mut scored: Vec<(f32, Document)> = slots
177 .iter()
178 .map(|s| {
179 let score = cosine_similarity(query_vector, q_norm, &s.vector, s.norm);
180 let mut doc = s.document.clone();
181 doc.score = Some(score);
182 (score, doc)
183 })
184 .collect();
185 scored.sort_by(|a, b| b.0.total_cmp(&a.0));
186 scored.truncate(top_k);
187 scored
188 };
189 Ok(scored.into_iter().map(|(_, d)| d).collect())
190 }
191
192 async fn delete(&self, _ctx: &ExecutionContext, ns: &Namespace, doc_id: &str) -> Result<()> {
193 let key = ns.render();
194 let mut guard = self.inner.write();
195 if let Some(slots) = guard.get_mut(&key) {
196 slots.retain(|s| s.doc_id != doc_id);
197 }
198 Ok(())
199 }
200
201 async fn update(
202 &self,
203 _ctx: &ExecutionContext,
204 ns: &Namespace,
205 doc_id: &str,
206 document: Document,
207 vector: Vec<f32>,
208 ) -> Result<()> {
209 if vector.len() != self.dimension {
210 return Err(Error::invalid_request(format!(
211 "InMemoryVectorStore: vector dimension {} does not match index dimension {}",
212 vector.len(),
213 self.dimension
214 )));
215 }
216 let norm = vector_norm(&vector);
217 let stored_doc = Document {
218 doc_id: Some(doc_id.to_owned()),
219 ..document
220 };
221 let mut guard = self.inner.write();
222 let slots = guard.entry(ns.render()).or_default();
223 if let Some(slot) = slots.iter_mut().find(|s| s.doc_id == doc_id) {
224 slot.document = stored_doc;
225 slot.vector = vector;
226 slot.norm = norm;
227 } else {
228 return Err(Error::invalid_request(format!(
229 "InMemoryVectorStore::update: doc_id '{doc_id}' not found"
230 )));
231 }
232 Ok(())
233 }
234
235 async fn search_filtered(
236 &self,
237 _ctx: &ExecutionContext,
238 ns: &Namespace,
239 query_vector: &[f32],
240 top_k: usize,
241 filter: &VectorFilter,
242 ) -> Result<Vec<Document>> {
243 if query_vector.len() != self.dimension {
244 return Err(Error::invalid_request(format!(
245 "InMemoryVectorStore: query dimension {} does not match index dimension {}",
246 query_vector.len(),
247 self.dimension
248 )));
249 }
250 let q_norm = vector_norm(query_vector);
251 let key = ns.render();
252 let scored: Vec<(f32, Document)> = {
253 let guard = self.inner.read();
254 let Some(slots) = guard.get(&key) else {
255 return Ok(Vec::new());
256 };
257 let mut scored: Vec<(f32, Document)> = slots
258 .iter()
259 .filter(|s| evaluate_filter(filter, &s.document.metadata))
260 .map(|s| {
261 let score = cosine_similarity(query_vector, q_norm, &s.vector, s.norm);
262 let mut doc = s.document.clone();
263 doc.score = Some(score);
264 (score, doc)
265 })
266 .collect();
267 scored.sort_by(|a, b| b.0.total_cmp(&a.0));
268 scored.truncate(top_k);
269 scored
270 };
271 Ok(scored.into_iter().map(|(_, d)| d).collect())
272 }
273
274 async fn count(
275 &self,
276 _ctx: &ExecutionContext,
277 ns: &Namespace,
278 filter: Option<&VectorFilter>,
279 ) -> Result<usize> {
280 let key = ns.render();
281 let guard = self.inner.read();
282 let count = guard.get(&key).map_or(0, |slots| match filter {
283 None => slots.len(),
284 Some(f) => slots
285 .iter()
286 .filter(|s| evaluate_filter(f, &s.document.metadata))
287 .count(),
288 });
289 Ok(count)
290 }
291
292 async fn list(
293 &self,
294 _ctx: &ExecutionContext,
295 ns: &Namespace,
296 filter: Option<&VectorFilter>,
297 limit: usize,
298 offset: usize,
299 ) -> Result<Vec<Document>> {
300 let key = ns.render();
301 let guard = self.inner.read();
302 let Some(slots) = guard.get(&key) else {
303 return Ok(Vec::new());
304 };
305 let out = slots
306 .iter()
307 .filter(|s| match filter {
308 None => true,
309 Some(f) => evaluate_filter(f, &s.document.metadata),
310 })
311 .skip(offset)
312 .take(limit)
313 .map(|s| s.document.clone())
314 .collect();
315 Ok(out)
316 }
317}
318
319fn evaluate_filter(filter: &VectorFilter, metadata: &serde_json::Value) -> bool {
326 match filter {
327 VectorFilter::All => true,
328 VectorFilter::Eq { key, value } => lookup(metadata, key).is_some_and(|v| v == value),
329 VectorFilter::Lt { key, value } => {
330 compare_numeric(metadata, key, value, std::cmp::Ordering::Less, false)
331 }
332 VectorFilter::Lte { key, value } => {
333 compare_numeric(metadata, key, value, std::cmp::Ordering::Less, true)
334 }
335 VectorFilter::Gt { key, value } => {
336 compare_numeric(metadata, key, value, std::cmp::Ordering::Greater, false)
337 }
338 VectorFilter::Gte { key, value } => {
339 compare_numeric(metadata, key, value, std::cmp::Ordering::Greater, true)
340 }
341 VectorFilter::Range { key, min, max } => {
342 compare_numeric(metadata, key, min, std::cmp::Ordering::Greater, true)
343 && compare_numeric(metadata, key, max, std::cmp::Ordering::Less, true)
344 }
345 VectorFilter::In { key, values } => {
346 lookup(metadata, key).is_some_and(|v| values.contains(v))
347 }
348 VectorFilter::Exists { key } => lookup(metadata, key).is_some(),
349 VectorFilter::And(children) => children.iter().all(|c| evaluate_filter(c, metadata)),
350 VectorFilter::Or(children) => children.iter().any(|c| evaluate_filter(c, metadata)),
351 VectorFilter::Not(child) => !evaluate_filter(child, metadata),
352 }
353}
354
355fn lookup<'a>(value: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
359 let mut cursor = value;
360 for segment in key.split('.') {
361 cursor = cursor.as_object()?.get(segment)?;
362 }
363 Some(cursor)
364}
365
366fn compare_numeric(
372 metadata: &serde_json::Value,
373 key: &str,
374 rhs: &serde_json::Value,
375 direction: std::cmp::Ordering,
376 inclusive: bool,
377) -> bool {
378 let Some(lhs) = lookup(metadata, key).and_then(serde_json::Value::as_f64) else {
379 return false;
380 };
381 let Some(rhs) = rhs.as_f64() else {
382 return false;
383 };
384 let cmp = lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal);
385 if cmp == std::cmp::Ordering::Equal {
386 return inclusive;
387 }
388 cmp == direction
389}
390
391#[cfg(test)]
392#[allow(clippy::unwrap_used, clippy::float_cmp, clippy::indexing_slicing)]
393mod tests {
394 use super::*;
395 use entelix_core::TenantId;
396 use serde_json::json;
397
398 fn ns() -> Namespace {
399 Namespace::new(TenantId::new("acme")).with_scope("agent-a")
400 }
401
402 fn ctx() -> ExecutionContext {
403 ExecutionContext::new()
404 }
405
406 fn doc(id: &str, content: &str, metadata: serde_json::Value) -> Document {
407 Document::new(content)
408 .with_doc_id(id)
409 .with_metadata(metadata)
410 }
411
412 #[tokio::test]
413 async fn add_then_search_returns_top_k_by_similarity() {
414 let store = InMemoryVectorStore::new(3);
415 let n = ns();
416 store
417 .add(
418 &ctx(),
419 &n,
420 doc("a", "alpha", json!({})),
421 vec![1.0, 0.0, 0.0],
422 )
423 .await
424 .unwrap();
425 store
426 .add(&ctx(), &n, doc("b", "beta", json!({})), vec![0.0, 1.0, 0.0])
427 .await
428 .unwrap();
429 store
430 .add(
431 &ctx(),
432 &n,
433 doc("c", "gamma", json!({})),
434 vec![0.9, 0.1, 0.0],
435 )
436 .await
437 .unwrap();
438 let hits = store.search(&ctx(), &n, &[1.0, 0.0, 0.0], 2).await.unwrap();
439 assert_eq!(hits.len(), 2);
440 assert_eq!(hits[0].doc_id.as_deref(), Some("a"));
441 assert_eq!(hits[1].doc_id.as_deref(), Some("c"));
442 assert!((hits[0].score.unwrap() - 1.0).abs() < 1e-6);
444 }
445
446 #[tokio::test]
447 async fn search_returns_empty_for_unknown_namespace() {
448 let store = InMemoryVectorStore::new(2);
449 let hits = store.search(&ctx(), &ns(), &[1.0, 0.0], 5).await.unwrap();
450 assert!(hits.is_empty());
451 }
452
453 #[tokio::test]
454 async fn dimension_mismatch_is_invalid_request() {
455 let store = InMemoryVectorStore::new(3);
456 let err = store
457 .add(&ctx(), &ns(), doc("a", "x", json!({})), vec![1.0, 0.0])
458 .await
459 .unwrap_err();
460 assert!(format!("{err}").contains("dimension"));
461 }
462
463 #[tokio::test]
464 async fn delete_then_search_omits_deleted_doc() {
465 let store = InMemoryVectorStore::new(2);
466 store
467 .add(&ctx(), &ns(), doc("a", "x", json!({})), vec![1.0, 0.0])
468 .await
469 .unwrap();
470 store.delete(&ctx(), &ns(), "a").await.unwrap();
471 let hits = store.search(&ctx(), &ns(), &[1.0, 0.0], 5).await.unwrap();
472 assert!(hits.is_empty());
473 }
474
475 #[tokio::test]
476 async fn update_replaces_vector_atomically() {
477 let store = InMemoryVectorStore::new(2);
478 store
479 .add(&ctx(), &ns(), doc("a", "v1", json!({})), vec![1.0, 0.0])
480 .await
481 .unwrap();
482 store
483 .update(
484 &ctx(),
485 &ns(),
486 "a",
487 doc("a", "v2", json!({"version": 2})),
488 vec![0.0, 1.0],
489 )
490 .await
491 .unwrap();
492 let hits = store.search(&ctx(), &ns(), &[0.0, 1.0], 1).await.unwrap();
493 assert_eq!(hits.len(), 1);
494 assert_eq!(hits[0].content, "v2");
495 assert_eq!(hits[0].metadata["version"], 2);
496 }
497
498 #[tokio::test]
499 async fn update_unknown_doc_returns_invalid_request() {
500 let store = InMemoryVectorStore::new(2);
501 let err = store
502 .update(
503 &ctx(),
504 &ns(),
505 "ghost",
506 doc("ghost", "x", json!({})),
507 vec![1.0, 0.0],
508 )
509 .await
510 .unwrap_err();
511 assert!(format!("{err}").contains("not found"));
512 }
513
514 #[tokio::test]
515 async fn search_filtered_honours_eq_filter() {
516 let store = InMemoryVectorStore::new(2);
517 store
518 .add(
519 &ctx(),
520 &ns(),
521 doc("a", "x", json!({"category": "A"})),
522 vec![1.0, 0.0],
523 )
524 .await
525 .unwrap();
526 store
527 .add(
528 &ctx(),
529 &ns(),
530 doc("b", "y", json!({"category": "B"})),
531 vec![1.0, 0.0],
532 )
533 .await
534 .unwrap();
535 let filter = VectorFilter::Eq {
536 key: "category".into(),
537 value: json!("A"),
538 };
539 let hits = store
540 .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 5, &filter)
541 .await
542 .unwrap();
543 assert_eq!(hits.len(), 1);
544 assert_eq!(hits[0].doc_id.as_deref(), Some("a"));
545 }
546
547 #[tokio::test]
548 async fn search_filtered_honours_range_and_negation() {
549 let store = InMemoryVectorStore::new(2);
550 for (id, score) in [("a", 5.0), ("b", 12.0), ("c", 25.0), ("d", 50.0)] {
551 store
552 .add(
553 &ctx(),
554 &ns(),
555 doc(id, "x", json!({"score": score})),
556 vec![1.0, 0.0],
557 )
558 .await
559 .unwrap();
560 }
561 let in_range = VectorFilter::Range {
562 key: "score".into(),
563 min: json!(10.0),
564 max: json!(30.0),
565 };
566 let hits = store
567 .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 10, &in_range)
568 .await
569 .unwrap();
570 assert_eq!(hits.len(), 2);
571 let ids: Vec<&str> = hits.iter().filter_map(|d| d.doc_id.as_deref()).collect();
572 assert!(ids.contains(&"b"));
573 assert!(ids.contains(&"c"));
574
575 let outside = VectorFilter::Not(Box::new(in_range));
576 let hits = store
577 .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 10, &outside)
578 .await
579 .unwrap();
580 assert_eq!(hits.len(), 2);
581 }
582
583 #[tokio::test]
584 async fn count_with_filter_returns_matching_subset() {
585 let store = InMemoryVectorStore::new(2);
586 for (id, cat) in [("a", "X"), ("b", "Y"), ("c", "X")] {
587 store
588 .add(
589 &ctx(),
590 &ns(),
591 doc(id, "x", json!({"cat": cat})),
592 vec![1.0, 0.0],
593 )
594 .await
595 .unwrap();
596 }
597 assert_eq!(store.count(&ctx(), &ns(), None).await.unwrap(), 3);
598 let only_x = VectorFilter::Eq {
599 key: "cat".into(),
600 value: json!("X"),
601 };
602 assert_eq!(store.count(&ctx(), &ns(), Some(&only_x)).await.unwrap(), 2);
603 }
604
605 #[tokio::test]
606 async fn list_paginates() {
607 let store = InMemoryVectorStore::new(2);
608 for i in 0..5 {
609 store
610 .add(
611 &ctx(),
612 &ns(),
613 doc(&format!("d{i}"), "x", json!({})),
614 vec![1.0, 0.0],
615 )
616 .await
617 .unwrap();
618 }
619 let page = store.list(&ctx(), &ns(), None, 2, 1).await.unwrap();
620 assert_eq!(page.len(), 2);
621 }
622
623 #[tokio::test]
624 async fn add_batch_default_loops_through_add() {
625 let store = InMemoryVectorStore::new(2);
626 let items = vec![
627 (doc("a", "x", json!({})), vec![1.0, 0.0]),
628 (doc("b", "y", json!({})), vec![0.0, 1.0]),
629 ];
630 store.add_batch(&ctx(), &ns(), items).await.unwrap();
631 assert_eq!(store.total_slots(), 2);
632 }
633
634 #[tokio::test]
635 async fn namespaces_are_isolated() {
636 let store = InMemoryVectorStore::new(2);
637 let ns_a = Namespace::new(TenantId::new("acme")).with_scope("agent-a");
638 let ns_b = Namespace::new(TenantId::new("acme")).with_scope("agent-b");
639 store
640 .add(&ctx(), &ns_a, doc("a", "x", json!({})), vec![1.0, 0.0])
641 .await
642 .unwrap();
643 let hits_a = store.search(&ctx(), &ns_a, &[1.0, 0.0], 5).await.unwrap();
644 let hits_b = store.search(&ctx(), &ns_b, &[1.0, 0.0], 5).await.unwrap();
645 assert_eq!(hits_a.len(), 1);
646 assert_eq!(hits_b.len(), 0);
647 }
648}