Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, semantic aliases, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::keywords_config::KeywordsConfig;
7use super::types::{AutoMemory, MemoryEntry};
8
9// ============================================================================
10// Keyword Extraction (uses KeywordsConfig)
11// ============================================================================
12
13/// Extract meaningful keywords from conversation context.
14/// Uses KeywordsConfig for stop words and tech keywords.
15/// Improved: avoids meaningless character fragments.
16pub fn extract_context_keywords(context: &str) -> Vec<String> {
17    let config = KeywordsConfig::load();
18    let stop_words = config.get_stop_words_set();
19    let tech_patterns = config.get_tech_keywords_set();
20
21    let lower = context.to_lowercase();
22    let mut keywords: HashSet<String> = HashSet::new();
23
24    // 1. Extract English words (must be meaningful - at least 3 chars)
25    for word in lower.split_whitespace() {
26        let cleaned = word
27            .trim_matches(|c: char| !c.is_alphanumeric())
28            .to_string();
29        // Only accept words at least 3 chars to avoid fragments like "ok", "go"
30        if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
31            keywords.insert(cleaned.clone());
32        }
33        // Always accept known tech patterns even if short
34        if tech_patterns.contains(cleaned.as_str()) {
35            keywords.insert(cleaned);
36        }
37    }
38
39    // 2. Extract meaningful Chinese phrases using known patterns from config
40    // Instead of sliding window (which produces nonsense fragments),
41    // use the predefined patterns for decision/preference/solution/finding
42    for category_patterns in config.patterns.values() {
43        for pattern in category_patterns {
44            if lower.contains(&pattern.to_lowercase()) {
45                keywords.insert(pattern.clone());
46            }
47        }
48    }
49
50    // 3. Extract known tech keywords from config
51    for kw in &config.tech_keywords {
52        if lower.contains(&kw.to_lowercase()) && !stop_words.contains(kw.as_str()) {
53            keywords.insert(kw.clone());
54        }
55    }
56
57    // 4. Extract specific tech patterns (camelCase, snake_case, file paths)
58    let tech_regexes = [
59        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",       // file extensions like .rs, .ts
60        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*", // module.function
61        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",                   // CamelCase
62        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",              // snake_case
63        r"[0-9]+[kKmMgGtT][bB]?",                       // sizes like 4KB, 2MB
64        r"[a-zA-Z]+-[a-zA-Z]+",                         // hyphenated like react-dom
65    ];
66
67    for pattern in tech_regexes {
68        if let Ok(re) = regex::Regex::new(pattern) {
69            for cap in re.find_iter(&lower) {
70                let match_str = cap.as_str();
71                if !stop_words.contains(match_str) {
72                    keywords.insert(match_str.to_string());
73                }
74            }
75        }
76    }
77
78    // Sort by length (longer = more specific) and limit to 10
79    let mut result: Vec<String> = keywords.into_iter().collect();
80    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
81    result.truncate(10);
82
83    result
84}
85
86/// Calculate word-based similarity between two strings (Jaccard coefficient).
87pub fn calculate_similarity(a: &str, b: &str) -> f64 {
88    AutoMemory::calculate_similarity(a, b)
89}
90
91// ============================================================================
92// Semantic Aliases (uses KeywordsConfig)
93// ============================================================================
94
95/// Get semantic aliases from KeywordsConfig.
96pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
97    // Note: This returns static references for compatibility
98    // For dynamic config, use KeywordsConfig::load().get_aliases()
99    SEMANTIC_ALIASES_DEFAULT.to_vec()
100}
101
102/// Default semantic aliases (embedded for fallback).
103pub const SEMANTIC_ALIASES_DEFAULT: &[(&str, &str)] = &[
104    // Database related
105    ("数据库", "database"),
106    ("db", "database"),
107    ("postgresql", "postgres"),
108    ("mysql", "mysql"),
109    ("mongodb", "mongo"),
110    ("redis", "redis"),
111    ("sqlite", "sqlite"),
112    ("sql", "database"),
113    // Frontend related
114    ("前端", "frontend"),
115    ("ui", "frontend"),
116    ("界面", "frontend"),
117    ("页面", "page"),
118    ("组件", "component"),
119    ("react", "react"),
120    ("vue", "vue"),
121    ("angular", "angular"),
122    // Backend related
123    ("后端", "backend"),
124    ("api", "api"),
125    ("接口", "api"),
126    ("服务", "service"),
127    ("server", "backend"),
128    ("服务器", "backend"),
129    // Framework/Language
130    ("rust", "rust"),
131    ("python", "python"),
132    ("javascript", "js"),
133    ("typescript", "ts"),
134    ("java", "java"),
135    ("go", "golang"),
136    ("golang", "go"),
137    ("c++", "cpp"),
138    ("cpp", "c++"),
139    ("nodejs", "node"),
140    ("node", "nodejs"),
141    // Tools
142    ("编辑器", "editor"),
143    ("ide", "editor"),
144    ("vim", "vim"),
145    ("vscode", "vscode"),
146    ("emacs", "emacs"),
147    // Config
148    ("配置", "config"),
149    ("设置", "config"),
150    ("config", "config"),
151    ("setting", "config"),
152    // Structure
153    ("目录", "directory"),
154    ("文件", "file"),
155    ("文件夹", "directory"),
156    ("路径", "path"),
157    ("模块", "module"),
158    ("包", "package"),
159    // Testing
160    ("测试", "test"),
161    ("test", "test"),
162    ("单元测试", "unittest"),
163    ("unittest", "test"),
164    // Cache
165    ("缓存", "cache"),
166    ("cache", "cache"),
167    // Auth
168    ("认证", "auth"),
169    ("登录", "login"),
170    ("auth", "auth"),
171    ("登录", "auth"),
172    // Performance
173    ("性能", "performance"),
174    ("优化", "optimize"),
175    ("速度", "speed"),
176    ("慢", "slow"),
177    // Common verbs
178    ("创建", "create"),
179    ("删除", "delete"),
180    ("修改", "modify"),
181    ("添加", "add"),
182    ("更新", "update"),
183    ("查询", "query"),
184];
185
186/// Expand keywords with semantic aliases from KeywordsConfig.
187pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
188    let config = KeywordsConfig::load();
189    let mut expanded: Vec<String> = keywords.to_vec();
190
191    for keyword in keywords {
192        let kw_lower = keyword.to_lowercase();
193        for (alias, target) in config.get_aliases() {
194            if kw_lower.contains(alias) {
195                expanded.push(target.to_string());
196            }
197            if kw_lower.contains(target) {
198                expanded.push(alias.to_string());
199            }
200        }
201    }
202
203    expanded.sort();
204    expanded.dedup();
205    expanded
206}
207
208// ============================================================================
209// Relevance & Contradiction Detection (uses KeywordsConfig)
210// ============================================================================
211
212/// Compute relevance score of a memory entry to context keywords.
213/// Returns 0.0-1.0 where 1.0 means highly relevant.
214pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
215    if context_keywords.is_empty() {
216        return 0.0;
217    }
218
219    let expanded_keywords = expand_semantic_keywords(context_keywords);
220    let content_lower = entry.content.to_lowercase();
221
222    let matches = expanded_keywords
223        .iter()
224        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
225        .count();
226
227    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
228
229    let tag_matches = entry
230        .tags
231        .iter()
232        .filter(|tag| {
233            let tag_lower = tag.to_lowercase();
234            expanded_keywords.iter().any(|kw| {
235                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
236            })
237        })
238        .count();
239
240    let tag_score = if tag_matches > 0 {
241        0.2 + (tag_matches as f64 * 0.05).min(0.1)
242    } else {
243        0.0
244    };
245
246    (keyword_score + tag_score).min(1.0)
247}
248
249/// Detect if two memory contents have contradiction signals.
250/// Uses KeywordsConfig for contradiction signals.
251pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
252    let config = KeywordsConfig::load();
253
254    // Check contradiction signals from config
255    for signal in &config.contradiction_signals {
256        if new.contains(signal) {
257            return true;
258        }
259    }
260
261    // Check action verbs that indicate change
262    let action_verbs = [
263        "决定使用",
264        "选择使用",
265        "采用",
266        "使用",
267        "decided to use",
268        "chose",
269        "using",
270        "adopted",
271    ];
272
273    for verb in &action_verbs {
274        if old.contains(verb) && new.contains(verb) {
275            return true;
276        }
277    }
278
279    // Check preference verbs
280    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
281    for verb in &pref_verbs {
282        if old.contains(verb) && new.contains(verb) {
283            return true;
284        }
285    }
286
287    false
288}
289
290// ============================================================================
291// AI Keyword Extraction (Hybrid)
292// ============================================================================
293
294/// Extract keywords using hybrid approach (rule-based + AI fallback).
295pub async fn extract_keywords_hybrid(
296    context: &str,
297    fast_provider: Option<&dyn crate::providers::Provider>,
298) -> Vec<String> {
299    // First try rule-based extraction
300    let rule_keywords = extract_context_keywords(context);
301
302    // Check if we need AI fallback
303    let mode = AiKeywordMode::from_env();
304    if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
305        // Use AI for keyword extraction
306        if let Some(provider) = fast_provider {
307            let ai_keywords = extract_keywords_with_ai(context, provider).await;
308            if !ai_keywords.is_empty() {
309                return ai_keywords;
310            }
311        }
312    }
313
314    rule_keywords
315}
316
317/// Extract keywords using AI provider.
318async fn extract_keywords_with_ai(
319    context: &str,
320    provider: &dyn crate::providers::Provider,
321) -> Vec<String> {
322    use crate::providers::{ChatRequest, Message, MessageContent, Role};
323
324    let truncated = if context.len() > 2000 {
325        &context[..2000]
326    } else {
327        context
328    };
329
330    let prompt = format!(
331        "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
332        truncated
333    );
334
335    let request = ChatRequest {
336        messages: vec![Message {
337            role: Role::User,
338            content: MessageContent::Text(prompt),
339        }],
340        tools: vec![],
341        system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
342        think: false,
343        max_tokens: 100,
344        server_tools: vec![],
345        enable_caching: false,
346    };
347
348    let response = match provider.chat(request).await {
349        Ok(r) => r,
350        Err(_) => return Vec::new(),
351    };
352
353    let text = response
354        .content
355        .iter()
356        .filter_map(|block| {
357            if let crate::providers::ContentBlock::Text { text } = block {
358                Some(text.clone())
359            } else {
360                None
361            }
362        })
363        .collect::<Vec<_>>()
364        .join("");
365
366    text.split(',')
367        .map(|s| s.trim().to_string())
368        .filter(|s| s.len() >= 2)
369        .collect()
370}
371
372// ============================================================================
373// TF-IDF Search
374// ============================================================================
375
376/// Semantic search using TF-IDF algorithm.
377///
378/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
379/// semantic search method without needing an AI model.
380pub struct TfIdfSearch {
381    /// Word frequency in each document.
382    doc_word_freq: HashMap<String, HashMap<String, f32>>,
383    /// Total documents.
384    total_docs: usize,
385    /// IDF cache.
386    idf_cache: HashMap<String, f32>,
387}
388
389impl TfIdfSearch {
390    /// Create a new TF-IDF search instance.
391    pub fn new() -> Self {
392        Self {
393            doc_word_freq: HashMap::new(),
394            total_docs: 0,
395            idf_cache: HashMap::new(),
396        }
397    }
398
399    /// Index all memories for TF-IDF search.
400    pub fn index(&mut self, memory: &AutoMemory) {
401        self.clear();
402        self.total_docs = memory.entries.len();
403
404        for entry in &memory.entries {
405            let words = self.tokenize(&entry.content);
406            let word_freq = self.compute_word_freq(&words);
407            self.doc_word_freq.insert(entry.content.clone(), word_freq);
408        }
409
410        self.compute_idf();
411    }
412
413    /// Tokenize text into words.
414    fn tokenize(&self, text: &str) -> Vec<String> {
415        let lower = text.to_lowercase();
416        let mut tokens = Vec::new();
417
418        for word in lower.split_whitespace() {
419            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
420            if trimmed.len() > 1 {
421                tokens.push(trimmed.to_string());
422            }
423
424            let chars: Vec<char> = trimmed.chars().collect();
425            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
426
427            if has_cjk {
428                for c in &chars {
429                    if Self::is_cjk(*c) {
430                        tokens.push(c.to_string());
431                    }
432                }
433                for window in chars.windows(2) {
434                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
435                        tokens.push(window.iter().collect::<String>());
436                    }
437                }
438            }
439        }
440
441        tokens
442    }
443
444    /// Check if a character is CJK.
445    fn is_cjk(c: char) -> bool {
446        matches!(c,
447            '\u{4E00}'..='\u{9FFF}' |
448            '\u{3400}'..='\u{4DBF}' |
449            '\u{F900}'..='\u{FAFF}' |
450            '\u{3000}'..='\u{303F}' |
451            '\u{3040}'..='\u{309F}' |
452            '\u{30A0}'..='\u{30FF}'
453        )
454    }
455
456    /// Compute word frequency in a document.
457    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
458        let total = words.len() as f32;
459        let mut freq = HashMap::new();
460
461        for word in words {
462            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
463        }
464
465        for (_, count) in freq.iter_mut() {
466            *count /= total;
467        }
468
469        freq
470    }
471
472    /// Compute IDF for all words.
473    fn compute_idf(&mut self) {
474        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
475
476        for word_freq in &self.doc_word_freq {
477            for word in word_freq.1.keys() {
478                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
479            }
480        }
481
482        for (word, count) in word_doc_count {
483            let idf = (self.total_docs as f32 / count as f32).ln();
484            self.idf_cache.insert(word, idf);
485        }
486    }
487
488    /// Search using TF-IDF similarity.
489    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
490        let query_words = self.tokenize(query);
491        let query_freq = self.compute_word_freq(&query_words);
492
493        let mut results: Vec<(String, f32)> = Vec::new();
494
495        for (doc, doc_freq) in &self.doc_word_freq {
496            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
497
498            if similarity > 0.0 {
499                results.push((doc.clone(), similarity));
500            }
501        }
502
503        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
504
505        if let Some(max) = limit {
506            results.into_iter().take(max).collect()
507        } else {
508            results
509        }
510    }
511
512    /// Search with multiple keywords.
513    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
514        let mut doc_scores: HashMap<String, f64> = HashMap::new();
515
516        for keyword in keywords {
517            let results = self.search(keyword, None);
518            for (doc, score) in results {
519                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
520            }
521        }
522
523        let num_keywords = keywords.len().max(1);
524        for (_, score) in doc_scores.iter_mut() {
525            *score /= num_keywords as f64;
526        }
527
528        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
529        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531        if let Some(max) = limit {
532            results.into_iter().take(max).collect()
533        } else {
534            results
535        }
536    }
537
538    /// Compute TF-IDF similarity.
539    fn compute_tfidf_similarity(
540        &self,
541        query_freq: &HashMap<String, f32>,
542        doc_freq: &HashMap<String, f32>,
543    ) -> f32 {
544        let mut similarity = 0.0;
545
546        for (word, tf_query) in query_freq {
547            if let Some(tf_doc) = doc_freq.get(word)
548                && let Some(idf) = self.idf_cache.get(word)
549            {
550                similarity += tf_query * idf * tf_doc * idf;
551            }
552        }
553
554        similarity
555    }
556
557    /// Clear all indices.
558    pub fn clear(&mut self) {
559        self.doc_word_freq.clear();
560        self.idf_cache.clear();
561        self.total_docs = 0;
562    }
563}
564
565impl Default for TfIdfSearch {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn test_extract_keywords() {
577        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
578        assert!(!keywords.is_empty());
579    }
580
581    #[test]
582    fn test_semantic_aliases() {
583        let keywords = vec!["数据库".to_string()];
584        let expanded = expand_semantic_keywords(&keywords);
585        assert!(expanded.contains(&"database".to_string()));
586    }
587
588    #[test]
589    fn test_tfidf_search() {
590        let mut tfidf = TfIdfSearch::new();
591        let mut memory = AutoMemory::new();
592
593        // Add multiple documents so IDF calculation works properly
594        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
595        memory.add(super::super::types::MemoryEntry::new(
596            super::super::types::MemoryCategory::Decision,
597            "使用 PostgreSQL 作为数据库".to_string(),
598            None,
599        ));
600        memory.add(super::super::types::MemoryEntry::new(
601            super::super::types::MemoryCategory::Decision,
602            "前端使用 React 框架开发".to_string(),
603            None,
604        ));
605        memory.add(super::super::types::MemoryEntry::new(
606            super::super::types::MemoryCategory::Decision,
607            "后端采用 Rust 编写".to_string(),
608            None,
609        ));
610
611        tfidf.index(&memory);
612        let results = tfidf.search("数据库", Some(5));
613        assert!(!results.is_empty());
614
615        // The PostgreSQL document should be the top result
616        assert!(results[0].0.contains("PostgreSQL"));
617    }
618}