1use std::collections::{HashMap, HashSet};
4
5use super::keywords_config::KeywordsConfig;
6use super::types::{AutoMemory, MemoryEntry};
7
8pub fn extract_context_keywords(context: &str) -> Vec<String> {
14 let config = KeywordsConfig::load();
15 let stop_words = config.get_stop_words_set();
16
17 let lower = context.to_lowercase();
18 let mut keywords: HashSet<String> = HashSet::new();
19
20 for word in lower.split_whitespace() {
22 let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
23 if cleaned.len() >= 3 && !stop_words.contains(cleaned.as_str()) {
24 keywords.insert(cleaned);
25 }
26 }
27
28 for category_patterns in config.patterns.values() {
30 for pattern in category_patterns {
31 if lower.contains(&pattern.to_lowercase()) {
32 keywords.insert(pattern.clone());
33 }
34 }
35 }
36
37 let tech_regexes = [
39 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", ];
44
45 for pattern in tech_regexes {
46 if let Ok(re) = regex::Regex::new(pattern) {
47 for cap in re.find_iter(&lower) {
48 let match_str = cap.as_str();
49 if !stop_words.contains(match_str) {
50 keywords.insert(match_str.to_string());
51 }
52 }
53 }
54 }
55
56 let mut result: Vec<String> = keywords.into_iter().collect();
58 result.sort_by_key(|b| std::cmp::Reverse(b.len()));
59 result.truncate(10);
60 result
61}
62
63const GREETING_PATTERNS: &[&str] = &[
69 "你好", "您好", "hi", "hello", "hey", "嗨", "早上好", "下午好", "晚上好",
70 "good morning", "good afternoon", "good evening",
71 "请问", "帮忙", "帮我", "帮我看", "看看", "help", "请",
72 "开始", "start", "准备好了", "ready",
73];
74
75pub fn should_skip_simple_message(msg: &str) -> bool {
78 let trimmed = msg.trim();
79
80 if trimmed.len() < 15 {
82 return true;
83 }
84
85 let lower = trimmed.to_lowercase();
87 for pattern in GREETING_PATTERNS {
88 if lower.starts_with(pattern) || lower == *pattern {
89 return true;
90 }
91 }
92
93 false
94}
95
96pub fn calculate_similarity(a: &str, b: &str) -> f64 {
98 AutoMemory::calculate_similarity(a, b)
99}
100
101pub fn get_semantic_aliases() -> Vec<(&'static str, &'static str)> {
107 KeywordsConfig::get_aliases()
108}
109
110pub fn expand_semantic_keywords(keywords: &[String]) -> Vec<String> {
112 let aliases = KeywordsConfig::get_aliases();
113 let mut expanded: Vec<String> = keywords.to_vec();
114
115 for keyword in keywords {
116 let kw_lower = keyword.to_lowercase();
117 for &(alias, target) in &aliases {
118 if kw_lower.contains(alias) {
119 expanded.push(target.to_string());
120 }
121 if kw_lower.contains(target) {
122 expanded.push(alias.to_string());
123 }
124 }
125 }
126
127 expanded.sort();
128 expanded.dedup();
129 expanded
130}
131
132pub fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
139 if context_keywords.is_empty() {
140 return 0.0;
141 }
142
143 let expanded_keywords = expand_semantic_keywords(context_keywords);
144 let content_lower = entry.content.to_lowercase();
145
146 let matches = expanded_keywords
147 .iter()
148 .filter(|kw| content_lower.contains(&kw.to_lowercase()))
149 .count();
150
151 let keyword_score = matches as f64 / expanded_keywords.len().max(context_keywords.len()) as f64;
152
153 let tag_matches = entry
154 .tags
155 .iter()
156 .filter(|tag| {
157 let tag_lower = tag.to_lowercase();
158 expanded_keywords.iter().any(|kw| {
159 tag_lower.contains(&kw.to_lowercase()) || kw.to_lowercase().contains(&tag_lower)
160 })
161 })
162 .count();
163
164 let tag_score = if tag_matches > 0 {
165 0.2 + (tag_matches as f64 * 0.05).min(0.1)
166 } else {
167 0.0
168 };
169
170 (keyword_score + tag_score).min(1.0)
171}
172
173pub fn has_contradiction_signal(old: &str, new: &str) -> bool {
176 let config = KeywordsConfig::load();
177
178 for signal in &config.contradiction_signals {
180 if new.contains(signal) {
181 return true;
182 }
183 }
184
185 let action_verbs = [
187 "决定使用",
188 "选择使用",
189 "采用",
190 "使用",
191 "decided to use",
192 "chose",
193 "using",
194 "adopted",
195 ];
196
197 for verb in &action_verbs {
198 if old.contains(verb) && new.contains(verb) {
199 return true;
200 }
201 }
202
203 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
205 for verb in &pref_verbs {
206 if old.contains(verb) && new.contains(verb) {
207 return true;
208 }
209 }
210
211 false
212}
213
214pub async fn extract_keywords_hybrid(
222 context: &str,
223 _fast_provider: Option<&dyn crate::providers::Provider>,
224) -> Vec<String> {
225 extract_context_keywords(context)
227}
228
229pub struct TfIdfSearch {
238 doc_word_freq: HashMap<String, HashMap<String, f32>>,
240 total_docs: usize,
242 idf_cache: HashMap<String, f32>,
244}
245
246impl TfIdfSearch {
247 pub fn new() -> Self {
249 Self {
250 doc_word_freq: HashMap::new(),
251 total_docs: 0,
252 idf_cache: HashMap::new(),
253 }
254 }
255
256 pub fn index(&mut self, memory: &AutoMemory) {
258 self.clear();
259 self.total_docs = memory.entries.len();
260
261 for entry in &memory.entries {
262 let words = self.tokenize(&entry.content);
263 let word_freq = self.compute_word_freq(&words);
264 self.doc_word_freq.insert(entry.content.clone(), word_freq);
265 }
266
267 self.compute_idf();
268 }
269
270 fn tokenize(&self, text: &str) -> Vec<String> {
272 let lower = text.to_lowercase();
273 let mut tokens = Vec::new();
274
275 for word in lower.split_whitespace() {
276 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
277 if trimmed.len() > 1 {
278 tokens.push(trimmed.to_string());
279 }
280
281 let chars: Vec<char> = trimmed.chars().collect();
282 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
283
284 if has_cjk {
285 for c in &chars {
286 if Self::is_cjk(*c) {
287 tokens.push(c.to_string());
288 }
289 }
290 for window in chars.windows(2) {
291 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
292 tokens.push(window.iter().collect::<String>());
293 }
294 }
295 }
296 }
297
298 tokens
299 }
300
301 fn is_cjk(c: char) -> bool {
303 matches!(c,
304 '\u{4E00}'..='\u{9FFF}' |
305 '\u{3400}'..='\u{4DBF}' |
306 '\u{F900}'..='\u{FAFF}' |
307 '\u{3000}'..='\u{303F}' |
308 '\u{3040}'..='\u{309F}' |
309 '\u{30A0}'..='\u{30FF}'
310 )
311 }
312
313 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
315 let total = words.len() as f32;
316 let mut freq = HashMap::new();
317
318 for word in words {
319 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
320 }
321
322 for (_, count) in freq.iter_mut() {
323 *count /= total;
324 }
325
326 freq
327 }
328
329 fn compute_idf(&mut self) {
331 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
332
333 for word_freq in &self.doc_word_freq {
334 for word in word_freq.1.keys() {
335 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
336 }
337 }
338
339 for (word, count) in word_doc_count {
340 let idf = (self.total_docs as f32 / count as f32).ln();
341 self.idf_cache.insert(word, idf);
342 }
343 }
344
345 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
347 let query_words = self.tokenize(query);
348 let query_freq = self.compute_word_freq(&query_words);
349
350 let mut results: Vec<(String, f32)> = Vec::new();
351
352 for (doc, doc_freq) in &self.doc_word_freq {
353 let similarity = self.compute_tfidf_similarity(&query_freq, doc_freq);
354
355 if similarity > 0.0 {
356 results.push((doc.clone(), similarity));
357 }
358 }
359
360 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
361
362 if let Some(max) = limit {
363 results.into_iter().take(max).collect()
364 } else {
365 results
366 }
367 }
368
369 pub fn search_multi(&self, keywords: &[&str], limit: Option<usize>) -> Vec<(String, f64)> {
371 let mut doc_scores: HashMap<String, f64> = HashMap::new();
372
373 for keyword in keywords {
374 let results = self.search(keyword, None);
375 for (doc, score) in results {
376 *doc_scores.entry(doc).or_insert(0.0) += score as f64;
377 }
378 }
379
380 let num_keywords = keywords.len().max(1);
381 for (_, score) in doc_scores.iter_mut() {
382 *score /= num_keywords as f64;
383 }
384
385 let mut results: Vec<(String, f64)> = doc_scores.into_iter().collect();
386 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
387
388 if let Some(max) = limit {
389 results.into_iter().take(max).collect()
390 } else {
391 results
392 }
393 }
394
395 fn compute_tfidf_similarity(
397 &self,
398 query_freq: &HashMap<String, f32>,
399 doc_freq: &HashMap<String, f32>,
400 ) -> f32 {
401 let mut similarity = 0.0;
402
403 for (word, tf_query) in query_freq {
404 if let Some(tf_doc) = doc_freq.get(word)
405 && let Some(idf) = self.idf_cache.get(word)
406 {
407 similarity += tf_query * idf * tf_doc * idf;
408 }
409 }
410
411 similarity
412 }
413
414 pub fn clear(&mut self) {
416 self.doc_word_freq.clear();
417 self.idf_cache.clear();
418 self.total_docs = 0;
419 }
420}
421
422impl Default for TfIdfSearch {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428const SELECT_MEMORIES_SYSTEM_PROMPT: &str = r#"你正在选择对处理用户查询有用的记忆。你会收到用户的查询和可用记忆文件列表(包含描述)。
434
435返回最有用的记忆索引列表(最多5个),以 JSON 数组格式返回。
436- 只选择你确定会有帮助的记忆
437- 如果不确定某个记忆是否有用,不要选择它
438- 如果没有明显有用的记忆,可以返回空数组 []
439- 优先选择与当前问题直接相关的记忆
440
441返回格式示例:{"selected": [0, 2, 5]}
442"#;
443
444pub async fn ai_select_memories(
449 query: &str,
450 memory_manifest: &str,
451 provider: &dyn crate::providers::Provider,
452) -> Vec<usize> {
453 use crate::providers::{ChatRequest, Message, MessageContent, Role};
454
455 let truncated_query = if query.len() > 1000 {
457 &query[..1000]
458 } else {
459 query
460 };
461
462 let user_prompt = format!(
463 "查询: {}\n\n可用记忆列表:\n{}\n\n请选择最有用的记忆索引(最多5个):",
464 truncated_query, memory_manifest
465 );
466
467 let request = ChatRequest {
468 messages: vec![Message {
469 role: Role::User,
470 content: MessageContent::Text(user_prompt),
471 }],
472 tools: vec![],
473 system: Some(SELECT_MEMORIES_SYSTEM_PROMPT.to_string()),
474 think: false,
475 max_tokens: 100,
476 server_tools: vec![],
477 enable_caching: false,
478 };
479
480 let response = match provider.chat(request).await {
481 Ok(r) => r,
482 Err(_) => return Vec::new(),
483 };
484
485 let text = response
487 .content
488 .iter()
489 .filter_map(|block| {
490 if let crate::providers::ContentBlock::Text { text } = block {
491 Some(text.clone())
492 } else {
493 None
494 }
495 })
496 .collect::<Vec<_>>()
497 .join("");
498
499 parse_selected_indices(&text)
501}
502
503fn parse_selected_indices(text: &str) -> Vec<usize> {
505 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
507 if let Some(selected) = json.get("selected").and_then(|s| s.as_array()) {
508 return selected
509 .iter()
510 .filter_map(|v| v.as_u64().map(|n| n as usize))
511 .collect();
512 }
513 if let Some(arr) = json.as_array() {
515 return arr
516 .iter()
517 .filter_map(|v| v.as_u64().map(|n| n as usize))
518 .collect();
519 }
520 }
521
522 let mut indices = Vec::new();
524 for part in text.split(',') {
525 let trimmed = part.trim();
526 if let Ok(n) = trimmed.parse::<usize>() {
527 indices.push(n);
528 }
529 }
530 indices
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_extract_keywords() {
539 let keywords = extract_context_keywords("使用 PostgreSQL 数据库配置");
540 assert!(!keywords.is_empty());
541 }
542
543 #[test]
544 fn test_semantic_aliases() {
545 let keywords = vec!["数据库".to_string()];
546 let expanded = expand_semantic_keywords(&keywords);
547 assert!(expanded.contains(&"database".to_string()));
548 }
549
550 #[test]
551 fn test_tfidf_search() {
552 let mut tfidf = TfIdfSearch::new();
553 let mut memory = AutoMemory::new();
554
555 memory.add(super::super::types::MemoryEntry::new(
558 super::super::types::MemoryCategory::Decision,
559 "使用 PostgreSQL 作为数据库".to_string(),
560 None,
561 None,
562 ));
563 memory.add(super::super::types::MemoryEntry::new(
564 super::super::types::MemoryCategory::Decision,
565 "前端使用 React 框架开发".to_string(),
566 None,
567 None,
568 ));
569 memory.add(super::super::types::MemoryEntry::new(
570 super::super::types::MemoryCategory::Decision,
571 "后端采用 Rust 编写".to_string(),
572 None,
573 None,
574 ));
575
576 tfidf.index(&memory);
577 let results = tfidf.search("数据库", Some(5));
578 assert!(!results.is_empty());
579
580 assert!(results[0].0.contains("PostgreSQL"));
582 }
583}