Skip to main content

dot/memory/
extract.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4
5use crate::memory::{MemoryKind, MemoryStore};
6use crate::provider::{ContentBlock, Message, Provider, Role, StreamEventType};
7
8const EXTRACT_SYSTEM: &str = "\
9You are a memory extraction assistant for an AI coding agent. Your job is to identify important, durable facts from the conversation that are worth remembering across sessions.
10
11Extract facts that would be useful in future conversations: user preferences, technical decisions, project context, workflow patterns, environment details, names, and tool configurations.
12
13DO NOT extract:
14- Transient task details (file contents being edited right now, current debugging steps)
15- Information obvious from the codebase (imports, function signatures)
16- Generic coding knowledge
17
18You will receive:
191. Recent conversation messages
202. An existing memory snapshot (may be empty)
21
22Return ONLY valid JSON with this structure:
23{
24  \"add\": [{\"content\": \"...\", \"kind\": \"fact|preference|decision|project|entity|belief\", \"importance\": 0.0-1.0}],
25  \"update\": [{\"id\": \"existing-memory-id\", \"content\": \"updated text\", \"importance\": 0.0-1.0}],
26  \"delete\": [\"memory-id-that-is-stale-or-contradicted\"]
27}
28
29Rules:
30- Each memory should be a single, self-contained statement
31- Prefer updating an existing memory over adding a duplicate
32- Delete memories that are contradicted by new information
33- importance: 0.9+ for core identity/preferences, 0.7-0.9 for project facts, 0.5-0.7 for contextual details, <0.5 for ephemeral
34- Return {\"add\":[],\"update\":[],\"delete\":[]} if nothing worth remembering";
35
36#[derive(Debug)]
37pub struct ExtractionResult {
38    pub added: usize,
39    pub updated: usize,
40    pub deleted: usize,
41}
42
43pub async fn extract(
44    messages: &[Message],
45    provider: &dyn Provider,
46    store: &Arc<MemoryStore>,
47    conversation_id: &str,
48) -> Result<ExtractionResult> {
49    let conversation_text = format_messages(messages);
50    if conversation_text.trim().is_empty() {
51        return Ok(ExtractionResult {
52            added: 0,
53            updated: 0,
54            deleted: 0,
55        });
56    }
57
58    let snapshot = store.snapshot(50).unwrap_or_default();
59    let snapshot_text = if snapshot.is_empty() {
60        "No existing memories.".to_string()
61    } else {
62        snapshot
63            .iter()
64            .map(|m| {
65                format!(
66                    "- [{}] id={} importance={:.2}: {}",
67                    m.kind, m.id, m.importance, m.content
68                )
69            })
70            .collect::<Vec<_>>()
71            .join("\n")
72    };
73
74    let prompt = format!(
75        "## Recent Conversation\n{}\n\n## Existing Memory Snapshot\n{}\n\nExtract memories as JSON.",
76        conversation_text, snapshot_text
77    );
78
79    let request = vec![Message {
80        role: Role::User,
81        content: vec![ContentBlock::Text(prompt)],
82    }];
83
84    let mut rx = provider
85        .stream(&request, Some(EXTRACT_SYSTEM), &[], 2048, 0)
86        .await
87        .context("starting extraction stream")?;
88
89    let mut response = String::new();
90    while let Some(event) = rx.recv().await {
91        if let StreamEventType::TextDelta(text) = event.event_type {
92            response.push_str(&text);
93        }
94    }
95
96    let json = extract_json(&response).unwrap_or(&response);
97    let ops: ExtractionOps = serde_json::from_str(json).unwrap_or_default();
98
99    let mut added = 0usize;
100    let mut updated = 0usize;
101    let mut deleted = 0usize;
102
103    for item in &ops.add {
104        let kind = MemoryKind::parse(&item.kind);
105        let importance = item.importance.clamp(0.0, 1.0);
106        if item.content.len() < 5 {
107            continue;
108        }
109        match store.add(&item.content, &kind, importance, Some(conversation_id)) {
110            Ok(_) => added += 1,
111            Err(e) => tracing::warn!("memory add failed: {e}"),
112        }
113    }
114
115    for item in &ops.update {
116        let importance = item.importance.clamp(0.0, 1.0);
117        match store.update(&item.id, &item.content, importance) {
118            Ok(()) => updated += 1,
119            Err(e) => tracing::warn!("memory update failed for {}: {e}", item.id),
120        }
121    }
122
123    for id in &ops.delete {
124        match store.delete(id) {
125            Ok(()) => deleted += 1,
126            Err(e) => tracing::warn!("memory delete failed for {id}: {e}"),
127        }
128    }
129
130    tracing::info!(
131        "memory extraction: +{added} ~{updated} -{deleted} from conversation {conversation_id}"
132    );
133
134    Ok(ExtractionResult {
135        added,
136        updated,
137        deleted,
138    })
139}
140
141fn format_messages(messages: &[Message]) -> String {
142    let tail = if messages.len() > 20 {
143        &messages[messages.len() - 20..]
144    } else {
145        messages
146    };
147    let mut out = String::new();
148    for msg in tail {
149        let role = match msg.role {
150            Role::User => "User",
151            Role::Assistant => "Assistant",
152            Role::System => continue,
153        };
154        for block in &msg.content {
155            match block {
156                ContentBlock::Text(t) if !t.is_empty() => {
157                    let truncated: String = t.chars().take(2000).collect();
158                    out.push_str(&format!("{role}: {truncated}\n\n"));
159                }
160                ContentBlock::ToolUse { name, .. } => {
161                    out.push_str(&format!("{role}: [used tool: {name}]\n\n"));
162                }
163                _ => {}
164            }
165        }
166    }
167    out
168}
169
170fn extract_json(text: &str) -> Option<&str> {
171    let start = text.find('{')?;
172    let mut depth = 0i32;
173    for (i, ch) in text[start..].char_indices() {
174        match ch {
175            '{' => depth += 1,
176            '}' => {
177                depth -= 1;
178                if depth == 0 {
179                    return Some(&text[start..start + i + 1]);
180                }
181            }
182            _ => {}
183        }
184    }
185    None
186}
187
188#[derive(Debug, Default, serde::Deserialize)]
189struct ExtractionOps {
190    #[serde(default)]
191    add: Vec<AddOp>,
192    #[serde(default)]
193    update: Vec<UpdateOp>,
194    #[serde(default)]
195    delete: Vec<String>,
196}
197
198#[derive(Debug, serde::Deserialize)]
199struct AddOp {
200    content: String,
201    #[serde(default = "default_kind")]
202    kind: String,
203    #[serde(default = "default_importance")]
204    importance: f32,
205}
206
207#[derive(Debug, serde::Deserialize)]
208struct UpdateOp {
209    id: String,
210    content: String,
211    #[serde(default = "default_importance")]
212    importance: f32,
213}
214
215fn default_kind() -> String {
216    "fact".to_string()
217}
218
219fn default_importance() -> f32 {
220    0.5
221}