Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::entry::MemoryEntry;
6use super::manager::AutoMemory;
7
8// ============================================================================
9// Hard-coded Stop Words (simplified)
10// ============================================================================
11
12/// Common Chinese + English stop words (minimal set).
13fn get_stop_words() -> HashSet<&'static str> {
14    HashSet::from([
15        // Chinese
16        "的", "了", "是", "在", "我", "有", "和", "就", "不", "都", "一", "也", "很", "到", "要", "去",
17        "你", "会", "着", "没有", "看", "好", "这", "那", "什么", "怎么", "请", "能", "可以", "需要",
18        // English
19        "the", "a", "an", "is", "are", "was", "were", "be", "have", "has", "do", "will", "would",
20        "could", "should", "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "and",
21        "but", "or", "not", "if", "then", "this", "that", "it", "i", "me", "my", "we", "you", "he",
22        "she", "they", "please", "help", "need", "want", "let", "use",
23    ])
24}
25
26// ============================================================================
27// Keyword Extraction (simplified)
28// ============================================================================
29
30/// Extract meaningful keywords from conversation context.
31pub fn extract_context_keywords(context: &str) -> Vec<String> {
32    let stop_words = get_stop_words();
33    let lower = context.to_lowercase();
34    let mut keywords: HashSet<String> = HashSet::new();
35
36    // 1. Extract English words (must be meaningful - at least 3 chars)
37    for word in lower.split_whitespace() {
38        let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
39        if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
40            keywords.insert(cleaned);
41        }
42    }
43
44    // 2. Extract tech patterns (camelCase, snake_case, file paths)
45    let tech_regexes = [
46        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",       // file extensions
47        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",                   // CamelCase
48        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",              // snake_case
49        r"[0-9]+[kKmMgGtT][bB]?",                       // sizes like 4KB
50    ];
51
52    for pattern in tech_regexes {
53        if let Ok(re) = regex::Regex::new(pattern) {
54            for cap in re.find_iter(&lower) {
55                let match_str = cap.as_str();
56                if !stop_words.contains(match_str) {
57                    keywords.insert(match_str.to_string());
58                }
59            }
60        }
61    }
62
63    // Sort by length and limit
64    let mut result: Vec<String> = keywords.into_iter().collect();
65    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
66    result.truncate(10);
67    result
68}
69
70// ============================================================================
71// Skip Simple Messages (Greeting/Short)
72// ============================================================================
73
74/// Greeting patterns to skip keyword extraction.
75const GREETING_PATTERNS: &[&str] = &[
76    "你好", "您好", "hi", "hello", "hey", "嗨", "早上好", "下午好", "晚上好",
77    "good morning", "good afternoon", "good evening",
78    "请问", "帮忙", "帮我", "帮我看", "看看", "help", "请",
79    "开始", "start", "准备好了", "ready",
80];
81
82/// Check if message is simple (greeting/short) and should skip AI keyword extraction.
83/// Returns true if should skip.
84pub fn should_skip_simple_message(msg: &str) -> bool {
85    let trimmed = msg.trim();
86
87    // Skip if too short (< 15 chars)
88    if trimmed.len() < 15 {
89        return true;
90    }
91
92    // Skip greeting patterns
93    let lower = trimmed.to_lowercase();
94    for pattern in GREETING_PATTERNS {
95        if lower.starts_with(pattern) || lower == *pattern {
96            return true;
97        }
98    }
99
100    false
101}
102
103/// Calculate word-based similarity between two strings (Jaccard coefficient).
104pub fn calculate_similarity(a: &str, b: &str) -> f64 {
105    AutoMemory::calculate_similarity(a, b)
106}
107
108// ============================================================================
109// Semantic Aliases (uses KeywordsConfig)
110// ============================================================================
111
112/// Get semantic aliases (hard-coded).
113pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
114    vec![
115        // Technical terms
116        ("rust", "Rust"),
117        ("typescript", "TypeScript"),
118        ("javascript", "JavaScript"),
119        ("python", "Python"),
120        ("react", "React"),
121        ("vue", "Vue"),
122        ("angular", "Angular"),
123        ("数据库", "database"),
124        ("db", "database"),
125        // Actions
126        ("修复", "fix"),
127        ("解决", "solve"),
128        ("优化", "optimize"),
129        ("重构", "refactor"),
130        ("更新", "update"),
131        ("删除", "delete"),
132        // Preferences
133        ("喜欢", "prefer"),
134        ("偏好", "prefer"),
135        ("首选", "prefer"),
136        // Structures
137        ("入口", "entry"),
138        ("主文件", "main"),
139        ("目录", "directory"),
140    ]
141}
142
143/// Expand keywords with semantic aliases.
144pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
145    let aliases = get_semantic_aliases();
146    let mut expanded: Vec<String> = keywords.to_vec();
147
148    for keyword in keywords {
149        let kw_lower = keyword.to_lowercase();
150        for &(alias, target) in &aliases {
151            if kw_lower.contains(alias) {
152                expanded.push(target.to_string());
153            }
154            if kw_lower.contains(target) {
155                expanded.push(alias.to_string());
156            }
157        }
158    }
159
160    expanded.sort();
161    expanded.dedup();
162    expanded
163}
164
165// ============================================================================
166// Relevance & Contradiction Detection (uses KeywordsConfig)
167// ============================================================================
168
169/// Compute relevance score of a memory entry to context keywords.
170/// Returns 0.0-1.0 where 1.0 means highly relevant.
171pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
172    if context_keywords.is_empty() {
173        return 0.0;
174    }
175
176    let expanded_keywords = expand_semantic_keywords(context_keywords);
177    let content_lower = entry.content.to_lowercase();
178
179    let matches = expanded_keywords
180        .iter()
181        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
182        .count();
183
184    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
185
186    let tag_matches = entry
187        .tags
188        .iter()
189        .filter(|tag| {
190            let tag_lower = tag.to_lowercase();
191            expanded_keywords.iter().any(|kw| {
192                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
193            })
194        })
195        .count();
196
197    let tag_score = if tag_matches > 0 {
198        0.2 + (tag_matches as f64 * 0.05).min(0.1)
199    } else {
200        0.0
201    };
202
203    (keyword_score + tag_score).min(1.0)
204}
205
206/// Detect if two memory contents have contradiction signals.
207/// Uses hard-coded contradiction signals.
208pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
209    // Hard-coded contradiction signals
210    let contradiction_signals = [
211        "不再",
212        "改为",
213        "换成",
214        "放弃",
215        "no longer",
216        "instead of",
217        "changed to",
218        "switched to",
219    ];
220
221    // Check contradiction signals
222    for signal in &contradiction_signals {
223        if new.contains(signal) {
224            return true;
225        }
226    }
227
228    // Check action verbs that indicate change
229    let action_verbs = [
230        "决定使用",
231        "选择使用",
232        "采用",
233        "使用",
234        "decided to use",
235        "chose",
236        "using",
237        "adopted",
238    ];
239
240    for verb in &action_verbs {
241        if old.contains(verb) && new.contains(verb) {
242            return true;
243        }
244    }
245
246    // Check preference verbs
247    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
248    for verb in &pref_verbs {
249        if old.contains(verb) && new.contains(verb) {
250            return true;
251        }
252    }
253
254    false
255}
256
257
258
259// ============================================================================
260// TF-IDF Search
261// ============================================================================
262
263/// Semantic search using TF-IDF algorithm.
264///
265/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
266/// semantic search method without needing an AI model.
267pub struct TfIdfSearch {
268    /// Word frequency in each document.
269    doc_word_freq: HashMap<String, HashMap<String, f32>>,
270    /// Total documents.
271    total_docs: usize,
272    /// IDF cache.
273    idf_cache: HashMap<String, f32>,
274}
275
276impl TfIdfSearch {
277    /// Create a new TF-IDF search instance.
278    pub fn new() -> Self {
279        Self {
280            doc_word_freq: HashMap::new(),
281            total_docs: 0,
282            idf_cache: HashMap::new(),
283        }
284    }
285
286    /// Index all memories for TF-IDF search.
287    pub fn index(&mut self, memory: &AutoMemory) {
288        self.clear();
289        self.total_docs = memory.entries.len();
290
291        for entry in &memory.entries {
292            let words = self.tokenize(&entry.content);
293            let word_freq = self.compute_word_freq(&words);
294            self.doc_word_freq.insert(entry.content.clone(), word_freq);
295        }
296
297        self.compute_idf();
298    }
299
300    /// Tokenize text into words.
301    fn tokenize(&self, text: &str) -> Vec<String> {
302        let lower = text.to_lowercase();
303        let mut tokens = Vec::new();
304
305        for word in lower.split_whitespace() {
306            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
307            if trimmed.len() > 1 {
308                tokens.push(trimmed.to_string());
309            }
310
311            let chars: Vec<char> = trimmed.chars().collect();
312            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
313
314            if has_cjk {
315                for c in &chars {
316                    if Self::is_cjk(*c) {
317                        tokens.push(c.to_string());
318                    }
319                }
320                for window in chars.windows(2) {
321                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
322                        tokens.push(window.iter().collect::<String>());
323                    }
324                }
325            }
326        }
327
328        tokens
329    }
330
331    /// Check if a character is CJK.
332    fn is_cjk(c: char) -> bool {
333        matches!(c,
334            '\u{4E00}'..='\u{9FFF}' |
335            '\u{3400}'..='\u{4DBF}' |
336            '\u{F900}'..='\u{FAFF}' |
337            '\u{3000}'..='\u{303F}' |
338            '\u{3040}'..='\u{309F}' |
339            '\u{30A0}'..='\u{30FF}'
340        )
341    }
342
343    /// Compute word frequency in a document.
344    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
345        let total = words.len() as f32;
346        let mut freq = HashMap::new();
347
348        for word in words {
349            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
350        }
351
352        for (_, count) in freq.iter_mut() {
353            *count /= total;
354        }
355
356        freq
357    }
358
359    /// Compute IDF for all words.
360    fn compute_idf(&mut self) {
361        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
362
363        for word_freq in &self.doc_word_freq {
364            for word in word_freq.1.keys() {
365                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
366            }
367        }
368
369        for (word, count) in word_doc_count {
370            let idf = (self.total_docs as f32 / count as f32).ln();
371            self.idf_cache.insert(word, idf);
372        }
373    }
374
375    /// Search using TF-IDF similarity.
376    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
377        let query_words = self.tokenize(query);
378        let query_freq = self.compute_word_freq(&query_words);
379
380        let mut results: Vec<(String, f32)> = Vec::new();
381
382        for (doc, doc_freq) in &self.doc_word_freq {
383            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
384
385            if similarity > 0.0 {
386                results.push((doc.clone(), similarity));
387            }
388        }
389
390        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
391
392        if let Some(max) = limit {
393            results.into_iter().take(max).collect()
394        } else {
395            results
396        }
397    }
398
399    /// Search with multiple keywords.
400    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
401        let mut doc_scores: HashMap<String, f64> = HashMap::new();
402
403        for keyword in keywords {
404            let results = self.search(keyword, None);
405            for (doc, score) in results {
406                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
407            }
408        }
409
410        let num_keywords = keywords.len().max(1);
411        for (_, score) in doc_scores.iter_mut() {
412            *score /= num_keywords as f64;
413        }
414
415        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
416        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417
418        if let Some(max) = limit {
419            results.into_iter().take(max).collect()
420        } else {
421            results
422        }
423    }
424
425    /// Compute TF-IDF similarity.
426    fn compute_tfidf_similarity(
427        &self,
428        query_freq: &HashMap<String, f32>,
429        doc_freq: &HashMap<String, f32>,
430    ) -> f32 {
431        let mut similarity = 0.0;
432
433        for (word, tf_query) in query_freq {
434            if let Some(tf_doc) = doc_freq.get(word)
435                && let Some(idf) = self.idf_cache.get(word)
436            {
437                similarity += tf_query * idf * tf_doc * idf;
438            }
439        }
440
441        similarity
442    }
443
444    /// Clear all indices.
445    pub fn clear(&mut self) {
446        self.doc_word_freq.clear();
447        self.idf_cache.clear();
448        self.total_docs = 0;
449    }
450}
451
452impl Default for TfIdfSearch {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458// ============================================================================
459// AI Memory Selection (Claude Code style)
460// ============================================================================
461
462/// System prompt for AI memory selection.
463const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
464
465返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
466- 只选择你确定会有帮助的记忆
467- 如果不确定某个记忆是否有用,不要选择它
468- 如果没有明显有用的记忆,可以返回空数组 []
469- 优先选择与当前问题直接相关的记忆
470
471返回格式示例:{"selected": [0, 2, 5]}
472"#;
473
474/// Select relevant memories using AI (Claude Code style).
475///
476/// Takes user query and memory manifest (descriptions), uses AI to select
477/// the most relevant ones (up to 5).
478pub async fn ai_select_memories(
479    query: &str,
480    memory_manifest: &str,
481    provider: &dyn crate::providers::Provider,
482) -> Vec<usize> {
483    use crate::providers::{ChatRequest, Message, MessageContent, Role};
484
485    // Truncate query if too long
486    let truncated_query = if query.len() > 1000 {
487        &query[..1000]
488    } else {
489        query
490    };
491
492    let user_prompt = format!(
493        "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
494        truncated_query, memory_manifest
495    );
496
497    let request = ChatRequest {
498        messages: vec![Message {
499            role: Role::User,
500            content: MessageContent::Text(user_prompt),
501        }],
502        tools: vec![],
503        system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
504        think: false,
505        max_tokens: 100,
506        server_tools: vec![],
507        enable_caching: false,
508    };
509
510    let response = match provider.chat(request).await {
511        Ok(r) => r,
512        Err(_) => return Vec::new(),
513    };
514
515    // Extract text from response
516    let text = response
517        .content
518        .iter()
519        .filter_map(|block| {
520            if let crate::providers::ContentBlock::Text { text } = block {
521                Some(text.clone())
522            } else {
523                None
524            }
525        })
526        .collect::<Vec<_>>()
527        .join("");
528
529    // Parse JSON response
530    parse_selected_indices(&text)
531}
532
533/// Parse selected indices from AI response.
534fn parse_selected_indices(text: &str) -> Vec<usize> {
535    // Try to parse as JSON
536    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
537        if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
538            return selected
539                .iter()
540                .filter_map(|v| v.as_u64().map(|n| n as usize))
541                .collect();
542        }
543        // Also try direct array format
544        if let Some(arr) = json.as_array() {
545            return arr
546                .iter()
547                .filter_map(|v| v.as_u64().map(|n| n as usize))
548                .collect();
549        }
550    }
551
552    // Fallback: try to extract numbers from text
553    let mut indices = Vec::new();
554    for part in text.split(',') {
555        let trimmed = part.trim();
556        if let Ok(n) = trimmed.parse::<usize>() {
557            indices.push(n);
558        }
559    }
560    indices
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use crate::memory::MemoryCategory;
567
568    #[test]
569    fn test_extract_keywords() {
570        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
571        assert!(!keywords.is_empty());
572    }
573
574    #[test]
575    fn test_semantic_aliases() {
576        let keywords = vec!["数据库".to_string()];
577        let expanded = expand_semantic_keywords(&keywords);
578        assert!(expanded.contains(&"database".to_string()));
579    }
580
581    #[test]
582    fn test_tfidf_search() {
583        let mut tfidf = TfIdfSearch::new();
584        let mut memory = AutoMemory::new();
585
586        // Add multiple documents so IDF calculation works properly
587        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
588        memory.add(MemoryEntry::new(
589            MemoryCategory::Decision,
590            "使用 PostgreSQL 作为数据库".to_string(),
591            None,
592            None,
593        ));
594        memory.add(MemoryEntry::new(
595            MemoryCategory::Decision,
596            "前端使用 React 框架开发".to_string(),
597            None,
598            None,
599        ));
600        memory.add(MemoryEntry::new(
601            MemoryCategory::Decision,
602            "后端采用 Rust 编写".to_string(),
603            None,
604            None,
605        ));
606
607        tfidf.index(&memory);
608        let results = tfidf.search("数据库", Some(5));
609        assert!(!results.is_empty());
610
611        // The PostgreSQL document should be the top result
612        assert!(results[0].0.contains("PostgreSQL"));
613    }
614}