use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use uuid::Uuid;
pub type MemoryResult<T> = Result<T, MemoryError>;
#[derive(Error, Debug)]
pub enum MemoryError {
#[error("Document not found: {0}")]
NotFound(Uuid),
#[error("Storage error: {0}")]
Storage(String),
#[error("Embedding error: {0}")]
Embedding(String),
#[error("Index error: {0}")]
Index(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: Option<Uuid>,
pub content: String,
pub metadata: HashMap<String, String>,
pub source: Option<String>,
pub created_at: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub id: Option<Uuid>,
pub document_id: Uuid,
pub content: String,
pub index: usize,
pub embedding: Option<Vec<f32>>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub chunk: Chunk,
pub score: f32,
pub source: RetrievalSource,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum RetrievalSource {
Vector,
BM25,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfig {
pub vector_weight: f32,
pub bm25_weight: f32,
pub use_reranker: bool,
pub reranker_top_k: usize,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
vector_weight: 0.7,
bm25_weight: 0.3,
use_reranker: true,
reranker_top_k: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextWindow {
pub chunks: Vec<SearchResult>,
pub total_tokens: usize,
pub truncated: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
pub name: String,
pub dimensions: usize,
pub metric: DistanceMetric,
pub ef_construction: usize,
pub m: usize,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
name: "default".to_string(),
dimensions: 384,
metric: DistanceMetric::Cosine,
ef_construction: 200,
m: 16,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum DistanceMetric {
Cosine,
Euclidean,
DotProduct,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub total_documents: usize,
pub total_chunks: usize,
pub total_vectors: usize,
pub index_size_bytes: u64,
pub last_updated: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub embedding_model: String,
pub embedding_dimensions: usize,
pub max_context_tokens: usize,
pub storage_path: Option<String>,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
chunk_size: 512,
chunk_overlap: 50,
embedding_model: "all-MiniLM-L6-v2".to_string(),
embedding_dimensions: 384,
max_context_tokens: 4096,
storage_path: None,
}
}
}
#[async_trait]
pub trait MemoryService: Send + Sync {
async fn store_document(&self, doc: &Document) -> MemoryResult<Uuid>;
async fn store_chunks(&self, chunks: &[Chunk]) -> MemoryResult<Vec<Uuid>>;
async fn delete_document(&self, id: Uuid) -> MemoryResult<()>;
async fn update_document(&self, id: Uuid, doc: &Document) -> MemoryResult<()>;
async fn search(&self, query: &str, top_k: usize) -> MemoryResult<Vec<SearchResult>>;
async fn hybrid_search(
&self,
query: &str,
top_k: usize,
config: HybridConfig,
) -> MemoryResult<Vec<SearchResult>>;
async fn get_by_id(&self, id: Uuid) -> MemoryResult<Option<Document>>;
async fn get_context(&self, query: &str, max_tokens: usize) -> MemoryResult<ContextWindow>;
async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>>;
async fn create_index(&self, config: IndexConfig) -> MemoryResult<()>;
async fn rebuild_index(&self) -> MemoryResult<()>;
async fn get_stats(&self) -> MemoryResult<IndexStats>;
fn config(&self) -> &MemoryConfig;
fn set_config(&mut self, config: MemoryConfig);
async fn health_check(&self) -> MemoryResult<bool>;
async fn flush(&self) -> MemoryResult<()>;
async fn shutdown(&self) -> MemoryResult<()>;
}