1use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18 async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21 fn model_name(&self) -> &str;
23}
24
25pub struct AiMemoryExtractor {
27 provider: Box<dyn crate::providers::Provider>,
28 model: String,
29}
30
31impl AiMemoryExtractor {
32 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34 Self { provider, model }
35 }
36}
37
38const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
39
40记忆类型:
41- decision: 项目或技术选型的决定
42- preference: 用户习惯或偏好
43- solution: 解决问题的具体方法
44- finding: 重要发现或信息
45- technical: 技术栈或框架信息
46- structure: 项目结构信息
47
48输出格式(严格 JSON):
49{
50 "memories": [
51 {
52 "category": "decision",
53 "content": "采用 PostgreSQL 作为主数据库",
54 "importance": 85,
55 "keywords": ["PostgreSQL", "数据库", "database"],
56 "tags": ["backend", "storage"]
57 }
58 ]
59}
60
61关键词提取要求:
62- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
63- 中英文关键词都提取
64- 用于后续记忆检索匹配
65
66标签提取要求:
67- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
68- 用于记忆分类筛选
69
70只返回 JSON,不要其他解释。"#;
71
72#[async_trait::async_trait]
73impl MemoryExtractor for AiMemoryExtractor {
74 async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
75 use crate::providers::{ChatRequest, Message, MessageContent, Role};
76
77 let truncated = truncate_chars(text, 4000);
79
80 let request = ChatRequest {
81 messages: vec![Message {
82 role: Role::User,
83 content: MessageContent::Text(format!(
84 "请从以下对话中提取值得记忆的关键信息:\n\n{}",
85 truncated
86 )),
87 }],
88 tools: vec![],
89 system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
90 think: false,
91 max_tokens: 512,
92 server_tools: vec![],
93 enable_caching: false,
94 };
95
96 let response = self.provider.chat(request).await?;
97
98 let response_text = response
99 .content
100 .iter()
101 .filter_map(|b| {
102 if let crate::providers::ContentBlock::Text { text } = b {
103 Some(text.clone())
104 } else {
105 None
106 }
107 })
108 .collect::<Vec<_>>()
109 .join("");
110
111 parse_memory_response(&response_text, session_id, project_path)
112 }
113
114 fn model_name(&self) -> &str {
115 &self.model
116 }
117}
118
119fn parse_memory_response(json_text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
120 let cleaned = json_text
121 .trim()
122 .trim_start_matches("```json")
123 .trim_start_matches("```")
124 .trim_end_matches("```")
125 .trim();
126
127 #[derive(Deserialize)]
128 struct MemoryResponse {
129 memories: Vec<MemoryItem>,
130 }
131
132 #[derive(Deserialize)]
133 struct MemoryItem {
134 category: String,
135 content: String,
136 #[serde(default)]
137 importance: f64,
138 #[serde(default)]
139 keywords: Vec<String>,
140 #[serde(default)]
141 tags: Vec<String>,
142 }
143
144 let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
145
146 let entries = parsed
147 .memories
148 .into_iter()
149 .filter_map(|item| {
150 let category = match item.category.to_lowercase().as_str() {
151 "decision" => MemoryCategory::Decision,
152 "preference" => MemoryCategory::Preference,
153 "solution" => MemoryCategory::Solution,
154 "finding" => MemoryCategory::Finding,
155 "technical" => MemoryCategory::Technical,
156 "structure" => MemoryCategory::Structure,
157 _ => return None,
158 };
159
160 if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
161 return None;
162 }
163
164 let mut entry =
165 MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()), project_path.map(|p| p.to_string()));
166 if item.importance > 0.0 {
167 entry.importance = item.importance.clamp(0.0, 100.0);
168 }
169 if !item.keywords.is_empty() {
171 entry.tags.extend(item.keywords);
172 }
173 if !item.tags.is_empty() {
174 entry.tags.extend(item.tags);
175 }
176 entry.tags.dedup();
177
178 Some(entry)
179 })
180 .collect();
181
182 Ok(deduplicate_entries(entries))
183}
184
185fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
186 let mut seen: Vec<String> = Vec::new();
187 entries
188 .into_iter()
189 .filter(|e| {
190 let content_lower = e.content.to_lowercase();
191 if seen.iter().any(|s| {
192 AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
193 }) {
194 false
195 } else {
196 seen.push(content_lower);
197 true
198 }
199 })
200 .take(MAX_DETECTED_ENTRIES)
201 .collect()
202}
203
204pub fn detect_memories_fallback(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
210 let config = KeywordsConfig::load();
211 let mut entries = Vec::new();
212 let text_lower = text.to_lowercase();
213
214 let categories = [
215 (MemoryCategory::Decision, "decision"),
216 (MemoryCategory::Preference, "preference"),
217 (MemoryCategory::Solution, "solution"),
218 (MemoryCategory::Finding, "finding"),
219 (MemoryCategory::Technical, "technical"),
220 (MemoryCategory::Structure, "structure"),
221 ];
222
223 for (category, key) in categories {
224 let patterns = config
225 .patterns
226 .get(key)
227 .map(|v| v.as_slice())
228 .unwrap_or(&[]);
229 for keyword in patterns {
230 if text_lower.contains(&keyword.to_lowercase()) {
231 let content = extract_memory_content(text, keyword);
232 if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
233 entries.push(MemoryEntry::new(
234 category,
235 content,
236 session_id.map(|s| s.to_string()),
237 project_path.map(|p| p.to_string()),
238 ));
239 }
240 }
241 }
242 }
243
244 deduplicate_entries(entries)
245}
246
247pub fn detect_memories_from_text(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
249 detect_memories_fallback(text, session_id, project_path)
250}
251
252pub async fn detect_memories_smart(
258 text: &str,
259 session_id: Option<&str>,
260 project_path: Option<&str>,
261 extractor: Option<&AiMemoryExtractor>,
262) -> Vec<MemoryEntry> {
263 let mode = AiDetectionMode::from_env();
264 let text_len = text.len();
265
266 let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
269
270 let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
272 crate::debug::debug_log().memory_ai_detection(
273 model_name,
274 0, text_len,
276 should_try_ai,
277 );
278
279 if should_try_ai && let Some(ex) = extractor {
280 if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
281 crate::debug::debug_log().memory_ai_detection(
284 ex.model_name(),
285 ai_entries.len(),
286 text_len,
287 true,
288 );
289 return deduplicate_entries(ai_entries);
290 }
291 log::warn!("AI memory extraction failed, skipping detection for this turn");
293 return Vec::new();
294 }
295
296 Vec::new()
299}
300
301fn extract_memory_content(text: &str, keyword: &str) -> String {
302 let text_lower = text.to_lowercase();
303 let keyword_lower = keyword.to_lowercase();
304
305 let pos = match text_lower.find(&keyword_lower) {
306 Some(p) => p,
307 None => return String::new(),
308 };
309
310 let start = text[..pos]
312 .rfind(['.', '。', '\n'])
313 .map(|i| i + 1)
314 .unwrap_or(0);
315
316 let end = text[pos..]
317 .find(['.', '。', '\n'])
318 .map(|i| pos + i + 1)
319 .unwrap_or(text.len());
320
321 let sentence = text[start..end].trim();
322
323 if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
324 sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
325 } else {
326 sentence.to_string()
327 }
328}
329
330pub fn infer_category_from_content(content: &str) -> MemoryCategory {
332 let lower = content.to_lowercase();
333
334 if lower.contains("决定")
335 || lower.contains("选择")
336 || lower.contains("采用")
337 || lower.contains("decided")
338 {
339 return MemoryCategory::Decision;
340 }
341 if lower.contains("喜欢")
342 || lower.contains("偏好")
343 || lower.contains("习惯")
344 || lower.contains("prefer")
345 {
346 return MemoryCategory::Preference;
347 }
348 if lower.contains("解决")
349 || lower.contains("修复")
350 || lower.contains("搞定")
351 || lower.contains("fixed")
352 {
353 return MemoryCategory::Solution;
354 }
355 if lower.contains("发现")
356 || lower.contains("原因")
357 || lower.contains("原来")
358 || lower.contains("found")
359 {
360 return MemoryCategory::Finding;
361 }
362 if lower.contains("技术")
363 || lower.contains("框架")
364 || lower.contains("库")
365 || lower.contains("tech")
366 {
367 return MemoryCategory::Technical;
368 }
369 if lower.contains("文件")
370 || lower.contains("目录")
371 || lower.contains("入口")
372 || lower.contains("file")
373 {
374 return MemoryCategory::Structure;
375 }
376
377 MemoryCategory::Finding }