Skip to main content

entelix_memory/
embedding_retriever.rs

1//! `EmbeddingRetriever` — adapter that wires a [`VectorStore`] +
2//! [`Embedder`] into a [`Retriever`].
3//!
4//! The two pieces every dense-retrieval pipeline needs sit on
5//! opposite sides of the [`Retriever`] surface — the embedder turns
6//! query text into a vector, the vector store does the
7//! nearest-neighbour search. This adapter wires them together with
8//! a fixed [`Namespace`] (the multi-tenant boundary, invariant 11)
9//! and exposes the canonical [`Retriever::retrieve`] shape so the
10//! result drops directly into [`SemanticMemory`](crate::SemanticMemory),
11//! `entelix-rag` recipes, or any custom retrieval-aware agent.
12//!
13//! ## Filter + score handling
14//!
15//! - [`RetrievalQuery::filter`] routes through
16//!   [`VectorStore::search_filtered`] when set; backends without
17//!   filter support surface their own
18//!   [`Error::Config`](entelix_core::Error::Config) as the trait
19//!   contract dictates. Filter-less queries route through
20//!   [`VectorStore::search`].
21//! - [`RetrievalQuery::min_score`] is applied as a post-filter on
22//!   the returned hits; backend-side score floors are not portable
23//!   across dot-product, cosine, and L2 distance backends, so the
24//!   adapter trims locally rather than translating the floor.
25//! - [`RetrievalQuery::top_k`] flows directly into the backend
26//!   call. The min-score post-filter then trims further; a query
27//!   that requests `top_k = 10` with `min_score = 0.5` may return
28//!   fewer than 10 hits when scores fall below the floor.
29
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use entelix_core::{ExecutionContext, Result};
34
35use crate::namespace::Namespace;
36use crate::traits::{Document, Embedder, RetrievalQuery, Retriever, VectorStore};
37
38/// Adapter that combines an [`Embedder`] and a [`VectorStore`]
39/// (scoped to one [`Namespace`]) into a [`Retriever`].
40///
41/// Cloning is cheap — both the embedder and the store sit behind
42/// `Arc`, so multiple retrievers can share one connection pool /
43/// embedding client.
44pub struct EmbeddingRetriever<E, V> {
45    embedder: Arc<E>,
46    store: Arc<V>,
47    namespace: Namespace,
48}
49
50impl<E, V> Clone for EmbeddingRetriever<E, V> {
51    fn clone(&self) -> Self {
52        Self {
53            embedder: Arc::clone(&self.embedder),
54            store: Arc::clone(&self.store),
55            namespace: self.namespace.clone(),
56        }
57    }
58}
59
60impl<E, V> std::fmt::Debug for EmbeddingRetriever<E, V> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("EmbeddingRetriever")
63            .field("namespace", &self.namespace)
64            .finish_non_exhaustive()
65    }
66}
67
68impl<E, V> EmbeddingRetriever<E, V>
69where
70    E: Embedder,
71    V: VectorStore,
72{
73    /// Build a retriever that runs every query against `store`
74    /// scoped to `namespace`, using `embedder` to turn query text
75    /// into the search vector.
76    #[must_use]
77    pub const fn new(embedder: Arc<E>, store: Arc<V>, namespace: Namespace) -> Self {
78        Self {
79            embedder,
80            store,
81            namespace,
82        }
83    }
84
85    /// Borrow the wired embedder.
86    #[must_use]
87    pub const fn embedder(&self) -> &Arc<E> {
88        &self.embedder
89    }
90
91    /// Borrow the wired vector store.
92    #[must_use]
93    pub const fn store(&self) -> &Arc<V> {
94        &self.store
95    }
96
97    /// Borrow the configured namespace.
98    #[must_use]
99    pub const fn namespace(&self) -> &Namespace {
100        &self.namespace
101    }
102}
103
104#[async_trait]
105impl<E, V> Retriever for EmbeddingRetriever<E, V>
106where
107    E: Embedder + 'static,
108    V: VectorStore + 'static,
109{
110    async fn retrieve(
111        &self,
112        query: RetrievalQuery,
113        ctx: &ExecutionContext,
114    ) -> Result<Vec<Document>> {
115        let embedding = self.embedder.embed(&query.text, ctx).await?;
116        let mut hits = match query.filter.as_ref() {
117            Some(filter) => {
118                self.store
119                    .search_filtered(ctx, &self.namespace, &embedding.vector, query.top_k, filter)
120                    .await?
121            }
122            None => {
123                self.store
124                    .search(ctx, &self.namespace, &embedding.vector, query.top_k)
125                    .await?
126            }
127        };
128        if let Some(floor) = query.min_score {
129            hits.retain(|doc| doc.score.is_some_and(|s| s >= floor));
130        }
131        Ok(hits)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::in_memory_vector_store::InMemoryVectorStore;
139    use crate::traits::{Embedding, VectorFilter};
140    use entelix_core::TenantId;
141    use std::sync::Arc;
142
143    /// Tiny BoW embedder over a fixed vocabulary — every recognised
144    /// word increments one basis component, the vector is L2-
145    /// normalised. Stable, no IO, deterministic.
146    struct BowEmbedder {
147        vocab: std::collections::HashMap<String, usize>,
148        dimension: usize,
149    }
150
151    impl BowEmbedder {
152        fn new(words: &[&str]) -> Self {
153            let dimension = words.len();
154            let vocab = words
155                .iter()
156                .enumerate()
157                .map(|(i, w)| ((*w).to_owned(), i))
158                .collect();
159            Self { vocab, dimension }
160        }
161    }
162
163    #[async_trait]
164    impl Embedder for BowEmbedder {
165        fn dimension(&self) -> usize {
166            self.dimension
167        }
168        async fn embed(&self, text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
169            let mut v = vec![0.0_f32; self.dimension];
170            for word in text.to_lowercase().split_whitespace() {
171                if let Some(&idx) = self.vocab.get(word)
172                    && let Some(slot) = v.get_mut(idx)
173                {
174                    *slot += 1.0;
175                }
176            }
177            let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
178            if norm > 0.0 {
179                for x in &mut v {
180                    *x /= norm;
181                }
182            }
183            Ok(Embedding::new(v))
184        }
185    }
186
187    fn ns(tenant: &str) -> Namespace {
188        Namespace::new(TenantId::new(tenant))
189    }
190
191    async fn seed_store(
192        embedder: &Arc<BowEmbedder>,
193        store: &Arc<InMemoryVectorStore>,
194        namespace: &Namespace,
195        docs: &[(&str, &str)],
196    ) -> Result<()> {
197        let ctx = ExecutionContext::new();
198        let mut items = Vec::new();
199        for (id, content) in docs {
200            let emb = embedder.embed(content, &ctx).await?;
201            let doc = Document::new(*content).with_doc_id((*id).to_owned());
202            items.push((doc, emb.vector));
203        }
204        store.add_batch(&ctx, namespace, items).await
205    }
206
207    #[tokio::test]
208    async fn retrieves_top_k_for_query() -> Result<()> {
209        let embedder = Arc::new(BowEmbedder::new(&[
210            "rust", "agent", "tokio", "async", "memory", "graph",
211        ]));
212        let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
213        let namespace = ns("acme");
214        seed_store(
215            &embedder,
216            &store,
217            &namespace,
218            &[
219                ("a", "rust agent tokio"),
220                ("b", "graph memory"),
221                ("c", "async rust"),
222            ],
223        )
224        .await?;
225
226        let retriever =
227            EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
228        let ctx = ExecutionContext::new();
229        let hits = retriever
230            .retrieve(RetrievalQuery::new("rust agent", 2), &ctx)
231            .await?;
232        assert_eq!(hits.len(), 2);
233        // The doc with both "rust" + "agent" must rank first.
234        assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
235        Ok(())
236    }
237
238    #[tokio::test]
239    async fn min_score_post_filters_below_floor() -> Result<()> {
240        let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
241        let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
242        let namespace = ns("acme");
243        seed_store(
244            &embedder,
245            &store,
246            &namespace,
247            &[("a", "alpha bravo"), ("b", "alpha"), ("c", "charlie")],
248        )
249        .await?;
250
251        let retriever =
252            EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
253        let ctx = ExecutionContext::new();
254        // Floor of 0.99 — only the exact match (cosine 1.0) survives.
255        let hits = retriever
256            .retrieve(
257                RetrievalQuery::new("alpha bravo", 5).with_min_score(0.99),
258                &ctx,
259            )
260            .await?;
261        assert_eq!(hits.len(), 1);
262        assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
263        Ok(())
264    }
265
266    #[tokio::test]
267    async fn filter_routes_through_search_filtered() -> Result<()> {
268        // InMemoryVectorStore implements search_filtered; verifying
269        // the adapter takes the filtered branch when query.filter is
270        // set.
271        let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo"]));
272        let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
273        let namespace = ns("acme");
274        let ctx = ExecutionContext::new();
275        let docs = [
276            ("a", "alpha bravo", serde_json::json!({"kind": "code"})),
277            ("b", "alpha", serde_json::json!({"kind": "doc"})),
278        ];
279        let mut items = Vec::new();
280        for (id, content, meta) in &docs {
281            let emb = embedder.embed(content, &ctx).await?;
282            let doc = Document::new(*content)
283                .with_doc_id((*id).to_owned())
284                .with_metadata(meta.clone());
285            items.push((doc, emb.vector));
286        }
287        store.add_batch(&ctx, &namespace, items).await?;
288
289        let retriever =
290            EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
291        let hits = retriever
292            .retrieve(
293                RetrievalQuery::new("alpha", 5).with_filter(VectorFilter::Eq {
294                    key: "kind".to_owned(),
295                    value: serde_json::json!("doc"),
296                }),
297                &ctx,
298            )
299            .await?;
300        assert_eq!(hits.len(), 1);
301        assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("b"));
302        Ok(())
303    }
304
305    #[tokio::test]
306    async fn namespace_isolation_blocks_cross_tenant_reads() -> Result<()> {
307        let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
308        let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
309        let alice = ns("alice");
310        let bob = ns("bob");
311        seed_store(
312            &embedder,
313            &store,
314            &alice,
315            &[("alice-doc", "alpha bravo charlie")],
316        )
317        .await?;
318        // Bob's namespace stays empty.
319
320        let bob_retriever = EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), bob);
321        let ctx = ExecutionContext::new();
322        let hits = bob_retriever
323            .retrieve(RetrievalQuery::new("alpha bravo charlie", 10), &ctx)
324            .await?;
325        assert!(
326            hits.is_empty(),
327            "Bob must not observe Alice's documents: {hits:?}"
328        );
329        Ok(())
330    }
331
332    #[tokio::test]
333    async fn clone_shares_embedder_and_store() {
334        let embedder = Arc::new(BowEmbedder::new(&["x"]));
335        let store = Arc::new(InMemoryVectorStore::new(1));
336        let namespace = ns("acme");
337        let original =
338            EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
339        let cloned = original.clone();
340        assert!(Arc::ptr_eq(original.embedder(), cloned.embedder()));
341        assert!(Arc::ptr_eq(original.store(), cloned.store()));
342        assert_eq!(cloned.namespace(), &namespace);
343    }
344}