1use anyhow::Result;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::fs;
13
14use crate::providers::Message;
15
16fn truncate_str(s: &str, max_len: usize) -> String {
18 if s.len() > max_len {
19 format!("{}...", &s[..max_len.saturating_sub(3)])
20 } else {
21 s.to_string()
22 }
23}
24
25fn truncate(s: &str, max_len: usize) -> String {
26 if s.len() > max_len {
27 s[..max_len].to_string()
28 } else {
29 s.to_string()
30 }
31}
32
33pub const MAX_IMPORTANCE_CEILING: f64 = 100.0;
39
40pub const MIN_SIMILARITY_LENGTH: usize = 10;
42
43pub const SIMILARITY_THRESHOLD: f64 = 0.85;
46
47pub const MIN_MEMORY_CONTENT_LENGTH: usize = 20;
50
51pub const MAX_DETECTED_ENTRIES: usize = 5;
53
54pub const MAX_MEMORY_CONTENT_LENGTH: usize = 200;
56
57pub const MAX_DISPLAY_LENGTH: usize = 60;
59
60pub const CONFLICT_OVERLAY_THRESHOLD: f64 = 0.5;
62
63pub const CONFLICT_OVERLAY_THRESHOLD_WITH_SIGNAL: f64 = 0.3;
65
66pub const IMPORTANCE_STAR_THRESHOLD: f64 = 80.0;
68
69pub const CONTEXT_RELEVANCE_WEIGHT: f64 = 0.6;
71
72pub const CONTEXT_IMPORTANCE_WEIGHT: f64 = 0.4;
74
75pub const DEFAULT_MEMORY_EXTRACTOR_MODEL: &str = "claude-3-5-haiku-20241022";
77
78pub const MIN_KEYWORDS_FOR_AI_FALLBACK: usize = 2;
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum AiKeywordMode {
85 #[default]
87 Auto,
88 Always,
90 Never,
92}
93
94impl AiKeywordMode {
95 pub fn from_env() -> Self {
97 match std::env::var("MEMORY_AI_KEYWORDS")
98 .unwrap_or_default()
99 .to_lowercase()
100 .as_str()
101 {
102 "always" | "true" | "1" => AiKeywordMode::Always,
103 "never" | "false" | "0" => AiKeywordMode::Never,
104 "auto" | "" => AiKeywordMode::Auto,
105 other => {
106 log::warn!("Unknown MEMORY_AI_KEYWORDS value: '{}', using 'auto'", other);
107 AiKeywordMode::Auto
108 }
109 }
110 }
111
112 pub fn should_use_ai(&self, keyword_count: usize) -> bool {
114 match self {
115 AiKeywordMode::Always => true,
116 AiKeywordMode::Never => false,
117 AiKeywordMode::Auto => keyword_count < MIN_KEYWORDS_FOR_AI_FALLBACK,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
125pub enum AiDetectionMode {
126 #[default]
128 Auto,
129 Always,
131 Never,
133}
134
135impl AiDetectionMode {
136 pub fn from_env() -> Self {
138 match std::env::var("MEMORY_AI_DETECTION")
139 .unwrap_or_default()
140 .to_lowercase()
141 .as_str()
142 {
143 "always" | "true" | "1" => AiDetectionMode::Always,
144 "never" | "false" | "0" => AiDetectionMode::Never,
145 "auto" | "" => AiDetectionMode::Auto,
146 other => {
147 log::warn!("Unknown MEMORY_AI_DETECTION value: '{}', using 'auto'", other);
148 AiDetectionMode::Auto
149 }
150 }
151 }
152
153 pub fn should_use_ai(&self) -> bool {
155 match self {
156 AiDetectionMode::Always => true,
157 AiDetectionMode::Never => false,
158 AiDetectionMode::Auto => {
159 false }
163 }
164 }
165
166 pub fn should_use_ai_for_text(&self, text_len: usize) -> bool {
169 match self {
170 AiDetectionMode::Always => true,
171 AiDetectionMode::Never => false,
172 AiDetectionMode::Auto => text_len > 500, }
174 }
175}
176
177pub const DEFAULT_FAST_MODEL: &str = "claude-3-5-haiku-20241022";
179
180pub const DEFAULT_IMPORTANCE_DECISION: f64 = 75.0; pub const DEFAULT_IMPORTANCE_SOLUTION: f64 = 70.0; pub const DEFAULT_IMPORTANCE_PREF: f64 = 65.0; pub const DEFAULT_IMPORTANCE_FINDING: f64 = 55.0; pub const DEFAULT_IMPORTANCE_TECH: f64 = 45.0; pub const DEFAULT_IMPORTANCE_STRUCTURE: f64 = 35.0; #[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct MemoryConfig {
196 pub max_entries: usize,
198 pub min_importance: f64,
200 pub enabled: bool,
202 pub decay_start_days: i64,
204 pub decay_rate: f64,
206 pub reference_increment: f64,
208 pub max_importance_ceiling: f64,
210}
211
212impl Default for MemoryConfig {
213 fn default() -> Self {
214 Self {
215 max_entries: 100,
216 min_importance: 30.0,
217 enabled: true,
218 decay_start_days: 30,
219 decay_rate: 0.5,
220 reference_increment: 1.0, max_importance_ceiling: MAX_IMPORTANCE_CEILING,
222 }
223 }
224}
225
226impl MemoryConfig {
227 pub fn with_max_entries(max: usize) -> Self {
229 Self {
230 max_entries: max,
231 ..Self::default()
232 }
233 }
234
235 pub fn minimal() -> Self {
237 Self {
238 max_entries: 50,
239 min_importance: 50.0,
240 enabled: true,
241 decay_start_days: 14,
242 decay_rate: 0.6,
243 reference_increment: 1.0,
244 max_importance_ceiling: MAX_IMPORTANCE_CEILING,
245 }
246 }
247
248 pub fn archival() -> Self {
250 Self {
251 max_entries: 500,
252 min_importance: 20.0,
253 enabled: true,
254 decay_start_days: 90,
255 decay_rate: 0.3,
256 reference_increment: 3.0,
257 max_importance_ceiling: MAX_IMPORTANCE_CEILING,
258 }
259 }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
268#[serde(rename_all = "snake_case")]
269pub enum MemoryCategory {
270 Preference,
272 Decision,
274 Finding,
276 Solution,
278 Technical,
280 Structure,
282}
283
284impl MemoryCategory {
285 pub fn display_name(&self) -> &'static str {
287 match self {
288 MemoryCategory::Preference => "偏好",
289 MemoryCategory::Decision => "决策",
290 MemoryCategory::Finding => "发现",
291 MemoryCategory::Solution => "解决方案",
292 MemoryCategory::Technical => "技术",
293 MemoryCategory::Structure => "结构",
294 }
295 }
296
297 pub fn icon(&self) -> &'static str {
299 match self {
300 MemoryCategory::Preference => "👤",
301 MemoryCategory::Decision => "🎯",
302 MemoryCategory::Finding => "💡",
303 MemoryCategory::Solution => "🔧",
304 MemoryCategory::Technical => "📚",
305 MemoryCategory::Structure => "🏗️",
306 }
307 }
308
309 pub fn default_importance(&self) -> f64 {
311 match self {
312 MemoryCategory::Decision => DEFAULT_IMPORTANCE_DECISION,
313 MemoryCategory::Solution => DEFAULT_IMPORTANCE_SOLUTION,
314 MemoryCategory::Preference => DEFAULT_IMPORTANCE_PREF,
315 MemoryCategory::Finding => DEFAULT_IMPORTANCE_FINDING,
316 MemoryCategory::Technical => DEFAULT_IMPORTANCE_TECH,
317 MemoryCategory::Structure => DEFAULT_IMPORTANCE_STRUCTURE,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct MemoryEntry {
329 pub id: String,
331 pub created_at: DateTime<Utc>,
333 pub last_referenced: DateTime<Utc>,
335 pub category: MemoryCategory,
337 pub content: String,
339 pub source_session: Option<String>,
341 pub reference_count: u32,
343 pub importance: f64,
345 pub tags: Vec<String>,
347 pub is_manual: bool,
349}
350
351impl MemoryEntry {
352 pub fn new(category: MemoryCategory, content: String, source_session: Option<String>) -> Self {
354 let id = uuid::Uuid::new_v4().to_string();
355 Self {
356 id,
357 created_at: Utc::now(),
358 last_referenced: Utc::now(),
359 category,
360 content,
361 source_session,
362 reference_count: 0,
363 importance: category.default_importance(),
364 tags: Vec::new(),
365 is_manual: false,
366 }
367 }
368
369 pub fn manual(category: MemoryCategory, content: String) -> Self {
371 let mut entry = Self::new(category, content, None);
372 entry.is_manual = true;
373 entry.importance = 95.0; entry
375 }
376
377 pub fn mark_referenced(&mut self) {
379 self.mark_referenced_with_increment(2.0);
380 }
381
382 pub fn mark_referenced_with_increment(&mut self, increment: f64) {
384 self.reference_count += 1;
385 self.last_referenced = Utc::now();
386 self.importance = (self.importance + increment).min(MAX_IMPORTANCE_CEILING);
388 }
389
390 pub fn format_line(&self) -> String {
392 let time = self.created_at.format("%Y-%m-%d %H:%M");
393 let importance_marker = if self.importance >= IMPORTANCE_STAR_THRESHOLD { "⭐" } else { "" };
394 let manual_marker = if self.is_manual { "📝" } else { "" };
395 format!(
396 "{} {} {}{}{} {}",
397 self.category.icon(),
398 time,
399 importance_marker,
400 manual_marker,
401 self.category.display_name(),
402 truncate_str(&self.content, MAX_DISPLAY_LENGTH)
403 )
404 }
405
406 pub fn format_for_prompt(&self) -> String {
408 let category_name = self.category.display_name();
409 if self.content.len() > MAX_MEMORY_CONTENT_LENGTH {
410 format!("{}: {}...", category_name, truncate(&self.content, MAX_MEMORY_CONTENT_LENGTH - 3))
411 } else {
412 format!("{}: {}", category_name, self.content)
413 }
414 }
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
423pub struct AutoMemory {
424 pub entries: Vec<MemoryEntry>,
426 #[serde(default)]
428 pub config: MemoryConfig,
429 #[serde(default = "default_max_entries")]
431 pub max_entries: usize,
432 #[serde(default = "default_min_importance")]
433 pub min_importance: f64,
434 #[serde(default = "default_enabled")]
435 pub enabled: bool,
436 #[serde(skip)]
438 search_index: Option<SearchIndex>,
439}
440
441#[derive(Debug, Clone)]
443struct SearchIndex {
444 content_lower: Vec<String>,
446 by_category: HashMap<MemoryCategory, Vec<usize>>,
448 by_importance: Vec<usize>,
450 #[allow(dead_code)]
452 word_freq: HashMap<String, usize>,
453}
454
455impl SearchIndex {
456 fn build(entries: &[MemoryEntry]) -> Self {
458 let content_lower: Vec<String> = entries
460 .iter()
461 .map(|e| e.content.to_lowercase())
462 .collect();
463
464 let mut by_category: HashMap<MemoryCategory, Vec<usize>> = HashMap::new();
466 for (i, entry) in entries.iter().enumerate() {
467 by_category.entry(entry.category).or_default().push(i);
468 }
469
470 let mut by_importance: Vec<usize> = (0..entries.len()).collect();
472 by_importance.sort_by(|a, b| {
473 entries[*b].importance.partial_cmp(&entries[*a].importance)
474 .unwrap_or(std::cmp::Ordering::Equal)
475 });
476
477 let mut word_freq: HashMap<String, usize> = HashMap::new();
479 for content in &content_lower {
480 for word in content.split_whitespace() {
481 *word_freq.entry(word.to_string()).or_default() += 1;
482 }
483 }
484
485 Self {
486 content_lower,
487 by_category,
488 by_importance,
489 word_freq,
490 }
491 }
492
493 #[allow(dead_code)]
495 fn get_lower(&self, idx: usize) -> &str {
496 &self.content_lower[idx]
497 }
498
499 fn search(&self, _entries: &[MemoryEntry], query_lower: &str, limit: Option<usize>) -> Vec<usize> {
501 let matches: Vec<usize> = self.by_importance
503 .iter()
504 .filter(|&idx| self.content_lower[*idx].contains(query_lower))
505 .copied()
506 .collect();
507
508 if let Some(max) = limit {
509 matches.into_iter().take(max).collect()
510 } else {
511 matches
512 }
513 }
514
515 fn search_multi(&self, keywords_lower: &[String]) -> Vec<usize> {
517 self.by_importance
518 .iter()
519 .filter(|&idx| {
520 let content = &self.content_lower[*idx];
521 keywords_lower.iter().any(|k| content.contains(k))
522 })
523 .copied()
524 .collect()
525 }
526
527 #[allow(dead_code)]
529 fn rebuild(&mut self, entries: &[MemoryEntry]) {
530 *self = Self::build(entries);
531 }
532}
533
534fn default_max_entries() -> usize { 100 }
535fn default_min_importance() -> f64 { 30.0 }
536fn default_enabled() -> bool { true }
537
538impl Default for AutoMemory {
539 fn default() -> Self {
540 let config = MemoryConfig::default();
541 Self {
542 entries: Vec::new(),
543 config: config.clone(),
544 max_entries: config.max_entries,
545 min_importance: config.min_importance,
546 enabled: config.enabled,
547 search_index: None,
548 }
549 }
550}
551
552impl AutoMemory {
553 pub fn new() -> Self {
555 Self::default()
556 }
557
558 fn ensure_index(&mut self) {
560 if self.search_index.is_none() {
561 self.rebuild_index();
562 }
563 }
564
565 pub fn rebuild_index(&mut self) {
567 self.search_index = Some(SearchIndex::build(&self.entries));
568 }
569
570 fn invalidate_index(&mut self) {
572 self.search_index = None;
573 }
574
575 pub fn with_config(config: MemoryConfig) -> Self {
577 Self {
578 entries: Vec::new(),
579 config: config.clone(),
580 max_entries: config.max_entries,
581 min_importance: config.min_importance,
582 enabled: config.enabled,
583 search_index: None,
584 }
585 }
586
587 pub fn minimal() -> Self {
589 Self::with_config(MemoryConfig::minimal())
590 }
591
592 pub fn archival() -> Self {
594 Self::with_config(MemoryConfig::archival())
595 }
596
597 pub fn add(&mut self, entry: MemoryEntry) {
599 self.entries.push(entry);
600 self.invalidate_index(); self.prune();
602 }
603
604 pub fn add_memory(
606 &mut self,
607 category: MemoryCategory,
608 content: String,
609 source_session: Option<String>,
610 ) {
611 if self.has_similar(&content) {
613 return;
614 }
615
616 if let Some(conflict_idx) = self.find_conflict(&content, category) {
618 let old_content = self.entries[conflict_idx].content.clone();
620 log::debug!("Memory conflict detected: '{}' supersedes '{}'", content, old_content);
621 self.entries.remove(conflict_idx);
622 self.invalidate_index();
623 }
624
625 let entry = MemoryEntry::new(category, content, source_session);
626 self.add(entry);
627 }
628
629 fn find_conflict(&self, new_content: &str, category: MemoryCategory) -> Option<usize> {
640 let new_lower = new_content.to_lowercase();
641 let new_words: std::collections::HashSet<&str> = new_lower.split_whitespace().collect();
642
643 let has_change_signal = has_contradiction_signal("", &new_lower);
645 let overlap_threshold = if has_change_signal {
646 CONFLICT_OVERLAY_THRESHOLD_WITH_SIGNAL
647 } else {
648 CONFLICT_OVERLAY_THRESHOLD
649 };
650
651 for (i, entry) in self.entries.iter().enumerate() {
653 if entry.category != category {
654 continue;
655 }
656
657 let entry_lower = entry.content.to_lowercase();
658 let entry_words: std::collections::HashSet<&str> = entry_lower.split_whitespace().collect();
659
660 let intersection = new_words.intersection(&entry_words).count();
662 let min_len = new_words.len().min(entry_words.len());
663
664 if min_len == 0 {
665 continue;
666 }
667
668 let topic_overlap = intersection as f64 / min_len as f64;
669
670 let jaccard = Self::calculate_similarity(&entry_lower, &new_lower);
672
673 if topic_overlap > overlap_threshold && jaccard < SIMILARITY_THRESHOLD {
674 if has_contradiction_signal(&entry_lower, &new_lower) {
676 return Some(i);
677 }
678 }
679
680 if has_change_signal {
683 let old_key_terms: Vec<&str> = entry_words.iter()
685 .filter(|w| w.len() > 2)
686 .copied()
687 .collect();
688 let referenced = old_key_terms.iter()
689 .any(|term| new_lower.contains(term));
690 if referenced {
691 return Some(i);
692 }
693 }
694 }
695
696 None
697 }
698
699 pub fn has_similar(&self, content: &str) -> bool {
702 let content_lower = content.to_lowercase();
703
704 if content_lower.len() < MIN_SIMILARITY_LENGTH {
706 return false;
707 }
708
709 self.entries.iter().any(|e| {
710 let entry_lower = e.content.to_lowercase();
711
712 if entry_lower == content_lower {
714 return true;
715 }
716
717 if entry_lower.len() < MIN_SIMILARITY_LENGTH {
719 return false;
720 }
721
722 let similarity = Self::calculate_similarity(&entry_lower, &content_lower);
724 similarity >= SIMILARITY_THRESHOLD
725 })
726 }
727
728fn calculate_similarity(a: &str, b: &str) -> f64 {
731 use std::collections::HashSet;
732
733 let a_words: HashSet<&str> = a.split_whitespace().collect();
734 let b_words: HashSet<&str> = b.split_whitespace().collect();
735
736 if a_words.is_empty() || b_words.is_empty() {
737 return 0.0;
738 }
739
740 let intersection = a_words.intersection(&b_words).count();
741 let union = a_words.union(&b_words).count();
742
743 if union == 0 {
744 0.0
745 } else {
746 intersection as f64 / union as f64
747 }
748 }
749
750 pub fn prune(&mut self) {
753 if self.entries.len() <= self.max_entries {
754 return;
755 }
756
757 let (manual_entries, auto_entries): (Vec<_>, Vec<_>) = self.entries
760 .iter()
761 .cloned()
762 .partition(|e| e.is_manual);
763
764 let mut sorted_auto = auto_entries;
766 sorted_auto.sort_by(|a, b| {
767 let importance_cmp = b.importance.partial_cmp(&a.importance)
769 .unwrap_or(std::cmp::Ordering::Equal);
770
771 if importance_cmp == std::cmp::Ordering::Equal {
773 b.last_referenced.cmp(&a.last_referenced)
774 } else {
775 importance_cmp
776 }
777 });
778
779 let kept_auto: Vec<_> = sorted_auto
781 .into_iter()
782 .filter(|e| e.importance >= self.min_importance)
783 .take(self.max_entries.saturating_sub(manual_entries.len()))
784 .collect();
785
786 self.entries = manual_entries.into_iter().chain(kept_auto).collect();
788
789 if self.entries.len() > self.max_entries {
791 self.entries.sort_by(|a, b| {
792 let importance_cmp = b.importance.partial_cmp(&a.importance)
793 .unwrap_or(std::cmp::Ordering::Equal);
794 if importance_cmp == std::cmp::Ordering::Equal {
795 b.last_referenced.cmp(&a.last_referenced)
796 } else {
797 importance_cmp
798 }
799 });
800 self.entries.truncate(self.max_entries);
801 }
802
803 self.invalidate_index(); }
805
806 pub fn by_category(&self, category: MemoryCategory) -> Vec<&MemoryEntry> {
808 self.entries.iter().filter(|e| e.category == category).collect()
809 }
810
811 pub fn by_category_fast(&mut self, category: MemoryCategory) -> Vec<&MemoryEntry> {
813 self.ensure_index();
814 if let Some(ref index) = self.search_index {
815 index.by_category.get(&category)
816 .map(|indices| indices.iter().map(|&i| &self.entries[i]).collect())
817 .unwrap_or_default()
818 } else {
819 self.by_category(category)
820 }
821 }
822
823 pub fn top_n(&self, n: usize) -> Vec<&MemoryEntry> {
825 let mut sorted: Vec<_> = self.entries.iter().collect();
826 sorted.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
827 sorted.into_iter().take(n).collect()
828 }
829
830 pub fn top_n_fast(&mut self, n: usize) -> Vec<&MemoryEntry> {
832 self.ensure_index();
833 if let Some(ref index) = self.search_index {
834 index.by_importance
835 .iter()
836 .take(n)
837 .map(|&i| &self.entries[i])
838 .collect()
839 } else {
840 self.top_n(n)
841 }
842 }
843
844 pub fn search(&self, query: &str) -> Vec<&MemoryEntry> {
846 self.search_with_limit(query, None)
847 }
848
849 pub fn search_with_limit(&self, query: &str, limit: Option<usize>) -> Vec<&MemoryEntry> {
851 let query_lower = query.to_lowercase();
852 let mut results: Vec<_> = self.entries
853 .iter()
854 .filter(|e| {
855 e.content.to_lowercase().contains(&query_lower) ||
856 e.tags.iter().any(|t| t.to_lowercase().contains(&query_lower))
857 })
858 .collect();
859
860 results.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
862
863 if let Some(max) = limit {
864 results.into_iter().take(max).collect()
865 } else {
866 results
867 }
868 }
869
870 pub fn search_fast(&mut self, query: &str, limit: Option<usize>) -> Vec<&MemoryEntry> {
872 self.ensure_index();
873 let query_lower = query.to_lowercase();
874
875 if let Some(ref index) = self.search_index {
876 let indices = index.search(&self.entries, &query_lower, limit);
877 indices.iter().map(|&i| &self.entries[i]).collect()
878 } else {
879 self.search_with_limit(query, limit)
880 }
881 }
882
883 pub fn search_multi(&self, keywords: &[&str]) -> Vec<&MemoryEntry> {
885 if keywords.is_empty() {
886 return Vec::new();
887 }
888
889 let keywords_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
890
891 self.entries
892 .iter()
893 .filter(|e| {
894 let content_lower = e.content.to_lowercase();
895 keywords_lower.iter().any(|k| content_lower.contains(k))
896 })
897 .collect()
898 }
899
900 pub fn search_multi_fast(&mut self, keywords: &[&str]) -> Vec<&MemoryEntry> {
902 if keywords.is_empty() {
903 return Vec::new();
904 }
905
906 self.ensure_index();
907 let keywords_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
908
909 if let Some(ref index) = self.search_index {
910 let indices = index.search_multi(&keywords_lower);
911 indices.iter().map(|&i| &self.entries[i]).collect()
912 } else {
913 self.search_multi(keywords)
914 }
915 }
916
917 pub fn add_batch(&mut self, entries: Vec<MemoryEntry>) {
920 for entry in entries {
922 if !self.has_similar(&entry.content) {
923 self.entries.push(entry);
924 }
925 }
926 self.prune();
928 }
929
930 pub fn update_references(&mut self, messages: &[Message]) {
933 let increment = self.config.reference_increment;
934
935 let texts_lower: Vec<String> = messages
937 .iter()
938 .filter_map(Self::extract_message_text_lower)
939 .collect();
940
941 let entry_contents_lower: Vec<String> = self.entries
943 .iter()
944 .map(|e| e.content.to_lowercase())
945 .collect();
946
947 for (i, entry) in self.entries.iter_mut().enumerate() {
949 let entry_lower = &entry_contents_lower[i];
950 if texts_lower.iter().any(|t| t.contains(entry_lower)) {
951 entry.mark_referenced_with_increment(increment);
952 }
953 }
954 }
955
956 fn extract_message_text_lower(msg: &Message) -> Option<String> {
958 match &msg.content {
959 crate::providers::MessageContent::Text(t) => Some(t.to_lowercase()),
960 crate::providers::MessageContent::Blocks(blocks) => {
961 let text = blocks
962 .iter()
963 .filter_map(|b| {
964 if let crate::providers::ContentBlock::Text { text } = b {
965 Some(text.as_str())
966 } else {
967 None
968 }
969 })
970 .collect::<Vec<_>>()
971 .join(" ");
972 Some(text.to_lowercase())
973 }
974 }
975 }
976
977 pub fn generate_prompt_summary(&self, max_entries: usize) -> String {
979 if self.entries.is_empty() {
980 return String::new();
981 }
982
983 let top_entries = self.top_n(max_entries);
984 if top_entries.is_empty() {
985 return String::new();
986 }
987
988 let mut summary = String::from("【自动记忆摘要】\n\n");
989
990 let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
992 for entry in top_entries {
993 by_cat.entry(entry.category).or_default().push(entry);
994 }
995
996 for (cat, entries) in by_cat {
997 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
998 for entry in entries {
999 summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
1000 }
1001 summary.push('\n');
1002 }
1003
1004 summary
1005 }
1006
1007 pub fn generate_contextual_summary(&self, context: &str, max_entries: usize) -> String {
1017 let keywords = extract_context_keywords(context);
1019 self.generate_contextual_summary_with_keywords(&keywords, max_entries)
1020 }
1021
1022 pub fn generate_contextual_summary_with_keywords(&self, context_keywords: &[String], max_entries: usize) -> String {
1025 if self.entries.is_empty() {
1026 return String::new();
1027 }
1028
1029 let mut scored: Vec<(&MemoryEntry, f64)> = self.entries
1031 .iter()
1032 .map(|entry| {
1033 let relevance = compute_relevance(entry, &context_keywords);
1034 (entry, relevance)
1035 })
1036 .collect();
1037
1038 scored.sort_by(|a, b| {
1040 if a.0.is_manual && !b.0.is_manual {
1042 return std::cmp::Ordering::Less;
1043 }
1044 if !a.0.is_manual && b.0.is_manual {
1045 return std::cmp::Ordering::Greater;
1046 }
1047
1048 let score_a = a.1 * CONTEXT_RELEVANCE_WEIGHT + (a.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
1050 let score_b = b.1 * CONTEXT_RELEVANCE_WEIGHT + (b.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
1051
1052 score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
1053 });
1054
1055 let selected: Vec<&MemoryEntry> = scored
1057 .iter()
1058 .take(max_entries)
1059 .map(|(entry, _)| *entry)
1060 .collect();
1061
1062 if selected.is_empty() {
1063 return String::new();
1064 }
1065
1066 let mut summary = String::from("【跨会话记忆】\n\n");
1067
1068 let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
1070 for entry in selected {
1071 by_cat.entry(entry.category).or_default().push(entry);
1072 }
1073
1074 for (cat, entries) in by_cat {
1075 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
1076 for entry in entries {
1077 summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
1078 }
1079 summary.push('\n');
1080 }
1081
1082 summary
1083 }
1084
1085 pub async fn generate_contextual_summary_async(
1090 &self,
1091 context: &str,
1092 max_entries: usize,
1093 fast_provider: Option<&dyn crate::providers::Provider>,
1094 ) -> String {
1095 if self.entries.is_empty() {
1096 return String::new();
1097 }
1098
1099 let context_keywords = if let Some(provider) = fast_provider {
1101 extract_keywords_hybrid(context, Some(provider)).await
1102 } else {
1103 extract_context_keywords(context)
1104 };
1105
1106 let mut scored: Vec<(&MemoryEntry, f64)> = self.entries
1108 .iter()
1109 .map(|entry| {
1110 let relevance = compute_relevance(entry, &context_keywords);
1111 (entry, relevance)
1112 })
1113 .collect();
1114
1115 scored.sort_by(|a, b| {
1117 if a.0.is_manual && !b.0.is_manual {
1119 return std::cmp::Ordering::Less;
1120 }
1121 if !a.0.is_manual && b.0.is_manual {
1122 return std::cmp::Ordering::Greater;
1123 }
1124
1125 let score_a = a.1 * CONTEXT_RELEVANCE_WEIGHT + (a.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
1127 let score_b = b.1 * CONTEXT_RELEVANCE_WEIGHT + (b.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
1128
1129 score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
1130 });
1131
1132 let selected: Vec<&MemoryEntry> = scored
1134 .iter()
1135 .take(max_entries)
1136 .map(|(entry, _)| *entry)
1137 .collect();
1138
1139 if selected.is_empty() {
1140 return String::new();
1141 }
1142
1143 let mut summary = String::from("【跨会话记忆】\n\n");
1144
1145 let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
1147 for entry in selected {
1148 by_cat.entry(entry.category).or_default().push(entry);
1149 }
1150
1151 for (cat, entries) in by_cat {
1152 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
1153 for entry in entries {
1154 summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
1155 }
1156 summary.push('\n');
1157 }
1158
1159 summary
1160 }
1161
1162 pub fn format_all(&self) -> String {
1164 if self.entries.is_empty() {
1165 return "[no memories accumulated]".to_string();
1166 }
1167
1168 let mut result = String::from("Accumulated memories:\n\n");
1169
1170 let mut sorted: Vec<_> = self.entries.iter().collect();
1172 sorted.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
1173
1174 for entry in sorted {
1175 result.push_str(&entry.format_line());
1176 result.push('\n');
1177 }
1178
1179 result
1180 }
1181
1182 pub fn generate_statistics(&self) -> MemoryStatistics {
1184 let total = self.entries.len();
1185 let manual = self.entries.iter().filter(|e| e.is_manual).count();
1186 let auto = total - manual;
1187
1188 let by_category: HashMap<MemoryCategory, usize> = self.entries
1190 .iter()
1191 .fold(HashMap::new(), |mut acc, e| {
1192 *acc.entry(e.category).or_default() += 1;
1193 acc
1194 });
1195
1196 let avg_importance = if total > 0 {
1198 self.entries.iter().map(|e| e.importance).sum::<f64>() / total as f64
1199 } else {
1200 0.0
1201 };
1202
1203 let oldest = self.entries
1205 .iter()
1206 .min_by_key(|e| e.created_at)
1207 .map(|e| e.created_at);
1208 let newest = self.entries
1209 .iter()
1210 .max_by_key(|e| e.created_at)
1211 .map(|e| e.created_at);
1212
1213 let highly_referenced = self.entries
1215 .iter()
1216 .filter(|e| e.reference_count >= 3)
1217 .count();
1218
1219 MemoryStatistics {
1220 total,
1221 manual,
1222 auto,
1223 by_category,
1224 avg_importance,
1225 oldest,
1226 newest,
1227 highly_referenced,
1228 }
1229 }
1230
1231 pub fn clear(&mut self) {
1233 self.entries.clear();
1234 self.invalidate_index();
1235 }
1236
1237 pub fn remove(&mut self, id: &str) -> bool {
1239 let idx = self.entries.iter().position(|e| e.id == id);
1240 if let Some(i) = idx {
1241 self.entries.remove(i);
1242 self.invalidate_index();
1243 true
1244 } else {
1245 false
1246 }
1247 }
1248
1249 pub fn apply_time_decay(&mut self) {
1252 let now = Utc::now();
1253 let decay_start_days = self.config.decay_start_days;
1254 let decay_rate = self.config.decay_rate;
1255 let decay_period_days = 30; for entry in &mut self.entries {
1258 if entry.is_manual {
1260 continue;
1261 }
1262
1263 let days_since_reference = (now - entry.last_referenced)
1265 .num_days()
1266 .max(0);
1267
1268 if days_since_reference > decay_start_days {
1270 let decay_periods = (days_since_reference - decay_start_days) / decay_period_days;
1272
1273 let decay_factor = decay_rate.powi(decay_periods as i32);
1275 entry.importance *= decay_factor;
1276
1277 entry.importance = entry.importance.max(self.min_importance * 0.5);
1279 }
1280 }
1281
1282 self.prune();
1284 }
1285}
1286
1287#[derive(Debug, Clone)]
1289pub struct MemoryStatistics {
1290 pub total: usize,
1292 pub manual: usize,
1294 pub auto: usize,
1296 pub by_category: HashMap<MemoryCategory, usize>,
1298 pub avg_importance: f64,
1300 pub oldest: Option<DateTime<Utc>>,
1302 pub newest: Option<DateTime<Utc>>,
1304 pub highly_referenced: usize,
1306}
1307
1308impl MemoryStatistics {
1309 pub fn format_summary(&self) -> String {
1311 use std::fmt::Write;
1312
1313 let mut output = String::new();
1314
1315 writeln!(output, "记忆统计:").unwrap();
1316 writeln!(output, " 总计: {} 条", self.total).unwrap();
1317 writeln!(output, " ├─ 手动添加: {} 条", self.manual).unwrap();
1318 writeln!(output, " └─ 自动检测: {} 条", self.auto).unwrap();
1319 writeln!(output).unwrap();
1320
1321 writeln!(output, "分类统计:").unwrap();
1322 for (cat, count) in &self.by_category {
1323 writeln!(output, " {} {}: {} 条", cat.icon(), cat.display_name(), count).unwrap();
1324 }
1325 writeln!(output).unwrap();
1326
1327 writeln!(output, "质量指标:").unwrap();
1328 writeln!(output, " 平均重要性: {:.1} 分", self.avg_importance).unwrap();
1329 writeln!(output, " 高频引用: {} 条 (≥3次)", self.highly_referenced).unwrap();
1330
1331 if let Some(oldest) = self.oldest {
1332 let days = (Utc::now() - oldest).num_days();
1333 writeln!(output, " 记忆跨度: {} 天", days).unwrap();
1334 }
1335
1336 output
1337 }
1338}
1339
1340pub struct MemoryFileLock {
1347 lock_path: PathBuf,
1349 locked: bool,
1351}
1352
1353impl MemoryFileLock {
1354 pub fn new(base_dir: &Path) -> Self {
1356 Self {
1357 lock_path: base_dir.join("memory.lock"),
1358 locked: false,
1359 }
1360 }
1361
1362 pub fn acquire(&mut self, timeout_ms: u64) -> Result<bool> {
1365 if self.locked {
1366 return Ok(true); }
1368
1369 let start = std::time::Instant::now();
1370
1371 while start.elapsed().as_millis() < timeout_ms as u128 {
1372 match fs::File::create_new(&self.lock_path) {
1374 Ok(_) => {
1375 let lock_info = format!(
1377 "{}:{}",
1378 std::process::id(),
1379 Utc::now().to_rfc3339()
1380 );
1381 fs::write(&self.lock_path, lock_info)?;
1382 self.locked = true;
1383 return Ok(true);
1384 }
1385 Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
1386 if self.is_stale_lock()? {
1388 self.remove_stale_lock()?;
1389 }
1390 std::thread::sleep(std::time::Duration::from_millis(50));
1392 }
1393 Err(e) => {
1394 return Err(e.into());
1395 }
1396 }
1397 }
1398
1399 Ok(false) }
1401
1402 fn is_stale_lock(&self) -> Result<bool> {
1404 if !self.lock_path.exists() {
1405 return Ok(false);
1406 }
1407
1408 let metadata = fs::metadata(&self.lock_path)?;
1410 let modified = metadata.modified()?;
1411 let age = std::time::SystemTime::now()
1412 .duration_since(modified)
1413 .unwrap_or(std::time::Duration::ZERO);
1414
1415 Ok(age > std::time::Duration::from_secs(30))
1417 }
1418
1419 fn remove_stale_lock(&self) -> Result<()> {
1421 if self.lock_path.exists() {
1422 fs::remove_file(&self.lock_path)?;
1423 }
1424 Ok(())
1425 }
1426
1427 pub fn release(&mut self) -> Result<()> {
1429 if self.locked {
1430 fs::remove_file(&self.lock_path)?;
1431 self.locked = false;
1432 }
1433 Ok(())
1434 }
1435}
1436
1437impl Drop for MemoryFileLock {
1438 fn drop(&mut self) {
1439 let _ = self.release();
1441 }
1442}
1443
1444pub struct MemoryStorage {
1446 base_dir: PathBuf,
1448 project_root: Option<PathBuf>,
1450 lock: MemoryFileLock,
1452}
1453
1454impl MemoryStorage {
1455 pub fn new(project_root: Option<&Path>) -> Result<Self> {
1457 let base_dir = Self::get_base_dir()?;
1458 let lock = MemoryFileLock::new(&base_dir);
1459 Ok(Self {
1460 base_dir,
1461 project_root: project_root.map(|p| p.to_path_buf()),
1462 lock,
1463 })
1464 }
1465
1466 pub fn with_lock_timeout(project_root: Option<&Path>, timeout_ms: u64) -> Result<Self> {
1468 let mut storage = Self::new(project_root)?;
1469 storage.lock.acquire(timeout_ms)?;
1470 Ok(storage)
1471 }
1472
1473 fn get_base_dir() -> Result<PathBuf> {
1475 let home = std::env::var_os("HOME")
1476 .or_else(|| std::env::var_os("USERPROFILE"))
1477 .ok_or_else(|| anyhow::anyhow!("HOME or USERPROFILE not set"))?;
1478 let mut p = PathBuf::from(home);
1479 p.push(".matrix");
1480 Ok(p)
1481 }
1482
1483 pub fn global_memory_path(&self) -> PathBuf {
1485 self.base_dir.join("memory.json")
1486 }
1487
1488 pub fn project_memory_path(&self) -> Option<PathBuf> {
1490 self.project_root.as_ref().map(|p| p.join(".matrix/memory.json"))
1491 }
1492
1493 pub fn config_path(&self) -> PathBuf {
1495 self.base_dir.join("memory_config.json")
1496 }
1497
1498 fn ensure_dirs(&self) -> Result<()> {
1500 fs::create_dir_all(&self.base_dir)?;
1501 if let Some(root) = &self.project_root {
1502 let memory_dir = root.join(".matrix");
1503 fs::create_dir_all(memory_dir)?;
1504 }
1505 Ok(())
1506 }
1507
1508 fn acquire_lock(&mut self) -> Result<()> {
1510 self.lock.acquire(5000)?; Ok(())
1512 }
1513
1514 fn release_lock(&mut self) -> Result<()> {
1516 self.lock.release()?;
1517 Ok(())
1518 }
1519
1520 pub fn load_global(&self) -> Result<AutoMemory> {
1522 let path = self.global_memory_path();
1523 if !path.exists() {
1524 return Ok(AutoMemory::new());
1525 }
1526 let data = fs::read_to_string(&path)?;
1527 let memory: AutoMemory = serde_json::from_str(&data)?;
1528 Ok(memory)
1529 }
1530
1531 pub fn load_project(&self) -> Result<Option<AutoMemory>> {
1533 let path = self.project_memory_path();
1534 match path {
1535 Some(p) if p.exists() => {
1536 let data = fs::read_to_string(&p)?;
1537 let memory: AutoMemory = serde_json::from_str(&data)?;
1538 Ok(Some(memory))
1539 }
1540 _ => Ok(None),
1541 }
1542 }
1543
1544 pub fn load_combined(&self) -> Result<AutoMemory> {
1546 let mut combined = self.load_global()?;
1547
1548 if let Some(project) = self.load_project()? {
1549 for entry in project.entries {
1551 let mut tagged_entry = entry;
1553 if !tagged_entry.tags.contains(&"project".to_string()) {
1554 tagged_entry.tags.push("project".to_string());
1555 }
1556 combined.entries.push(tagged_entry);
1557 }
1558 combined.prune();
1559 }
1560
1561 Ok(combined)
1562 }
1563
1564 pub fn save_global(&mut self, memory: &AutoMemory) -> Result<()> {
1566 self.acquire_lock()?;
1567 self.ensure_dirs()?;
1568
1569 let path = self.global_memory_path();
1570 let json = serde_json::to_string_pretty(memory)?;
1571
1572 let tmp = path.with_extension("json.tmp");
1574 fs::write(&tmp, json)?;
1575 fs::rename(&tmp, &path)?;
1576
1577 self.release_lock()?;
1578 Ok(())
1579 }
1580
1581 pub fn save_project(&mut self, memory: &AutoMemory) -> Result<()> {
1583 self.acquire_lock()?;
1584 self.ensure_dirs()?;
1585
1586 let path = self.project_memory_path()
1587 .ok_or_else(|| anyhow::anyhow!("no project root"))?;
1588 let json = serde_json::to_string_pretty(memory)?;
1589
1590 let tmp = path.with_extension("json.tmp");
1591 fs::write(&tmp, json)?;
1592 fs::rename(&tmp, &path)?;
1593
1594 self.release_lock()?;
1595 Ok(())
1596 }
1597
1598 pub fn save_config(&mut self, config: &MemoryConfig) -> Result<()> {
1600 self.ensure_dirs()?;
1601 let path = self.config_path();
1602 let json = serde_json::to_string_pretty(config)?;
1603 fs::write(&path, json)?;
1604 Ok(())
1605 }
1606
1607 pub fn load_config(&self) -> Result<MemoryConfig> {
1609 let path = self.config_path();
1610 if !path.exists() {
1611 return Ok(MemoryConfig::default());
1612 }
1613 let data = fs::read_to_string(&path)?;
1614 let config: MemoryConfig = serde_json::from_str(&data)?;
1615 Ok(config)
1616 }
1617
1618 pub fn add_entry(&mut self, entry: MemoryEntry, is_project_specific: bool) -> Result<()> {
1620 self.acquire_lock()?;
1621
1622 if is_project_specific {
1623 let mut project = self.load_project()?.unwrap_or_else(AutoMemory::new);
1624 project.add(entry);
1625 self.save_project_locked(&project)?;
1626 } else {
1627 let mut global = self.load_global()?;
1628 global.add(entry);
1629 self.save_global_locked(&global)?;
1630 }
1631
1632 self.release_lock()?;
1633 Ok(())
1634 }
1635
1636 pub fn remove_entry(&mut self, id: &str, is_project_specific: bool) -> Result<bool> {
1638 self.acquire_lock()?;
1639
1640 let removed = if is_project_specific {
1641 if let Some(mut project) = self.load_project()? {
1642 let removed = project.remove(id);
1643 if removed {
1644 self.save_project_locked(&project)?;
1645 }
1646 removed
1647 } else {
1648 false
1649 }
1650 } else {
1651 let mut global = self.load_global()?;
1652 let removed = global.remove(id);
1653 if removed {
1654 self.save_global_locked(&global)?;
1655 }
1656 removed
1657 };
1658
1659 self.release_lock()?;
1660 Ok(removed)
1661 }
1662
1663 fn save_global_locked(&self, memory: &AutoMemory) -> Result<()> {
1665 let path = self.global_memory_path();
1666 let json = serde_json::to_string_pretty(memory)?;
1667 let tmp = path.with_extension("json.tmp");
1668 fs::write(&tmp, json)?;
1669 fs::rename(&tmp, &path)?;
1670 Ok(())
1671 }
1672
1673 fn save_project_locked(&self, memory: &AutoMemory) -> Result<()> {
1674 let path = self.project_memory_path()
1675 .ok_or_else(|| anyhow::anyhow!("no project root"))?;
1676 let json = serde_json::to_string_pretty(memory)?;
1677 let tmp = path.with_extension("json.tmp");
1678 fs::write(&tmp, json)?;
1679 fs::rename(&tmp, &path)?;
1680 Ok(())
1681 }
1682}
1683
1684pub fn calculate_similarity(a: &str, b: &str) -> f64 {
1692 AutoMemory::calculate_similarity(a, b)
1693}
1694
1695pub fn extract_context_keywords(context: &str) -> Vec<String> {
1699 use std::collections::HashSet;
1700
1701 let stop_words: HashSet<&str> = [
1703 "的", "了", "是", "在", "我", "有", "和", "就", "不", "人", "都", "一", "一个",
1705 "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好",
1706 "自己", "这", "他", "她", "它", "们", "那", "些", "什么", "怎么", "如何", "请",
1707 "能", "可以", "需要", "应该", "可能", "因为", "所以", "但是", "然后", "还是",
1708 "已经", "正在", "将要", "曾经", "一下", "一点", "一些", "所有", "每个", "任何",
1709 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
1711 "have", "has", "had", "do", "does", "did", "will", "would", "could",
1712 "should", "may", "might", "can", "shall", "to", "of", "in", "for",
1713 "on", "with", "at", "by", "from", "as", "into", "through", "during",
1714 "before", "after", "above", "below", "between", "and", "but", "or",
1715 "not", "no", "so", "if", "then", "than", "too", "very", "just",
1716 "this", "that", "these", "those", "it", "its", "i", "me", "my",
1717 "we", "our", "you", "your", "he", "his", "she", "her", "they", "their",
1718 "please", "help", "need", "want", "make", "get", "let", "use",
1719 ].iter().copied().collect();
1720
1721 let tech_patterns: HashSet<&str> = [
1723 "api", "cli", "gui", "tui", "web", "http", "json", "xml", "sql", "db",
1725 "git", "npm", "cargo", "rust", "js", "ts", "py", "go", "java", "cpp",
1726 "cpu", "gpu", "io", "fs", "os", "ui", "ux", "ai", "ml", "dl",
1727 "rs", "js", "ts", "py", "go", "java", "c", "h", "cpp", "hpp",
1729 "json", "yaml", "yml", "toml", "md", "txt", "html", "css", "scss",
1730 "bug", "fix", "add", "new", "old", "use", "run", "build", "test",
1732 "code", "data", "file", "dir", "path", "name", "type", "value",
1733 ].iter().copied().collect();
1734
1735 let lower = context.to_lowercase();
1736 let mut keywords: HashSet<String> = HashSet::new();
1737
1738 for word in lower.split_whitespace() {
1740 let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
1741 if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
1742 keywords.insert(cleaned.clone());
1743 }
1744 if tech_patterns.contains(cleaned.as_str()) {
1746 keywords.insert(cleaned);
1747 }
1748 }
1749
1750 let chinese_chars: Vec<char> = lower
1753 .chars()
1754 .filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}') .collect();
1756
1757 for window_size in 2..=4 {
1759 if chinese_chars.len() >= window_size {
1760 for window in chinese_chars.windows(window_size) {
1761 let phrase: String = window.iter().collect();
1762 let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
1764 if !has_stop && phrase.len() >= window_size {
1765 keywords.insert(phrase);
1766 }
1767 }
1768 }
1769 }
1770
1771 let patterns = [
1774 r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", ];
1783
1784 for pattern in patterns {
1785 if let Ok(re) = regex::Regex::new(pattern) {
1786 for cap in re.find_iter(&lower) {
1787 keywords.insert(cap.as_str().to_string());
1788 }
1789 }
1790 }
1791
1792 let mut result: Vec<String> = keywords.into_iter().collect();
1794 result.sort_by(|a, b| b.len().cmp(&a.len()));
1795
1796 result.truncate(15);
1798
1799 result
1800}
1801
1802fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
1805 if context_keywords.is_empty() {
1806 return 0.0;
1807 }
1808
1809 let content_lower = entry.content.to_lowercase();
1810
1811 let matches = context_keywords
1813 .iter()
1814 .filter(|kw| content_lower.contains(kw.as_str()))
1815 .count();
1816
1817 let keyword_score = matches as f64 / context_keywords.len() as f64;
1819
1820 let tag_matches = entry.tags
1822 .iter()
1823 .filter(|tag| {
1824 let tag_lower = tag.to_lowercase();
1825 context_keywords.iter().any(|kw| tag_lower.contains(kw.as_str()))
1826 })
1827 .count();
1828
1829 let tag_score = if tag_matches > 0 { 0.2 } else { 0.0 };
1830
1831 (keyword_score + tag_score).min(1.0)
1833}
1834
1835fn has_contradiction_signal(old: &str, new: &str) -> bool {
1842 let change_signals = [
1844 "改用", "换成", "替换", "改为", "切换到", "迁移到",
1845 "不再使用", "弃用", "放弃", "取消",
1846 "switched to", "replaced", "migrated to", "changed to",
1847 "no longer", "deprecated", "abandoned",
1848 ];
1849
1850 for signal in &change_signals {
1851 if new.contains(signal) {
1852 return true;
1853 }
1854 }
1855
1856 let action_verbs = [
1859 "决定使用", "选择使用", "采用", "使用",
1860 "decided to use", "chose", "using", "adopted",
1861 ];
1862
1863 for verb in &action_verbs {
1864 if old.contains(verb) && new.contains(verb) {
1865 return true;
1868 }
1869 }
1870
1871 let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
1873 for verb in &pref_verbs {
1874 if old.contains(verb) && new.contains(verb) {
1875 return true;
1876 }
1877 }
1878
1879 false
1880}
1881
1882#[async_trait::async_trait]
1888pub trait MemoryExtractor: Send + Sync {
1889 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
1891
1892 fn model_name(&self) -> &str;
1894}
1895
1896pub struct AiMemoryExtractor {
1898 provider: Box<dyn crate::providers::Provider>,
1899 model: String,
1900}
1901
1902impl AiMemoryExtractor {
1903 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
1905 Self { provider, model }
1906 }
1907}
1908
1909const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。你的任务是从对话中识别并提取值得长期记忆的关键信息。
1911
1912记忆类型:
19131. Decision(决策): 项目或技术选型的决定,如"决定使用 PostgreSQL"
19142. Preference(偏好): 用户习惯或偏好,如"我喜欢用 vim"
19153. Solution(解决方案): 解决问题的具体方法,如"通过添加 middleware 修复 bug"
19164. Finding(发现): 重要发现或信息,如"API 端点在 /api/v2"
19175. Technical(技术): 技术栈或框架信息,如"使用 React Query 做数据获取"
19186. Structure(结构): 项目结构信息,如"入口文件是 src/index.ts"
1919
1920提取原则:
1921- 只提取有价值、可复用的信息
1922- 避免提取临时性、一次性信息
1923- 避免提取过于具体的代码细节
1924- 每条记忆应简洁明确(一句话)
1925- 最多提取 5 条记忆
1926
1927输出格式(严格 JSON):
1928```json
1929{
1930 "memories": [
1931 {
1932 "category": "decision",
1933 "content": "决定使用 PostgreSQL 作为主数据库",
1934 "importance": 90
1935 },
1936 {
1937 "category": "preference",
1938 "content": "用户偏好 TypeScript 而非 JavaScript",
1939 "importance": 70
1940 }
1941 ]
1942}
1943```
1944
1945如果没有值得记忆的内容,返回:
1946```json
1947{"memories": []}
1948```
1949
1950直接输出 JSON,不要加代码块包裹。"#;
1951
1952#[async_trait::async_trait]
1953impl MemoryExtractor for AiMemoryExtractor {
1954 async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
1955 use crate::providers::{ChatRequest, Message, MessageContent, Role};
1956
1957 let truncated_text = if text.len() > 4000 {
1959 truncate_str(text, 4000)
1960 } else {
1961 text.to_string()
1962 };
1963
1964 let request = ChatRequest {
1965 messages: vec![Message {
1966 role: Role::User,
1967 content: MessageContent::Text(format!(
1968 "请从以下对话中提取值得记忆的关键信息:\n\n{}",
1969 truncated_text
1970 )),
1971 }],
1972 tools: vec![], system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
1974 think: false, max_tokens: 512, server_tools: vec![],
1977 enable_caching: false,
1978 };
1979
1980 let response = self.provider.chat(request).await?;
1981
1982 let response_text = response.content
1984 .iter()
1985 .filter_map(|block| {
1986 if let crate::providers::ContentBlock::Text { text } = block {
1987 Some(text.clone())
1988 } else {
1989 None
1990 }
1991 })
1992 .collect::<Vec<_>>()
1993 .join("");
1994
1995 parse_memory_response(&response_text, session_id)
1997 }
1998
1999 fn model_name(&self) -> &str {
2000 &self.model
2001 }
2002}
2003
2004fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
2006 let cleaned = json_text
2008 .trim()
2009 .trim_start_matches("```json")
2010 .trim_start_matches("```")
2011 .trim_end_matches("```")
2012 .trim();
2013
2014 #[derive(serde::Deserialize)]
2016 struct MemoryResponse {
2017 memories: Vec<MemoryItem>,
2018 }
2019
2020 #[derive(serde::Deserialize)]
2021 struct MemoryItem {
2022 category: String,
2023 content: String,
2024 #[serde(default)]
2025 importance: f64,
2026 }
2027
2028 let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
2029
2030 let entries = parsed.memories
2032 .into_iter()
2033 .filter_map(|item| {
2034 let category = match item.category.to_lowercase().as_str() {
2036 "decision" => MemoryCategory::Decision,
2037 "preference" => MemoryCategory::Preference,
2038 "solution" => MemoryCategory::Solution,
2039 "finding" => MemoryCategory::Finding,
2040 "technical" => MemoryCategory::Technical,
2041 "structure" => MemoryCategory::Structure,
2042 _ => return None, };
2044
2045 if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
2047 return None;
2048 }
2049
2050 let mut entry = MemoryEntry::new(
2052 category,
2053 item.content,
2054 session_id.map(|s| s.to_string()),
2055 );
2056
2057 if item.importance > 0.0 {
2059 entry.importance = item.importance.clamp(0.0, 100.0);
2060 }
2061
2062 Some(entry)
2063 })
2064 .collect();
2065
2066 Ok(deduplicate_entries(entries))
2068}
2069
2070const KEYWORD_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个关键词提取助手。你的任务是从用户输入中提取有意义的关键词,用于检索相关记忆。
2076
2077提取原则:
20781. 只提取有实际意义的词汇(技术名词、项目名、概念等)
20792. 过滤掉常见的停用词(的、是、在、我、你、the、a、is 等)
20803. 保留专有名词和技术术语
20814. 中英文混合输入时,两种语言的关键词都提取
20825. 提取 3-10 个关键词
2083
2084输出格式(严格 JSON):
2085```json
2086{
2087 "keywords": ["数据库", "PostgreSQL", "优化", "查询"]
2088}
2089```
2090
2091如果没有有意义的关键词,返回:
2092```json
2093{"keywords": []}
2094```
2095
2096直接输出 JSON,不要加代码块包裹。"#;
2097
2098pub async fn extract_keywords_with_ai(
2103 context: &str,
2104 provider: &dyn crate::providers::Provider,
2105) -> Result<Vec<String>> {
2106 use crate::providers::{ChatRequest, Message, MessageContent, Role};
2107
2108 let truncated = if context.len() > 1000 {
2110 truncate_str(context, 1000)
2111 } else {
2112 context.to_string()
2113 };
2114
2115 let request = ChatRequest {
2116 messages: vec![Message {
2117 role: Role::User,
2118 content: MessageContent::Text(format!(
2119 "请从以下文本中提取关键词:\n\n{}",
2120 truncated
2121 )),
2122 }],
2123 tools: vec![],
2124 system: Some(KEYWORD_EXTRACT_SYSTEM_PROMPT.to_string()),
2125 think: false,
2126 max_tokens: 256,
2127 server_tools: vec![],
2128 enable_caching: false,
2129 };
2130
2131 let response = provider.chat(request).await?;
2132
2133 let response_text = response.content
2135 .iter()
2136 .filter_map(|block| {
2137 if let crate::providers::ContentBlock::Text { text } = block {
2138 Some(text.clone())
2139 } else {
2140 None
2141 }
2142 })
2143 .collect::<Vec<_>>()
2144 .join("");
2145
2146 parse_keyword_response(&response_text)
2148}
2149
2150fn parse_keyword_response(json_text: &str) -> Result<Vec<String>> {
2152 let cleaned = json_text
2154 .trim()
2155 .trim_start_matches("```json")
2156 .trim_start_matches("```")
2157 .trim_end_matches("```")
2158 .trim();
2159
2160 #[derive(serde::Deserialize)]
2161 struct KeywordResponse {
2162 keywords: Vec<String>,
2163 }
2164
2165 let parsed: KeywordResponse = serde_json::from_str(cleaned)?;
2166
2167 Ok(parsed.keywords
2169 .into_iter()
2170 .filter(|k| k.len() >= 2)
2171 .collect())
2172}
2173
2174pub async fn extract_keywords_hybrid(
2181 context: &str,
2182 fast_provider: Option<&dyn crate::providers::Provider>,
2183) -> Vec<String> {
2184 let mode = AiKeywordMode::from_env();
2186
2187 if mode == AiKeywordMode::Never {
2189 return extract_context_keywords(context);
2190 }
2191
2192 let keywords = if mode == AiKeywordMode::Always {
2194 Vec::new() } else {
2196 extract_context_keywords(context)
2197 };
2198
2199 if !mode.should_use_ai(keywords.len()) {
2201 return keywords;
2202 }
2203
2204 if let Some(provider) = fast_provider {
2206 match extract_keywords_with_ai(context, provider).await {
2207 Ok(ai_keywords) if !ai_keywords.is_empty() => {
2208 log::debug!("AI extracted {} keywords: {:?}", ai_keywords.len(), ai_keywords);
2209 if mode == AiKeywordMode::Auto && !keywords.is_empty() {
2211 let merged = keywords
2212 .into_iter()
2213 .chain(ai_keywords.into_iter())
2214 .collect::<std::collections::HashSet<_>>();
2215 return merged.into_iter().collect();
2216 }
2217 return ai_keywords;
2218 }
2219 Ok(_) => {
2220 log::debug!("AI returned no keywords, keeping rule-based results");
2221 }
2222 Err(e) => {
2223 log::warn!("AI keyword extraction failed: {}, keeping rule-based results", e);
2224 }
2225 }
2226 }
2227
2228 keywords
2230}
2231
2232const MEMORY_SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个记忆摘要助手。你的任务是将多条相关记忆合并为一条精炼的摘要记忆。
2238
2239摘要原则:
22401. 保留核心信息,去除冗余细节
22412. 使用简洁明确的一句话表达
22423. 保留关键的技术名词和决策结论
22434. 如果多条记忆主题相同,合并为一条综合性记忆
22445. 优先保留高价值的决策和解决方案
2245
2246输出格式(严格 JSON):
2247```json
2248{
2249 "summary": "决定使用 PostgreSQL 作为主数据库,Redis 作为缓存层",
2250 "category": "decision",
2251 "importance": 90
2252}
2253```
2254
2255如果没有值得保留的信息,返回:
2256```json
2257{"summary": "", "category": "", "importance": 0}
2258```
2259
2260直接输出 JSON,不要加代码块包裹。"#;
2261
2262const MEMORY_CONFLICT_SYSTEM_PROMPT: &str = r#"你是一个记忆冲突检测助手。你的任务是判断两条记忆是否矛盾或需要更新。
2264
2265冲突类型:
22661. 直接矛盾:两条记忆结论相反(如"使用 PostgreSQL" vs "使用 MySQL")
22672. 过时更新:新记忆明确替换旧记忆(如"改用 Redis" 替换 "使用 Memcached")
22683. 补充关系:新记忆补充旧记忆(如"PostgreSQL 版本为 15" 补充 "使用 PostgreSQL")
22694. 无关关系:两条记忆主题不同,不冲突
2270
2271输出格式(严格 JSON):
2272```json
2273{
2274 "conflict_type": "direct_conflict",
2275 "should_replace": true,
2276 "reason": "两条记忆都是数据库选型决策,但选择了不同的数据库",
2277 "winner": "new"
2278}
2279```
2280
2281conflict_type 可选值:
2282- "direct_conflict": 直接矛盾,需要选择一条
2283- "outdated_update": 过时更新,新记忆替换旧记忆
2284- "supplement": 补充关系,两者可共存
2285- "no_conflict": 无关关系,不冲突
2286
2287should_replace: true 表示需要替换旧记忆,false 表示保留两者
2288winner: "new" 表示新记忆胜出,"old" 表示旧记忆胜出(仅在 direct_conflict 时有意义)
2289
2290直接输出 JSON,不要加代码块包裹。"#;
2291
2292const MEMORY_QUALITY_SYSTEM_PROMPT: &str = r#"你是一个记忆质量评估助手。你的任务是评估记忆的长期价值和重要程度。
2294
2295评估维度:
22961. 复用价值:这条信息在未来的���话中会被引用吗?
22972. 决策权重:这是重要的项目决策还是次要细节?
22983. 时效性:这条信息会很快过时吗?
22994. 独特性:这条信息是否足够独特,不与其他记忆重叠?
2300
2301评分标准:
2302- 90-100: 核心决策,长期有效,高复用价值(如数据库选型、框架选择)
2303- 70-89: 重要偏好或解决方案,中等复用价值
2304- 50-69: 有用的技术信息或发现,时效性中等
2305- 30-49: 一般性信息,复用价值较低
2306- 0-29: 过时或过于具体的细节,建议丢弃
2307
2308输出格式(严格 JSON):
2309```json
2310{
2311 "quality_score": 85,
2312 "reason": "这是核心的技术选型决策,长期有效,高复用价值",
2313 "should_keep": true,
2314 "suggested_category": "decision"
2315}
2316```
2317
2318直接输出 JSON,不要加代码块包裹。"#;
2319
2320const MEMORY_MERGE_SYSTEM_PROMPT: &str = r#"你是一个记忆合并助手。你的任务是将多条相似或相关的记忆合并为一条精炼的记忆。
2322
2323合并原则:
23241. 相同主题的记忆应合并为一条综合性记忆
23252. 保留所有关键信息,去除重复内容
23263. 使用简洁的一句话表达
23274. 合并后的记忆应比原记忆更全面但更简洁
23285. 如果记忆完全不相关,返回空结果表示不应合并
2329
2330输出格式(严格 JSON):
2331```json
2332{
2333 "merged_content": "使用 PostgreSQL 作为主数据库(版本15),Redis 作为缓存层,通过连接池优化性能",
2334 "category": "technical",
2335 "importance": 75,
2336 "merged_from_count": 3,
2337 "summary_reason": "三条记忆都与数据库和缓存技术栈相关,合并为一条综合性技术栈记忆"
2338}
2339```
2340
2341如果不应合并,返回:
2342```json
2343{"merged_content": "", "category": "", "importance": 0, "merged_from_count": 0, "summary_reason": "记忆主题不同,不应合并"}
2344```
2345
2346直接输出 JSON,不要加代码块包裹。"#;
2347
2348#[derive(Debug, Clone, serde::Deserialize)]
2350pub struct MemorySummaryResult {
2351 pub summary: String,
2352 pub category: String,
2353 pub importance: f64,
2354}
2355
2356#[derive(Debug, Clone, serde::Deserialize)]
2358pub struct MemoryConflictResult {
2359 pub conflict_type: String,
2360 pub should_replace: bool,
2361 pub reason: String,
2362 pub winner: Option<String>,
2363}
2364
2365#[derive(Debug, Clone, serde::Deserialize)]
2367pub struct MemoryQualityResult {
2368 pub quality_score: f64,
2369 pub reason: String,
2370 pub should_keep: bool,
2371 pub suggested_category: Option<String>,
2372}
2373
2374#[derive(Debug, Clone, serde::Deserialize)]
2376pub struct MemoryMergeResult {
2377 pub merged_content: String,
2378 pub category: String,
2379 pub importance: f64,
2380 pub merged_from_count: usize,
2381 pub summary_reason: String,
2382}
2383
2384pub struct AiMemoryProcessor {
2387 provider: Box<dyn crate::providers::Provider>,
2388 model: String,
2389}
2390
2391impl AiMemoryProcessor {
2392 pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
2394 Self { provider, model }
2395 }
2396
2397 pub async fn summarize_memories(&self, memories: &[&MemoryEntry]) -> Result<Option<MemoryEntry>> {
2399 if memories.is_empty() {
2400 return Ok(None);
2401 }
2402
2403 let memories_text = memories
2405 .iter()
2406 .map(|m| format!("[{}] {}", m.category.display_name(), m.content))
2407 .collect::<Vec<_>>()
2408 .join("\n");
2409
2410 let request = build_ai_request(
2411 MEMORY_SUMMARY_SYSTEM_PROMPT,
2412 &format!("请将以下记忆合并为一条精炼的摘要:\n\n{}", memories_text),
2413 );
2414
2415 let response = self.provider.chat(request).await?;
2416 let response_text = extract_response_text(&response);
2417
2418 let result: MemorySummaryResult = parse_json_response(&response_text)?;
2419
2420 if result.summary.is_empty() {
2421 return Ok(None);
2422 }
2423
2424 let category = parse_category(&result.category)?;
2425 let mut entry = MemoryEntry::new(category, result.summary, None);
2426 entry.importance = result.importance.clamp(0.0, 100.0);
2427
2428 Ok(Some(entry))
2429 }
2430
2431 pub async fn detect_conflict(&self, old: &MemoryEntry, new: &MemoryEntry) -> Result<MemoryConflictResult> {
2433 let input = format!(
2434 "旧记忆:[{}] {}\n新记忆:[{}] {}\n\n请判断这两条记忆是否存在冲突。",
2435 old.category.display_name(),
2436 old.content,
2437 new.category.display_name(),
2438 new.content
2439 );
2440
2441 let request = build_ai_request(MEMORY_CONFLICT_SYSTEM_PROMPT, &input);
2442 let response = self.provider.chat(request).await?;
2443 let response_text = extract_response_text(&response);
2444
2445 parse_json_response(&response_text)
2446 }
2447
2448 pub async fn assess_quality(&self, memory: &MemoryEntry) -> Result<MemoryQualityResult> {
2450 let input = format!(
2451 "记忆内容:[{}] {}\n\n请评估这条记忆的质量和长期价值。",
2452 memory.category.display_name(),
2453 memory.content
2454 );
2455
2456 let request = build_ai_request(MEMORY_QUALITY_SYSTEM_PROMPT, &input);
2457 let response = self.provider.chat(request).await?;
2458 let response_text = extract_response_text(&response);
2459
2460 parse_json_response(&response_text)
2461 }
2462
2463 pub async fn merge_memories(&self, memories: &[&MemoryEntry]) -> Result<Option<MemoryEntry>> {
2465 if memories.len() < 2 {
2466 return Ok(None);
2467 }
2468
2469 let memories_text = memories
2470 .iter()
2471 .map(|m| format!("[{}] {}", m.category.display_name(), m.content))
2472 .collect::<Vec<_>>()
2473 .join("\n");
2474
2475 let request = build_ai_request(
2476 MEMORY_MERGE_SYSTEM_PROMPT,
2477 &format!("请判断以下记忆是否应该合并,如果应该则生成合并后的记忆:\n\n{}", memories_text),
2478 );
2479
2480 let response = self.provider.chat(request).await?;
2481 let response_text = extract_response_text(&response);
2482
2483 let result: MemoryMergeResult = parse_json_response(&response_text)?;
2484
2485 if result.merged_content.is_empty() || result.merged_from_count == 0 {
2486 return Ok(None);
2487 }
2488
2489 let category = parse_category(&result.category)?;
2490 let mut entry = MemoryEntry::new(category, result.merged_content, None);
2491 entry.importance = result.importance.clamp(0.0, 100.0);
2492
2493 Ok(Some(entry))
2494 }
2495
2496 pub fn model_name(&self) -> &str {
2498 &self.model
2499 }
2500}
2501
2502fn build_ai_request(system_prompt: &str, user_input: &str) -> crate::providers::ChatRequest {
2504 use crate::providers::{ChatRequest, Message, MessageContent, Role};
2505
2506 ChatRequest {
2507 messages: vec![Message {
2508 role: Role::User,
2509 content: MessageContent::Text(user_input.to_string()),
2510 }],
2511 tools: vec![],
2512 system: Some(system_prompt.to_string()),
2513 think: false,
2514 max_tokens: 512,
2515 server_tools: vec![],
2516 enable_caching: false,
2517 }
2518}
2519
2520fn extract_response_text(response: &crate::providers::ChatResponse) -> String {
2522 response.content
2523 .iter()
2524 .filter_map(|block| {
2525 if let crate::providers::ContentBlock::Text { text } = block {
2526 Some(text.clone())
2527 } else {
2528 None
2529 }
2530 })
2531 .collect::<Vec<_>>()
2532 .join("")
2533}
2534
2535fn parse_json_response<T: serde::de::DeserializeOwned>(json_text: &str) -> Result<T> {
2537 let cleaned = json_text
2538 .trim()
2539 .trim_start_matches("```json")
2540 .trim_start_matches("```")
2541 .trim_end_matches("```")
2542 .trim();
2543
2544 serde_json::from_str(cleaned).map_err(|e| anyhow::anyhow!("JSON parse error: {}", e))
2545}
2546
2547fn parse_category(s: &str) -> Result<MemoryCategory> {
2549 match s.to_lowercase().as_str() {
2550 "decision" | "决策" => Ok(MemoryCategory::Decision),
2551 "preference" | "偏好" => Ok(MemoryCategory::Preference),
2552 "solution" | "解决方案" => Ok(MemoryCategory::Solution),
2553 "finding" | "发现" => Ok(MemoryCategory::Finding),
2554 "technical" | "技术" => Ok(MemoryCategory::Technical),
2555 "structure" | "结构" => Ok(MemoryCategory::Structure),
2556 _ => anyhow::bail!("Unknown category: {}", s),
2557 }
2558}
2559
2560#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2562pub struct AiMemoryConfig {
2563 pub enable_summarization: bool,
2565 pub enable_conflict_detection: bool,
2567 pub enable_quality_assessment: bool,
2569 pub enable_merging: bool,
2571 pub summarize_threshold: usize,
2573 pub quality_threshold: f64,
2575 pub merge_similarity_threshold: f64,
2577}
2578
2579impl Default for AiMemoryConfig {
2580 fn default() -> Self {
2581 Self {
2582 enable_summarization: true,
2583 enable_conflict_detection: true,
2584 enable_quality_assessment: false, enable_merging: true,
2586 summarize_threshold: 5,
2587 quality_threshold: 30.0,
2588 merge_similarity_threshold: 0.6,
2589 }
2590 }
2591}
2592
2593impl AiMemoryConfig {
2594 pub fn minimal() -> Self {
2596 Self {
2597 enable_summarization: false,
2598 enable_conflict_detection: false,
2599 enable_quality_assessment: false,
2600 enable_merging: false,
2601 summarize_threshold: 10,
2602 quality_threshold: 20.0,
2603 merge_similarity_threshold: 0.8,
2604 }
2605 }
2606
2607 pub fn aggressive() -> Self {
2609 Self {
2610 enable_summarization: true,
2611 enable_conflict_detection: true,
2612 enable_quality_assessment: true,
2613 enable_merging: true,
2614 summarize_threshold: 3,
2615 quality_threshold: 40.0,
2616 merge_similarity_threshold: 0.5,
2617 }
2618 }
2619
2620 pub fn from_env() -> Self {
2622 let enable_all = std::env::var("MEMORY_AI_ALL")
2623 .map(|v| v == "true" || v == "1")
2624 .unwrap_or(false);
2625
2626 if enable_all {
2627 return Self::aggressive();
2628 }
2629
2630 Self {
2631 enable_summarization: std::env::var("MEMORY_AI_SUMMARY")
2632 .map(|v| v != "false" && v != "0")
2633 .unwrap_or(true),
2634 enable_conflict_detection: std::env::var("MEMORY_AI_CONFLICT")
2635 .map(|v| v != "false" && v != "0")
2636 .unwrap_or(true),
2637 enable_quality_assessment: std::env::var("MEMORY_AI_QUALITY")
2638 .map(|v| v == "true" || v == "1")
2639 .unwrap_or(false),
2640 enable_merging: std::env::var("MEMORY_AI_MERGE")
2641 .map(|v| v != "false" && v != "0")
2642 .unwrap_or(true),
2643 summarize_threshold: std::env::var("MEMORY_SUMMARY_THRESHOLD")
2644 .and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
2645 .unwrap_or(5),
2646 quality_threshold: std::env::var("MEMORY_QUALITY_THRESHOLD")
2647 .and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
2648 .unwrap_or(30.0),
2649 merge_similarity_threshold: std::env::var("MEMORY_MERGE_THRESHOLD")
2650 .and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
2651 .unwrap_or(0.6),
2652 }
2653 }
2654}
2655
2656impl AutoMemory {
2658 pub async fn add_memory_with_ai_conflict(
2660 &mut self,
2661 category: MemoryCategory,
2662 content: String,
2663 source_session: Option<String>,
2664 processor: Option<&AiMemoryProcessor>,
2665 ) -> Result<()> {
2666 if self.has_similar(&content) {
2668 return Ok(());
2669 }
2670
2671 let new_entry = MemoryEntry::new(category, content.clone(), source_session);
2673
2674 let potential_conflicts: Vec<(usize, &MemoryEntry)> = self.entries
2676 .iter()
2677 .enumerate()
2678 .filter(|(_, e)| {
2679 e.category == category &&
2680 Self::calculate_similarity(&e.content.to_lowercase(), &content.to_lowercase()) > 0.3
2681 })
2682 .collect();
2683
2684 if let Some(processor) = processor {
2685 for (idx, old_entry) in potential_conflicts {
2687 let result = processor.detect_conflict(old_entry, &new_entry).await?;
2688
2689 if result.should_replace {
2690 log::debug!("AI detected conflict: {} -> replacing '{}' with '{}'",
2691 result.conflict_type, old_entry.content, content);
2692 self.entries.remove(idx);
2693 self.invalidate_index();
2694 break;
2695 }
2696 }
2697 } else {
2698 if let Some(conflict_idx) = self.find_conflict(&content, category) {
2700 self.entries.remove(conflict_idx);
2701 self.invalidate_index();
2702 }
2703 }
2704
2705 self.add(new_entry);
2706 Ok(())
2707 }
2708
2709 pub async fn assess_quality_with_ai(
2711 &mut self,
2712 processor: &AiMemoryProcessor,
2713 config: &AiMemoryConfig,
2714 ) -> Result<usize> {
2715 if !config.enable_quality_assessment {
2716 return Ok(0);
2717 }
2718
2719 let indices_to_assess: Vec<usize> = self.entries
2721 .iter()
2722 .enumerate()
2723 .filter(|(_, entry)| !entry.is_manual)
2724 .map(|(i, _)| i)
2725 .collect();
2726
2727 let mut to_remove: Vec<usize> = Vec::new();
2729 let mut importance_updates: Vec<(usize, f64)> = Vec::new();
2730
2731 for i in indices_to_assess {
2732 let entry = &self.entries[i];
2733 let result = processor.assess_quality(entry).await?;
2734
2735 if !result.should_keep || result.quality_score < config.quality_threshold {
2736 log::debug!("AI quality assessment: removing '{}' (score: {:.1}, reason: {})",
2737 entry.content, result.quality_score, result.reason);
2738 to_remove.push(i);
2739 } else {
2740 importance_updates.push((i, result.quality_score));
2742 }
2743 }
2744
2745 for (i, score) in importance_updates {
2747 self.entries[i].importance = score;
2748 }
2749
2750 let removed_count = to_remove.len();
2751
2752 for idx in to_remove.into_iter().rev() {
2754 self.entries.remove(idx);
2755 }
2756
2757 if removed_count > 0 {
2758 self.invalidate_index();
2759 self.prune();
2760 }
2761
2762 Ok(removed_count)
2763 }
2764
2765 pub async fn merge_similar_with_ai(
2767 &mut self,
2768 processor: &AiMemoryProcessor,
2769 config: &AiMemoryConfig,
2770 ) -> Result<usize> {
2771 if !config.enable_merging || self.entries.len() < 2 {
2772 return Ok(0);
2773 }
2774
2775 let mut merged_count = 0;
2776 let mut to_remove: Vec<usize> = Vec::new();
2777 let mut new_entries: Vec<MemoryEntry> = Vec::new();
2778
2779 let mut processed: std::collections::HashSet<usize> = std::collections::HashSet::new();
2781
2782 for i in 0..self.entries.len() {
2783 if processed.contains(&i) {
2784 continue;
2785 }
2786
2787 let mut similar_group: Vec<usize> = vec![i];
2789
2790 for j in (i + 1)..self.entries.len() {
2791 if processed.contains(&j) {
2792 continue;
2793 }
2794
2795 let sim = Self::calculate_similarity(
2796 &self.entries[i].content.to_lowercase(),
2797 &self.entries[j].content.to_lowercase(),
2798 );
2799
2800 if sim >= config.merge_similarity_threshold {
2801 similar_group.push(j);
2802 }
2803 }
2804
2805 if similar_group.len() >= 2 {
2807 let group_entries: Vec<&MemoryEntry> = similar_group
2808 .iter()
2809 .map(|&idx| &self.entries[idx])
2810 .collect();
2811
2812 if let Some(merged) = processor.merge_memories(&group_entries).await? {
2813 log::debug!("AI merged {} memories into: '{}'",
2814 similar_group.len(), merged.content);
2815
2816 new_entries.push(merged);
2817 to_remove.extend(similar_group.iter().copied());
2818 processed.extend(similar_group.iter().copied());
2819 merged_count += similar_group.len() - 1;
2820 }
2821 }
2822 }
2823
2824 let mut sorted_remove: Vec<usize> = to_remove;
2826 sorted_remove.sort();
2827 for idx in sorted_remove.into_iter().rev() {
2828 self.entries.remove(idx);
2829 }
2830
2831 for entry in new_entries {
2833 self.entries.push(entry);
2834 }
2835
2836 if merged_count > 0 {
2837 self.invalidate_index();
2838 self.prune();
2839 }
2840
2841 Ok(merged_count)
2842 }
2843
2844 pub async fn generate_ai_summary(
2846 &self,
2847 max_entries: usize,
2848 processor: Option<&AiMemoryProcessor>,
2849 config: Option<&AiMemoryConfig>,
2850 ) -> Result<String> {
2851 if self.entries.is_empty() {
2852 return Ok(String::new());
2853 }
2854
2855 let default_config = AiMemoryConfig::default();
2856 let config = config.unwrap_or(&default_config);
2857
2858 if config.enable_summarization
2860 && let Some(processor) = processor
2861 && self.entries.len() >= config.summarize_threshold
2862 {
2863
2864 let mut by_category: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
2866 for entry in &self.entries {
2867 by_category.entry(entry.category).or_default().push(entry);
2868 }
2869
2870 let mut summary = String::from("【跨会话记忆 (AI摘要)】\n\n");
2871
2872 for (cat, entries) in by_category {
2873 if entries.is_empty() {
2874 continue;
2875 }
2876
2877 let top_entries: Vec<&MemoryEntry> = entries
2879 .iter()
2880 .take(max_entries.min(entries.len()))
2881 .copied()
2882 .collect();
2883
2884 if let Some(ai_summary) = processor.summarize_memories(&top_entries).await? {
2886 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
2887 summary.push_str(&format!(" {}\n\n", ai_summary.content));
2888 } else {
2889 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
2891 for entry in top_entries {
2892 summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
2893 }
2894 summary.push('\n');
2895 }
2896 }
2897
2898 Ok(summary)
2899 } else {
2900 Ok(self.generate_contextual_summary("", max_entries))
2902 }
2903 }
2904}
2905
2906
2907
2908pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
2916 let mut entries = Vec::new();
2917 let text_lower = text.to_lowercase();
2918
2919 let patterns: Vec<(MemoryCategory, Vec<&str>)> = vec![
2921 (MemoryCategory::Decision, vec![
2922 "最终决定", "决定采用", "我们决定", "最终选择", "经过讨论决定",
2924 "项目决定", "团队决定", "最终选定", "确定使用",
2925 "we decided", "final decision", "decided to use", "chose to use",
2927 "team decided", "final choice", "ultimately chose",
2928 ]),
2929 (MemoryCategory::Preference, vec![
2930 "我喜欢", "我最喜欢", "我特别喜欢", "我非常喜欢",
2933 "我偏好", "我偏好使用", "个人偏好",
2935 "我习惯", "我习惯用", "我的习惯", "通常我会",
2937 "我倾向于", "更倾向于", "我偏爱",
2939 "i like", "i prefer", "my favorite", "i love",
2941 "i prefer using", "my preference is", "i usually use",
2942 "i tend to use", "my habit is", "i really like",
2943 ]),
2944 (MemoryCategory::Solution, vec![
2945 "通过修改", "通过添加", "通过删除", "解决方案是",
2947 "修复方法是", "解决方法是", "根本原因是",
2948 "修复了问题", "解决了问题", "关键修复",
2949 "fixed by", "solved by", "solution is", "root cause is",
2951 "the fix was", "fixed the issue",
2952 ]),
2953 (MemoryCategory::Finding, vec![
2954 "关键发现", "重要发现", "我注意到", "发现问题是",
2956 "问题根源是", "问题出在", "主要原因是",
2957 "key finding", "important discovery", "found that the",
2959 "the issue is", "root cause", "discovered that",
2960 ]),
2961 (MemoryCategory::Technical, vec![
2962 "技术栈是", "框架使用", "依赖的是", "构建工具是",
2964 "数据库是", "后端框架", "前端框架",
2965 "tech stack is", "using framework", "built with",
2967 "database is", "backend uses", "frontend uses",
2968 ]),
2969 (MemoryCategory::Structure, vec![
2970 "入口文件是", "主文件位于", "核心模块是", "项目结构是",
2972 "主要目录", "核心目录", "重要文件是",
2973 "entry point is", "main file is", "core module is",
2975 "project structure", "main directory",
2976 ]),
2977 ];
2978
2979 for (category, keywords) in patterns {
2980 for keyword in keywords {
2981 if text_lower.contains(keyword) {
2982 let content = extract_memory_content(text, keyword);
2984 if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
2986 let entry = MemoryEntry::new(
2987 category,
2988 content,
2989 session_id.map(|s| s.to_string()),
2990 );
2991 entries.push(entry);
2992 }
2993 }
2994 }
2995 }
2996
2997 deduplicate_entries(entries)
2999}
3000
3001pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
3004 detect_memories_fallback(text, session_id)
3005}
3006
3007pub async fn detect_memories_smart(
3011 text: &str,
3012 session_id: Option<&str>,
3013 extractor: Option<&dyn MemoryExtractor>,
3014) -> Vec<MemoryEntry> {
3015 let mode = AiDetectionMode::from_env();
3016
3017 if mode.should_use_ai() && extractor.is_some() {
3018 match detect_memories_with_ai(text, session_id, extractor).await {
3020 Ok(entries) if !entries.is_empty() => {
3021 log::debug!("AI memory detection found {} entries", entries.len());
3022 return entries;
3023 }
3024 Ok(_) => {
3025 log::debug!("AI detection returned empty, falling back to rules");
3026 }
3027 Err(e) => {
3028 log::warn!("AI memory detection failed: {}, falling back to rules", e);
3029 }
3030 }
3031 }
3032
3033 detect_memories_fallback(text, session_id)
3035}
3036
3037pub async fn detect_memories_with_ai(
3040 text: &str,
3041 session_id: Option<&str>,
3042 extractor: Option<&dyn MemoryExtractor>,
3043) -> Result<Vec<MemoryEntry>> {
3044 if let Some(ai_extractor) = extractor {
3045 match ai_extractor.extract(text, session_id).await {
3047 Ok(entries) if !entries.is_empty() => {
3048 return Ok(entries);
3049 }
3050 Ok(_) => {
3051 }
3053 Err(_) => {
3054 }
3056 }
3057 }
3058
3059 Ok(detect_memories_fallback(text, session_id))
3061}
3062
3063fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
3066 if entries.is_empty() {
3067 return entries;
3068 }
3069
3070 let mut sorted = entries;
3072 sorted.sort_by(|a, b| b.content.len().cmp(&a.content.len()));
3073
3074 let mut unique: Vec<MemoryEntry> = Vec::new();
3076 for entry in sorted {
3077 let entry_lower = entry.content.to_lowercase();
3078
3079 let is_duplicate = unique.iter().any(|existing| {
3081 let existing_lower = existing.content.to_lowercase();
3082
3083 if existing_lower == entry_lower {
3085 return true;
3086 }
3087
3088 let similarity = calculate_similarity(&existing_lower, &entry_lower);
3090 similarity >= 0.8
3091 });
3092
3093 if !is_duplicate {
3094 unique.push(entry);
3095 }
3096
3097 if unique.len() >= MAX_DETECTED_ENTRIES {
3099 break;
3100 }
3101 }
3102
3103 unique
3104}
3105
3106fn extract_memory_content(text: &str, keyword: &str) -> String {
3109 let text_lower = text.to_lowercase();
3110 let keyword_lower = keyword.to_lowercase();
3111
3112 let pos = match text_lower.find(&keyword_lower) {
3114 Some(p) => p,
3115 None => return String::new(),
3116 };
3117
3118 let sentence_end_markers: &[char] = &['.', '!', '?', '。', '!', '?', '\n'];
3121 let sentence_start_markers: &[char] = &['\n'];
3122
3123 let start = text[..pos].rfind(sentence_start_markers)
3126 .map(|i| {
3127 match text[i..].char_indices().nth(1) {
3129 Some((next_idx, _)) => i + next_idx,
3130 None => pos,
3131 }
3132 })
3133 .unwrap_or_else(|| {
3134 text[..pos].rfind(sentence_end_markers)
3137 .map(|i| {
3138 match text[i..].char_indices().nth(1) {
3139 Some((next_idx, _)) => i + next_idx,
3140 None => pos,
3141 }
3142 })
3143 .unwrap_or(0)
3144 });
3145
3146 let end = text[pos..].find(sentence_end_markers)
3148 .map(|i| {
3149 let marker_pos = pos + i;
3150 match text[marker_pos..].char_indices().nth(1) {
3152 Some((next_idx, _)) => marker_pos + next_idx,
3153 None => text.len(),
3154 }
3155 })
3156 .unwrap_or_else(|| {
3157 let max_end = (pos + MAX_MEMORY_CONTENT_LENGTH).min(text.len());
3159 let mut boundary = max_end;
3161 while boundary > pos && !text.is_char_boundary(boundary) {
3162 boundary -= 1;
3163 }
3164 boundary
3165 });
3166
3167 if start >= end || start > text.len() || end > text.len() {
3169 return String::new();
3170 }
3171
3172 let content = text[start..end].trim();
3173
3174 if is_low_quality_memory(content) {
3176 return String::new();
3177 }
3178
3179 let trimmed = content.trim_start();
3182 if let Some(first_char) = trimmed.chars().next() {
3183 if first_char.is_lowercase() && first_char > '\u{4E00}' {
3185 return String::new();
3187 }
3188 }
3189
3190 if content.len() > MAX_MEMORY_CONTENT_LENGTH {
3192 let truncation_point = content[..MAX_MEMORY_CONTENT_LENGTH]
3194 .rfind(sentence_end_markers)
3195 .map(|i| i + 1) .unwrap_or(MAX_MEMORY_CONTENT_LENGTH - 3);
3197 truncate_str(content, truncation_point)
3198 } else {
3199 content.to_string()
3200 }
3201}
3202
3203fn is_low_quality_memory(content: &str) -> bool {
3206 if content.len() < MIN_MEMORY_CONTENT_LENGTH {
3208 return true;
3209 }
3210
3211 let formatting_chars = ['│', '├', '└', '┌', '┐', '─', '═', '║', '╔', '╗', '╚', '╝'];
3213 if content.chars().any(|c| formatting_chars.contains(&c)) {
3214 return true;
3215 }
3216
3217 let first_char = content.chars().next().unwrap_or(' ');
3219 if !first_char.is_alphanumeric() && !first_char.is_ascii_punctuation() && first_char > '\u{FF}' {
3220 return true; }
3222
3223 if content.contains("【自动记忆摘要】") || content.contains("[ACCUMULATED MEMORY]") ||
3225 content.contains("记忆统计") || content.contains("memory.json") ||
3226 content.contains("Debug Report") || content.contains("诊断报告") {
3227 return true;
3228 }
3229
3230 if (content.starts_with("- ") || content.starts_with("* ") || content.starts_with("• "))
3232 && content.len() < 30 {
3233 return true;
3234 }
3235
3236 let alpha_count = content.chars().filter(|c| c.is_alphabetic()).count();
3238 let total_count = content.chars().count();
3239 if total_count > 0 && alpha_count < total_count / 4 {
3240 return true;
3241 }
3242
3243 if content.starts_with("rs**") || content.starts_with("rs:") ||
3246 content.starts_with("fn ") || content.starts_with("pub fn") ||
3247 content.starts_with("let ") || content.starts_with("use ") {
3248 return true;
3249 }
3250
3251 let trimmed = content.trim();
3254 if let Some(second_char) = trimmed.chars().nth(1) {
3255 let first = trimmed.chars().next().unwrap_or(' ');
3256 if !first.is_alphanumeric() && second_char.is_lowercase() && second_char > '\u{4E00}' {
3258 return true;
3259 }
3260 }
3261
3262 if content.len() < 25 && (
3265 content.contains("好的") || content.contains("好的,") ||
3266 content.contains("可以") || content.contains("没问题")
3267 ) {
3268 return true;
3269 }
3270
3271 let punct_count = content.chars().filter(|&c|
3273 c == '.' || c == ',' || c == '!' || c == '?' || c == '。' || c == ','
3274 ).count();
3275 if punct_count > content.len() / 5 {
3276 return true;
3277 }
3278
3279 false
3280}
3281
3282#[derive(Debug, Clone)]
3288pub struct RewindResult {
3289 pub original_count: usize,
3291 pub new_count: usize,
3293 pub rewind_index: usize,
3295 pub summary: Option<String>,
3297 pub new_messages: Vec<Message>,
3299}
3300
3301pub async fn summarize_up_to(
3304 messages: &[Message],
3305 index: usize,
3306 compressor: Option<&dyn crate::compress::Compressor>,
3307) -> Result<RewindResult> {
3308 if index >= messages.len() {
3309 anyhow::bail!("rewind index {} out of bounds (messages: {})", index, messages.len());
3310 }
3311
3312 if index == 0 {
3313 return Ok(RewindResult {
3315 original_count: messages.len(),
3316 new_count: messages.len(),
3317 rewind_index: 0,
3318 summary: None,
3319 new_messages: messages.to_vec(),
3320 });
3321 }
3322
3323 let to_summarize = &messages[..index];
3324 let to_keep = &messages[index..];
3325
3326 let summary = if let Some(comp) = compressor {
3328 let segment = comp.summarize(to_summarize, &crate::compress::CompressionConfig::default()).await?;
3330 Some(segment.summary)
3331 } else {
3332 Some(generate_simple_summary(to_summarize))
3334 };
3335
3336 let summary_msg = create_summary_message(&summary, to_summarize.len());
3338
3339 let new_messages: Vec<Message> = std::iter::once(summary_msg)
3341 .chain(to_keep.iter().cloned())
3342 .collect();
3343
3344 let new_count = new_messages.len();
3345
3346 Ok(RewindResult {
3347 original_count: messages.len(),
3348 new_count,
3349 rewind_index: index,
3350 summary,
3351 new_messages,
3352 })
3353}
3354
3355fn create_summary_message(summary: &Option<String>, original_count: usize) -> Message {
3357 let content = match summary {
3358 Some(s) => format!("[对话摘要 - 原 {} 条消息]\n\n{}", original_count, s),
3359 None => format!("[对话摘要 - 原 {} 条消息已压缩]", original_count),
3360 };
3361
3362 Message {
3363 role: crate::providers::Role::User,
3364 content: crate::providers::MessageContent::Text(content),
3365 }
3366}
3367
3368fn generate_simple_summary(messages: &[Message]) -> String {
3370 let mut parts: Vec<String> = Vec::new();
3371
3372 for msg in messages {
3374 if msg.role == crate::providers::Role::User {
3375 let text = match &msg.content {
3376 crate::providers::MessageContent::Text(t) => t,
3377 _ => continue,
3378 };
3379 let first_line = text.lines().next().unwrap_or("");
3381 if first_line.len() > 20 {
3382 parts.push(truncate_str(first_line, 100));
3383 }
3384 }
3385 }
3386
3387 if parts.is_empty() {
3388 "对话已压缩".to_string()
3389 } else if parts.len() <= 5 {
3390 parts.join(" | ")
3391 } else {
3392 format!("{} ... (共 {} 个话题)", parts[0], parts.len())
3393 }
3394}
3395
3396pub struct SemanticUtils;
3403
3404impl SemanticUtils {
3405 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3416 if a.len() != b.len() || a.is_empty() {
3417 return 0.0;
3418 }
3419
3420 let dot_product = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>();
3421 let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
3422 let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
3423
3424 if norm_a == 0.0 || norm_b == 0.0 {
3425 return 0.0;
3426 }
3427
3428 dot_product / (norm_a * norm_b)
3429 }
3430}
3431
3432
3433pub struct TfIdfSearch {
3466 doc_word_freq: HashMap<String, HashMap<String, f32>>,
3468 total_docs: usize,
3470 idf_cache: HashMap<String, f32>,
3472}
3473
3474impl TfIdfSearch {
3475 pub fn new() -> Self {
3477 Self {
3478 doc_word_freq: HashMap::new(),
3479 total_docs: 0,
3480 idf_cache: HashMap::new(),
3481 }
3482 }
3483
3484 pub fn index(&mut self, memory: &AutoMemory) {
3486 self.clear();
3487 self.total_docs = memory.entries.len();
3488
3489 for entry in &memory.entries {
3490 let words = self.tokenize(&entry.content);
3491 let word_freq = self.compute_word_freq(&words);
3492 self.doc_word_freq.insert(entry.content.clone(), word_freq);
3493 }
3494
3495 self.compute_idf();
3497 }
3498
3499 fn tokenize(&self, text: &str) -> Vec<String> {
3502 let lower = text.to_lowercase();
3503 let mut tokens = Vec::new();
3504
3505 for word in lower.split_whitespace() {
3507 let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
3508 if trimmed.len() > 1 {
3509 tokens.push(trimmed.to_string());
3510 }
3511
3512 let chars: Vec<char> = trimmed.chars().collect();
3514 let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
3515
3516 if has_cjk {
3517 for c in &chars {
3519 if Self::is_cjk(*c) {
3520 tokens.push(c.to_string());
3521 }
3522 }
3523 for window in chars.windows(2) {
3525 if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
3526 tokens.push(window.iter().collect::<String>());
3527 }
3528 }
3529 }
3530 }
3531
3532 tokens
3533 }
3534
3535 fn is_cjk(c: char) -> bool {
3537 matches!(c,
3538 '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}' | '\u{3000}'..='\u{303F}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' )
3545 }
3546
3547 fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
3549 let total = words.len() as f32;
3550 let mut freq = HashMap::new();
3551
3552 for word in words {
3553 *freq.entry(word.clone()).or_insert(0.0) += 1.0;
3554 }
3555
3556 for (_, count) in freq.iter_mut() {
3558 *count /= total;
3559 }
3560
3561 freq
3562 }
3563
3564 fn compute_idf(&mut self) {
3566 let mut word_doc_count: HashMap<String, usize> = HashMap::new();
3568
3569 for word_freq in &self.doc_word_freq {
3570 for word in word_freq.1.keys() {
3571 *word_doc_count.entry(word.clone()).or_insert(0) += 1;
3572 }
3573 }
3574
3575 for (word, count) in word_doc_count {
3577 let idf = (self.total_docs as f32 / count as f32).ln();
3578 self.idf_cache.insert(word, idf);
3579 }
3580 }
3581
3582 pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
3584 let query_words = self.tokenize(query);
3585 let query_freq = self.compute_word_freq(&query_words);
3586
3587 let mut results: Vec<(String, f32)> = Vec::new();
3588
3589 for (doc, doc_freq) in &self.doc_word_freq {
3590 let similarity = self.compute_similarity(&query_freq, doc_freq);
3592
3593 if similarity > 0.0 {
3594 results.push((doc.clone(), similarity));
3595 }
3596 }
3597
3598 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
3600
3601 if let Some(max) = limit {
3603 results.into_iter().take(max).collect()
3604 } else {
3605 results
3606 }
3607 }
3608
3609 fn compute_similarity(&self, query_freq: &HashMap<String, f32>, doc_freq: &HashMap<String, f32>) -> f32 {
3611 let mut similarity = 0.0;
3612
3613 for (word, tf_query) in query_freq {
3614 if let Some(tf_doc) = doc_freq.get(word)
3615 && let Some(idf) = self.idf_cache.get(word) {
3616 similarity += tf_query * idf * tf_doc * idf;
3618 }
3619 }
3620
3621 similarity
3622 }
3623
3624 pub fn clear(&mut self) {
3626 self.doc_word_freq.clear();
3627 self.idf_cache.clear();
3628 self.total_docs = 0;
3629 }
3630}
3631
3632impl Default for TfIdfSearch {
3633 fn default() -> Self {
3634 Self::new()
3635 }
3636}
3637
3638#[cfg(test)]
3639mod tests {
3640 use super::*;
3641
3642 #[test]
3643 fn test_memory_entry_creation() {
3644 let entry = MemoryEntry::new(
3645 MemoryCategory::Decision,
3646 "Decided to use PostgreSQL for database".to_string(),
3647 Some("session-123".to_string()),
3648 );
3649 assert_eq!(entry.category, MemoryCategory::Decision);
3650 assert_eq!(entry.importance, DEFAULT_IMPORTANCE_DECISION); assert!(!entry.is_manual);
3652 }
3653
3654 #[test]
3655 fn test_memory_reference_increase() {
3656 let mut entry = MemoryEntry::new(
3657 MemoryCategory::Finding,
3658 "API endpoint is at /api/v2".to_string(),
3659 None,
3660 );
3661 assert_eq!(entry.importance, DEFAULT_IMPORTANCE_FINDING); entry.mark_referenced();
3663 assert_eq!(entry.importance, 57.0); entry.mark_referenced();
3667 entry.mark_referenced();
3668 assert_eq!(entry.importance, 61.0); }
3670
3671 #[test]
3672 fn test_auto_memory_add_and_prune() {
3673 let mut memory = AutoMemory::new();
3674 memory.max_entries = 5;
3675
3676 for i in 0..10 {
3677 memory.add(MemoryEntry::new(
3678 MemoryCategory::Technical,
3679 format!("Note {}", i),
3680 None,
3681 ));
3682 }
3683
3684 assert!(memory.entries.len() <= memory.max_entries);
3686 }
3687
3688 #[test]
3689 fn test_duplicate_detection() {
3690 let mut memory = AutoMemory::new();
3691 memory.add_memory(
3692 MemoryCategory::Decision,
3693 "Use PostgreSQL".to_string(),
3694 None,
3695 );
3696
3697 memory.add_memory(
3699 MemoryCategory::Decision,
3700 "Use PostgreSQL".to_string(),
3701 None,
3702 );
3703
3704 assert_eq!(memory.entries.len(), 1);
3705 }
3706
3707 #[test]
3708 fn test_memory_detection() {
3709 let text = "我们决定采用 React 作为前端框架";
3711 let entries = detect_memories_from_text(text, None);
3712 assert!(!entries.is_empty());
3713 assert_eq!(entries[0].category, MemoryCategory::Decision);
3714
3715 let text2 = "解决了认证问题,解决方案是通过添加 token refresh 机制";
3717 let entries2 = detect_memories_from_text(text2, None);
3718 assert!(!entries2.is_empty());
3719 assert_eq!(entries2[0].category, MemoryCategory::Solution);
3720
3721 let text3 = "我偏好使用 TypeScript 进行开发";
3723 let entries3 = detect_memories_from_text(text3, None);
3724 assert!(!entries3.is_empty());
3725 assert_eq!(entries3[0].category, MemoryCategory::Preference);
3726 }
3727
3728 #[test]
3729 fn test_category_importance() {
3730 assert!(MemoryCategory::Decision.default_importance() > MemoryCategory::Structure.default_importance());
3731 assert!(MemoryCategory::Solution.default_importance() > MemoryCategory::Technical.default_importance());
3732 }
3733
3734 #[test]
3735 fn test_top_n_entries() {
3736 let mut memory = AutoMemory::new();
3737
3738 memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decision 1".into(), None));
3740 memory.add(MemoryEntry::new(MemoryCategory::Finding, "Finding 1".into(), None));
3741 memory.add(MemoryEntry::new(MemoryCategory::Structure, "Structure 1".into(), None));
3742
3743 let top = memory.top_n(2);
3744 assert_eq!(top.len(), 2);
3745 assert_eq!(top[0].category, MemoryCategory::Decision); }
3747
3748 #[test]
3749 fn test_similarity_calculation() {
3750 let sim = AutoMemory::calculate_similarity("hello world", "hello world");
3752 assert_eq!(sim, 1.0);
3753
3754 let sim = AutoMemory::calculate_similarity("hello world", "foo bar");
3756 assert_eq!(sim, 0.0);
3757
3758 let sim = AutoMemory::calculate_similarity("hello world", "hello there");
3760 assert!(sim > 0.0 && sim < 1.0);
3761
3762 let sim = AutoMemory::calculate_similarity("", "hello");
3764 assert_eq!(sim, 0.0);
3765 }
3766
3767 #[test]
3768 fn test_similarity_threshold() {
3769 let mut memory = AutoMemory::new();
3770
3771 memory.add(MemoryEntry::new(
3773 MemoryCategory::Decision,
3774 "We decided to use PostgreSQL for our database system".to_string(),
3775 None,
3776 ));
3777
3778 memory.add_memory(
3780 MemoryCategory::Decision,
3781 "We decided to use PostgreSQL for our database backend".to_string(),
3782 None,
3783 );
3784
3785 assert_eq!(memory.entries.len(), 1);
3787 }
3788
3789 #[test]
3790 fn test_short_content_skipped() {
3791 let mut memory = AutoMemory::new();
3792
3793 memory.add(MemoryEntry::new(
3795 MemoryCategory::Technical,
3796 "short".to_string(), None,
3798 ));
3799
3800 memory.add_memory(
3802 MemoryCategory::Technical,
3803 "brief".to_string(),
3804 None,
3805 );
3806
3807 assert_eq!(memory.entries.len(), 2);
3808 }
3809
3810 #[test]
3811 fn test_prune_preserves_manual() {
3812 let mut memory = AutoMemory::new();
3813 memory.max_entries = 3;
3814
3815 let mut manual = MemoryEntry::manual(MemoryCategory::Decision, "Manual decision".into());
3817 manual.importance = 10.0; memory.add(manual);
3819
3820 for i in 0..5 {
3822 let entry = MemoryEntry::new(
3823 MemoryCategory::Decision,
3824 format!("Auto decision {}", i),
3825 None,
3826 );
3827 memory.add(entry);
3828 }
3829
3830 assert!(memory.entries.iter().any(|e| e.is_manual));
3832 assert!(memory.entries.len() <= memory.max_entries);
3833 }
3834
3835 #[test]
3836 fn test_deduplicate_entries() {
3837 let entries = vec![
3839 MemoryEntry::new(MemoryCategory::Decision, "We chose PostgreSQL database system for our backend".into(), None),
3840 MemoryEntry::new(MemoryCategory::Decision, "We chose PostgreSQL database system backend".into(), None),
3841 MemoryEntry::new(MemoryCategory::Decision, "Using Redis for caching layer".into(), None),
3842 ];
3843
3844 let deduped = deduplicate_entries(entries);
3845
3846 assert!(deduped.len() >= 1);
3848 assert!(deduped.len() <= 3);
3849
3850 let pg_entries: Vec<_> = deduped.iter()
3852 .filter(|e| e.content.to_lowercase().contains("postgresql"))
3853 .collect();
3854
3855 if pg_entries.len() == 1 {
3856 assert!(pg_entries[0].content.contains("backend"));
3859 }
3860 }
3861
3862 #[test]
3863 fn test_memory_detection_edge_cases() {
3864 let entries = detect_memories_from_text("", None);
3866 assert!(entries.is_empty());
3867
3868 let entries = detect_memories_from_text("决定", None);
3870 assert!(entries.is_empty());
3871
3872 let entries = detect_memories_from_text("使用", None);
3874 assert!(entries.is_empty());
3875
3876 let text = "我决定使用React,解决了性能问题通过添加缓存机制";
3878 let entries = detect_memories_from_text(text, None);
3879 assert!(entries.len() <= MAX_DETECTED_ENTRIES);
3880 }
3881
3882 #[test]
3883 fn test_importance_ceiling() {
3884 let mut entry = MemoryEntry::new(
3885 MemoryCategory::Decision,
3886 "Important decision".into(),
3887 None,
3888 );
3889
3890 assert_eq!(entry.importance, DEFAULT_IMPORTANCE_DECISION);
3892
3893 for _ in 0..20 {
3895 entry.mark_referenced();
3896 }
3897
3898 assert!(entry.importance <= MAX_IMPORTANCE_CEILING);
3900 }
3901
3902 #[test]
3903 fn test_time_decay() {
3904 let mut memory = AutoMemory::new();
3905 memory.min_importance = 30.0;
3906
3907 let mut manual = MemoryEntry::manual(MemoryCategory::Decision, "Manual entry".into());
3909 manual.importance = 50.0;
3910 memory.add(manual);
3911
3912 let mut old_entry = MemoryEntry::new(
3914 MemoryCategory::Technical,
3915 "Old technical note".into(),
3916 None,
3917 );
3918 old_entry.importance = 60.0;
3919 old_entry.last_referenced = Utc::now() - chrono::Duration::days(60);
3921 memory.add(old_entry);
3922
3923 let recent_entry = MemoryEntry::new(
3925 MemoryCategory::Finding,
3926 "Recent finding".into(),
3927 None,
3928 );
3929 memory.add(recent_entry);
3930
3931 memory.apply_time_decay();
3933
3934 let manual_entry = memory.entries.iter().find(|e| e.is_manual);
3936 assert!(manual_entry.is_some());
3937 assert_eq!(manual_entry.unwrap().importance, 50.0);
3938
3939 let recent = memory.entries.iter().find(|e| e.content.contains("Recent"));
3941 assert!(recent.is_some());
3942 assert!(recent.unwrap().importance >= DEFAULT_IMPORTANCE_FINDING); let old = memory.entries.iter().find(|e| e.content.contains("Old"));
3946 if let Some(old_entry) = old {
3947 assert!(old_entry.importance < 60.0);
3950 assert!(old_entry.importance >= memory.min_importance * 0.5);
3952 }
3953 }
3954
3955 #[test]
3956 fn test_parse_memory_response() {
3957 let json = r#"{"memories": [{"category": "decision", "content": "决定使用 PostgreSQL 作为数据库", "importance": 90}, {"category": "preference", "content": "我偏好 TypeScript 而非 JavaScript", "importance": 70}]}"#;
3959 let entries = parse_memory_response(json, None).unwrap();
3960 assert_eq!(entries.len(), 2);
3961
3962 let has_decision = entries.iter().any(|e| e.category == MemoryCategory::Decision);
3964 let has_preference = entries.iter().any(|e| e.category == MemoryCategory::Preference);
3965 assert!(has_decision);
3966 assert!(has_preference);
3967
3968 let decision_entry = entries.iter().find(|e| e.category == MemoryCategory::Decision);
3970 assert!(decision_entry.is_some());
3971 assert_eq!(decision_entry.unwrap().importance, 90.0);
3972
3973 let empty_json = r#"{"memories": []}"#;
3975 let empty_entries = parse_memory_response(empty_json, None).unwrap();
3976 assert!(empty_entries.is_empty());
3977
3978 let markdown_json = r#"```json
3980{"memories": [{"category": "solution", "content": "通过添加 middleware 修复认证问题", "importance": 85}]}
3981```"#;
3982 let markdown_entries = parse_memory_response(markdown_json, None).unwrap();
3983 assert_eq!(markdown_entries.len(), 1);
3984 assert_eq!(markdown_entries[0].category, MemoryCategory::Solution);
3985
3986 let unknown_json = r#"{"memories": [{"category": "unknown", "content": "This should be skipped content", "importance": 50}]}"#;
3988 let unknown_entries = parse_memory_response(unknown_json, None).unwrap();
3989 assert!(unknown_entries.is_empty());
3990
3991 let short_json = r#"{"memories": [{"category": "finding", "content": "short", "importance": 60}]}"#;
3993 let short_entries = parse_memory_response(short_json, None).unwrap();
3994 assert!(short_entries.is_empty());
3995 }
3996
3997 #[test]
3998 fn test_public_has_similar() {
3999 let mut memory = AutoMemory::new();
4000
4001 memory.add(MemoryEntry::new(
4003 MemoryCategory::Decision,
4004 "We decided to use PostgreSQL for our main database system".to_string(),
4005 None,
4006 ));
4007
4008 assert!(memory.has_similar("We decided to use PostgreSQL for our main database system"));
4010
4011 assert!(memory.has_similar("We decided to use PostgreSQL for our main database system backend"));
4015
4016 assert!(!memory.has_similar("We decided to use Redis for caching"));
4020
4021 assert!(!memory.has_similar("The project uses React for frontend"));
4023
4024 assert!(!memory.has_similar("short"));
4026 }
4027
4028 #[test]
4029 fn test_public_prune() {
4030 let mut memory = AutoMemory::new();
4031 memory.max_entries = 5;
4032 memory.min_importance = 30.0;
4033
4034 for i in 0..10 {
4036 memory.add(MemoryEntry::new(
4037 MemoryCategory::Technical,
4038 format!("Technical note number {} with sufficient length", i),
4039 None,
4040 ));
4041 }
4042
4043 memory.prune();
4045
4046 assert!(memory.entries.len() <= memory.max_entries);
4048 }
4049
4050 #[test]
4051 fn test_statistics() {
4052 let mut memory = AutoMemory::new();
4053
4054 memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decision one with enough content".to_string(), None));
4056 memory.add(MemoryEntry::new(MemoryCategory::Preference, "Preference for TypeScript over JavaScript".to_string(), None));
4057 memory.add(MemoryEntry::manual(MemoryCategory::Technical, "Manual technical note".to_string()));
4058
4059 memory.entries[0].mark_referenced();
4061 memory.entries[0].mark_referenced();
4062 memory.entries[0].mark_referenced();
4063
4064 let stats = memory.generate_statistics();
4065
4066 assert_eq!(stats.total, 3);
4067 assert_eq!(stats.manual, 1);
4068 assert_eq!(stats.auto, 2);
4069 assert_eq!(stats.highly_referenced, 1); assert!(stats.by_category.contains_key(&MemoryCategory::Decision));
4071 assert!(stats.by_category.contains_key(&MemoryCategory::Preference));
4072 assert!(stats.by_category.contains_key(&MemoryCategory::Technical));
4073 assert!(stats.avg_importance > 0.0);
4074 }
4075
4076 #[test]
4077 fn test_memory_config() {
4078 let config = MemoryConfig::default();
4080 assert_eq!(config.max_entries, 100);
4081 assert_eq!(config.min_importance, 30.0);
4082 assert_eq!(config.decay_start_days, 30);
4083 assert_eq!(config.decay_rate, 0.5);
4084
4085 let minimal = MemoryConfig::minimal();
4087 assert_eq!(minimal.max_entries, 50);
4088 assert!(minimal.min_importance > config.min_importance);
4089
4090 let archival = MemoryConfig::archival();
4092 assert_eq!(archival.max_entries, 500);
4093 assert!(archival.min_importance < config.min_importance);
4094
4095 let custom = MemoryConfig::with_max_entries(200);
4097 assert_eq!(custom.max_entries, 200);
4098 assert_eq!(custom.min_importance, 30.0); }
4100
4101 #[test]
4102 fn test_auto_memory_with_config() {
4103 let config = MemoryConfig::minimal();
4104 let mut memory = AutoMemory::with_config(config);
4105
4106 assert_eq!(memory.max_entries, 50);
4107 assert_eq!(memory.min_importance, 50.0);
4108
4109 for i in 0..60 {
4111 memory.add(MemoryEntry::new(
4112 MemoryCategory::Technical,
4113 format!("Technical note {} with enough length for detection", i),
4114 None,
4115 ));
4116 }
4117
4118 assert!(memory.entries.len() <= 50);
4120 }
4121
4122 #[test]
4123 fn test_batch_add() {
4124 let mut memory = AutoMemory::new();
4125
4126 let entries: Vec<MemoryEntry> = vec![
4128 MemoryEntry::new(MemoryCategory::Decision, "First decision with sufficient content".into(), None),
4129 MemoryEntry::new(MemoryCategory::Finding, "First finding with sufficient content".into(), None),
4130 MemoryEntry::new(MemoryCategory::Solution, "First solution with sufficient content".into(), None),
4131 ];
4132
4133 memory.add_batch(entries);
4134 assert_eq!(memory.entries.len(), 3);
4135
4136 let duplicate_entries: Vec<MemoryEntry> = vec![
4138 MemoryEntry::new(MemoryCategory::Decision, "First decision with sufficient content".into(), None), MemoryEntry::new(MemoryCategory::Technical, "New technical note with sufficient content".into(), None),
4140 ];
4141
4142 memory.add_batch(duplicate_entries);
4143 assert_eq!(memory.entries.len(), 4); }
4145
4146 #[test]
4147 fn test_search_with_limit() {
4148 let mut memory = AutoMemory::new();
4149
4150 for i in 0..10 {
4152 memory.add(MemoryEntry::new(
4153 MemoryCategory::Technical,
4154 format!("PostgreSQL technical note {} with details", i),
4155 None,
4156 ));
4157 }
4158
4159 let all = memory.search("postgresql");
4161 assert_eq!(all.len(), 10);
4162
4163 let limited = memory.search_with_limit("postgresql", Some(5));
4165 assert_eq!(limited.len(), 5);
4166
4167 assert!(limited[0].importance >= limited[limited.len() - 1].importance);
4169 }
4170
4171 #[test]
4172 fn test_multi_keyword_search() {
4173 let mut memory = AutoMemory::new();
4174
4175 memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decided to use PostgreSQL".into(), None));
4176 memory.add(MemoryEntry::new(MemoryCategory::Technical, "Using Redis for caching".into(), None));
4177 memory.add(MemoryEntry::new(MemoryCategory::Solution, "Fixed by adding middleware".into(), None));
4178
4179 let results = memory.search_multi(&["postgresql", "redis"]);
4181 assert_eq!(results.len(), 2);
4182
4183 let empty = memory.search_multi(&["mongodb"]);
4185 assert!(empty.is_empty());
4186 }
4187
4188 #[test]
4189 fn test_mark_referenced_with_increment() {
4190 let mut entry = MemoryEntry::new(
4191 MemoryCategory::Finding,
4192 "API endpoint location".into(),
4193 None,
4194 );
4195
4196 assert_eq!(entry.importance, DEFAULT_IMPORTANCE_FINDING); entry.mark_referenced_with_increment(5.0);
4200 assert_eq!(entry.importance, 60.0); entry.mark_referenced();
4204 assert_eq!(entry.importance, 62.0); for _ in 0..20 {
4208 entry.mark_referenced_with_increment(10.0);
4209 }
4210 assert!(entry.importance <= MAX_IMPORTANCE_CEILING);
4211 }
4212
4213 #[test]
4214 fn test_search_index() {
4215 let mut memory = AutoMemory::new();
4216
4217 for i in 0..20 {
4219 memory.add(MemoryEntry::new(
4220 MemoryCategory::Technical,
4221 format!("PostgreSQL technical note {} with sufficient content length", i),
4222 None,
4223 ));
4224 }
4225 for i in 0..10 {
4226 memory.add(MemoryEntry::new(
4227 MemoryCategory::Decision,
4228 format!("Redis decision {} with sufficient content for testing", i),
4229 None,
4230 ));
4231 }
4232
4233 memory.rebuild_index();
4235 assert!(memory.search_index.is_some());
4236
4237 let results = memory.search_fast("postgresql", Some(5));
4239 assert!(results.len() <= 5);
4240 assert!(results.iter().all(|e| e.content.to_lowercase().contains("postgresql")));
4241
4242 let multi_results = memory.search_multi_fast(&["postgresql", "redis"]);
4244 assert!(multi_results.len() > 0);
4245
4246 let tech_entries = memory.by_category_fast(MemoryCategory::Technical);
4248 assert_eq!(tech_entries.len(), 20);
4249
4250 let decision_entries = memory.by_category_fast(MemoryCategory::Decision);
4251 assert_eq!(decision_entries.len(), 10);
4252
4253 let top = memory.top_n_fast(5);
4255 assert_eq!(top.len(), 5);
4256 assert!(top[0].importance >= top[top.len() - 1].importance);
4258 }
4259
4260 #[test]
4261 fn test_index_auto_rebuild() {
4262 let mut memory = AutoMemory::new();
4263
4264 assert!(memory.search_index.is_none());
4266
4267 memory.add(MemoryEntry::new(
4269 MemoryCategory::Decision,
4270 "Test decision with sufficient content length".into(),
4271 None,
4272 ));
4273
4274 let results = memory.search_fast("test", None);
4275 assert!(results.len() > 0);
4276 assert!(memory.search_index.is_some()); memory.clear();
4280 assert!(memory.search_index.is_none());
4281
4282 memory.add(MemoryEntry::new(
4284 MemoryCategory::Finding,
4285 "New finding with sufficient content".into(),
4286 None,
4287 ));
4288 let _ = memory.search_fast("finding", None);
4289 assert!(memory.search_index.is_some());
4290 }
4291
4292 #[test]
4293 fn test_cosine_similarity() {
4294 let a = vec![1.0, 0.0, 0.0];
4296 let b = vec![1.0, 0.0, 0.0];
4297 assert_eq!(SemanticUtils::cosine_similarity(&a, &b), 1.0);
4298
4299 let a = vec![1.0, 0.0, 0.0];
4301 let b = vec![0.0, 1.0, 0.0];
4302 assert!((SemanticUtils::cosine_similarity(&a, &b) - 0.0).abs() < 0.001);
4303
4304 let a = vec![1.0, 0.0, 0.0];
4306 let b = vec![-1.0, 0.0, 0.0];
4307 assert!((SemanticUtils::cosine_similarity(&a, &b) - (-1.0)).abs() < 0.001);
4308
4309 let a = vec![1.0, 1.0, 0.0];
4311 let b = vec![1.0, 0.0, 0.0];
4312 let sim = SemanticUtils::cosine_similarity(&a, &b);
4313 assert!(sim > 0.0 && sim < 1.0);
4314
4315 let a: Vec<f32> = vec![];
4317 let b: Vec<f32> = vec![];
4318 assert_eq!(SemanticUtils::cosine_similarity(&a, &b), 0.0);
4319 }
4320
4321 #[test]
4322 fn test_tfidf_search() {
4323 let mut memory = AutoMemory::new();
4324
4325 memory.add(MemoryEntry::new(MemoryCategory::Decision, "使用 PostgreSQL 作为主数据库系统".into(), None));
4326 memory.add(MemoryEntry::new(MemoryCategory::Technical, "Redis 缓存配置为 10 个连接".into(), None));
4327 memory.add(MemoryEntry::new(MemoryCategory::Solution, "通过添加 middleware 修复认证问题".into(), None));
4328 memory.add(MemoryEntry::new(MemoryCategory::Finding, "数据库连接池设置为 20".into(), None));
4329
4330 let mut tfidf = TfIdfSearch::new();
4331 tfidf.index(&memory);
4332
4333 let results = tfidf.search("数据库", Some(5));
4335 assert!(!results.is_empty());
4336 assert!(results[0].0.contains("数据库"));
4338
4339 let results = tfidf.search("redis", Some(5));
4341 assert!(!results.is_empty());
4342 assert!(results[0].0.to_lowercase().contains("redis"));
4343
4344 let results = tfidf.search("mongodb", Some(5));
4346 assert!(results.is_empty());
4347 }
4348
4349 #[test]
4350 fn test_tfidf_ranking() {
4351 let mut memory = AutoMemory::new();
4352
4353 memory.add(MemoryEntry::new(MemoryCategory::Decision, "使用 PostgreSQL 数据库 作为主数据库".into(), None));
4355 memory.add(MemoryEntry::new(MemoryCategory::Technical, "数据库连接池配置".into(), None));
4356 memory.add(MemoryEntry::new(MemoryCategory::Solution, "修复了前端样式问题".into(), None));
4357
4358 let mut tfidf = TfIdfSearch::new();
4359 tfidf.index(&memory);
4360
4361 let results = tfidf.search("数据库", None);
4362
4363 if results.len() >= 2 {
4365 assert!(results[0].1 >= results[1].1);
4366 }
4367 }
4368
4369 #[test]
4370 fn test_conflict_detection() {
4371 let mut memory = AutoMemory::new();
4372
4373 memory.add_memory(
4375 MemoryCategory::Decision,
4376 "决定使用 PostgreSQL 作为主数据库".to_string(),
4377 None,
4378 );
4379 assert_eq!(memory.entries.len(), 1);
4380 assert!(memory.entries[0].content.contains("PostgreSQL"));
4381
4382 memory.add_memory(
4384 MemoryCategory::Decision,
4385 "决定使用 MySQL 作为主数据库".to_string(),
4386 None,
4387 );
4388
4389 assert_eq!(memory.entries.len(), 1);
4391 assert!(memory.entries[0].content.contains("MySQL"));
4392 }
4393
4394 #[test]
4395 fn test_conflict_with_change_signal() {
4396 let mut memory = AutoMemory::new();
4397
4398 memory.add_memory(
4400 MemoryCategory::Preference,
4401 "偏好使用 vim 编辑器".to_string(),
4402 None,
4403 );
4404 assert_eq!(memory.entries.len(), 1);
4405
4406 memory.add_memory(
4408 MemoryCategory::Preference,
4409 "改用 vscode 编辑器,不再使用 vim".to_string(),
4410 None,
4411 );
4412
4413 assert_eq!(memory.entries.len(), 1);
4415 assert!(memory.entries[0].content.contains("vscode"));
4416 }
4417
4418 #[test]
4419 fn test_no_false_conflict() {
4420 let mut memory = AutoMemory::new();
4421
4422 memory.add_memory(
4424 MemoryCategory::Decision,
4425 "决定使用 PostgreSQL 作为主数据库".to_string(),
4426 None,
4427 );
4428 memory.add_memory(
4429 MemoryCategory::Decision,
4430 "决定使用 Redis 作为缓存系统".to_string(),
4431 None,
4432 );
4433
4434 assert_eq!(memory.entries.len(), 2);
4436 }
4437
4438 #[test]
4439 fn test_contextual_summary() {
4440 let mut memory = AutoMemory::new();
4441
4442 memory.add(MemoryEntry::new(MemoryCategory::Decision, "决定使用 PostgreSQL 作为主数据库".into(), None));
4444 memory.add(MemoryEntry::new(MemoryCategory::Technical, "前端使用 React 框架开发".into(), None));
4445 memory.add(MemoryEntry::new(MemoryCategory::Solution, "通过添加 Redis 缓存解决性能问题".into(), None));
4446 memory.add(MemoryEntry::new(MemoryCategory::Finding, "API 响应时间在 200ms 以内".into(), None));
4447 memory.add(MemoryEntry::new(MemoryCategory::Preference, "偏好使用 TypeScript 而非 JavaScript".into(), None));
4448
4449 let db_summary = memory.generate_contextual_summary("数据库查询优化", 3);
4451 assert!(db_summary.contains("PostgreSQL"));
4452
4453 let fe_summary = memory.generate_contextual_summary("React 组件开发", 3);
4455 assert!(fe_summary.contains("React"));
4456
4457 let empty_summary = memory.generate_contextual_summary("", 3);
4459 assert!(!empty_summary.is_empty());
4460 }
4461
4462 #[test]
4463 fn test_low_quality_memory_filter() {
4464 assert!(is_low_quality_memory("│ 🎯 决策: 决定使用 PostgreSQL."));
4466 assert!(is_low_quality_memory("├── Structure: 入口文件是 main."));
4467 assert!(is_low_quality_memory("🔧 解决方案: 通过添加 middleware."));
4468 assert!(is_low_quality_memory("【自动记忆摘要】"));
4469 assert!(is_low_quality_memory("short"));
4470
4471 assert!(!is_low_quality_memory("决定使用 PostgreSQL 作为主数据库系统"));
4473 assert!(!is_low_quality_memory("通过添加 Redis 缓存层解决了性能问题"));
4474 assert!(!is_low_quality_memory("用户偏好使用 TypeScript 进行开发"));
4475 }
4476}