use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::fs;
use crate::providers::Message;
fn truncate_str(s: &str, max_len: usize) -> String {
if s.len() > max_len {
format!("{}...", &s[..max_len.saturating_sub(3)])
} else {
s.to_string()
}
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() > max_len {
s[..max_len].to_string()
} else {
s.to_string()
}
}
pub const MAX_IMPORTANCE_CEILING: f64 = 100.0;
pub const MIN_SIMILARITY_LENGTH: usize = 10;
pub const SIMILARITY_THRESHOLD: f64 = 0.7;
pub const MIN_MEMORY_CONTENT_LENGTH: usize = 15;
pub const MAX_DETECTED_ENTRIES: usize = 5;
pub const MAX_MEMORY_CONTENT_LENGTH: usize = 200;
pub const MAX_DISPLAY_LENGTH: usize = 60;
pub const CONFLICT_OVERLAY_THRESHOLD: f64 = 0.5;
pub const CONFLICT_OVERLAY_THRESHOLD_WITH_SIGNAL: f64 = 0.3;
pub const IMPORTANCE_STAR_THRESHOLD: f64 = 80.0;
pub const CONTEXT_RELEVANCE_WEIGHT: f64 = 0.6;
pub const CONTEXT_IMPORTANCE_WEIGHT: f64 = 0.4;
pub const DEFAULT_MEMORY_EXTRACTOR_MODEL: &str = "claude-3-5-haiku-20241022";
pub const MIN_KEYWORDS_FOR_AI_FALLBACK: usize = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AiKeywordMode {
#[default]
Auto,
Always,
Never,
}
impl AiKeywordMode {
pub fn from_env() -> Self {
match std::env::var("MEMORY_AI_KEYWORDS")
.unwrap_or_default()
.to_lowercase()
.as_str()
{
"always" | "true" | "1" => AiKeywordMode::Always,
"never" | "false" | "0" => AiKeywordMode::Never,
"auto" | "" => AiKeywordMode::Auto,
other => {
log::warn!("Unknown MEMORY_AI_KEYWORDS value: '{}', using 'auto'", other);
AiKeywordMode::Auto
}
}
}
pub fn should_use_ai(&self, keyword_count: usize) -> bool {
match self {
AiKeywordMode::Always => true,
AiKeywordMode::Never => false,
AiKeywordMode::Auto => keyword_count < MIN_KEYWORDS_FOR_AI_FALLBACK,
}
}
}
pub const DEFAULT_IMPORTANCE_DECISION: f64 = 90.0;
pub const DEFAULT_IMPORTANCE_SOLUTION: f64 = 85.0;
pub const DEFAULT_IMPORTANCE_PREF: f64 = 70.0;
pub const DEFAULT_IMPORTANCE_FINDING: f64 = 60.0;
pub const DEFAULT_IMPORTANCE_TECH: f64 = 50.0;
pub const DEFAULT_IMPORTANCE_STRUCTURE: f64 = 40.0;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub max_entries: usize,
pub min_importance: f64,
pub enabled: bool,
pub decay_start_days: i64,
pub decay_rate: f64,
pub reference_increment: f64,
pub max_importance_ceiling: f64,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
max_entries: 100,
min_importance: 30.0,
enabled: true,
decay_start_days: 30,
decay_rate: 0.5,
reference_increment: 2.0,
max_importance_ceiling: MAX_IMPORTANCE_CEILING,
}
}
}
impl MemoryConfig {
pub fn with_max_entries(max: usize) -> Self {
Self {
max_entries: max,
..Self::default()
}
}
pub fn minimal() -> Self {
Self {
max_entries: 50,
min_importance: 50.0,
enabled: true,
decay_start_days: 14,
decay_rate: 0.6,
reference_increment: 1.0,
max_importance_ceiling: MAX_IMPORTANCE_CEILING,
}
}
pub fn archival() -> Self {
Self {
max_entries: 500,
min_importance: 20.0,
enabled: true,
decay_start_days: 90,
decay_rate: 0.3,
reference_increment: 3.0,
max_importance_ceiling: MAX_IMPORTANCE_CEILING,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[serde(rename_all = "snake_case")]
pub enum MemoryCategory {
Preference,
Decision,
Finding,
Solution,
Technical,
Structure,
}
impl MemoryCategory {
pub fn display_name(&self) -> &'static str {
match self {
MemoryCategory::Preference => "偏好",
MemoryCategory::Decision => "决策",
MemoryCategory::Finding => "发现",
MemoryCategory::Solution => "解决方案",
MemoryCategory::Technical => "技术",
MemoryCategory::Structure => "结构",
}
}
pub fn icon(&self) -> &'static str {
match self {
MemoryCategory::Preference => "👤",
MemoryCategory::Decision => "🎯",
MemoryCategory::Finding => "💡",
MemoryCategory::Solution => "🔧",
MemoryCategory::Technical => "📚",
MemoryCategory::Structure => "🏗️",
}
}
pub fn default_importance(&self) -> f64 {
match self {
MemoryCategory::Decision => DEFAULT_IMPORTANCE_DECISION,
MemoryCategory::Solution => DEFAULT_IMPORTANCE_SOLUTION,
MemoryCategory::Preference => DEFAULT_IMPORTANCE_PREF,
MemoryCategory::Finding => DEFAULT_IMPORTANCE_FINDING,
MemoryCategory::Technical => DEFAULT_IMPORTANCE_TECH,
MemoryCategory::Structure => DEFAULT_IMPORTANCE_STRUCTURE,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub created_at: DateTime<Utc>,
pub last_referenced: DateTime<Utc>,
pub category: MemoryCategory,
pub content: String,
pub source_session: Option<String>,
pub reference_count: u32,
pub importance: f64,
pub tags: Vec<String>,
pub is_manual: bool,
}
impl MemoryEntry {
pub fn new(category: MemoryCategory, content: String, source_session: Option<String>) -> Self {
let id = uuid::Uuid::new_v4().to_string();
Self {
id,
created_at: Utc::now(),
last_referenced: Utc::now(),
category,
content,
source_session,
reference_count: 0,
importance: category.default_importance(),
tags: Vec::new(),
is_manual: false,
}
}
pub fn manual(category: MemoryCategory, content: String) -> Self {
let mut entry = Self::new(category, content, None);
entry.is_manual = true;
entry.importance = 95.0; entry
}
pub fn mark_referenced(&mut self) {
self.mark_referenced_with_increment(2.0);
}
pub fn mark_referenced_with_increment(&mut self, increment: f64) {
self.reference_count += 1;
self.last_referenced = Utc::now();
self.importance = (self.importance + increment).min(MAX_IMPORTANCE_CEILING);
}
pub fn format_line(&self) -> String {
let time = self.created_at.format("%Y-%m-%d %H:%M");
let importance_marker = if self.importance >= IMPORTANCE_STAR_THRESHOLD { "⭐" } else { "" };
let manual_marker = if self.is_manual { "📝" } else { "" };
format!(
"{} {} {}{}{} {}",
self.category.icon(),
time,
importance_marker,
manual_marker,
self.category.display_name(),
truncate_str(&self.content, MAX_DISPLAY_LENGTH)
)
}
pub fn format_for_prompt(&self) -> String {
let category_name = self.category.display_name();
if self.content.len() > MAX_MEMORY_CONTENT_LENGTH {
format!("{}: {}...", category_name, truncate(&self.content, MAX_MEMORY_CONTENT_LENGTH - 3))
} else {
format!("{}: {}", category_name, self.content)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutoMemory {
pub entries: Vec<MemoryEntry>,
#[serde(default)]
pub config: MemoryConfig,
#[serde(default = "default_max_entries")]
pub max_entries: usize,
#[serde(default = "default_min_importance")]
pub min_importance: f64,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(skip)]
search_index: Option<SearchIndex>,
}
#[derive(Debug, Clone)]
struct SearchIndex {
content_lower: Vec<String>,
by_category: HashMap<MemoryCategory, Vec<usize>>,
by_importance: Vec<usize>,
#[allow(dead_code)]
word_freq: HashMap<String, usize>,
}
impl SearchIndex {
fn build(entries: &[MemoryEntry]) -> Self {
let content_lower: Vec<String> = entries
.iter()
.map(|e| e.content.to_lowercase())
.collect();
let mut by_category: HashMap<MemoryCategory, Vec<usize>> = HashMap::new();
for (i, entry) in entries.iter().enumerate() {
by_category.entry(entry.category).or_default().push(i);
}
let mut by_importance: Vec<usize> = (0..entries.len()).collect();
by_importance.sort_by(|a, b| {
entries[*b].importance.partial_cmp(&entries[*a].importance)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut word_freq: HashMap<String, usize> = HashMap::new();
for content in &content_lower {
for word in content.split_whitespace() {
*word_freq.entry(word.to_string()).or_default() += 1;
}
}
Self {
content_lower,
by_category,
by_importance,
word_freq,
}
}
#[allow(dead_code)]
fn get_lower(&self, idx: usize) -> &str {
&self.content_lower[idx]
}
fn search(&self, _entries: &[MemoryEntry], query_lower: &str, limit: Option<usize>) -> Vec<usize> {
let matches: Vec<usize> = self.by_importance
.iter()
.filter(|&idx| self.content_lower[*idx].contains(query_lower))
.copied()
.collect();
if let Some(max) = limit {
matches.into_iter().take(max).collect()
} else {
matches
}
}
fn search_multi(&self, keywords_lower: &[String]) -> Vec<usize> {
self.by_importance
.iter()
.filter(|&idx| {
let content = &self.content_lower[*idx];
keywords_lower.iter().any(|k| content.contains(k))
})
.copied()
.collect()
}
#[allow(dead_code)]
fn rebuild(&mut self, entries: &[MemoryEntry]) {
*self = Self::build(entries);
}
}
fn default_max_entries() -> usize { 100 }
fn default_min_importance() -> f64 { 30.0 }
fn default_enabled() -> bool { true }
impl Default for AutoMemory {
fn default() -> Self {
let config = MemoryConfig::default();
Self {
entries: Vec::new(),
config: config.clone(),
max_entries: config.max_entries,
min_importance: config.min_importance,
enabled: config.enabled,
search_index: None,
}
}
}
impl AutoMemory {
pub fn new() -> Self {
Self::default()
}
fn ensure_index(&mut self) {
if self.search_index.is_none() {
self.rebuild_index();
}
}
pub fn rebuild_index(&mut self) {
self.search_index = Some(SearchIndex::build(&self.entries));
}
fn invalidate_index(&mut self) {
self.search_index = None;
}
pub fn with_config(config: MemoryConfig) -> Self {
Self {
entries: Vec::new(),
config: config.clone(),
max_entries: config.max_entries,
min_importance: config.min_importance,
enabled: config.enabled,
search_index: None,
}
}
pub fn minimal() -> Self {
Self::with_config(MemoryConfig::minimal())
}
pub fn archival() -> Self {
Self::with_config(MemoryConfig::archival())
}
pub fn add(&mut self, entry: MemoryEntry) {
self.entries.push(entry);
self.invalidate_index(); self.prune();
}
pub fn add_memory(
&mut self,
category: MemoryCategory,
content: String,
source_session: Option<String>,
) {
if self.has_similar(&content) {
return;
}
if let Some(conflict_idx) = self.find_conflict(&content, category) {
let old_content = self.entries[conflict_idx].content.clone();
log::debug!("Memory conflict detected: '{}' supersedes '{}'", content, old_content);
self.entries.remove(conflict_idx);
self.invalidate_index();
}
let entry = MemoryEntry::new(category, content, source_session);
self.add(entry);
}
fn find_conflict(&self, new_content: &str, category: MemoryCategory) -> Option<usize> {
let new_lower = new_content.to_lowercase();
let new_words: std::collections::HashSet<&str> = new_lower.split_whitespace().collect();
let has_change_signal = has_contradiction_signal("", &new_lower);
let overlap_threshold = if has_change_signal {
CONFLICT_OVERLAY_THRESHOLD_WITH_SIGNAL
} else {
CONFLICT_OVERLAY_THRESHOLD
};
for (i, entry) in self.entries.iter().enumerate() {
if entry.category != category {
continue;
}
let entry_lower = entry.content.to_lowercase();
let entry_words: std::collections::HashSet<&str> = entry_lower.split_whitespace().collect();
let intersection = new_words.intersection(&entry_words).count();
let min_len = new_words.len().min(entry_words.len());
if min_len == 0 {
continue;
}
let topic_overlap = intersection as f64 / min_len as f64;
let jaccard = Self::calculate_similarity(&entry_lower, &new_lower);
if topic_overlap > overlap_threshold && jaccard < SIMILARITY_THRESHOLD {
if has_contradiction_signal(&entry_lower, &new_lower) {
return Some(i);
}
}
if has_change_signal {
let old_key_terms: Vec<&str> = entry_words.iter()
.filter(|w| w.len() > 2)
.copied()
.collect();
let referenced = old_key_terms.iter()
.any(|term| new_lower.contains(term));
if referenced {
return Some(i);
}
}
}
None
}
pub fn has_similar(&self, content: &str) -> bool {
let content_lower = content.to_lowercase();
if content_lower.len() < MIN_SIMILARITY_LENGTH {
return false;
}
self.entries.iter().any(|e| {
let entry_lower = e.content.to_lowercase();
if entry_lower == content_lower {
return true;
}
if entry_lower.len() < MIN_SIMILARITY_LENGTH {
return false;
}
let similarity = Self::calculate_similarity(&entry_lower, &content_lower);
similarity >= SIMILARITY_THRESHOLD
})
}
fn calculate_similarity(a: &str, b: &str) -> f64 {
use std::collections::HashSet;
let a_words: HashSet<&str> = a.split_whitespace().collect();
let b_words: HashSet<&str> = b.split_whitespace().collect();
if a_words.is_empty() || b_words.is_empty() {
return 0.0;
}
let intersection = a_words.intersection(&b_words).count();
let union = a_words.union(&b_words).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
pub fn prune(&mut self) {
if self.entries.len() <= self.max_entries {
return;
}
let (manual_entries, auto_entries): (Vec<_>, Vec<_>) = self.entries
.iter()
.cloned()
.partition(|e| e.is_manual);
let mut sorted_auto = auto_entries;
sorted_auto.sort_by(|a, b| {
let importance_cmp = b.importance.partial_cmp(&a.importance)
.unwrap_or(std::cmp::Ordering::Equal);
if importance_cmp == std::cmp::Ordering::Equal {
b.last_referenced.cmp(&a.last_referenced)
} else {
importance_cmp
}
});
let kept_auto: Vec<_> = sorted_auto
.into_iter()
.filter(|e| e.importance >= self.min_importance)
.take(self.max_entries.saturating_sub(manual_entries.len()))
.collect();
self.entries = manual_entries.into_iter().chain(kept_auto).collect();
if self.entries.len() > self.max_entries {
self.entries.sort_by(|a, b| {
let importance_cmp = b.importance.partial_cmp(&a.importance)
.unwrap_or(std::cmp::Ordering::Equal);
if importance_cmp == std::cmp::Ordering::Equal {
b.last_referenced.cmp(&a.last_referenced)
} else {
importance_cmp
}
});
self.entries.truncate(self.max_entries);
}
self.invalidate_index(); }
pub fn by_category(&self, category: MemoryCategory) -> Vec<&MemoryEntry> {
self.entries.iter().filter(|e| e.category == category).collect()
}
pub fn by_category_fast(&mut self, category: MemoryCategory) -> Vec<&MemoryEntry> {
self.ensure_index();
if let Some(ref index) = self.search_index {
index.by_category.get(&category)
.map(|indices| indices.iter().map(|&i| &self.entries[i]).collect())
.unwrap_or_default()
} else {
self.by_category(category)
}
}
pub fn top_n(&self, n: usize) -> Vec<&MemoryEntry> {
let mut sorted: Vec<_> = self.entries.iter().collect();
sorted.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
sorted.into_iter().take(n).collect()
}
pub fn top_n_fast(&mut self, n: usize) -> Vec<&MemoryEntry> {
self.ensure_index();
if let Some(ref index) = self.search_index {
index.by_importance
.iter()
.take(n)
.map(|&i| &self.entries[i])
.collect()
} else {
self.top_n(n)
}
}
pub fn search(&self, query: &str) -> Vec<&MemoryEntry> {
self.search_with_limit(query, None)
}
pub fn search_with_limit(&self, query: &str, limit: Option<usize>) -> Vec<&MemoryEntry> {
let query_lower = query.to_lowercase();
let mut results: Vec<_> = self.entries
.iter()
.filter(|e| {
e.content.to_lowercase().contains(&query_lower) ||
e.tags.iter().any(|t| t.to_lowercase().contains(&query_lower))
})
.collect();
results.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
if let Some(max) = limit {
results.into_iter().take(max).collect()
} else {
results
}
}
pub fn search_fast(&mut self, query: &str, limit: Option<usize>) -> Vec<&MemoryEntry> {
self.ensure_index();
let query_lower = query.to_lowercase();
if let Some(ref index) = self.search_index {
let indices = index.search(&self.entries, &query_lower, limit);
indices.iter().map(|&i| &self.entries[i]).collect()
} else {
self.search_with_limit(query, limit)
}
}
pub fn search_multi(&self, keywords: &[&str]) -> Vec<&MemoryEntry> {
if keywords.is_empty() {
return Vec::new();
}
let keywords_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
self.entries
.iter()
.filter(|e| {
let content_lower = e.content.to_lowercase();
keywords_lower.iter().any(|k| content_lower.contains(k))
})
.collect()
}
pub fn search_multi_fast(&mut self, keywords: &[&str]) -> Vec<&MemoryEntry> {
if keywords.is_empty() {
return Vec::new();
}
self.ensure_index();
let keywords_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
if let Some(ref index) = self.search_index {
let indices = index.search_multi(&keywords_lower);
indices.iter().map(|&i| &self.entries[i]).collect()
} else {
self.search_multi(keywords)
}
}
pub fn add_batch(&mut self, entries: Vec<MemoryEntry>) {
for entry in entries {
if !self.has_similar(&entry.content) {
self.entries.push(entry);
}
}
self.prune();
}
pub fn update_references(&mut self, messages: &[Message]) {
let increment = self.config.reference_increment;
let texts_lower: Vec<String> = messages
.iter()
.filter_map(Self::extract_message_text_lower)
.collect();
let entry_contents_lower: Vec<String> = self.entries
.iter()
.map(|e| e.content.to_lowercase())
.collect();
for (i, entry) in self.entries.iter_mut().enumerate() {
let entry_lower = &entry_contents_lower[i];
if texts_lower.iter().any(|t| t.contains(entry_lower)) {
entry.mark_referenced_with_increment(increment);
}
}
}
fn extract_message_text_lower(msg: &Message) -> Option<String> {
match &msg.content {
crate::providers::MessageContent::Text(t) => Some(t.to_lowercase()),
crate::providers::MessageContent::Blocks(blocks) => {
let text = blocks
.iter()
.filter_map(|b| {
if let crate::providers::ContentBlock::Text { text } = b {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" ");
Some(text.to_lowercase())
}
}
}
pub fn generate_prompt_summary(&self, max_entries: usize) -> String {
if self.entries.is_empty() {
return String::new();
}
let top_entries = self.top_n(max_entries);
if top_entries.is_empty() {
return String::new();
}
let mut summary = String::from("【自动记忆摘要】\n\n");
let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
for entry in top_entries {
by_cat.entry(entry.category).or_default().push(entry);
}
for (cat, entries) in by_cat {
summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
for entry in entries {
summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
}
summary.push('\n');
}
summary
}
pub fn generate_contextual_summary(&self, context: &str, max_entries: usize) -> String {
let keywords = extract_context_keywords(context);
self.generate_contextual_summary_with_keywords(&keywords, max_entries)
}
pub fn generate_contextual_summary_with_keywords(&self, context_keywords: &[String], max_entries: usize) -> String {
if self.entries.is_empty() {
return String::new();
}
let mut scored: Vec<(&MemoryEntry, f64)> = self.entries
.iter()
.map(|entry| {
let relevance = compute_relevance(entry, &context_keywords);
(entry, relevance)
})
.collect();
scored.sort_by(|a, b| {
if a.0.is_manual && !b.0.is_manual {
return std::cmp::Ordering::Less;
}
if !a.0.is_manual && b.0.is_manual {
return std::cmp::Ordering::Greater;
}
let score_a = a.1 * CONTEXT_RELEVANCE_WEIGHT + (a.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
let score_b = b.1 * CONTEXT_RELEVANCE_WEIGHT + (b.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
});
let selected: Vec<&MemoryEntry> = scored
.iter()
.take(max_entries)
.map(|(entry, _)| *entry)
.collect();
if selected.is_empty() {
return String::new();
}
let mut summary = String::from("【跨会话记忆】\n\n");
let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
for entry in selected {
by_cat.entry(entry.category).or_default().push(entry);
}
for (cat, entries) in by_cat {
summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
for entry in entries {
summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
}
summary.push('\n');
}
summary
}
pub async fn generate_contextual_summary_async(
&self,
context: &str,
max_entries: usize,
fast_provider: Option<&dyn crate::providers::Provider>,
) -> String {
if self.entries.is_empty() {
return String::new();
}
let context_keywords = if let Some(provider) = fast_provider {
extract_keywords_hybrid(context, Some(provider)).await
} else {
extract_context_keywords(context)
};
let mut scored: Vec<(&MemoryEntry, f64)> = self.entries
.iter()
.map(|entry| {
let relevance = compute_relevance(entry, &context_keywords);
(entry, relevance)
})
.collect();
scored.sort_by(|a, b| {
if a.0.is_manual && !b.0.is_manual {
return std::cmp::Ordering::Less;
}
if !a.0.is_manual && b.0.is_manual {
return std::cmp::Ordering::Greater;
}
let score_a = a.1 * CONTEXT_RELEVANCE_WEIGHT + (a.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
let score_b = b.1 * CONTEXT_RELEVANCE_WEIGHT + (b.0.importance / MAX_IMPORTANCE_CEILING) * CONTEXT_IMPORTANCE_WEIGHT;
score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
});
let selected: Vec<&MemoryEntry> = scored
.iter()
.take(max_entries)
.map(|(entry, _)| *entry)
.collect();
if selected.is_empty() {
return String::new();
}
let mut summary = String::from("【跨会话记忆】\n\n");
let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
for entry in selected {
by_cat.entry(entry.category).or_default().push(entry);
}
for (cat, entries) in by_cat {
summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
for entry in entries {
summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
}
summary.push('\n');
}
summary
}
pub fn format_all(&self) -> String {
if self.entries.is_empty() {
return "[no memories accumulated]".to_string();
}
let mut result = String::from("Accumulated memories:\n\n");
let mut sorted: Vec<_> = self.entries.iter().collect();
sorted.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal));
for entry in sorted {
result.push_str(&entry.format_line());
result.push('\n');
}
result
}
pub fn generate_statistics(&self) -> MemoryStatistics {
let total = self.entries.len();
let manual = self.entries.iter().filter(|e| e.is_manual).count();
let auto = total - manual;
let by_category: HashMap<MemoryCategory, usize> = self.entries
.iter()
.fold(HashMap::new(), |mut acc, e| {
*acc.entry(e.category).or_default() += 1;
acc
});
let avg_importance = if total > 0 {
self.entries.iter().map(|e| e.importance).sum::<f64>() / total as f64
} else {
0.0
};
let oldest = self.entries
.iter()
.min_by_key(|e| e.created_at)
.map(|e| e.created_at);
let newest = self.entries
.iter()
.max_by_key(|e| e.created_at)
.map(|e| e.created_at);
let highly_referenced = self.entries
.iter()
.filter(|e| e.reference_count >= 3)
.count();
MemoryStatistics {
total,
manual,
auto,
by_category,
avg_importance,
oldest,
newest,
highly_referenced,
}
}
pub fn clear(&mut self) {
self.entries.clear();
self.invalidate_index();
}
pub fn remove(&mut self, id: &str) -> bool {
let idx = self.entries.iter().position(|e| e.id == id);
if let Some(i) = idx {
self.entries.remove(i);
self.invalidate_index();
true
} else {
false
}
}
pub fn apply_time_decay(&mut self) {
let now = Utc::now();
let decay_start_days = self.config.decay_start_days;
let decay_rate = self.config.decay_rate;
let decay_period_days = 30;
for entry in &mut self.entries {
if entry.is_manual {
continue;
}
let days_since_reference = (now - entry.last_referenced)
.num_days()
.max(0);
if days_since_reference > decay_start_days {
let decay_periods = (days_since_reference - decay_start_days) / decay_period_days;
let decay_factor = decay_rate.powi(decay_periods as i32);
entry.importance *= decay_factor;
entry.importance = entry.importance.max(self.min_importance * 0.5);
}
}
self.prune();
}
}
#[derive(Debug, Clone)]
pub struct MemoryStatistics {
pub total: usize,
pub manual: usize,
pub auto: usize,
pub by_category: HashMap<MemoryCategory, usize>,
pub avg_importance: f64,
pub oldest: Option<DateTime<Utc>>,
pub newest: Option<DateTime<Utc>>,
pub highly_referenced: usize,
}
impl MemoryStatistics {
pub fn format_summary(&self) -> String {
use std::fmt::Write;
let mut output = String::new();
writeln!(output, "记忆统计:").unwrap();
writeln!(output, " 总计: {} 条", self.total).unwrap();
writeln!(output, " ├─ 手动添加: {} 条", self.manual).unwrap();
writeln!(output, " └─ 自动检测: {} 条", self.auto).unwrap();
writeln!(output).unwrap();
writeln!(output, "分类统计:").unwrap();
for (cat, count) in &self.by_category {
writeln!(output, " {} {}: {} 条", cat.icon(), cat.display_name(), count).unwrap();
}
writeln!(output).unwrap();
writeln!(output, "质量指标:").unwrap();
writeln!(output, " 平均重要性: {:.1} 分", self.avg_importance).unwrap();
writeln!(output, " 高频引用: {} 条 (≥3次)", self.highly_referenced).unwrap();
if let Some(oldest) = self.oldest {
let days = (Utc::now() - oldest).num_days();
writeln!(output, " 记忆跨度: {} 天", days).unwrap();
}
output
}
}
pub struct MemoryFileLock {
lock_path: PathBuf,
locked: bool,
}
impl MemoryFileLock {
pub fn new(base_dir: &Path) -> Self {
Self {
lock_path: base_dir.join("memory.lock"),
locked: false,
}
}
pub fn acquire(&mut self, timeout_ms: u64) -> Result<bool> {
if self.locked {
return Ok(true); }
let start = std::time::Instant::now();
while start.elapsed().as_millis() < timeout_ms as u128 {
match fs::File::create_new(&self.lock_path) {
Ok(_) => {
let lock_info = format!(
"{}:{}",
std::process::id(),
Utc::now().to_rfc3339()
);
fs::write(&self.lock_path, lock_info)?;
self.locked = true;
return Ok(true);
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if self.is_stale_lock()? {
self.remove_stale_lock()?;
}
std::thread::sleep(std::time::Duration::from_millis(50));
}
Err(e) => {
return Err(e.into());
}
}
}
Ok(false) }
fn is_stale_lock(&self) -> Result<bool> {
if !self.lock_path.exists() {
return Ok(false);
}
let metadata = fs::metadata(&self.lock_path)?;
let modified = metadata.modified()?;
let age = std::time::SystemTime::now()
.duration_since(modified)
.unwrap_or(std::time::Duration::ZERO);
Ok(age > std::time::Duration::from_secs(30))
}
fn remove_stale_lock(&self) -> Result<()> {
if self.lock_path.exists() {
fs::remove_file(&self.lock_path)?;
}
Ok(())
}
pub fn release(&mut self) -> Result<()> {
if self.locked {
fs::remove_file(&self.lock_path)?;
self.locked = false;
}
Ok(())
}
}
impl Drop for MemoryFileLock {
fn drop(&mut self) {
let _ = self.release();
}
}
pub struct MemoryStorage {
base_dir: PathBuf,
project_root: Option<PathBuf>,
lock: MemoryFileLock,
}
impl MemoryStorage {
pub fn new(project_root: Option<&Path>) -> Result<Self> {
let base_dir = Self::get_base_dir()?;
let lock = MemoryFileLock::new(&base_dir);
Ok(Self {
base_dir,
project_root: project_root.map(|p| p.to_path_buf()),
lock,
})
}
pub fn with_lock_timeout(project_root: Option<&Path>, timeout_ms: u64) -> Result<Self> {
let mut storage = Self::new(project_root)?;
storage.lock.acquire(timeout_ms)?;
Ok(storage)
}
fn get_base_dir() -> Result<PathBuf> {
let home = std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.ok_or_else(|| anyhow::anyhow!("HOME or USERPROFILE not set"))?;
let mut p = PathBuf::from(home);
p.push(".matrix");
Ok(p)
}
pub fn global_memory_path(&self) -> PathBuf {
self.base_dir.join("memory.json")
}
pub fn project_memory_path(&self) -> Option<PathBuf> {
self.project_root.as_ref().map(|p| p.join(".matrix/memory.json"))
}
pub fn config_path(&self) -> PathBuf {
self.base_dir.join("memory_config.json")
}
fn ensure_dirs(&self) -> Result<()> {
fs::create_dir_all(&self.base_dir)?;
if let Some(root) = &self.project_root {
let memory_dir = root.join(".matrix");
fs::create_dir_all(memory_dir)?;
}
Ok(())
}
fn acquire_lock(&mut self) -> Result<()> {
self.lock.acquire(5000)?; Ok(())
}
fn release_lock(&mut self) -> Result<()> {
self.lock.release()?;
Ok(())
}
pub fn load_global(&self) -> Result<AutoMemory> {
let path = self.global_memory_path();
if !path.exists() {
return Ok(AutoMemory::new());
}
let data = fs::read_to_string(&path)?;
let memory: AutoMemory = serde_json::from_str(&data)?;
Ok(memory)
}
pub fn load_project(&self) -> Result<Option<AutoMemory>> {
let path = self.project_memory_path();
match path {
Some(p) if p.exists() => {
let data = fs::read_to_string(&p)?;
let memory: AutoMemory = serde_json::from_str(&data)?;
Ok(Some(memory))
}
_ => Ok(None),
}
}
pub fn load_combined(&self) -> Result<AutoMemory> {
let mut combined = self.load_global()?;
if let Some(project) = self.load_project()? {
for entry in project.entries {
let mut tagged_entry = entry;
if !tagged_entry.tags.contains(&"project".to_string()) {
tagged_entry.tags.push("project".to_string());
}
combined.entries.push(tagged_entry);
}
combined.prune();
}
Ok(combined)
}
pub fn save_global(&mut self, memory: &AutoMemory) -> Result<()> {
self.acquire_lock()?;
self.ensure_dirs()?;
let path = self.global_memory_path();
let json = serde_json::to_string_pretty(memory)?;
let tmp = path.with_extension("json.tmp");
fs::write(&tmp, json)?;
fs::rename(&tmp, &path)?;
self.release_lock()?;
Ok(())
}
pub fn save_project(&mut self, memory: &AutoMemory) -> Result<()> {
self.acquire_lock()?;
self.ensure_dirs()?;
let path = self.project_memory_path()
.ok_or_else(|| anyhow::anyhow!("no project root"))?;
let json = serde_json::to_string_pretty(memory)?;
let tmp = path.with_extension("json.tmp");
fs::write(&tmp, json)?;
fs::rename(&tmp, &path)?;
self.release_lock()?;
Ok(())
}
pub fn save_config(&mut self, config: &MemoryConfig) -> Result<()> {
self.ensure_dirs()?;
let path = self.config_path();
let json = serde_json::to_string_pretty(config)?;
fs::write(&path, json)?;
Ok(())
}
pub fn load_config(&self) -> Result<MemoryConfig> {
let path = self.config_path();
if !path.exists() {
return Ok(MemoryConfig::default());
}
let data = fs::read_to_string(&path)?;
let config: MemoryConfig = serde_json::from_str(&data)?;
Ok(config)
}
pub fn add_entry(&mut self, entry: MemoryEntry, is_project_specific: bool) -> Result<()> {
self.acquire_lock()?;
if is_project_specific {
let mut project = self.load_project()?.unwrap_or_else(AutoMemory::new);
project.add(entry);
self.save_project_locked(&project)?;
} else {
let mut global = self.load_global()?;
global.add(entry);
self.save_global_locked(&global)?;
}
self.release_lock()?;
Ok(())
}
pub fn remove_entry(&mut self, id: &str, is_project_specific: bool) -> Result<bool> {
self.acquire_lock()?;
let removed = if is_project_specific {
if let Some(mut project) = self.load_project()? {
let removed = project.remove(id);
if removed {
self.save_project_locked(&project)?;
}
removed
} else {
false
}
} else {
let mut global = self.load_global()?;
let removed = global.remove(id);
if removed {
self.save_global_locked(&global)?;
}
removed
};
self.release_lock()?;
Ok(removed)
}
fn save_global_locked(&self, memory: &AutoMemory) -> Result<()> {
let path = self.global_memory_path();
let json = serde_json::to_string_pretty(memory)?;
let tmp = path.with_extension("json.tmp");
fs::write(&tmp, json)?;
fs::rename(&tmp, &path)?;
Ok(())
}
fn save_project_locked(&self, memory: &AutoMemory) -> Result<()> {
let path = self.project_memory_path()
.ok_or_else(|| anyhow::anyhow!("no project root"))?;
let json = serde_json::to_string_pretty(memory)?;
let tmp = path.with_extension("json.tmp");
fs::write(&tmp, json)?;
fs::rename(&tmp, &path)?;
Ok(())
}
}
pub fn calculate_similarity(a: &str, b: &str) -> f64 {
AutoMemory::calculate_similarity(a, b)
}
pub fn extract_context_keywords(context: &str) -> Vec<String> {
use std::collections::HashSet;
let stop_words: HashSet<&str> = [
"的", "了", "是", "在", "我", "有", "和", "就", "不", "人", "都", "一", "一个",
"上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好",
"自己", "这", "他", "她", "它", "们", "那", "些", "什么", "怎么", "如何", "请",
"能", "可以", "需要", "应该", "可能", "因为", "所以", "但是", "然后", "还是",
"已经", "正在", "将要", "曾经", "一下", "一点", "一些", "所有", "每个", "任何",
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did", "will", "would", "could",
"should", "may", "might", "can", "shall", "to", "of", "in", "for",
"on", "with", "at", "by", "from", "as", "into", "through", "during",
"before", "after", "above", "below", "between", "and", "but", "or",
"not", "no", "so", "if", "then", "than", "too", "very", "just",
"this", "that", "these", "those", "it", "its", "i", "me", "my",
"we", "our", "you", "your", "he", "his", "she", "her", "they", "their",
"please", "help", "need", "want", "make", "get", "let", "use",
].iter().copied().collect();
let tech_patterns: HashSet<&str> = [
"api", "cli", "gui", "tui", "web", "http", "json", "xml", "sql", "db",
"git", "npm", "cargo", "rust", "js", "ts", "py", "go", "java", "cpp",
"cpu", "gpu", "io", "fs", "os", "ui", "ux", "ai", "ml", "dl",
"rs", "js", "ts", "py", "go", "java", "c", "h", "cpp", "hpp",
"json", "yaml", "yml", "toml", "md", "txt", "html", "css", "scss",
"bug", "fix", "add", "new", "old", "use", "run", "build", "test",
"code", "data", "file", "dir", "path", "name", "type", "value",
].iter().copied().collect();
let lower = context.to_lowercase();
let mut keywords: HashSet<String> = HashSet::new();
for word in lower.split_whitespace() {
let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
if cleaned.len() >= 2 && !stop_words.contains(cleaned.as_str()) {
keywords.insert(cleaned.clone());
}
if tech_patterns.contains(cleaned.as_str()) {
keywords.insert(cleaned);
}
}
let chinese_chars: Vec<char> = lower
.chars()
.filter(|c| *c >= '\u{4E00}' && *c <= '\u{9FFF}') .collect();
for window_size in 2..=4 {
if chinese_chars.len() >= window_size {
for window in chinese_chars.windows(window_size) {
let phrase: String = window.iter().collect();
let has_stop = stop_words.iter().any(|sw| phrase.contains(sw));
if !has_stop && phrase.len() >= window_size {
keywords.insert(phrase);
}
}
}
}
let patterns = [
r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z]{1,4}", r"[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*", r"[A-Z][a-z]+[A-Z][a-zA-Z]*", r"[a-z][a-z0-9]*_[a-z][a-z0-9_]*", r"[0-9]+[kKmMgGtT][bB]?", ];
for pattern in patterns {
if let Ok(re) = regex::Regex::new(pattern) {
for cap in re.find_iter(&lower) {
keywords.insert(cap.as_str().to_string());
}
}
}
let mut result: Vec<String> = keywords.into_iter().collect();
result.sort_by(|a, b| b.len().cmp(&a.len()));
result.truncate(15);
result
}
fn compute_relevance(entry: &MemoryEntry, context_keywords: &[String]) -> f64 {
if context_keywords.is_empty() {
return 0.0;
}
let content_lower = entry.content.to_lowercase();
let matches = context_keywords
.iter()
.filter(|kw| content_lower.contains(kw.as_str()))
.count();
let keyword_score = matches as f64 / context_keywords.len() as f64;
let tag_matches = entry.tags
.iter()
.filter(|tag| {
let tag_lower = tag.to_lowercase();
context_keywords.iter().any(|kw| tag_lower.contains(kw.as_str()))
})
.count();
let tag_score = if tag_matches > 0 { 0.2 } else { 0.0 };
(keyword_score + tag_score).min(1.0)
}
fn has_contradiction_signal(old: &str, new: &str) -> bool {
let change_signals = [
"改用", "换成", "替换", "改为", "切换到", "迁移到",
"不再使用", "弃用", "放弃", "取消",
"switched to", "replaced", "migrated to", "changed to",
"no longer", "deprecated", "abandoned",
];
for signal in &change_signals {
if new.contains(signal) {
return true;
}
}
let action_verbs = [
"决定使用", "选择使用", "采用", "使用",
"decided to use", "chose", "using", "adopted",
];
for verb in &action_verbs {
if old.contains(verb) && new.contains(verb) {
return true;
}
}
let pref_verbs = ["偏好", "喜欢", "prefer", "like"];
for verb in &pref_verbs {
if old.contains(verb) && new.contains(verb) {
return true;
}
}
false
}
#[async_trait::async_trait]
pub trait MemoryExtractor: Send + Sync {
async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
fn model_name(&self) -> &str;
}
pub struct AiMemoryExtractor {
provider: Box<dyn crate::providers::Provider>,
model: String,
}
impl AiMemoryExtractor {
pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
Self { provider, model }
}
}
const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。你的任务是从对话中识别并提取值得长期记忆的关键信息。
记忆类型:
1. Decision(决策): 项目或技术选型的决定,如"决定使用 PostgreSQL"
2. Preference(偏好): 用户习惯或偏好,如"我喜欢用 vim"
3. Solution(解决方案): 解决问题的具体方法,如"通过添加 middleware 修复 bug"
4. Finding(发现): 重要发现或信息,如"API 端点在 /api/v2"
5. Technical(技术): 技术栈或框架信息,如"使用 React Query 做数据获取"
6. Structure(结构): 项目结构信息,如"入口文件是 src/index.ts"
提取原则:
- 只提取有价值、可复用的信息
- 避免提取临时性、一次性信息
- 避免提取过于具体的代码细节
- 每条记忆应简洁明确(一句话)
- 最多提取 5 条记忆
输出格式(严格 JSON):
```json
{
"memories": [
{
"category": "decision",
"content": "决定使用 PostgreSQL 作为主数据库",
"importance": 90
},
{
"category": "preference",
"content": "用户偏好 TypeScript 而非 JavaScript",
"importance": 70
}
]
}
```
如果没有值得记忆的内容,返回:
```json
{"memories": []}
```
直接输出 JSON,不要加代码块包裹。"#;
#[async_trait::async_trait]
impl MemoryExtractor for AiMemoryExtractor {
async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
use crate::providers::{ChatRequest, Message, MessageContent, Role};
let truncated_text = if text.len() > 4000 {
truncate_str(text, 4000)
} else {
text.to_string()
};
let request = ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(format!(
"请从以下对话中提取值得记忆的关键信息:\n\n{}",
truncated_text
)),
}],
tools: vec![], system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
think: false, max_tokens: 512, server_tools: vec![],
enable_caching: false,
};
let response = self.provider.chat(request).await?;
let response_text = response.content
.iter()
.filter_map(|block| {
if let crate::providers::ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
parse_memory_response(&response_text, session_id)
}
fn model_name(&self) -> &str {
&self.model
}
}
fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
let cleaned = json_text
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
#[derive(serde::Deserialize)]
struct MemoryResponse {
memories: Vec<MemoryItem>,
}
#[derive(serde::Deserialize)]
struct MemoryItem {
category: String,
content: String,
#[serde(default)]
importance: f64,
}
let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
let entries = parsed.memories
.into_iter()
.filter_map(|item| {
let category = match item.category.to_lowercase().as_str() {
"decision" => MemoryCategory::Decision,
"preference" => MemoryCategory::Preference,
"solution" => MemoryCategory::Solution,
"finding" => MemoryCategory::Finding,
"technical" => MemoryCategory::Technical,
"structure" => MemoryCategory::Structure,
_ => return None, };
if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
return None;
}
let mut entry = MemoryEntry::new(
category,
item.content,
session_id.map(|s| s.to_string()),
);
if item.importance > 0.0 {
entry.importance = item.importance.clamp(0.0, 100.0);
}
Some(entry)
})
.collect();
Ok(deduplicate_entries(entries))
}
const KEYWORD_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个关键词提取助手。你的任务是从用户输入中提取有意义的关键词,用于检索相关记忆。
提取原则:
1. 只提取有实际意义的词汇(技术名词、项目名、概念等)
2. 过滤掉常见的停用词(的、是、在、我、你、the、a、is 等)
3. 保留专有名词和技术术语
4. 中英文混合输入时,两种语言的关键词都提取
5. 提取 3-10 个关键词
输出格式(严格 JSON):
```json
{
"keywords": ["数据库", "PostgreSQL", "优化", "查询"]
}
```
如果没有有意义的关键词,返回:
```json
{"keywords": []}
```
直接输出 JSON,不要加代码块包裹。"#;
pub async fn extract_keywords_with_ai(
context: &str,
provider: &dyn crate::providers::Provider,
) -> Result<Vec<String>> {
use crate::providers::{ChatRequest, Message, MessageContent, Role};
let truncated = if context.len() > 1000 {
truncate_str(context, 1000)
} else {
context.to_string()
};
let request = ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(format!(
"请从以下文本中提取关键词:\n\n{}",
truncated
)),
}],
tools: vec![],
system: Some(KEYWORD_EXTRACT_SYSTEM_PROMPT.to_string()),
think: false,
max_tokens: 256,
server_tools: vec![],
enable_caching: false,
};
let response = provider.chat(request).await?;
let response_text = response.content
.iter()
.filter_map(|block| {
if let crate::providers::ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
parse_keyword_response(&response_text)
}
fn parse_keyword_response(json_text: &str) -> Result<Vec<String>> {
let cleaned = json_text
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
#[derive(serde::Deserialize)]
struct KeywordResponse {
keywords: Vec<String>,
}
let parsed: KeywordResponse = serde_json::from_str(cleaned)?;
Ok(parsed.keywords
.into_iter()
.filter(|k| k.len() >= 2)
.collect())
}
pub async fn extract_keywords_hybrid(
context: &str,
fast_provider: Option<&dyn crate::providers::Provider>,
) -> Vec<String> {
let mode = AiKeywordMode::from_env();
if mode == AiKeywordMode::Never {
return extract_context_keywords(context);
}
let keywords = if mode == AiKeywordMode::Always {
Vec::new() } else {
extract_context_keywords(context)
};
if !mode.should_use_ai(keywords.len()) {
return keywords;
}
if let Some(provider) = fast_provider {
match extract_keywords_with_ai(context, provider).await {
Ok(ai_keywords) if !ai_keywords.is_empty() => {
log::debug!("AI extracted {} keywords: {:?}", ai_keywords.len(), ai_keywords);
if mode == AiKeywordMode::Auto && !keywords.is_empty() {
let merged = keywords
.into_iter()
.chain(ai_keywords.into_iter())
.collect::<std::collections::HashSet<_>>();
return merged.into_iter().collect();
}
return ai_keywords;
}
Ok(_) => {
log::debug!("AI returned no keywords, keeping rule-based results");
}
Err(e) => {
log::warn!("AI keyword extraction failed: {}, keeping rule-based results", e);
}
}
}
keywords
}
const MEMORY_SUMMARY_SYSTEM_PROMPT: &str = r#"你是一个记忆摘要助手。你的任务是将多条相关记忆合并为一条精炼的摘要记忆。
摘要原则:
1. 保留核心信息,去除冗余细节
2. 使用简洁明确的一句话表达
3. 保留关键的技术名词和决策结论
4. 如果多条记忆主题相同,合并为一条综合性记忆
5. 优先保留高价值的决策和解决方案
输出格式(严格 JSON):
```json
{
"summary": "决定使用 PostgreSQL 作为主数据库,Redis 作为缓存层",
"category": "decision",
"importance": 90
}
```
如果没有值得保留的信息,返回:
```json
{"summary": "", "category": "", "importance": 0}
```
直接输出 JSON,不要加代码块包裹。"#;
const MEMORY_CONFLICT_SYSTEM_PROMPT: &str = r#"你是一个记忆冲突检测助手。你的任务是判断两条记忆是否矛盾或需要更新。
冲突类型:
1. 直接矛盾:两条记忆结论相反(如"使用 PostgreSQL" vs "使用 MySQL")
2. 过时更新:新记忆明确替换旧记忆(如"改用 Redis" 替换 "使用 Memcached")
3. 补充关系:新记忆补充旧记忆(如"PostgreSQL 版本为 15" 补充 "使用 PostgreSQL")
4. 无关关系:两条记忆主题不同,不冲突
输出格式(严格 JSON):
```json
{
"conflict_type": "direct_conflict",
"should_replace": true,
"reason": "两条记忆都是数据库选型决策,但选择了不同的数据库",
"winner": "new"
}
```
conflict_type 可选值:
- "direct_conflict": 直接矛盾,需要选择一条
- "outdated_update": 过时更新,新记忆替换旧记忆
- "supplement": 补充关系,两者可共存
- "no_conflict": 无关关系,不冲突
should_replace: true 表示需要替换旧记忆,false 表示保留两者
winner: "new" 表示新记忆胜出,"old" 表示旧记忆胜出(仅在 direct_conflict 时有意义)
直接输出 JSON,不要加代码块包裹。"#;
const MEMORY_QUALITY_SYSTEM_PROMPT: &str = r#"你是一个记忆质量评估助手。你的任务是评估记忆的长期价值和重要程度。
评估维度:
1. 复用价值:这条信息在未来的���话中会被引用吗?
2. 决策权重:这是重要的项目决策还是次要细节?
3. 时效性:这条信息会很快过时吗?
4. 独特性:这条信息是否足够独特,不与其他记忆重叠?
评分标准:
- 90-100: 核心决策,长期有效,高复用价值(如数据库选型、框架选择)
- 70-89: 重要偏好或解决方案,中等复用价值
- 50-69: 有用的技术信息或发现,时效性中等
- 30-49: 一般性信息,复用价值较低
- 0-29: 过时或过于具体的细节,建议丢弃
输出格式(严格 JSON):
```json
{
"quality_score": 85,
"reason": "这是核心的技术选型决策,长期有效,高复用价值",
"should_keep": true,
"suggested_category": "decision"
}
```
直接输出 JSON,不要加代码块包裹。"#;
const MEMORY_MERGE_SYSTEM_PROMPT: &str = r#"你是一个记忆合并助手。你的任务是将多条相似或相关的记忆合并为一条精炼的记忆。
合并原则:
1. 相同主题的记忆应合并为一条综合性记忆
2. 保留所有关键信息,去除重复内容
3. 使用简洁的一句话表达
4. 合并后的记忆应比原记忆更全面但更简洁
5. 如果记忆完全不相关,返回空结果表示不应合并
输出格式(严格 JSON):
```json
{
"merged_content": "使用 PostgreSQL 作为主数据库(版本15),Redis 作为缓存层,通过连接池优化性能",
"category": "technical",
"importance": 75,
"merged_from_count": 3,
"summary_reason": "三条记忆都与数据库和缓存技术栈相关,合并为一条综合性技术栈记忆"
}
```
如果不应合并,返回:
```json
{"merged_content": "", "category": "", "importance": 0, "merged_from_count": 0, "summary_reason": "记忆主题不同,不应合并"}
```
直接输出 JSON,不要加代码块包裹。"#;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct MemorySummaryResult {
pub summary: String,
pub category: String,
pub importance: f64,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct MemoryConflictResult {
pub conflict_type: String,
pub should_replace: bool,
pub reason: String,
pub winner: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct MemoryQualityResult {
pub quality_score: f64,
pub reason: String,
pub should_keep: bool,
pub suggested_category: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct MemoryMergeResult {
pub merged_content: String,
pub category: String,
pub importance: f64,
pub merged_from_count: usize,
pub summary_reason: String,
}
pub struct AiMemoryProcessor {
provider: Box<dyn crate::providers::Provider>,
model: String,
}
impl AiMemoryProcessor {
pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
Self { provider, model }
}
pub async fn summarize_memories(&self, memories: &[&MemoryEntry]) -> Result<Option<MemoryEntry>> {
if memories.is_empty() {
return Ok(None);
}
let memories_text = memories
.iter()
.map(|m| format!("[{}] {}", m.category.display_name(), m.content))
.collect::<Vec<_>>()
.join("\n");
let request = build_ai_request(
MEMORY_SUMMARY_SYSTEM_PROMPT,
&format!("请将以下记忆合并为一条精炼的摘要:\n\n{}", memories_text),
);
let response = self.provider.chat(request).await?;
let response_text = extract_response_text(&response);
let result: MemorySummaryResult = parse_json_response(&response_text)?;
if result.summary.is_empty() {
return Ok(None);
}
let category = parse_category(&result.category)?;
let mut entry = MemoryEntry::new(category, result.summary, None);
entry.importance = result.importance.clamp(0.0, 100.0);
Ok(Some(entry))
}
pub async fn detect_conflict(&self, old: &MemoryEntry, new: &MemoryEntry) -> Result<MemoryConflictResult> {
let input = format!(
"旧记忆:[{}] {}\n新记忆:[{}] {}\n\n请判断这两条记忆是否存在冲突。",
old.category.display_name(),
old.content,
new.category.display_name(),
new.content
);
let request = build_ai_request(MEMORY_CONFLICT_SYSTEM_PROMPT, &input);
let response = self.provider.chat(request).await?;
let response_text = extract_response_text(&response);
parse_json_response(&response_text)
}
pub async fn assess_quality(&self, memory: &MemoryEntry) -> Result<MemoryQualityResult> {
let input = format!(
"记忆内容:[{}] {}\n\n请评估这条记忆的质量和长期价值。",
memory.category.display_name(),
memory.content
);
let request = build_ai_request(MEMORY_QUALITY_SYSTEM_PROMPT, &input);
let response = self.provider.chat(request).await?;
let response_text = extract_response_text(&response);
parse_json_response(&response_text)
}
pub async fn merge_memories(&self, memories: &[&MemoryEntry]) -> Result<Option<MemoryEntry>> {
if memories.len() < 2 {
return Ok(None);
}
let memories_text = memories
.iter()
.map(|m| format!("[{}] {}", m.category.display_name(), m.content))
.collect::<Vec<_>>()
.join("\n");
let request = build_ai_request(
MEMORY_MERGE_SYSTEM_PROMPT,
&format!("请判断以下记忆是否应该合并,如果应该则生成合并后的记忆:\n\n{}", memories_text),
);
let response = self.provider.chat(request).await?;
let response_text = extract_response_text(&response);
let result: MemoryMergeResult = parse_json_response(&response_text)?;
if result.merged_content.is_empty() || result.merged_from_count == 0 {
return Ok(None);
}
let category = parse_category(&result.category)?;
let mut entry = MemoryEntry::new(category, result.merged_content, None);
entry.importance = result.importance.clamp(0.0, 100.0);
Ok(Some(entry))
}
pub fn model_name(&self) -> &str {
&self.model
}
}
fn build_ai_request(system_prompt: &str, user_input: &str) -> crate::providers::ChatRequest {
use crate::providers::{ChatRequest, Message, MessageContent, Role};
ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(user_input.to_string()),
}],
tools: vec![],
system: Some(system_prompt.to_string()),
think: false,
max_tokens: 512,
server_tools: vec![],
enable_caching: false,
}
}
fn extract_response_text(response: &crate::providers::ChatResponse) -> String {
response.content
.iter()
.filter_map(|block| {
if let crate::providers::ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("")
}
fn parse_json_response<T: serde::de::DeserializeOwned>(json_text: &str) -> Result<T> {
let cleaned = json_text
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
serde_json::from_str(cleaned).map_err(|e| anyhow::anyhow!("JSON parse error: {}", e))
}
fn parse_category(s: &str) -> Result<MemoryCategory> {
match s.to_lowercase().as_str() {
"decision" | "决策" => Ok(MemoryCategory::Decision),
"preference" | "偏好" => Ok(MemoryCategory::Preference),
"solution" | "解决方案" => Ok(MemoryCategory::Solution),
"finding" | "发现" => Ok(MemoryCategory::Finding),
"technical" | "技术" => Ok(MemoryCategory::Technical),
"structure" | "结构" => Ok(MemoryCategory::Structure),
_ => anyhow::bail!("Unknown category: {}", s),
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct AiMemoryConfig {
pub enable_summarization: bool,
pub enable_conflict_detection: bool,
pub enable_quality_assessment: bool,
pub enable_merging: bool,
pub summarize_threshold: usize,
pub quality_threshold: f64,
pub merge_similarity_threshold: f64,
}
impl Default for AiMemoryConfig {
fn default() -> Self {
Self {
enable_summarization: true,
enable_conflict_detection: true,
enable_quality_assessment: false, enable_merging: true,
summarize_threshold: 5,
quality_threshold: 30.0,
merge_similarity_threshold: 0.6,
}
}
}
impl AiMemoryConfig {
pub fn minimal() -> Self {
Self {
enable_summarization: false,
enable_conflict_detection: false,
enable_quality_assessment: false,
enable_merging: false,
summarize_threshold: 10,
quality_threshold: 20.0,
merge_similarity_threshold: 0.8,
}
}
pub fn aggressive() -> Self {
Self {
enable_summarization: true,
enable_conflict_detection: true,
enable_quality_assessment: true,
enable_merging: true,
summarize_threshold: 3,
quality_threshold: 40.0,
merge_similarity_threshold: 0.5,
}
}
pub fn from_env() -> Self {
let enable_all = std::env::var("MEMORY_AI_ALL")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if enable_all {
return Self::aggressive();
}
Self {
enable_summarization: std::env::var("MEMORY_AI_SUMMARY")
.map(|v| v != "false" && v != "0")
.unwrap_or(true),
enable_conflict_detection: std::env::var("MEMORY_AI_CONFLICT")
.map(|v| v != "false" && v != "0")
.unwrap_or(true),
enable_quality_assessment: std::env::var("MEMORY_AI_QUALITY")
.map(|v| v == "true" || v == "1")
.unwrap_or(false),
enable_merging: std::env::var("MEMORY_AI_MERGE")
.map(|v| v != "false" && v != "0")
.unwrap_or(true),
summarize_threshold: std::env::var("MEMORY_SUMMARY_THRESHOLD")
.and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
.unwrap_or(5),
quality_threshold: std::env::var("MEMORY_QUALITY_THRESHOLD")
.and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
.unwrap_or(30.0),
merge_similarity_threshold: std::env::var("MEMORY_MERGE_THRESHOLD")
.and_then(|v| v.parse().map_err(|_| std::env::VarError::NotPresent))
.unwrap_or(0.6),
}
}
}
impl AutoMemory {
pub async fn add_memory_with_ai_conflict(
&mut self,
category: MemoryCategory,
content: String,
source_session: Option<String>,
processor: Option<&AiMemoryProcessor>,
) -> Result<()> {
if self.has_similar(&content) {
return Ok(());
}
let new_entry = MemoryEntry::new(category, content.clone(), source_session);
let potential_conflicts: Vec<(usize, &MemoryEntry)> = self.entries
.iter()
.enumerate()
.filter(|(_, e)| {
e.category == category &&
Self::calculate_similarity(&e.content.to_lowercase(), &content.to_lowercase()) > 0.3
})
.collect();
if let Some(processor) = processor {
for (idx, old_entry) in potential_conflicts {
let result = processor.detect_conflict(old_entry, &new_entry).await?;
if result.should_replace {
log::debug!("AI detected conflict: {} -> replacing '{}' with '{}'",
result.conflict_type, old_entry.content, content);
self.entries.remove(idx);
self.invalidate_index();
break;
}
}
} else {
if let Some(conflict_idx) = self.find_conflict(&content, category) {
self.entries.remove(conflict_idx);
self.invalidate_index();
}
}
self.add(new_entry);
Ok(())
}
pub async fn assess_quality_with_ai(
&mut self,
processor: &AiMemoryProcessor,
config: &AiMemoryConfig,
) -> Result<usize> {
if !config.enable_quality_assessment {
return Ok(0);
}
let indices_to_assess: Vec<usize> = self.entries
.iter()
.enumerate()
.filter(|(_, entry)| !entry.is_manual)
.map(|(i, _)| i)
.collect();
let mut to_remove: Vec<usize> = Vec::new();
let mut importance_updates: Vec<(usize, f64)> = Vec::new();
for i in indices_to_assess {
let entry = &self.entries[i];
let result = processor.assess_quality(entry).await?;
if !result.should_keep || result.quality_score < config.quality_threshold {
log::debug!("AI quality assessment: removing '{}' (score: {:.1}, reason: {})",
entry.content, result.quality_score, result.reason);
to_remove.push(i);
} else {
importance_updates.push((i, result.quality_score));
}
}
for (i, score) in importance_updates {
self.entries[i].importance = score;
}
let removed_count = to_remove.len();
for idx in to_remove.into_iter().rev() {
self.entries.remove(idx);
}
if removed_count > 0 {
self.invalidate_index();
self.prune();
}
Ok(removed_count)
}
pub async fn merge_similar_with_ai(
&mut self,
processor: &AiMemoryProcessor,
config: &AiMemoryConfig,
) -> Result<usize> {
if !config.enable_merging || self.entries.len() < 2 {
return Ok(0);
}
let mut merged_count = 0;
let mut to_remove: Vec<usize> = Vec::new();
let mut new_entries: Vec<MemoryEntry> = Vec::new();
let mut processed: std::collections::HashSet<usize> = std::collections::HashSet::new();
for i in 0..self.entries.len() {
if processed.contains(&i) {
continue;
}
let mut similar_group: Vec<usize> = vec![i];
for j in (i + 1)..self.entries.len() {
if processed.contains(&j) {
continue;
}
let sim = Self::calculate_similarity(
&self.entries[i].content.to_lowercase(),
&self.entries[j].content.to_lowercase(),
);
if sim >= config.merge_similarity_threshold {
similar_group.push(j);
}
}
if similar_group.len() >= 2 {
let group_entries: Vec<&MemoryEntry> = similar_group
.iter()
.map(|&idx| &self.entries[idx])
.collect();
if let Some(merged) = processor.merge_memories(&group_entries).await? {
log::debug!("AI merged {} memories into: '{}'",
similar_group.len(), merged.content);
new_entries.push(merged);
to_remove.extend(similar_group.iter().copied());
processed.extend(similar_group.iter().copied());
merged_count += similar_group.len() - 1;
}
}
}
let mut sorted_remove: Vec<usize> = to_remove;
sorted_remove.sort();
for idx in sorted_remove.into_iter().rev() {
self.entries.remove(idx);
}
for entry in new_entries {
self.entries.push(entry);
}
if merged_count > 0 {
self.invalidate_index();
self.prune();
}
Ok(merged_count)
}
pub async fn generate_ai_summary(
&self,
max_entries: usize,
processor: Option<&AiMemoryProcessor>,
config: Option<&AiMemoryConfig>,
) -> Result<String> {
if self.entries.is_empty() {
return Ok(String::new());
}
let default_config = AiMemoryConfig::default();
let config = config.unwrap_or(&default_config);
if config.enable_summarization
&& let Some(processor) = processor
&& self.entries.len() >= config.summarize_threshold
{
let mut by_category: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
for entry in &self.entries {
by_category.entry(entry.category).or_default().push(entry);
}
let mut summary = String::from("【跨会话记忆 (AI摘要)】\n\n");
for (cat, entries) in by_category {
if entries.is_empty() {
continue;
}
let top_entries: Vec<&MemoryEntry> = entries
.iter()
.take(max_entries.min(entries.len()))
.copied()
.collect();
if let Some(ai_summary) = processor.summarize_memories(&top_entries).await? {
summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
summary.push_str(&format!(" {}\n\n", ai_summary.content));
} else {
summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
for entry in top_entries {
summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
}
summary.push('\n');
}
}
Ok(summary)
} else {
Ok(self.generate_contextual_summary("", max_entries))
}
}
}
pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
let mut entries = Vec::new();
let text_lower = text.to_lowercase();
let patterns: Vec<(MemoryCategory, Vec<&str>)> = vec![
(MemoryCategory::Decision, vec![
"决定", "决定使用", "选择使用", "采用", "decided to", "decision to",
"chose to", "adopted", "选定", "最终选择",
]),
(MemoryCategory::Preference, vec![
"我喜欢", "我偏好", "prefer to", "i prefer", "my preference is",
"习惯用", "我习惯", "usually prefer", "偏好使用",
]),
(MemoryCategory::Solution, vec![
"修复了", "解决了", "fixed by", "solved by", "resolved by",
"通过添加", "通过修改", "通过删除", "解决方法是",
]),
(MemoryCategory::Finding, vec![
"发现", "注意到", "found that", "noticed that", "discovered",
"观察到", "api 端点", "位于", "located at", "关键发现",
]),
(MemoryCategory::Technical, vec![
"使用框架", "using framework", "built with", "基于",
"框架是", "技术栈", "依赖库",
]),
(MemoryCategory::Structure, vec![
"入口文件", "entry point is", "主文件是", "main file",
"配置文件", "config file", "核心文件",
]),
];
for (category, keywords) in patterns {
for keyword in keywords {
if text_lower.contains(keyword) {
let content = extract_memory_content(text, keyword);
if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
let entry = MemoryEntry::new(
category,
content,
session_id.map(|s| s.to_string()),
);
entries.push(entry);
}
}
}
}
deduplicate_entries(entries)
}
pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
detect_memories_fallback(text, session_id)
}
pub async fn detect_memories_with_ai(
text: &str,
session_id: Option<&str>,
extractor: Option<&dyn MemoryExtractor>,
) -> Result<Vec<MemoryEntry>> {
if let Some(ai_extractor) = extractor {
match ai_extractor.extract(text, session_id).await {
Ok(entries) if !entries.is_empty() => {
return Ok(entries);
}
Ok(_) => {
}
Err(_) => {
}
}
}
Ok(detect_memories_fallback(text, session_id))
}
fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
if entries.is_empty() {
return entries;
}
let mut sorted = entries;
sorted.sort_by(|a, b| b.content.len().cmp(&a.content.len()));
let mut unique: Vec<MemoryEntry> = Vec::new();
for entry in sorted {
let entry_lower = entry.content.to_lowercase();
let is_duplicate = unique.iter().any(|existing| {
let existing_lower = existing.content.to_lowercase();
if existing_lower == entry_lower {
return true;
}
let similarity = calculate_similarity(&existing_lower, &entry_lower);
similarity >= 0.8
});
if !is_duplicate {
unique.push(entry);
}
if unique.len() >= MAX_DETECTED_ENTRIES {
break;
}
}
unique
}
fn extract_memory_content(text: &str, keyword: &str) -> String {
let text_lower = text.to_lowercase();
let keyword_lower = keyword.to_lowercase();
let pos = match text_lower.find(&keyword_lower) {
Some(p) => p,
None => return String::new(),
};
const SENTENCE_END_MARKERS: [char; 3] = ['.', '\n', '。'];
let start = text[..pos].rfind(SENTENCE_END_MARKERS)
.map(|i| {
match text[i..].char_indices().nth(1) {
Some((next_idx, _)) => i + next_idx, None => pos, }
})
.unwrap_or(0);
let end = text[pos..].find(SENTENCE_END_MARKERS)
.map(|i| {
let marker_pos = pos + i;
match text[marker_pos..].char_indices().nth(1) {
Some((next_idx, _)) => marker_pos + next_idx,
None => text.len(), }
})
.unwrap_or_else(|| {
let max_end = pos + MAX_MEMORY_CONTENT_LENGTH;
if max_end >= text.len() {
text.len()
} else {
let mut boundary = max_end;
while boundary > pos && !text.is_char_boundary(boundary) {
boundary -= 1;
}
boundary
}
});
if start >= end || start > text.len() || end > text.len() {
return String::new();
}
let content = text[start..end].trim();
if is_low_quality_memory(content) {
return String::new();
}
if content.len() > MAX_MEMORY_CONTENT_LENGTH {
truncate_str(content, MAX_MEMORY_CONTENT_LENGTH - 3)
} else {
content.to_string()
}
}
fn is_low_quality_memory(content: &str) -> bool {
if content.len() < MIN_MEMORY_CONTENT_LENGTH {
return true;
}
let formatting_chars = ['│', '├', '└', '┌', '┐', '─', '═', '║', '╔', '╗', '╚', '╝'];
if content.chars().any(|c| formatting_chars.contains(&c)) {
return true;
}
let first_char = content.chars().next().unwrap_or(' ');
if !first_char.is_alphanumeric() && !first_char.is_ascii_punctuation() && first_char > '\u{FF}' {
if content.starts_with("🎯") || content.starts_with("🔧") || content.starts_with("💡") ||
content.starts_with("📚") || content.starts_with("🏗") || content.starts_with("👤") ||
content.starts_with("⭐") || content.starts_with("📝") || content.starts_with("✅") ||
content.starts_with("❌") || content.starts_with("⚠") {
return true;
}
}
if content.contains("【自动记忆摘要】") || content.contains("[ACCUMULATED MEMORY]") ||
content.contains("记忆统计") || content.contains("memory.json") {
return true;
}
if content.starts_with("- ") && content.len() < 30 {
return true;
}
let alpha_count = content.chars().filter(|c| c.is_alphabetic()).count();
let total_count = content.chars().count();
if total_count > 0 && alpha_count < total_count / 4 {
return true;
}
false
}
#[derive(Debug, Clone)]
pub struct RewindResult {
pub original_count: usize,
pub new_count: usize,
pub rewind_index: usize,
pub summary: Option<String>,
pub new_messages: Vec<Message>,
}
pub async fn summarize_up_to(
messages: &[Message],
index: usize,
compressor: Option<&dyn crate::compress::Compressor>,
) -> Result<RewindResult> {
if index >= messages.len() {
anyhow::bail!("rewind index {} out of bounds (messages: {})", index, messages.len());
}
if index == 0 {
return Ok(RewindResult {
original_count: messages.len(),
new_count: messages.len(),
rewind_index: 0,
summary: None,
new_messages: messages.to_vec(),
});
}
let to_summarize = &messages[..index];
let to_keep = &messages[index..];
let summary = if let Some(comp) = compressor {
let segment = comp.summarize(to_summarize, &crate::compress::CompressionConfig::default()).await?;
Some(segment.summary)
} else {
Some(generate_simple_summary(to_summarize))
};
let summary_msg = create_summary_message(&summary, to_summarize.len());
let new_messages: Vec<Message> = std::iter::once(summary_msg)
.chain(to_keep.iter().cloned())
.collect();
let new_count = new_messages.len();
Ok(RewindResult {
original_count: messages.len(),
new_count,
rewind_index: index,
summary,
new_messages,
})
}
fn create_summary_message(summary: &Option<String>, original_count: usize) -> Message {
let content = match summary {
Some(s) => format!("[对话摘要 - 原 {} 条消息]\n\n{}", original_count, s),
None => format!("[对话摘要 - 原 {} 条消息已压缩]", original_count),
};
Message {
role: crate::providers::Role::User,
content: crate::providers::MessageContent::Text(content),
}
}
fn generate_simple_summary(messages: &[Message]) -> String {
let mut parts: Vec<String> = Vec::new();
for msg in messages {
if msg.role == crate::providers::Role::User {
let text = match &msg.content {
crate::providers::MessageContent::Text(t) => t,
_ => continue,
};
let first_line = text.lines().next().unwrap_or("");
if first_line.len() > 20 {
parts.push(truncate_str(first_line, 100));
}
}
}
if parts.is_empty() {
"对话已压缩".to_string()
} else if parts.len() <= 5 {
parts.join(" | ")
} else {
format!("{} ... (共 {} 个话题)", parts[0], parts.len())
}
}
pub struct SemanticUtils;
impl SemanticUtils {
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot_product = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>();
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
}
pub struct TfIdfSearch {
doc_word_freq: HashMap<String, HashMap<String, f32>>,
total_docs: usize,
idf_cache: HashMap<String, f32>,
}
impl TfIdfSearch {
pub fn new() -> Self {
Self {
doc_word_freq: HashMap::new(),
total_docs: 0,
idf_cache: HashMap::new(),
}
}
pub fn index(&mut self, memory: &AutoMemory) {
self.clear();
self.total_docs = memory.entries.len();
for entry in &memory.entries {
let words = self.tokenize(&entry.content);
let word_freq = self.compute_word_freq(&words);
self.doc_word_freq.insert(entry.content.clone(), word_freq);
}
self.compute_idf();
}
fn tokenize(&self, text: &str) -> Vec<String> {
let lower = text.to_lowercase();
let mut tokens = Vec::new();
for word in lower.split_whitespace() {
let trimmed = word.trim_matches(|c: char| !c.is_alphanumeric());
if trimmed.len() > 1 {
tokens.push(trimmed.to_string());
}
let chars: Vec<char> = trimmed.chars().collect();
let has_cjk = chars.iter().any(|c| Self::is_cjk(*c));
if has_cjk {
for c in &chars {
if Self::is_cjk(*c) {
tokens.push(c.to_string());
}
}
for window in chars.windows(2) {
if Self::is_cjk(window[0]) || Self::is_cjk(window[1]) {
tokens.push(window.iter().collect::<String>());
}
}
}
}
tokens
}
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}' | '\u{3000}'..='\u{303F}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' )
}
fn compute_word_freq(&self, words: &[String]) -> HashMap<String, f32> {
let total = words.len() as f32;
let mut freq = HashMap::new();
for word in words {
*freq.entry(word.clone()).or_insert(0.0) += 1.0;
}
for (_, count) in freq.iter_mut() {
*count /= total;
}
freq
}
fn compute_idf(&mut self) {
let mut word_doc_count: HashMap<String, usize> = HashMap::new();
for word_freq in &self.doc_word_freq {
for word in word_freq.1.keys() {
*word_doc_count.entry(word.clone()).or_insert(0) += 1;
}
}
for (word, count) in word_doc_count {
let idf = (self.total_docs as f32 / count as f32).ln();
self.idf_cache.insert(word, idf);
}
}
pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<(String, f32)> {
let query_words = self.tokenize(query);
let query_freq = self.compute_word_freq(&query_words);
let mut results: Vec<(String, f32)> = Vec::new();
for (doc, doc_freq) in &self.doc_word_freq {
let similarity = self.compute_similarity(&query_freq, doc_freq);
if similarity > 0.0 {
results.push((doc.clone(), similarity));
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if let Some(max) = limit {
results.into_iter().take(max).collect()
} else {
results
}
}
fn compute_similarity(&self, query_freq: &HashMap<String, f32>, doc_freq: &HashMap<String, f32>) -> f32 {
let mut similarity = 0.0;
for (word, tf_query) in query_freq {
if let Some(tf_doc) = doc_freq.get(word)
&& let Some(idf) = self.idf_cache.get(word) {
similarity += tf_query * idf * tf_doc * idf;
}
}
similarity
}
pub fn clear(&mut self) {
self.doc_word_freq.clear();
self.idf_cache.clear();
self.total_docs = 0;
}
}
impl Default for TfIdfSearch {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_entry_creation() {
let entry = MemoryEntry::new(
MemoryCategory::Decision,
"Decided to use PostgreSQL for database".to_string(),
Some("session-123".to_string()),
);
assert_eq!(entry.category, MemoryCategory::Decision);
assert_eq!(entry.importance, 90.0);
assert!(!entry.is_manual);
}
#[test]
fn test_memory_reference_increase() {
let mut entry = MemoryEntry::new(
MemoryCategory::Finding,
"API endpoint is at /api/v2".to_string(),
None,
);
assert_eq!(entry.importance, 60.0);
entry.mark_referenced();
assert_eq!(entry.importance, 62.0);
entry.mark_referenced();
entry.mark_referenced();
assert_eq!(entry.importance, 66.0);
}
#[test]
fn test_auto_memory_add_and_prune() {
let mut memory = AutoMemory::new();
memory.max_entries = 5;
for i in 0..10 {
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
format!("Note {}", i),
None,
));
}
assert!(memory.entries.len() <= memory.max_entries);
}
#[test]
fn test_duplicate_detection() {
let mut memory = AutoMemory::new();
memory.add_memory(
MemoryCategory::Decision,
"Use PostgreSQL".to_string(),
None,
);
memory.add_memory(
MemoryCategory::Decision,
"Use PostgreSQL".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
}
#[test]
fn test_memory_detection() {
let text = "我决定使用 React 作为前端框架";
let entries = detect_memories_from_text(text, None);
assert!(!entries.is_empty());
assert_eq!(entries[0].category, MemoryCategory::Decision);
let text2 = "解决了认证问题,通过添加 token refresh 机制";
let entries2 = detect_memories_from_text(text2, None);
assert!(!entries2.is_empty());
assert_eq!(entries2[0].category, MemoryCategory::Solution);
let text3 = "我偏好使用 TypeScript 进行开发";
let entries3 = detect_memories_from_text(text3, None);
assert!(!entries3.is_empty());
assert_eq!(entries3[0].category, MemoryCategory::Preference);
}
#[test]
fn test_category_importance() {
assert!(MemoryCategory::Decision.default_importance() > MemoryCategory::Structure.default_importance());
assert!(MemoryCategory::Solution.default_importance() > MemoryCategory::Technical.default_importance());
}
#[test]
fn test_top_n_entries() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decision 1".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Finding, "Finding 1".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Structure, "Structure 1".into(), None));
let top = memory.top_n(2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].category, MemoryCategory::Decision); }
#[test]
fn test_similarity_calculation() {
let sim = AutoMemory::calculate_similarity("hello world", "hello world");
assert_eq!(sim, 1.0);
let sim = AutoMemory::calculate_similarity("hello world", "foo bar");
assert_eq!(sim, 0.0);
let sim = AutoMemory::calculate_similarity("hello world", "hello there");
assert!(sim > 0.0 && sim < 1.0);
let sim = AutoMemory::calculate_similarity("", "hello");
assert_eq!(sim, 0.0);
}
#[test]
fn test_similarity_threshold() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(
MemoryCategory::Decision,
"We decided to use PostgreSQL for our database system".to_string(),
None,
));
memory.add_memory(
MemoryCategory::Decision,
"We decided to use PostgreSQL for our database backend".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
}
#[test]
fn test_short_content_skipped() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
"short".to_string(), None,
));
memory.add_memory(
MemoryCategory::Technical,
"brief".to_string(),
None,
);
assert_eq!(memory.entries.len(), 2);
}
#[test]
fn test_prune_preserves_manual() {
let mut memory = AutoMemory::new();
memory.max_entries = 3;
let mut manual = MemoryEntry::manual(MemoryCategory::Decision, "Manual decision".into());
manual.importance = 10.0; memory.add(manual);
for i in 0..5 {
let entry = MemoryEntry::new(
MemoryCategory::Decision,
format!("Auto decision {}", i),
None,
);
memory.add(entry);
}
assert!(memory.entries.iter().any(|e| e.is_manual));
assert!(memory.entries.len() <= memory.max_entries);
}
#[test]
fn test_deduplicate_entries() {
let entries = vec![
MemoryEntry::new(MemoryCategory::Decision, "We chose PostgreSQL database system for our backend".into(), None),
MemoryEntry::new(MemoryCategory::Decision, "We chose PostgreSQL database system backend".into(), None),
MemoryEntry::new(MemoryCategory::Decision, "Using Redis for caching layer".into(), None),
];
let deduped = deduplicate_entries(entries);
assert!(deduped.len() >= 1);
assert!(deduped.len() <= 3);
let pg_entries: Vec<_> = deduped.iter()
.filter(|e| e.content.to_lowercase().contains("postgresql"))
.collect();
if pg_entries.len() == 1 {
assert!(pg_entries[0].content.contains("backend"));
}
}
#[test]
fn test_memory_detection_edge_cases() {
let entries = detect_memories_from_text("", None);
assert!(entries.is_empty());
let entries = detect_memories_from_text("决定", None);
assert!(entries.is_empty());
let entries = detect_memories_from_text("使用", None);
assert!(entries.is_empty());
let text = "我决定使用React,解决了性能问题通过添加缓存机制";
let entries = detect_memories_from_text(text, None);
assert!(entries.len() <= MAX_DETECTED_ENTRIES);
}
#[test]
fn test_importance_ceiling() {
let mut entry = MemoryEntry::new(
MemoryCategory::Decision,
"Important decision".into(),
None,
);
assert_eq!(entry.importance, 90.0);
for _ in 0..10 {
entry.mark_referenced();
}
assert!(entry.importance <= 100.0);
}
#[test]
fn test_time_decay() {
let mut memory = AutoMemory::new();
memory.min_importance = 30.0;
let mut manual = MemoryEntry::manual(MemoryCategory::Decision, "Manual entry".into());
manual.importance = 50.0;
memory.add(manual);
let mut old_entry = MemoryEntry::new(
MemoryCategory::Technical,
"Old technical note".into(),
None,
);
old_entry.importance = 60.0;
old_entry.last_referenced = Utc::now() - chrono::Duration::days(60);
memory.add(old_entry);
let recent_entry = MemoryEntry::new(
MemoryCategory::Finding,
"Recent finding".into(),
None,
);
memory.add(recent_entry);
memory.apply_time_decay();
let manual_entry = memory.entries.iter().find(|e| e.is_manual);
assert!(manual_entry.is_some());
assert_eq!(manual_entry.unwrap().importance, 50.0);
let recent = memory.entries.iter().find(|e| e.content.contains("Recent"));
assert!(recent.is_some());
assert!(recent.unwrap().importance >= 60.0);
let old = memory.entries.iter().find(|e| e.content.contains("Old"));
if let Some(old_entry) = old {
assert!(old_entry.importance < 60.0);
assert!(old_entry.importance >= memory.min_importance * 0.5);
}
}
#[test]
fn test_parse_memory_response() {
let json = r#"{"memories": [{"category": "decision", "content": "决定使用 PostgreSQL 作为数据库", "importance": 90}, {"category": "preference", "content": "我偏好 TypeScript 而非 JavaScript", "importance": 70}]}"#;
let entries = parse_memory_response(json, None).unwrap();
assert_eq!(entries.len(), 2);
let has_decision = entries.iter().any(|e| e.category == MemoryCategory::Decision);
let has_preference = entries.iter().any(|e| e.category == MemoryCategory::Preference);
assert!(has_decision);
assert!(has_preference);
let decision_entry = entries.iter().find(|e| e.category == MemoryCategory::Decision);
assert!(decision_entry.is_some());
assert_eq!(decision_entry.unwrap().importance, 90.0);
let empty_json = r#"{"memories": []}"#;
let empty_entries = parse_memory_response(empty_json, None).unwrap();
assert!(empty_entries.is_empty());
let markdown_json = r#"```json
{"memories": [{"category": "solution", "content": "通过添加 middleware 修复认证问题", "importance": 85}]}
```"#;
let markdown_entries = parse_memory_response(markdown_json, None).unwrap();
assert_eq!(markdown_entries.len(), 1);
assert_eq!(markdown_entries[0].category, MemoryCategory::Solution);
let unknown_json = r#"{"memories": [{"category": "unknown", "content": "This should be skipped content", "importance": 50}]}"#;
let unknown_entries = parse_memory_response(unknown_json, None).unwrap();
assert!(unknown_entries.is_empty());
let short_json = r#"{"memories": [{"category": "finding", "content": "short", "importance": 60}]}"#;
let short_entries = parse_memory_response(short_json, None).unwrap();
assert!(short_entries.is_empty());
}
#[test]
fn test_public_has_similar() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(
MemoryCategory::Decision,
"We decided to use PostgreSQL for our main database system".to_string(),
None,
));
assert!(memory.has_similar("We decided to use PostgreSQL for our main database system"));
assert!(memory.has_similar("We decided to use PostgreSQL for our main database backend"));
assert!(!memory.has_similar("We decided to use Redis for caching"));
assert!(!memory.has_similar("The project uses React for frontend"));
assert!(!memory.has_similar("short"));
}
#[test]
fn test_public_prune() {
let mut memory = AutoMemory::new();
memory.max_entries = 5;
memory.min_importance = 30.0;
for i in 0..10 {
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
format!("Technical note number {} with sufficient length", i),
None,
));
}
memory.prune();
assert!(memory.entries.len() <= memory.max_entries);
}
#[test]
fn test_statistics() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decision one with enough content".to_string(), None));
memory.add(MemoryEntry::new(MemoryCategory::Preference, "Preference for TypeScript over JavaScript".to_string(), None));
memory.add(MemoryEntry::manual(MemoryCategory::Technical, "Manual technical note".to_string()));
memory.entries[0].mark_referenced();
memory.entries[0].mark_referenced();
memory.entries[0].mark_referenced();
let stats = memory.generate_statistics();
assert_eq!(stats.total, 3);
assert_eq!(stats.manual, 1);
assert_eq!(stats.auto, 2);
assert_eq!(stats.highly_referenced, 1); assert!(stats.by_category.contains_key(&MemoryCategory::Decision));
assert!(stats.by_category.contains_key(&MemoryCategory::Preference));
assert!(stats.by_category.contains_key(&MemoryCategory::Technical));
assert!(stats.avg_importance > 0.0);
}
#[test]
fn test_memory_config() {
let config = MemoryConfig::default();
assert_eq!(config.max_entries, 100);
assert_eq!(config.min_importance, 30.0);
assert_eq!(config.decay_start_days, 30);
assert_eq!(config.decay_rate, 0.5);
let minimal = MemoryConfig::minimal();
assert_eq!(minimal.max_entries, 50);
assert!(minimal.min_importance > config.min_importance);
let archival = MemoryConfig::archival();
assert_eq!(archival.max_entries, 500);
assert!(archival.min_importance < config.min_importance);
let custom = MemoryConfig::with_max_entries(200);
assert_eq!(custom.max_entries, 200);
assert_eq!(custom.min_importance, 30.0); }
#[test]
fn test_auto_memory_with_config() {
let config = MemoryConfig::minimal();
let mut memory = AutoMemory::with_config(config);
assert_eq!(memory.max_entries, 50);
assert_eq!(memory.min_importance, 50.0);
for i in 0..60 {
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
format!("Technical note {} with enough length for detection", i),
None,
));
}
assert!(memory.entries.len() <= 50);
}
#[test]
fn test_batch_add() {
let mut memory = AutoMemory::new();
let entries: Vec<MemoryEntry> = vec![
MemoryEntry::new(MemoryCategory::Decision, "First decision with sufficient content".into(), None),
MemoryEntry::new(MemoryCategory::Finding, "First finding with sufficient content".into(), None),
MemoryEntry::new(MemoryCategory::Solution, "First solution with sufficient content".into(), None),
];
memory.add_batch(entries);
assert_eq!(memory.entries.len(), 3);
let duplicate_entries: Vec<MemoryEntry> = vec![
MemoryEntry::new(MemoryCategory::Decision, "First decision with sufficient content".into(), None), MemoryEntry::new(MemoryCategory::Technical, "New technical note with sufficient content".into(), None),
];
memory.add_batch(duplicate_entries);
assert_eq!(memory.entries.len(), 4); }
#[test]
fn test_search_with_limit() {
let mut memory = AutoMemory::new();
for i in 0..10 {
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
format!("PostgreSQL technical note {} with details", i),
None,
));
}
let all = memory.search("postgresql");
assert_eq!(all.len(), 10);
let limited = memory.search_with_limit("postgresql", Some(5));
assert_eq!(limited.len(), 5);
assert!(limited[0].importance >= limited[limited.len() - 1].importance);
}
#[test]
fn test_multi_keyword_search() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "Decided to use PostgreSQL".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Technical, "Using Redis for caching".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Solution, "Fixed by adding middleware".into(), None));
let results = memory.search_multi(&["postgresql", "redis"]);
assert_eq!(results.len(), 2);
let empty = memory.search_multi(&["mongodb"]);
assert!(empty.is_empty());
}
#[test]
fn test_mark_referenced_with_increment() {
let mut entry = MemoryEntry::new(
MemoryCategory::Finding,
"API endpoint location".into(),
None,
);
assert_eq!(entry.importance, 60.0);
entry.mark_referenced_with_increment(5.0);
assert_eq!(entry.importance, 65.0);
entry.mark_referenced();
assert_eq!(entry.importance, 67.0);
for _ in 0..20 {
entry.mark_referenced_with_increment(10.0);
}
assert!(entry.importance <= 100.0);
}
#[test]
fn test_search_index() {
let mut memory = AutoMemory::new();
for i in 0..20 {
memory.add(MemoryEntry::new(
MemoryCategory::Technical,
format!("PostgreSQL technical note {} with sufficient content length", i),
None,
));
}
for i in 0..10 {
memory.add(MemoryEntry::new(
MemoryCategory::Decision,
format!("Redis decision {} with sufficient content for testing", i),
None,
));
}
memory.rebuild_index();
assert!(memory.search_index.is_some());
let results = memory.search_fast("postgresql", Some(5));
assert!(results.len() <= 5);
assert!(results.iter().all(|e| e.content.to_lowercase().contains("postgresql")));
let multi_results = memory.search_multi_fast(&["postgresql", "redis"]);
assert!(multi_results.len() > 0);
let tech_entries = memory.by_category_fast(MemoryCategory::Technical);
assert_eq!(tech_entries.len(), 20);
let decision_entries = memory.by_category_fast(MemoryCategory::Decision);
assert_eq!(decision_entries.len(), 10);
let top = memory.top_n_fast(5);
assert_eq!(top.len(), 5);
assert!(top[0].importance >= top[top.len() - 1].importance);
}
#[test]
fn test_index_auto_rebuild() {
let mut memory = AutoMemory::new();
assert!(memory.search_index.is_none());
memory.add(MemoryEntry::new(
MemoryCategory::Decision,
"Test decision with sufficient content length".into(),
None,
));
let results = memory.search_fast("test", None);
assert!(results.len() > 0);
assert!(memory.search_index.is_some());
memory.clear();
assert!(memory.search_index.is_none());
memory.add(MemoryEntry::new(
MemoryCategory::Finding,
"New finding with sufficient content".into(),
None,
));
let _ = memory.search_fast("finding", None);
assert!(memory.search_index.is_some());
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(SemanticUtils::cosine_similarity(&a, &b), 1.0);
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!((SemanticUtils::cosine_similarity(&a, &b) - 0.0).abs() < 0.001);
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!((SemanticUtils::cosine_similarity(&a, &b) - (-1.0)).abs() < 0.001);
let a = vec![1.0, 1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = SemanticUtils::cosine_similarity(&a, &b);
assert!(sim > 0.0 && sim < 1.0);
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
assert_eq!(SemanticUtils::cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_tfidf_search() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "使用 PostgreSQL 作为主数据库系统".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Technical, "Redis 缓存配置为 10 个连接".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Solution, "通过添加 middleware 修复认证问题".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Finding, "数据库连接池设置为 20".into(), None));
let mut tfidf = TfIdfSearch::new();
tfidf.index(&memory);
let results = tfidf.search("数据库", Some(5));
assert!(!results.is_empty());
assert!(results[0].0.contains("数据库"));
let results = tfidf.search("redis", Some(5));
assert!(!results.is_empty());
assert!(results[0].0.to_lowercase().contains("redis"));
let results = tfidf.search("mongodb", Some(5));
assert!(results.is_empty());
}
#[test]
fn test_tfidf_ranking() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "使用 PostgreSQL 数据库 作为主数据库".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Technical, "数据库连接池配置".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Solution, "修复了前端样式问题".into(), None));
let mut tfidf = TfIdfSearch::new();
tfidf.index(&memory);
let results = tfidf.search("数据库", None);
if results.len() >= 2 {
assert!(results[0].1 >= results[1].1);
}
}
#[test]
fn test_conflict_detection() {
let mut memory = AutoMemory::new();
memory.add_memory(
MemoryCategory::Decision,
"决定使用 PostgreSQL 作为主数据库".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
assert!(memory.entries[0].content.contains("PostgreSQL"));
memory.add_memory(
MemoryCategory::Decision,
"决定使用 MySQL 作为主数据库".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
assert!(memory.entries[0].content.contains("MySQL"));
}
#[test]
fn test_conflict_with_change_signal() {
let mut memory = AutoMemory::new();
memory.add_memory(
MemoryCategory::Preference,
"偏好使用 vim 编辑器".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
memory.add_memory(
MemoryCategory::Preference,
"改用 vscode 编辑器,不再使用 vim".to_string(),
None,
);
assert_eq!(memory.entries.len(), 1);
assert!(memory.entries[0].content.contains("vscode"));
}
#[test]
fn test_no_false_conflict() {
let mut memory = AutoMemory::new();
memory.add_memory(
MemoryCategory::Decision,
"决定使用 PostgreSQL 作为主数据库".to_string(),
None,
);
memory.add_memory(
MemoryCategory::Decision,
"决定使用 Redis 作为缓存系统".to_string(),
None,
);
assert_eq!(memory.entries.len(), 2);
}
#[test]
fn test_contextual_summary() {
let mut memory = AutoMemory::new();
memory.add(MemoryEntry::new(MemoryCategory::Decision, "决定使用 PostgreSQL 作为主数据库".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Technical, "前端使用 React 框架开发".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Solution, "通过添加 Redis 缓存解决性能问题".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Finding, "API 响应时间在 200ms 以内".into(), None));
memory.add(MemoryEntry::new(MemoryCategory::Preference, "偏好使用 TypeScript 而非 JavaScript".into(), None));
let db_summary = memory.generate_contextual_summary("数据库查询优化", 3);
assert!(db_summary.contains("PostgreSQL"));
let fe_summary = memory.generate_contextual_summary("React 组件开发", 3);
assert!(fe_summary.contains("React"));
let empty_summary = memory.generate_contextual_summary("", 3);
assert!(!empty_summary.is_empty());
}
#[test]
fn test_low_quality_memory_filter() {
assert!(is_low_quality_memory("│ 🎯 决策: 决定使用 PostgreSQL."));
assert!(is_low_quality_memory("├── Structure: 入口文件是 main."));
assert!(is_low_quality_memory("🔧 解决方案: 通过添加 middleware."));
assert!(is_low_quality_memory("【自动记忆摘要】"));
assert!(is_low_quality_memory("short"));
assert!(!is_low_quality_memory("决定使用 PostgreSQL 作为主数据库系统"));
assert!(!is_low_quality_memory("通过添加 Redis 缓存层解决了性能问题"));
assert!(!is_low_quality_memory("用户偏好使用 TypeScript 进行开发"));
}
}