1use std::sync::Arc;
31
32use async_trait::async_trait;
33use entelix_core::{ExecutionContext, Result};
34
35use crate::namespace::Namespace;
36use crate::traits::{Document, Embedder, RetrievalQuery, Retriever, VectorStore};
37
38pub struct EmbeddingRetriever<E, V> {
45 embedder: Arc<E>,
46 store: Arc<V>,
47 namespace: Namespace,
48}
49
50impl<E, V> Clone for EmbeddingRetriever<E, V> {
51 fn clone(&self) -> Self {
52 Self {
53 embedder: Arc::clone(&self.embedder),
54 store: Arc::clone(&self.store),
55 namespace: self.namespace.clone(),
56 }
57 }
58}
59
60impl<E, V> std::fmt::Debug for EmbeddingRetriever<E, V> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("EmbeddingRetriever")
63 .field("namespace", &self.namespace)
64 .finish_non_exhaustive()
65 }
66}
67
68impl<E, V> EmbeddingRetriever<E, V>
69where
70 E: Embedder,
71 V: VectorStore,
72{
73 #[must_use]
77 pub const fn new(embedder: Arc<E>, store: Arc<V>, namespace: Namespace) -> Self {
78 Self {
79 embedder,
80 store,
81 namespace,
82 }
83 }
84
85 #[must_use]
87 pub const fn embedder(&self) -> &Arc<E> {
88 &self.embedder
89 }
90
91 #[must_use]
93 pub const fn store(&self) -> &Arc<V> {
94 &self.store
95 }
96
97 #[must_use]
99 pub const fn namespace(&self) -> &Namespace {
100 &self.namespace
101 }
102}
103
104#[async_trait]
105impl<E, V> Retriever for EmbeddingRetriever<E, V>
106where
107 E: Embedder + 'static,
108 V: VectorStore + 'static,
109{
110 async fn retrieve(
111 &self,
112 query: RetrievalQuery,
113 ctx: &ExecutionContext,
114 ) -> Result<Vec<Document>> {
115 let embedding = self.embedder.embed(&query.text, ctx).await?;
116 let mut hits = match query.filter.as_ref() {
117 Some(filter) => {
118 self.store
119 .search_filtered(ctx, &self.namespace, &embedding.vector, query.top_k, filter)
120 .await?
121 }
122 None => {
123 self.store
124 .search(ctx, &self.namespace, &embedding.vector, query.top_k)
125 .await?
126 }
127 };
128 if let Some(floor) = query.min_score {
129 hits.retain(|doc| doc.score.is_some_and(|s| s >= floor));
130 }
131 Ok(hits)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::in_memory_vector_store::InMemoryVectorStore;
139 use crate::traits::{Embedding, VectorFilter};
140 use entelix_core::TenantId;
141 use std::sync::Arc;
142
143 struct BowEmbedder {
147 vocab: std::collections::HashMap<String, usize>,
148 dimension: usize,
149 }
150
151 impl BowEmbedder {
152 fn new(words: &[&str]) -> Self {
153 let dimension = words.len();
154 let vocab = words
155 .iter()
156 .enumerate()
157 .map(|(i, w)| ((*w).to_owned(), i))
158 .collect();
159 Self { vocab, dimension }
160 }
161 }
162
163 #[async_trait]
164 impl Embedder for BowEmbedder {
165 fn dimension(&self) -> usize {
166 self.dimension
167 }
168 async fn embed(&self, text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
169 let mut v = vec![0.0_f32; self.dimension];
170 for word in text.to_lowercase().split_whitespace() {
171 if let Some(&idx) = self.vocab.get(word)
172 && let Some(slot) = v.get_mut(idx)
173 {
174 *slot += 1.0;
175 }
176 }
177 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
178 if norm > 0.0 {
179 for x in &mut v {
180 *x /= norm;
181 }
182 }
183 Ok(Embedding::new(v))
184 }
185 }
186
187 fn ns(tenant: &str) -> Namespace {
188 Namespace::new(TenantId::new(tenant))
189 }
190
191 async fn seed_store(
192 embedder: &Arc<BowEmbedder>,
193 store: &Arc<InMemoryVectorStore>,
194 namespace: &Namespace,
195 docs: &[(&str, &str)],
196 ) -> Result<()> {
197 let ctx = ExecutionContext::new();
198 let mut items = Vec::new();
199 for (id, content) in docs {
200 let emb = embedder.embed(content, &ctx).await?;
201 let doc = Document::new(*content).with_doc_id((*id).to_owned());
202 items.push((doc, emb.vector));
203 }
204 store.add_batch(&ctx, namespace, items).await
205 }
206
207 #[tokio::test]
208 async fn retrieves_top_k_for_query() -> Result<()> {
209 let embedder = Arc::new(BowEmbedder::new(&[
210 "rust", "agent", "tokio", "async", "memory", "graph",
211 ]));
212 let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
213 let namespace = ns("acme");
214 seed_store(
215 &embedder,
216 &store,
217 &namespace,
218 &[
219 ("a", "rust agent tokio"),
220 ("b", "graph memory"),
221 ("c", "async rust"),
222 ],
223 )
224 .await?;
225
226 let retriever =
227 EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
228 let ctx = ExecutionContext::new();
229 let hits = retriever
230 .retrieve(RetrievalQuery::new("rust agent", 2), &ctx)
231 .await?;
232 assert_eq!(hits.len(), 2);
233 assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
235 Ok(())
236 }
237
238 #[tokio::test]
239 async fn min_score_post_filters_below_floor() -> Result<()> {
240 let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
241 let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
242 let namespace = ns("acme");
243 seed_store(
244 &embedder,
245 &store,
246 &namespace,
247 &[("a", "alpha bravo"), ("b", "alpha"), ("c", "charlie")],
248 )
249 .await?;
250
251 let retriever =
252 EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
253 let ctx = ExecutionContext::new();
254 let hits = retriever
256 .retrieve(
257 RetrievalQuery::new("alpha bravo", 5).with_min_score(0.99),
258 &ctx,
259 )
260 .await?;
261 assert_eq!(hits.len(), 1);
262 assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
263 Ok(())
264 }
265
266 #[tokio::test]
267 async fn filter_routes_through_search_filtered() -> Result<()> {
268 let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo"]));
272 let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
273 let namespace = ns("acme");
274 let ctx = ExecutionContext::new();
275 let docs = [
276 ("a", "alpha bravo", serde_json::json!({"kind": "code"})),
277 ("b", "alpha", serde_json::json!({"kind": "doc"})),
278 ];
279 let mut items = Vec::new();
280 for (id, content, meta) in &docs {
281 let emb = embedder.embed(content, &ctx).await?;
282 let doc = Document::new(*content)
283 .with_doc_id((*id).to_owned())
284 .with_metadata(meta.clone());
285 items.push((doc, emb.vector));
286 }
287 store.add_batch(&ctx, &namespace, items).await?;
288
289 let retriever =
290 EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
291 let hits = retriever
292 .retrieve(
293 RetrievalQuery::new("alpha", 5).with_filter(VectorFilter::Eq {
294 key: "kind".to_owned(),
295 value: serde_json::json!("doc"),
296 }),
297 &ctx,
298 )
299 .await?;
300 assert_eq!(hits.len(), 1);
301 assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("b"));
302 Ok(())
303 }
304
305 #[tokio::test]
306 async fn namespace_isolation_blocks_cross_tenant_reads() -> Result<()> {
307 let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
308 let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
309 let alice = ns("alice");
310 let bob = ns("bob");
311 seed_store(
312 &embedder,
313 &store,
314 &alice,
315 &[("alice-doc", "alpha bravo charlie")],
316 )
317 .await?;
318 let bob_retriever = EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), bob);
321 let ctx = ExecutionContext::new();
322 let hits = bob_retriever
323 .retrieve(RetrievalQuery::new("alpha bravo charlie", 10), &ctx)
324 .await?;
325 assert!(
326 hits.is_empty(),
327 "Bob must not observe Alice's documents: {hits:?}"
328 );
329 Ok(())
330 }
331
332 #[tokio::test]
333 async fn clone_shares_embedder_and_store() {
334 let embedder = Arc::new(BowEmbedder::new(&["x"]));
335 let store = Arc::new(InMemoryVectorStore::new(1));
336 let namespace = ns("acme");
337 let original =
338 EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
339 let cloned = original.clone();
340 assert!(Arc::ptr_eq(original.embedder(), cloned.embedder()));
341 assert!(Arc::ptr_eq(original.store(), cloned.store()));
342 assert_eq!(cloned.namespace(), &namespace);
343 }
344}