use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::{ExecutionContext, Result};
use crate::namespace::Namespace;
use crate::traits::{Document, Embedder, RetrievalQuery, Retriever, VectorStore};
pub struct EmbeddingRetriever<E, V> {
embedder: Arc<E>,
store: Arc<V>,
namespace: Namespace,
}
impl<E, V> Clone for EmbeddingRetriever<E, V> {
fn clone(&self) -> Self {
Self {
embedder: Arc::clone(&self.embedder),
store: Arc::clone(&self.store),
namespace: self.namespace.clone(),
}
}
}
impl<E, V> std::fmt::Debug for EmbeddingRetriever<E, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingRetriever")
.field("namespace", &self.namespace)
.finish_non_exhaustive()
}
}
impl<E, V> EmbeddingRetriever<E, V>
where
E: Embedder,
V: VectorStore,
{
#[must_use]
pub const fn new(embedder: Arc<E>, store: Arc<V>, namespace: Namespace) -> Self {
Self {
embedder,
store,
namespace,
}
}
#[must_use]
pub const fn embedder(&self) -> &Arc<E> {
&self.embedder
}
#[must_use]
pub const fn store(&self) -> &Arc<V> {
&self.store
}
#[must_use]
pub const fn namespace(&self) -> &Namespace {
&self.namespace
}
}
#[async_trait]
impl<E, V> Retriever for EmbeddingRetriever<E, V>
where
E: Embedder + 'static,
V: VectorStore + 'static,
{
async fn retrieve(
&self,
query: RetrievalQuery,
ctx: &ExecutionContext,
) -> Result<Vec<Document>> {
let embedding = self.embedder.embed(&query.text, ctx).await?;
let mut hits = match query.filter.as_ref() {
Some(filter) => {
self.store
.search_filtered(ctx, &self.namespace, &embedding.vector, query.top_k, filter)
.await?
}
None => {
self.store
.search(ctx, &self.namespace, &embedding.vector, query.top_k)
.await?
}
};
if let Some(floor) = query.min_score {
hits.retain(|doc| doc.score.is_some_and(|s| s >= floor));
}
Ok(hits)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::in_memory_vector_store::InMemoryVectorStore;
use crate::traits::{Embedding, VectorFilter};
use entelix_core::TenantId;
use std::sync::Arc;
struct BowEmbedder {
vocab: std::collections::HashMap<String, usize>,
dimension: usize,
}
impl BowEmbedder {
fn new(words: &[&str]) -> Self {
let dimension = words.len();
let vocab = words
.iter()
.enumerate()
.map(|(i, w)| ((*w).to_owned(), i))
.collect();
Self { vocab, dimension }
}
}
#[async_trait]
impl Embedder for BowEmbedder {
fn dimension(&self) -> usize {
self.dimension
}
async fn embed(&self, text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
let mut v = vec![0.0_f32; self.dimension];
for word in text.to_lowercase().split_whitespace() {
if let Some(&idx) = self.vocab.get(word)
&& let Some(slot) = v.get_mut(idx)
{
*slot += 1.0;
}
}
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
Ok(Embedding::new(v))
}
}
fn ns(tenant: &str) -> Namespace {
Namespace::new(TenantId::new(tenant))
}
async fn seed_store(
embedder: &Arc<BowEmbedder>,
store: &Arc<InMemoryVectorStore>,
namespace: &Namespace,
docs: &[(&str, &str)],
) -> Result<()> {
let ctx = ExecutionContext::new();
let mut items = Vec::new();
for (id, content) in docs {
let emb = embedder.embed(content, &ctx).await?;
let doc = Document::new(*content).with_doc_id((*id).to_owned());
items.push((doc, emb.vector));
}
store.add_batch(&ctx, namespace, items).await
}
#[tokio::test]
async fn retrieves_top_k_for_query() -> Result<()> {
let embedder = Arc::new(BowEmbedder::new(&[
"rust", "agent", "tokio", "async", "memory", "graph",
]));
let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
let namespace = ns("acme");
seed_store(
&embedder,
&store,
&namespace,
&[
("a", "rust agent tokio"),
("b", "graph memory"),
("c", "async rust"),
],
)
.await?;
let retriever =
EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
let ctx = ExecutionContext::new();
let hits = retriever
.retrieve(RetrievalQuery::new("rust agent", 2), &ctx)
.await?;
assert_eq!(hits.len(), 2);
assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
Ok(())
}
#[tokio::test]
async fn min_score_post_filters_below_floor() -> Result<()> {
let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
let namespace = ns("acme");
seed_store(
&embedder,
&store,
&namespace,
&[("a", "alpha bravo"), ("b", "alpha"), ("c", "charlie")],
)
.await?;
let retriever =
EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
let ctx = ExecutionContext::new();
let hits = retriever
.retrieve(
RetrievalQuery::new("alpha bravo", 5).with_min_score(0.99),
&ctx,
)
.await?;
assert_eq!(hits.len(), 1);
assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("a"));
Ok(())
}
#[tokio::test]
async fn filter_routes_through_search_filtered() -> Result<()> {
let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo"]));
let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
let namespace = ns("acme");
let ctx = ExecutionContext::new();
let docs = [
("a", "alpha bravo", serde_json::json!({"kind": "code"})),
("b", "alpha", serde_json::json!({"kind": "doc"})),
];
let mut items = Vec::new();
for (id, content, meta) in &docs {
let emb = embedder.embed(content, &ctx).await?;
let doc = Document::new(*content)
.with_doc_id((*id).to_owned())
.with_metadata(meta.clone());
items.push((doc, emb.vector));
}
store.add_batch(&ctx, &namespace, items).await?;
let retriever =
EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
let hits = retriever
.retrieve(
RetrievalQuery::new("alpha", 5).with_filter(VectorFilter::Eq {
key: "kind".to_owned(),
value: serde_json::json!("doc"),
}),
&ctx,
)
.await?;
assert_eq!(hits.len(), 1);
assert_eq!(hits.first().and_then(|h| h.doc_id.as_deref()), Some("b"));
Ok(())
}
#[tokio::test]
async fn namespace_isolation_blocks_cross_tenant_reads() -> Result<()> {
let embedder = Arc::new(BowEmbedder::new(&["alpha", "bravo", "charlie"]));
let store = Arc::new(InMemoryVectorStore::new(embedder.dimension()));
let alice = ns("alice");
let bob = ns("bob");
seed_store(
&embedder,
&store,
&alice,
&[("alice-doc", "alpha bravo charlie")],
)
.await?;
let bob_retriever = EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), bob);
let ctx = ExecutionContext::new();
let hits = bob_retriever
.retrieve(RetrievalQuery::new("alpha bravo charlie", 10), &ctx)
.await?;
assert!(
hits.is_empty(),
"Bob must not observe Alice's documents: {hits:?}"
);
Ok(())
}
#[tokio::test]
async fn clone_shares_embedder_and_store() {
let embedder = Arc::new(BowEmbedder::new(&["x"]));
let store = Arc::new(InMemoryVectorStore::new(1));
let namespace = ns("acme");
let original =
EmbeddingRetriever::new(Arc::clone(&embedder), Arc::clone(&store), namespace.clone());
let cloned = original.clone();
assert!(Arc::ptr_eq(original.embedder(), cloned.embedder()));
assert!(Arc::ptr_eq(original.store(), cloned.store()));
assert_eq!(cloned.namespace(), &namespace);
}
}