Skip to main content

matrixcode_core/memory/
extractor.rs

1//! Memory extraction: AI-based and rule-based detection.
2
3use crate::truncate::truncate_chars;
4use anyhow::Result;
5use serde::Deserialize;
6
7use super::config::*;
8use super::keywords_config::KeywordsConfig;
9use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
10
11// ============================================================================
12// Memory Extractor Trait
13// ============================================================================
14
15/// Trait for memory extraction implementations.
16#[async_trait::async_trait]
17pub trait MemoryExtractor: Send + Sync {
18    /// Extract memories from conversation text using AI.
19    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>>;
20
21    /// Get the model name used for extraction.
22    fn model_name(&self) -> &str;
23}
24
25/// AI-based memory extractor using a fast/cheap model.
26pub struct AiMemoryExtractor {
27    provider: Box<dyn crate::providers::Provider>,
28    model: String,
29}
30
31impl AiMemoryExtractor {
32    /// Create a new AI memory extractor.
33    pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
34        Self { provider, model }
35    }
36
37    /// Create a minimal extractor (for background tasks, uses simplified prompt).
38    /// This is more efficient for non-blocking background extraction.
39    pub fn new_minimal(model: String) -> Self {
40        // Create a minimal provider that uses the global config
41        // This is for background tasks, so we use a simplified approach
42        Self {
43            provider: crate::create_minimal_provider(&model),
44            model,
45        }
46    }
47}
48
49const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
50
51记忆类型:
52- decision: 项目或技术选型的决定
53- preference: 用户习惯或偏好
54- solution: 解决问题的具体方法
55- finding: 重要发现或信息
56- technical: 技术栈或框架信息
57- structure: 项目结构信息
58
59输出格式(严格 JSON):
60{
61  "memories": [
62    {
63      "category": "decision",
64      "content": "采用 PostgreSQL 作为主数据库",
65      "importance": 85,
66      "keywords": ["PostgreSQL", "数据库", "database"],
67      "tags": ["backend", "storage"]
68    }
69  ]
70}
71
72关键词提取要求:
73- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
74- 中英文关键词都提取
75- 用于后续记忆检索匹配
76
77标签提取要求:
78- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
79- 用于记忆分类筛选
80
81只返回 JSON,不要其他解释。"#;
82
83#[async_trait::async_trait]
84impl MemoryExtractor for AiMemoryExtractor {
85    async fn extract(&self, text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
86        use crate::providers::{ChatRequest, Message, MessageContent, Role};
87
88        // Safely truncate to ~4000 chars respecting UTF-8 boundaries
89        let truncated = truncate_chars(text, 4000);
90
91        let request = ChatRequest {
92            messages: vec![Message {
93                role: Role::User,
94                content: MessageContent::Text(format!(
95                    "请从以下对话中提取值得记忆的关键信息:\n\n{}",
96                    truncated
97                )),
98            }],
99            tools: vec![],
100            system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
101            think: false,
102            max_tokens: 512,
103            server_tools: vec![],
104            enable_caching: false,
105        };
106
107        let response = self.provider.chat(request).await?;
108
109        let response_text = response
110            .content
111            .iter()
112            .filter_map(|b| {
113                if let crate::providers::ContentBlock::Text { text } = b {
114                    Some(text.clone())
115                } else {
116                    None
117                }
118            })
119            .collect::<Vec<_>>()
120            .join("");
121
122        parse_memory_response(&response_text, session_id, project_path)
123    }
124
125    fn model_name(&self) -> &str {
126        &self.model
127    }
128}
129
130fn parse_memory_response(json_text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Result<Vec<MemoryEntry>> {
131    let cleaned = json_text
132        .trim()
133        .trim_start_matches("```json")
134        .trim_start_matches("```")
135        .trim_end_matches("```")
136        .trim();
137
138    #[derive(Deserialize)]
139    struct MemoryResponse {
140        memories: Vec<MemoryItem>,
141    }
142
143    #[derive(Deserialize)]
144    struct MemoryItem {
145        category: String,
146        content: String,
147        #[serde(default)]
148        importance: f64,
149        #[serde(default)]
150        keywords: Vec<String>,
151        #[serde(default)]
152        tags: Vec<String>,
153    }
154
155    let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
156
157    let entries = parsed
158        .memories
159        .into_iter()
160        .filter_map(|item| {
161            let category = match item.category.to_lowercase().as_str() {
162                "decision" => MemoryCategory::Decision,
163                "preference" => MemoryCategory::Preference,
164                "solution" => MemoryCategory::Solution,
165                "finding" => MemoryCategory::Finding,
166                "technical" => MemoryCategory::Technical,
167                "structure" => MemoryCategory::Structure,
168                _ => return None,
169            };
170
171            if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
172                return None;
173            }
174
175            let mut entry =
176                MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()), project_path.map(|p| p.to_string()));
177            if item.importance > 0.0 {
178                entry.importance = item.importance.clamp(0.0, 100.0);
179            }
180            // Add AI-extracted keywords and tags
181            if !item.keywords.is_empty() {
182                entry.tags.extend(item.keywords);
183            }
184            if !item.tags.is_empty() {
185                entry.tags.extend(item.tags);
186            }
187            entry.tags.dedup();
188
189            Some(entry)
190        })
191        .collect();
192
193    Ok(deduplicate_entries(entries))
194}
195
196fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
197    let mut seen: Vec<String> = Vec::new();
198    entries
199        .into_iter()
200        .filter(|e| {
201            let content_lower = e.content.to_lowercase();
202            if seen.iter().any(|s| {
203                AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
204            }) {
205                false
206            } else {
207                seen.push(content_lower);
208                true
209            }
210        })
211        .take(MAX_DETECTED_ENTRIES)
212        .collect()
213}
214
215// ============================================================================
216// Rule-based Detection (uses KeywordsConfig)
217// ============================================================================
218
219/// Detect memories from text using configurable patterns.
220pub fn detect_memories_fallback(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
221    let config = KeywordsConfig::load();
222    let mut entries = Vec::new();
223    let text_lower = text.to_lowercase();
224
225    let categories = [
226        (MemoryCategory::Decision, "decision"),
227        (MemoryCategory::Preference, "preference"),
228        (MemoryCategory::Solution, "solution"),
229        (MemoryCategory::Finding, "finding"),
230        (MemoryCategory::Technical, "technical"),
231        (MemoryCategory::Structure, "structure"),
232    ];
233
234    for (category, key) in categories {
235        let patterns = config
236            .patterns
237            .get(key)
238            .map(|v| v.as_slice())
239            .unwrap_or(&[]);
240        for keyword in patterns {
241            if text_lower.contains(&keyword.to_lowercase()) {
242                let content = extract_memory_content(text, keyword);
243                if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
244                    entries.push(MemoryEntry::new(
245                        category,
246                        content,
247                        session_id.map(|s| s.to_string()),
248                        project_path.map(|p| p.to_string()),
249                    ));
250                }
251            }
252        }
253    }
254
255    deduplicate_entries(entries)
256}
257
258/// Detect memories from text (wrapper for fallback).
259pub fn detect_memories_from_text(text: &str, session_id: Option<&str>, project_path: Option<&str>) -> Vec<MemoryEntry> {
260    detect_memories_fallback(text, session_id, project_path)
261}
262
263/// Smart detection: AI-first with rule-based fallback.
264///
265/// Priority order:
266/// 1. AI extraction (if text > 200 chars and extractor available)
267/// 2. Rule-based fallback (if AI fails or text too short)
268pub async fn detect_memories_smart(
269    text: &str,
270    session_id: Option<&str>,
271    project_path: Option<&str>,
272    extractor: Option<&AiMemoryExtractor>,
273) -> Vec<MemoryEntry> {
274    let mode = AiDetectionMode::from_env();
275    let text_len = text.len();
276
277    // Determine if we should try AI first
278    // Only use AI for text > 200 chars (avoid API overhead for short texts)
279    let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
280
281    // Debug log: show method and model
282    let model_name = extractor.map(|e| e.model_name()).unwrap_or("none");
283    crate::debug::debug_log().memory_ai_detection(
284        model_name,
285        0, // Will update after detection
286        text_len,
287        should_try_ai,
288    );
289
290    if should_try_ai && let Some(ex) = extractor {
291        if let Ok(ai_entries) = ex.extract(text, session_id, project_path).await {
292            // AI succeeded - use AI results entirely (skip hardcoded rules)
293            // Debug log: AI result
294            crate::debug::debug_log().memory_ai_detection(
295                ex.model_name(),
296                ai_entries.len(),
297                text_len,
298                true,
299            );
300            return deduplicate_entries(ai_entries);
301        }
302        // AI failed - log and skip rule-based fallback (per user request)
303        log::warn!("AI memory extraction failed, skipping detection for this turn");
304        return Vec::new();
305    }
306
307    // For short texts (< 200 chars), skip detection entirely (per user request)
308    // No rule-based fallback
309    Vec::new()
310}
311
312fn extract_memory_content(text: &str, keyword: &str) -> String {
313    let text_lower = text.to_lowercase();
314    let keyword_lower = keyword.to_lowercase();
315
316    let pos = match text_lower.find(&keyword_lower) {
317        Some(p) => p,
318        None => return String::new(),
319    };
320
321    // Find sentence containing the keyword
322    let start = text[..pos]
323        .rfind(['.', '。', '\n'])
324        .map(|i| i + 1)
325        .unwrap_or(0);
326
327    let end = text[pos..]
328        .find(['.', '。', '\n'])
329        .map(|i| pos + i + 1)
330        .unwrap_or(text.len());
331
332    let sentence = text[start..end].trim();
333
334    if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
335        sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
336    } else {
337        sentence.to_string()
338    }
339}
340
341/// Infer category from content.
342pub fn infer_category_from_content(content: &str) -> MemoryCategory {
343    let lower = content.to_lowercase();
344
345    if lower.contains("决定")
346        || lower.contains("选择")
347        || lower.contains("采用")
348        || lower.contains("decided")
349    {
350        return MemoryCategory::Decision;
351    }
352    if lower.contains("喜欢")
353        || lower.contains("偏好")
354        || lower.contains("习惯")
355        || lower.contains("prefer")
356    {
357        return MemoryCategory::Preference;
358    }
359    if lower.contains("解决")
360        || lower.contains("修复")
361        || lower.contains("搞定")
362        || lower.contains("fixed")
363    {
364        return MemoryCategory::Solution;
365    }
366    if lower.contains("发现")
367        || lower.contains("原因")
368        || lower.contains("原来")
369        || lower.contains("found")
370    {
371        return MemoryCategory::Finding;
372    }
373    if lower.contains("技术")
374        || lower.contains("框架")
375        || lower.contains("库")
376        || lower.contains("tech")
377    {
378        return MemoryCategory::Technical;
379    }
380    if lower.contains("文件")
381        || lower.contains("目录")
382        || lower.contains("入口")
383        || lower.contains("file")
384    {
385        return MemoryCategory::Structure;
386    }
387
388    MemoryCategory::Finding // Default
389}