use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[cfg(feature = "memory")]
pub use reasonkit_mem::{
Chunk, Document, DocumentContent, DocumentType, MatchSource, Metadata, ProcessingState,
ProcessingStatus, RetrievalConfig, SearchResult, Source, SourceType,
};
#[cfg(not(feature = "memory"))]
pub type Chunk = ();
#[cfg(not(feature = "memory"))]
pub type Document = ();
#[cfg(not(feature = "memory"))]
pub type DocumentContent = ();
#[cfg(not(feature = "memory"))]
pub type DocumentType = ();
#[cfg(not(feature = "memory"))]
pub type Metadata = ();
#[cfg(not(feature = "memory"))]
pub type ProcessingState = ();
#[cfg(not(feature = "memory"))]
pub type ProcessingStatus = ();
#[cfg(not(feature = "memory"))]
pub type RetrievalConfig = ();
#[cfg(not(feature = "memory"))]
pub type SearchResult = ();
#[cfg(not(feature = "memory"))]
pub type Source = ();
#[cfg(not(feature = "memory"))]
pub type SourceType = ();
#[cfg(not(feature = "memory"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MatchSource {
Dense,
Sparse,
Hybrid,
Raptor,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryError {
pub category: ErrorCategory,
pub message: String,
pub context: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ErrorCategory {
Storage,
Embedding,
Retrieval,
Indexing,
NotFound,
InvalidInput,
Config,
Internal,
}
impl std::fmt::Display for MemoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:?}: {}{}",
self.category,
self.message,
self.context
.as_ref()
.map(|c| format!(" ({})", c))
.unwrap_or_default()
)
}
}
impl std::error::Error for MemoryError {}
pub type MemoryResult<T> = std::result::Result<T, MemoryError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextConfig {
pub top_k: usize,
pub min_score: f32,
pub alpha: f32,
pub use_raptor: bool,
pub rerank: bool,
pub include_metadata: bool,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
top_k: 10,
min_score: 0.0,
alpha: 0.7, use_raptor: false,
rerank: false,
include_metadata: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextWindow {
pub chunks: Vec<Chunk>,
pub documents: Vec<Document>,
pub scores: Vec<f32>,
pub sources: Vec<MatchSource>,
pub token_count: usize,
pub quality: ContextQuality,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextQuality {
pub avg_score: f32,
pub max_score: f32,
pub min_score: f32,
pub diversity: f32,
pub coverage: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryStats {
pub document_count: usize,
pub chunk_count: usize,
pub embedding_count: usize,
pub storage_size_bytes: u64,
pub indexed_count: usize,
pub is_healthy: bool,
}
#[async_trait]
pub trait MemoryService: Send + Sync {
async fn store_document(&self, document: &Document) -> MemoryResult<Uuid>;
async fn store_documents(&self, documents: &[Document]) -> MemoryResult<Vec<Uuid>>;
async fn get_document(&self, doc_id: &Uuid) -> MemoryResult<Option<Document>>;
async fn delete_document(&self, doc_id: &Uuid) -> MemoryResult<()>;
async fn list_documents(&self) -> MemoryResult<Vec<Uuid>>;
async fn search(&self, query: &str, top_k: usize) -> MemoryResult<Vec<SearchResult>>;
async fn search_with_config(
&self,
query: &str,
config: &ContextConfig,
) -> MemoryResult<Vec<SearchResult>>;
async fn search_by_vector(
&self,
embedding: &[f32],
top_k: usize,
) -> MemoryResult<Vec<SearchResult>>;
async fn search_by_keywords(
&self,
query: &str,
top_k: usize,
) -> MemoryResult<Vec<SearchResult>>;
async fn get_context(&self, query: &str, top_k: usize) -> MemoryResult<ContextWindow>;
async fn get_context_with_config(
&self,
query: &str,
config: &ContextConfig,
) -> MemoryResult<ContextWindow>;
async fn get_document_chunks(&self, doc_id: &Uuid) -> MemoryResult<Vec<Chunk>>;
async fn embed(&self, text: &str) -> MemoryResult<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> MemoryResult<Vec<Vec<f32>>>;
async fn build_indexes(&self) -> MemoryResult<()>;
async fn rebuild_indexes(&self) -> MemoryResult<()>;
async fn check_index_health(&self) -> MemoryResult<IndexStats>;
async fn stats(&self) -> MemoryResult<MemoryStats>;
async fn is_healthy(&self) -> MemoryResult<bool>;
async fn clear_all(&self) -> MemoryResult<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub indexed_docs: usize,
pub indexed_chunks: usize,
pub index_size_bytes: u64,
pub last_indexed_at: i64,
pub is_valid: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_config_default() {
let config = ContextConfig::default();
assert_eq!(config.top_k, 10);
assert_eq!(config.alpha, 0.7);
assert!(!config.use_raptor);
}
#[test]
fn test_memory_error_display() {
let err = MemoryError {
category: ErrorCategory::NotFound,
message: "Document not found".to_string(),
context: Some("doc_id=123".to_string()),
};
let display = format!("{}", err);
assert!(display.contains("NotFound"));
assert!(display.contains("Document not found"));
}
#[test]
fn test_context_quality_fields() {
let quality = ContextQuality {
avg_score: 0.8,
max_score: 0.95,
min_score: 0.65,
diversity: 0.7,
coverage: 0.85,
};
assert!(quality.avg_score < quality.max_score);
assert!(quality.min_score < quality.avg_score);
assert!(quality.diversity >= 0.0 && quality.diversity <= 1.0);
assert!(quality.coverage >= 0.0 && quality.coverage <= 1.0);
}
}