mod embedding;
mod retrieval;
mod store;
pub use embedding::{Embedding, EmbeddingModel, EmbeddingProvider};
pub use retrieval::{RetrievalConfig, RetrievalResult, Retriever};
pub use store::{Memory, MemoryId, MemoryStore, MemoryType};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct MemorySystem {
short_term: Arc<RwLock<ShortTermMemory>>,
long_term: Arc<RwLock<Box<dyn MemoryStore>>>,
embedder: Arc<dyn EmbeddingProvider>,
#[allow(dead_code)]
retriever: Arc<Retriever>,
config: MemoryConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub enabled: bool,
pub max_short_term: usize,
pub max_long_term: usize,
pub embedding_model: String,
pub storage_path: Option<PathBuf>,
pub similarity_threshold: f32,
pub max_memory_tokens: usize,
pub persist_interval_secs: u64,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
enabled: true,
max_short_term: 100,
max_long_term: 10000,
embedding_model: "text-embedding-3-small".to_string(),
storage_path: None,
similarity_threshold: 0.7,
max_memory_tokens: 500,
persist_interval_secs: 300,
}
}
}
#[derive(Debug, Default)]
pub struct ShortTermMemory {
entries: Vec<ShortTermEntry>,
max_entries: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShortTermEntry {
pub id: String,
pub content: String,
pub content_type: ContentType,
pub timestamp: DateTime<Utc>,
pub agent_id: Option<String>,
pub task_id: Option<String>,
pub importance: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContentType {
UserMessage,
AgentResponse,
ToolOutput,
Error,
System,
TaskContext,
Code,
Documentation,
}
impl ShortTermMemory {
pub fn new(max_entries: usize) -> Self {
Self {
entries: Vec::with_capacity(max_entries),
max_entries,
}
}
pub fn add(&mut self, entry: ShortTermEntry) {
if self.entries.len() >= self.max_entries {
if let Some(min_idx) = self
.entries
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.importance
.partial_cmp(&b.importance)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
{
self.entries.remove(min_idx);
}
}
self.entries.push(entry);
}
pub fn recent(&self, count: usize) -> Vec<&ShortTermEntry> {
self.entries.iter().rev().take(count).collect()
}
pub fn by_type(&self, content_type: ContentType) -> Vec<&ShortTermEntry> {
self.entries
.iter()
.filter(|e| e.content_type == content_type)
.collect()
}
pub fn by_agent(&self, agent_id: &str) -> Vec<&ShortTermEntry> {
self.entries
.iter()
.filter(|e| e.agent_id.as_deref() == Some(agent_id))
.collect()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn get_for_consolidation(&self, min_importance: f32) -> Vec<&ShortTermEntry> {
self.entries
.iter()
.filter(|e| e.importance >= min_importance)
.collect()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl MemorySystem {
pub fn new(config: MemoryConfig) -> Self {
let short_term = Arc::new(RwLock::new(ShortTermMemory::new(config.max_short_term)));
let long_term = Arc::new(RwLock::new(Box::new(store::InMemoryStore::new(
config.max_long_term,
)) as Box<dyn MemoryStore>));
let embedder = Arc::new(embedding::MockEmbedder::new()) as Arc<dyn EmbeddingProvider>;
let retriever = Arc::new(Retriever::new(RetrievalConfig::default()));
Self {
short_term,
long_term,
embedder,
retriever,
config,
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub async fn store(
&self,
content: String,
memory_type: MemoryType,
) -> Result<MemoryId, String> {
if !self.config.enabled {
return Err("Memory system is disabled".to_string());
}
let embedding = self
.embedder
.embed(&content)
.await
.map_err(|e| format!("Embedding failed: {}", e))?;
let memory = Memory::new(content, memory_type).with_embedding(embedding);
let mut store = self.long_term.write().await;
store.store(memory.clone()).await?;
Ok(memory.id)
}
pub async fn store_short_term(&self, entry: ShortTermEntry) {
let mut short_term = self.short_term.write().await;
short_term.add(entry);
}
pub async fn retrieve(
&self,
query: &str,
limit: usize,
) -> Result<Vec<RetrievalResult>, String> {
if !self.config.enabled {
return Ok(Vec::new());
}
let query_embedding = self
.embedder
.embed(query)
.await
.map_err(|e| format!("Query embedding failed: {}", e))?;
let store = self.long_term.read().await;
let memories = store.search(&query_embedding, limit).await?;
let results = memories
.into_iter()
.filter(|(_, score)| *score >= self.config.similarity_threshold)
.map(|(memory, score)| RetrievalResult {
memory,
score,
source: "long_term".to_string(),
})
.collect();
Ok(results)
}
pub async fn get_recent(&self, count: usize) -> Vec<ShortTermEntry> {
let short_term = self.short_term.read().await;
short_term.recent(count).into_iter().cloned().collect()
}
pub async fn consolidate(&self, min_importance: f32) -> Result<usize, String> {
let short_term = self.short_term.read().await;
let entries = short_term.get_for_consolidation(min_importance);
let mut consolidated = 0;
for entry in entries {
let memory_type = match entry.content_type {
ContentType::Code => MemoryType::CodePattern,
ContentType::Error => MemoryType::ErrorPattern,
ContentType::Documentation => MemoryType::Documentation,
_ => MemoryType::Conversation,
};
if self.store(entry.content.clone(), memory_type).await.is_ok() {
consolidated += 1;
}
}
Ok(consolidated)
}
pub async fn clear_short_term(&self) {
let mut short_term = self.short_term.write().await;
short_term.clear();
}
pub async fn get_stats(&self) -> MemoryStats {
let short_term = self.short_term.read().await;
let long_term = self.long_term.read().await;
MemoryStats {
short_term_count: short_term.len(),
long_term_count: long_term.count().await,
max_short_term: self.config.max_short_term,
max_long_term: self.config.max_long_term,
embedding_model: self.config.embedding_model.clone(),
}
}
pub async fn delete(&self, id: &MemoryId) -> Result<(), String> {
let mut store = self.long_term.write().await;
store.delete(id).await
}
pub async fn update_importance(&self, id: &MemoryId, importance: f32) -> Result<(), String> {
let mut store = self.long_term.write().await;
store.update_importance(id, importance).await
}
}
impl Default for MemorySystem {
fn default() -> Self {
Self::new(MemoryConfig::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryStats {
pub short_term_count: usize,
pub long_term_count: usize,
pub max_short_term: usize,
pub max_long_term: usize,
pub embedding_model: String,
}
pub fn build_rag_context(memories: &[RetrievalResult], max_tokens: usize) -> String {
let mut context = String::new();
let mut tokens_used = 0;
for result in memories {
let estimated_tokens = result.memory.content.len() / 4;
if tokens_used + estimated_tokens > max_tokens {
break;
}
context.push_str(&format!(
"### Relevant Memory (score: {:.2})\n{}\n\n",
result.score, result.memory.content
));
tokens_used += estimated_tokens;
}
context
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_short_term_memory() {
let mut memory = ShortTermMemory::new(3);
memory.add(ShortTermEntry {
id: "1".to_string(),
content: "First entry".to_string(),
content_type: ContentType::UserMessage,
timestamp: Utc::now(),
agent_id: None,
task_id: None,
importance: 0.5,
});
memory.add(ShortTermEntry {
id: "2".to_string(),
content: "Second entry".to_string(),
content_type: ContentType::AgentResponse,
timestamp: Utc::now(),
agent_id: Some("agent1".to_string()),
task_id: None,
importance: 0.8,
});
assert_eq!(memory.len(), 2);
let recent = memory.recent(5);
assert_eq!(recent.len(), 2);
}
#[test]
fn test_short_term_eviction() {
let mut memory = ShortTermMemory::new(2);
memory.add(ShortTermEntry {
id: "1".to_string(),
content: "Low importance".to_string(),
content_type: ContentType::System,
timestamp: Utc::now(),
agent_id: None,
task_id: None,
importance: 0.1,
});
memory.add(ShortTermEntry {
id: "2".to_string(),
content: "High importance".to_string(),
content_type: ContentType::UserMessage,
timestamp: Utc::now(),
agent_id: None,
task_id: None,
importance: 0.9,
});
memory.add(ShortTermEntry {
id: "3".to_string(),
content: "New entry".to_string(),
content_type: ContentType::AgentResponse,
timestamp: Utc::now(),
agent_id: None,
task_id: None,
importance: 0.5,
});
assert_eq!(memory.len(), 2);
assert!(memory.entries.iter().any(|e| e.id == "2"));
assert!(memory.entries.iter().any(|e| e.id == "3"));
}
#[tokio::test]
async fn test_memory_system() {
let system = MemorySystem::new(MemoryConfig::default());
assert!(system.is_enabled());
let id = system
.store(
"This is a test memory".to_string(),
MemoryType::Conversation,
)
.await
.unwrap();
assert!(!id.0.is_empty());
let stats = system.get_stats().await;
assert_eq!(stats.long_term_count, 1);
}
#[tokio::test]
async fn test_short_term_operations() {
let system = MemorySystem::new(MemoryConfig::default());
let entry = ShortTermEntry {
id: "test-1".to_string(),
content: "Test content".to_string(),
content_type: ContentType::UserMessage,
timestamp: Utc::now(),
agent_id: Some("agent1".to_string()),
task_id: None,
importance: 0.7,
};
system.store_short_term(entry).await;
let recent = system.get_recent(10).await;
assert_eq!(recent.len(), 1);
}
#[tokio::test]
async fn test_memory_retrieval() {
let system = MemorySystem::new(MemoryConfig::default());
system
.store(
"How to implement user authentication".to_string(),
MemoryType::Documentation,
)
.await
.unwrap();
system
.store(
"Fix login bug in auth module".to_string(),
MemoryType::TaskContext,
)
.await
.unwrap();
let results = system.retrieve("authentication", 5).await.unwrap();
assert!(results.len() <= 5);
}
#[test]
fn test_build_rag_context() {
let memories = vec![
RetrievalResult {
memory: Memory::new(
"First relevant memory".to_string(),
MemoryType::Conversation,
),
score: 0.9,
source: "long_term".to_string(),
},
RetrievalResult {
memory: Memory::new(
"Second relevant memory".to_string(),
MemoryType::Documentation,
),
score: 0.8,
source: "long_term".to_string(),
},
];
let context = build_rag_context(&memories, 1000);
assert!(context.contains("First relevant memory"));
assert!(context.contains("Second relevant memory"));
assert!(context.contains("0.90"));
}
#[tokio::test]
async fn test_memory_consolidation() {
let system = MemorySystem::new(MemoryConfig::default());
let entry = ShortTermEntry {
id: "consolidate-1".to_string(),
content: "Important pattern".to_string(),
content_type: ContentType::Code,
timestamp: Utc::now(),
agent_id: None,
task_id: None,
importance: 0.9,
};
system.store_short_term(entry).await;
let count = system.consolidate(0.8).await.unwrap();
assert_eq!(count, 1);
let stats = system.get_stats().await;
assert_eq!(stats.long_term_count, 1);
}
}