Skip to main content

cersei_agent/
session_memory.rs

1//! Session memory extraction: extract key facts from conversations.
2//!
3//! After enough conversation activity (≥20 messages, ≥3 tool calls since
4//! last extraction), the extractor calls the LLM to identify reusable facts
5//! and persists them to the memory directory.
6
7use cersei_types::*;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11// ─── Constants ───────────────────────────────────────────────────────────────
12
13const MIN_MESSAGES_TO_EXTRACT: usize = 20;
14const MIN_TOOL_CALLS_BETWEEN_EXTRACTIONS: usize = 3;
15
16// ─── Types ───────────────────────────────────────────────────────────────────
17
18/// Categories of extracted memories.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum MemoryCategory {
21    UserPreference,
22    ProjectFact,
23    CodePattern,
24    Decision,
25    Constraint,
26}
27
28impl MemoryCategory {
29    pub fn label(&self) -> &'static str {
30        match self {
31            Self::UserPreference => "preference",
32            Self::ProjectFact => "project",
33            Self::CodePattern => "pattern",
34            Self::Decision => "decision",
35            Self::Constraint => "constraint",
36        }
37    }
38
39    pub fn from_str(s: &str) -> Option<Self> {
40        match s.to_lowercase().as_str() {
41            "preference" | "userpreference" | "user_preference" => Some(Self::UserPreference),
42            "project" | "projectfact" | "project_fact" => Some(Self::ProjectFact),
43            "pattern" | "codepattern" | "code_pattern" => Some(Self::CodePattern),
44            "decision" => Some(Self::Decision),
45            "constraint" => Some(Self::Constraint),
46            _ => None,
47        }
48    }
49}
50
51/// A single extracted memory fact.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ExtractedMemory {
54    pub content: String,
55    pub category: MemoryCategory,
56    pub confidence: f32,
57}
58
59/// Tracks extraction progress to avoid re-extracting.
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct SessionMemoryState {
62    pub last_extracted_message_index: usize,
63    pub tool_calls_since_last: usize,
64    pub extraction_count: u32,
65}
66
67// ─── Gate checks ─────────────────────────────────────────────────────────────
68
69/// Check if enough conversation has happened to warrant extraction.
70pub fn should_extract(messages: &[Message], state: &SessionMemoryState) -> bool {
71    // Need enough messages
72    if messages.len() < MIN_MESSAGES_TO_EXTRACT {
73        return false;
74    }
75
76    // Need enough tool calls since last extraction
77    if state.extraction_count > 0
78        && state.tool_calls_since_last < MIN_TOOL_CALLS_BETWEEN_EXTRACTIONS
79    {
80        return false;
81    }
82
83    // Don't extract if the last assistant message has pending tool calls
84    if let Some(last) = messages.iter().rev().find(|m| m.role == Role::Assistant) {
85        if last.has_tool_use() {
86            return false;
87        }
88    }
89
90    true
91}
92
93/// Count tool calls in messages since a given index.
94pub fn count_tool_calls_since(messages: &[Message], since_index: usize) -> usize {
95    messages[since_index..]
96        .iter()
97        .map(|m| m.get_tool_use_blocks().len())
98        .sum()
99}
100
101// ─── Extraction prompt ──────────────────────────────────────────────────────
102
103/// Build the extraction system prompt.
104pub fn extraction_prompt() -> &'static str {
105    "You are a memory extraction system. Read the conversation and extract \
106    key facts worth remembering for future sessions.\n\n\
107    For each fact, output one line in this exact format:\n\
108    MEMORY: <category> | <confidence 0-10> | <fact>\n\n\
109    Categories: preference, project, pattern, decision, constraint\n\n\
110    Only extract facts that would be genuinely useful in future sessions. \
111    Don't extract trivial or ephemeral information. Be specific and actionable."
112}
113
114/// Parse extraction output into structured memories.
115pub fn parse_extraction_output(output: &str) -> Vec<ExtractedMemory> {
116    output
117        .lines()
118        .filter_map(|line| {
119            let line = line.trim();
120            if !line.starts_with("MEMORY:") {
121                return None;
122            }
123            let rest = line.strip_prefix("MEMORY:")?.trim();
124            let parts: Vec<&str> = rest.splitn(3, '|').collect();
125            if parts.len() != 3 {
126                return None;
127            }
128
129            let category = MemoryCategory::from_str(parts[0].trim())?;
130            let confidence = parts[1].trim().parse::<f32>().ok()? / 10.0;
131            let content = parts[2].trim().to_string();
132
133            if content.is_empty() || confidence < 0.0 {
134                return None;
135            }
136
137            Some(ExtractedMemory {
138                content,
139                category,
140                confidence: confidence.clamp(0.0, 1.0),
141            })
142        })
143        .collect()
144}
145
146// ─── Persistence ─────────────────────────────────────────────────────────────
147
148/// Persist extracted memories to a file under `## Auto-extracted memories`.
149pub fn persist_memories(memories: &[ExtractedMemory], target_path: &Path) -> std::io::Result<()> {
150    if memories.is_empty() {
151        return Ok(());
152    }
153
154    if let Some(parent) = target_path.parent() {
155        std::fs::create_dir_all(parent)?;
156    }
157
158    let existing = std::fs::read_to_string(target_path).unwrap_or_default();
159
160    let date = chrono::Utc::now().format("%Y-%m-%d").to_string();
161    let section_header = "## Auto-extracted memories";
162    let date_header = format!("### Session memories — {}", date);
163
164    let mut new_entries = String::new();
165    for mem in memories {
166        new_entries.push_str(&format!(
167            "- **[{}]** {} *(confidence: {:.0}%)*\n",
168            mem.category.label(),
169            mem.content,
170            mem.confidence * 100.0,
171        ));
172    }
173
174    let output = if existing.contains(section_header) {
175        // Append under existing section
176        if existing.contains(&date_header) {
177            // Append to existing date block
178            existing.replace(&date_header, &format!("{}\n{}", date_header, new_entries))
179        } else {
180            // Add new date block at end of section
181            let insert_pos = existing.find(section_header).unwrap() + section_header.len();
182            let (before, after) = existing.split_at(insert_pos);
183            format!("{}\n\n{}\n{}\n{}", before, date_header, new_entries, after)
184        }
185    } else {
186        // Create section
187        if existing.is_empty() {
188            format!("{}\n\n{}\n{}", section_header, date_header, new_entries)
189        } else {
190            format!(
191                "{}\n\n{}\n\n{}\n{}",
192                existing.trim(),
193                section_header,
194                date_header,
195                new_entries
196            )
197        }
198    };
199
200    std::fs::write(target_path, output)
201}
202
203// ─── Tests ───────────────────────────────────────────────────────────────────
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    fn make_messages(n: usize) -> Vec<Message> {
210        (0..n)
211            .map(|i| {
212                if i % 2 == 0 {
213                    Message::user(format!("Msg {}", i))
214                } else {
215                    Message::assistant(format!("Response {}", i))
216                }
217            })
218            .collect()
219    }
220
221    #[test]
222    fn test_should_extract_below_threshold() {
223        let msgs = make_messages(10);
224        let state = SessionMemoryState::default();
225        assert!(!should_extract(&msgs, &state));
226    }
227
228    #[test]
229    fn test_should_extract_above_threshold() {
230        let msgs = make_messages(25);
231        let state = SessionMemoryState::default();
232        assert!(should_extract(&msgs, &state));
233    }
234
235    #[test]
236    fn test_should_extract_cooldown() {
237        let msgs = make_messages(25);
238        let state = SessionMemoryState {
239            extraction_count: 1,
240            tool_calls_since_last: 1, // < 3
241            ..Default::default()
242        };
243        assert!(!should_extract(&msgs, &state));
244    }
245
246    #[test]
247    fn test_parse_extraction_output() {
248        let output = "\
249MEMORY: preference | 8 | User prefers Rust over Python
250MEMORY: project | 9 | The API uses REST with JSON responses
251MEMORY: decision | 7 | We chose PostgreSQL for the database
252not a memory line
253MEMORY: invalid | 5 | this category won't parse
254";
255        let memories = parse_extraction_output(output);
256        assert_eq!(memories.len(), 3);
257        assert_eq!(memories[0].content, "User prefers Rust over Python");
258        assert!((memories[0].confidence - 0.8).abs() < 0.01);
259        assert_eq!(memories[1].content, "The API uses REST with JSON responses");
260    }
261
262    #[test]
263    fn test_persist_memories() {
264        let tmp = tempfile::tempdir().unwrap();
265        let path = tmp.path().join("memories.md");
266
267        let memories = vec![
268            ExtractedMemory {
269                content: "User prefers dark mode".into(),
270                category: MemoryCategory::UserPreference,
271                confidence: 0.9,
272            },
273            ExtractedMemory {
274                content: "Project uses Cersei SDK".into(),
275                category: MemoryCategory::ProjectFact,
276                confidence: 0.95,
277            },
278        ];
279
280        persist_memories(&memories, &path).unwrap();
281        let content = std::fs::read_to_string(&path).unwrap();
282        assert!(content.contains("Auto-extracted memories"));
283        assert!(content.contains("dark mode"));
284        assert!(content.contains("Cersei SDK"));
285        assert!(content.contains("preference"));
286        assert!(content.contains("90%"));
287
288        // Persist more — should append
289        let more = vec![ExtractedMemory {
290            content: "Tests use tokio".into(),
291            category: MemoryCategory::CodePattern,
292            confidence: 0.7,
293        }];
294        persist_memories(&more, &path).unwrap();
295        let content = std::fs::read_to_string(&path).unwrap();
296        assert!(content.contains("tokio"));
297        assert!(content.contains("dark mode")); // original preserved
298    }
299}