Skip to main content

lean_ctx/core/
episodic_memory.rs

1//! Episodic Memory — persistent cross-session experiences with outcomes.
2//!
3//! Automatically records what the agent did in each session, with what result.
4//! Enables learning from past experiences: "What happened last time I refactored auth?"
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10use crate::core::memory_policy::EpisodicPolicy;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EpisodicStore {
14    pub project_hash: String,
15    pub episodes: Vec<Episode>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Episode {
20    pub id: String,
21    pub session_id: String,
22    pub timestamp: DateTime<Utc>,
23    pub task_description: String,
24    pub actions: Vec<Action>,
25    pub outcome: Outcome,
26    pub affected_files: Vec<String>,
27    pub summary: String,
28    pub duration_secs: u64,
29    pub tokens_used: u64,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Action {
34    pub tool: String,
35    pub description: String,
36    pub timestamp: DateTime<Utc>,
37    pub duration_ms: u64,
38    pub success: bool,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub enum Outcome {
43    Success { tests_passed: bool },
44    Failure { error: String },
45    Partial { details: String },
46    Unknown,
47}
48
49impl Outcome {
50    pub fn label(&self) -> &'static str {
51        match self {
52            Outcome::Success { .. } => "success",
53            Outcome::Failure { .. } => "failure",
54            Outcome::Partial { .. } => "partial",
55            Outcome::Unknown => "unknown",
56        }
57    }
58}
59
60impl EpisodicStore {
61    pub fn new(project_hash: &str) -> Self {
62        Self {
63            project_hash: project_hash.to_string(),
64            episodes: Vec::new(),
65        }
66    }
67
68    pub fn record_episode(&mut self, mut episode: Episode, policy: &EpisodicPolicy) {
69        episode.actions.truncate(policy.max_actions_per_episode);
70
71        if episode.summary.is_empty() {
72            episode.summary = auto_summarize(&episode, policy.summary_max_chars);
73        }
74
75        self.episodes.push(episode);
76
77        if self.episodes.len() > policy.max_episodes {
78            self.episodes
79                .drain(0..self.episodes.len() - policy.max_episodes);
80        }
81    }
82
83    pub fn search(&self, query: &str) -> Vec<&Episode> {
84        let q = query.to_lowercase();
85        let terms: Vec<&str> = q.split_whitespace().collect();
86
87        let mut scored: Vec<(&Episode, f32)> = self
88            .episodes
89            .iter()
90            .filter_map(|ep| {
91                let searchable = format!(
92                    "{} {} {}",
93                    ep.task_description.to_lowercase(),
94                    ep.summary.to_lowercase(),
95                    ep.affected_files.join(" ").to_lowercase()
96                );
97                let hits = terms.iter().filter(|t| searchable.contains(**t)).count();
98                if hits > 0 {
99                    let relevance = hits as f32 / terms.len() as f32;
100                    Some((ep, relevance))
101                } else {
102                    None
103                }
104            })
105            .collect();
106
107        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
108        scored.into_iter().map(|(ep, _)| ep).collect()
109    }
110
111    pub fn recent(&self, n: usize) -> Vec<&Episode> {
112        self.episodes.iter().rev().take(n).collect()
113    }
114
115    pub fn by_outcome(&self, outcome_label: &str) -> Vec<&Episode> {
116        self.episodes
117            .iter()
118            .filter(|ep| ep.outcome.label() == outcome_label)
119            .collect()
120    }
121
122    pub fn by_file(&self, file_path: &str) -> Vec<&Episode> {
123        self.episodes
124            .iter()
125            .filter(|ep| ep.affected_files.iter().any(|f| f.contains(file_path)))
126            .collect()
127    }
128
129    pub fn stats(&self) -> EpisodicStats {
130        let total = self.episodes.len();
131        let successes = self
132            .episodes
133            .iter()
134            .filter(|ep| matches!(ep.outcome, Outcome::Success { .. }))
135            .count();
136        let failures = self
137            .episodes
138            .iter()
139            .filter(|ep| matches!(ep.outcome, Outcome::Failure { .. }))
140            .count();
141        let total_tokens: u64 = self.episodes.iter().map(|ep| ep.tokens_used).sum();
142
143        EpisodicStats {
144            total_episodes: total,
145            successes,
146            failures,
147            success_rate: if total > 0 {
148                successes as f32 / total as f32
149            } else {
150                0.0
151            },
152            total_tokens,
153        }
154    }
155
156    fn store_path(project_hash: &str) -> Option<PathBuf> {
157        let dir = crate::core::data_dir::lean_ctx_data_dir()
158            .ok()?
159            .join("memory")
160            .join("episodes");
161        Some(dir.join(format!("{project_hash}.json")))
162    }
163
164    pub fn load(project_hash: &str) -> Option<Self> {
165        let path = Self::store_path(project_hash)?;
166        let data = std::fs::read_to_string(path).ok()?;
167        serde_json::from_str(&data).ok()
168    }
169
170    pub fn load_or_create(project_hash: &str) -> Self {
171        Self::load(project_hash).unwrap_or_else(|| Self::new(project_hash))
172    }
173
174    pub fn save(&self) -> Result<(), String> {
175        let path = Self::store_path(&self.project_hash)
176            .ok_or_else(|| "Cannot determine data directory".to_string())?;
177        if let Some(dir) = path.parent() {
178            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
179        }
180        let json = serde_json::to_string_pretty(self).map_err(|e| format!("{e}"))?;
181        std::fs::write(path, json).map_err(|e| format!("{e}"))
182    }
183}
184
185#[derive(Debug)]
186pub struct EpisodicStats {
187    pub total_episodes: usize,
188    pub successes: usize,
189    pub failures: usize,
190    pub success_rate: f32,
191    pub total_tokens: u64,
192}
193
194pub fn create_episode_from_session(
195    session: &super::session::SessionState,
196    tool_calls: &[(String, u64)],
197) -> Episode {
198    let actions: Vec<Action> = tool_calls
199        .iter()
200        .map(|(tool, duration_ms)| Action {
201            tool: tool.clone(),
202            description: String::new(),
203            timestamp: Utc::now(),
204            duration_ms: *duration_ms,
205            success: true,
206        })
207        .collect();
208
209    let affected_files: Vec<String> = session
210        .files_touched
211        .iter()
212        .map(|f| f.path.clone())
213        .collect();
214
215    let task_description = session
216        .task
217        .as_ref()
218        .map(|t| t.description.clone())
219        .unwrap_or_default();
220
221    let outcome = if session.findings.iter().any(|f| {
222        f.summary.to_lowercase().contains("error") || f.summary.to_lowercase().contains("failed")
223    }) {
224        Outcome::Failure {
225            error: session
226                .findings
227                .iter()
228                .find(|f| {
229                    f.summary.to_lowercase().contains("error")
230                        || f.summary.to_lowercase().contains("failed")
231                })
232                .map(|f| f.summary.clone())
233                .unwrap_or_default(),
234        }
235    } else if !session.findings.is_empty() || !session.decisions.is_empty() {
236        Outcome::Success { tests_passed: true }
237    } else {
238        Outcome::Unknown
239    };
240
241    Episode {
242        id: format!("ep-{}", &session.id[..8.min(session.id.len())]),
243        session_id: session.id.clone(),
244        timestamp: Utc::now(),
245        task_description,
246        actions,
247        outcome,
248        affected_files,
249        summary: String::new(),
250        duration_secs: 0,
251        tokens_used: session.stats.total_tokens_saved,
252    }
253}
254
255fn auto_summarize(episode: &Episode, max_chars: usize) -> String {
256    let tool_counts = count_tools(&episode.actions);
257    let top_tools: Vec<String> = tool_counts
258        .into_iter()
259        .take(3)
260        .map(|(tool, count)| format!("{tool}x{count}"))
261        .collect();
262
263    let files_hint = if episode.affected_files.len() <= 3 {
264        episode.affected_files.join(", ")
265    } else {
266        format!(
267            "{}, ... +{} more",
268            episode.affected_files[..3].join(", "),
269            episode.affected_files.len() - 3
270        )
271    };
272
273    let task = if episode.task_description.chars().count() > max_chars {
274        episode.task_description.chars().take(max_chars).collect()
275    } else {
276        episode.task_description.clone()
277    };
278    let mut summary = format!(
279        "{task} [{}] tools:[{}]",
280        episode.outcome.label(),
281        top_tools.join(",")
282    );
283
284    if !files_hint.is_empty() {
285        summary.push_str(&format!(" files:[{files_hint}]"));
286    }
287
288    summary
289}
290
291fn count_tools(actions: &[Action]) -> Vec<(String, usize)> {
292    let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
293    for action in actions {
294        *counts.entry(&action.tool).or_insert(0) += 1;
295    }
296    let mut sorted: Vec<(String, usize)> = counts
297        .into_iter()
298        .map(|(k, v)| (k.to_string(), v))
299        .collect();
300    sorted.sort_by_key(|item| std::cmp::Reverse(item.1));
301    sorted
302}
303
304pub fn format_episode_compact(episode: &Episode) -> String {
305    format!(
306        "[{}] {} — {} ({} actions, {} files)",
307        episode.outcome.label(),
308        episode.task_description,
309        episode.summary,
310        episode.actions.len(),
311        episode.affected_files.len()
312    )
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    fn make_episode(task: &str, outcome: Outcome) -> Episode {
320        Episode {
321            id: "ep-test".to_string(),
322            session_id: "sess-1".to_string(),
323            timestamp: Utc::now(),
324            task_description: task.to_string(),
325            actions: vec![
326                Action {
327                    tool: "ctx_read".to_string(),
328                    description: String::new(),
329                    timestamp: Utc::now(),
330                    duration_ms: 50,
331                    success: true,
332                },
333                Action {
334                    tool: "ctx_shell".to_string(),
335                    description: String::new(),
336                    timestamp: Utc::now(),
337                    duration_ms: 200,
338                    success: true,
339                },
340            ],
341            outcome,
342            affected_files: vec!["src/main.rs".to_string(), "src/lib.rs".to_string()],
343            summary: String::new(),
344            duration_secs: 60,
345            tokens_used: 5000,
346        }
347    }
348
349    #[test]
350    fn record_and_search() {
351        let policy = EpisodicPolicy::default();
352        let mut store = EpisodicStore::new("test");
353        store.record_episode(
354            make_episode(
355                "Refactor auth module",
356                Outcome::Success { tests_passed: true },
357            ),
358            &policy,
359        );
360        store.record_episode(
361            make_episode(
362                "Fix database connection",
363                Outcome::Failure {
364                    error: "timeout".to_string(),
365                },
366            ),
367            &policy,
368        );
369
370        let results = store.search("auth refactor");
371        assert_eq!(results.len(), 1);
372        assert!(results[0].task_description.contains("auth"));
373    }
374
375    #[test]
376    fn filter_by_outcome() {
377        let policy = EpisodicPolicy::default();
378        let mut store = EpisodicStore::new("test");
379        store.record_episode(
380            make_episode("Task 1", Outcome::Success { tests_passed: true }),
381            &policy,
382        );
383        store.record_episode(
384            make_episode(
385                "Task 2",
386                Outcome::Failure {
387                    error: "err".to_string(),
388                },
389            ),
390            &policy,
391        );
392        store.record_episode(
393            make_episode(
394                "Task 3",
395                Outcome::Success {
396                    tests_passed: false,
397                },
398            ),
399            &policy,
400        );
401
402        assert_eq!(store.by_outcome("success").len(), 2);
403        assert_eq!(store.by_outcome("failure").len(), 1);
404    }
405
406    #[test]
407    fn filter_by_file() {
408        let policy = EpisodicPolicy::default();
409        let mut store = EpisodicStore::new("test");
410        store.record_episode(make_episode("Task", Outcome::Unknown), &policy);
411
412        let results = store.by_file("main.rs");
413        assert_eq!(results.len(), 1);
414
415        let results = store.by_file("nonexistent.rs");
416        assert!(results.is_empty());
417    }
418
419    #[test]
420    fn recent_episodes() {
421        let policy = EpisodicPolicy::default();
422        let mut store = EpisodicStore::new("test");
423        for i in 0..5 {
424            store.record_episode(
425                make_episode(&format!("Task {i}"), Outcome::Unknown),
426                &policy,
427            );
428        }
429
430        let recent = store.recent(3);
431        assert_eq!(recent.len(), 3);
432        assert!(recent[0].task_description.contains('4'));
433    }
434
435    #[test]
436    fn stats_calculation() {
437        let policy = EpisodicPolicy::default();
438        let mut store = EpisodicStore::new("test");
439        store.record_episode(
440            make_episode("T1", Outcome::Success { tests_passed: true }),
441            &policy,
442        );
443        store.record_episode(
444            make_episode(
445                "T2",
446                Outcome::Failure {
447                    error: "e".to_string(),
448                },
449            ),
450            &policy,
451        );
452        store.record_episode(
453            make_episode(
454                "T3",
455                Outcome::Success {
456                    tests_passed: false,
457                },
458            ),
459            &policy,
460        );
461
462        let stats = store.stats();
463        assert_eq!(stats.total_episodes, 3);
464        assert_eq!(stats.successes, 2);
465        assert_eq!(stats.failures, 1);
466        assert!((stats.success_rate - 0.6667).abs() < 0.01);
467    }
468
469    #[test]
470    fn auto_summary_generation() {
471        let mut ep = make_episode("Fix the login bug", Outcome::Success { tests_passed: true });
472        ep.summary = String::new();
473        let summary = auto_summarize(&ep, EpisodicPolicy::default().summary_max_chars);
474        assert!(summary.contains("Fix the login bug"));
475        assert!(summary.contains("[success]"));
476        assert!(summary.contains("ctx_read"));
477    }
478
479    #[test]
480    fn max_episodes_enforced() {
481        let policy = EpisodicPolicy::default();
482        let mut store = EpisodicStore::new("test");
483        for i in 0..510 {
484            store.record_episode(
485                make_episode(&format!("Task {i}"), Outcome::Unknown),
486                &policy,
487            );
488        }
489        assert!(store.episodes.len() <= policy.max_episodes);
490    }
491
492    #[test]
493    fn format_compact() {
494        let ep = make_episode("Deploy v2", Outcome::Success { tests_passed: true });
495        let output = format_episode_compact(&ep);
496        assert!(output.contains("[success]"));
497        assert!(output.contains("Deploy v2"));
498    }
499}