matrixcode_core/memory/
extractor.rs1use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21 fn model_name(&self) -> &str;
23}
24
25pub struct AiMemoryExtractor {
27 provider: Box<dyn crate::providers::Provider>,
28 model: String,
29}
30
31impl AiMemoryExtractor {
32 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34 Self { provider, model }
35 }
36}
37
38const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
39
40记忆类型:
41- decision: 项目或技术选型的决定
42- preference: 用户习惯或偏好
43- solution: 解决问题的具体方法
44- finding: 重要发现或信息
45- technical: 技术栈或框架信息
46- structure: 项目结构信息
47
48输出格式(严格 JSON):
49{
50 "memories": [
51 {
52 "category": "decision",
53 "content": "采用 PostgreSQL 作为主数据库",
54 "importance": 85,
55 "keywords": ["PostgreSQL", "数据库", "database"],
56 "tags": ["backend", "storage"]
57 }
58 ]
59}
60
61关键词提取要求:
62- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
63- 中英文关键词都提取
64- 用于后续记忆检索匹配
65
66标签提取要求:
67- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
68- 用于记忆分类筛选
69
70只返回 JSON,不要其他解释。"#;
71
72#[async_trait::async_trait]
73impl MemoryExtractor for AiMemoryExtractor {
74 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
75 use crate::providers::{ChatRequest, Message, MessageContent, Role};
76
77 let truncated = truncate_chars(text, 4000);
79
80 let request = ChatRequest {
81 messages: vec![Message {
82 role: Role::User,
83 content: MessageContent::Text(format!(
84 "请从以下对话中提取值得记忆的关键信息:\n\n{}",
85 truncated
86 )),
87 }],
88 tools: vec![],
89 system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
90 think: false,
91 max_tokens: 512,
92 server_tools: vec![],
93 enable_caching: false,
94 };
95
96 let response = self.provider.chat(request).await?;
97
98 let response_text = response
99 .content
100 .iter()
101 .filter_map(|b| {
102 if let crate::providers::ContentBlock::Text { text } = b {
103 Some(text.clone())
104 } else {
105 None
106 }
107 })
108 .collect::<Vec<_>>()
109 .join("");
110
111 parse_memory_response(&response_text, session_id)
112 }
113
114 fn model_name(&self) -> &str {
115 &self.model
116 }
117}
118
119fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
120 let cleaned = json_text
121 .trim()
122 .trim_start_matches("```json")
123 .trim_start_matches("```")
124 .trim_end_matches("```")
125 .trim();
126
127 #[derive(Deserialize)]
128 struct MemoryResponse {
129 memories: Vec<MemoryItem>,
130 }
131
132 #[derive(Deserialize)]
133 struct MemoryItem {
134 category: String,
135 content: String,
136 #[serde(default)]
137 importance: f64,
138 #[serde(default)]
139 keywords: Vec<String>,
140 #[serde(default)]
141 tags: Vec<String>,
142 }
143
144 let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
145
146 let entries = parsed
147 .memories
148 .into_iter()
149 .filter_map(|item| {
150 let category = match item.category.to_lowercase().as_str() {
151 "decision" => MemoryCategory::Decision,
152 "preference" => MemoryCategory::Preference,
153 "solution" => MemoryCategory::Solution,
154 "finding" => MemoryCategory::Finding,
155 "technical" => MemoryCategory::Technical,
156 "structure" => MemoryCategory::Structure,
157 _ => return None,
158 };
159
160 if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
161 return None;
162 }
163
164 let mut entry =
165 MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()));
166 if item.importance > 0.0 {
167 entry.importance = item.importance.clamp(0.0, 100.0);
168 }
169 if !item.keywords.is_empty() {
171 entry.tags.extend(item.keywords);
172 }
173 if !item.tags.is_empty() {
174 entry.tags.extend(item.tags);
175 }
176 entry.tags.dedup();
177
178 Some(entry)
179 })
180 .collect();
181
182 Ok(deduplicate_entries(entries))
183}
184
185fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
186 let mut seen: Vec<String> = Vec::new();
187 entries
188 .into_iter()
189 .filter(|e| {
190 let content_lower = e.content.to_lowercase();
191 if seen.iter().any(|s| {
192 AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
193 }) {
194 false
195 } else {
196 seen.push(content_lower);
197 true
198 }
199 })
200 .take(MAX_DETECTED_ENTRIES)
201 .collect()
202}
203
204pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
210 let config = KeywordsConfig::load();
211 let mut entries = Vec::new();
212 let text_lower = text.to_lowercase();
213
214 let categories = [
215 (MemoryCategory::Decision, "decision"),
216 (MemoryCategory::Preference, "preference"),
217 (MemoryCategory::Solution, "solution"),
218 (MemoryCategory::Finding, "finding"),
219 (MemoryCategory::Technical, "technical"),
220 (MemoryCategory::Structure, "structure"),
221 ];
222
223 for (category, key) in categories {
224 let patterns = config
225 .patterns
226 .get(key)
227 .map(|v| v.as_slice())
228 .unwrap_or(&[]);
229 for keyword in patterns {
230 if text_lower.contains(&keyword.to_lowercase()) {
231 let content = extract_memory_content(text, keyword);
232 if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
233 entries.push(MemoryEntry::new(
234 category,
235 content,
236 session_id.map(|s| s.to_string()),
237 ));
238 }
239 }
240 }
241 }
242
243 deduplicate_entries(entries)
244}
245
246pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
248 detect_memories_fallback(text, session_id)
249}
250
251pub async fn detect_memories_smart(
257 text: &str,
258 session_id: Option<&str>,
259 extractor: Option<&AiMemoryExtractor>,
260) -> Vec<MemoryEntry> {
261 let mode = AiDetectionMode::from_env();
262 let text_len = text.len();
263
264 let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200; if should_try_ai && let Some(ex) = extractor {
268 if let Ok(ai_entries) = ex.extract(text, session_id).await {
269 return deduplicate_entries(ai_entries);
271 }
272 log::warn!("AI memory extraction failed, falling back to rule-based");
274 }
275
276 detect_memories_fallback(text, session_id)
278}
279
280fn extract_memory_content(text: &str, keyword: &str) -> String {
281 let text_lower = text.to_lowercase();
282 let keyword_lower = keyword.to_lowercase();
283
284 let pos = match text_lower.find(&keyword_lower) {
285 Some(p) => p,
286 None => return String::new(),
287 };
288
289 let start = text[..pos]
291 .rfind(['.', '。', '\n'])
292 .map(|i| i + 1)
293 .unwrap_or(0);
294
295 let end = text[pos..]
296 .find(['.', '。', '\n'])
297 .map(|i| pos + i + 1)
298 .unwrap_or(text.len());
299
300 let sentence = text[start..end].trim();
301
302 if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
303 sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
304 } else {
305 sentence.to_string()
306 }
307}
308
309pub fn infer_category_from_content(content: &str) -> MemoryCategory {
311 let lower = content.to_lowercase();
312
313 if lower.contains("决定")
314 || lower.contains("选择")
315 || lower.contains("采用")
316 || lower.contains("decided")
317 {
318 return MemoryCategory::Decision;
319 }
320 if lower.contains("喜欢")
321 || lower.contains("偏好")
322 || lower.contains("习惯")
323 || lower.contains("prefer")
324 {
325 return MemoryCategory::Preference;
326 }
327 if lower.contains("解决")
328 || lower.contains("修复")
329 || lower.contains("搞定")
330 || lower.contains("fixed")
331 {
332 return MemoryCategory::Solution;
333 }
334 if lower.contains("发现")
335 || lower.contains("原因")
336 || lower.contains("原来")
337 || lower.contains("found")
338 {
339 return MemoryCategory::Finding;
340 }
341 if lower.contains("技术")
342 || lower.contains("框架")
343 || lower.contains("库")
344 || lower.contains("tech")
345 {
346 return MemoryCategory::Technical;
347 }
348 if lower.contains("文件")
349 || lower.contains("目录")
350 || lower.contains("入口")
351 || lower.contains("file")
352 {
353 return MemoryCategory::Structure;
354 }
355
356 MemoryCategory::Finding }