1use std::collections::{HashMap, HashSet};
4
5use super::entry::MemoryEntry;
6use super::manager::AutoMemory;
7
8fn get_stop_words() -> HashSet<&'static str> {
14 HashSet::from([
15 "的", "了", "是", "在", "我", "有", "和", "就", "不", "都", "一", "也", "很", "到", "要", "去",
17 "你", "会", "着", "没有", "看", "好", "这", "那", "什么", "怎么", "请", "能", "可以", "需要",
18 "the", "a", "an", "is", "are", "was", "were", "be", "have", "has", "do", "will", "would",
20 "could", "should", "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "and",
21 "but", "or", "not", "if", "then", "this", "that", "it", "i", "me", "my", "we", "you", "he",
22 "she", "they", "please", "help", "need", "want", "let", "use",
23 ])
24}
25
26pub fn extract_context_keywords(context: &str) -> Vec<String> {
32 let stop_words = get_stop_words();
33 let lower = context.to_lowercase();
34 let mut keywords: HashSet<String> = HashSet::new();
35
36 for word in lower.split_whitespace() {
38 let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
39 if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
40 keywords.insert(cleaned);
41 }
42 }
43
44 let tech_regexes = [
46 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", ];
51
52 for pattern in tech_regexes {
53 if let Ok(re) = regex::Regex::new(pattern) {
54 for cap in re.find_iter(&lower) {
55 let match_str = cap.as_str();
56 if !stop_words.contains(match_str) {
57 keywords.insert(match_str.to_string());
58 }
59 }
60 }
61 }
62
63 let mut result: Vec<String> = keywords.into_iter().collect();
65 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
66 result.truncate(10);
67 result
68}
69
70const GREETING_PATTERNS: &[&str] = &[
76 "你好", "您好", "hi", "hello", "hey", "嗨", "早上好", "下午好", "晚上好",
77 "good morning", "good afternoon", "good evening",
78 "请问", "帮忙", "帮我", "帮我看", "看看", "help", "请",
79 "开始", "start", "准备好了", "ready",
80];
81
82pub fn should_skip_simple_message(msg: &str) -> bool {
85 let trimmed = msg.trim();
86
87 if trimmed.len() < 15 {
89 return true;
90 }
91
92 let lower = trimmed.to_lowercase();
94 for pattern in GREETING_PATTERNS {
95 if lower.starts_with(pattern) || lower == *pattern {
96 return true;
97 }
98 }
99
100 false
101}
102
103pub fn calculate_similarity(a: &str, b: &str) -> f64 {
105 AutoMemory::calculate_similarity(a, b)
106}
107
108pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
114 vec![
115 ("rust", "Rust"),
117 ("typescript", "TypeScript"),
118 ("javascript", "JavaScript"),
119 ("python", "Python"),
120 ("react", "React"),
121 ("vue", "Vue"),
122 ("angular", "Angular"),
123 ("数据库", "database"),
124 ("db", "database"),
125 ("修复", "fix"),
127 ("解决", "solve"),
128 ("优化", "optimize"),
129 ("重构", "refactor"),
130 ("更新", "update"),
131 ("删除", "delete"),
132 ("喜欢", "prefer"),
134 ("偏好", "prefer"),
135 ("首选", "prefer"),
136 ("入口", "entry"),
138 ("主文件", "main"),
139 ("目录", "directory"),
140 ]
141}
142
143pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
145 let aliases = get_semantic_aliases();
146 let mut expanded: Vec<String> = keywords.to_vec();
147
148 for keyword in keywords {
149 let kw_lower = keyword.to_lowercase();
150 for &(alias, target) in &aliases {
151 if kw_lower.contains(alias) {
152 expanded.push(target.to_string());
153 }
154 if kw_lower.contains(target) {
155 expanded.push(alias.to_string());
156 }
157 }
158 }
159
160 expanded.sort();
161 expanded.dedup();
162 expanded
163}
164
165pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
172 if context_keywords.is_empty() {
173 return 0.0;
174 }
175
176 let expanded_keywords = expand_semantic_keywords(context_keywords);
177 let content_lower = entry.content.to_lowercase();
178
179 let matches = expanded_keywords
180 .iter()
181 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
182 .count();
183
184 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
185
186 let tag_matches = entry
187 .tags
188 .iter()
189 .filter(|tag| {
190 let tag_lower = tag.to_lowercase();
191 expanded_keywords.iter().any(|kw| {
192 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
193 })
194 })
195 .count();
196
197 let tag_score = if tag_matches > 0 {
198 0.2 + (tag_matches as f64 * 0.05).min(0.1)
199 } else {
200 0.0
201 };
202
203 (keyword_score + tag_score).min(1.0)
204}
205
206pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
209 let contradiction_signals = [
211 "不再",
212 "改为",
213 "换成",
214 "放弃",
215 "no longer",
216 "instead of",
217 "changed to",
218 "switched to",
219 ];
220
221 for signal in &contradiction_signals {
223 if new.contains(signal) {
224 return true;
225 }
226 }
227
228 let action_verbs = [
230 "决定使用",
231 "选择使用",
232 "采用",
233 "使用",
234 "decided to use",
235 "chose",
236 "using",
237 "adopted",
238 ];
239
240 for verb in &action_verbs {
241 if old.contains(verb) && new.contains(verb) {
242 return true;
243 }
244 }
245
246 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
248 for verb in &pref_verbs {
249 if old.contains(verb) && new.contains(verb) {
250 return true;
251 }
252 }
253
254 false
255}
256
257
258
259pub struct TfIdfSearch {
268 doc_word_freq: HashMap<String, HashMap<String, f32>>,
270 total_docs: usize,
272 idf_cache: HashMap<String, f32>,
274}
275
276impl TfIdfSearch {
277 pub fn new() -> Self {
279 Self {
280 doc_word_freq: HashMap::new(),
281 total_docs: 0,
282 idf_cache: HashMap::new(),
283 }
284 }
285
286 pub fn index(&mut self, memory: &AutoMemory) {
288 self.clear();
289 self.total_docs = memory.entries.len();
290
291 for entry in &memory.entries {
292 let words = self.tokenize(&entry.content);
293 let word_freq = self.compute_word_freq(&words);
294 self.doc_word_freq.insert(entry.content.clone(), word_freq);
295 }
296
297 self.compute_idf();
298 }
299
300 fn tokenize(&self, text: &str) -> Vec<String> {
302 let lower = text.to_lowercase();
303 let mut tokens = Vec::new();
304
305 for word in lower.split_whitespace() {
306 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
307 if trimmed.len() > 1 {
308 tokens.push(trimmed.to_string());
309 }
310
311 let chars: Vec<char> = trimmed.chars().collect();
312 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
313
314 if has_cjk {
315 for c in &chars {
316 if Self::is_cjk(*c) {
317 tokens.push(c.to_string());
318 }
319 }
320 for window in chars.windows(2) {
321 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
322 tokens.push(window.iter().collect::<String>());
323 }
324 }
325 }
326 }
327
328 tokens
329 }
330
331 fn is_cjk(c: char) -> bool {
333 matches!(c,
334 '\u{4E00}'..='\u{9FFF}' |
335 '\u{3400}'..='\u{4DBF}' |
336 '\u{F900}'..='\u{FAFF}' |
337 '\u{3000}'..='\u{303F}' |
338 '\u{3040}'..='\u{309F}' |
339 '\u{30A0}'..='\u{30FF}'
340 )
341 }
342
343 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
345 let total = words.len() as f32;
346 let mut freq = HashMap::new();
347
348 for word in words {
349 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
350 }
351
352 for (_, count) in freq.iter_mut() {
353 *count /= total;
354 }
355
356 freq
357 }
358
359 fn compute_idf(&mut self) {
361 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
362
363 for word_freq in &self.doc_word_freq {
364 for word in word_freq.1.keys() {
365 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
366 }
367 }
368
369 for (word, count) in word_doc_count {
370 let idf = (self.total_docs as f32 / count as f32).ln();
371 self.idf_cache.insert(word, idf);
372 }
373 }
374
375 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
377 let query_words = self.tokenize(query);
378 let query_freq = self.compute_word_freq(&query_words);
379
380 let mut results: Vec<(String, f32)> = Vec::new();
381
382 for (doc, doc_freq) in &self.doc_word_freq {
383 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
384
385 if similarity > 0.0 {
386 results.push((doc.clone(), similarity));
387 }
388 }
389
390 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
391
392 if let Some(max) = limit {
393 results.into_iter().take(max).collect()
394 } else {
395 results
396 }
397 }
398
399 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
401 let mut doc_scores: HashMap<String, f64> = HashMap::new();
402
403 for keyword in keywords {
404 let results = self.search(keyword, None);
405 for (doc, score) in results {
406 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
407 }
408 }
409
410 let num_keywords = keywords.len().max(1);
411 for (_, score) in doc_scores.iter_mut() {
412 *score /= num_keywords as f64;
413 }
414
415 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
416 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417
418 if let Some(max) = limit {
419 results.into_iter().take(max).collect()
420 } else {
421 results
422 }
423 }
424
425 fn compute_tfidf_similarity(
427 &self,
428 query_freq: &HashMap<String, f32>,
429 doc_freq: &HashMap<String, f32>,
430 ) -> f32 {
431 let mut similarity = 0.0;
432
433 for (word, tf_query) in query_freq {
434 if let Some(tf_doc) = doc_freq.get(word)
435 && let Some(idf) = self.idf_cache.get(word)
436 {
437 similarity += tf_query * idf * tf_doc * idf;
438 }
439 }
440
441 similarity
442 }
443
444 pub fn clear(&mut self) {
446 self.doc_word_freq.clear();
447 self.idf_cache.clear();
448 self.total_docs = 0;
449 }
450}
451
452impl Default for TfIdfSearch {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
464
465返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
466- 只选择你确定会有帮助的记忆
467- 如果不确定某个记忆是否有用,不要选择它
468- 如果没有明显有用的记忆,可以返回空数组 []
469- 优先选择与当前问题直接相关的记忆
470
471返回格式示例:{"selected": [0, 2, 5]}
472"#;
473
474pub async fn ai_select_memories(
479 query: &str,
480 memory_manifest: &str,
481 provider: &dyn crate::providers::Provider,
482) -> Vec<usize> {
483 use crate::providers::{ChatRequest, Message, MessageContent, Role};
484
485 let truncated_query = if query.len() > 1000 {
487 &query[..1000]
488 } else {
489 query
490 };
491
492 let user_prompt = format!(
493 "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
494 truncated_query, memory_manifest
495 );
496
497 let request = ChatRequest {
498 messages: vec![Message {
499 role: Role::User,
500 content: MessageContent::Text(user_prompt),
501 }],
502 tools: vec![],
503 system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
504 think: false,
505 max_tokens: 100,
506 server_tools: vec![],
507 enable_caching: false,
508 };
509
510 let response = match provider.chat(request).await {
511 Ok(r) => r,
512 Err(_) => return Vec::new(),
513 };
514
515 let text = response
517 .content
518 .iter()
519 .filter_map(|block| {
520 if let crate::providers::ContentBlock::Text { text } = block {
521 Some(text.clone())
522 } else {
523 None
524 }
525 })
526 .collect::<Vec<_>>()
527 .join("");
528
529 parse_selected_indices(&text)
531}
532
533fn parse_selected_indices(text: &str) -> Vec<usize> {
535 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
537 if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
538 return selected
539 .iter()
540 .filter_map(|v| v.as_u64().map(|n| n as usize))
541 .collect();
542 }
543 if let Some(arr) = json.as_array() {
545 return arr
546 .iter()
547 .filter_map(|v| v.as_u64().map(|n| n as usize))
548 .collect();
549 }
550 }
551
552 let mut indices = Vec::new();
554 for part in text.split(',') {
555 let trimmed = part.trim();
556 if let Ok(n) = trimmed.parse::<usize>() {
557 indices.push(n);
558 }
559 }
560 indices
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::memory::MemoryCategory;
567
568 #[test]
569 fn test_extract_keywords() {
570 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
571 assert!(!keywords.is_empty());
572 }
573
574 #[test]
575 fn test_semantic_aliases() {
576 let keywords = vec!["数据库".to_string()];
577 let expanded = expand_semantic_keywords(&keywords);
578 assert!(expanded.contains(&"database".to_string()));
579 }
580
581 #[test]
582 fn test_tfidf_search() {
583 let mut tfidf = TfIdfSearch::new();
584 let mut memory = AutoMemory::new();
585
586 memory.add(MemoryEntry::new(
589 MemoryCategory::Decision,
590 "使用 PostgreSQL 作为数据库".to_string(),
591 None,
592 None,
593 ));
594 memory.add(MemoryEntry::new(
595 MemoryCategory::Decision,
596 "前端使用 React 框架开发".to_string(),
597 None,
598 None,
599 ));
600 memory.add(MemoryEntry::new(
601 MemoryCategory::Decision,
602 "后端采用 Rust 编写".to_string(),
603 None,
604 None,
605 ));
606
607 tfidf.index(&memory);
608 let results = tfidf.search("数据库", Some(5));
609 assert!(!results.is_empty());
610
611 assert!(results[0].0.contains("PostgreSQL"));
613 }
614}