1use std::collections::{HashMap, HashSet};
4
5use super::entry::MemoryEntry;
6use super::manager::AutoMemory;
7
8fn get_stop_words() -> HashSet<&'static str> {
14 HashSet::from([
15 "的", "了", "是", "在", "我", "有", "和", "就", "不", "都", "一", "也", "很", "到", "要",
17 "去", "你", "会", "着", "没有", "看", "好", "这", "那", "什么", "怎么", "请", "能", "可以",
18 "需要", "the", "a", "an", "is", "are", "was", "were", "be", "have", "has", "do", "will", "would",
20 "could", "should", "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "and",
21 "but", "or", "not", "if", "then", "this", "that", "it", "i", "me", "my", "we", "you", "he",
22 "she", "they", "please", "help", "need", "want", "let", "use",
23 ])
24}
25
26pub fn extract_context_keywords(context: &str) -> Vec<String> {
32 let stop_words = get_stop_words();
33 let lower = context.to_lowercase();
34 let mut keywords: HashSet<String> = HashSet::new();
35
36 for word in lower.split_whitespace() {
38 let cleaned = word
39 .trim_matches(|c: char| !c.is_alphanumeric())
40 .to_string();
41 if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
42 keywords.insert(cleaned);
43 }
44 }
45
46 let tech_regexes = [
48 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", ];
53
54 for pattern in tech_regexes {
55 if let Ok(re) = regex::Regex::new(pattern) {
56 for cap in re.find_iter(&lower) {
57 let match_str = cap.as_str();
58 if !stop_words.contains(match_str) {
59 keywords.insert(match_str.to_string());
60 }
61 }
62 }
63 }
64
65 let mut result: Vec<String> = keywords.into_iter().collect();
67 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
68 result.truncate(10);
69 result
70}
71
72const GREETING_PATTERNS: &[&str] = &[
78 "你好",
79 "您好",
80 "hi",
81 "hello",
82 "hey",
83 "嗨",
84 "早上好",
85 "下午好",
86 "晚上好",
87 "good morning",
88 "good afternoon",
89 "good evening",
90 "请问",
91 "帮忙",
92 "帮我",
93 "帮我看",
94 "看看",
95 "help",
96 "请",
97 "开始",
98 "start",
99 "准备好了",
100 "ready",
101];
102
103pub fn should_skip_simple_message(msg: &str) -> bool {
106 let trimmed = msg.trim();
107
108 if trimmed.len() < 15 {
110 return true;
111 }
112
113 let lower = trimmed.to_lowercase();
115 for pattern in GREETING_PATTERNS {
116 if lower.starts_with(pattern) || lower == *pattern {
117 return true;
118 }
119 }
120
121 false
122}
123
124pub fn calculate_similarity(a: &str, b: &str) -> f64 {
126 AutoMemory::calculate_similarity(a, b)
127}
128
129pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
135 vec![
136 ("rust", "Rust"),
138 ("typescript", "TypeScript"),
139 ("javascript", "JavaScript"),
140 ("python", "Python"),
141 ("react", "React"),
142 ("vue", "Vue"),
143 ("angular", "Angular"),
144 ("数据库", "database"),
145 ("db", "database"),
146 ("修复", "fix"),
148 ("解决", "solve"),
149 ("优化", "optimize"),
150 ("重构", "refactor"),
151 ("更新", "update"),
152 ("删除", "delete"),
153 ("喜欢", "prefer"),
155 ("偏好", "prefer"),
156 ("首选", "prefer"),
157 ("入口", "entry"),
159 ("主文件", "main"),
160 ("目录", "directory"),
161 ]
162}
163
164pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
166 let aliases = get_semantic_aliases();
167 let mut expanded: Vec<String> = keywords.to_vec();
168
169 for keyword in keywords {
170 let kw_lower = keyword.to_lowercase();
171 for &(alias, target) in &aliases {
172 if kw_lower.contains(alias) {
173 expanded.push(target.to_string());
174 }
175 if kw_lower.contains(target) {
176 expanded.push(alias.to_string());
177 }
178 }
179 }
180
181 expanded.sort();
182 expanded.dedup();
183 expanded
184}
185
186pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
193 if context_keywords.is_empty() {
194 return 0.0;
195 }
196
197 let expanded_keywords = expand_semantic_keywords(context_keywords);
198 let content_lower = entry.content.to_lowercase();
199
200 let matches = expanded_keywords
201 .iter()
202 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
203 .count();
204
205 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
206
207 let tag_matches = entry
208 .tags
209 .iter()
210 .filter(|tag| {
211 let tag_lower = tag.to_lowercase();
212 expanded_keywords.iter().any(|kw| {
213 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
214 })
215 })
216 .count();
217
218 let tag_score = if tag_matches > 0 {
219 0.2 + (tag_matches as f64 * 0.05).min(0.1)
220 } else {
221 0.0
222 };
223
224 (keyword_score + tag_score).min(1.0)
225}
226
227pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
230 let contradiction_signals = [
232 "不再",
233 "改为",
234 "换成",
235 "放弃",
236 "no longer",
237 "instead of",
238 "changed to",
239 "switched to",
240 ];
241
242 for signal in &contradiction_signals {
244 if new.contains(signal) {
245 return true;
246 }
247 }
248
249 let action_verbs = [
251 "决定使用",
252 "选择使用",
253 "采用",
254 "使用",
255 "decided to use",
256 "chose",
257 "using",
258 "adopted",
259 ];
260
261 for verb in &action_verbs {
262 if old.contains(verb) && new.contains(verb) {
263 return true;
264 }
265 }
266
267 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
269 for verb in &pref_verbs {
270 if old.contains(verb) && new.contains(verb) {
271 return true;
272 }
273 }
274
275 false
276}
277
278pub struct TfIdfSearch {
287 doc_word_freq: HashMap<String, HashMap<String, f32>>,
289 total_docs: usize,
291 idf_cache: HashMap<String, f32>,
293}
294
295impl TfIdfSearch {
296 pub fn new() -> Self {
298 Self {
299 doc_word_freq: HashMap::new(),
300 total_docs: 0,
301 idf_cache: HashMap::new(),
302 }
303 }
304
305 pub fn index(&mut self, memory: &AutoMemory) {
307 self.clear();
308 self.total_docs = memory.entries.len();
309
310 for entry in &memory.entries {
311 let words = self.tokenize(&entry.content);
312 let word_freq = self.compute_word_freq(&words);
313 self.doc_word_freq.insert(entry.content.clone(), word_freq);
314 }
315
316 self.compute_idf();
317 }
318
319 fn tokenize(&self, text: &str) -> Vec<String> {
321 let lower = text.to_lowercase();
322 let mut tokens = Vec::new();
323
324 for word in lower.split_whitespace() {
325 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
326 if trimmed.len() > 1 {
327 tokens.push(trimmed.to_string());
328 }
329
330 let chars: Vec<char> = trimmed.chars().collect();
331 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
332
333 if has_cjk {
334 for c in &chars {
335 if Self::is_cjk(*c) {
336 tokens.push(c.to_string());
337 }
338 }
339 for window in chars.windows(2) {
340 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
341 tokens.push(window.iter().collect::<String>());
342 }
343 }
344 }
345 }
346
347 tokens
348 }
349
350 fn is_cjk(c: char) -> bool {
352 matches!(c,
353 '\u{4E00}'..='\u{9FFF}' |
354 '\u{3400}'..='\u{4DBF}' |
355 '\u{F900}'..='\u{FAFF}' |
356 '\u{3000}'..='\u{303F}' |
357 '\u{3040}'..='\u{309F}' |
358 '\u{30A0}'..='\u{30FF}'
359 )
360 }
361
362 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
364 let total = words.len() as f32;
365 let mut freq = HashMap::new();
366
367 for word in words {
368 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
369 }
370
371 for (_, count) in freq.iter_mut() {
372 *count /= total;
373 }
374
375 freq
376 }
377
378 fn compute_idf(&mut self) {
380 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
381
382 for word_freq in &self.doc_word_freq {
383 for word in word_freq.1.keys() {
384 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
385 }
386 }
387
388 for (word, count) in word_doc_count {
389 let idf = (self.total_docs as f32 / count as f32).ln();
390 self.idf_cache.insert(word, idf);
391 }
392 }
393
394 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
396 let query_words = self.tokenize(query);
397 let query_freq = self.compute_word_freq(&query_words);
398
399 let mut results: Vec<(String, f32)> = Vec::new();
400
401 for (doc, doc_freq) in &self.doc_word_freq {
402 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
403
404 if similarity > 0.0 {
405 results.push((doc.clone(), similarity));
406 }
407 }
408
409 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
410
411 if let Some(max) = limit {
412 results.into_iter().take(max).collect()
413 } else {
414 results
415 }
416 }
417
418 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
420 let mut doc_scores: HashMap<String, f64> = HashMap::new();
421
422 for keyword in keywords {
423 let results = self.search(keyword, None);
424 for (doc, score) in results {
425 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
426 }
427 }
428
429 let num_keywords = keywords.len().max(1);
430 for (_, score) in doc_scores.iter_mut() {
431 *score /= num_keywords as f64;
432 }
433
434 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
435 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
436
437 if let Some(max) = limit {
438 results.into_iter().take(max).collect()
439 } else {
440 results
441 }
442 }
443
444 fn compute_tfidf_similarity(
446 &self,
447 query_freq: &HashMap<String, f32>,
448 doc_freq: &HashMap<String, f32>,
449 ) -> f32 {
450 let mut similarity = 0.0;
451
452 for (word, tf_query) in query_freq {
453 if let Some(tf_doc) = doc_freq.get(word)
454 && let Some(idf) = self.idf_cache.get(word)
455 {
456 similarity += tf_query * idf * tf_doc * idf;
457 }
458 }
459
460 similarity
461 }
462
463 pub fn clear(&mut self) {
465 self.doc_word_freq.clear();
466 self.idf_cache.clear();
467 self.total_docs = 0;
468 }
469}
470
471impl Default for TfIdfSearch {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
483
484返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
485- 只选择你确定会有帮助的记忆
486- 如果不确定某个记忆是否有用,不要选择它
487- 如果没有明显有用的记忆,可以返回空数组 []
488- 优先选择与当前问题直接相关的记忆
489
490返回格式示例:{"selected": [0, 2, 5]}
491"#;
492
493pub async fn ai_select_memories(
498 query: &str,
499 memory_manifest: &str,
500 provider: &dyn crate::providers::Provider,
501) -> Vec<usize> {
502 use crate::providers::{ChatRequest, Message, MessageContent, Role};
503
504 let truncated_query = if query.len() > 1000 {
506 &query[..1000]
507 } else {
508 query
509 };
510
511 let user_prompt = format!(
512 "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
513 truncated_query, memory_manifest
514 );
515
516 let request = ChatRequest {
517 messages: vec![Message {
518 role: Role::User,
519 content: MessageContent::Text(user_prompt),
520 }],
521 tools: vec![],
522 system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
523 think: false,
524 max_tokens: 100,
525 server_tools: vec![],
526 enable_caching: false,
527 };
528
529 let response = match provider.chat(request).await {
530 Ok(r) => r,
531 Err(_) => return Vec::new(),
532 };
533
534 let text = response
536 .content
537 .iter()
538 .filter_map(|block| {
539 if let crate::providers::ContentBlock::Text { text } = block {
540 Some(text.clone())
541 } else {
542 None
543 }
544 })
545 .collect::<Vec<_>>()
546 .join("");
547
548 parse_selected_indices(&text)
550}
551
552fn parse_selected_indices(text: &str) -> Vec<usize> {
554 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
556 if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
557 return selected
558 .iter()
559 .filter_map(|v| v.as_u64().map(|n| n as usize))
560 .collect();
561 }
562 if let Some(arr) = json.as_array() {
564 return arr
565 .iter()
566 .filter_map(|v| v.as_u64().map(|n| n as usize))
567 .collect();
568 }
569 }
570
571 let mut indices = Vec::new();
573 for part in text.split(',') {
574 let trimmed = part.trim();
575 if let Ok(n) = trimmed.parse::<usize>() {
576 indices.push(n);
577 }
578 }
579 indices
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use crate::memory::MemoryCategory;
586
587 #[test]
588 fn test_extract_keywords() {
589 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
590 assert!(!keywords.is_empty());
591 }
592
593 #[test]
594 fn test_semantic_aliases() {
595 let keywords = vec!["数据库".to_string()];
596 let expanded = expand_semantic_keywords(&keywords);
597 assert!(expanded.contains(&"database".to_string()));
598 }
599
600 #[test]
601 fn test_tfidf_search() {
602 let mut tfidf = TfIdfSearch::new();
603 let mut memory = AutoMemory::new();
604
605 memory.add(MemoryEntry::new(
608 MemoryCategory::Decision,
609 "使用 PostgreSQL 作为数据库".to_string(),
610 None,
611 None,
612 ));
613 memory.add(MemoryEntry::new(
614 MemoryCategory::Decision,
615 "前端使用 React 框架开发".to_string(),
616 None,
617 None,
618 ));
619 memory.add(MemoryEntry::new(
620 MemoryCategory::Decision,
621 "后端采用 Rust 编写".to_string(),
622 None,
623 None,
624 ));
625
626 tfidf.index(&memory);
627 let results = tfidf.search("数据库", Some(5));
628 assert!(!results.is_empty());
629
630 assert!(results[0].0.contains("PostgreSQL"));
632 }
633}