Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use anyhow::Result;
4use serde::Deserialize;
5use crate::truncate::truncate_chars;
6
7use super::config::*;
8use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
9
10// ============================================================================
11// Memory Extractor Trait
12// ============================================================================
13
14/// Trait for memory extraction implementations.
15#[async_trait::async_trait]
16pub trait MemoryExtractor: Send + Sync {
17    /// Extract memories from conversation text using AI.
18    async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
19
20    /// Get the model name used for extraction.
21    fn model_name(&self) -> &str;
22}
23
24/// AI-based memory extractor using a fast/cheap model.
25pub struct AiMemoryExtractor {
26    provider: Box<dyn crate::providers::Provider>,
27    model: String,
28}
29
30impl AiMemoryExtractor {
31    /// Create a new AI memory extractor.
32    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
33        Self { provider, model }
34    }
35}
36
37const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。你的任务是从对话中识别并提取值得长期记忆的关键信息。
38
39记忆类型:
401. Decision(决策): 项目或技术选型的决定
412. Preference(偏好): 用户习惯或偏好
423. Solution(解决方案): 解决问题的具体方法
434. Finding(发现): 重要发现或信息
445. Technical(技术): 技术栈或框架信息
456. Structure(结构): 项目结构信息
46
47输出格式(严格 JSON):
48{"memories": [{"category": "decision", "content": "...", "importance": 90}]}
49"#;
50
51#[async_trait::async_trait]
52impl MemoryExtractor for AiMemoryExtractor {
53    async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
54        use crate::providers::{ChatRequest, Message, MessageContent, Role};
55
56        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
57        let truncated = truncate_chars(text, 4000);
58
59        let request = ChatRequest {
60            messages: vec![Message {
61                role: Role::User,
62                content: MessageContent::Text(format!(
63                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
64                    truncated
65                )),
66            }],
67            tools: vec![],
68            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
69            think: false,
70            max_tokens: 512,
71            server_tools: vec![],
72            enable_caching: false,
73        };
74
75        let response = self.provider.chat(request).await?;
76
77        let response_text = response
78            .content
79            .iter()
80            .filter_map(|b| {
81                if let crate::providers::ContentBlock::Text { text } = b {
82                    Some(text.clone())
83                } else {
84                    None
85                }
86            })
87            .collect::<Vec<_>>()
88            .join("");
89
90        parse_memory_response(&response_text, session_id)
91    }
92
93    fn model_name(&self) -> &str {
94        &self.model
95    }
96}
97
98fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
99    let cleaned = json_text
100        .trim()
101        .trim_start_matches("```json")
102        .trim_start_matches("```")
103        .trim_end_matches("```")
104        .trim();
105
106    #[derive(Deserialize)]
107    struct MemoryResponse {
108        memories: Vec<MemoryItem>,
109    }
110
111    #[derive(Deserialize)]
112    struct MemoryItem {
113        category: String,
114        content: String,
115        #[serde(default)]
116        importance: f64,
117    }
118
119    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
120
121    let entries = parsed
122        .memories
123        .into_iter()
124        .filter_map(|item| {
125            let category = match item.category.to_lowercase().as_str() {
126                "decision" => MemoryCategory::Decision,
127                "preference" => MemoryCategory::Preference,
128                "solution" => MemoryCategory::Solution,
129                "finding" => MemoryCategory::Finding,
130                "technical" => MemoryCategory::Technical,
131                "structure" => MemoryCategory::Structure,
132                _ => return None,
133            };
134
135            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
136                return None;
137            }
138
139            let mut entry =
140                MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()));
141            if item.importance > 0.0 {
142                entry.importance = item.importance.clamp(0.0, 100.0);
143            }
144
145            Some(entry)
146        })
147        .collect();
148
149    Ok(deduplicate_entries(entries))
150}
151
152fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
153    let mut seen: Vec<String> = Vec::new();
154    entries
155        .into_iter()
156        .filter(|e| {
157            let content_lower = e.content.to_lowercase();
158            if seen.iter().any(|s| {
159                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
160            }) {
161                false
162            } else {
163                seen.push(content_lower);
164                true
165            }
166        })
167        .take(MAX_DETECTED_ENTRIES)
168        .collect()
169}
170
171// ============================================================================
172// Rule-based Detection
173// ============================================================================
174
175/// Detect memories from text using rule-based patterns.
176pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
177    let mut entries = Vec::new();
178    let text_lower = text.to_lowercase();
179
180    let patterns: Vec<(MemoryCategory, Vec<&str>)> = vec![
181        (
182            MemoryCategory::Decision,
183            vec![
184                "最终决定",
185                "决定采用",
186                "我们决定",
187                "选择使用",
188                "采用方案",
189                "定下来",
190                "就定这个",
191                "敲定",
192                "拍板",
193                "we decided",
194                "final decision",
195            ],
196        ),
197        (
198            MemoryCategory::Preference,
199            vec![
200                "我喜欢",
201                "我偏好",
202                "我习惯",
203                "最常用",
204                "一直用",
205                "推荐",
206                "建议使用",
207                "首选",
208                "i like",
209                "i prefer",
210            ],
211        ),
212        (
213            MemoryCategory::Solution,
214            vec![
215                "通过修改",
216                "解决方案是",
217                "搞定",
218                "解决了",
219                "修复成功",
220                "改成",
221                "优化了",
222                "fixed by",
223                "solved by",
224            ],
225        ),
226        (
227            MemoryCategory::Finding,
228            vec![
229                "发现",
230                "注意到",
231                "原来",
232                "找到问题",
233                "定位到",
234                "排查发现",
235                "原因是",
236                "found that",
237                "discovered",
238            ],
239        ),
240        (
241            MemoryCategory::Technical,
242            vec![
243                "技术栈是",
244                "框架使用",
245                "用的是",
246                "基于",
247                "tech stack",
248                "using framework",
249                "built with",
250            ],
251        ),
252        (
253            MemoryCategory::Structure,
254            vec![
255                "入口文件是",
256                "主文件位于",
257                "项目结构是",
258                "入口是",
259                "目录是",
260                "entry point",
261                "main file",
262            ],
263        ),
264    ];
265
266    for (category, keywords) in patterns {
267        for keyword in keywords {
268            if text_lower.contains(keyword) {
269                let content = extract_memory_content(text, keyword);
270                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
271                    entries.push(MemoryEntry::new(
272                        category,
273                        content,
274                        session_id.map(|s| s.to_string()),
275                    ));
276                }
277            }
278        }
279    }
280
281    deduplicate_entries(entries)
282}
283
284/// Detect memories from text (wrapper for fallback).
285pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
286    detect_memories_fallback(text, session_id)
287}
288
289/// Smart detection: rule-based + AI fallback.
290pub async fn detect_memories_smart(
291    text: &str,
292    session_id: Option<&str>,
293    extractor: Option<&AiMemoryExtractor>,
294) -> Vec<MemoryEntry> {
295    // First try rule-based
296    let rule_entries = detect_memories_fallback(text, session_id);
297
298    // Check if we need AI fallback
299    let mode = AiDetectionMode::from_env();
300    if mode.should_use_ai_for_text(text.len())
301        && extractor.is_some()
302        && let Some(ex) = extractor
303        && let Ok(ai_entries) = ex.extract(text, session_id).await
304    {
305        // Combine and deduplicate
306        let combined = rule_entries.into_iter().chain(ai_entries).collect();
307        return deduplicate_entries(combined);
308    }
309
310    rule_entries
311}
312
313fn extract_memory_content(text: &str, keyword: &str) -> String {
314    let text_lower = text.to_lowercase();
315    let keyword_lower = keyword.to_lowercase();
316
317    let pos = match text_lower.find(&keyword_lower) {
318        Some(p) => p,
319        None => return String::new(),
320    };
321
322    // Find sentence containing the keyword
323    let start = text[..pos]
324        .rfind(['.', '。', '\n'])
325        .map(|i| i + 1)
326        .unwrap_or(0);
327
328    let end = text[pos..]
329        .find(['.', '。', '\n'])
330        .map(|i| pos + i + 1)
331        .unwrap_or(text.len());
332
333    let sentence = text[start..end].trim();
334
335    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
336        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
337    } else {
338        sentence.to_string()
339    }
340}
341
342/// Infer category from content.
343pub fn infer_category_from_content(content: &str) -> MemoryCategory {
344    let lower = content.to_lowercase();
345
346    if lower.contains("决定")
347        || lower.contains("选择")
348        || lower.contains("采用")
349        || lower.contains("decided")
350    {
351        return MemoryCategory::Decision;
352    }
353    if lower.contains("喜欢")
354        || lower.contains("偏好")
355        || lower.contains("习惯")
356        || lower.contains("prefer")
357    {
358        return MemoryCategory::Preference;
359    }
360    if lower.contains("解决")
361        || lower.contains("修复")
362        || lower.contains("搞定")
363        || lower.contains("fixed")
364    {
365        return MemoryCategory::Solution;
366    }
367    if lower.contains("发现")
368        || lower.contains("原因")
369        || lower.contains("原来")
370        || lower.contains("found")
371    {
372        return MemoryCategory::Finding;
373    }
374    if lower.contains("技术")
375        || lower.contains("框架")
376        || lower.contains("库")
377        || lower.contains("tech")
378    {
379        return MemoryCategory::Technical;
380    }
381    if lower.contains("文件")
382        || lower.contains("目录")
383        || lower.contains("入口")
384        || lower.contains("file")
385    {
386        return MemoryCategory::Structure;
387    }
388
389    MemoryCategory::Finding // Default
390}