1use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::types::{AutoMemory, MemoryEntry};
7
8pub fn extract_context_keywords(context: &str) -> Vec<String> {
15 let stop_words: HashSet<&str> = [
17 "的", "了", "是", "在", "我", "有", "和", "就", "不", "人", "都", "一", "一个", "上", "也",
19 "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这", "他",
20 "她", "它", "们", "那", "些", "什么", "怎么", "如何", "请", "能", "可以", "需要", "应该",
21 "可能", "因为", "所以", "但是", "然后", "还是", "已经", "正在", "将要", "曾经", "一下",
22 "一点", "一些", "所有", "每个", "任何", "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
24 "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
25 "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
26 "during", "before", "after", "above", "below", "between", "and", "but", "or", "not", "no",
27 "so", "if", "then", "than", "too", "very", "just", "this", "that", "these", "those", "it",
28 "its", "i", "me", "my", "we", "our", "you", "your", "he", "his", "she", "her", "they",
29 "their", "please", "help", "need", "want", "make", "get", "let", "use",
30 ]
31 .iter()
32 .copied()
33 .collect();
34
35 let tech_patterns: HashSet<&str> = [
37 "api", "cli", "gui", "tui", "web", "http", "json", "xml", "sql", "db", "git", "npm",
38 "cargo", "rust", "js", "ts", "py", "go", "java", "cpp", "cpu", "gpu", "io", "fs", "os",
39 "ui", "ux", "ai", "ml", "dl", "rs", "js", "ts", "py", "go", "java", "c", "h", "cpp", "hpp",
40 "json", "yaml", "yml", "toml", "md", "txt", "html", "css", "scss", "bug", "fix", "add",
41 "new", "old", "use", "run", "build", "test", "code", "data", "file", "dir", "path", "name",
42 "type", "value",
43 ]
44 .iter()
45 .copied()
46 .collect();
47
48 let lower = context.to_lowercase();
49 let mut keywords: HashSet<String> = HashSet::new();
50
51 for word in lower.split_whitespace() {
53 let cleaned = word
54 .trim_matches(|c: char| !c.is_alphanumeric())
55 .to_string();
56 if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
57 keywords.insert(cleaned.clone());
58 }
59 if tech_patterns.contains(cleaned.as_str()) {
60 keywords.insert(cleaned);
61 }
62 }
63
64 let chinese_chars: Vec<char> = lower
66 .chars()
67 .filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}')
68 .collect();
69
70 for window_size in 2..=4 {
71 if chinese_chars.len() >= window_size {
72 for window in chinese_chars.windows(window_size) {
73 let phrase: String = window.iter().collect();
74 let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
75 if !has_stop && phrase.len() >= window_size {
76 keywords.insert(phrase);
77 }
78 }
79 }
80 }
81
82 let patterns = [
84 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",
85 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*",
86 r"[A-Z][a-z]+[A-Z][a-zA-Z]*",
87 r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",
88 r"[0-9]+[kKmMgGtT][bB]?",
89 ];
90
91 for pattern in patterns {
92 if let Ok(re) = regex::Regex::new(pattern) {
93 for cap in re.find_iter(&lower) {
94 keywords.insert(cap.as_str().to_string());
95 }
96 }
97 }
98
99 let mut result: Vec<String> = keywords.into_iter().collect();
100 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
101 result.truncate(15);
102
103 result
104}
105
106pub fn calculate_similarity(a: &str, b: &str) -> f64 {
108 AutoMemory::calculate_similarity(a, b)
109}
110
111pub const SEMANTIC_ALIASES: &[(&str, &str)] = &[
117 ("数据库", "database"),
119 ("db", "database"),
120 ("postgresql", "postgres"),
121 ("mysql", "mysql"),
122 ("mongodb", "mongo"),
123 ("redis", "redis"),
124 ("sqlite", "sqlite"),
125 ("sql", "database"),
126 ("前端", "frontend"),
128 ("ui", "frontend"),
129 ("界面", "frontend"),
130 ("页面", "page"),
131 ("组件", "component"),
132 ("react", "react"),
133 ("vue", "vue"),
134 ("angular", "angular"),
135 ("后端", "backend"),
137 ("api", "api"),
138 ("接口", "api"),
139 ("服务", "service"),
140 ("server", "backend"),
141 ("服务器", "backend"),
142 ("rust", "rust"),
144 ("python", "python"),
145 ("javascript", "js"),
146 ("typescript", "ts"),
147 ("java", "java"),
148 ("go", "golang"),
149 ("golang", "go"),
150 ("c++", "cpp"),
151 ("cpp", "c++"),
152 ("nodejs", "node"),
153 ("node", "nodejs"),
154 ("编辑器", "editor"),
156 ("ide", "editor"),
157 ("vim", "vim"),
158 ("vscode", "vscode"),
159 ("emacs", "emacs"),
160 ("配置", "config"),
162 ("设置", "config"),
163 ("config", "config"),
164 ("setting", "config"),
165 ("目录", "directory"),
167 ("文件", "file"),
168 ("文件夹", "directory"),
169 ("路径", "path"),
170 ("模块", "module"),
171 ("包", "package"),
172 ("测试", "test"),
174 ("test", "test"),
175 ("单元测试", "unittest"),
176 ("unittest", "test"),
177 ("缓存", "cache"),
179 ("cache", "cache"),
180 ("认证", "auth"),
182 ("登录", "login"),
183 ("auth", "auth"),
184 ("登录", "auth"),
185 ("性能", "performance"),
187 ("优化", "optimize"),
188 ("速度", "speed"),
189 ("慢", "slow"),
190 ("创建", "create"),
192 ("删除", "delete"),
193 ("修改", "modify"),
194 ("添加", "add"),
195 ("更新", "update"),
196 ("查询", "query"),
197];
198
199pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
201 let mut expanded: Vec<String> = keywords.to_vec();
202
203 for keyword in keywords {
204 let kw_lower = keyword.to_lowercase();
205 for (alias, target) in SEMANTIC_ALIASES {
206 if kw_lower.contains(alias) {
207 expanded.push(target.to_string());
208 }
209 if kw_lower.contains(target) {
210 expanded.push(alias.to_string());
211 }
212 }
213 }
214
215 expanded.sort();
216 expanded.dedup();
217 expanded
218}
219
220pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
227 if context_keywords.is_empty() {
228 return 0.0;
229 }
230
231 let expanded_keywords = expand_semantic_keywords(context_keywords);
232 let content_lower = entry.content.to_lowercase();
233
234 let matches = expanded_keywords
235 .iter()
236 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
237 .count();
238
239 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
240
241 let tag_matches = entry
242 .tags
243 .iter()
244 .filter(|tag| {
245 let tag_lower = tag.to_lowercase();
246 expanded_keywords.iter().any(|kw| {
247 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
248 })
249 })
250 .count();
251
252 let tag_score = if tag_matches > 0 {
253 0.2 + (tag_matches as f64 * 0.05).min(0.1)
254 } else {
255 0.0
256 };
257
258 (keyword_score + tag_score).min(1.0)
259}
260
261pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
263 let change_signals = [
264 "改用",
265 "换成",
266 "替换",
267 "改为",
268 "切换到",
269 "迁移到",
270 "不再使用",
271 "弃用",
272 "放弃",
273 "取消",
274 "switched to",
275 "replaced",
276 "migrated to",
277 "changed to",
278 "no longer",
279 "deprecated",
280 "abandoned",
281 ];
282
283 for signal in &change_signals {
284 if new.contains(signal) {
285 return true;
286 }
287 }
288
289 let action_verbs = [
290 "决定使用",
291 "选择使用",
292 "采用",
293 "使用",
294 "decided to use",
295 "chose",
296 "using",
297 "adopted",
298 ];
299
300 for verb in &action_verbs {
301 if old.contains(verb) && new.contains(verb) {
302 return true;
303 }
304 }
305
306 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
307 for verb in &pref_verbs {
308 if old.contains(verb) && new.contains(verb) {
309 return true;
310 }
311 }
312
313 false
314}
315
316pub async fn extract_keywords_hybrid(
322 context: &str,
323 fast_provider: Option<&dyn crate::providers::Provider>,
324) -> Vec<String> {
325 let rule_keywords = extract_context_keywords(context);
327
328 let mode = AiKeywordMode::from_env();
330 if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
331 if let Some(provider) = fast_provider {
333 let ai_keywords = extract_keywords_with_ai(context, provider).await;
334 if !ai_keywords.is_empty() {
335 return ai_keywords;
336 }
337 }
338 }
339
340 rule_keywords
341}
342
343async fn extract_keywords_with_ai(
345 context: &str,
346 provider: &dyn crate::providers::Provider,
347) -> Vec<String> {
348 use crate::providers::{ChatRequest, Message, MessageContent, Role};
349
350 let truncated = if context.len() > 2000 {
351 &context[..2000]
352 } else {
353 context
354 };
355
356 let prompt = format!(
357 "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
358 truncated
359 );
360
361 let request = ChatRequest {
362 messages: vec![Message {
363 role: Role::User,
364 content: MessageContent::Text(prompt),
365 }],
366 tools: vec![],
367 system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
368 think: false,
369 max_tokens: 100,
370 server_tools: vec![],
371 enable_caching: false,
372 };
373
374 let response = match provider.chat(request).await {
375 Ok(r) => r,
376 Err(_) => return Vec::new(),
377 };
378
379 let text = response
380 .content
381 .iter()
382 .filter_map(|block| {
383 if let crate::providers::ContentBlock::Text { text } = block {
384 Some(text.clone())
385 } else {
386 None
387 }
388 })
389 .collect::<Vec<_>>()
390 .join("");
391
392 text.split(',')
393 .map(|s| s.trim().to_string())
394 .filter(|s| s.len() >= 2)
395 .collect()
396}
397
398pub struct TfIdfSearch {
407 doc_word_freq: HashMap<String, HashMap<String, f32>>,
409 total_docs: usize,
411 idf_cache: HashMap<String, f32>,
413}
414
415impl TfIdfSearch {
416 pub fn new() -> Self {
418 Self {
419 doc_word_freq: HashMap::new(),
420 total_docs: 0,
421 idf_cache: HashMap::new(),
422 }
423 }
424
425 pub fn index(&mut self, memory: &AutoMemory) {
427 self.clear();
428 self.total_docs = memory.entries.len();
429
430 for entry in &memory.entries {
431 let words = self.tokenize(&entry.content);
432 let word_freq = self.compute_word_freq(&words);
433 self.doc_word_freq.insert(entry.content.clone(), word_freq);
434 }
435
436 self.compute_idf();
437 }
438
439 fn tokenize(&self, text: &str) -> Vec<String> {
441 let lower = text.to_lowercase();
442 let mut tokens = Vec::new();
443
444 for word in lower.split_whitespace() {
445 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
446 if trimmed.len() > 1 {
447 tokens.push(trimmed.to_string());
448 }
449
450 let chars: Vec<char> = trimmed.chars().collect();
451 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
452
453 if has_cjk {
454 for c in &chars {
455 if Self::is_cjk(*c) {
456 tokens.push(c.to_string());
457 }
458 }
459 for window in chars.windows(2) {
460 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
461 tokens.push(window.iter().collect::<String>());
462 }
463 }
464 }
465 }
466
467 tokens
468 }
469
470 fn is_cjk(c: char) -> bool {
472 matches!(c,
473 '\u{4E00}'..='\u{9FFF}' |
474 '\u{3400}'..='\u{4DBF}' |
475 '\u{F900}'..='\u{FAFF}' |
476 '\u{3000}'..='\u{303F}' |
477 '\u{3040}'..='\u{309F}' |
478 '\u{30A0}'..='\u{30FF}'
479 )
480 }
481
482 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
484 let total = words.len() as f32;
485 let mut freq = HashMap::new();
486
487 for word in words {
488 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
489 }
490
491 for (_, count) in freq.iter_mut() {
492 *count /= total;
493 }
494
495 freq
496 }
497
498 fn compute_idf(&mut self) {
500 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
501
502 for word_freq in &self.doc_word_freq {
503 for word in word_freq.1.keys() {
504 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
505 }
506 }
507
508 for (word, count) in word_doc_count {
509 let idf = (self.total_docs as f32 / count as f32).ln();
510 self.idf_cache.insert(word, idf);
511 }
512 }
513
514 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
516 let query_words = self.tokenize(query);
517 let query_freq = self.compute_word_freq(&query_words);
518
519 let mut results: Vec<(String, f32)> = Vec::new();
520
521 for (doc, doc_freq) in &self.doc_word_freq {
522 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
523
524 if similarity > 0.0 {
525 results.push((doc.clone(), similarity));
526 }
527 }
528
529 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531 if let Some(max) = limit {
532 results.into_iter().take(max).collect()
533 } else {
534 results
535 }
536 }
537
538 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
540 let mut doc_scores: HashMap<String, f64> = HashMap::new();
541
542 for keyword in keywords {
543 let results = self.search(keyword, None);
544 for (doc, score) in results {
545 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
546 }
547 }
548
549 let num_keywords = keywords.len().max(1);
550 for (_, score) in doc_scores.iter_mut() {
551 *score /= num_keywords as f64;
552 }
553
554 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
555 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
556
557 if let Some(max) = limit {
558 results.into_iter().take(max).collect()
559 } else {
560 results
561 }
562 }
563
564 fn compute_tfidf_similarity(
566 &self,
567 query_freq: &HashMap<String, f32>,
568 doc_freq: &HashMap<String, f32>,
569 ) -> f32 {
570 let mut similarity = 0.0;
571
572 for (word, tf_query) in query_freq {
573 if let Some(tf_doc) = doc_freq.get(word)
574 && let Some(idf) = self.idf_cache.get(word)
575 {
576 similarity += tf_query * idf * tf_doc * idf;
577 }
578 }
579
580 similarity
581 }
582
583 pub fn clear(&mut self) {
585 self.doc_word_freq.clear();
586 self.idf_cache.clear();
587 self.total_docs = 0;
588 }
589}
590
591impl Default for TfIdfSearch {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_extract_keywords() {
603 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
604 assert!(!keywords.is_empty());
605 }
606
607 #[test]
608 fn test_semantic_aliases() {
609 let keywords = vec!["数据库".to_string()];
610 let expanded = expand_semantic_keywords(&keywords);
611 assert!(expanded.contains(&"database".to_string()));
612 }
613
614 #[test]
615 fn test_tfidf_search() {
616 let mut tfidf = TfIdfSearch::new();
617 let mut memory = AutoMemory::new();
618
619 memory.add(super::super::types::MemoryEntry::new(
622 super::super::types::MemoryCategory::Decision,
623 "使用 PostgreSQL 作为数据库".to_string(),
624 None,
625 ));
626 memory.add(super::super::types::MemoryEntry::new(
627 super::super::types::MemoryCategory::Decision,
628 "前端使用 React 框架开发".to_string(),
629 None,
630 ));
631 memory.add(super::super::types::MemoryEntry::new(
632 super::super::types::MemoryCategory::Decision,
633 "后端采用 Rust 编写".to_string(),
634 None,
635 ));
636
637 tfidf.index(&memory);
638 let results = tfidf.search("数据库", Some(5));
639 assert!(!results.is_empty());
640
641 assert!(results[0].0.contains("PostgreSQL"));
643 }
644}