1use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::keywords_config::KeywordsConfig;
7use super::types::{AutoMemory, MemoryEntry};
8
9pub fn extract_context_keywords(context: &str) -> Vec<String> {
16 let config = KeywordsConfig::load();
17 let stop_words = config.get_stop_words_set();
18 let tech_patterns = config.get_tech_keywords_set();
19
20 let lower = context.to_lowercase();
21 let mut keywords: HashSet<String> = HashSet::new();
22
23 for word in lower.split_whitespace() {
25 let cleaned = word
26 .trim_matches(|c: char| !c.is_alphanumeric())
27 .to_string();
28 if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
29 keywords.insert(cleaned.clone());
30 }
31 if tech_patterns.contains(cleaned.as_str()) {
32 keywords.insert(cleaned);
33 }
34 }
35
36 let chinese_chars: Vec<char> = lower
38 .chars()
39 .filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}')
40 .collect();
41
42 for window_size in 2..=4 {
43 if chinese_chars.len() >= window_size {
44 for window in chinese_chars.windows(window_size) {
45 let phrase: String = window.iter().collect();
46 let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
47 if !has_stop && phrase.len() >= window_size {
48 keywords.insert(phrase);
49 }
50 }
51 }
52 }
53
54 let patterns = [
56 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}",
57 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*",
58 r"[A-Z][a-z]+[A-Z][a-zA-Z]*",
59 r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*",
60 r"[0-9]+[kKmMgGtT][bB]?",
61 ];
62
63 for pattern in patterns {
64 if let Ok(re) = regex::Regex::new(pattern) {
65 for cap in re.find_iter(&lower) {
66 keywords.insert(cap.as_str().to_string());
67 }
68 }
69 }
70
71 let mut result: Vec<String> = keywords.into_iter().collect();
72 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
73 result.truncate(15);
74
75 result
76}
77
78pub fn calculate_similarity(a: &str, b: &str) -> f64 {
80 AutoMemory::calculate_similarity(a, b)
81}
82
83pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
89 SEMANTIC_ALIASES_DEFAULT.to_vec()
92}
93
94pub const SEMANTIC_ALIASES_DEFAULT: &[(&str, &str)] = &[
96 ("数据库", "database"),
98 ("db", "database"),
99 ("postgresql", "postgres"),
100 ("mysql", "mysql"),
101 ("mongodb", "mongo"),
102 ("redis", "redis"),
103 ("sqlite", "sqlite"),
104 ("sql", "database"),
105 ("前端", "frontend"),
107 ("ui", "frontend"),
108 ("界面", "frontend"),
109 ("页面", "page"),
110 ("组件", "component"),
111 ("react", "react"),
112 ("vue", "vue"),
113 ("angular", "angular"),
114 ("后端", "backend"),
116 ("api", "api"),
117 ("接口", "api"),
118 ("服务", "service"),
119 ("server", "backend"),
120 ("服务器", "backend"),
121 ("rust", "rust"),
123 ("python", "python"),
124 ("javascript", "js"),
125 ("typescript", "ts"),
126 ("java", "java"),
127 ("go", "golang"),
128 ("golang", "go"),
129 ("c++", "cpp"),
130 ("cpp", "c++"),
131 ("nodejs", "node"),
132 ("node", "nodejs"),
133 ("编辑器", "editor"),
135 ("ide", "editor"),
136 ("vim", "vim"),
137 ("vscode", "vscode"),
138 ("emacs", "emacs"),
139 ("配置", "config"),
141 ("设置", "config"),
142 ("config", "config"),
143 ("setting", "config"),
144 ("目录", "directory"),
146 ("文件", "file"),
147 ("文件夹", "directory"),
148 ("路径", "path"),
149 ("模块", "module"),
150 ("包", "package"),
151 ("测试", "test"),
153 ("test", "test"),
154 ("单元测试", "unittest"),
155 ("unittest", "test"),
156 ("缓存", "cache"),
158 ("cache", "cache"),
159 ("认证", "auth"),
161 ("登录", "login"),
162 ("auth", "auth"),
163 ("登录", "auth"),
164 ("性能", "performance"),
166 ("优化", "optimize"),
167 ("速度", "speed"),
168 ("慢", "slow"),
169 ("创建", "create"),
171 ("删除", "delete"),
172 ("修改", "modify"),
173 ("添加", "add"),
174 ("更新", "update"),
175 ("查询", "query"),
176];
177
178pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
180 let config = KeywordsConfig::load();
181 let mut expanded: Vec<String> = keywords.to_vec();
182
183 for keyword in keywords {
184 let kw_lower = keyword.to_lowercase();
185 for (alias, target) in config.get_aliases() {
186 if kw_lower.contains(alias) {
187 expanded.push(target.to_string());
188 }
189 if kw_lower.contains(target) {
190 expanded.push(alias.to_string());
191 }
192 }
193 }
194
195 expanded.sort();
196 expanded.dedup();
197 expanded
198}
199
200pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
207 if context_keywords.is_empty() {
208 return 0.0;
209 }
210
211 let expanded_keywords = expand_semantic_keywords(context_keywords);
212 let content_lower = entry.content.to_lowercase();
213
214 let matches = expanded_keywords
215 .iter()
216 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
217 .count();
218
219 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
220
221 let tag_matches = entry
222 .tags
223 .iter()
224 .filter(|tag| {
225 let tag_lower = tag.to_lowercase();
226 expanded_keywords.iter().any(|kw| {
227 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
228 })
229 })
230 .count();
231
232 let tag_score = if tag_matches > 0 {
233 0.2 + (tag_matches as f64 * 0.05).min(0.1)
234 } else {
235 0.0
236 };
237
238 (keyword_score + tag_score).min(1.0)
239}
240
241pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
244 let config = KeywordsConfig::load();
245
246 for signal in &config.contradiction_signals {
248 if new.contains(signal) {
249 return true;
250 }
251 }
252
253 let action_verbs = [
255 "决定使用",
256 "选择使用",
257 "采用",
258 "使用",
259 "decided to use",
260 "chose",
261 "using",
262 "adopted",
263 ];
264
265 for verb in &action_verbs {
266 if old.contains(verb) && new.contains(verb) {
267 return true;
268 }
269 }
270
271 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
273 for verb in &pref_verbs {
274 if old.contains(verb) && new.contains(verb) {
275 return true;
276 }
277 }
278
279 false
280}
281
282pub async fn extract_keywords_hybrid(
288 context: &str,
289 fast_provider: Option<&dyn crate::providers::Provider>,
290) -> Vec<String> {
291 let rule_keywords = extract_context_keywords(context);
293
294 let mode = AiKeywordMode::from_env();
296 if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
297 if let Some(provider) = fast_provider {
299 let ai_keywords = extract_keywords_with_ai(context, provider).await;
300 if !ai_keywords.is_empty() {
301 return ai_keywords;
302 }
303 }
304 }
305
306 rule_keywords
307}
308
309async fn extract_keywords_with_ai(
311 context: &str,
312 provider: &dyn crate::providers::Provider,
313) -> Vec<String> {
314 use crate::providers::{ChatRequest, Message, MessageContent, Role};
315
316 let truncated = if context.len() > 2000 {
317 &context[..2000]
318 } else {
319 context
320 };
321
322 let prompt = format!(
323 "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
324 truncated
325 );
326
327 let request = ChatRequest {
328 messages: vec![Message {
329 role: Role::User,
330 content: MessageContent::Text(prompt),
331 }],
332 tools: vec![],
333 system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
334 think: false,
335 max_tokens: 100,
336 server_tools: vec![],
337 enable_caching: false,
338 };
339
340 let response = match provider.chat(request).await {
341 Ok(r) => r,
342 Err(_) => return Vec::new(),
343 };
344
345 let text = response
346 .content
347 .iter()
348 .filter_map(|block| {
349 if let crate::providers::ContentBlock::Text { text } = block {
350 Some(text.clone())
351 } else {
352 None
353 }
354 })
355 .collect::<Vec<_>>()
356 .join("");
357
358 text.split(',')
359 .map(|s| s.trim().to_string())
360 .filter(|s| s.len() >= 2)
361 .collect()
362}
363
364pub struct TfIdfSearch {
373 doc_word_freq: HashMap<String, HashMap<String, f32>>,
375 total_docs: usize,
377 idf_cache: HashMap<String, f32>,
379}
380
381impl TfIdfSearch {
382 pub fn new() -> Self {
384 Self {
385 doc_word_freq: HashMap::new(),
386 total_docs: 0,
387 idf_cache: HashMap::new(),
388 }
389 }
390
391 pub fn index(&mut self, memory: &AutoMemory) {
393 self.clear();
394 self.total_docs = memory.entries.len();
395
396 for entry in &memory.entries {
397 let words = self.tokenize(&entry.content);
398 let word_freq = self.compute_word_freq(&words);
399 self.doc_word_freq.insert(entry.content.clone(), word_freq);
400 }
401
402 self.compute_idf();
403 }
404
405 fn tokenize(&self, text: &str) -> Vec<String> {
407 let lower = text.to_lowercase();
408 let mut tokens = Vec::new();
409
410 for word in lower.split_whitespace() {
411 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
412 if trimmed.len() > 1 {
413 tokens.push(trimmed.to_string());
414 }
415
416 let chars: Vec<char> = trimmed.chars().collect();
417 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
418
419 if has_cjk {
420 for c in &chars {
421 if Self::is_cjk(*c) {
422 tokens.push(c.to_string());
423 }
424 }
425 for window in chars.windows(2) {
426 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
427 tokens.push(window.iter().collect::<String>());
428 }
429 }
430 }
431 }
432
433 tokens
434 }
435
436 fn is_cjk(c: char) -> bool {
438 matches!(c,
439 '\u{4E00}'..='\u{9FFF}' |
440 '\u{3400}'..='\u{4DBF}' |
441 '\u{F900}'..='\u{FAFF}' |
442 '\u{3000}'..='\u{303F}' |
443 '\u{3040}'..='\u{309F}' |
444 '\u{30A0}'..='\u{30FF}'
445 )
446 }
447
448 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
450 let total = words.len() as f32;
451 let mut freq = HashMap::new();
452
453 for word in words {
454 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
455 }
456
457 for (_, count) in freq.iter_mut() {
458 *count /= total;
459 }
460
461 freq
462 }
463
464 fn compute_idf(&mut self) {
466 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
467
468 for word_freq in &self.doc_word_freq {
469 for word in word_freq.1.keys() {
470 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
471 }
472 }
473
474 for (word, count) in word_doc_count {
475 let idf = (self.total_docs as f32 / count as f32).ln();
476 self.idf_cache.insert(word, idf);
477 }
478 }
479
480 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
482 let query_words = self.tokenize(query);
483 let query_freq = self.compute_word_freq(&query_words);
484
485 let mut results: Vec<(String, f32)> = Vec::new();
486
487 for (doc, doc_freq) in &self.doc_word_freq {
488 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
489
490 if similarity > 0.0 {
491 results.push((doc.clone(), similarity));
492 }
493 }
494
495 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496
497 if let Some(max) = limit {
498 results.into_iter().take(max).collect()
499 } else {
500 results
501 }
502 }
503
504 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
506 let mut doc_scores: HashMap<String, f64> = HashMap::new();
507
508 for keyword in keywords {
509 let results = self.search(keyword, None);
510 for (doc, score) in results {
511 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
512 }
513 }
514
515 let num_keywords = keywords.len().max(1);
516 for (_, score) in doc_scores.iter_mut() {
517 *score /= num_keywords as f64;
518 }
519
520 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
521 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
522
523 if let Some(max) = limit {
524 results.into_iter().take(max).collect()
525 } else {
526 results
527 }
528 }
529
530 fn compute_tfidf_similarity(
532 &self,
533 query_freq: &HashMap<String, f32>,
534 doc_freq: &HashMap<String, f32>,
535 ) -> f32 {
536 let mut similarity = 0.0;
537
538 for (word, tf_query) in query_freq {
539 if let Some(tf_doc) = doc_freq.get(word)
540 && let Some(idf) = self.idf_cache.get(word)
541 {
542 similarity += tf_query * idf * tf_doc * idf;
543 }
544 }
545
546 similarity
547 }
548
549 pub fn clear(&mut self) {
551 self.doc_word_freq.clear();
552 self.idf_cache.clear();
553 self.total_docs = 0;
554 }
555}
556
557impl Default for TfIdfSearch {
558 fn default() -> Self {
559 Self::new()
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_extract_keywords() {
569 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
570 assert!(!keywords.is_empty());
571 }
572
573 #[test]
574 fn test_semantic_aliases() {
575 let keywords = vec!["数据库".to_string()];
576 let expanded = expand_semantic_keywords(&keywords);
577 assert!(expanded.contains(&"database".to_string()));
578 }
579
580 #[test]
581 fn test_tfidf_search() {
582 let mut tfidf = TfIdfSearch::new();
583 let mut memory = AutoMemory::new();
584
585 memory.add(super::super::types::MemoryEntry::new(
588 super::super::types::MemoryCategory::Decision,
589 "使用 PostgreSQL 作为数据库".to_string(),
590 None,
591 ));
592 memory.add(super::super::types::MemoryEntry::new(
593 super::super::types::MemoryCategory::Decision,
594 "前端使用 React 框架开发".to_string(),
595 None,
596 ));
597 memory.add(super::super::types::MemoryEntry::new(
598 super::super::types::MemoryCategory::Decision,
599 "后端采用 Rust 编写".to_string(),
600 None,
601 ));
602
603 tfidf.index(&memory);
604 let results = tfidf.search("数据库", Some(5));
605 assert!(!results.is_empty());
606
607 assert!(results[0].0.contains("PostgreSQL"));
609 }
610}