Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::entry::MemoryEntry;
6use super::manager::AutoMemory;
7
8// ============================================================================
9// Hard-coded Stop Words (simplified)
10// ============================================================================
11
12/// Common Chinese + English stop words (minimal set).
13fn get_stop_words() -> HashSet<&'static str> {
14    HashSet::from([
15        // Chinese
16        "的", "了", "是", "在", "我", "有", "和", "就", "不", "都", "一", "也", "很", "到", "要",
17        "去", "你", "会", "着", "没有", "看", "好", "这", "那", "什么", "怎么", "请", "能", "可以",
18        "需要", // English
19        "the", "a", "an", "is", "are", "was", "were", "be", "have", "has", "do", "will", "would",
20        "could", "should", "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "and",
21        "but", "or", "not", "if", "then", "this", "that", "it", "i", "me", "my", "we", "you", "he",
22        "she", "they", "please", "help", "need", "want", "let", "use",
23    ])
24}
25
26// ============================================================================
27// Keyword Extraction (simplified)
28// ============================================================================
29
30/// Extract meaningful keywords from conversation context.
31pub fn extract_context_keywords(context: &str) -> Vec<String> {
32    let stop_words = get_stop_words();
33    let lower = context.to_lowercase();
34    let mut keywords: HashSet<String> = HashSet::new();
35
36    // 1. Extract English words (must be meaningful - at least 3 chars)
37    for word in lower.split_whitespace() {
38        let cleaned = word
39            .trim_matches(|c: char| !c.is_alphanumeric())
40            .to_string();
41        if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
42            keywords.insert(cleaned);
43        }
44    }
45
46    // 2. Extract tech patterns (camelCase, snake_case, file paths)
47    let tech_regexes = [
48        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", // file extensions
49        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",             // CamelCase
50        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",        // snake_case
51        r"[0-9]+[kKmMgGtT][bB]?",                 // sizes like 4KB
52    ];
53
54    for pattern in tech_regexes {
55        if let Ok(re) = regex::Regex::new(pattern) {
56            for cap in re.find_iter(&lower) {
57                let match_str = cap.as_str();
58                if !stop_words.contains(match_str) {
59                    keywords.insert(match_str.to_string());
60                }
61            }
62        }
63    }
64
65    // Sort by length and limit
66    let mut result: Vec<String> = keywords.into_iter().collect();
67    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
68    result.truncate(10);
69    result
70}
71
72// ============================================================================
73// Skip Simple Messages (Greeting/Short)
74// ============================================================================
75
76/// Greeting patterns to skip keyword extraction.
77const GREETING_PATTERNS: &[&str] = &[
78    "你好",
79    "您好",
80    "hi",
81    "hello",
82    "hey",
83    "嗨",
84    "早上好",
85    "下午好",
86    "晚上好",
87    "good morning",
88    "good afternoon",
89    "good evening",
90    "请问",
91    "帮忙",
92    "帮我",
93    "帮我看",
94    "看看",
95    "help",
96    "请",
97    "开始",
98    "start",
99    "准备好了",
100    "ready",
101];
102
103/// Check if message is simple (greeting/short) and should skip AI keyword extraction.
104/// Returns true if should skip.
105pub fn should_skip_simple_message(msg: &str) -> bool {
106    let trimmed = msg.trim();
107
108    // Skip if too short (< 15 chars)
109    if trimmed.len() < 15 {
110        return true;
111    }
112
113    // Skip greeting patterns
114    let lower = trimmed.to_lowercase();
115    for pattern in GREETING_PATTERNS {
116        if lower.starts_with(pattern) || lower == *pattern {
117            return true;
118        }
119    }
120
121    false
122}
123
124/// Calculate word-based similarity between two strings (Jaccard coefficient).
125pub fn calculate_similarity(a: &str, b: &str) -> f64 {
126    AutoMemory::calculate_similarity(a, b)
127}
128
129// ============================================================================
130// Semantic Aliases (uses KeywordsConfig)
131// ============================================================================
132
133/// Get semantic aliases (hard-coded).
134pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
135    vec![
136        // Technical terms
137        ("rust", "Rust"),
138        ("typescript", "TypeScript"),
139        ("javascript", "JavaScript"),
140        ("python", "Python"),
141        ("react", "React"),
142        ("vue", "Vue"),
143        ("angular", "Angular"),
144        ("数据库", "database"),
145        ("db", "database"),
146        // Actions
147        ("修复", "fix"),
148        ("解决", "solve"),
149        ("优化", "optimize"),
150        ("重构", "refactor"),
151        ("更新", "update"),
152        ("删除", "delete"),
153        // Preferences
154        ("喜欢", "prefer"),
155        ("偏好", "prefer"),
156        ("首选", "prefer"),
157        // Structures
158        ("入口", "entry"),
159        ("主文件", "main"),
160        ("目录", "directory"),
161    ]
162}
163
164/// Expand keywords with semantic aliases.
165pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
166    let aliases = get_semantic_aliases();
167    let mut expanded: Vec<String> = keywords.to_vec();
168
169    for keyword in keywords {
170        let kw_lower = keyword.to_lowercase();
171        for &(alias, target) in &aliases {
172            if kw_lower.contains(alias) {
173                expanded.push(target.to_string());
174            }
175            if kw_lower.contains(target) {
176                expanded.push(alias.to_string());
177            }
178        }
179    }
180
181    expanded.sort();
182    expanded.dedup();
183    expanded
184}
185
186// ============================================================================
187// Relevance & Contradiction Detection (uses KeywordsConfig)
188// ============================================================================
189
190/// Compute relevance score of a memory entry to context keywords.
191/// Returns 0.0-1.0 where 1.0 means highly relevant.
192pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
193    if context_keywords.is_empty() {
194        return 0.0;
195    }
196
197    let expanded_keywords = expand_semantic_keywords(context_keywords);
198    let content_lower = entry.content.to_lowercase();
199
200    let matches = expanded_keywords
201        .iter()
202        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
203        .count();
204
205    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
206
207    let tag_matches = entry
208        .tags
209        .iter()
210        .filter(|tag| {
211            let tag_lower = tag.to_lowercase();
212            expanded_keywords.iter().any(|kw| {
213                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
214            })
215        })
216        .count();
217
218    let tag_score = if tag_matches > 0 {
219        0.2 + (tag_matches as f64 * 0.05).min(0.1)
220    } else {
221        0.0
222    };
223
224    (keyword_score + tag_score).min(1.0)
225}
226
227/// Detect if two memory contents have contradiction signals.
228/// Uses hard-coded contradiction signals.
229pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
230    // Hard-coded contradiction signals
231    let contradiction_signals = [
232        "不再",
233        "改为",
234        "换成",
235        "放弃",
236        "no longer",
237        "instead of",
238        "changed to",
239        "switched to",
240    ];
241
242    // Check contradiction signals
243    for signal in &contradiction_signals {
244        if new.contains(signal) {
245            return true;
246        }
247    }
248
249    // Check action verbs that indicate change
250    let action_verbs = [
251        "决定使用",
252        "选择使用",
253        "采用",
254        "使用",
255        "decided to use",
256        "chose",
257        "using",
258        "adopted",
259    ];
260
261    for verb in &action_verbs {
262        if old.contains(verb) && new.contains(verb) {
263            return true;
264        }
265    }
266
267    // Check preference verbs
268    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
269    for verb in &pref_verbs {
270        if old.contains(verb) && new.contains(verb) {
271            return true;
272        }
273    }
274
275    false
276}
277
278// ============================================================================
279// TF-IDF Search
280// ============================================================================
281
282/// Semantic search using TF-IDF algorithm.
283///
284/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
285/// semantic search method without needing an AI model.
286pub struct TfIdfSearch {
287    /// Word frequency in each document.
288    doc_word_freq: HashMap<String, HashMap<String, f32>>,
289    /// Total documents.
290    total_docs: usize,
291    /// IDF cache.
292    idf_cache: HashMap<String, f32>,
293}
294
295impl TfIdfSearch {
296    /// Create a new TF-IDF search instance.
297    pub fn new() -> Self {
298        Self {
299            doc_word_freq: HashMap::new(),
300            total_docs: 0,
301            idf_cache: HashMap::new(),
302        }
303    }
304
305    /// Index all memories for TF-IDF search.
306    pub fn index(&mut self, memory: &AutoMemory) {
307        self.clear();
308        self.total_docs = memory.entries.len();
309
310        for entry in &memory.entries {
311            let words = self.tokenize(&entry.content);
312            let word_freq = self.compute_word_freq(&words);
313            self.doc_word_freq.insert(entry.content.clone(), word_freq);
314        }
315
316        self.compute_idf();
317    }
318
319    /// Tokenize text into words.
320    fn tokenize(&self, text: &str) -> Vec<String> {
321        let lower = text.to_lowercase();
322        let mut tokens = Vec::new();
323
324        for word in lower.split_whitespace() {
325            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
326            if trimmed.len() > 1 {
327                tokens.push(trimmed.to_string());
328            }
329
330            let chars: Vec<char> = trimmed.chars().collect();
331            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
332
333            if has_cjk {
334                for c in &chars {
335                    if Self::is_cjk(*c) {
336                        tokens.push(c.to_string());
337                    }
338                }
339                for window in chars.windows(2) {
340                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
341                        tokens.push(window.iter().collect::<String>());
342                    }
343                }
344            }
345        }
346
347        tokens
348    }
349
350    /// Check if a character is CJK.
351    fn is_cjk(c: char) -> bool {
352        matches!(c,
353            '\u{4E00}'..='\u{9FFF}' |
354            '\u{3400}'..='\u{4DBF}' |
355            '\u{F900}'..='\u{FAFF}' |
356            '\u{3000}'..='\u{303F}' |
357            '\u{3040}'..='\u{309F}' |
358            '\u{30A0}'..='\u{30FF}'
359        )
360    }
361
362    /// Compute word frequency in a document.
363    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
364        let total = words.len() as f32;
365        let mut freq = HashMap::new();
366
367        for word in words {
368            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
369        }
370
371        for (_, count) in freq.iter_mut() {
372            *count /= total;
373        }
374
375        freq
376    }
377
378    /// Compute IDF for all words.
379    fn compute_idf(&mut self) {
380        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
381
382        for word_freq in &self.doc_word_freq {
383            for word in word_freq.1.keys() {
384                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
385            }
386        }
387
388        for (word, count) in word_doc_count {
389            let idf = (self.total_docs as f32 / count as f32).ln();
390            self.idf_cache.insert(word, idf);
391        }
392    }
393
394    /// Search using TF-IDF similarity.
395    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
396        let query_words = self.tokenize(query);
397        let query_freq = self.compute_word_freq(&query_words);
398
399        let mut results: Vec<(String, f32)> = Vec::new();
400
401        for (doc, doc_freq) in &self.doc_word_freq {
402            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
403
404            if similarity > 0.0 {
405                results.push((doc.clone(), similarity));
406            }
407        }
408
409        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
410
411        if let Some(max) = limit {
412            results.into_iter().take(max).collect()
413        } else {
414            results
415        }
416    }
417
418    /// Search with multiple keywords.
419    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
420        let mut doc_scores: HashMap<String, f64> = HashMap::new();
421
422        for keyword in keywords {
423            let results = self.search(keyword, None);
424            for (doc, score) in results {
425                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
426            }
427        }
428
429        let num_keywords = keywords.len().max(1);
430        for (_, score) in doc_scores.iter_mut() {
431            *score /= num_keywords as f64;
432        }
433
434        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
435        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
436
437        if let Some(max) = limit {
438            results.into_iter().take(max).collect()
439        } else {
440            results
441        }
442    }
443
444    /// Compute TF-IDF similarity.
445    fn compute_tfidf_similarity(
446        &self,
447        query_freq: &HashMap<String, f32>,
448        doc_freq: &HashMap<String, f32>,
449    ) -> f32 {
450        let mut similarity = 0.0;
451
452        for (word, tf_query) in query_freq {
453            if let Some(tf_doc) = doc_freq.get(word)
454                && let Some(idf) = self.idf_cache.get(word)
455            {
456                similarity += tf_query * idf * tf_doc * idf;
457            }
458        }
459
460        similarity
461    }
462
463    /// Clear all indices.
464    pub fn clear(&mut self) {
465        self.doc_word_freq.clear();
466        self.idf_cache.clear();
467        self.total_docs = 0;
468    }
469}
470
471impl Default for TfIdfSearch {
472    fn default() -> Self {
473        Self::new()
474    }
475}
476
477// ============================================================================
478// AI Memory Selection (Claude Code style)
479// ============================================================================
480
481/// System prompt for AI memory selection.
482const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
483
484返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
485- 只选择你确定会有帮助的记忆
486- 如果不确定某个记忆是否有用,不要选择它
487- 如果没有明显有用的记忆,可以返回空数组 []
488- 优先选择与当前问题直接相关的记忆
489
490返回格式示例:{"selected": [0, 2, 5]}
491"#;
492
493/// Select relevant memories using AI (Claude Code style).
494///
495/// Takes user query and memory manifest (descriptions), uses AI to select
496/// the most relevant ones (up to 5).
497pub async fn ai_select_memories(
498    query: &str,
499    memory_manifest: &str,
500    provider: &dyn crate::providers::Provider,
501) -> Vec<usize> {
502    use crate::providers::{ChatRequest, Message, MessageContent, Role};
503
504    // Truncate query if too long
505    let truncated_query = if query.len() > 1000 {
506        &query[..1000]
507    } else {
508        query
509    };
510
511    let user_prompt = format!(
512        "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
513        truncated_query, memory_manifest
514    );
515
516    let request = ChatRequest {
517        messages: vec![Message {
518            role: Role::User,
519            content: MessageContent::Text(user_prompt),
520        }],
521        tools: vec![],
522        system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
523        think: false,
524        max_tokens: 100,
525        server_tools: vec![],
526        enable_caching: false,
527    };
528
529    let response = match provider.chat(request).await {
530        Ok(r) => r,
531        Err(_) => return Vec::new(),
532    };
533
534    // Extract text from response
535    let text = response
536        .content
537        .iter()
538        .filter_map(|block| {
539            if let crate::providers::ContentBlock::Text { text } = block {
540                Some(text.clone())
541            } else {
542                None
543            }
544        })
545        .collect::<Vec<_>>()
546        .join("");
547
548    // Parse JSON response
549    parse_selected_indices(&text)
550}
551
552/// Parse selected indices from AI response.
553fn parse_selected_indices(text: &str) -> Vec<usize> {
554    // Try to parse as JSON
555    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
556        if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
557            return selected
558                .iter()
559                .filter_map(|v| v.as_u64().map(|n| n as usize))
560                .collect();
561        }
562        // Also try direct array format
563        if let Some(arr) = json.as_array() {
564            return arr
565                .iter()
566                .filter_map(|v| v.as_u64().map(|n| n as usize))
567                .collect();
568        }
569    }
570
571    // Fallback: try to extract numbers from text
572    let mut indices = Vec::new();
573    for part in text.split(',') {
574        let trimmed = part.trim();
575        if let Ok(n) = trimmed.parse::<usize>() {
576            indices.push(n);
577        }
578    }
579    indices
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use crate::memory::MemoryCategory;
586
587    #[test]
588    fn test_extract_keywords() {
589        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
590        assert!(!keywords.is_empty());
591    }
592
593    #[test]
594    fn test_semantic_aliases() {
595        let keywords = vec!["数据库".to_string()];
596        let expanded = expand_semantic_keywords(&keywords);
597        assert!(expanded.contains(&"database".to_string()));
598    }
599
600    #[test]
601    fn test_tfidf_search() {
602        let mut tfidf = TfIdfSearch::new();
603        let mut memory = AutoMemory::new();
604
605        // Add multiple documents so IDF calculation works properly
606        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
607        memory.add(MemoryEntry::new(
608            MemoryCategory::Decision,
609            "使用 PostgreSQL 作为数据库".to_string(),
610            None,
611            None,
612        ));
613        memory.add(MemoryEntry::new(
614            MemoryCategory::Decision,
615            "前端使用 React 框架开发".to_string(),
616            None,
617            None,
618        ));
619        memory.add(MemoryEntry::new(
620            MemoryCategory::Decision,
621            "后端采用 Rust 编写".to_string(),
622            None,
623            None,
624        ));
625
626        tfidf.index(&memory);
627        let results = tfidf.search("数据库", Some(5));
628        assert!(!results.is_empty());
629
630        // The PostgreSQL document should be the top result
631        assert!(results[0].0.contains("PostgreSQL"));
632    }
633}