Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11// ============================================================================
12// Memory Extractor Trait
13// ============================================================================
14
15/// Trait for memory extraction implementations.
16#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18    /// Extract memories from conversation text using AI.
19    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21    /// Get the model name used for extraction.
22    fn model_name(&self) -> &str;
23}
24
25/// AI-based memory extractor using a fast/cheap model.
26pub struct AiMemoryExtractor {
27    provider: Box<dyn crate::providers::Provider>,
28    model: String,
29}
30
31impl AiMemoryExtractor {
32    /// Create a new AI memory extractor.
33    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34        Self { provider, model }
35    }
36}
37
38const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
39
40记忆类型:
41- decision: 项目或技术选型的决定
42- preference: 用户习惯或偏好
43- solution: 解决问题的具体方法
44- finding: 重要发现或信息
45- technical: 技术栈或框架信息
46- structure: 项目结构信息
47
48输出格式(严格 JSON):
49{
50  "memories": [
51    {
52      "category": "decision",
53      "content": "采用 PostgreSQL 作为主数据库",
54      "importance": 85,
55      "keywords": ["PostgreSQL", "数据库", "database"],
56      "tags": ["backend", "storage"]
57    }
58  ]
59}
60
61关键词提取要求:
62- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
63- 中英文关键词都提取
64- 用于后续记忆检索匹配
65
66标签提取要求:
67- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
68- 用于记忆分类筛选
69
70只返回 JSON,不要其他解释。"#;
71
72#[async_trait::async_trait]
73impl MemoryExtractor for AiMemoryExtractor {
74    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
75        use crate::providers::{ChatRequest, Message, MessageContent, Role};
76
77        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
78        let truncated = truncate_chars(text, 4000);
79
80        let request = ChatRequest {
81            messages: vec![Message {
82                role: Role::User,
83                content: MessageContent::Text(format!(
84                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
85                    truncated
86                )),
87            }],
88            tools: vec![],
89            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
90            think: false,
91            max_tokens: 512,
92            server_tools: vec![],
93            enable_caching: false,
94        };
95
96        let response = self.provider.chat(request).await?;
97
98        let response_text = response
99            .content
100            .iter()
101            .filter_map(|b| {
102                if let crate::providers::ContentBlock::Text { text } = b {
103                    Some(text.clone())
104                } else {
105                    None
106                }
107            })
108            .collect::<Vec<_>>()
109            .join("");
110
111        parse_memory_response(&response_text, session_id, project_path)
112    }
113
114    fn model_name(&self) -> &str {
115        &self.model
116    }
117}
118
119fn parse_memory_response(json_text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
120    let cleaned = json_text
121        .trim()
122        .trim_start_matches("```json")
123        .trim_start_matches("```")
124        .trim_end_matches("```")
125        .trim();
126
127    #[derive(Deserialize)]
128    struct MemoryResponse {
129        memories: Vec<MemoryItem>,
130    }
131
132    #[derive(Deserialize)]
133    struct MemoryItem {
134        category: String,
135        content: String,
136        #[serde(default)]
137        importance: f64,
138        #[serde(default)]
139        keywords: Vec<String>,
140        #[serde(default)]
141        tags: Vec<String>,
142    }
143
144    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
145
146    let entries = parsed
147        .memories
148        .into_iter()
149        .filter_map(|item| {
150            let category = match item.category.to_lowercase().as_str() {
151                "decision" => MemoryCategory::Decision,
152                "preference" => MemoryCategory::Preference,
153                "solution" => MemoryCategory::Solution,
154                "finding" => MemoryCategory::Finding,
155                "technical" => MemoryCategory::Technical,
156                "structure" => MemoryCategory::Structure,
157                _ => return None,
158            };
159
160            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
161                return None;
162            }
163
164            let mut entry =
165                MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()), project_path.map(|p| p.to_string()));
166            if item.importance > 0.0 {
167                entry.importance = item.importance.clamp(0.0, 100.0);
168            }
169            // Add AI-extracted keywords and tags
170            if !item.keywords.is_empty() {
171                entry.tags.extend(item.keywords);
172            }
173            if !item.tags.is_empty() {
174                entry.tags.extend(item.tags);
175            }
176            entry.tags.dedup();
177
178            Some(entry)
179        })
180        .collect();
181
182    Ok(deduplicate_entries(entries))
183}
184
185fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
186    let mut seen: Vec<String> = Vec::new();
187    entries
188        .into_iter()
189        .filter(|e| {
190            let content_lower = e.content.to_lowercase();
191            if seen.iter().any(|s| {
192                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
193            }) {
194                false
195            } else {
196                seen.push(content_lower);
197                true
198            }
199        })
200        .take(MAX_DETECTED_ENTRIES)
201        .collect()
202}
203
204// ============================================================================
205// Rule-based Detection (uses KeywordsConfig)
206// ============================================================================
207
208/// Detect memories from text using configurable patterns.
209pub fn detect_memories_fallback(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
210    let config = KeywordsConfig::load();
211    let mut entries = Vec::new();
212    let text_lower = text.to_lowercase();
213
214    let categories = [
215        (MemoryCategory::Decision, "decision"),
216        (MemoryCategory::Preference, "preference"),
217        (MemoryCategory::Solution, "solution"),
218        (MemoryCategory::Finding, "finding"),
219        (MemoryCategory::Technical, "technical"),
220        (MemoryCategory::Structure, "structure"),
221    ];
222
223    for (category, key) in categories {
224        let patterns = config
225            .patterns
226            .get(key)
227            .map(|v| v.as_slice())
228            .unwrap_or(&[]);
229        for keyword in patterns {
230            if text_lower.contains(&keyword.to_lowercase()) {
231                let content = extract_memory_content(text, keyword);
232                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
233                    entries.push(MemoryEntry::new(
234                        category,
235                        content,
236                        session_id.map(|s| s.to_string()),
237                        project_path.map(|p| p.to_string()),
238                    ));
239                }
240            }
241        }
242    }
243
244    deduplicate_entries(entries)
245}
246
247/// Detect memories from text (wrapper for fallback).
248pub fn detect_memories_from_text(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
249    detect_memories_fallback(text, session_id, project_path)
250}
251
252/// Smart detection: AI-first with rule-based fallback.
253///
254/// Priority order:
255/// 1. AI extraction (if text > 200 chars and extractor available)
256/// 2. Rule-based fallback (if AI fails or text too short)
257pub async fn detect_memories_smart(
258    text: &str,
259    session_id: Option<&str>,
260    project_path: Option<&str>,
261    extractor: Option<&AiMemoryExtractor>,
262) -> Vec<MemoryEntry> {
263    let mode = AiDetectionMode::from_env();
264    let text_len = text.len();
265
266    // Determine if we should try AI first
267    // Only use AI for text > 200 chars (avoid API overhead for short texts)
268    let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
269
270    // Debug log: show method and model
271    let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
272    crate::debug::debug_log().memory_ai_detection(
273        model_name,
274        0, // Will update after detection
275        text_len,
276        should_try_ai,
277    );
278
279    if should_try_ai && let Some(ex) = extractor {
280        if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
281            // AI succeeded - use AI results entirely (skip hardcoded rules)
282            // Debug log: AI result
283            crate::debug::debug_log().memory_ai_detection(
284                ex.model_name(),
285                ai_entries.len(),
286                text_len,
287                true,
288            );
289            return deduplicate_entries(ai_entries);
290        }
291        // AI failed - log and skip rule-based fallback (per user request)
292        log::warn!("AI memory extraction failed, skipping detection for this turn");
293        return Vec::new();
294    }
295
296    // For short texts (< 200 chars), skip detection entirely (per user request)
297    // No rule-based fallback
298    Vec::new()
299}
300
301fn extract_memory_content(text: &str, keyword: &str) -> String {
302    let text_lower = text.to_lowercase();
303    let keyword_lower = keyword.to_lowercase();
304
305    let pos = match text_lower.find(&keyword_lower) {
306        Some(p) => p,
307        None => return String::new(),
308    };
309
310    // Find sentence containing the keyword
311    let start = text[..pos]
312        .rfind(['.', '。', '\n'])
313        .map(|i| i + 1)
314        .unwrap_or(0);
315
316    let end = text[pos..]
317        .find(['.', '。', '\n'])
318        .map(|i| pos + i + 1)
319        .unwrap_or(text.len());
320
321    let sentence = text[start..end].trim();
322
323    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
324        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
325    } else {
326        sentence.to_string()
327    }
328}
329
330/// Infer category from content.
331pub fn infer_category_from_content(content: &str) -> MemoryCategory {
332    let lower = content.to_lowercase();
333
334    if lower.contains("决定")
335        || lower.contains("选择")
336        || lower.contains("采用")
337        || lower.contains("decided")
338    {
339        return MemoryCategory::Decision;
340    }
341    if lower.contains("喜欢")
342        || lower.contains("偏好")
343        || lower.contains("习惯")
344        || lower.contains("prefer")
345    {
346        return MemoryCategory::Preference;
347    }
348    if lower.contains("解决")
349        || lower.contains("修复")
350        || lower.contains("搞定")
351        || lower.contains("fixed")
352    {
353        return MemoryCategory::Solution;
354    }
355    if lower.contains("发现")
356        || lower.contains("原因")
357        || lower.contains("原来")
358        || lower.contains("found")
359    {
360        return MemoryCategory::Finding;
361    }
362    if lower.contains("技术")
363        || lower.contains("框架")
364        || lower.contains("库")
365        || lower.contains("tech")
366    {
367        return MemoryCategory::Technical;
368    }
369    if lower.contains("文件")
370        || lower.contains("目录")
371        || lower.contains("入口")
372        || lower.contains("file")
373    {
374        return MemoryCategory::Structure;
375    }
376
377    MemoryCategory::Finding // Default
378}