Skip to main content

entelix_memory/
in_memory_vector_store.rs

1//! [`InMemoryVectorStore`] — concrete brute-force [`VectorStore`].
2//!
3//! In-process, namespace-scoped, cosine-similarity ranking over a
4//! linear scan. Designed for two real workloads:
5//!
6//! - **Tests / dev loops** — deterministic, zero I/O, no
7//!   external dependencies. Round-trips through the same
8//!   [`VectorStore`] surface a production backend (qdrant /
9//!   lancedb / pgvector) implements, so `SemanticMemory<E, V>`
10//!   wired to this store exercises the full pipeline end-to-end.
11//! - **Small-corpus production** — under ~10K documents per
12//!   namespace, brute-force scan with `simd`-friendly dot product
13//!   beats the operational complexity of a vector DB. Hot path is a
14//!   single read-locked `Vec` walk; writes briefly take the write
15//!   lock.
16//!
17//! Above ~10K docs/namespace, swap in a companion `VectorStore` with
18//! an ANN index (HNSW / IVF). The trait surface is identical, so the
19//! swap is a one-line replacement at the [`SemanticMemory`]
20//! construction site.
21//!
22//! Filters: `search_filtered`, `count`, `list` all honour the
23//! [`VectorFilter`] taxonomy in full — no `LossyEncode` cases. The
24//! linear scan evaluates the predicate per row before the dot
25//! product, so filter selectivity directly reduces work.
26
27use std::collections::HashMap;
28use std::sync::Arc;
29
30use async_trait::async_trait;
31use entelix_core::{Error, ExecutionContext, Result};
32use parking_lot::RwLock;
33use uuid::Uuid;
34
35use crate::namespace::Namespace;
36use crate::traits::{Document, VectorFilter, VectorStore};
37
38/// In-process [`VectorStore`] backed by a per-namespace `Vec<Slot>`.
39///
40/// Cloning is cheap — internal state lives behind `Arc<RwLock<...>>`
41/// so multiple `SemanticMemory` instances can share one store.
42pub struct InMemoryVectorStore {
43    dimension: usize,
44    inner: Arc<RwLock<HashMap<String, Vec<Slot>>>>,
45}
46
47#[derive(Clone, Debug)]
48struct Slot {
49    doc_id: String,
50    document: Document,
51    vector: Vec<f32>,
52    /// Pre-computed `‖vector‖` so cosine similarity reduces to a
53    /// single dot product per candidate.
54    norm: f32,
55}
56
57impl InMemoryVectorStore {
58    /// Build an empty store fixed to `dimension`. Inserts whose
59    /// `vector.len()` differs surface
60    /// [`Error::InvalidRequest`] at `add` time.
61    #[must_use]
62    pub fn new(dimension: usize) -> Self {
63        Self {
64            dimension,
65            inner: Arc::new(RwLock::new(HashMap::new())),
66        }
67    }
68
69    /// Total slot count across every namespace. Useful for tests.
70    #[must_use]
71    pub fn total_slots(&self) -> usize {
72        let guard = self.inner.read();
73        guard.values().map(Vec::len).sum()
74    }
75}
76
77impl Clone for InMemoryVectorStore {
78    fn clone(&self) -> Self {
79        Self {
80            dimension: self.dimension,
81            inner: Arc::clone(&self.inner),
82        }
83    }
84}
85
86impl std::fmt::Debug for InMemoryVectorStore {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        let guard = self.inner.read();
89        f.debug_struct("InMemoryVectorStore")
90            .field("dimension", &self.dimension)
91            .field("namespaces", &guard.len())
92            .field("total_slots", &guard.values().map(Vec::len).sum::<usize>())
93            .finish()
94    }
95}
96
97/// Cosine similarity between two equal-length vectors. Pre-computed
98/// `lhs_norm` / `rhs_norm` (the L2 norm of each vector) skip the
99/// `sqrt` per candidate during search.
100fn cosine_similarity(lhs: &[f32], lhs_norm: f32, rhs: &[f32], rhs_norm: f32) -> f32 {
101    if lhs_norm == 0.0 || rhs_norm == 0.0 {
102        return 0.0;
103    }
104    let dot: f32 = lhs.iter().zip(rhs.iter()).map(|(a, b)| a * b).sum();
105    dot / (lhs_norm * rhs_norm)
106}
107
108fn vector_norm(v: &[f32]) -> f32 {
109    v.iter().map(|x| x * x).sum::<f32>().sqrt()
110}
111
112#[async_trait]
113impl VectorStore for InMemoryVectorStore {
114    fn dimension(&self) -> usize {
115        self.dimension
116    }
117
118    async fn add(
119        &self,
120        _ctx: &ExecutionContext,
121        ns: &Namespace,
122        document: Document,
123        vector: Vec<f32>,
124    ) -> Result<()> {
125        if vector.len() != self.dimension {
126            return Err(Error::invalid_request(format!(
127                "InMemoryVectorStore: vector dimension {} does not match index dimension {}",
128                vector.len(),
129                self.dimension
130            )));
131        }
132        let norm = vector_norm(&vector);
133        // Backends typically mint a stable id at insertion time; we
134        // honour an operator-supplied `doc_id` and otherwise mint a
135        // fresh UUIDv4. Either way the slot carries a non-empty id
136        // so subsequent update/delete calls can address it.
137        let doc_id = document
138            .doc_id
139            .clone()
140            .unwrap_or_else(|| Uuid::new_v4().to_string());
141        let stored_doc = Document {
142            doc_id: Some(doc_id.clone()),
143            ..document
144        };
145        let mut guard = self.inner.write();
146        guard.entry(ns.render()).or_default().push(Slot {
147            doc_id,
148            document: stored_doc,
149            vector,
150            norm,
151        });
152        Ok(())
153    }
154
155    async fn search(
156        &self,
157        _ctx: &ExecutionContext,
158        ns: &Namespace,
159        query_vector: &[f32],
160        top_k: usize,
161    ) -> Result<Vec<Document>> {
162        if query_vector.len() != self.dimension {
163            return Err(Error::invalid_request(format!(
164                "InMemoryVectorStore: query dimension {} does not match index dimension {}",
165                query_vector.len(),
166                self.dimension
167            )));
168        }
169        let q_norm = vector_norm(query_vector);
170        let key = ns.render();
171        let scored: Vec<(f32, Document)> = {
172            let guard = self.inner.read();
173            let Some(slots) = guard.get(&key) else {
174                return Ok(Vec::new());
175            };
176            let mut scored: Vec<(f32, Document)> = slots
177                .iter()
178                .map(|s| {
179                    let score = cosine_similarity(query_vector, q_norm, &s.vector, s.norm);
180                    let mut doc = s.document.clone();
181                    doc.score = Some(score);
182                    (score, doc)
183                })
184                .collect();
185            scored.sort_by(|a, b| b.0.total_cmp(&a.0));
186            scored.truncate(top_k);
187            scored
188        };
189        Ok(scored.into_iter().map(|(_, d)| d).collect())
190    }
191
192    async fn delete(&self, _ctx: &ExecutionContext, ns: &Namespace, doc_id: &str) -> Result<()> {
193        let key = ns.render();
194        let mut guard = self.inner.write();
195        if let Some(slots) = guard.get_mut(&key) {
196            slots.retain(|s| s.doc_id != doc_id);
197        }
198        Ok(())
199    }
200
201    async fn update(
202        &self,
203        _ctx: &ExecutionContext,
204        ns: &Namespace,
205        doc_id: &str,
206        document: Document,
207        vector: Vec<f32>,
208    ) -> Result<()> {
209        if vector.len() != self.dimension {
210            return Err(Error::invalid_request(format!(
211                "InMemoryVectorStore: vector dimension {} does not match index dimension {}",
212                vector.len(),
213                self.dimension
214            )));
215        }
216        let norm = vector_norm(&vector);
217        let stored_doc = Document {
218            doc_id: Some(doc_id.to_owned()),
219            ..document
220        };
221        let mut guard = self.inner.write();
222        let slots = guard.entry(ns.render()).or_default();
223        if let Some(slot) = slots.iter_mut().find(|s| s.doc_id == doc_id) {
224            slot.document = stored_doc;
225            slot.vector = vector;
226            slot.norm = norm;
227        } else {
228            return Err(Error::invalid_request(format!(
229                "InMemoryVectorStore::update: doc_id '{doc_id}' not found"
230            )));
231        }
232        Ok(())
233    }
234
235    async fn search_filtered(
236        &self,
237        _ctx: &ExecutionContext,
238        ns: &Namespace,
239        query_vector: &[f32],
240        top_k: usize,
241        filter: &VectorFilter,
242    ) -> Result<Vec<Document>> {
243        if query_vector.len() != self.dimension {
244            return Err(Error::invalid_request(format!(
245                "InMemoryVectorStore: query dimension {} does not match index dimension {}",
246                query_vector.len(),
247                self.dimension
248            )));
249        }
250        let q_norm = vector_norm(query_vector);
251        let key = ns.render();
252        let scored: Vec<(f32, Document)> = {
253            let guard = self.inner.read();
254            let Some(slots) = guard.get(&key) else {
255                return Ok(Vec::new());
256            };
257            let mut scored: Vec<(f32, Document)> = slots
258                .iter()
259                .filter(|s| evaluate_filter(filter, &s.document.metadata))
260                .map(|s| {
261                    let score = cosine_similarity(query_vector, q_norm, &s.vector, s.norm);
262                    let mut doc = s.document.clone();
263                    doc.score = Some(score);
264                    (score, doc)
265                })
266                .collect();
267            scored.sort_by(|a, b| b.0.total_cmp(&a.0));
268            scored.truncate(top_k);
269            scored
270        };
271        Ok(scored.into_iter().map(|(_, d)| d).collect())
272    }
273
274    async fn count(
275        &self,
276        _ctx: &ExecutionContext,
277        ns: &Namespace,
278        filter: Option<&VectorFilter>,
279    ) -> Result<usize> {
280        let key = ns.render();
281        let guard = self.inner.read();
282        let count = guard.get(&key).map_or(0, |slots| match filter {
283            None => slots.len(),
284            Some(f) => slots
285                .iter()
286                .filter(|s| evaluate_filter(f, &s.document.metadata))
287                .count(),
288        });
289        Ok(count)
290    }
291
292    async fn list(
293        &self,
294        _ctx: &ExecutionContext,
295        ns: &Namespace,
296        filter: Option<&VectorFilter>,
297        limit: usize,
298        offset: usize,
299    ) -> Result<Vec<Document>> {
300        let key = ns.render();
301        let guard = self.inner.read();
302        let Some(slots) = guard.get(&key) else {
303            return Ok(Vec::new());
304        };
305        let out = slots
306            .iter()
307            .filter(|s| match filter {
308                None => true,
309                Some(f) => evaluate_filter(f, &s.document.metadata),
310            })
311            .skip(offset)
312            .take(limit)
313            .map(|s| s.document.clone())
314            .collect();
315        Ok(out)
316    }
317}
318
319/// Evaluate a [`VectorFilter`] against a document's metadata blob.
320/// Returns `true` when the predicate matches. The implementation
321/// matches the wire-level semantics every backend agrees on:
322/// - missing metadata key → predicate false (except [`VectorFilter::Not`]).
323/// - non-numeric operands on numeric variants → false.
324/// - equality is JSON-value equality (not relaxed coercion).
325fn evaluate_filter(filter: &VectorFilter, metadata: &serde_json::Value) -> bool {
326    match filter {
327        VectorFilter::All => true,
328        VectorFilter::Eq { key, value } => lookup(metadata, key).is_some_and(|v| v == value),
329        VectorFilter::Lt { key, value } => {
330            compare_numeric(metadata, key, value, std::cmp::Ordering::Less, false)
331        }
332        VectorFilter::Lte { key, value } => {
333            compare_numeric(metadata, key, value, std::cmp::Ordering::Less, true)
334        }
335        VectorFilter::Gt { key, value } => {
336            compare_numeric(metadata, key, value, std::cmp::Ordering::Greater, false)
337        }
338        VectorFilter::Gte { key, value } => {
339            compare_numeric(metadata, key, value, std::cmp::Ordering::Greater, true)
340        }
341        VectorFilter::Range { key, min, max } => {
342            compare_numeric(metadata, key, min, std::cmp::Ordering::Greater, true)
343                && compare_numeric(metadata, key, max, std::cmp::Ordering::Less, true)
344        }
345        VectorFilter::In { key, values } => {
346            lookup(metadata, key).is_some_and(|v| values.contains(v))
347        }
348        VectorFilter::Exists { key } => lookup(metadata, key).is_some(),
349        VectorFilter::And(children) => children.iter().all(|c| evaluate_filter(c, metadata)),
350        VectorFilter::Or(children) => children.iter().any(|c| evaluate_filter(c, metadata)),
351        VectorFilter::Not(child) => !evaluate_filter(child, metadata),
352    }
353}
354
355/// Look up a dotted metadata path. Each segment indexes one level
356/// of the JSON tree; non-object intermediates short-circuit to
357/// `None`.
358fn lookup<'a>(value: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
359    let mut cursor = value;
360    for segment in key.split('.') {
361        cursor = cursor.as_object()?.get(segment)?;
362    }
363    Some(cursor)
364}
365
366/// Numeric comparison helper. `direction` is the side of the ordering
367/// we want (Less means "lhs < rhs"); `inclusive` flips strict
368/// inequality to include equality. Returns false on non-numeric
369/// operands rather than coerce — a metadata field that stores `"42"`
370/// as a string is a schema bug; surface it by failing the predicate.
371fn compare_numeric(
372    metadata: &serde_json::Value,
373    key: &str,
374    rhs: &serde_json::Value,
375    direction: std::cmp::Ordering,
376    inclusive: bool,
377) -> bool {
378    let Some(lhs) = lookup(metadata, key).and_then(serde_json::Value::as_f64) else {
379        return false;
380    };
381    let Some(rhs) = rhs.as_f64() else {
382        return false;
383    };
384    let cmp = lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal);
385    if cmp == std::cmp::Ordering::Equal {
386        return inclusive;
387    }
388    cmp == direction
389}
390
391#[cfg(test)]
392#[allow(clippy::unwrap_used, clippy::float_cmp, clippy::indexing_slicing)]
393mod tests {
394    use super::*;
395    use entelix_core::TenantId;
396    use serde_json::json;
397
398    fn ns() -> Namespace {
399        Namespace::new(TenantId::new("acme")).with_scope("agent-a")
400    }
401
402    fn ctx() -> ExecutionContext {
403        ExecutionContext::new()
404    }
405
406    fn doc(id: &str, content: &str, metadata: serde_json::Value) -> Document {
407        Document::new(content)
408            .with_doc_id(id)
409            .with_metadata(metadata)
410    }
411
412    #[tokio::test]
413    async fn add_then_search_returns_top_k_by_similarity() {
414        let store = InMemoryVectorStore::new(3);
415        let n = ns();
416        store
417            .add(
418                &ctx(),
419                &n,
420                doc("a", "alpha", json!({})),
421                vec![1.0, 0.0, 0.0],
422            )
423            .await
424            .unwrap();
425        store
426            .add(&ctx(), &n, doc("b", "beta", json!({})), vec![0.0, 1.0, 0.0])
427            .await
428            .unwrap();
429        store
430            .add(
431                &ctx(),
432                &n,
433                doc("c", "gamma", json!({})),
434                vec![0.9, 0.1, 0.0],
435            )
436            .await
437            .unwrap();
438        let hits = store.search(&ctx(), &n, &[1.0, 0.0, 0.0], 2).await.unwrap();
439        assert_eq!(hits.len(), 2);
440        assert_eq!(hits[0].doc_id.as_deref(), Some("a"));
441        assert_eq!(hits[1].doc_id.as_deref(), Some("c"));
442        // Score is cosine — exact match → 1.0.
443        assert!((hits[0].score.unwrap() - 1.0).abs() < 1e-6);
444    }
445
446    #[tokio::test]
447    async fn search_returns_empty_for_unknown_namespace() {
448        let store = InMemoryVectorStore::new(2);
449        let hits = store.search(&ctx(), &ns(), &[1.0, 0.0], 5).await.unwrap();
450        assert!(hits.is_empty());
451    }
452
453    #[tokio::test]
454    async fn dimension_mismatch_is_invalid_request() {
455        let store = InMemoryVectorStore::new(3);
456        let err = store
457            .add(&ctx(), &ns(), doc("a", "x", json!({})), vec![1.0, 0.0])
458            .await
459            .unwrap_err();
460        assert!(format!("{err}").contains("dimension"));
461    }
462
463    #[tokio::test]
464    async fn delete_then_search_omits_deleted_doc() {
465        let store = InMemoryVectorStore::new(2);
466        store
467            .add(&ctx(), &ns(), doc("a", "x", json!({})), vec![1.0, 0.0])
468            .await
469            .unwrap();
470        store.delete(&ctx(), &ns(), "a").await.unwrap();
471        let hits = store.search(&ctx(), &ns(), &[1.0, 0.0], 5).await.unwrap();
472        assert!(hits.is_empty());
473    }
474
475    #[tokio::test]
476    async fn update_replaces_vector_atomically() {
477        let store = InMemoryVectorStore::new(2);
478        store
479            .add(&ctx(), &ns(), doc("a", "v1", json!({})), vec![1.0, 0.0])
480            .await
481            .unwrap();
482        store
483            .update(
484                &ctx(),
485                &ns(),
486                "a",
487                doc("a", "v2", json!({"version": 2})),
488                vec![0.0, 1.0],
489            )
490            .await
491            .unwrap();
492        let hits = store.search(&ctx(), &ns(), &[0.0, 1.0], 1).await.unwrap();
493        assert_eq!(hits.len(), 1);
494        assert_eq!(hits[0].content, "v2");
495        assert_eq!(hits[0].metadata["version"], 2);
496    }
497
498    #[tokio::test]
499    async fn update_unknown_doc_returns_invalid_request() {
500        let store = InMemoryVectorStore::new(2);
501        let err = store
502            .update(
503                &ctx(),
504                &ns(),
505                "ghost",
506                doc("ghost", "x", json!({})),
507                vec![1.0, 0.0],
508            )
509            .await
510            .unwrap_err();
511        assert!(format!("{err}").contains("not found"));
512    }
513
514    #[tokio::test]
515    async fn search_filtered_honours_eq_filter() {
516        let store = InMemoryVectorStore::new(2);
517        store
518            .add(
519                &ctx(),
520                &ns(),
521                doc("a", "x", json!({"category": "A"})),
522                vec![1.0, 0.0],
523            )
524            .await
525            .unwrap();
526        store
527            .add(
528                &ctx(),
529                &ns(),
530                doc("b", "y", json!({"category": "B"})),
531                vec![1.0, 0.0],
532            )
533            .await
534            .unwrap();
535        let filter = VectorFilter::Eq {
536            key: "category".into(),
537            value: json!("A"),
538        };
539        let hits = store
540            .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 5, &filter)
541            .await
542            .unwrap();
543        assert_eq!(hits.len(), 1);
544        assert_eq!(hits[0].doc_id.as_deref(), Some("a"));
545    }
546
547    #[tokio::test]
548    async fn search_filtered_honours_range_and_negation() {
549        let store = InMemoryVectorStore::new(2);
550        for (id, score) in [("a", 5.0), ("b", 12.0), ("c", 25.0), ("d", 50.0)] {
551            store
552                .add(
553                    &ctx(),
554                    &ns(),
555                    doc(id, "x", json!({"score": score})),
556                    vec![1.0, 0.0],
557                )
558                .await
559                .unwrap();
560        }
561        let in_range = VectorFilter::Range {
562            key: "score".into(),
563            min: json!(10.0),
564            max: json!(30.0),
565        };
566        let hits = store
567            .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 10, &in_range)
568            .await
569            .unwrap();
570        assert_eq!(hits.len(), 2);
571        let ids: Vec<&str> = hits.iter().filter_map(|d| d.doc_id.as_deref()).collect();
572        assert!(ids.contains(&"b"));
573        assert!(ids.contains(&"c"));
574
575        let outside = VectorFilter::Not(Box::new(in_range));
576        let hits = store
577            .search_filtered(&ctx(), &ns(), &[1.0, 0.0], 10, &outside)
578            .await
579            .unwrap();
580        assert_eq!(hits.len(), 2);
581    }
582
583    #[tokio::test]
584    async fn count_with_filter_returns_matching_subset() {
585        let store = InMemoryVectorStore::new(2);
586        for (id, cat) in [("a", "X"), ("b", "Y"), ("c", "X")] {
587            store
588                .add(
589                    &ctx(),
590                    &ns(),
591                    doc(id, "x", json!({"cat": cat})),
592                    vec![1.0, 0.0],
593                )
594                .await
595                .unwrap();
596        }
597        assert_eq!(store.count(&ctx(), &ns(), None).await.unwrap(), 3);
598        let only_x = VectorFilter::Eq {
599            key: "cat".into(),
600            value: json!("X"),
601        };
602        assert_eq!(store.count(&ctx(), &ns(), Some(&only_x)).await.unwrap(), 2);
603    }
604
605    #[tokio::test]
606    async fn list_paginates() {
607        let store = InMemoryVectorStore::new(2);
608        for i in 0..5 {
609            store
610                .add(
611                    &ctx(),
612                    &ns(),
613                    doc(&format!("d{i}"), "x", json!({})),
614                    vec![1.0, 0.0],
615                )
616                .await
617                .unwrap();
618        }
619        let page = store.list(&ctx(), &ns(), None, 2, 1).await.unwrap();
620        assert_eq!(page.len(), 2);
621    }
622
623    #[tokio::test]
624    async fn add_batch_default_loops_through_add() {
625        let store = InMemoryVectorStore::new(2);
626        let items = vec![
627            (doc("a", "x", json!({})), vec![1.0, 0.0]),
628            (doc("b", "y", json!({})), vec![0.0, 1.0]),
629        ];
630        store.add_batch(&ctx(), &ns(), items).await.unwrap();
631        assert_eq!(store.total_slots(), 2);
632    }
633
634    #[tokio::test]
635    async fn namespaces_are_isolated() {
636        let store = InMemoryVectorStore::new(2);
637        let ns_a = Namespace::new(TenantId::new("acme")).with_scope("agent-a");
638        let ns_b = Namespace::new(TenantId::new("acme")).with_scope("agent-b");
639        store
640            .add(&ctx(), &ns_a, doc("a", "x", json!({})), vec![1.0, 0.0])
641            .await
642            .unwrap();
643        let hits_a = store.search(&ctx(), &ns_a, &[1.0, 0.0], 5).await.unwrap();
644        let hits_b = store.search(&ctx(), &ns_b, &[1.0, 0.0], 5).await.unwrap();
645        assert_eq!(hits_a.len(), 1);
646        assert_eq!(hits_b.len(), 0);
647    }
648}