Skip to main content

matrixcode_core/memory/
retrieval.rs

1//! Retrieval helpers: TF-IDF search, semantic aliases, keyword extraction.
2
3use std::collections::{HashMap, HashSet};
4
5use super::keywords_config::KeywordsConfig;
6use super::types::{AutoMemory, MemoryEntry};
7
8// ============================================================================
9// Keyword Extraction (uses KeywordsConfig)
10// ============================================================================
11
12/// Extract meaningful keywords from conversation context.
13pub fn extract_context_keywords(context: &str) -> Vec<String> {
14    let config = KeywordsConfig::load();
15    let stop_words = config.get_stop_words_set();
16
17    let lower = context.to_lowercase();
18    let mut keywords: HashSet<String> = HashSet::new();
19
20    // 1. Extract English words (must be meaningful - at least 3 chars)
21    for word in lower.split_whitespace() {
22        let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
23        if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
24            keywords.insert(cleaned);
25        }
26    }
27
28    // 2. Extract patterns from config
29    for category_patterns in config.patterns.values() {
30        for pattern in category_patterns {
31            if lower.contains(&pattern.to_lowercase()) {
32                keywords.insert(pattern.clone());
33            }
34        }
35    }
36
37    // 3. Extract tech patterns (camelCase, snake_case, file paths)
38    let tech_regexes = [
39        r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",       // file extensions
40        r"[A-Z][a-z]+[A-Z][a-zA-Z]*",                   // CamelCase
41        r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",              // snake_case
42        r"[0-9]+[kKmMgGtT][bB]?",                       // sizes like 4KB
43    ];
44
45    for pattern in tech_regexes {
46        if let Ok(re) = regex::Regex::new(pattern) {
47            for cap in re.find_iter(&lower) {
48                let match_str = cap.as_str();
49                if !stop_words.contains(match_str) {
50                    keywords.insert(match_str.to_string());
51                }
52            }
53        }
54    }
55
56    // Sort by length and limit
57    let mut result: Vec<String> = keywords.into_iter().collect();
58    result.sort_by_key(|b| std::cmp::Reverse(b.len()));
59    result.truncate(10);
60    result
61}
62
63// ============================================================================
64// Skip Simple Messages (Greeting/Short)
65// ============================================================================
66
67/// Greeting patterns to skip keyword extraction.
68const GREETING_PATTERNS: &[&str] = &[
69    "你好", "您好", "hi", "hello", "hey", "嗨", "早上好", "下午好", "晚上好",
70    "good morning", "good afternoon", "good evening",
71    "请问", "帮忙", "帮我", "帮我看", "看看", "help", "请",
72    "开始", "start", "准备好了", "ready",
73];
74
75/// Check if message is simple (greeting/short) and should skip AI keyword extraction.
76/// Returns true if should skip.
77pub fn should_skip_simple_message(msg: &str) -> bool {
78    let trimmed = msg.trim();
79
80    // Skip if too short (< 15 chars)
81    if trimmed.len() < 15 {
82        return true;
83    }
84
85    // Skip greeting patterns
86    let lower = trimmed.to_lowercase();
87    for pattern in GREETING_PATTERNS {
88        if lower.starts_with(pattern) || lower == *pattern {
89            return true;
90        }
91    }
92
93    false
94}
95
96/// Calculate word-based similarity between two strings (Jaccard coefficient).
97pub fn calculate_similarity(a: &str, b: &str) -> f64 {
98    AutoMemory::calculate_similarity(a, b)
99}
100
101// ============================================================================
102// Semantic Aliases (uses KeywordsConfig)
103// ============================================================================
104
105/// Get semantic aliases.
106pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
107    KeywordsConfig::get_aliases()
108}
109
110/// Expand keywords with semantic aliases.
111pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
112    let aliases = KeywordsConfig::get_aliases();
113    let mut expanded: Vec<String> = keywords.to_vec();
114
115    for keyword in keywords {
116        let kw_lower = keyword.to_lowercase();
117        for &(alias, target) in &aliases {
118            if kw_lower.contains(alias) {
119                expanded.push(target.to_string());
120            }
121            if kw_lower.contains(target) {
122                expanded.push(alias.to_string());
123            }
124        }
125    }
126
127    expanded.sort();
128    expanded.dedup();
129    expanded
130}
131
132// ============================================================================
133// Relevance & Contradiction Detection (uses KeywordsConfig)
134// ============================================================================
135
136/// Compute relevance score of a memory entry to context keywords.
137/// Returns 0.0-1.0 where 1.0 means highly relevant.
138pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
139    if context_keywords.is_empty() {
140        return 0.0;
141    }
142
143    let expanded_keywords = expand_semantic_keywords(context_keywords);
144    let content_lower = entry.content.to_lowercase();
145
146    let matches = expanded_keywords
147        .iter()
148        .filter(|kw| content_lower.contains(&kw.to_lowercase()))
149        .count();
150
151    let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
152
153    let tag_matches = entry
154        .tags
155        .iter()
156        .filter(|tag| {
157            let tag_lower = tag.to_lowercase();
158            expanded_keywords.iter().any(|kw| {
159                tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
160            })
161        })
162        .count();
163
164    let tag_score = if tag_matches > 0 {
165        0.2 + (tag_matches as f64 * 0.05).min(0.1)
166    } else {
167        0.0
168    };
169
170    (keyword_score + tag_score).min(1.0)
171}
172
173/// Detect if two memory contents have contradiction signals.
174/// Uses KeywordsConfig for contradiction signals.
175pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
176    let config = KeywordsConfig::load();
177
178    // Check contradiction signals from config
179    for signal in &config.contradiction_signals {
180        if new.contains(signal) {
181            return true;
182        }
183    }
184
185    // Check action verbs that indicate change
186    let action_verbs = [
187        "决定使用",
188        "选择使用",
189        "采用",
190        "使用",
191        "decided to use",
192        "chose",
193        "using",
194        "adopted",
195    ];
196
197    for verb in &action_verbs {
198        if old.contains(verb) && new.contains(verb) {
199            return true;
200        }
201    }
202
203    // Check preference verbs
204    let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
205    for verb in &pref_verbs {
206        if old.contains(verb) && new.contains(verb) {
207            return true;
208        }
209    }
210
211    false
212}
213
214// ============================================================================
215// AI Keyword Extraction (Hybrid) - DEPRECATED, use extract_context_keywords instead
216// ============================================================================
217
218/// Extract keywords using hybrid approach (rule-based only now, AI removed).
219/// DEPRECATED: This function now just calls extract_context_keywords.
220/// Use extract_context_keywords directly for clarity.
221pub async fn extract_keywords_hybrid(
222    context: &str,
223    _fast_provider: Option<&dyn crate::providers::Provider>,
224) -> Vec<String> {
225    // AI keyword extraction removed - just use rule-based
226    extract_context_keywords(context)
227}
228
229// ============================================================================
230// TF-IDF Search
231// ============================================================================
232
233/// Semantic search using TF-IDF algorithm.
234///
235/// TF-IDF (Term Frequency-Inverse Document Frequency) is a
236/// semantic search method without needing an AI model.
237pub struct TfIdfSearch {
238    /// Word frequency in each document.
239    doc_word_freq: HashMap<String, HashMap<String, f32>>,
240    /// Total documents.
241    total_docs: usize,
242    /// IDF cache.
243    idf_cache: HashMap<String, f32>,
244}
245
246impl TfIdfSearch {
247    /// Create a new TF-IDF search instance.
248    pub fn new() -> Self {
249        Self {
250            doc_word_freq: HashMap::new(),
251            total_docs: 0,
252            idf_cache: HashMap::new(),
253        }
254    }
255
256    /// Index all memories for TF-IDF search.
257    pub fn index(&mut self, memory: &AutoMemory) {
258        self.clear();
259        self.total_docs = memory.entries.len();
260
261        for entry in &memory.entries {
262            let words = self.tokenize(&entry.content);
263            let word_freq = self.compute_word_freq(&words);
264            self.doc_word_freq.insert(entry.content.clone(), word_freq);
265        }
266
267        self.compute_idf();
268    }
269
270    /// Tokenize text into words.
271    fn tokenize(&self, text: &str) -> Vec<String> {
272        let lower = text.to_lowercase();
273        let mut tokens = Vec::new();
274
275        for word in lower.split_whitespace() {
276            let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
277            if trimmed.len() > 1 {
278                tokens.push(trimmed.to_string());
279            }
280
281            let chars: Vec<char> = trimmed.chars().collect();
282            let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
283
284            if has_cjk {
285                for c in &chars {
286                    if Self::is_cjk(*c) {
287                        tokens.push(c.to_string());
288                    }
289                }
290                for window in chars.windows(2) {
291                    if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
292                        tokens.push(window.iter().collect::<String>());
293                    }
294                }
295            }
296        }
297
298        tokens
299    }
300
301    /// Check if a character is CJK.
302    fn is_cjk(c: char) -> bool {
303        matches!(c,
304            '\u{4E00}'..='\u{9FFF}' |
305            '\u{3400}'..='\u{4DBF}' |
306            '\u{F900}'..='\u{FAFF}' |
307            '\u{3000}'..='\u{303F}' |
308            '\u{3040}'..='\u{309F}' |
309            '\u{30A0}'..='\u{30FF}'
310        )
311    }
312
313    /// Compute word frequency in a document.
314    fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
315        let total = words.len() as f32;
316        let mut freq = HashMap::new();
317
318        for word in words {
319            *freq.entry(word.clone()).or_insert(0.0) += 1.0;
320        }
321
322        for (_, count) in freq.iter_mut() {
323            *count /= total;
324        }
325
326        freq
327    }
328
329    /// Compute IDF for all words.
330    fn compute_idf(&mut self) {
331        let mut word_doc_count: HashMap<String, usize> = HashMap::new();
332
333        for word_freq in &self.doc_word_freq {
334            for word in word_freq.1.keys() {
335                *word_doc_count.entry(word.clone()).or_insert(0) += 1;
336            }
337        }
338
339        for (word, count) in word_doc_count {
340            let idf = (self.total_docs as f32 / count as f32).ln();
341            self.idf_cache.insert(word, idf);
342        }
343    }
344
345    /// Search using TF-IDF similarity.
346    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
347        let query_words = self.tokenize(query);
348        let query_freq = self.compute_word_freq(&query_words);
349
350        let mut results: Vec<(String, f32)> = Vec::new();
351
352        for (doc, doc_freq) in &self.doc_word_freq {
353            let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
354
355            if similarity > 0.0 {
356                results.push((doc.clone(), similarity));
357            }
358        }
359
360        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
361
362        if let Some(max) = limit {
363            results.into_iter().take(max).collect()
364        } else {
365            results
366        }
367    }
368
369    /// Search with multiple keywords.
370    pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
371        let mut doc_scores: HashMap<String, f64> = HashMap::new();
372
373        for keyword in keywords {
374            let results = self.search(keyword, None);
375            for (doc, score) in results {
376                *doc_scores.entry(doc).or_insert(0.0) += score as f64;
377            }
378        }
379
380        let num_keywords = keywords.len().max(1);
381        for (_, score) in doc_scores.iter_mut() {
382            *score /= num_keywords as f64;
383        }
384
385        let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
386        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
387
388        if let Some(max) = limit {
389            results.into_iter().take(max).collect()
390        } else {
391            results
392        }
393    }
394
395    /// Compute TF-IDF similarity.
396    fn compute_tfidf_similarity(
397        &self,
398        query_freq: &HashMap<String, f32>,
399        doc_freq: &HashMap<String, f32>,
400    ) -> f32 {
401        let mut similarity = 0.0;
402
403        for (word, tf_query) in query_freq {
404            if let Some(tf_doc) = doc_freq.get(word)
405                && let Some(idf) = self.idf_cache.get(word)
406            {
407                similarity += tf_query * idf * tf_doc * idf;
408            }
409        }
410
411        similarity
412    }
413
414    /// Clear all indices.
415    pub fn clear(&mut self) {
416        self.doc_word_freq.clear();
417        self.idf_cache.clear();
418        self.total_docs = 0;
419    }
420}
421
422impl Default for TfIdfSearch {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428// ============================================================================
429// AI Memory Selection (Claude Code style)
430// ============================================================================
431
432/// System prompt for AI memory selection.
433const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
434
435返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
436- 只选择你确定会有帮助的记忆
437- 如果不确定某个记忆是否有用,不要选择它
438- 如果没有明显有用的记忆,可以返回空数组 []
439- 优先选择与当前问题直接相关的记忆
440
441返回格式示例:{"selected": [0, 2, 5]}
442"#;
443
444/// Select relevant memories using AI (Claude Code style).
445///
446/// Takes user query and memory manifest (descriptions), uses AI to select
447/// the most relevant ones (up to 5).
448pub async fn ai_select_memories(
449    query: &str,
450    memory_manifest: &str,
451    provider: &dyn crate::providers::Provider,
452) -> Vec<usize> {
453    use crate::providers::{ChatRequest, Message, MessageContent, Role};
454
455    // Truncate query if too long
456    let truncated_query = if query.len() > 1000 {
457        &query[..1000]
458    } else {
459        query
460    };
461
462    let user_prompt = format!(
463        "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
464        truncated_query, memory_manifest
465    );
466
467    let request = ChatRequest {
468        messages: vec![Message {
469            role: Role::User,
470            content: MessageContent::Text(user_prompt),
471        }],
472        tools: vec![],
473        system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
474        think: false,
475        max_tokens: 100,
476        server_tools: vec![],
477        enable_caching: false,
478    };
479
480    let response = match provider.chat(request).await {
481        Ok(r) => r,
482        Err(_) => return Vec::new(),
483    };
484
485    // Extract text from response
486    let text = response
487        .content
488        .iter()
489        .filter_map(|block| {
490            if let crate::providers::ContentBlock::Text { text } = block {
491                Some(text.clone())
492            } else {
493                None
494            }
495        })
496        .collect::<Vec<_>>()
497        .join("");
498
499    // Parse JSON response
500    parse_selected_indices(&text)
501}
502
503/// Parse selected indices from AI response.
504fn parse_selected_indices(text: &str) -> Vec<usize> {
505    // Try to parse as JSON
506    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
507        if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
508            return selected
509                .iter()
510                .filter_map(|v| v.as_u64().map(|n| n as usize))
511                .collect();
512        }
513        // Also try direct array format
514        if let Some(arr) = json.as_array() {
515            return arr
516                .iter()
517                .filter_map(|v| v.as_u64().map(|n| n as usize))
518                .collect();
519        }
520    }
521
522    // Fallback: try to extract numbers from text
523    let mut indices = Vec::new();
524    for part in text.split(',') {
525        let trimmed = part.trim();
526        if let Ok(n) = trimmed.parse::<usize>() {
527            indices.push(n);
528        }
529    }
530    indices
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_extract_keywords() {
539        let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
540        assert!(!keywords.is_empty());
541    }
542
543    #[test]
544    fn test_semantic_aliases() {
545        let keywords = vec!["数据库".to_string()];
546        let expanded = expand_semantic_keywords(&keywords);
547        assert!(expanded.contains(&"database".to_string()));
548    }
549
550    #[test]
551    fn test_tfidf_search() {
552        let mut tfidf = TfIdfSearch::new();
553        let mut memory = AutoMemory::new();
554
555        // Add multiple documents so IDF calculation works properly
556        // (IDF = ln(N/df) where N is total docs, df is docs containing word)
557        memory.add(super::super::types::MemoryEntry::new(
558            super::super::types::MemoryCategory::Decision,
559            "使用 PostgreSQL 作为数据库".to_string(),
560            None,
561            None,
562        ));
563        memory.add(super::super::types::MemoryEntry::new(
564            super::super::types::MemoryCategory::Decision,
565            "前端使用 React 框架开发".to_string(),
566            None,
567            None,
568        ));
569        memory.add(super::super::types::MemoryEntry::new(
570            super::super::types::MemoryCategory::Decision,
571            "后端采用 Rust 编写".to_string(),
572            None,
573            None,
574        ));
575
576        tfidf.index(&memory);
577        let results = tfidf.search("数据库", Some(5));
578        assert!(!results.is_empty());
579
580        // The PostgreSQL document should be the top result
581        assert!(results[0].0.contains("PostgreSQL"));
582    }
583}