1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4
5use crate::memory::{MemoryKind, MemoryStore};
6use crate::provider::{ContentBlock, Message, Provider, Role, StreamEventType};
7
8const EXTRACT_SYSTEM: &str = "\
9You are a memory extraction assistant for an AI coding agent. Your job is to identify important, durable facts from the conversation that are worth remembering across sessions.
10
11Extract facts that would be useful in future conversations: user preferences, technical decisions, project context, workflow patterns, environment details, names, and tool configurations.
12
13DO NOT extract:
14- Transient task details (file contents being edited right now, current debugging steps)
15- Information obvious from the codebase (imports, function signatures)
16- Generic coding knowledge
17
18You will receive:
191. Recent conversation messages
202. An existing memory snapshot (may be empty)
21
22Return ONLY valid JSON with this structure:
23{
24 \"add\": [{\"content\": \"...\", \"kind\": \"fact|preference|decision|project|entity|belief\", \"importance\": 0.0-1.0}],
25 \"update\": [{\"id\": \"existing-memory-id\", \"content\": \"updated text\", \"importance\": 0.0-1.0}],
26 \"delete\": [\"memory-id-that-is-stale-or-contradicted\"]
27}
28
29Rules:
30- Each memory should be a single, self-contained statement
31- Prefer updating an existing memory over adding a duplicate
32- Delete memories that are contradicted by new information
33- importance: 0.9+ for core identity/preferences, 0.7-0.9 for project facts, 0.5-0.7 for contextual details, <0.5 for ephemeral
34- Return {\"add\":[],\"update\":[],\"delete\":[]} if nothing worth remembering";
35
36#[derive(Debug)]
37pub struct ExtractionResult {
38 pub added: usize,
39 pub updated: usize,
40 pub deleted: usize,
41}
42
43pub async fn extract(
44 messages: &[Message],
45 provider: &dyn Provider,
46 store: &Arc<MemoryStore>,
47 conversation_id: &str,
48) -> Result<ExtractionResult> {
49 let conversation_text = format_messages(messages);
50 if conversation_text.trim().is_empty() {
51 return Ok(ExtractionResult {
52 added: 0,
53 updated: 0,
54 deleted: 0,
55 });
56 }
57
58 let snapshot = store.snapshot(50).unwrap_or_default();
59 let snapshot_text = if snapshot.is_empty() {
60 "No existing memories.".to_string()
61 } else {
62 snapshot
63 .iter()
64 .map(|m| {
65 format!(
66 "- [{}] id={} importance={:.2}: {}",
67 m.kind, m.id, m.importance, m.content
68 )
69 })
70 .collect::<Vec<_>>()
71 .join("\n")
72 };
73
74 let prompt = format!(
75 "## Recent Conversation\n{}\n\n## Existing Memory Snapshot\n{}\n\nExtract memories as JSON.",
76 conversation_text, snapshot_text
77 );
78
79 let request = vec![Message {
80 role: Role::User,
81 content: vec![ContentBlock::Text(prompt)],
82 }];
83
84 let mut rx = provider
85 .stream(&request, Some(EXTRACT_SYSTEM), &[], 2048, 0)
86 .await
87 .context("starting extraction stream")?;
88
89 let mut response = String::new();
90 while let Some(event) = rx.recv().await {
91 if let StreamEventType::TextDelta(text) = event.event_type {
92 response.push_str(&text);
93 }
94 }
95
96 let json = extract_json(&response).unwrap_or(&response);
97 let ops: ExtractionOps = serde_json::from_str(json).unwrap_or_default();
98
99 let mut added = 0usize;
100 let mut updated = 0usize;
101 let mut deleted = 0usize;
102
103 for item in &ops.add {
104 let kind = MemoryKind::parse(&item.kind);
105 let importance = item.importance.clamp(0.0, 1.0);
106 if item.content.len() < 5 {
107 continue;
108 }
109 match store.add(&item.content, &kind, importance, Some(conversation_id)) {
110 Ok(_) => added += 1,
111 Err(e) => tracing::warn!("memory add failed: {e}"),
112 }
113 }
114
115 for item in &ops.update {
116 let importance = item.importance.clamp(0.0, 1.0);
117 match store.update(&item.id, &item.content, importance) {
118 Ok(()) => updated += 1,
119 Err(e) => tracing::warn!("memory update failed for {}: {e}", item.id),
120 }
121 }
122
123 for id in &ops.delete {
124 match store.delete(id) {
125 Ok(()) => deleted += 1,
126 Err(e) => tracing::warn!("memory delete failed for {id}: {e}"),
127 }
128 }
129
130 tracing::info!(
131 "memory extraction: +{added} ~{updated} -{deleted} from conversation {conversation_id}"
132 );
133
134 Ok(ExtractionResult {
135 added,
136 updated,
137 deleted,
138 })
139}
140
141fn format_messages(messages: &[Message]) -> String {
142 let tail = if messages.len() > 20 {
143 &messages[messages.len() - 20..]
144 } else {
145 messages
146 };
147 let mut out = String::new();
148 for msg in tail {
149 let role = match msg.role {
150 Role::User => "User",
151 Role::Assistant => "Assistant",
152 Role::System => continue,
153 };
154 for block in &msg.content {
155 match block {
156 ContentBlock::Text(t) if !t.is_empty() => {
157 let truncated: String = t.chars().take(2000).collect();
158 out.push_str(&format!("{role}: {truncated}\n\n"));
159 }
160 ContentBlock::ToolUse { name, .. } => {
161 out.push_str(&format!("{role}: [used tool: {name}]\n\n"));
162 }
163 _ => {}
164 }
165 }
166 }
167 out
168}
169
170fn extract_json(text: &str) -> Option<&str> {
171 let start = text.find('{')?;
172 let mut depth = 0i32;
173 for (i, ch) in text[start..].char_indices() {
174 match ch {
175 '{' => depth += 1,
176 '}' => {
177 depth -= 1;
178 if depth == 0 {
179 return Some(&text[start..start + i + 1]);
180 }
181 }
182 _ => {}
183 }
184 }
185 None
186}
187
188#[derive(Debug, Default, serde::Deserialize)]
189struct ExtractionOps {
190 #[serde(default)]
191 add: Vec<AddOp>,
192 #[serde(default)]
193 update: Vec<UpdateOp>,
194 #[serde(default)]
195 delete: Vec<String>,
196}
197
198#[derive(Debug, serde::Deserialize)]
199struct AddOp {
200 content: String,
201 #[serde(default = "default_kind")]
202 kind: String,
203 #[serde(default = "default_importance")]
204 importance: f32,
205}
206
207#[derive(Debug, serde::Deserialize)]
208struct UpdateOp {
209 id: String,
210 content: String,
211 #[serde(default = "default_importance")]
212 importance: f32,
213}
214
215fn default_kind() -> String {
216 "fact".to_string()
217}
218
219fn default_importance() -> f32 {
220 0.5
221}