use chrono::{DateTime, Duration, Utc};
use parking_lot::RwLock;
use ruvector_core::index::hnsw::HnswIndex;
use ruvector_core::index::VectorIndex;
use ruvector_core::types::{DistanceMetric, HnswConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodicMemoryConfig {
pub embedding_dim: usize,
pub max_episodes: usize,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub hnsw_ef_search: usize,
pub compression_age_days: i64,
pub compression_ratio: f32,
pub auto_compress: bool,
}
impl Default for EpisodicMemoryConfig {
fn default() -> Self {
Self {
embedding_dim: 768,
max_episodes: 10_000,
hnsw_m: 16,
hnsw_ef_construction: 100,
hnsw_ef_search: 50,
compression_age_days: 7,
compression_ratio: 0.5,
auto_compress: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trajectory {
pub id: String,
pub steps: Vec<TrajectoryStep>,
pub outcome: f32,
pub quality_score: f32,
pub task_type: String,
pub agent_type: Option<String>,
pub duration_ms: u64,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryStep {
pub state: String,
pub action: String,
pub result: Option<String>,
pub embedding: Option<Vec<f32>>,
pub reward: f32,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodeMetadata {
pub episode_id: String,
pub task_description: String,
pub task_type: String,
pub outcome: f32,
pub quality_score: f32,
pub agent_type: Option<String>,
pub step_count: usize,
pub duration_ms: u64,
pub is_compressed: bool,
pub tags: Vec<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Episode {
pub id: String,
pub embedding: Vec<f32>,
pub metadata: EpisodeMetadata,
pub trajectory: Option<Trajectory>,
pub compressed: Option<CompressedEpisode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressedEpisode {
pub embedding: Vec<f32>,
pub summary: String,
pub key_observations: Vec<String>,
pub key_actions: Vec<String>,
pub patterns: Vec<String>,
pub original_step_count: usize,
pub compressed_at: DateTime<Utc>,
}
pub struct MemoryCompressor {
ratio: f32,
target_dim: Option<usize>,
}
impl MemoryCompressor {
pub fn new(ratio: f32, target_dim: Option<usize>) -> Self {
Self { ratio, target_dim }
}
pub fn compress(&self, trajectory: &Trajectory) -> CompressedEpisode {
let total_steps = trajectory.steps.len();
let keep_count = ((total_steps as f32) * self.ratio).max(1.0) as usize;
let mut steps_with_reward: Vec<(usize, &TrajectoryStep)> =
trajectory.steps.iter().enumerate().collect();
steps_with_reward.sort_by(|a, b| {
b.1.reward
.partial_cmp(&a.1.reward)
.unwrap_or(std::cmp::Ordering::Equal)
});
let key_steps: Vec<&TrajectoryStep> = steps_with_reward
.into_iter()
.take(keep_count)
.map(|(_, s)| s)
.collect();
let key_observations: Vec<String> = key_steps.iter().map(|s| s.state.clone()).collect();
let key_actions: Vec<String> = key_steps.iter().map(|s| s.action.clone()).collect();
let summary = format!(
"Task: {} | Outcome: {:.2} | Steps: {} | Key actions: {}",
trajectory.task_type,
trajectory.outcome,
total_steps,
key_actions.len()
);
let patterns = self.extract_patterns(&key_actions);
let embedding = self.compress_embedding(&key_steps);
CompressedEpisode {
embedding,
summary,
key_observations,
key_actions,
patterns,
original_step_count: total_steps,
compressed_at: Utc::now(),
}
}
fn extract_patterns(&self, actions: &[String]) -> Vec<String> {
let mut patterns = Vec::new();
let mut action_counts: HashMap<String, usize> = HashMap::new();
for action in actions {
if let Some(action_type) = action.split_whitespace().next() {
*action_counts.entry(action_type.to_string()).or_insert(0) += 1;
}
}
for (pattern, count) in action_counts {
if count > 1 {
patterns.push(format!("{}:{}", pattern, count));
}
}
patterns
}
fn compress_embedding(&self, steps: &[&TrajectoryStep]) -> Vec<f32> {
let embeddings: Vec<&Vec<f32>> =
steps.iter().filter_map(|s| s.embedding.as_ref()).collect();
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let target_dim = self.target_dim.unwrap_or(dim);
let mut avg = vec![0.0f32; dim];
for emb in &embeddings {
for (i, v) in emb.iter().enumerate() {
avg[i] += v;
}
}
let n = embeddings.len() as f32;
for v in &mut avg {
*v /= n;
}
if target_dim < dim {
avg.truncate(target_dim);
}
avg
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodicMemoryStats {
pub total_episodes: u64,
pub compressed_episodes: u64,
pub uncompressed_episodes: u64,
pub total_searches: u64,
pub avg_search_latency_us: u64,
pub successful_retrievals: u64,
}
pub struct EpisodicMemory {
config: EpisodicMemoryConfig,
index: Arc<RwLock<HnswIndex>>,
episodes: Arc<RwLock<HashMap<String, Episode>>>,
compressor: MemoryCompressor,
stats: EpisodicMemoryStatsInternal,
}
#[derive(Debug, Default)]
struct EpisodicMemoryStatsInternal {
total_searches: AtomicU64,
successful_retrievals: AtomicU64,
total_search_latency_us: AtomicU64,
}
impl EpisodicMemory {
pub fn new(config: EpisodicMemoryConfig) -> Result<Self> {
let hnsw_config = HnswConfig {
m: config.hnsw_m,
ef_construction: config.hnsw_ef_construction,
ef_search: config.hnsw_ef_search,
max_elements: config.max_episodes,
};
let index = HnswIndex::new(config.embedding_dim, DistanceMetric::Cosine, hnsw_config)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?;
let compressor = MemoryCompressor::new(config.compression_ratio, None);
Ok(Self {
config,
index: Arc::new(RwLock::new(index)),
episodes: Arc::new(RwLock::new(HashMap::new())),
compressor,
stats: EpisodicMemoryStatsInternal::default(),
})
}
pub fn store_episode(
&self,
trajectory: Trajectory,
summary_embedding: Vec<f32>,
tags: Vec<String>,
) -> Result<String> {
let episode_id = trajectory.id.clone();
let metadata = EpisodeMetadata {
episode_id: episode_id.clone(),
task_description: trajectory.task_type.clone(),
task_type: trajectory.task_type.clone(),
outcome: trajectory.outcome,
quality_score: trajectory.quality_score,
agent_type: trajectory.agent_type.clone(),
step_count: trajectory.steps.len(),
duration_ms: trajectory.duration_ms,
is_compressed: false,
tags,
created_at: trajectory.created_at,
};
let episode = Episode {
id: episode_id.clone(),
embedding: summary_embedding.clone(),
metadata,
trajectory: Some(trajectory),
compressed: None,
};
{
let mut index = self.index.write();
index.add(episode_id.clone(), summary_embedding)?;
}
{
let mut episodes = self.episodes.write();
episodes.insert(episode_id.clone(), episode);
}
if self.config.auto_compress {
self.compress_old_episodes()?;
}
self.enforce_limit()?;
Ok(episode_id)
}
pub fn search_similar(&self, query_embedding: &[f32], k: usize) -> Result<Vec<Episode>> {
let start = std::time::Instant::now();
let results = {
let index = self.index.read();
index.search(query_embedding, k)?
};
let episodes = self.episodes.read();
let found: Vec<Episode> = results
.into_iter()
.filter_map(|r| episodes.get(&r.id).cloned())
.collect();
let latency = start.elapsed().as_micros() as u64;
self.stats.total_searches.fetch_add(1, Ordering::SeqCst);
self.stats
.total_search_latency_us
.fetch_add(latency, Ordering::SeqCst);
if !found.is_empty() {
self.stats
.successful_retrievals
.fetch_add(1, Ordering::SeqCst);
}
Ok(found)
}
pub fn search_with_filter<F>(
&self,
query_embedding: &[f32],
k: usize,
filter: F,
) -> Result<Vec<Episode>>
where
F: Fn(&EpisodeMetadata) -> bool,
{
let search_k = k * 3;
let results = self.search_similar(query_embedding, search_k)?;
let filtered: Vec<Episode> = results
.into_iter()
.filter(|e| filter(&e.metadata))
.take(k)
.collect();
Ok(filtered)
}
pub fn search_by_task_type(
&self,
query_embedding: &[f32],
task_type: &str,
k: usize,
) -> Result<Vec<Episode>> {
self.search_with_filter(query_embedding, k, |meta| meta.task_type == task_type)
}
pub fn search_successful(
&self,
query_embedding: &[f32],
min_quality: f32,
k: usize,
) -> Result<Vec<Episode>> {
self.search_with_filter(query_embedding, k, |meta| {
meta.outcome > 0.5 && meta.quality_score >= min_quality
})
}
pub fn compress_old_episodes(&self) -> Result<usize> {
let threshold = Utc::now() - Duration::days(self.config.compression_age_days);
let mut compressed_count = 0;
let episodes_to_compress: Vec<String> = {
let episodes = self.episodes.read();
episodes
.iter()
.filter(|(_, e)| {
e.metadata.created_at < threshold
&& !e.metadata.is_compressed
&& e.trajectory.is_some()
})
.map(|(id, _)| id.clone())
.collect()
};
for id in episodes_to_compress {
if let Some(episode) = self.episodes.write().get_mut(&id) {
if let Some(trajectory) = episode.trajectory.take() {
let compressed = self.compressor.compress(&trajectory);
episode.compressed = Some(compressed);
episode.metadata.is_compressed = true;
compressed_count += 1;
}
}
}
Ok(compressed_count)
}
pub fn get(&self, id: &str) -> Option<Episode> {
self.episodes.read().get(id).cloned()
}
pub fn delete(&self, id: &str) -> Result<bool> {
let removed = {
let mut episodes = self.episodes.write();
episodes.remove(id).is_some()
};
if removed {
let mut index = self.index.write();
index.remove(&id.to_string())?;
}
Ok(removed)
}
fn enforce_limit(&self) -> Result<()> {
let mut episodes = self.episodes.write();
while episodes.len() > self.config.max_episodes {
if let Some(oldest) = episodes
.iter()
.filter(|(_, e)| e.metadata.is_compressed)
.min_by_key(|(_, e)| e.metadata.created_at)
.map(|(id, _)| id.clone())
{
episodes.remove(&oldest);
let mut index = self.index.write();
let _ = index.remove(&oldest);
} else if let Some(oldest) = episodes
.iter()
.min_by_key(|(_, e)| e.metadata.created_at)
.map(|(id, _)| id.clone())
{
episodes.remove(&oldest);
let mut index = self.index.write();
let _ = index.remove(&oldest);
} else {
break;
}
}
Ok(())
}
pub fn stats(&self) -> EpisodicMemoryStats {
let episodes = self.episodes.read();
let compressed = episodes
.iter()
.filter(|(_, e)| e.metadata.is_compressed)
.count() as u64;
let total = episodes.len() as u64;
let searches = self.stats.total_searches.load(Ordering::SeqCst);
let total_latency = self.stats.total_search_latency_us.load(Ordering::SeqCst);
let avg_latency = total_latency.checked_div(searches).unwrap_or(0);
EpisodicMemoryStats {
total_episodes: total,
compressed_episodes: compressed,
uncompressed_episodes: total - compressed,
total_searches: searches,
avg_search_latency_us: avg_latency,
successful_retrievals: self.stats.successful_retrievals.load(Ordering::SeqCst),
}
}
pub fn clear(&self) -> Result<()> {
self.episodes.write().clear();
let hnsw_config = HnswConfig {
m: self.config.hnsw_m,
ef_construction: self.config.hnsw_ef_construction,
ef_search: self.config.hnsw_ef_search,
max_elements: self.config.max_episodes,
};
let new_index = HnswIndex::new(
self.config.embedding_dim,
DistanceMetric::Cosine,
hnsw_config,
)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?;
*self.index.write() = new_index;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_embedding(dim: usize) -> Vec<f32> {
vec![0.1; dim]
}
fn test_trajectory() -> Trajectory {
Trajectory {
id: "traj-1".to_string(),
steps: vec![
TrajectoryStep {
state: "Initial state".to_string(),
action: "read_file /src/main.rs".to_string(),
result: Some("file contents".to_string()),
embedding: Some(vec![0.1; 128]),
reward: 0.5,
timestamp: Utc::now(),
},
TrajectoryStep {
state: "After reading".to_string(),
action: "edit_file /src/main.rs".to_string(),
result: Some("edited".to_string()),
embedding: Some(vec![0.2; 128]),
reward: 0.8,
timestamp: Utc::now(),
},
],
outcome: 1.0,
quality_score: 0.9,
task_type: "coding".to_string(),
agent_type: Some("coder".to_string()),
duration_ms: 5000,
created_at: Utc::now(),
}
}
#[test]
fn test_episodic_memory_creation() {
let config = EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
};
let memory = EpisodicMemory::new(config).unwrap();
assert_eq!(memory.stats().total_episodes, 0);
}
#[test]
fn test_store_and_search() {
let config = EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
};
let memory = EpisodicMemory::new(config).unwrap();
let trajectory = test_trajectory();
let embedding = test_embedding(128);
let id = memory
.store_episode(trajectory, embedding.clone(), vec!["test".to_string()])
.unwrap();
assert_eq!(id, "traj-1");
assert_eq!(memory.stats().total_episodes, 1);
let results = memory.search_similar(&embedding, 5).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "traj-1");
}
#[test]
fn test_search_with_filter() {
let config = EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
};
let memory = EpisodicMemory::new(config).unwrap();
let trajectory = test_trajectory();
let embedding = test_embedding(128);
memory
.store_episode(trajectory, embedding.clone(), vec!["test".to_string()])
.unwrap();
let results = memory.search_by_task_type(&embedding, "coding", 5).unwrap();
assert_eq!(results.len(), 1);
let results = memory
.search_by_task_type(&embedding, "research", 5)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_compression() {
let compressor = MemoryCompressor::new(0.5, None);
let trajectory = test_trajectory();
let compressed = compressor.compress(&trajectory);
assert!(!compressed.summary.is_empty());
assert!(!compressed.key_actions.is_empty());
assert_eq!(compressed.original_step_count, 2);
}
#[test]
fn test_delete() {
let config = EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
};
let memory = EpisodicMemory::new(config).unwrap();
let trajectory = test_trajectory();
let embedding = test_embedding(128);
memory.store_episode(trajectory, embedding, vec![]).unwrap();
assert!(memory.get("traj-1").is_some());
assert!(memory.delete("traj-1").unwrap());
assert!(memory.get("traj-1").is_none());
}
#[test]
fn test_clear() {
let config = EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
};
let memory = EpisodicMemory::new(config).unwrap();
let trajectory = test_trajectory();
let embedding = test_embedding(128);
memory.store_episode(trajectory, embedding, vec![]).unwrap();
assert_eq!(memory.stats().total_episodes, 1);
memory.clear().unwrap();
assert_eq!(memory.stats().total_episodes, 0);
}
}