use crate::{
embedding::{EmbeddingPipeline, EmbeddingProvider, OpenAIEmbedding},
indexing::IndexManager,
retrieval::{HybridResult, HybridRetriever, RetrievalStats as MemRetrievalStats},
storage::{AccessContext, AccessLevel, Storage},
Document as MemDocument, DocumentType, Error as MemError, MatchSource, Result as MemResult,
Source, SourceType,
};
use async_trait::async_trait;
use chrono::Utc;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
pub type MemoryResult<T> = Result<T, MemoryError>;
#[derive(thiserror::Error, Debug)]
pub enum MemoryError {
#[error("Document not found: {0}")]
NotFound(Uuid),
#[error("Storage error: {0}")]
Storage(String),
#[error("Embedding error: {0}")]
Embedding(String),
#[error("Index error: {0}")]
Index(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
impl From<MemError> for MemoryError {
fn from(e: MemError) -> Self {
match e {
MemError::NotFound(s) => MemoryError::NotFound(Uuid::parse_str(&s).unwrap_or_default()),
MemError::Storage(s) => MemoryError::Storage(s),
MemError::Embedding(s) => MemoryError::Embedding(s),
MemError::Indexing(s) => MemoryError::Index(s),
MemError::Config(s) => MemoryError::Config(s),
MemError::Serialization(s) => MemoryError::Serialization(s),
MemError::Io(e) => MemoryError::Io(e),
other => MemoryError::Storage(other.to_string()),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Document {
pub id: Option<Uuid>,
pub content: String,
pub metadata: HashMap<String, String>,
pub source: Option<String>,
pub created_at: Option<i64>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Chunk {
pub id: Option<Uuid>,
pub document_id: Uuid,
pub content: String,
pub index: usize,
pub embedding: Option<Vec<f32>>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchResult {
pub chunk: Chunk,
pub score: f32,
pub source: RetrievalSource,
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub enum RetrievalSource {
Vector,
BM25,
Hybrid,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HybridConfig {
pub vector_weight: f32,
pub bm25_weight: f32,
pub use_reranker: bool,
pub reranker_top_k: usize,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
vector_weight: 0.7,
bm25_weight: 0.3,
use_reranker: true,
reranker_top_k: 10,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextWindow {
pub chunks: Vec<SearchResult>,
pub total_tokens: usize,
pub truncated: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IndexConfig {
pub name: String,
pub dimensions: usize,
pub metric: DistanceMetric,
pub ef_construction: usize,
pub m: usize,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
name: "default".to_string(),
dimensions: 384,
metric: DistanceMetric::Cosine,
ef_construction: 200,
m: 16,
}
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub enum DistanceMetric {
Cosine,
Euclidean,
DotProduct,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IndexStats {
pub total_documents: usize,
pub total_chunks: usize,
pub total_vectors: usize,
pub index_size_bytes: u64,
pub last_updated: i64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MemoryConfig {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub embedding_model: String,
pub embedding_dimensions: usize,
pub max_context_tokens: usize,
pub storage_path: Option<String>,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
chunk_size: 512,
chunk_overlap: 50,
embedding_model: "all-MiniLM-L6-v2".to_string(),
embedding_dimensions: 384,
max_context_tokens: 4096,
storage_path: None,
}
}
}
#[async_trait]
pub trait MemoryService: Send + Sync {
async fn store_document(&self, doc: &Document) -> MemoryResult<Uuid>;
async fn store_chunks(&self, chunks: &[Chunk]) -> MemoryResult<Vec<Uuid>>;
async fn delete_document(&self, id: Uuid) -> MemoryResult<()>;
async fn update_document(&self, id: Uuid, doc: &Document) -> MemoryResult<()>;
async fn search(&self, query: &str, top_k: usize) -> MemoryResult<Vec<SearchResult>>;
async fn hybrid_search(
&self,
query: &str,
top_k: usize,
config: HybridConfig,
) -> MemoryResult<Vec<SearchResult>>;
async fn get_by_id(&self, id: Uuid) -> MemoryResult<Option<Document>>;
async fn get_context(&self, query: &str, max_tokens: usize) -> MemoryResult<ContextWindow>;
async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>>;
async fn create_index(&self, config: IndexConfig) -> MemoryResult<()>;
async fn rebuild_index(&self) -> MemoryResult<()>;
async fn get_stats(&self) -> MemoryResult<IndexStats>;
fn config(&self) -> &MemoryConfig;
fn set_config(&mut self, config: MemoryConfig);
async fn health_check(&self) -> MemoryResult<bool>;
async fn flush(&self) -> MemoryResult<()>;
async fn shutdown(&self) -> MemoryResult<()>;
}
pub struct MemServiceImpl {
retriever: HybridRetriever,
embedding_pipeline: Option<Arc<EmbeddingPipeline>>,
config: RwLock<MemoryConfig>,
is_healthy: std::sync::atomic::AtomicBool,
}
impl MemServiceImpl {
pub fn in_memory() -> MemResult<Self> {
let retriever = HybridRetriever::in_memory()?;
Ok(Self {
retriever,
embedding_pipeline: None,
config: RwLock::new(MemoryConfig::default()),
is_healthy: std::sync::atomic::AtomicBool::new(true),
})
}
pub fn new(storage: Storage, index: IndexManager) -> Self {
let retriever = HybridRetriever::new(storage, index);
Self {
retriever,
embedding_pipeline: None,
config: RwLock::new(MemoryConfig::default()),
is_healthy: std::sync::atomic::AtomicBool::new(true),
}
}
pub fn with_embedding_pipeline(mut self, pipeline: Arc<EmbeddingPipeline>) -> Self {
self.embedding_pipeline = Some(pipeline.clone());
self.retriever = self.retriever.with_embedding_pipeline(pipeline);
self
}
pub fn with_openai_embeddings(self) -> MemResult<Self> {
let provider = OpenAIEmbedding::openai()?;
let pipeline = Arc::new(EmbeddingPipeline::new(Arc::new(provider)));
Ok(self.with_embedding_pipeline(pipeline))
}
pub fn with_config(self, config: MemoryConfig) -> Self {
*self.config.write().unwrap() = config;
self
}
pub fn retriever(&self) -> &HybridRetriever {
&self.retriever
}
fn admin_context(&self, operation: &str) -> AccessContext {
AccessContext::new(
"mem-service".to_string(),
AccessLevel::Admin,
operation.to_string(),
)
}
fn to_mem_document(&self, doc: &Document) -> MemDocument {
use crate::types::{Chunk as MemChunk, EmbeddingIds};
let source = Source {
source_type: SourceType::Local,
url: None,
path: doc.source.clone(),
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: None,
};
let mut mem_doc =
MemDocument::new(DocumentType::Note, source).with_content(doc.content.clone());
if let Some(id) = doc.id {
mem_doc.id = id;
}
mem_doc.metadata.tags = doc.metadata.keys().cloned().collect();
if !doc.content.is_empty() {
let chunk = MemChunk {
id: Uuid::new_v4(),
text: doc.content.clone(),
index: 0,
start_char: 0,
end_char: doc.content.len(),
token_count: Some(doc.content.split_whitespace().count()),
section: None,
page: None,
embedding_ids: EmbeddingIds::default(),
};
mem_doc.chunks.push(chunk);
}
mem_doc
}
fn to_search_result(&self, result: &HybridResult) -> SearchResult {
let source = match result.match_source {
MatchSource::Dense => RetrievalSource::Vector,
MatchSource::Sparse => RetrievalSource::BM25,
MatchSource::Hybrid | MatchSource::Raptor => RetrievalSource::Hybrid,
};
SearchResult {
chunk: Chunk {
id: Some(result.chunk_id),
document_id: result.doc_id,
content: result.text.clone(),
index: 0, embedding: None,
metadata: HashMap::new(),
},
score: result.score,
source,
}
}
fn to_external_document(&self, doc: &MemDocument) -> Document {
let mut metadata = HashMap::new();
for tag in &doc.metadata.tags {
metadata.insert(tag.clone(), "true".to_string());
}
Document {
id: Some(doc.id),
content: doc.content.raw.clone(),
metadata,
source: doc.source.path.clone(),
created_at: Some(doc.created_at.timestamp()),
}
}
}
#[async_trait]
impl MemoryService for MemServiceImpl {
async fn store_document(&self, doc: &Document) -> MemoryResult<Uuid> {
let mem_doc = self.to_mem_document(doc);
let doc_id = mem_doc.id;
self.retriever.add_document(&mem_doc).await?;
Ok(doc_id)
}
async fn store_chunks(&self, chunks: &[Chunk]) -> MemoryResult<Vec<Uuid>> {
let mut ids = Vec::with_capacity(chunks.len());
for chunk in chunks {
let doc = Document {
id: chunk.id,
content: chunk.content.clone(),
metadata: chunk.metadata.clone(),
source: None,
created_at: None,
};
let id = self.store_document(&doc).await?;
ids.push(id);
}
Ok(ids)
}
async fn delete_document(&self, id: Uuid) -> MemoryResult<()> {
self.retriever.delete_document(&id).await?;
Ok(())
}
async fn update_document(&self, id: Uuid, doc: &Document) -> MemoryResult<()> {
self.delete_document(id).await?;
let mut new_doc = doc.clone();
new_doc.id = Some(id);
self.store_document(&new_doc).await?;
Ok(())
}
async fn search(&self, query: &str, top_k: usize) -> MemoryResult<Vec<SearchResult>> {
let results = self.retriever.search(query, top_k).await?;
Ok(results.iter().map(|r| self.to_search_result(r)).collect())
}
async fn hybrid_search(
&self,
query: &str,
top_k: usize,
config: HybridConfig,
) -> MemoryResult<Vec<SearchResult>> {
let retrieval_config = crate::RetrievalConfig {
top_k,
min_score: 0.0,
alpha: config.vector_weight,
use_raptor: false,
rerank: config.use_reranker,
};
let results = self
.retriever
.search_hybrid(query, None, &retrieval_config)
.await?;
Ok(results.iter().map(|r| self.to_search_result(r)).collect())
}
async fn get_by_id(&self, id: Uuid) -> MemoryResult<Option<Document>> {
let context = self.admin_context("get_by_id");
match self.retriever.storage().get_document(&id, &context).await {
Ok(Some(doc)) => Ok(Some(self.to_external_document(&doc))),
Ok(None) => Ok(None),
Err(e) => Err(e.into()),
}
}
async fn get_context(&self, query: &str, max_tokens: usize) -> MemoryResult<ContextWindow> {
let top_k = {
let config = self.config.read().unwrap();
max_tokens / config.chunk_size.max(1)
};
let results = self.search(query, top_k.max(5)).await?;
let mut total_tokens = 0;
let mut chunks = Vec::new();
let mut truncated = false;
for result in results {
let chunk_tokens = result.chunk.content.len() / 4;
if total_tokens + chunk_tokens > max_tokens {
truncated = true;
break;
}
total_tokens += chunk_tokens;
chunks.push(result);
}
Ok(ContextWindow {
chunks,
total_tokens,
truncated,
})
}
async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
let pipeline = self
.embedding_pipeline
.as_ref()
.ok_or_else(|| MemoryError::Config("Embedding pipeline not configured".into()))?;
pipeline
.embed_text(text)
.await
.map_err(|e| MemoryError::Embedding(e.to_string()))
}
async fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>> {
let pipeline = self
.embedding_pipeline
.as_ref()
.ok_or_else(|| MemoryError::Config("Embedding pipeline not configured".into()))?;
let provider = pipeline.provider();
let results = provider
.embed_batch(texts)
.await
.map_err(|e| MemoryError::Embedding(e.to_string()))?;
results
.into_iter()
.map(|r| {
r.dense
.ok_or_else(|| MemoryError::Embedding("No dense embedding returned".into()))
})
.collect()
}
async fn create_index(&self, _config: IndexConfig) -> MemoryResult<()> {
Ok(())
}
async fn rebuild_index(&self) -> MemoryResult<()> {
self.retriever
.index()
.optimize()
.map_err(|e| MemoryError::Index(e.to_string()))?;
Ok(())
}
async fn get_stats(&self) -> MemoryResult<IndexStats> {
let stats = self.retriever.stats().await?;
Ok(IndexStats {
total_documents: stats.document_count,
total_chunks: stats.chunk_count,
total_vectors: stats.embedding_count,
index_size_bytes: stats.storage_bytes + stats.index_bytes,
last_updated: Utc::now().timestamp(),
})
}
fn config(&self) -> &MemoryConfig {
let config = self.config.read().unwrap().clone();
Box::leak(Box::new(config))
}
fn set_config(&mut self, config: MemoryConfig) {
*self.config.write().unwrap() = config;
}
async fn health_check(&self) -> MemoryResult<bool> {
Ok(self.is_healthy.load(std::sync::atomic::Ordering::SeqCst))
}
async fn flush(&self) -> MemoryResult<()> {
Ok(())
}
async fn shutdown(&self) -> MemoryResult<()> {
self.is_healthy
.store(false, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mem_service_creation() {
let service = MemServiceImpl::in_memory().expect("Failed to create service");
assert!(service.health_check().await.unwrap());
}
#[tokio::test]
async fn test_store_and_search_sparse() {
let service = MemServiceImpl::in_memory().expect("Failed to create service");
let doc = Document {
id: None,
content: "Machine learning is a subset of artificial intelligence.".to_string(),
metadata: HashMap::new(),
source: Some("/test/doc.md".to_string()),
created_at: None,
};
let id = service.store_document(&doc).await.unwrap();
assert_ne!(id, Uuid::nil());
let results = service
.retriever
.search_sparse("machine learning", 5)
.await
.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_get_stats() {
let service = MemServiceImpl::in_memory().expect("Failed to create service");
let stats = service.get_stats().await.unwrap();
assert_eq!(stats.total_documents, 0);
}
#[tokio::test]
async fn test_shutdown() {
let service = MemServiceImpl::in_memory().expect("Failed to create service");
assert!(service.health_check().await.unwrap());
service.shutdown().await.unwrap();
assert!(!service.health_check().await.unwrap());
}
#[test]
fn test_config_default() {
let config = MemoryConfig::default();
assert_eq!(config.chunk_size, 512);
assert_eq!(config.chunk_overlap, 50);
assert_eq!(config.embedding_dimensions, 384);
}
#[test]
fn test_hybrid_config_default() {
let config = HybridConfig::default();
assert!((config.vector_weight - 0.7).abs() < 0.001);
assert!((config.bm25_weight - 0.3).abs() < 0.001);
assert!(config.use_reranker);
}
}