1use std::collections::{HashMap, HashSet};
4
5use super::config::*;
6use super::keywords_config::KeywordsConfig;
7use super::types::{AutoMemory, MemoryEntry};
8
9pub fn extract_context_keywords(context: &str) -> Vec<String> {
17 let config = KeywordsConfig::load();
18 let stop_words = config.get_stop_words_set();
19 let tech_patterns = config.get_tech_keywords_set();
20
21 let lower = context.to_lowercase();
22 let mut keywords: HashSet<String> = HashSet::new();
23
24 for word in lower.split_whitespace() {
26 let cleaned = word
27 .trim_matches(|c: char| !c.is_alphanumeric())
28 .to_string();
29 if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
31 keywords.insert(cleaned.clone());
32 }
33 if tech_patterns.contains(cleaned.as_str()) {
35 keywords.insert(cleaned);
36 }
37 }
38
39 for category_patterns in config.patterns.values() {
43 for pattern in category_patterns {
44 if lower.contains(&pattern.to_lowercase()) {
45 keywords.insert(pattern.clone());
46 }
47 }
48 }
49
50 for kw in &config.tech_keywords {
52 if lower.contains(&kw.to_lowercase()) && !stop_words.contains(kw.as_str()) {
53 keywords.insert(kw.clone());
54 }
55 }
56
57 let tech_regexes = [
59 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", r"[a-zA-Z]+-[a-zA-Z]+", ];
66
67 for pattern in tech_regexes {
68 if let Ok(re) = regex::Regex::new(pattern) {
69 for cap in re.find_iter(&lower) {
70 let match_str = cap.as_str();
71 if !stop_words.contains(match_str) {
72 keywords.insert(match_str.to_string());
73 }
74 }
75 }
76 }
77
78 let mut result: Vec<String> = keywords.into_iter().collect();
80 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
81 result.truncate(10);
82
83 result
84}
85
86pub fn calculate_similarity(a: &str, b: &str) -> f64 {
88 AutoMemory::calculate_similarity(a, b)
89}
90
91pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
97 SEMANTIC_ALIASES_DEFAULT.to_vec()
100}
101
102pub const SEMANTIC_ALIASES_DEFAULT: &[(&str, &str)] = &[
104 ("数据库", "database"),
106 ("db", "database"),
107 ("postgresql", "postgres"),
108 ("mysql", "mysql"),
109 ("mongodb", "mongo"),
110 ("redis", "redis"),
111 ("sqlite", "sqlite"),
112 ("sql", "database"),
113 ("前端", "frontend"),
115 ("ui", "frontend"),
116 ("界面", "frontend"),
117 ("页面", "page"),
118 ("组件", "component"),
119 ("react", "react"),
120 ("vue", "vue"),
121 ("angular", "angular"),
122 ("后端", "backend"),
124 ("api", "api"),
125 ("接口", "api"),
126 ("服务", "service"),
127 ("server", "backend"),
128 ("服务器", "backend"),
129 ("rust", "rust"),
131 ("python", "python"),
132 ("javascript", "js"),
133 ("typescript", "ts"),
134 ("java", "java"),
135 ("go", "golang"),
136 ("golang", "go"),
137 ("c++", "cpp"),
138 ("cpp", "c++"),
139 ("nodejs", "node"),
140 ("node", "nodejs"),
141 ("编辑器", "editor"),
143 ("ide", "editor"),
144 ("vim", "vim"),
145 ("vscode", "vscode"),
146 ("emacs", "emacs"),
147 ("配置", "config"),
149 ("设置", "config"),
150 ("config", "config"),
151 ("setting", "config"),
152 ("目录", "directory"),
154 ("文件", "file"),
155 ("文件夹", "directory"),
156 ("路径", "path"),
157 ("模块", "module"),
158 ("包", "package"),
159 ("测试", "test"),
161 ("test", "test"),
162 ("单元测试", "unittest"),
163 ("unittest", "test"),
164 ("缓存", "cache"),
166 ("cache", "cache"),
167 ("认证", "auth"),
169 ("登录", "login"),
170 ("auth", "auth"),
171 ("登录", "auth"),
172 ("性能", "performance"),
174 ("优化", "optimize"),
175 ("速度", "speed"),
176 ("慢", "slow"),
177 ("创建", "create"),
179 ("删除", "delete"),
180 ("修改", "modify"),
181 ("添加", "add"),
182 ("更新", "update"),
183 ("查询", "query"),
184];
185
186pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
188 let config = KeywordsConfig::load();
189 let mut expanded: Vec<String> = keywords.to_vec();
190
191 for keyword in keywords {
192 let kw_lower = keyword.to_lowercase();
193 for (alias, target) in config.get_aliases() {
194 if kw_lower.contains(alias) {
195 expanded.push(target.to_string());
196 }
197 if kw_lower.contains(target) {
198 expanded.push(alias.to_string());
199 }
200 }
201 }
202
203 expanded.sort();
204 expanded.dedup();
205 expanded
206}
207
208pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
215 if context_keywords.is_empty() {
216 return 0.0;
217 }
218
219 let expanded_keywords = expand_semantic_keywords(context_keywords);
220 let content_lower = entry.content.to_lowercase();
221
222 let matches = expanded_keywords
223 .iter()
224 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
225 .count();
226
227 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
228
229 let tag_matches = entry
230 .tags
231 .iter()
232 .filter(|tag| {
233 let tag_lower = tag.to_lowercase();
234 expanded_keywords.iter().any(|kw| {
235 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
236 })
237 })
238 .count();
239
240 let tag_score = if tag_matches > 0 {
241 0.2 + (tag_matches as f64 * 0.05).min(0.1)
242 } else {
243 0.0
244 };
245
246 (keyword_score + tag_score).min(1.0)
247}
248
249pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
252 let config = KeywordsConfig::load();
253
254 for signal in &config.contradiction_signals {
256 if new.contains(signal) {
257 return true;
258 }
259 }
260
261 let action_verbs = [
263 "决定使用",
264 "选择使用",
265 "采用",
266 "使用",
267 "decided to use",
268 "chose",
269 "using",
270 "adopted",
271 ];
272
273 for verb in &action_verbs {
274 if old.contains(verb) && new.contains(verb) {
275 return true;
276 }
277 }
278
279 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
281 for verb in &pref_verbs {
282 if old.contains(verb) && new.contains(verb) {
283 return true;
284 }
285 }
286
287 false
288}
289
290pub async fn extract_keywords_hybrid(
296 context: &str,
297 fast_provider: Option<&dyn crate::providers::Provider>,
298) -> Vec<String> {
299 let rule_keywords = extract_context_keywords(context);
301
302 let mode = AiKeywordMode::from_env();
304 if mode.should_use_ai(rule_keywords.len()) && fast_provider.is_some() {
305 if let Some(provider) = fast_provider {
307 let ai_keywords = extract_keywords_with_ai(context, provider).await;
308 if !ai_keywords.is_empty() {
309 return ai_keywords;
310 }
311 }
312 }
313
314 rule_keywords
315}
316
317async fn extract_keywords_with_ai(
319 context: &str,
320 provider: &dyn crate::providers::Provider,
321) -> Vec<String> {
322 use crate::providers::{ChatRequest, Message, MessageContent, Role};
323
324 let truncated = if context.len() > 2000 {
325 &context[..2000]
326 } else {
327 context
328 };
329
330 let prompt = format!(
331 "从以下对话内容中提取关键词(用于记忆检索),最多返回10个关键词,以逗号分隔:\n\n{}",
332 truncated
333 );
334
335 let request = ChatRequest {
336 messages: vec![Message {
337 role: Role::User,
338 content: MessageContent::Text(prompt),
339 }],
340 tools: vec![],
341 system: Some("你是一个关键词提取助手,返回关键词列表,不要其他解释。".to_string()),
342 think: false,
343 max_tokens: 100,
344 server_tools: vec![],
345 enable_caching: false,
346 };
347
348 let response = match provider.chat(request).await {
349 Ok(r) => r,
350 Err(_) => return Vec::new(),
351 };
352
353 let text = response
354 .content
355 .iter()
356 .filter_map(|block| {
357 if let crate::providers::ContentBlock::Text { text } = block {
358 Some(text.clone())
359 } else {
360 None
361 }
362 })
363 .collect::<Vec<_>>()
364 .join("");
365
366 text.split(',')
367 .map(|s| s.trim().to_string())
368 .filter(|s| s.len() >= 2)
369 .collect()
370}
371
372pub struct TfIdfSearch {
381 doc_word_freq: HashMap<String, HashMap<String, f32>>,
383 total_docs: usize,
385 idf_cache: HashMap<String, f32>,
387}
388
389impl TfIdfSearch {
390 pub fn new() -> Self {
392 Self {
393 doc_word_freq: HashMap::new(),
394 total_docs: 0,
395 idf_cache: HashMap::new(),
396 }
397 }
398
399 pub fn index(&mut self, memory: &AutoMemory) {
401 self.clear();
402 self.total_docs = memory.entries.len();
403
404 for entry in &memory.entries {
405 let words = self.tokenize(&entry.content);
406 let word_freq = self.compute_word_freq(&words);
407 self.doc_word_freq.insert(entry.content.clone(), word_freq);
408 }
409
410 self.compute_idf();
411 }
412
413 fn tokenize(&self, text: &str) -> Vec<String> {
415 let lower = text.to_lowercase();
416 let mut tokens = Vec::new();
417
418 for word in lower.split_whitespace() {
419 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
420 if trimmed.len() > 1 {
421 tokens.push(trimmed.to_string());
422 }
423
424 let chars: Vec<char> = trimmed.chars().collect();
425 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
426
427 if has_cjk {
428 for c in &chars {
429 if Self::is_cjk(*c) {
430 tokens.push(c.to_string());
431 }
432 }
433 for window in chars.windows(2) {
434 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
435 tokens.push(window.iter().collect::<String>());
436 }
437 }
438 }
439 }
440
441 tokens
442 }
443
444 fn is_cjk(c: char) -> bool {
446 matches!(c,
447 '\u{4E00}'..='\u{9FFF}' |
448 '\u{3400}'..='\u{4DBF}' |
449 '\u{F900}'..='\u{FAFF}' |
450 '\u{3000}'..='\u{303F}' |
451 '\u{3040}'..='\u{309F}' |
452 '\u{30A0}'..='\u{30FF}'
453 )
454 }
455
456 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
458 let total = words.len() as f32;
459 let mut freq = HashMap::new();
460
461 for word in words {
462 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
463 }
464
465 for (_, count) in freq.iter_mut() {
466 *count /= total;
467 }
468
469 freq
470 }
471
472 fn compute_idf(&mut self) {
474 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
475
476 for word_freq in &self.doc_word_freq {
477 for word in word_freq.1.keys() {
478 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
479 }
480 }
481
482 for (word, count) in word_doc_count {
483 let idf = (self.total_docs as f32 / count as f32).ln();
484 self.idf_cache.insert(word, idf);
485 }
486 }
487
488 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
490 let query_words = self.tokenize(query);
491 let query_freq = self.compute_word_freq(&query_words);
492
493 let mut results: Vec<(String, f32)> = Vec::new();
494
495 for (doc, doc_freq) in &self.doc_word_freq {
496 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
497
498 if similarity > 0.0 {
499 results.push((doc.clone(), similarity));
500 }
501 }
502
503 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
504
505 if let Some(max) = limit {
506 results.into_iter().take(max).collect()
507 } else {
508 results
509 }
510 }
511
512 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
514 let mut doc_scores: HashMap<String, f64> = HashMap::new();
515
516 for keyword in keywords {
517 let results = self.search(keyword, None);
518 for (doc, score) in results {
519 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
520 }
521 }
522
523 let num_keywords = keywords.len().max(1);
524 for (_, score) in doc_scores.iter_mut() {
525 *score /= num_keywords as f64;
526 }
527
528 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
529 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531 if let Some(max) = limit {
532 results.into_iter().take(max).collect()
533 } else {
534 results
535 }
536 }
537
538 fn compute_tfidf_similarity(
540 &self,
541 query_freq: &HashMap<String, f32>,
542 doc_freq: &HashMap<String, f32>,
543 ) -> f32 {
544 let mut similarity = 0.0;
545
546 for (word, tf_query) in query_freq {
547 if let Some(tf_doc) = doc_freq.get(word)
548 && let Some(idf) = self.idf_cache.get(word)
549 {
550 similarity += tf_query * idf * tf_doc * idf;
551 }
552 }
553
554 similarity
555 }
556
557 pub fn clear(&mut self) {
559 self.doc_word_freq.clear();
560 self.idf_cache.clear();
561 self.total_docs = 0;
562 }
563}
564
565impl Default for TfIdfSearch {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_extract_keywords() {
577 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
578 assert!(!keywords.is_empty());
579 }
580
581 #[test]
582 fn test_semantic_aliases() {
583 let keywords = vec!["数据库".to_string()];
584 let expanded = expand_semantic_keywords(&keywords);
585 assert!(expanded.contains(&"database".to_string()));
586 }
587
588 #[test]
589 fn test_tfidf_search() {
590 let mut tfidf = TfIdfSearch::new();
591 let mut memory = AutoMemory::new();
592
593 memory.add(super::super::types::MemoryEntry::new(
596 super::super::types::MemoryCategory::Decision,
597 "使用 PostgreSQL 作为数据库".to_string(),
598 None,
599 ));
600 memory.add(super::super::types::MemoryEntry::new(
601 super::super::types::MemoryCategory::Decision,
602 "前端使用 React 框架开发".to_string(),
603 None,
604 ));
605 memory.add(super::super::types::MemoryEntry::new(
606 super::super::types::MemoryCategory::Decision,
607 "后端采用 Rust 编写".to_string(),
608 None,
609 ));
610
611 tfidf.index(&memory);
612 let results = tfidf.search("数据库", Some(5));
613 assert!(!results.is_empty());
614
615 assert!(results[0].0.contains("PostgreSQL"));
617 }
618}