Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, semantic aliases, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::types::{AutoMemory, MemoryEntry};
7
8// ============================================================================
9// Keyword Extraction
10// ============================================================================
11
12/// Extract meaningful keywords from conversation context.
13/// Filters out common stop words and short tokens.
14pub fn extract_context_keywords(context: &str) -> Vec<String> {
15    // Common stop words (Chinese + English)
16    let stop_words: HashSet<&str> = [
17        // Chinese stop words
18        "的", "了", "是", "在", "我", "有", "和", "就", "不", "人", "都", "一", "一个", "上", "也",
19        "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这", "他",
20        "她", "它", "们", "那", "些", "什么", "怎么", "如何", "请", "能", "可以", "需要", "应该",
21        "可能", "因为", "所以", "但是", "然后", "还是", "已经", "正在", "将要", "曾经", "一下",
22        "一点", "一些", "所有", "每个", "任何", // English stop words
23        "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
24        "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
25        "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
26        "during", "before", "after", "above", "below", "between", "and", "but", "or", "not", "no",
27        "so", "if", "then", "than", "too", "very", "just", "this", "that", "these", "those", "it",
28        "its", "i", "me", "my", "we", "our", "you", "your", "he", "his", "she", "her", "they",
29        "their", "please", "help", "need", "want", "make", "get", "let", "use",
30    ]
31    .iter()
32    .copied()
33    .collect();
34
35    // Technical/meaningful patterns
36    let tech_patterns: HashSet<&str> = [
37        "api", "cli", "gui", "tui", "web", "http", "json", "xml", "sql", "db", "git", "npm",
38        "cargo", "rust", "js", "ts", "py", "go", "java", "cpp", "cpu", "gpu", "io", "fs", "os",
39        "ui", "ux", "ai", "ml", "dl", "rs", "js", "ts", "py", "go", "java", "c", "h", "cpp", "hpp",
40        "json", "yaml", "yml", "toml", "md", "txt", "html", "css", "scss", "bug", "fix", "add",
41        "new", "old", "use", "run", "build", "test", "code", "data", "file", "dir", "path", "name",
42        "type", "value",
43    ]
44    .iter()
45    .copied()
46    .collect();
47
48    let lower = context.to_lowercase();
49    let mut keywords: HashSet<String> = HashSet::new();
50
51    // 1. Extract English words
52    for word in lower.split_whitespace() {
53        let cleaned = word
54            .trim_matches(|c: char| !c.is_alphanumeric())
55            .to_string();
56        if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
57            keywords.insert(cleaned.clone());
58        }
59        if tech_patterns.contains(cleaned.as_str()) {
60            keywords.insert(cleaned);
61        }
62    }
63
64    // 2. Extract Chinese words/phrases
65    let chinese_chars: Vec<char> = lower
66        .chars()
67        .filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}')
68        .collect();
69
70    for window_size in 2..=4 {
71        if chinese_chars.len() >= window_size {
72            for window in chinese_chars.windows(window_size) {
73                let phrase: String = window.iter().collect();
74                let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
75                if !has_stop && phrase.len() >= window_size {
76                    keywords.insert(phrase);
77                }
78            }
79        }
80    }
81
82    // 3. Extract specific patterns
83    let patterns = [
84        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",
85        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*",
86        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",
87        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",
88        r"[0-9]+[kKmMgGtT][bB]?",
89    ];
90
91    for pattern in patterns {
92        if let Ok(re) = regex::Regex::new(pattern) {
93            for cap in re.find_iter(&lower) {
94                keywords.insert(cap.as_str().to_string());
95            }
96        }
97    }
98
99    let mut result: Vec<String> = keywords.into_iter().collect();
100    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
101    result.truncate(15);
102
103    result
104}
105
106/// Calculate word-based similarity between two strings (Jaccard coefficient).
107pub fn calculate_similarity(a: &str, b: &str) -> f64 {
108    AutoMemory::calculate_similarity(a, b)
109}
110
111// ============================================================================
112// Semantic Aliases
113// ============================================================================
114
115/// Semantic alias mappings for better keyword matching.
116pub const SEMANTIC_ALIASES: &[(&str, &str)] = &[
117    // Database related
118    ("数据库", "database"),
119    ("db", "database"),
120    ("postgresql", "postgres"),
121    ("mysql", "mysql"),
122    ("mongodb", "mongo"),
123    ("redis", "redis"),
124    ("sqlite", "sqlite"),
125    ("sql", "database"),
126    // Frontend related
127    ("前端", "frontend"),
128    ("ui", "frontend"),
129    ("界面", "frontend"),
130    ("页面", "page"),
131    ("组件", "component"),
132    ("react", "react"),
133    ("vue", "vue"),
134    ("angular", "angular"),
135    // Backend related
136    ("后端", "backend"),
137    ("api", "api"),
138    ("接口", "api"),
139    ("服务", "service"),
140    ("server", "backend"),
141    ("服务器", "backend"),
142    // Framework/Language
143    ("rust", "rust"),
144    ("python", "python"),
145    ("javascript", "js"),
146    ("typescript", "ts"),
147    ("java", "java"),
148    ("go", "golang"),
149    ("golang", "go"),
150    ("c++", "cpp"),
151    ("cpp", "c++"),
152    ("nodejs", "node"),
153    ("node", "nodejs"),
154    // Tools
155    ("编辑器", "editor"),
156    ("ide", "editor"),
157    ("vim", "vim"),
158    ("vscode", "vscode"),
159    ("emacs", "emacs"),
160    // Config
161    ("配置", "config"),
162    ("设置", "config"),
163    ("config", "config"),
164    ("setting", "config"),
165    // Structure
166    ("目录", "directory"),
167    ("文件", "file"),
168    ("文件夹", "directory"),
169    ("路径", "path"),
170    ("模块", "module"),
171    ("包", "package"),
172    // Testing
173    ("测试", "test"),
174    ("test", "test"),
175    ("单元测试", "unittest"),
176    ("unittest", "test"),
177    // Cache
178    ("缓存", "cache"),
179    ("cache", "cache"),
180    // Auth
181    ("认证", "auth"),
182    ("登录", "login"),
183    ("auth", "auth"),
184    ("登录", "auth"),
185    // Performance
186    ("性能", "performance"),
187    ("优化", "optimize"),
188    ("速度", "speed"),
189    ("慢", "slow"),
190    // Common verbs
191    ("创建", "create"),
192    ("删除", "delete"),
193    ("修改", "modify"),
194    ("添加", "add"),
195    ("更新", "update"),
196    ("查询", "query"),
197];
198
199/// Expand keywords with semantic aliases.
200pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
201    let mut expanded: Vec<String> = keywords.to_vec();
202
203    for keyword in keywords {
204        let kw_lower = keyword.to_lowercase();
205        for (alias, target) in SEMANTIC_ALIASES {
206            if kw_lower.contains(alias) {
207                expanded.push(target.to_string());
208            }
209            if kw_lower.contains(target) {
210                expanded.push(alias.to_string());
211            }
212        }
213    }
214
215    expanded.sort();
216    expanded.dedup();
217    expanded
218}
219
220// ============================================================================
221// Relevance & Contradiction Detection
222// ============================================================================
223
224/// Compute relevance score of a memory entry to context keywords.
225/// Returns 0.0-1.0 where 1.0 means highly relevant.
226pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
227    if context_keywords.is_empty() {
228        return 0.0;
229    }
230
231    let expanded_keywords = expand_semantic_keywords(context_keywords);
232    let content_lower = entry.content.to_lowercase();
233
234    let matches = expanded_keywords
235        .iter()
236        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
237        .count();
238
239    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
240
241    let tag_matches = entry
242        .tags
243        .iter()
244        .filter(|tag| {
245            let tag_lower = tag.to_lowercase();
246            expanded_keywords.iter().any(|kw| {
247                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
248            })
249        })
250        .count();
251
252    let tag_score = if tag_matches > 0 {
253        0.2 + (tag_matches as f64 * 0.05).min(0.1)
254    } else {
255        0.0
256    };
257
258    (keyword_score + tag_score).min(1.0)
259}
260
261/// Detect if two memory contents have contradiction signals.
262pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
263    let change_signals = [
264        "改用",
265        "换成",
266        "替换",
267        "改为",
268        "切换到",
269        "迁移到",
270        "不再使用",
271        "弃用",
272        "放弃",
273        "取消",
274        "switched to",
275        "replaced",
276        "migrated to",
277        "changed to",
278        "no longer",
279        "deprecated",
280        "abandoned",
281    ];
282
283    for signal in &change_signals {
284        if new.contains(signal) {
285            return true;
286        }
287    }
288
289    let action_verbs = [
290        "决定使用",
291        "选择使用",
292        "采用",
293        "使用",
294        "decided to use",
295        "chose",
296        "using",
297        "adopted",
298    ];
299
300    for verb in &action_verbs {
301        if old.contains(verb) && new.contains(verb) {
302            return true;
303        }
304    }
305
306    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
307    for verb in &pref_verbs {
308        if old.contains(verb) && new.contains(verb) {
309            return true;
310        }
311    }
312
313    false
314}
315
316// ============================================================================
317// AI Keyword Extraction (Hybrid)
318// ============================================================================
319
320/// Extract keywords using hybrid approach (rule-based + AI fallback).
321pub async fn extract_keywords_hybrid(
322    context: &str,
323    fast_provider: Option<&dyn crate::providers::Provider>,
324) -> Vec<String> {
325    // First try rule-based extraction
326    let rule_keywords = extract_context_keywords(context);
327
328    // Check if we need AI fallback
329    let mode = AiKeywordMode::from_env();
330    if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
331        // Use AI for keyword extraction
332        if let Some(provider) = fast_provider {
333            let ai_keywords = extract_keywords_with_ai(context, provider).await;
334            if !ai_keywords.is_empty() {
335                return ai_keywords;
336            }
337        }
338    }
339
340    rule_keywords
341}
342
343/// Extract keywords using AI provider.
344async fn extract_keywords_with_ai(
345    context: &str,
346    provider: &dyn crate::providers::Provider,
347) -> Vec<String> {
348    use crate::providers::{ChatRequest, Message, MessageContent, Role};
349
350    let truncated = if context.len() > 2000 {
351        &context[..2000]
352    } else {
353        context
354    };
355
356    let prompt = format!(
357        "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
358        truncated
359    );
360
361    let request = ChatRequest {
362        messages: vec![Message {
363            role: Role::User,
364            content: MessageContent::Text(prompt),
365        }],
366        tools: vec![],
367        system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
368        think: false,
369        max_tokens: 100,
370        server_tools: vec![],
371        enable_caching: false,
372    };
373
374    let response = match provider.chat(request).await {
375        Ok(r) => r,
376        Err(_) => return Vec::new(),
377    };
378
379    let text = response
380        .content
381        .iter()
382        .filter_map(|block| {
383            if let crate::providers::ContentBlock::Text { text } = block {
384                Some(text.clone())
385            } else {
386                None
387            }
388        })
389        .collect::<Vec<_>>()
390        .join("");
391
392    text.split(',')
393        .map(|s| s.trim().to_string())
394        .filter(|s| s.len() >= 2)
395        .collect()
396}
397
398// ============================================================================
399// TF-IDF Search
400// ============================================================================
401
402/// Semantic search using TF-IDF algorithm.
403///
404/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
405/// semantic search method without needing an AI model.
406pub struct TfIdfSearch {
407    /// Word frequency in each document.
408    doc_word_freq: HashMap<String, HashMap<String, f32>>,
409    /// Total documents.
410    total_docs: usize,
411    /// IDF cache.
412    idf_cache: HashMap<String, f32>,
413}
414
415impl TfIdfSearch {
416    /// Create a new TF-IDF search instance.
417    pub fn new() -> Self {
418        Self {
419            doc_word_freq: HashMap::new(),
420            total_docs: 0,
421            idf_cache: HashMap::new(),
422        }
423    }
424
425    /// Index all memories for TF-IDF search.
426    pub fn index(&mut self, memory: &AutoMemory) {
427        self.clear();
428        self.total_docs = memory.entries.len();
429
430        for entry in &memory.entries {
431            let words = self.tokenize(&entry.content);
432            let word_freq = self.compute_word_freq(&words);
433            self.doc_word_freq.insert(entry.content.clone(), word_freq);
434        }
435
436        self.compute_idf();
437    }
438
439    /// Tokenize text into words.
440    fn tokenize(&self, text: &str) -> Vec<String> {
441        let lower = text.to_lowercase();
442        let mut tokens = Vec::new();
443
444        for word in lower.split_whitespace() {
445            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
446            if trimmed.len() > 1 {
447                tokens.push(trimmed.to_string());
448            }
449
450            let chars: Vec<char> = trimmed.chars().collect();
451            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
452
453            if has_cjk {
454                for c in &chars {
455                    if Self::is_cjk(*c) {
456                        tokens.push(c.to_string());
457                    }
458                }
459                for window in chars.windows(2) {
460                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
461                        tokens.push(window.iter().collect::<String>());
462                    }
463                }
464            }
465        }
466
467        tokens
468    }
469
470    /// Check if a character is CJK.
471    fn is_cjk(c: char) -> bool {
472        matches!(c,
473            '\u{4E00}'..='\u{9FFF}' |
474            '\u{3400}'..='\u{4DBF}' |
475            '\u{F900}'..='\u{FAFF}' |
476            '\u{3000}'..='\u{303F}' |
477            '\u{3040}'..='\u{309F}' |
478            '\u{30A0}'..='\u{30FF}'
479        )
480    }
481
482    /// Compute word frequency in a document.
483    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
484        let total = words.len() as f32;
485        let mut freq = HashMap::new();
486
487        for word in words {
488            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
489        }
490
491        for (_, count) in freq.iter_mut() {
492            *count /= total;
493        }
494
495        freq
496    }
497
498    /// Compute IDF for all words.
499    fn compute_idf(&mut self) {
500        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
501
502        for word_freq in &self.doc_word_freq {
503            for word in word_freq.1.keys() {
504                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
505            }
506        }
507
508        for (word, count) in word_doc_count {
509            let idf = (self.total_docs as f32 / count as f32).ln();
510            self.idf_cache.insert(word, idf);
511        }
512    }
513
514    /// Search using TF-IDF similarity.
515    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
516        let query_words = self.tokenize(query);
517        let query_freq = self.compute_word_freq(&query_words);
518
519        let mut results: Vec<(String, f32)> = Vec::new();
520
521        for (doc, doc_freq) in &self.doc_word_freq {
522            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
523
524            if similarity > 0.0 {
525                results.push((doc.clone(), similarity));
526            }
527        }
528
529        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531        if let Some(max) = limit {
532            results.into_iter().take(max).collect()
533        } else {
534            results
535        }
536    }
537
538    /// Search with multiple keywords.
539    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
540        let mut doc_scores: HashMap<String, f64> = HashMap::new();
541
542        for keyword in keywords {
543            let results = self.search(keyword, None);
544            for (doc, score) in results {
545                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
546            }
547        }
548
549        let num_keywords = keywords.len().max(1);
550        for (_, score) in doc_scores.iter_mut() {
551            *score /= num_keywords as f64;
552        }
553
554        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
555        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
556
557        if let Some(max) = limit {
558            results.into_iter().take(max).collect()
559        } else {
560            results
561        }
562    }
563
564    /// Compute TF-IDF similarity.
565    fn compute_tfidf_similarity(
566        &self,
567        query_freq: &HashMap<String, f32>,
568        doc_freq: &HashMap<String, f32>,
569    ) -> f32 {
570        let mut similarity = 0.0;
571
572        for (word, tf_query) in query_freq {
573            if let Some(tf_doc) = doc_freq.get(word)
574                && let Some(idf) = self.idf_cache.get(word)
575            {
576                similarity += tf_query * idf * tf_doc * idf;
577            }
578        }
579
580        similarity
581    }
582
583    /// Clear all indices.
584    pub fn clear(&mut self) {
585        self.doc_word_freq.clear();
586        self.idf_cache.clear();
587        self.total_docs = 0;
588    }
589}
590
591impl Default for TfIdfSearch {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_extract_keywords() {
603        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
604        assert!(!keywords.is_empty());
605    }
606
607    #[test]
608    fn test_semantic_aliases() {
609        let keywords = vec!["数据库".to_string()];
610        let expanded = expand_semantic_keywords(&keywords);
611        assert!(expanded.contains(&"database".to_string()));
612    }
613
614    #[test]
615    fn test_tfidf_search() {
616        let mut tfidf = TfIdfSearch::new();
617        let mut memory = AutoMemory::new();
618
619        // Add multiple documents so IDF calculation works properly
620        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
621        memory.add(super::super::types::MemoryEntry::new(
622            super::super::types::MemoryCategory::Decision,
623            "使用 PostgreSQL 作为数据库".to_string(),
624            None,
625        ));
626        memory.add(super::super::types::MemoryEntry::new(
627            super::super::types::MemoryCategory::Decision,
628            "前端使用 React 框架开发".to_string(),
629            None,
630        ));
631        memory.add(super::super::types::MemoryEntry::new(
632            super::super::types::MemoryCategory::Decision,
633            "后端采用 Rust 编写".to_string(),
634            None,
635        ));
636
637        tfidf.index(&memory);
638        let results = tfidf.search("数据库", Some(5));
639        assert!(!results.is_empty());
640
641        // The PostgreSQL document should be the top result
642        assert!(results[0].0.contains("PostgreSQL"));
643    }
644}