Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::entry::{MemoryCategory, MemoryEntry};
9use super::manager::AutoMemory;
10
11// ============================================================================
12// Memory Extractor Trait
13// ============================================================================
14
15/// Trait for memory extraction implementations.
16#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18    /// Extract memories from conversation text using AI.
19    async fn extract(
20        &self,
21        text: &str,
22        session_id: Option<&str>,
23        project_path: Option<&str>,
24    ) -> Result<Vec<MemoryEntry>>;
25
26    /// Get the model name used for extraction.
27    fn model_name(&self) -> &str;
28}
29
30/// AI-based memory extractor using a fast/cheap model.
31pub struct AiMemoryExtractor {
32    provider: Box<dyn crate::providers::Provider>,
33    model: String,
34}
35
36impl AiMemoryExtractor {
37    /// Create a new AI memory extractor.
38    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
39        Self { provider, model }
40    }
41
42    /// Create a minimal extractor (for background tasks, uses simplified prompt).
43    /// This is more efficient for non-blocking background extraction.
44    pub fn new_minimal(model: String) -> Self {
45        // Create a minimal provider that uses the global config
46        // This is for background tasks, so we use a simplified approach
47        Self {
48            provider: crate::create_minimal_provider(&model),
49            model,
50        }
51    }
52}
53
54const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是记忆提取助手。从对话中提取值得长期记忆的关键信息。
55
56# 记忆类型
57
58<types>
59<type>
60    <name>decision</name>
61    <description>项目或技术选型的决定</description>
62    <when_to_save>用户明确做出技术决策时</when_to_save>
63    <body_structure>先写决策内容,然后 **Why:** 决策原因,**Context:** 适用场景</body_structure>
64</type>
65<type>
66    <name>preference</name>
67    <description>用户习惯或偏好</description>
68    <when_to_save>用户表达"我喜欢/习惯/偏好"时</when_to_save>
69    <body_structure>先写偏好内容,然后 **Why:** 偏好原因(如有)</body_structure>
70</type>
71<type>
72    <name>solution</name>
73    <description>解决问题的具体方法</description>
74    <when_to_save>问题成功解决且方法可复用时</when_to_save>
75    <body_structure>先写解决方案,然后 **Problem:** 解决的问题,**Key:** 关键步骤</body_structure>
76</type>
77<type>
78    <name>finding</name>
79    <description>重要发现或信息</description>
80    <when_to_save>发现非显而易见的信息时</when_to_save>
81</type>
82<type>
83    <name>technical</name>
84    <description>技术栈或框架信息</description>
85    <when_to_save>确认项目使用的技术时</when_to_save>
86</type>
87<type>
88    <name>structure</name>
89    <description>项目结构信息</description>
90    <when_to_save>发现关键入口或核心文件时</when_to_save>
91</type>
92</types>
93
94# 不要保存什么到记忆中
95
96- 代码路径、文件名、目录结构 — 可从项目实时获取
97- Git 历史、最近更改 — git log/blame 是权威来源
98- 临时状态:进行中的任务、当前对话上下文
99- 已在 CLAUDE.md/MATRIX.md 中记录的内容
100- 错误信息和调试细节 — 问题解决后无需保留
101
102这些排除规则即使当用户要求保存时也适用。
103如果他们要求保存临时信息,问:"有什么 surprising 或 non-obvious 的部分?"
104
105# 输出格式
106
107严格 JSON:
108{
109  "memories": [
110    {
111      "category": "decision",
112      "content": "采用 PostgreSQL 作为主数据库。**Why:** 性能要求和团队经验",
113      "importance": 85,
114      "keywords": ["PostgreSQL", "数据库", "database"],
115      "tags": ["backend", "storage"]
116    }
117  ]
118}
119
120关键词提取:3-5 个核心关键词(技术名词、项目名、关键概念)
121标签提取:1-3 个分类标签(backend、frontend、config、auth 等)
122
123只返回 JSON,不要其他解释。"#;
124
125#[async_trait::async_trait]
126impl MemoryExtractor for AiMemoryExtractor {
127    async fn extract(
128        &self,
129        text: &str,
130        session_id: Option<&str>,
131        project_path: Option<&str>,
132    ) -> Result<Vec<MemoryEntry>> {
133        use crate::providers::{ChatRequest, Message, MessageContent, Role};
134
135        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
136        let truncated = truncate_chars(text, 4000);
137
138        let request = ChatRequest {
139            messages: vec![Message {
140                role: Role::User,
141                content: MessageContent::Text(format!(
142                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
143                    truncated
144                )),
145            }],
146            tools: vec![],
147            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
148            think: false,
149            max_tokens: 512,
150            server_tools: vec![],
151            enable_caching: false,
152        };
153
154        let response = self.provider.chat(request).await?;
155
156        let response_text = response
157            .content
158            .iter()
159            .filter_map(|b| {
160                if let crate::providers::ContentBlock::Text { text } = b {
161                    Some(text.clone())
162                } else {
163                    None
164                }
165            })
166            .collect::<Vec<_>>()
167            .join("");
168
169        parse_memory_response(&response_text, session_id, project_path)
170    }
171
172    fn model_name(&self) -> &str {
173        &self.model
174    }
175}
176
177fn parse_memory_response(
178    json_text: &str,
179    session_id: Option<&str>,
180    project_path: Option<&str>,
181) -> Result<Vec<MemoryEntry>> {
182    let cleaned = json_text
183        .trim()
184        .trim_start_matches("```json")
185        .trim_start_matches("```")
186        .trim_end_matches("```")
187        .trim();
188
189    #[derive(Deserialize)]
190    struct MemoryResponse {
191        memories: Vec<MemoryItem>,
192    }
193
194    #[derive(Deserialize)]
195    struct MemoryItem {
196        category: String,
197        content: String,
198        #[serde(default)]
199        importance: f64,
200        #[serde(default)]
201        keywords: Vec<String>,
202        #[serde(default)]
203        tags: Vec<String>,
204    }
205
206    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
207
208    let entries = parsed
209        .memories
210        .into_iter()
211        .filter_map(|item| {
212            let category = match item.category.to_lowercase().as_str() {
213                "decision" => MemoryCategory::Decision,
214                "preference" => MemoryCategory::Preference,
215                "solution" => MemoryCategory::Solution,
216                "finding" => MemoryCategory::Finding,
217                "technical" => MemoryCategory::Technical,
218                "structure" => MemoryCategory::Structure,
219                _ => return None,
220            };
221
222            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
223                return None;
224            }
225
226            let mut entry = MemoryEntry::new(
227                category,
228                item.content,
229                session_id.map(|s| s.to_string()),
230                project_path.map(|p| p.to_string()),
231            );
232            if item.importance > 0.0 {
233                entry.importance = item.importance.clamp(0.0, 100.0);
234            }
235            // Add AI-extracted keywords and tags
236            if !item.keywords.is_empty() {
237                entry.tags.extend(item.keywords);
238            }
239            if !item.tags.is_empty() {
240                entry.tags.extend(item.tags);
241            }
242            entry.tags.dedup();
243
244            Some(entry)
245        })
246        .collect();
247
248    Ok(deduplicate_entries(entries))
249}
250
251fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
252    let mut seen: Vec<String> = Vec::new();
253    entries
254        .into_iter()
255        .filter(|e| {
256            let content_lower = e.content.to_lowercase();
257            if seen.iter().any(|s| {
258                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
259            }) {
260                false
261            } else {
262                seen.push(content_lower);
263                true
264            }
265        })
266        .take(MAX_DETECTED_ENTRIES)
267        .collect()
268}
269
270// ============================================================================
271// Rule-based Detection (uses KeywordsConfig)
272// ============================================================================
273
274/// Detect memories from text using hard-coded patterns.
275pub fn detect_memories_fallback(
276    text: &str,
277    session_id: Option<&str>,
278    project_path: Option<&str>,
279) -> Vec<MemoryEntry> {
280    let mut entries = Vec::new();
281    let text_lower = text.to_lowercase();
282
283    // Hard-coded patterns for each category
284    let patterns = [
285        (
286            MemoryCategory::Decision,
287            ["决定", "选择", "采用", "定下", "decided", "chose"],
288        ),
289        (
290            MemoryCategory::Preference,
291            ["偏好", "习惯", "喜欢", "首选", "prefer", "like"],
292        ),
293        (
294            MemoryCategory::Solution,
295            ["解决", "修复", "搞定", "改成", "fixed", "solved"],
296        ),
297        (
298            MemoryCategory::Finding,
299            ["发现", "原来", "原因", "定位", "found", "reason"],
300        ),
301        (
302            MemoryCategory::Technical,
303            ["技术栈", "框架", "用的", "基于", "stack", "using"],
304        ),
305        (
306            MemoryCategory::Structure,
307            ["入口", "主文件", "目录", "位于", "entry", "main"],
308        ),
309    ];
310
311    for (category, keywords) in patterns {
312        for keyword in keywords {
313            if text_lower.contains(&keyword.to_lowercase()) {
314                let content = extract_memory_content(text, keyword);
315                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
316                    entries.push(MemoryEntry::new(
317                        category,
318                        content,
319                        session_id.map(|s| s.to_string()),
320                        project_path.map(|p| p.to_string()),
321                    ));
322                }
323            }
324        }
325    }
326
327    deduplicate_entries(entries)
328}
329
330/// Detect memories from text (wrapper for fallback).
331pub fn detect_memories_from_text(
332    text: &str,
333    session_id: Option<&str>,
334    project_path: Option<&str>,
335) -> Vec<MemoryEntry> {
336    detect_memories_fallback(text, session_id, project_path)
337}
338
339/// Smart detection: AI-first with rule-based fallback.
340///
341/// Priority order:
342/// 1. AI extraction (if text > 200 chars and extractor available)
343/// 2. Rule-based fallback (if AI fails or text too short)
344pub async fn detect_memories_smart(
345    text: &str,
346    session_id: Option<&str>,
347    project_path: Option<&str>,
348    extractor: Option<&AiMemoryExtractor>,
349) -> Vec<MemoryEntry> {
350    let mode = AiDetectionMode::from_env();
351    let text_len = text.len();
352
353    // Determine if we should try AI first
354    // Only use AI for text > 200 chars (avoid API overhead for short texts)
355    let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
356
357    // Debug log: show method and model
358    let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
359    crate::debug::debug_log().memory_ai_detection(
360        model_name,
361        0, // Will update after detection
362        text_len,
363        should_try_ai,
364    );
365
366    if should_try_ai && let Some(ex) = extractor {
367        if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
368            // AI succeeded - use AI results entirely (skip hardcoded rules)
369            // Debug log: AI result
370            crate::debug::debug_log().memory_ai_detection(
371                ex.model_name(),
372                ai_entries.len(),
373                text_len,
374                true,
375            );
376            return deduplicate_entries(ai_entries);
377        }
378        // AI failed - log and skip rule-based fallback (per user request)
379        log::warn!("AI memory extraction failed, skipping detection for this turn");
380        return Vec::new();
381    }
382
383    // For short texts (< 200 chars), skip detection entirely (per user request)
384    // No rule-based fallback
385    Vec::new()
386}
387
388fn extract_memory_content(text: &str, keyword: &str) -> String {
389    let text_lower = text.to_lowercase();
390    let keyword_lower = keyword.to_lowercase();
391
392    let pos = match text_lower.find(&keyword_lower) {
393        Some(p) => p,
394        None => return String::new(),
395    };
396
397    // Find sentence containing the keyword
398    let start = text[..pos]
399        .rfind(['.', '。', '\n'])
400        .map(|i| i + 1)
401        .unwrap_or(0);
402
403    let end = text[pos..]
404        .find(['.', '。', '\n'])
405        .map(|i| pos + i + 1)
406        .unwrap_or(text.len());
407
408    let sentence = text[start..end].trim();
409
410    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
411        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
412    } else {
413        sentence.to_string()
414    }
415}
416
417/// Infer category from content.
418pub fn infer_category_from_content(content: &str) -> MemoryCategory {
419    let lower = content.to_lowercase();
420
421    if lower.contains("决定")
422        || lower.contains("选择")
423        || lower.contains("采用")
424        || lower.contains("decided")
425    {
426        return MemoryCategory::Decision;
427    }
428    if lower.contains("喜欢")
429        || lower.contains("偏好")
430        || lower.contains("习惯")
431        || lower.contains("prefer")
432    {
433        return MemoryCategory::Preference;
434    }
435    if lower.contains("解决")
436        || lower.contains("修复")
437        || lower.contains("搞定")
438        || lower.contains("fixed")
439    {
440        return MemoryCategory::Solution;
441    }
442    if lower.contains("发现")
443        || lower.contains("原因")
444        || lower.contains("原来")
445        || lower.contains("found")
446    {
447        return MemoryCategory::Finding;
448    }
449    if lower.contains("技术")
450        || lower.contains("框架")
451        || lower.contains("库")
452        || lower.contains("tech")
453    {
454        return MemoryCategory::Technical;
455    }
456    if lower.contains("文件")
457        || lower.contains("目录")
458        || lower.contains("入口")
459        || lower.contains("file")
460    {
461        return MemoryCategory::Structure;
462    }
463
464    MemoryCategory::Finding // Default
465}