use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::{Error, ExecutionContext, Result};
use crate::namespace::Namespace;
use crate::traits::{Document, Embedder, RerankedDocument, Reranker, VectorFilter, VectorStore};
#[async_trait]
pub trait SemanticMemoryBackend: Send + Sync + 'static {
fn namespace(&self) -> &Namespace;
fn dimension(&self) -> usize;
async fn search(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
) -> Result<Vec<Document>>;
async fn search_filtered(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
filter: &VectorFilter,
) -> Result<Vec<Document>>;
async fn search_with_rerank_dyn(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
candidates: usize,
reranker: &dyn Reranker,
) -> Result<Vec<RerankedDocument>>;
async fn add(&self, ctx: &ExecutionContext, document: Document) -> Result<()>;
async fn add_batch(&self, ctx: &ExecutionContext, documents: Vec<Document>) -> Result<()>;
async fn delete(&self, ctx: &ExecutionContext, doc_id: &str) -> Result<()>;
async fn update(&self, ctx: &ExecutionContext, doc_id: &str, document: Document) -> Result<()>;
async fn count(&self, ctx: &ExecutionContext, filter: Option<&VectorFilter>) -> Result<usize>;
async fn list(
&self,
ctx: &ExecutionContext,
filter: Option<&VectorFilter>,
limit: usize,
offset: usize,
) -> Result<Vec<Document>>;
}
pub struct SemanticMemory<E, V>
where
E: Embedder,
V: VectorStore,
{
embedder: Arc<E>,
vector_store: Arc<V>,
namespace: Namespace,
}
impl<E, V> SemanticMemory<E, V>
where
E: Embedder,
V: VectorStore,
{
pub fn new(embedder: Arc<E>, vector_store: Arc<V>, namespace: Namespace) -> Result<Self> {
let e_dim = embedder.dimension();
let v_dim = vector_store.dimension();
if e_dim != v_dim {
return Err(Error::config(format!(
"SemanticMemory: embedder dimension ({e_dim}) does not match vector-store \
dimension ({v_dim})"
)));
}
Ok(Self {
embedder,
vector_store,
namespace,
})
}
pub const fn namespace(&self) -> &Namespace {
&self.namespace
}
pub async fn add(&self, ctx: &ExecutionContext, document: Document) -> Result<()> {
let embedding = self.embedder.embed(&document.content, ctx).await?;
self.vector_store
.add(ctx, &self.namespace, document, embedding.vector)
.await
}
pub async fn add_batch(&self, ctx: &ExecutionContext, documents: Vec<Document>) -> Result<()> {
if documents.is_empty() {
return Ok(());
}
let texts: Vec<String> = documents.iter().map(|d| d.content.clone()).collect();
let embeddings = self.embedder.embed_batch(&texts, ctx).await?;
if embeddings.len() != texts.len() {
return Err(Error::config(format!(
"SemanticMemory::add_batch: embedder returned {} vectors for {} documents",
embeddings.len(),
texts.len()
)));
}
let items: Vec<(Document, Vec<f32>)> = documents
.into_iter()
.zip(embeddings)
.map(|(doc, embedding)| (doc, embedding.vector))
.collect();
self.vector_store
.add_batch(ctx, &self.namespace, items)
.await
}
pub async fn delete(&self, ctx: &ExecutionContext, doc_id: &str) -> Result<()> {
self.vector_store.delete(ctx, &self.namespace, doc_id).await
}
pub async fn update(
&self,
ctx: &ExecutionContext,
doc_id: &str,
document: Document,
) -> Result<()> {
let embedding = self.embedder.embed(&document.content, ctx).await?;
self.vector_store
.update(ctx, &self.namespace, doc_id, document, embedding.vector)
.await
}
pub async fn search(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
) -> Result<Vec<Document>> {
let embedding = self.embedder.embed(query, ctx).await?;
self.vector_store
.search(ctx, &self.namespace, &embedding.vector, top_k)
.await
}
pub async fn search_filtered(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
filter: &VectorFilter,
) -> Result<Vec<Document>> {
let embedding = self.embedder.embed(query, ctx).await?;
self.vector_store
.search_filtered(ctx, &self.namespace, &embedding.vector, top_k, filter)
.await
}
pub async fn search_with_rerank<R: Reranker>(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
candidates: usize,
reranker: &R,
) -> Result<Vec<RerankedDocument>> {
let pool = self.search(ctx, query, candidates.max(top_k)).await?;
reranker.rerank(query, pool, top_k, ctx).await
}
pub async fn count(
&self,
ctx: &ExecutionContext,
filter: Option<&VectorFilter>,
) -> Result<usize> {
self.vector_store.count(ctx, &self.namespace, filter).await
}
pub async fn list(
&self,
ctx: &ExecutionContext,
filter: Option<&VectorFilter>,
limit: usize,
offset: usize,
) -> Result<Vec<Document>> {
self.vector_store
.list(ctx, &self.namespace, filter, limit, offset)
.await
}
}
#[async_trait]
impl<E, V> SemanticMemoryBackend for SemanticMemory<E, V>
where
E: Embedder,
V: VectorStore,
{
fn namespace(&self) -> &Namespace {
&self.namespace
}
fn dimension(&self) -> usize {
self.embedder.dimension()
}
async fn search(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
) -> Result<Vec<Document>> {
Self::search(self, ctx, query, top_k).await
}
async fn search_filtered(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
filter: &VectorFilter,
) -> Result<Vec<Document>> {
Self::search_filtered(self, ctx, query, top_k, filter).await
}
async fn add(&self, ctx: &ExecutionContext, document: Document) -> Result<()> {
Self::add(self, ctx, document).await
}
async fn add_batch(&self, ctx: &ExecutionContext, documents: Vec<Document>) -> Result<()> {
Self::add_batch(self, ctx, documents).await
}
async fn delete(&self, ctx: &ExecutionContext, doc_id: &str) -> Result<()> {
Self::delete(self, ctx, doc_id).await
}
async fn update(&self, ctx: &ExecutionContext, doc_id: &str, document: Document) -> Result<()> {
Self::update(self, ctx, doc_id, document).await
}
async fn search_with_rerank_dyn(
&self,
ctx: &ExecutionContext,
query: &str,
top_k: usize,
candidates: usize,
reranker: &dyn Reranker,
) -> Result<Vec<RerankedDocument>> {
let pool = self.search(ctx, query, candidates.max(top_k)).await?;
reranker.rerank(query, pool, top_k, ctx).await
}
async fn count(&self, ctx: &ExecutionContext, filter: Option<&VectorFilter>) -> Result<usize> {
Self::count(self, ctx, filter).await
}
async fn list(
&self,
ctx: &ExecutionContext,
filter: Option<&VectorFilter>,
limit: usize,
offset: usize,
) -> Result<Vec<Document>> {
Self::list(self, ctx, filter, limit, offset).await
}
}