1use cersei_types::*;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11const MIN_MESSAGES_TO_EXTRACT: usize = 20;
14const MIN_TOOL_CALLS_BETWEEN_EXTRACTIONS: usize = 3;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum MemoryCategory {
21 UserPreference,
22 ProjectFact,
23 CodePattern,
24 Decision,
25 Constraint,
26}
27
28impl MemoryCategory {
29 pub fn label(&self) -> &'static str {
30 match self {
31 Self::UserPreference => "preference",
32 Self::ProjectFact => "project",
33 Self::CodePattern => "pattern",
34 Self::Decision => "decision",
35 Self::Constraint => "constraint",
36 }
37 }
38
39 pub fn from_str(s: &str) -> Option<Self> {
40 match s.to_lowercase().as_str() {
41 "preference" | "userpreference" | "user_preference" => Some(Self::UserPreference),
42 "project" | "projectfact" | "project_fact" => Some(Self::ProjectFact),
43 "pattern" | "codepattern" | "code_pattern" => Some(Self::CodePattern),
44 "decision" => Some(Self::Decision),
45 "constraint" => Some(Self::Constraint),
46 _ => None,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ExtractedMemory {
54 pub content: String,
55 pub category: MemoryCategory,
56 pub confidence: f32,
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct SessionMemoryState {
62 pub last_extracted_message_index: usize,
63 pub tool_calls_since_last: usize,
64 pub extraction_count: u32,
65}
66
67pub fn should_extract(messages: &[Message], state: &SessionMemoryState) -> bool {
71 if messages.len() < MIN_MESSAGES_TO_EXTRACT {
73 return false;
74 }
75
76 if state.extraction_count > 0
78 && state.tool_calls_since_last < MIN_TOOL_CALLS_BETWEEN_EXTRACTIONS
79 {
80 return false;
81 }
82
83 if let Some(last) = messages.iter().rev().find(|m| m.role == Role::Assistant) {
85 if last.has_tool_use() {
86 return false;
87 }
88 }
89
90 true
91}
92
93pub fn count_tool_calls_since(messages: &[Message], since_index: usize) -> usize {
95 messages[since_index..]
96 .iter()
97 .map(|m| m.get_tool_use_blocks().len())
98 .sum()
99}
100
101pub fn extraction_prompt() -> &'static str {
105 "You are a memory extraction system. Read the conversation and extract \
106 key facts worth remembering for future sessions.\n\n\
107 For each fact, output one line in this exact format:\n\
108 MEMORY: <category> | <confidence 0-10> | <fact>\n\n\
109 Categories: preference, project, pattern, decision, constraint\n\n\
110 Only extract facts that would be genuinely useful in future sessions. \
111 Don't extract trivial or ephemeral information. Be specific and actionable."
112}
113
114pub fn parse_extraction_output(output: &str) -> Vec<ExtractedMemory> {
116 output
117 .lines()
118 .filter_map(|line| {
119 let line = line.trim();
120 if !line.starts_with("MEMORY:") {
121 return None;
122 }
123 let rest = line.strip_prefix("MEMORY:")?.trim();
124 let parts: Vec<&str> = rest.splitn(3, '|').collect();
125 if parts.len() != 3 {
126 return None;
127 }
128
129 let category = MemoryCategory::from_str(parts[0].trim())?;
130 let confidence = parts[1].trim().parse::<f32>().ok()? / 10.0;
131 let content = parts[2].trim().to_string();
132
133 if content.is_empty() || confidence < 0.0 {
134 return None;
135 }
136
137 Some(ExtractedMemory {
138 content,
139 category,
140 confidence: confidence.clamp(0.0, 1.0),
141 })
142 })
143 .collect()
144}
145
146pub fn persist_memories(memories: &[ExtractedMemory], target_path: &Path) -> std::io::Result<()> {
150 if memories.is_empty() {
151 return Ok(());
152 }
153
154 if let Some(parent) = target_path.parent() {
155 std::fs::create_dir_all(parent)?;
156 }
157
158 let existing = std::fs::read_to_string(target_path).unwrap_or_default();
159
160 let date = chrono::Utc::now().format("%Y-%m-%d").to_string();
161 let section_header = "## Auto-extracted memories";
162 let date_header = format!("### Session memories — {}", date);
163
164 let mut new_entries = String::new();
165 for mem in memories {
166 new_entries.push_str(&format!(
167 "- **[{}]** {} *(confidence: {:.0}%)*\n",
168 mem.category.label(),
169 mem.content,
170 mem.confidence * 100.0,
171 ));
172 }
173
174 let output = if existing.contains(section_header) {
175 if existing.contains(&date_header) {
177 existing.replace(&date_header, &format!("{}\n{}", date_header, new_entries))
179 } else {
180 let insert_pos = existing.find(section_header).unwrap() + section_header.len();
182 let (before, after) = existing.split_at(insert_pos);
183 format!("{}\n\n{}\n{}\n{}", before, date_header, new_entries, after)
184 }
185 } else {
186 if existing.is_empty() {
188 format!("{}\n\n{}\n{}", section_header, date_header, new_entries)
189 } else {
190 format!(
191 "{}\n\n{}\n\n{}\n{}",
192 existing.trim(),
193 section_header,
194 date_header,
195 new_entries
196 )
197 }
198 };
199
200 std::fs::write(target_path, output)
201}
202
203#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn make_messages(n: usize) -> Vec<Message> {
210 (0..n)
211 .map(|i| {
212 if i % 2 == 0 {
213 Message::user(format!("Msg {}", i))
214 } else {
215 Message::assistant(format!("Response {}", i))
216 }
217 })
218 .collect()
219 }
220
221 #[test]
222 fn test_should_extract_below_threshold() {
223 let msgs = make_messages(10);
224 let state = SessionMemoryState::default();
225 assert!(!should_extract(&msgs, &state));
226 }
227
228 #[test]
229 fn test_should_extract_above_threshold() {
230 let msgs = make_messages(25);
231 let state = SessionMemoryState::default();
232 assert!(should_extract(&msgs, &state));
233 }
234
235 #[test]
236 fn test_should_extract_cooldown() {
237 let msgs = make_messages(25);
238 let state = SessionMemoryState {
239 extraction_count: 1,
240 tool_calls_since_last: 1, ..Default::default()
242 };
243 assert!(!should_extract(&msgs, &state));
244 }
245
246 #[test]
247 fn test_parse_extraction_output() {
248 let output = "\
249MEMORY: preference | 8 | User prefers Rust over Python
250MEMORY: project | 9 | The API uses REST with JSON responses
251MEMORY: decision | 7 | We chose PostgreSQL for the database
252not a memory line
253MEMORY: invalid | 5 | this category won't parse
254";
255 let memories = parse_extraction_output(output);
256 assert_eq!(memories.len(), 3);
257 assert_eq!(memories[0].content, "User prefers Rust over Python");
258 assert!((memories[0].confidence - 0.8).abs() < 0.01);
259 assert_eq!(memories[1].content, "The API uses REST with JSON responses");
260 }
261
262 #[test]
263 fn test_persist_memories() {
264 let tmp = tempfile::tempdir().unwrap();
265 let path = tmp.path().join("memories.md");
266
267 let memories = vec![
268 ExtractedMemory {
269 content: "User prefers dark mode".into(),
270 category: MemoryCategory::UserPreference,
271 confidence: 0.9,
272 },
273 ExtractedMemory {
274 content: "Project uses Cersei SDK".into(),
275 category: MemoryCategory::ProjectFact,
276 confidence: 0.95,
277 },
278 ];
279
280 persist_memories(&memories, &path).unwrap();
281 let content = std::fs::read_to_string(&path).unwrap();
282 assert!(content.contains("Auto-extracted memories"));
283 assert!(content.contains("dark mode"));
284 assert!(content.contains("Cersei SDK"));
285 assert!(content.contains("preference"));
286 assert!(content.contains("90%"));
287
288 let more = vec![ExtractedMemory {
290 content: "Tests use tokio".into(),
291 category: MemoryCategory::CodePattern,
292 confidence: 0.7,
293 }];
294 persist_memories(&more, &path).unwrap();
295 let content = std::fs::read_to_string(&path).unwrap();
296 assert!(content.contains("tokio"));
297 assert!(content.contains("dark mode")); }
299}