Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::entry::{MemoryCategory, MemoryEntry};
9use super::manager::AutoMemory;
10
11// ============================================================================
12// Memory Extractor Trait
13// ============================================================================
14
15/// Trait for memory extraction implementations.
16#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18    /// Extract memories from conversation text using AI.
19    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21    /// Get the model name used for extraction.
22    fn model_name(&self) -> &str;
23}
24
25/// AI-based memory extractor using a fast/cheap model.
26pub struct AiMemoryExtractor {
27    provider: Box<dyn crate::providers::Provider>,
28    model: String,
29}
30
31impl AiMemoryExtractor {
32    /// Create a new AI memory extractor.
33    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34        Self { provider, model }
35    }
36
37    /// Create a minimal extractor (for background tasks, uses simplified prompt).
38    /// This is more efficient for non-blocking background extraction.
39    pub fn new_minimal(model: String) -> Self {
40        // Create a minimal provider that uses the global config
41        // This is for background tasks, so we use a simplified approach
42        Self {
43            provider: crate::create_minimal_provider(&model),
44            model,
45        }
46    }
47}
48
49const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是记忆提取助手。从对话中提取值得长期记忆的关键信息。
50
51# 记忆类型
52
53<types>
54<type>
55    <name>decision</name>
56    <description>项目或技术选型的决定</description>
57    <when_to_save>用户明确做出技术决策时</when_to_save>
58    <body_structure>先写决策内容,然后 **Why:** 决策原因,**Context:** 适用场景</body_structure>
59</type>
60<type>
61    <name>preference</name>
62    <description>用户习惯或偏好</description>
63    <when_to_save>用户表达"我喜欢/习惯/偏好"时</when_to_save>
64    <body_structure>先写偏好内容,然后 **Why:** 偏好原因(如有)</body_structure>
65</type>
66<type>
67    <name>solution</name>
68    <description>解决问题的具体方法</description>
69    <when_to_save>问题成功解决且方法可复用时</when_to_save>
70    <body_structure>先写解决方案,然后 **Problem:** 解决的问题,**Key:** 关键步骤</body_structure>
71</type>
72<type>
73    <name>finding</name>
74    <description>重要发现或信息</description>
75    <when_to_save>发现非显而易见的信息时</when_to_save>
76</type>
77<type>
78    <name>technical</name>
79    <description>技术栈或框架信息</description>
80    <when_to_save>确认项目使用的技术时</when_to_save>
81</type>
82<type>
83    <name>structure</name>
84    <description>项目结构信息</description>
85    <when_to_save>发现关键入口或核心文件时</when_to_save>
86</type>
87</types>
88
89# 不要保存什么到记忆中
90
91- 代码路径、文件名、目录结构 — 可从项目实时获取
92- Git 历史、最近更改 — git log/blame 是权威来源
93- 临时状态:进行中的任务、当前对话上下文
94- 已在 CLAUDE.md/MATRIX.md 中记录的内容
95- 错误信息和调试细节 — 问题解决后无需保留
96
97这些排除规则即使当用户要求保存时也适用。
98如果他们要求保存临时信息,问:"有什么 surprising 或 non-obvious 的部分?"
99
100# 输出格式
101
102严格 JSON:
103{
104  "memories": [
105    {
106      "category": "decision",
107      "content": "采用 PostgreSQL 作为主数据库。**Why:** 性能要求和团队经验",
108      "importance": 85,
109      "keywords": ["PostgreSQL", "数据库", "database"],
110      "tags": ["backend", "storage"]
111    }
112  ]
113}
114
115关键词提取:3-5 个核心关键词(技术名词、项目名、关键概念)
116标签提取:1-3 个分类标签(backend、frontend、config、auth 等)
117
118只返回 JSON,不要其他解释。"#;
119
120#[async_trait::async_trait]
121impl MemoryExtractor for AiMemoryExtractor {
122    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
123        use crate::providers::{ChatRequest, Message, MessageContent, Role};
124
125        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
126        let truncated = truncate_chars(text, 4000);
127
128        let request = ChatRequest {
129            messages: vec![Message {
130                role: Role::User,
131                content: MessageContent::Text(format!(
132                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
133                    truncated
134                )),
135            }],
136            tools: vec![],
137            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
138            think: false,
139            max_tokens: 512,
140            server_tools: vec![],
141            enable_caching: false,
142        };
143
144        let response = self.provider.chat(request).await?;
145
146        let response_text = response
147            .content
148            .iter()
149            .filter_map(|b| {
150                if let crate::providers::ContentBlock::Text { text } = b {
151                    Some(text.clone())
152                } else {
153                    None
154                }
155            })
156            .collect::<Vec<_>>()
157            .join("");
158
159        parse_memory_response(&response_text, session_id, project_path)
160    }
161
162    fn model_name(&self) -> &str {
163        &self.model
164    }
165}
166
167fn parse_memory_response(json_text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
168    let cleaned = json_text
169        .trim()
170        .trim_start_matches("```json")
171        .trim_start_matches("```")
172        .trim_end_matches("```")
173        .trim();
174
175    #[derive(Deserialize)]
176    struct MemoryResponse {
177        memories: Vec<MemoryItem>,
178    }
179
180    #[derive(Deserialize)]
181    struct MemoryItem {
182        category: String,
183        content: String,
184        #[serde(default)]
185        importance: f64,
186        #[serde(default)]
187        keywords: Vec<String>,
188        #[serde(default)]
189        tags: Vec<String>,
190    }
191
192    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
193
194    let entries = parsed
195        .memories
196        .into_iter()
197        .filter_map(|item| {
198            let category = match item.category.to_lowercase().as_str() {
199                "decision" => MemoryCategory::Decision,
200                "preference" => MemoryCategory::Preference,
201                "solution" => MemoryCategory::Solution,
202                "finding" => MemoryCategory::Finding,
203                "technical" => MemoryCategory::Technical,
204                "structure" => MemoryCategory::Structure,
205                _ => return None,
206            };
207
208            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
209                return None;
210            }
211
212            let mut entry =
213                MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()), project_path.map(|p| p.to_string()));
214            if item.importance > 0.0 {
215                entry.importance = item.importance.clamp(0.0, 100.0);
216            }
217            // Add AI-extracted keywords and tags
218            if !item.keywords.is_empty() {
219                entry.tags.extend(item.keywords);
220            }
221            if !item.tags.is_empty() {
222                entry.tags.extend(item.tags);
223            }
224            entry.tags.dedup();
225
226            Some(entry)
227        })
228        .collect();
229
230    Ok(deduplicate_entries(entries))
231}
232
233fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
234    let mut seen: Vec<String> = Vec::new();
235    entries
236        .into_iter()
237        .filter(|e| {
238            let content_lower = e.content.to_lowercase();
239            if seen.iter().any(|s| {
240                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
241            }) {
242                false
243            } else {
244                seen.push(content_lower);
245                true
246            }
247        })
248        .take(MAX_DETECTED_ENTRIES)
249        .collect()
250}
251
252// ============================================================================
253// Rule-based Detection (uses KeywordsConfig)
254// ============================================================================
255
256/// Detect memories from text using hard-coded patterns.
257pub fn detect_memories_fallback(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
258    let mut entries = Vec::new();
259    let text_lower = text.to_lowercase();
260
261    // Hard-coded patterns for each category
262    let patterns = [
263        (MemoryCategory::Decision, ["决定", "选择", "采用", "定下", "decided", "chose"]),
264        (MemoryCategory::Preference, ["偏好", "习惯", "喜欢", "首选", "prefer", "like"]),
265        (MemoryCategory::Solution, ["解决", "修复", "搞定", "改成", "fixed", "solved"]),
266        (MemoryCategory::Finding, ["发现", "原来", "原因", "定位", "found", "reason"]),
267        (MemoryCategory::Technical, ["技术栈", "框架", "用的", "基于", "stack", "using"]),
268        (MemoryCategory::Structure, ["入口", "主文件", "目录", "位于", "entry", "main"]),
269    ];
270
271    for (category, keywords) in patterns {
272        for keyword in keywords {
273            if text_lower.contains(&keyword.to_lowercase()) {
274                let content = extract_memory_content(text, keyword);
275                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
276                    entries.push(MemoryEntry::new(
277                        category,
278                        content,
279                        session_id.map(|s| s.to_string()),
280                        project_path.map(|p| p.to_string()),
281                    ));
282                }
283            }
284        }
285    }
286
287    deduplicate_entries(entries)
288}
289
290/// Detect memories from text (wrapper for fallback).
291pub fn detect_memories_from_text(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
292    detect_memories_fallback(text, session_id, project_path)
293}
294
295/// Smart detection: AI-first with rule-based fallback.
296///
297/// Priority order:
298/// 1. AI extraction (if text > 200 chars and extractor available)
299/// 2. Rule-based fallback (if AI fails or text too short)
300pub async fn detect_memories_smart(
301    text: &str,
302    session_id: Option<&str>,
303    project_path: Option<&str>,
304    extractor: Option<&AiMemoryExtractor>,
305) -> Vec<MemoryEntry> {
306    let mode = AiDetectionMode::from_env();
307    let text_len = text.len();
308
309    // Determine if we should try AI first
310    // Only use AI for text > 200 chars (avoid API overhead for short texts)
311    let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
312
313    // Debug log: show method and model
314    let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
315    crate::debug::debug_log().memory_ai_detection(
316        model_name,
317        0, // Will update after detection
318        text_len,
319        should_try_ai,
320    );
321
322    if should_try_ai && let Some(ex) = extractor {
323        if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
324            // AI succeeded - use AI results entirely (skip hardcoded rules)
325            // Debug log: AI result
326            crate::debug::debug_log().memory_ai_detection(
327                ex.model_name(),
328                ai_entries.len(),
329                text_len,
330                true,
331            );
332            return deduplicate_entries(ai_entries);
333        }
334        // AI failed - log and skip rule-based fallback (per user request)
335        log::warn!("AI memory extraction failed, skipping detection for this turn");
336        return Vec::new();
337    }
338
339    // For short texts (< 200 chars), skip detection entirely (per user request)
340    // No rule-based fallback
341    Vec::new()
342}
343
344fn extract_memory_content(text: &str, keyword: &str) -> String {
345    let text_lower = text.to_lowercase();
346    let keyword_lower = keyword.to_lowercase();
347
348    let pos = match text_lower.find(&keyword_lower) {
349        Some(p) => p,
350        None => return String::new(),
351    };
352
353    // Find sentence containing the keyword
354    let start = text[..pos]
355        .rfind(['.', '。', '\n'])
356        .map(|i| i + 1)
357        .unwrap_or(0);
358
359    let end = text[pos..]
360        .find(['.', '。', '\n'])
361        .map(|i| pos + i + 1)
362        .unwrap_or(text.len());
363
364    let sentence = text[start..end].trim();
365
366    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
367        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
368    } else {
369        sentence.to_string()
370    }
371}
372
373/// Infer category from content.
374pub fn infer_category_from_content(content: &str) -> MemoryCategory {
375    let lower = content.to_lowercase();
376
377    if lower.contains("决定")
378        || lower.contains("选择")
379        || lower.contains("采用")
380        || lower.contains("decided")
381    {
382        return MemoryCategory::Decision;
383    }
384    if lower.contains("喜欢")
385        || lower.contains("偏好")
386        || lower.contains("习惯")
387        || lower.contains("prefer")
388    {
389        return MemoryCategory::Preference;
390    }
391    if lower.contains("解决")
392        || lower.contains("修复")
393        || lower.contains("搞定")
394        || lower.contains("fixed")
395    {
396        return MemoryCategory::Solution;
397    }
398    if lower.contains("发现")
399        || lower.contains("原因")
400        || lower.contains("原来")
401        || lower.contains("found")
402    {
403        return MemoryCategory::Finding;
404    }
405    if lower.contains("技术")
406        || lower.contains("框架")
407        || lower.contains("库")
408        || lower.contains("tech")
409    {
410        return MemoryCategory::Technical;
411    }
412    if lower.contains("文件")
413        || lower.contains("目录")
414        || lower.contains("入口")
415        || lower.contains("file")
416    {
417        return MemoryCategory::Structure;
418    }
419
420    MemoryCategory::Finding // Default
421}