Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, semantic aliases, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::keywords_config::KeywordsConfig;
7use super::types::{AutoMemory, MemoryEntry};
8
9// ============================================================================
10// Keyword Extraction (uses KeywordsConfig)
11// ============================================================================
12
13/// Extract meaningful keywords from conversation context.
14/// Uses KeywordsConfig for stop words and tech keywords.
15pub fn extract_context_keywords(context: &str) -> Vec<String> {
16    let config = KeywordsConfig::load();
17    let stop_words = config.get_stop_words_set();
18    let tech_patterns = config.get_tech_keywords_set();
19
20    let lower = context.to_lowercase();
21    let mut keywords: HashSet<String> = HashSet::new();
22
23    // 1. Extract English words
24    for word in lower.split_whitespace() {
25        let cleaned = word
26            .trim_matches(|c: char| !c.is_alphanumeric())
27            .to_string();
28        if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
29            keywords.insert(cleaned.clone());
30        }
31        if tech_patterns.contains(cleaned.as_str()) {
32            keywords.insert(cleaned);
33        }
34    }
35
36    // 2. Extract Chinese words/phrases
37    let chinese_chars: Vec<char> = lower
38        .chars()
39        .filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}')
40        .collect();
41
42    for window_size in 2..=4 {
43        if chinese_chars.len() >= window_size {
44            for window in chinese_chars.windows(window_size) {
45                let phrase: String = window.iter().collect();
46                let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
47                if !has_stop && phrase.len() >= window_size {
48                    keywords.insert(phrase);
49                }
50            }
51        }
52    }
53
54    // 3. Extract specific patterns
55    let patterns = [
56        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",
57        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*",
58        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",
59        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",
60        r"[0-9]+[kKmMgGtT][bB]?",
61    ];
62
63    for pattern in patterns {
64        if let Ok(re) = regex::Regex::new(pattern) {
65            for cap in re.find_iter(&lower) {
66                keywords.insert(cap.as_str().to_string());
67            }
68        }
69    }
70
71    let mut result: Vec<String> = keywords.into_iter().collect();
72    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
73    result.truncate(15);
74
75    result
76}
77
78/// Calculate word-based similarity between two strings (Jaccard coefficient).
79pub fn calculate_similarity(a: &str, b: &str) -> f64 {
80    AutoMemory::calculate_similarity(a, b)
81}
82
83// ============================================================================
84// Semantic Aliases (uses KeywordsConfig)
85// ============================================================================
86
87/// Get semantic aliases from KeywordsConfig.
88pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
89    // Note: This returns static references for compatibility
90    // For dynamic config, use KeywordsConfig::load().get_aliases()
91    SEMANTIC_ALIASES_DEFAULT.to_vec()
92}
93
94/// Default semantic aliases (embedded for fallback).
95pub const SEMANTIC_ALIASES_DEFAULT: &[(&str, &str)] = &[
96    // Database related
97    ("数据库", "database"),
98    ("db", "database"),
99    ("postgresql", "postgres"),
100    ("mysql", "mysql"),
101    ("mongodb", "mongo"),
102    ("redis", "redis"),
103    ("sqlite", "sqlite"),
104    ("sql", "database"),
105    // Frontend related
106    ("前端", "frontend"),
107    ("ui", "frontend"),
108    ("界面", "frontend"),
109    ("页面", "page"),
110    ("组件", "component"),
111    ("react", "react"),
112    ("vue", "vue"),
113    ("angular", "angular"),
114    // Backend related
115    ("后端", "backend"),
116    ("api", "api"),
117    ("接口", "api"),
118    ("服务", "service"),
119    ("server", "backend"),
120    ("服务器", "backend"),
121    // Framework/Language
122    ("rust", "rust"),
123    ("python", "python"),
124    ("javascript", "js"),
125    ("typescript", "ts"),
126    ("java", "java"),
127    ("go", "golang"),
128    ("golang", "go"),
129    ("c++", "cpp"),
130    ("cpp", "c++"),
131    ("nodejs", "node"),
132    ("node", "nodejs"),
133    // Tools
134    ("编辑器", "editor"),
135    ("ide", "editor"),
136    ("vim", "vim"),
137    ("vscode", "vscode"),
138    ("emacs", "emacs"),
139    // Config
140    ("配置", "config"),
141    ("设置", "config"),
142    ("config", "config"),
143    ("setting", "config"),
144    // Structure
145    ("目录", "directory"),
146    ("文件", "file"),
147    ("文件夹", "directory"),
148    ("路径", "path"),
149    ("模块", "module"),
150    ("包", "package"),
151    // Testing
152    ("测试", "test"),
153    ("test", "test"),
154    ("单元测试", "unittest"),
155    ("unittest", "test"),
156    // Cache
157    ("缓存", "cache"),
158    ("cache", "cache"),
159    // Auth
160    ("认证", "auth"),
161    ("登录", "login"),
162    ("auth", "auth"),
163    ("登录", "auth"),
164    // Performance
165    ("性能", "performance"),
166    ("优化", "optimize"),
167    ("速度", "speed"),
168    ("慢", "slow"),
169    // Common verbs
170    ("创建", "create"),
171    ("删除", "delete"),
172    ("修改", "modify"),
173    ("添加", "add"),
174    ("更新", "update"),
175    ("查询", "query"),
176];
177
178/// Expand keywords with semantic aliases from KeywordsConfig.
179pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
180    let config = KeywordsConfig::load();
181    let mut expanded: Vec<String> = keywords.to_vec();
182
183    for keyword in keywords {
184        let kw_lower = keyword.to_lowercase();
185        for (alias, target) in config.get_aliases() {
186            if kw_lower.contains(alias) {
187                expanded.push(target.to_string());
188            }
189            if kw_lower.contains(target) {
190                expanded.push(alias.to_string());
191            }
192        }
193    }
194
195    expanded.sort();
196    expanded.dedup();
197    expanded
198}
199
200// ============================================================================
201// Relevance & Contradiction Detection (uses KeywordsConfig)
202// ============================================================================
203
204/// Compute relevance score of a memory entry to context keywords.
205/// Returns 0.0-1.0 where 1.0 means highly relevant.
206pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
207    if context_keywords.is_empty() {
208        return 0.0;
209    }
210
211    let expanded_keywords = expand_semantic_keywords(context_keywords);
212    let content_lower = entry.content.to_lowercase();
213
214    let matches = expanded_keywords
215        .iter()
216        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
217        .count();
218
219    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
220
221    let tag_matches = entry
222        .tags
223        .iter()
224        .filter(|tag| {
225            let tag_lower = tag.to_lowercase();
226            expanded_keywords.iter().any(|kw| {
227                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
228            })
229        })
230        .count();
231
232    let tag_score = if tag_matches > 0 {
233        0.2 + (tag_matches as f64 * 0.05).min(0.1)
234    } else {
235        0.0
236    };
237
238    (keyword_score + tag_score).min(1.0)
239}
240
241/// Detect if two memory contents have contradiction signals.
242/// Uses KeywordsConfig for contradiction signals.
243pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
244    let config = KeywordsConfig::load();
245
246    // Check contradiction signals from config
247    for signal in &config.contradiction_signals {
248        if new.contains(signal) {
249            return true;
250        }
251    }
252
253    // Check action verbs that indicate change
254    let action_verbs = [
255        "决定使用",
256        "选择使用",
257        "采用",
258        "使用",
259        "decided to use",
260        "chose",
261        "using",
262        "adopted",
263    ];
264
265    for verb in &action_verbs {
266        if old.contains(verb) && new.contains(verb) {
267            return true;
268        }
269    }
270
271    // Check preference verbs
272    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
273    for verb in &pref_verbs {
274        if old.contains(verb) && new.contains(verb) {
275            return true;
276        }
277    }
278
279    false
280}
281
282// ============================================================================
283// AI Keyword Extraction (Hybrid)
284// ============================================================================
285
286/// Extract keywords using hybrid approach (rule-based + AI fallback).
287pub async fn extract_keywords_hybrid(
288    context: &str,
289    fast_provider: Option<&dyn crate::providers::Provider>,
290) -> Vec<String> {
291    // First try rule-based extraction
292    let rule_keywords = extract_context_keywords(context);
293
294    // Check if we need AI fallback
295    let mode = AiKeywordMode::from_env();
296    if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
297        // Use AI for keyword extraction
298        if let Some(provider) = fast_provider {
299            let ai_keywords = extract_keywords_with_ai(context, provider).await;
300            if !ai_keywords.is_empty() {
301                return ai_keywords;
302            }
303        }
304    }
305
306    rule_keywords
307}
308
309/// Extract keywords using AI provider.
310async fn extract_keywords_with_ai(
311    context: &str,
312    provider: &dyn crate::providers::Provider,
313) -> Vec<String> {
314    use crate::providers::{ChatRequest, Message, MessageContent, Role};
315
316    let truncated = if context.len() > 2000 {
317        &context[..2000]
318    } else {
319        context
320    };
321
322    let prompt = format!(
323        "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
324        truncated
325    );
326
327    let request = ChatRequest {
328        messages: vec![Message {
329            role: Role::User,
330            content: MessageContent::Text(prompt),
331        }],
332        tools: vec![],
333        system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
334        think: false,
335        max_tokens: 100,
336        server_tools: vec![],
337        enable_caching: false,
338    };
339
340    let response = match provider.chat(request).await {
341        Ok(r) => r,
342        Err(_) => return Vec::new(),
343    };
344
345    let text = response
346        .content
347        .iter()
348        .filter_map(|block| {
349            if let crate::providers::ContentBlock::Text { text } = block {
350                Some(text.clone())
351            } else {
352                None
353            }
354        })
355        .collect::<Vec<_>>()
356        .join("");
357
358    text.split(',')
359        .map(|s| s.trim().to_string())
360        .filter(|s| s.len() >= 2)
361        .collect()
362}
363
364// ============================================================================
365// TF-IDF Search
366// ============================================================================
367
368/// Semantic search using TF-IDF algorithm.
369///
370/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
371/// semantic search method without needing an AI model.
372pub struct TfIdfSearch {
373    /// Word frequency in each document.
374    doc_word_freq: HashMap<String, HashMap<String, f32>>,
375    /// Total documents.
376    total_docs: usize,
377    /// IDF cache.
378    idf_cache: HashMap<String, f32>,
379}
380
381impl TfIdfSearch {
382    /// Create a new TF-IDF search instance.
383    pub fn new() -> Self {
384        Self {
385            doc_word_freq: HashMap::new(),
386            total_docs: 0,
387            idf_cache: HashMap::new(),
388        }
389    }
390
391    /// Index all memories for TF-IDF search.
392    pub fn index(&mut self, memory: &AutoMemory) {
393        self.clear();
394        self.total_docs = memory.entries.len();
395
396        for entry in &memory.entries {
397            let words = self.tokenize(&entry.content);
398            let word_freq = self.compute_word_freq(&words);
399            self.doc_word_freq.insert(entry.content.clone(), word_freq);
400        }
401
402        self.compute_idf();
403    }
404
405    /// Tokenize text into words.
406    fn tokenize(&self, text: &str) -> Vec<String> {
407        let lower = text.to_lowercase();
408        let mut tokens = Vec::new();
409
410        for word in lower.split_whitespace() {
411            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
412            if trimmed.len() > 1 {
413                tokens.push(trimmed.to_string());
414            }
415
416            let chars: Vec<char> = trimmed.chars().collect();
417            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
418
419            if has_cjk {
420                for c in &chars {
421                    if Self::is_cjk(*c) {
422                        tokens.push(c.to_string());
423                    }
424                }
425                for window in chars.windows(2) {
426                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
427                        tokens.push(window.iter().collect::<String>());
428                    }
429                }
430            }
431        }
432
433        tokens
434    }
435
436    /// Check if a character is CJK.
437    fn is_cjk(c: char) -> bool {
438        matches!(c,
439            '\u{4E00}'..='\u{9FFF}' |
440            '\u{3400}'..='\u{4DBF}' |
441            '\u{F900}'..='\u{FAFF}' |
442            '\u{3000}'..='\u{303F}' |
443            '\u{3040}'..='\u{309F}' |
444            '\u{30A0}'..='\u{30FF}'
445        )
446    }
447
448    /// Compute word frequency in a document.
449    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
450        let total = words.len() as f32;
451        let mut freq = HashMap::new();
452
453        for word in words {
454            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
455        }
456
457        for (_, count) in freq.iter_mut() {
458            *count /= total;
459        }
460
461        freq
462    }
463
464    /// Compute IDF for all words.
465    fn compute_idf(&mut self) {
466        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
467
468        for word_freq in &self.doc_word_freq {
469            for word in word_freq.1.keys() {
470                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
471            }
472        }
473
474        for (word, count) in word_doc_count {
475            let idf = (self.total_docs as f32 / count as f32).ln();
476            self.idf_cache.insert(word, idf);
477        }
478    }
479
480    /// Search using TF-IDF similarity.
481    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
482        let query_words = self.tokenize(query);
483        let query_freq = self.compute_word_freq(&query_words);
484
485        let mut results: Vec<(String, f32)> = Vec::new();
486
487        for (doc, doc_freq) in &self.doc_word_freq {
488            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
489
490            if similarity > 0.0 {
491                results.push((doc.clone(), similarity));
492            }
493        }
494
495        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496
497        if let Some(max) = limit {
498            results.into_iter().take(max).collect()
499        } else {
500            results
501        }
502    }
503
504    /// Search with multiple keywords.
505    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
506        let mut doc_scores: HashMap<String, f64> = HashMap::new();
507
508        for keyword in keywords {
509            let results = self.search(keyword, None);
510            for (doc, score) in results {
511                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
512            }
513        }
514
515        let num_keywords = keywords.len().max(1);
516        for (_, score) in doc_scores.iter_mut() {
517            *score /= num_keywords as f64;
518        }
519
520        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
521        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
522
523        if let Some(max) = limit {
524            results.into_iter().take(max).collect()
525        } else {
526            results
527        }
528    }
529
530    /// Compute TF-IDF similarity.
531    fn compute_tfidf_similarity(
532        &self,
533        query_freq: &HashMap<String, f32>,
534        doc_freq: &HashMap<String, f32>,
535    ) -> f32 {
536        let mut similarity = 0.0;
537
538        for (word, tf_query) in query_freq {
539            if let Some(tf_doc) = doc_freq.get(word)
540                && let Some(idf) = self.idf_cache.get(word)
541            {
542                similarity += tf_query * idf * tf_doc * idf;
543            }
544        }
545
546        similarity
547    }
548
549    /// Clear all indices.
550    pub fn clear(&mut self) {
551        self.doc_word_freq.clear();
552        self.idf_cache.clear();
553        self.total_docs = 0;
554    }
555}
556
557impl Default for TfIdfSearch {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_extract_keywords() {
569        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
570        assert!(!keywords.is_empty());
571    }
572
573    #[test]
574    fn test_semantic_aliases() {
575        let keywords = vec!["数据库".to_string()];
576        let expanded = expand_semantic_keywords(&keywords);
577        assert!(expanded.contains(&"database".to_string()));
578    }
579
580    #[test]
581    fn test_tfidf_search() {
582        let mut tfidf = TfIdfSearch::new();
583        let mut memory = AutoMemory::new();
584
585        // Add multiple documents so IDF calculation works properly
586        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
587        memory.add(super::super::types::MemoryEntry::new(
588            super::super::types::MemoryCategory::Decision,
589            "使用 PostgreSQL 作为数据库".to_string(),
590            None,
591        ));
592        memory.add(super::super::types::MemoryEntry::new(
593            super::super::types::MemoryCategory::Decision,
594            "前端使用 React 框架开发".to_string(),
595            None,
596        ));
597        memory.add(super::super::types::MemoryEntry::new(
598            super::super::types::MemoryCategory::Decision,
599            "后端采用 Rust 编写".to_string(),
600            None,
601        ));
602
603        tfidf.index(&memory);
604        let results = tfidf.search("数据库", Some(5));
605        assert!(!results.is_empty());
606
607        // The PostgreSQL document should be the top result
608        assert!(results[0].0.contains("PostgreSQL"));
609    }
610}