use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::token_count::estimate_content_tokens;
use crate::vector_store::{EmbeddingBackend, VectorIndex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum EpisodeType {
#[default]
Conversation,
ToolExecution,
Error,
Success,
CodeChange,
Learning,
Decision,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum Importance {
Low = 1,
#[default]
Normal = 2,
High = 3,
Critical = 4,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Episode {
pub id: String,
pub episode_type: EpisodeType,
pub content: String,
pub context: HashMap<String, String>,
pub importance: Importance,
pub timestamp: u64,
pub session_id: String,
pub related_episodes: Vec<String>,
pub tags: Vec<String>,
pub outcome: Option<EpisodeOutcome>,
pub access_count: u32,
pub last_accessed: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodeOutcome {
pub success: bool,
pub description: String,
pub lessons: Vec<String>,
}
impl Episode {
pub fn new(
episode_type: EpisodeType,
content: impl Into<String>,
session_id: impl Into<String>,
) -> Self {
let content = content.into();
let session_id = session_id.into();
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
hasher.update(now.to_le_bytes());
let id = hex::encode(&hasher.finalize()[..8]);
Self {
id,
episode_type,
content,
context: HashMap::new(),
importance: Importance::Normal,
timestamp: now,
session_id,
related_episodes: Vec::new(),
tags: Vec::new(),
outcome: None,
access_count: 0,
last_accessed: now,
}
}
pub fn with_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.context.insert(key.into(), value.into());
self
}
pub fn with_importance(mut self, importance: Importance) -> Self {
self.importance = importance;
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn with_related(mut self, episode_id: impl Into<String>) -> Self {
self.related_episodes.push(episode_id.into());
self
}
pub fn with_outcome(mut self, outcome: EpisodeOutcome) -> Self {
self.outcome = Some(outcome);
self
}
pub fn record_access(&mut self) {
self.access_count += 1;
self.last_accessed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
}
pub fn recency_score(&self) -> f32 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let age_hours = (now - self.timestamp) as f32 / 3600.0;
(-age_hours / 24.0).exp()
}
pub fn relevance_score(&self, base_similarity: f32) -> f32 {
let importance_weight = match self.importance {
Importance::Low => 0.5,
Importance::Normal => 1.0,
Importance::High => 1.5,
Importance::Critical => 2.0,
};
let recency = self.recency_score();
let access_bonus = (self.access_count as f32 * 0.1).min(0.5);
base_similarity * importance_weight * (0.5 + 0.5 * recency) + access_bonus
}
pub fn searchable_text(&self) -> String {
let mut parts = vec![self.content.clone()];
for value in self.context.values() {
parts.push(value.clone());
}
parts.extend(self.tags.clone());
if let Some(ref outcome) = self.outcome {
parts.push(outcome.description.clone());
parts.extend(outcome.lessons.clone());
}
parts.join(" ")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub started_at: u64,
pub ended_at: Option<u64>,
pub working_dir: PathBuf,
pub summary: Option<String>,
pub task: Option<String>,
pub episode_count: usize,
pub success_rate: f32,
}
impl Session {
pub fn new(working_dir: impl Into<PathBuf>) -> Self {
let id = uuid::Uuid::new_v4().to_string();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
id,
started_at: now,
ended_at: None,
working_dir: working_dir.into(),
summary: None,
task: None,
episode_count: 0,
success_rate: 0.0,
}
}
pub fn with_task(mut self, task: impl Into<String>) -> Self {
self.task = Some(task.into());
self
}
pub fn end(&mut self, summary: impl Into<String>) {
self.ended_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
self.summary = Some(summary.into());
}
pub fn duration_secs(&self) -> u64 {
let end = self.ended_at.unwrap_or_else(|| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
});
end.saturating_sub(self.started_at)
}
pub fn is_active(&self) -> bool {
self.ended_at.is_none()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pattern {
pub id: String,
pub description: String,
pub episode_ids: Vec<String>,
pub frequency: usize,
pub pattern_type: PatternType,
pub confidence: f32,
pub suggestion: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PatternType {
RecurringError,
SuccessfulApproach,
#[default]
Workflow,
Preference,
AntiPattern,
}
impl Pattern {
pub fn new(description: impl Into<String>, pattern_type: PatternType) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
description: description.into(),
episode_ids: Vec::new(),
frequency: 0,
pattern_type,
confidence: 0.5,
suggestion: None,
}
}
pub fn add_episode(&mut self, episode_id: impl Into<String>) {
self.episode_ids.push(episode_id.into());
self.frequency = self.episode_ids.len();
self.confidence = (self.confidence + 0.1).min(1.0);
}
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestion = Some(suggestion.into());
self
}
}
#[derive(Debug, Clone)]
pub struct MemoryResult {
pub episode: Episode,
pub similarity: f32,
pub relevance: f32,
}
pub struct EpisodicMemory {
episodes: HashMap<String, Episode>,
sessions: HashMap<String, Session>,
current_session: Option<String>,
index: VectorIndex,
provider: Arc<EmbeddingBackend>,
patterns: Vec<Pattern>,
storage_path: Option<PathBuf>,
config: EpisodicMemoryConfig,
recent: VecDeque<String>,
}
#[derive(Debug, Clone)]
pub struct EpisodicMemoryConfig {
pub max_episodes: usize,
pub max_recent: usize,
pub min_importance_to_keep: Importance,
pub age_threshold_secs: u64,
pub pattern_threshold: usize,
}
impl Default for EpisodicMemoryConfig {
fn default() -> Self {
Self {
max_episodes: 10000,
max_recent: 50,
min_importance_to_keep: Importance::Normal,
age_threshold_secs: 7 * 24 * 3600, pattern_threshold: 3,
}
}
}
impl EpisodicMemory {
pub fn new(provider: Arc<EmbeddingBackend>) -> Self {
Self {
episodes: HashMap::new(),
sessions: HashMap::new(),
current_session: None,
index: VectorIndex::new(provider.dimension()),
provider,
patterns: Vec::new(),
storage_path: None,
config: EpisodicMemoryConfig::default(),
recent: VecDeque::new(),
}
}
pub fn with_storage(mut self, path: impl Into<PathBuf>) -> Self {
self.storage_path = Some(path.into());
self
}
pub fn with_config(mut self, config: EpisodicMemoryConfig) -> Self {
self.config = config;
self
}
pub fn start_session(&mut self, working_dir: impl Into<PathBuf>) -> String {
let session = Session::new(working_dir);
let id = session.id.clone();
self.sessions.insert(id.clone(), session);
self.current_session = Some(id.clone());
id
}
pub fn start_session_with_task(
&mut self,
working_dir: impl Into<PathBuf>,
task: impl Into<String>,
) -> String {
let session = Session::new(working_dir).with_task(task);
let id = session.id.clone();
self.sessions.insert(id.clone(), session);
self.current_session = Some(id.clone());
id
}
pub fn end_session(&mut self, summary: impl Into<String>) {
if let Some(ref session_id) = self.current_session {
if let Some(session) = self.sessions.get_mut(session_id) {
let session_episodes: Vec<&Episode> = self
.episodes
.values()
.filter(|e| &e.session_id == session_id)
.collect();
let successes = session_episodes
.iter()
.filter(|e| e.outcome.as_ref().map(|o| o.success).unwrap_or(false))
.count();
session.episode_count = session_episodes.len();
if !session_episodes.is_empty() {
session.success_rate = successes as f32 / session_episodes.len() as f32;
}
session.end(summary);
}
}
self.current_session = None;
}
pub fn current_session(&self) -> Option<&Session> {
self.current_session
.as_ref()
.and_then(|id| self.sessions.get(id))
}
pub async fn record(&mut self, mut episode: Episode) -> Result<String> {
if episode.session_id.is_empty() {
if let Some(ref session_id) = self.current_session {
episode.session_id = session_id.clone();
}
}
let text = episode.searchable_text();
let embedding = self.provider.embed(&text).await?;
let id = episode.id.clone();
self.index.add(id.clone(), embedding)?;
self.recent.push_back(id.clone());
while self.recent.len() > self.config.max_recent {
self.recent.pop_front();
}
self.episodes.insert(id.clone(), episode);
if self.episodes.len() > self.config.max_episodes {
self.cleanup();
}
Ok(id)
}
pub async fn record_conversation(&mut self, content: impl Into<String>) -> Result<String> {
let session_id = self
.current_session
.clone()
.unwrap_or_else(|| "default".to_string());
let episode = Episode::new(EpisodeType::Conversation, content, session_id);
self.record(episode).await
}
pub async fn record_tool_execution(
&mut self,
tool: impl Into<String>,
args: impl Into<String>,
result: impl Into<String>,
success: bool,
) -> Result<String> {
let session_id = self
.current_session
.clone()
.unwrap_or_else(|| "default".to_string());
let tool = tool.into();
let content = format!("Tool: {}", tool);
let episode = Episode::new(EpisodeType::ToolExecution, content, session_id)
.with_context("tool", &tool)
.with_context("args", args)
.with_context("result", result)
.with_outcome(EpisodeOutcome {
success,
description: if success {
"Tool executed successfully".to_string()
} else {
"Tool execution failed".to_string()
},
lessons: Vec::new(),
});
self.record(episode).await
}
pub async fn record_error(
&mut self,
error: impl Into<String>,
context: impl Into<String>,
) -> Result<String> {
let session_id = self
.current_session
.clone()
.unwrap_or_else(|| "default".to_string());
let episode = Episode::new(EpisodeType::Error, error, session_id)
.with_context("error_context", context)
.with_importance(Importance::High)
.with_tag("error");
self.record(episode).await
}
pub async fn record_success(
&mut self,
description: impl Into<String>,
lessons: Vec<String>,
) -> Result<String> {
let session_id = self
.current_session
.clone()
.unwrap_or_else(|| "default".to_string());
let episode = Episode::new(EpisodeType::Success, description, session_id)
.with_importance(Importance::High)
.with_outcome(EpisodeOutcome {
success: true,
description: "Task completed successfully".to_string(),
lessons,
});
self.record(episode).await
}
pub async fn record_learning(&mut self, insight: impl Into<String>) -> Result<String> {
let session_id = self
.current_session
.clone()
.unwrap_or_else(|| "default".to_string());
let episode = Episode::new(EpisodeType::Learning, insight, session_id)
.with_importance(Importance::High)
.with_tag("learning");
self.record(episode).await
}
pub async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<MemoryResult>> {
let embedding = self.provider.embed(query).await?;
let search_results = self.index.search(&embedding, limit * 2);
let mut results: Vec<MemoryResult> = search_results
.into_iter()
.filter_map(|(id, similarity)| {
self.episodes.get(&id).map(|episode| {
let relevance = episode.relevance_score(similarity);
MemoryResult {
episode: episode.clone(),
similarity,
relevance,
}
})
})
.collect();
results.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
pub fn retrieve_recent(&self, limit: usize) -> Vec<&Episode> {
self.recent
.iter()
.rev()
.take(limit)
.filter_map(|id| self.episodes.get(id))
.collect()
}
pub fn retrieve_by_type(&self, episode_type: EpisodeType, limit: usize) -> Vec<&Episode> {
let mut episodes: Vec<&Episode> = self
.episodes
.values()
.filter(|e| e.episode_type == episode_type)
.collect();
episodes.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
episodes.truncate(limit);
episodes
}
pub fn session_errors(&self) -> Vec<&Episode> {
if let Some(ref session_id) = self.current_session {
self.episodes
.values()
.filter(|e| &e.session_id == session_id && e.episode_type == EpisodeType::Error)
.collect()
} else {
Vec::new()
}
}
pub fn get(&self, id: &str) -> Option<&Episode> {
self.episodes.get(id)
}
pub fn get_mut(&mut self, id: &str) -> Option<&mut Episode> {
self.episodes.get_mut(id)
}
pub fn access(&mut self, id: &str) {
if let Some(episode) = self.episodes.get_mut(id) {
episode.record_access();
}
}
pub fn detect_patterns(&mut self) {
let error_episodes: Vec<&Episode> = self
.episodes
.values()
.filter(|e| e.episode_type == EpisodeType::Error)
.collect();
let mut error_patterns: HashMap<String, Vec<String>> = HashMap::new();
for episode in error_episodes {
let key = episode
.content
.split_whitespace()
.take(5)
.collect::<Vec<_>>()
.join(" ");
error_patterns
.entry(key)
.or_default()
.push(episode.id.clone());
}
for (description, episode_ids) in error_patterns {
if episode_ids.len() >= self.config.pattern_threshold {
let mut pattern = Pattern::new(
format!("Recurring error: {}", description),
PatternType::RecurringError,
);
for id in episode_ids {
pattern.add_episode(id);
}
pattern.suggestion =
Some("Consider addressing the root cause of this error".to_string());
self.patterns.push(pattern);
}
}
let success_episodes: Vec<&Episode> = self
.episodes
.values()
.filter(|e| e.outcome.as_ref().map(|o| o.success).unwrap_or(false))
.collect();
if success_episodes.len() >= self.config.pattern_threshold {
let mut tag_counts: HashMap<String, Vec<String>> = HashMap::new();
for episode in success_episodes {
for tag in &episode.tags {
tag_counts
.entry(tag.clone())
.or_default()
.push(episode.id.clone());
}
}
for (tag, episode_ids) in tag_counts {
if episode_ids.len() >= self.config.pattern_threshold {
let mut pattern = Pattern::new(
format!("Successful approach with '{}'", tag),
PatternType::SuccessfulApproach,
);
for id in episode_ids {
pattern.add_episode(id);
}
self.patterns.push(pattern);
}
}
}
}
pub fn patterns(&self) -> &[Pattern] {
&self.patterns
}
pub fn patterns_by_type(&self, pattern_type: PatternType) -> Vec<&Pattern> {
self.patterns
.iter()
.filter(|p| p.pattern_type == pattern_type)
.collect()
}
pub async fn reconstruct_context(&self, query: &str, max_tokens: usize) -> Result<String> {
let memories = self.retrieve(query, 10).await?;
let mut context_parts = Vec::new();
let mut token_count = 0;
for memory in memories {
let episode = &memory.episode;
let text = format!(
"[{:?} - {}]\n{}\n",
episode.episode_type,
chrono::DateTime::from_timestamp(episode.timestamp as i64, 0)
.map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_else(|| "Unknown".to_string()),
episode.content
);
let estimated_tokens = estimate_content_tokens(&text);
if token_count + estimated_tokens > max_tokens {
break;
}
context_parts.push(text);
token_count += estimated_tokens;
}
Ok(context_parts.join("\n---\n"))
}
fn cleanup(&mut self) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let threshold = now - self.config.age_threshold_secs;
let to_remove: Vec<String> = self
.episodes
.iter()
.filter(|(_, e)| {
e.timestamp < threshold && e.importance < self.config.min_importance_to_keep
})
.map(|(id, _)| id.clone())
.collect();
for id in to_remove {
self.episodes.remove(&id);
self.index.remove(&id);
}
}
pub fn save(&self) -> Result<()> {
let path = self
.storage_path
.as_ref()
.ok_or_else(|| anyhow!("Storage path not set"))?;
std::fs::create_dir_all(path)?;
fn atomic_write(path: &std::path::Path, json: &str) -> Result<()> {
let tmp_path = path.with_extension(format!("tmp.{}", std::process::id()));
std::fs::write(&tmp_path, json)?;
std::fs::rename(&tmp_path, path)?;
Ok(())
}
let episodes_path = path.join("episodes.json");
atomic_write(
&episodes_path,
&serde_json::to_string_pretty(&self.episodes)?,
)?;
let sessions_path = path.join("sessions.json");
atomic_write(
&sessions_path,
&serde_json::to_string_pretty(&self.sessions)?,
)?;
let patterns_path = path.join("patterns.json");
atomic_write(
&patterns_path,
&serde_json::to_string_pretty(&self.patterns)?,
)?;
Ok(())
}
pub fn load(&mut self) -> Result<()> {
let path = self
.storage_path
.as_ref()
.ok_or_else(|| anyhow!("Storage path not set"))?
.clone();
if !path.exists() {
return Ok(());
}
let episodes_path = path.join("episodes.json");
if episodes_path.exists() {
let json = std::fs::read_to_string(&episodes_path)?;
self.episodes = serde_json::from_str(&json)?;
}
let sessions_path = path.join("sessions.json");
if sessions_path.exists() {
let json = std::fs::read_to_string(&sessions_path)?;
self.sessions = serde_json::from_str(&json)?;
}
let patterns_path = path.join("patterns.json");
if patterns_path.exists() {
let json = std::fs::read_to_string(&patterns_path)?;
self.patterns = serde_json::from_str(&json)?;
}
Ok(())
}
pub fn stats(&self) -> MemoryStats {
let by_type: HashMap<EpisodeType, usize> =
self.episodes.values().fold(HashMap::new(), |mut acc, e| {
*acc.entry(e.episode_type).or_insert(0) += 1;
acc
});
let avg_importance = if self.episodes.is_empty() {
0.0
} else {
self.episodes
.values()
.map(|e| e.importance as u8 as f32)
.sum::<f32>()
/ self.episodes.len() as f32
};
MemoryStats {
total_episodes: self.episodes.len(),
total_sessions: self.sessions.len(),
active_session: self.current_session.is_some(),
recent_count: self.recent.len(),
pattern_count: self.patterns.len(),
episodes_by_type: by_type,
average_importance: avg_importance,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_episodes: usize,
pub total_sessions: usize,
pub active_session: bool,
pub recent_count: usize,
pub pattern_count: usize,
pub episodes_by_type: HashMap<EpisodeType, usize>,
pub average_importance: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector_store::MockEmbeddingProvider;
use std::sync::Arc;
use tempfile::tempdir;
fn mock_provider() -> Arc<EmbeddingBackend> {
Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()))
}
#[test]
fn test_episode_type_default() {
assert_eq!(EpisodeType::default(), EpisodeType::Conversation);
}
#[test]
fn test_importance_ordering() {
assert!(Importance::Critical > Importance::High);
assert!(Importance::High > Importance::Normal);
assert!(Importance::Normal > Importance::Low);
}
#[test]
fn test_episode_creation() {
let episode = Episode::new(EpisodeType::Conversation, "Test content", "session1")
.with_context("key", "value")
.with_importance(Importance::High)
.with_tag("test");
assert!(!episode.id.is_empty());
assert_eq!(episode.content, "Test content");
assert_eq!(episode.importance, Importance::High);
assert!(episode.tags.contains(&"test".to_string()));
}
#[test]
fn test_episode_outcome() {
let outcome = EpisodeOutcome {
success: true,
description: "Task completed".to_string(),
lessons: vec!["Lesson 1".to_string()],
};
let episode =
Episode::new(EpisodeType::Success, "Success", "session1").with_outcome(outcome);
assert!(episode.outcome.is_some());
assert!(episode.outcome.unwrap().success);
}
#[test]
fn test_episode_recency_score() {
let episode = Episode::new(EpisodeType::Conversation, "Test", "session1");
let score = episode.recency_score();
assert!(score > 0.9);
}
#[test]
fn test_episode_access_tracking() {
let mut episode = Episode::new(EpisodeType::Conversation, "Test", "session1");
assert_eq!(episode.access_count, 0);
episode.record_access();
assert_eq!(episode.access_count, 1);
episode.record_access();
assert_eq!(episode.access_count, 2);
}
#[test]
fn test_episode_searchable_text() {
let episode = Episode::new(EpisodeType::Conversation, "Main content", "session1")
.with_context("ctx", "context value")
.with_tag("tag1");
let text = episode.searchable_text();
assert!(text.contains("Main content"));
assert!(text.contains("context value"));
assert!(text.contains("tag1"));
}
#[test]
fn test_session_creation() {
let session = Session::new("/tmp/project").with_task("Implement feature");
assert!(!session.id.is_empty());
assert!(session.is_active());
assert_eq!(session.task, Some("Implement feature".to_string()));
}
#[test]
fn test_session_end() {
let mut session = Session::new("/tmp/project");
assert!(session.is_active());
session.end("Session completed");
assert!(!session.is_active());
assert!(session.summary.is_some());
}
#[test]
fn test_session_duration() {
let session = Session::new("/tmp/project");
let duration = session.duration_secs();
assert!(duration < 2); }
#[test]
fn test_pattern_creation() {
let mut pattern = Pattern::new("Test pattern", PatternType::RecurringError)
.with_suggestion("Fix the root cause");
assert!(!pattern.id.is_empty());
assert_eq!(pattern.pattern_type, PatternType::RecurringError);
assert!(pattern.suggestion.is_some());
pattern.add_episode("ep1");
pattern.add_episode("ep2");
assert_eq!(pattern.frequency, 2);
assert!(pattern.confidence > 0.5);
}
#[tokio::test]
async fn test_episodic_memory_creation() {
let memory = EpisodicMemory::new(mock_provider());
let stats = memory.stats();
assert_eq!(stats.total_episodes, 0);
assert_eq!(stats.total_sessions, 0);
}
#[tokio::test]
async fn test_start_session() {
let mut memory = EpisodicMemory::new(mock_provider());
let session_id = memory.start_session("/tmp/project");
assert!(!session_id.is_empty());
assert!(memory.current_session().is_some());
}
#[tokio::test]
async fn test_record_episode() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let episode = Episode::new(EpisodeType::Conversation, "Test message", "");
let id = memory.record(episode).await.unwrap();
assert!(memory.get(&id).is_some());
}
#[tokio::test]
async fn test_record_conversation() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory.record_conversation("Hello world").await.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::Conversation);
}
#[tokio::test]
async fn test_record_tool_execution() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_tool_execution("file_read", "/tmp/test.txt", "File contents", true)
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::ToolExecution);
assert!(episode.outcome.as_ref().unwrap().success);
}
#[tokio::test]
async fn test_record_error() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_error("Something failed", "Error context")
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::Error);
assert_eq!(episode.importance, Importance::High);
}
#[tokio::test]
async fn test_retrieve_similar() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory
.record_conversation("Calculate the sum of two numbers")
.await
.unwrap();
memory
.record_conversation("Find the product of values")
.await
.unwrap();
let results = memory.retrieve("sum calculation", 5).await.unwrap();
assert!(results.len() <= 5);
}
#[tokio::test]
async fn test_retrieve_recent() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory.record_conversation("First").await.unwrap();
memory.record_conversation("Second").await.unwrap();
memory.record_conversation("Third").await.unwrap();
let recent = memory.retrieve_recent(2);
assert_eq!(recent.len(), 2);
assert_eq!(recent[0].content, "Third"); }
#[tokio::test]
async fn test_retrieve_by_type() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory.record_conversation("Chat").await.unwrap();
memory.record_error("Error 1", "ctx").await.unwrap();
memory.record_error("Error 2", "ctx").await.unwrap();
let errors = memory.retrieve_by_type(EpisodeType::Error, 10);
assert_eq!(errors.len(), 2);
}
#[tokio::test]
async fn test_session_errors() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory.record_error("Session error", "ctx").await.unwrap();
memory.record_conversation("Normal chat").await.unwrap();
let errors = memory.session_errors();
assert_eq!(errors.len(), 1);
}
#[tokio::test]
async fn test_end_session() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory.record_conversation("Test").await.unwrap();
memory.end_session("Session done");
assert!(memory.current_session().is_none());
}
#[tokio::test]
async fn test_pattern_detection() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.config.pattern_threshold = 2;
memory.start_session("/tmp");
memory
.record_error("Connection to server failed due to timeout", "ctx")
.await
.unwrap();
memory
.record_error("Connection to server failed due to DNS", "ctx")
.await
.unwrap();
memory
.record_error("Connection to server failed due to firewall", "ctx")
.await
.unwrap();
memory.detect_patterns();
let error_patterns = memory.patterns_by_type(PatternType::RecurringError);
let _ = error_patterns.len();
}
#[tokio::test]
async fn test_context_reconstruction() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
memory
.record_conversation("Working on feature X")
.await
.unwrap();
memory
.record_conversation("Added new function")
.await
.unwrap();
let context = memory.reconstruct_context("feature", 1000).await.unwrap();
assert!(!context.is_empty());
}
#[tokio::test]
async fn test_memory_persistence() {
let dir = tempdir().unwrap();
let storage_path = dir.path().to_path_buf();
{
let mut memory = EpisodicMemory::new(mock_provider()).with_storage(&storage_path);
memory.start_session("/tmp");
memory
.record_conversation("Persistent message")
.await
.unwrap();
memory.save().unwrap();
}
{
let mut memory = EpisodicMemory::new(mock_provider()).with_storage(&storage_path);
memory.load().unwrap();
assert_eq!(memory.stats().total_episodes, 1);
}
}
#[test]
fn test_memory_stats() {
let memory = EpisodicMemory::new(mock_provider());
let stats = memory.stats();
assert_eq!(stats.total_episodes, 0);
assert!(!stats.active_session);
}
#[test]
fn test_episodic_memory_config_default() {
let config = EpisodicMemoryConfig::default();
assert_eq!(config.max_episodes, 10000);
assert_eq!(config.max_recent, 50);
}
#[test]
fn test_pattern_type_default() {
assert_eq!(PatternType::default(), PatternType::Workflow);
}
#[tokio::test]
async fn test_access_episode() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory.record_conversation("Test").await.unwrap();
memory.access(&id);
let episode = memory.get(&id).unwrap();
assert_eq!(episode.access_count, 1);
}
#[tokio::test]
async fn test_record_success() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_success("Task completed", vec!["Lesson 1".to_string()])
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::Success);
}
#[tokio::test]
async fn test_record_learning() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_learning("Learned something new")
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::Learning);
assert!(episode.tags.contains(&"learning".to_string()));
}
#[test]
fn test_episode_related() {
let episode = Episode::new(EpisodeType::Conversation, "Test", "session1")
.with_related("other_episode");
assert!(episode
.related_episodes
.contains(&"other_episode".to_string()));
}
#[test]
fn test_episode_type_all_variants_debug() {
let types = [
EpisodeType::Conversation,
EpisodeType::ToolExecution,
EpisodeType::Error,
EpisodeType::Success,
EpisodeType::CodeChange,
EpisodeType::Learning,
EpisodeType::Decision,
];
for t in types {
let _ = format!("{:?}", t);
}
}
#[test]
fn test_importance_default_value() {
assert_eq!(Importance::default(), Importance::Normal);
}
#[test]
fn test_importance_values() {
assert_eq!(Importance::Low as u8, 1);
assert_eq!(Importance::Normal as u8, 2);
assert_eq!(Importance::High as u8, 3);
assert_eq!(Importance::Critical as u8, 4);
}
#[test]
fn test_episode_with_context() {
let episode = Episode::new(EpisodeType::ToolExecution, "Test", "session")
.with_context("key1", "value1")
.with_context("key2", "value2");
assert_eq!(episode.context.get("key1"), Some(&"value1".to_string()));
assert_eq!(episode.context.get("key2"), Some(&"value2".to_string()));
}
#[test]
fn test_episode_with_importance() {
let episode = Episode::new(EpisodeType::Error, "Critical error", "session")
.with_importance(Importance::Critical);
assert_eq!(episode.importance, Importance::Critical);
}
#[test]
fn test_episode_with_tag() {
let episode = Episode::new(EpisodeType::Learning, "Lesson", "session")
.with_tag("rust")
.with_tag("testing");
assert!(episode.tags.contains(&"rust".to_string()));
assert!(episode.tags.contains(&"testing".to_string()));
}
#[test]
fn test_episode_with_outcome() {
let outcome = EpisodeOutcome {
success: true,
description: "Task completed".to_string(),
lessons: vec!["Use smaller steps".to_string()],
};
let episode = Episode::new(EpisodeType::Success, "Done", "session").with_outcome(outcome);
assert!(episode.outcome.is_some());
assert!(episode.outcome.as_ref().unwrap().success);
}
#[test]
fn test_episode_record_access() {
let mut episode = Episode::new(EpisodeType::Conversation, "Test", "session");
assert_eq!(episode.access_count, 0);
episode.record_access();
assert_eq!(episode.access_count, 1);
episode.record_access();
assert_eq!(episode.access_count, 2);
}
#[test]
fn test_episode_relevance_score() {
let episode = Episode::new(EpisodeType::Conversation, "Test", "session")
.with_importance(Importance::High);
let score = episode.relevance_score(0.8);
assert!(score > 0.8);
}
#[test]
fn test_episode_relevance_score_with_access_bonus() {
let mut episode = Episode::new(EpisodeType::Conversation, "Test", "session");
episode.access_count = 5;
let score = episode.relevance_score(0.8);
assert!(score > 0.8);
}
#[test]
fn test_episode_searchable_text_with_outcome() {
let outcome = EpisodeOutcome {
success: true,
description: "Outcome desc".to_string(),
lessons: vec!["Lesson 1".to_string()],
};
let episode =
Episode::new(EpisodeType::Success, "Content", "session").with_outcome(outcome);
let text = episode.searchable_text();
assert!(text.contains("Outcome desc"));
assert!(text.contains("Lesson 1"));
}
#[test]
fn test_session_new() {
let session = Session::new("/home/user/project");
assert!(!session.id.is_empty());
assert!(session.started_at > 0);
assert!(session.ended_at.is_none());
assert!(session.is_active());
}
#[test]
fn test_session_with_task() {
let session = Session::new("/tmp").with_task("Build feature X");
assert_eq!(session.task, Some("Build feature X".to_string()));
}
#[test]
fn test_pattern_new() {
let pattern = Pattern::new("Recurring timeout errors", PatternType::RecurringError);
assert!(!pattern.id.is_empty());
assert_eq!(pattern.pattern_type, PatternType::RecurringError);
assert_eq!(pattern.frequency, 0);
assert!((pattern.confidence - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_pattern_add_episode() {
let mut pattern = Pattern::new("Test pattern", PatternType::Workflow);
pattern.add_episode("ep1");
assert_eq!(pattern.frequency, 1);
assert!(pattern.confidence > 0.5);
pattern.add_episode("ep2");
assert_eq!(pattern.frequency, 2);
assert!(pattern.confidence > 0.6);
}
#[test]
fn test_pattern_with_suggestion() {
let pattern = Pattern::new("Error pattern", PatternType::AntiPattern)
.with_suggestion("Avoid using this approach");
assert_eq!(
pattern.suggestion,
Some("Avoid using this approach".to_string())
);
}
#[test]
fn test_pattern_type_variants() {
let types = [
PatternType::RecurringError,
PatternType::SuccessfulApproach,
PatternType::Workflow,
PatternType::Preference,
PatternType::AntiPattern,
];
for t in types {
let _ = format!("{:?}", t);
}
}
#[test]
fn test_memory_result_debug() {
let episode = Episode::new(EpisodeType::Conversation, "Test", "session");
let result = MemoryResult {
episode,
similarity: 0.9,
relevance: 0.85,
};
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("MemoryResult"));
}
#[test]
fn test_episodic_memory_config_clone() {
let config = EpisodicMemoryConfig::default();
let cloned = config.clone();
assert_eq!(config.max_episodes, cloned.max_episodes);
assert_eq!(config.max_recent, cloned.max_recent);
}
#[test]
fn test_episode_outcome_clone() {
let outcome = EpisodeOutcome {
success: true,
description: "Done".to_string(),
lessons: vec!["L1".to_string()],
};
let cloned = outcome.clone();
assert_eq!(outcome.success, cloned.success);
assert_eq!(outcome.description, cloned.description);
}
#[test]
fn test_episode_clone() {
let episode = Episode::new(EpisodeType::Decision, "Decision made", "session");
let cloned = episode.clone();
assert_eq!(episode.id, cloned.id);
assert_eq!(episode.content, cloned.content);
}
#[test]
fn test_session_clone() {
let session = Session::new("/tmp");
let cloned = session.clone();
assert_eq!(session.id, cloned.id);
assert_eq!(session.working_dir, cloned.working_dir);
}
#[test]
fn test_pattern_clone() {
let pattern = Pattern::new("Test", PatternType::Workflow);
let cloned = pattern.clone();
assert_eq!(pattern.id, cloned.id);
assert_eq!(pattern.description, cloned.description);
}
#[tokio::test]
async fn test_start_session_with_task() {
let mut memory = EpisodicMemory::new(mock_provider());
let session_id = memory.start_session_with_task("/tmp", "Build new feature");
let session = memory.current_session().unwrap();
assert_eq!(session.id, session_id);
assert_eq!(session.task, Some("Build new feature".to_string()));
}
#[tokio::test]
async fn test_record_tool_execution_success() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_tool_execution(
"file_read",
"{\"path\": \"test.txt\"}",
"file contents",
true,
)
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert_eq!(episode.episode_type, EpisodeType::ToolExecution);
assert!(episode.outcome.as_ref().unwrap().success);
}
#[tokio::test]
async fn test_record_tool_execution_failure() {
let mut memory = EpisodicMemory::new(mock_provider());
memory.start_session("/tmp");
let id = memory
.record_tool_execution("file_read", "{}", "File not found", false)
.await
.unwrap();
let episode = memory.get(&id).unwrap();
assert!(!episode.outcome.as_ref().unwrap().success);
}
#[test]
fn test_episodic_memory_with_config() {
let config = EpisodicMemoryConfig {
max_episodes: 500,
max_recent: 20,
min_importance_to_keep: Importance::High,
age_threshold_secs: 3600,
pattern_threshold: 5,
};
let memory = EpisodicMemory::new(mock_provider()).with_config(config);
assert_eq!(memory.config.max_episodes, 500);
assert_eq!(memory.config.max_recent, 20);
}
#[test]
fn test_episodic_memory_with_storage() {
let memory = EpisodicMemory::new(mock_provider()).with_storage("/tmp/memory");
assert!(memory.storage_path.is_some());
}
#[test]
fn test_episode_type_eq() {
assert_eq!(EpisodeType::Conversation, EpisodeType::Conversation);
assert_ne!(EpisodeType::Error, EpisodeType::Success);
}
#[test]
fn test_episode_type_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(EpisodeType::Conversation);
set.insert(EpisodeType::Error);
assert_eq!(set.len(), 2);
}
}