Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11// ============================================================================
12// Memory Extractor Trait
13// ============================================================================
14
15/// Trait for memory extraction implementations.
16#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18    /// Extract memories from conversation text using AI.
19    async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21    /// Get the model name used for extraction.
22    fn model_name(&self) -> &str;
23}
24
25/// AI-based memory extractor using a fast/cheap model.
26pub struct AiMemoryExtractor {
27    provider: Box<dyn crate::providers::Provider>,
28    model: String,
29}
30
31impl AiMemoryExtractor {
32    /// Create a new AI memory extractor.
33    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34        Self { provider, model }
35    }
36}
37
38const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
39
40记忆类型:
41- decision: 项目或技术选型的决定
42- preference: 用户习惯或偏好
43- solution: 解决问题的具体方法
44- finding: 重要发现或信息
45- technical: 技术栈或框架信息
46- structure: 项目结构信息
47
48输出格式(严格 JSON):
49{
50  "memories": [
51    {
52      "category": "decision",
53      "content": "采用 PostgreSQL 作为主数据库",
54      "importance": 85,
55      "keywords": ["PostgreSQL", "数据库", "database"],
56      "tags": ["backend", "storage"]
57    }
58  ]
59}
60
61关键词提取要求:
62- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
63- 中英文关键词都提取
64- 用于后续记忆检索匹配
65
66标签提取要求:
67- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
68- 用于记忆分类筛选
69
70只返回 JSON,不要其他解释。"#;
71
72#[async_trait::async_trait]
73impl MemoryExtractor for AiMemoryExtractor {
74    async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
75        use crate::providers::{ChatRequest, Message, MessageContent, Role};
76
77        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
78        let truncated = truncate_chars(text, 4000);
79
80        let request = ChatRequest {
81            messages: vec![Message {
82                role: Role::User,
83                content: MessageContent::Text(format!(
84                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
85                    truncated
86                )),
87            }],
88            tools: vec![],
89            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
90            think: false,
91            max_tokens: 512,
92            server_tools: vec![],
93            enable_caching: false,
94        };
95
96        let response = self.provider.chat(request).await?;
97
98        let response_text = response
99            .content
100            .iter()
101            .filter_map(|b| {
102                if let crate::providers::ContentBlock::Text { text } = b {
103                    Some(text.clone())
104                } else {
105                    None
106                }
107            })
108            .collect::<Vec<_>>()
109            .join("");
110
111        parse_memory_response(&response_text, session_id)
112    }
113
114    fn model_name(&self) -> &str {
115        &self.model
116    }
117}
118
119fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
120    let cleaned = json_text
121        .trim()
122        .trim_start_matches("```json")
123        .trim_start_matches("```")
124        .trim_end_matches("```")
125        .trim();
126
127    #[derive(Deserialize)]
128    struct MemoryResponse {
129        memories: Vec<MemoryItem>,
130    }
131
132    #[derive(Deserialize)]
133    struct MemoryItem {
134        category: String,
135        content: String,
136        #[serde(default)]
137        importance: f64,
138        #[serde(default)]
139        keywords: Vec<String>,
140        #[serde(default)]
141        tags: Vec<String>,
142    }
143
144    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
145
146    let entries = parsed
147        .memories
148        .into_iter()
149        .filter_map(|item| {
150            let category = match item.category.to_lowercase().as_str() {
151                "decision" => MemoryCategory::Decision,
152                "preference" => MemoryCategory::Preference,
153                "solution" => MemoryCategory::Solution,
154                "finding" => MemoryCategory::Finding,
155                "technical" => MemoryCategory::Technical,
156                "structure" => MemoryCategory::Structure,
157                _ => return None,
158            };
159
160            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
161                return None;
162            }
163
164            let mut entry =
165                MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()));
166            if item.importance > 0.0 {
167                entry.importance = item.importance.clamp(0.0, 100.0);
168            }
169            // Add AI-extracted keywords and tags
170            if !item.keywords.is_empty() {
171                entry.tags.extend(item.keywords);
172            }
173            if !item.tags.is_empty() {
174                entry.tags.extend(item.tags);
175            }
176            entry.tags.dedup();
177
178            Some(entry)
179        })
180        .collect();
181
182    Ok(deduplicate_entries(entries))
183}
184
185fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
186    let mut seen: Vec<String> = Vec::new();
187    entries
188        .into_iter()
189        .filter(|e| {
190            let content_lower = e.content.to_lowercase();
191            if seen.iter().any(|s| {
192                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
193            }) {
194                false
195            } else {
196                seen.push(content_lower);
197                true
198            }
199        })
200        .take(MAX_DETECTED_ENTRIES)
201        .collect()
202}
203
204// ============================================================================
205// Rule-based Detection (uses KeywordsConfig)
206// ============================================================================
207
208/// Detect memories from text using configurable patterns.
209pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
210    let config = KeywordsConfig::load();
211    let mut entries = Vec::new();
212    let text_lower = text.to_lowercase();
213
214    let categories = [
215        (MemoryCategory::Decision, "decision"),
216        (MemoryCategory::Preference, "preference"),
217        (MemoryCategory::Solution, "solution"),
218        (MemoryCategory::Finding, "finding"),
219        (MemoryCategory::Technical, "technical"),
220        (MemoryCategory::Structure, "structure"),
221    ];
222
223    for (category, key) in categories {
224        let patterns = config
225            .patterns
226            .get(key)
227            .map(|v| v.as_slice())
228            .unwrap_or(&[]);
229        for keyword in patterns {
230            if text_lower.contains(&keyword.to_lowercase()) {
231                let content = extract_memory_content(text, keyword);
232                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
233                    entries.push(MemoryEntry::new(
234                        category,
235                        content,
236                        session_id.map(|s| s.to_string()),
237                    ));
238                }
239            }
240        }
241    }
242
243    deduplicate_entries(entries)
244}
245
246/// Detect memories from text (wrapper for fallback).
247pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
248    detect_memories_fallback(text, session_id)
249}
250
251/// Smart detection: AI-first with rule-based fallback.
252///
253/// Priority order:
254/// 1. AI extraction (if text > 200 chars and extractor available)
255/// 2. Rule-based fallback (if AI fails or text too short)
256pub async fn detect_memories_smart(
257    text: &str,
258    session_id: Option<&str>,
259    extractor: Option<&AiMemoryExtractor>,
260) -> Vec<MemoryEntry> {
261    let mode = AiDetectionMode::from_env();
262    let text_len = text.len();
263
264    // Determine if we should try AI first
265    let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200; // Minimum text length for AI (avoid API overhead for short texts)
266
267    if should_try_ai && let Some(ex) = extractor {
268        if let Ok(ai_entries) = ex.extract(text, session_id).await {
269            // AI succeeded - use AI results entirely (skip hardcoded rules)
270            return deduplicate_entries(ai_entries);
271        }
272        // AI failed - log and fall back to rules
273        log::warn!("AI memory extraction failed, falling back to rule-based");
274    }
275
276    // Fallback: rule-based detection using KeywordsConfig
277    detect_memories_fallback(text, session_id)
278}
279
280fn extract_memory_content(text: &str, keyword: &str) -> String {
281    let text_lower = text.to_lowercase();
282    let keyword_lower = keyword.to_lowercase();
283
284    let pos = match text_lower.find(&keyword_lower) {
285        Some(p) => p,
286        None => return String::new(),
287    };
288
289    // Find sentence containing the keyword
290    let start = text[..pos]
291        .rfind(['.', '。', '\n'])
292        .map(|i| i + 1)
293        .unwrap_or(0);
294
295    let end = text[pos..]
296        .find(['.', '。', '\n'])
297        .map(|i| pos + i + 1)
298        .unwrap_or(text.len());
299
300    let sentence = text[start..end].trim();
301
302    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
303        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
304    } else {
305        sentence.to_string()
306    }
307}
308
309/// Infer category from content.
310pub fn infer_category_from_content(content: &str) -> MemoryCategory {
311    let lower = content.to_lowercase();
312
313    if lower.contains("决定")
314        || lower.contains("选择")
315        || lower.contains("采用")
316        || lower.contains("decided")
317    {
318        return MemoryCategory::Decision;
319    }
320    if lower.contains("喜欢")
321        || lower.contains("偏好")
322        || lower.contains("习惯")
323        || lower.contains("prefer")
324    {
325        return MemoryCategory::Preference;
326    }
327    if lower.contains("解决")
328        || lower.contains("修复")
329        || lower.contains("搞定")
330        || lower.contains("fixed")
331    {
332        return MemoryCategory::Solution;
333    }
334    if lower.contains("发现")
335        || lower.contains("原因")
336        || lower.contains("原来")
337        || lower.contains("found")
338    {
339        return MemoryCategory::Finding;
340    }
341    if lower.contains("技术")
342        || lower.contains("框架")
343        || lower.contains("库")
344        || lower.contains("tech")
345    {
346        return MemoryCategory::Technical;
347    }
348    if lower.contains("文件")
349        || lower.contains("目录")
350        || lower.contains("入口")
351        || lower.contains("file")
352    {
353        return MemoryCategory::Structure;
354    }
355
356    MemoryCategory::Finding // Default
357}