1use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18 async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21 fn model_name(&self) -> &str;
23}
24
25pub struct AiMemoryExtractor {
27 provider: Box<dyn crate::providers::Provider>,
28 model: String,
29}
30
31impl AiMemoryExtractor {
32 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34 Self { provider, model }
35 }
36
37 pub fn new_minimal(model: String) -> Self {
40 Self {
43 provider: crate::create_minimal_provider(&model),
44 model,
45 }
46 }
47}
48
49const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
50
51记忆类型:
52- decision: 项目或技术选型的决定
53- preference: 用户习惯或偏好
54- solution: 解决问题的具体方法
55- finding: 重要发现或信息
56- technical: 技术栈或框架信息
57- structure: 项目结构信息
58
59输出格式(严格 JSON):
60{
61 "memories": [
62 {
63 "category": "decision",
64 "content": "采用 PostgreSQL 作为主数据库",
65 "importance": 85,
66 "keywords": ["PostgreSQL", "数据库", "database"],
67 "tags": ["backend", "storage"]
68 }
69 ]
70}
71
72关键词提取要求:
73- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
74- 中英文关键词都提取
75- 用于后续记忆检索匹配
76
77标签提取要求:
78- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
79- 用于记忆分类筛选
80
81只返回 JSON,不要其他解释。"#;
82
83#[async_trait::async_trait]
84impl MemoryExtractor for AiMemoryExtractor {
85 async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
86 use crate::providers::{ChatRequest, Message, MessageContent, Role};
87
88 let truncated = truncate_chars(text, 4000);
90
91 let request = ChatRequest {
92 messages: vec![Message {
93 role: Role::User,
94 content: MessageContent::Text(format!(
95 "请从以下对话中提取值得记忆的关键信息:\n\n{}",
96 truncated
97 )),
98 }],
99 tools: vec![],
100 system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
101 think: false,
102 max_tokens: 512,
103 server_tools: vec![],
104 enable_caching: false,
105 };
106
107 let response = self.provider.chat(request).await?;
108
109 let response_text = response
110 .content
111 .iter()
112 .filter_map(|b| {
113 if let crate::providers::ContentBlock::Text { text } = b {
114 Some(text.clone())
115 } else {
116 None
117 }
118 })
119 .collect::<Vec<_>>()
120 .join("");
121
122 parse_memory_response(&response_text, session_id, project_path)
123 }
124
125 fn model_name(&self) -> &str {
126 &self.model
127 }
128}
129
130fn parse_memory_response(json_text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
131 let cleaned = json_text
132 .trim()
133 .trim_start_matches("```json")
134 .trim_start_matches("```")
135 .trim_end_matches("```")
136 .trim();
137
138 #[derive(Deserialize)]
139 struct MemoryResponse {
140 memories: Vec<MemoryItem>,
141 }
142
143 #[derive(Deserialize)]
144 struct MemoryItem {
145 category: String,
146 content: String,
147 #[serde(default)]
148 importance: f64,
149 #[serde(default)]
150 keywords: Vec<String>,
151 #[serde(default)]
152 tags: Vec<String>,
153 }
154
155 let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
156
157 let entries = parsed
158 .memories
159 .into_iter()
160 .filter_map(|item| {
161 let category = match item.category.to_lowercase().as_str() {
162 "decision" => MemoryCategory::Decision,
163 "preference" => MemoryCategory::Preference,
164 "solution" => MemoryCategory::Solution,
165 "finding" => MemoryCategory::Finding,
166 "technical" => MemoryCategory::Technical,
167 "structure" => MemoryCategory::Structure,
168 _ => return None,
169 };
170
171 if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
172 return None;
173 }
174
175 let mut entry =
176 MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()), project_path.map(|p| p.to_string()));
177 if item.importance > 0.0 {
178 entry.importance = item.importance.clamp(0.0, 100.0);
179 }
180 if !item.keywords.is_empty() {
182 entry.tags.extend(item.keywords);
183 }
184 if !item.tags.is_empty() {
185 entry.tags.extend(item.tags);
186 }
187 entry.tags.dedup();
188
189 Some(entry)
190 })
191 .collect();
192
193 Ok(deduplicate_entries(entries))
194}
195
196fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
197 let mut seen: Vec<String> = Vec::new();
198 entries
199 .into_iter()
200 .filter(|e| {
201 let content_lower = e.content.to_lowercase();
202 if seen.iter().any(|s| {
203 AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
204 }) {
205 false
206 } else {
207 seen.push(content_lower);
208 true
209 }
210 })
211 .take(MAX_DETECTED_ENTRIES)
212 .collect()
213}
214
215pub fn detect_memories_fallback(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
221 let config = KeywordsConfig::load();
222 let mut entries = Vec::new();
223 let text_lower = text.to_lowercase();
224
225 let categories = [
226 (MemoryCategory::Decision, "decision"),
227 (MemoryCategory::Preference, "preference"),
228 (MemoryCategory::Solution, "solution"),
229 (MemoryCategory::Finding, "finding"),
230 (MemoryCategory::Technical, "technical"),
231 (MemoryCategory::Structure, "structure"),
232 ];
233
234 for (category, key) in categories {
235 let patterns = config
236 .patterns
237 .get(key)
238 .map(|v| v.as_slice())
239 .unwrap_or(&[]);
240 for keyword in patterns {
241 if text_lower.contains(&keyword.to_lowercase()) {
242 let content = extract_memory_content(text, keyword);
243 if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
244 entries.push(MemoryEntry::new(
245 category,
246 content,
247 session_id.map(|s| s.to_string()),
248 project_path.map(|p| p.to_string()),
249 ));
250 }
251 }
252 }
253 }
254
255 deduplicate_entries(entries)
256}
257
258pub fn detect_memories_from_text(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
260 detect_memories_fallback(text, session_id, project_path)
261}
262
263pub async fn detect_memories_smart(
269 text: &str,
270 session_id: Option<&str>,
271 project_path: Option<&str>,
272 extractor: Option<&AiMemoryExtractor>,
273) -> Vec<MemoryEntry> {
274 let mode = AiDetectionMode::from_env();
275 let text_len = text.len();
276
277 let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
280
281 let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
283 crate::debug::debug_log().memory_ai_detection(
284 model_name,
285 0, text_len,
287 should_try_ai,
288 );
289
290 if should_try_ai && let Some(ex) = extractor {
291 if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
292 crate::debug::debug_log().memory_ai_detection(
295 ex.model_name(),
296 ai_entries.len(),
297 text_len,
298 true,
299 );
300 return deduplicate_entries(ai_entries);
301 }
302 log::warn!("AI memory extraction failed, skipping detection for this turn");
304 return Vec::new();
305 }
306
307 Vec::new()
310}
311
312fn extract_memory_content(text: &str, keyword: &str) -> String {
313 let text_lower = text.to_lowercase();
314 let keyword_lower = keyword.to_lowercase();
315
316 let pos = match text_lower.find(&keyword_lower) {
317 Some(p) => p,
318 None => return String::new(),
319 };
320
321 let start = text[..pos]
323 .rfind(['.', '。', '\n'])
324 .map(|i| i + 1)
325 .unwrap_or(0);
326
327 let end = text[pos..]
328 .find(['.', '。', '\n'])
329 .map(|i| pos + i + 1)
330 .unwrap_or(text.len());
331
332 let sentence = text[start..end].trim();
333
334 if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
335 sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
336 } else {
337 sentence.to_string()
338 }
339}
340
341pub fn infer_category_from_content(content: &str) -> MemoryCategory {
343 let lower = content.to_lowercase();
344
345 if lower.contains("决定")
346 || lower.contains("选择")
347 || lower.contains("采用")
348 || lower.contains("decided")
349 {
350 return MemoryCategory::Decision;
351 }
352 if lower.contains("喜欢")
353 || lower.contains("偏好")
354 || lower.contains("习惯")
355 || lower.contains("prefer")
356 {
357 return MemoryCategory::Preference;
358 }
359 if lower.contains("解决")
360 || lower.contains("修复")
361 || lower.contains("搞定")
362 || lower.contains("fixed")
363 {
364 return MemoryCategory::Solution;
365 }
366 if lower.contains("发现")
367 || lower.contains("原因")
368 || lower.contains("原来")
369 || lower.contains("found")
370 {
371 return MemoryCategory::Finding;
372 }
373 if lower.contains("技术")
374 || lower.contains("框架")
375 || lower.contains("库")
376 || lower.contains("tech")
377 {
378 return MemoryCategory::Technical;
379 }
380 if lower.contains("文件")
381 || lower.contains("目录")
382 || lower.contains("入口")
383 || lower.contains("file")
384 {
385 return MemoryCategory::Structure;
386 }
387
388 MemoryCategory::Finding }