1use anyhow::Result;
4use serde::Deserialize;
5use crate::truncate::truncate_chars;
6
7use super::config::*;
8use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
9
10#[async_trait::async_trait]
16pub trait MemoryExtractor: Send + Sync {
17 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
19
20 fn model_name(&self) -> &str;
22}
23
24pub struct AiMemoryExtractor {
26 provider: Box<dyn crate::providers::Provider>,
27 model: String,
28}
29
30impl AiMemoryExtractor {
31 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
33 Self { provider, model }
34 }
35}
36
37const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。你的任务是从对话中识别并提取值得长期记忆的关键信息。
38
39记忆类型:
401. Decision(决策): 项目或技术选型的决定
412. Preference(偏好): 用户习惯或偏好
423. Solution(解决方案): 解决问题的具体方法
434. Finding(发现): 重要发现或信息
445. Technical(技术): 技术栈或框架信息
456. Structure(结构): 项目结构信息
46
47输出格式(严格 JSON):
48{"memories": [{"category": "decision", "content": "...", "importance": 90}]}
49"#;
50
51#[async_trait::async_trait]
52impl MemoryExtractor for AiMemoryExtractor {
53 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
54 use crate::providers::{ChatRequest, Message, MessageContent, Role};
55
56 let truncated = truncate_chars(text, 4000);
58
59 let request = ChatRequest {
60 messages: vec![Message {
61 role: Role::User,
62 content: MessageContent::Text(format!(
63 "请从以下对话中提取值得记忆的关键信息:\n\n{}",
64 truncated
65 )),
66 }],
67 tools: vec![],
68 system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
69 think: false,
70 max_tokens: 512,
71 server_tools: vec![],
72 enable_caching: false,
73 };
74
75 let response = self.provider.chat(request).await?;
76
77 let response_text = response
78 .content
79 .iter()
80 .filter_map(|b| {
81 if let crate::providers::ContentBlock::Text { text } = b {
82 Some(text.clone())
83 } else {
84 None
85 }
86 })
87 .collect::<Vec<_>>()
88 .join("");
89
90 parse_memory_response(&response_text, session_id)
91 }
92
93 fn model_name(&self) -> &str {
94 &self.model
95 }
96}
97
98fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
99 let cleaned = json_text
100 .trim()
101 .trim_start_matches("```json")
102 .trim_start_matches("```")
103 .trim_end_matches("```")
104 .trim();
105
106 #[derive(Deserialize)]
107 struct MemoryResponse {
108 memories: Vec<MemoryItem>,
109 }
110
111 #[derive(Deserialize)]
112 struct MemoryItem {
113 category: String,
114 content: String,
115 #[serde(default)]
116 importance: f64,
117 }
118
119 let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
120
121 let entries = parsed
122 .memories
123 .into_iter()
124 .filter_map(|item| {
125 let category = match item.category.to_lowercase().as_str() {
126 "decision" => MemoryCategory::Decision,
127 "preference" => MemoryCategory::Preference,
128 "solution" => MemoryCategory::Solution,
129 "finding" => MemoryCategory::Finding,
130 "technical" => MemoryCategory::Technical,
131 "structure" => MemoryCategory::Structure,
132 _ => return None,
133 };
134
135 if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
136 return None;
137 }
138
139 let mut entry =
140 MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()));
141 if item.importance > 0.0 {
142 entry.importance = item.importance.clamp(0.0, 100.0);
143 }
144
145 Some(entry)
146 })
147 .collect();
148
149 Ok(deduplicate_entries(entries))
150}
151
152fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
153 let mut seen: Vec<String> = Vec::new();
154 entries
155 .into_iter()
156 .filter(|e| {
157 let content_lower = e.content.to_lowercase();
158 if seen.iter().any(|s| {
159 AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
160 }) {
161 false
162 } else {
163 seen.push(content_lower);
164 true
165 }
166 })
167 .take(MAX_DETECTED_ENTRIES)
168 .collect()
169}
170
171pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
177 let mut entries = Vec::new();
178 let text_lower = text.to_lowercase();
179
180 let patterns: Vec<(MemoryCategory, Vec<&str>)> = vec![
181 (
182 MemoryCategory::Decision,
183 vec![
184 "最终决定",
185 "决定采用",
186 "我们决定",
187 "选择使用",
188 "采用方案",
189 "定下来",
190 "就定这个",
191 "敲定",
192 "拍板",
193 "we decided",
194 "final decision",
195 ],
196 ),
197 (
198 MemoryCategory::Preference,
199 vec![
200 "我喜欢",
201 "我偏好",
202 "我习惯",
203 "最常用",
204 "一直用",
205 "推荐",
206 "建议使用",
207 "首选",
208 "i like",
209 "i prefer",
210 ],
211 ),
212 (
213 MemoryCategory::Solution,
214 vec![
215 "通过修改",
216 "解决方案是",
217 "搞定",
218 "解决了",
219 "修复成功",
220 "改成",
221 "优化了",
222 "fixed by",
223 "solved by",
224 ],
225 ),
226 (
227 MemoryCategory::Finding,
228 vec![
229 "发现",
230 "注意到",
231 "原来",
232 "找到问题",
233 "定位到",
234 "排查发现",
235 "原因是",
236 "found that",
237 "discovered",
238 ],
239 ),
240 (
241 MemoryCategory::Technical,
242 vec![
243 "技术栈是",
244 "框架使用",
245 "用的是",
246 "基于",
247 "tech stack",
248 "using framework",
249 "built with",
250 ],
251 ),
252 (
253 MemoryCategory::Structure,
254 vec![
255 "入口文件是",
256 "主文件位于",
257 "项目结构是",
258 "入口是",
259 "目录是",
260 "entry point",
261 "main file",
262 ],
263 ),
264 ];
265
266 for (category, keywords) in patterns {
267 for keyword in keywords {
268 if text_lower.contains(keyword) {
269 let content = extract_memory_content(text, keyword);
270 if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
271 entries.push(MemoryEntry::new(
272 category,
273 content,
274 session_id.map(|s| s.to_string()),
275 ));
276 }
277 }
278 }
279 }
280
281 deduplicate_entries(entries)
282}
283
284pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
286 detect_memories_fallback(text, session_id)
287}
288
289pub async fn detect_memories_smart(
291 text: &str,
292 session_id: Option<&str>,
293 extractor: Option<&AiMemoryExtractor>,
294) -> Vec<MemoryEntry> {
295 let rule_entries = detect_memories_fallback(text, session_id);
297
298 let mode = AiDetectionMode::from_env();
300 if mode.should_use_ai_for_text(text.len())
301 && extractor.is_some()
302 && let Some(ex) = extractor
303 && let Ok(ai_entries) = ex.extract(text, session_id).await
304 {
305 let combined = rule_entries.into_iter().chain(ai_entries).collect();
307 return deduplicate_entries(combined);
308 }
309
310 rule_entries
311}
312
313fn extract_memory_content(text: &str, keyword: &str) -> String {
314 let text_lower = text.to_lowercase();
315 let keyword_lower = keyword.to_lowercase();
316
317 let pos = match text_lower.find(&keyword_lower) {
318 Some(p) => p,
319 None => return String::new(),
320 };
321
322 let start = text[..pos]
324 .rfind(['.', '。', '\n'])
325 .map(|i| i + 1)
326 .unwrap_or(0);
327
328 let end = text[pos..]
329 .find(['.', '。', '\n'])
330 .map(|i| pos + i + 1)
331 .unwrap_or(text.len());
332
333 let sentence = text[start..end].trim();
334
335 if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
336 sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
337 } else {
338 sentence.to_string()
339 }
340}
341
342pub fn infer_category_from_content(content: &str) -> MemoryCategory {
344 let lower = content.to_lowercase();
345
346 if lower.contains("决定")
347 || lower.contains("选择")
348 || lower.contains("采用")
349 || lower.contains("decided")
350 {
351 return MemoryCategory::Decision;
352 }
353 if lower.contains("喜欢")
354 || lower.contains("偏好")
355 || lower.contains("习惯")
356 || lower.contains("prefer")
357 {
358 return MemoryCategory::Preference;
359 }
360 if lower.contains("解决")
361 || lower.contains("修复")
362 || lower.contains("搞定")
363 || lower.contains("fixed")
364 {
365 return MemoryCategory::Solution;
366 }
367 if lower.contains("发现")
368 || lower.contains("原因")
369 || lower.contains("原来")
370 || lower.contains("found")
371 {
372 return MemoryCategory::Finding;
373 }
374 if lower.contains("技术")
375 || lower.contains("框架")
376 || lower.contains("库")
377 || lower.contains("tech")
378 {
379 return MemoryCategory::Technical;
380 }
381 if lower.contains("文件")
382 || lower.contains("目录")
383 || lower.contains("入口")
384 || lower.contains("file")
385 {
386 return MemoryCategory::Structure;
387 }
388
389 MemoryCategory::Finding }