Skip to main content

aster/memory/
chat_memory.rs

1//! 对话记忆模块
2//!
3//! 负责存储和管理对话摘要,支持:
4//! - 层级压缩(工作记忆 → 短期记忆 → 核心记忆)
5//! - 关键词/话题/时间范围搜索
6//! - 核心记忆管理(永不遗忘)
7
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use chrono::{DateTime, Utc};
12
13use super::types::{
14    ChatMemoryStats, ChatMemoryStore, ConversationSummary, MemoryHierarchyConfig, MemoryImportance,
15    Timestamp,
16};
17
18const CHAT_MEMORY_VERSION: &str = "1.0.0";
19const SUMMARIES_FILE: &str = "summaries.json";
20const CORE_FILE: &str = "core.json";
21
22/// 获取当前时间戳
23fn now() -> Timestamp {
24    Utc::now().to_rfc3339()
25}
26
27/// 解析时间戳
28fn parse_timestamp(ts: &str) -> Option<DateTime<Utc>> {
29    DateTime::parse_from_rfc3339(ts)
30        .ok()
31        .map(|dt| dt.with_timezone(&Utc))
32}
33
34/// 计算天数差
35fn days_between(start: &str, end: &str) -> i64 {
36    let start_dt = parse_timestamp(start);
37    let end_dt = parse_timestamp(end);
38
39    match (start_dt, end_dt) {
40        (Some(s), Some(e)) => (e - s).num_days(),
41        _ => 0,
42    }
43}
44
45/// 对话记忆管理器
46pub struct ChatMemory {
47    global_dir: PathBuf,
48    project_dir: Option<PathBuf>,
49    store: ChatMemoryStore,
50    config: MemoryHierarchyConfig,
51}
52
53impl ChatMemory {
54    /// 创建新的对话记忆管理器
55    pub fn new(project_path: Option<&Path>, config: Option<MemoryHierarchyConfig>) -> Self {
56        let global_dir = dirs::home_dir()
57            .unwrap_or_default()
58            .join(".aster")
59            .join("memory")
60            .join("chat");
61
62        let project_dir = project_path.map(|p| p.join(".aster").join("memory").join("chat"));
63
64        let cfg = config.unwrap_or_default();
65        let project_path_str = project_path
66            .map(|p| p.display().to_string())
67            .unwrap_or_default();
68
69        let mut memory = Self {
70            global_dir,
71            project_dir,
72            store: Self::create_empty_store(&project_path_str),
73            config: cfg,
74        };
75
76        memory.load();
77        memory
78    }
79
80    /// 添加对话摘要
81    pub fn add_conversation(&mut self, mut summary: ConversationSummary) {
82        if summary.id.is_empty() {
83            summary.id = nanoid::nanoid!();
84        }
85
86        self.store.summaries.push(summary);
87        self.update_stats();
88
89        if self.store.summaries.len() > self.config.compression_threshold {
90            self.compress();
91        }
92
93        self.save();
94    }
95
96    /// 搜索对话
97    pub fn search(&self, query: &str, limit: Option<usize>) -> Vec<&ConversationSummary> {
98        let limit = limit.unwrap_or(10);
99        let query_lower = query.to_lowercase();
100
101        let mut results: Vec<(&ConversationSummary, f32)> = self
102            .store
103            .summaries
104            .iter()
105            .filter_map(|summary| {
106                let mut score = 0.0;
107
108                // 摘要内容匹配
109                if summary.summary.to_lowercase().contains(&query_lower) {
110                    score += 2.0;
111                }
112
113                // 话题匹配
114                let topic_matches = summary
115                    .topics
116                    .iter()
117                    .filter(|t| t.to_lowercase().contains(&query_lower))
118                    .count();
119                score += topic_matches as f32 * 3.0;
120
121                // 文件名匹配
122                if summary
123                    .files_discussed
124                    .iter()
125                    .any(|f| f.to_lowercase().contains(&query_lower))
126                {
127                    score += 1.0;
128                }
129
130                // 符号匹配
131                if summary
132                    .symbols_discussed
133                    .iter()
134                    .any(|s| s.to_lowercase().contains(&query_lower))
135                {
136                    score += 1.0;
137                }
138
139                // 重要性加权
140                score += summary.importance as u8 as f32;
141
142                if score > 0.0 {
143                    Some((summary, score))
144                } else {
145                    None
146                }
147            })
148            .collect();
149
150        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
151        results.into_iter().take(limit).map(|(s, _)| s).collect()
152    }
153
154    /// 按话题搜索
155    pub fn search_by_topic(&self, topic: &str, limit: Option<usize>) -> Vec<&ConversationSummary> {
156        let limit = limit.unwrap_or(10);
157        let topic_lower = topic.to_lowercase();
158
159        let mut results: Vec<_> = self
160            .store
161            .summaries
162            .iter()
163            .filter(|s| {
164                s.topics
165                    .iter()
166                    .any(|t| t.to_lowercase().contains(&topic_lower))
167            })
168            .collect();
169
170        results.sort_by(|a, b| b.end_time.cmp(&a.end_time));
171        results.into_iter().take(limit).collect()
172    }
173
174    /// 压缩旧记忆
175    pub fn compress(&mut self) {
176        let current_time = now();
177        let mut summaries = std::mem::take(&mut self.store.summaries);
178
179        // 按时间排序(新到旧)
180        summaries.sort_by(|a, b| b.end_time.cmp(&a.end_time));
181
182        // 分离工作记忆
183        let working_memory: Vec<_> = summaries
184            .iter()
185            .take(self.config.working_memory_size)
186            .cloned()
187            .collect();
188
189        let older_memories: Vec<_> = summaries
190            .into_iter()
191            .skip(self.config.working_memory_size)
192            .collect();
193
194        // 分离短期和长期记忆
195        let mut short_term = Vec::new();
196        let mut long_term = Vec::new();
197
198        for memory in older_memories {
199            let days = days_between(&memory.end_time, &current_time);
200            if days <= self.config.short_term_days as i64 {
201                short_term.push(memory);
202            } else {
203                long_term.push(memory);
204            }
205        }
206
207        // 处理长期记忆(保留高重要性的)
208        let compressed_long_term: Vec<_> = long_term
209            .into_iter()
210            .filter(|m| m.importance >= MemoryImportance::Medium)
211            .collect();
212
213        // 合并结果
214        self.store.summaries = working_memory;
215        self.store.summaries.extend(short_term);
216        self.store.summaries.extend(compressed_long_term);
217
218        self.update_stats();
219        self.save();
220    }
221
222    /// 获取核心记忆
223    pub fn get_core_memories(&self) -> &[String] {
224        &self.store.core_memories
225    }
226
227    /// 添加核心记忆
228    pub fn add_core_memory(&mut self, memory: String) {
229        if self.store.core_memories.contains(&memory) {
230            return;
231        }
232
233        if self.store.core_memories.len() >= self.config.max_core_memories {
234            self.store.core_memories.remove(0);
235        }
236
237        self.store.core_memories.push(memory);
238        self.save();
239    }
240
241    /// 移除核心记忆
242    pub fn remove_core_memory(&mut self, memory: &str) -> bool {
243        if let Some(pos) = self.store.core_memories.iter().position(|m| m == memory) {
244            self.store.core_memories.remove(pos);
245            self.save();
246            true
247        } else {
248            false
249        }
250    }
251
252    /// 获取最近 N 条摘要
253    pub fn get_recent(&self, count: usize) -> Vec<&ConversationSummary> {
254        let mut sorted: Vec<_> = self.store.summaries.iter().collect();
255        sorted.sort_by(|a, b| b.end_time.cmp(&a.end_time));
256        sorted.into_iter().take(count).collect()
257    }
258
259    /// 获取所有摘要
260    pub fn get_all(&self) -> &[ConversationSummary] {
261        &self.store.summaries
262    }
263
264    /// 根据 ID 获取摘要
265    pub fn get_by_id(&self, id: &str) -> Option<&ConversationSummary> {
266        self.store.summaries.iter().find(|s| s.id == id)
267    }
268
269    /// 删除摘要
270    pub fn delete_summary(&mut self, id: &str) -> bool {
271        if let Some(pos) = self.store.summaries.iter().position(|s| s.id == id) {
272            self.store.summaries.remove(pos);
273            self.update_stats();
274            self.save();
275            true
276        } else {
277            false
278        }
279    }
280
281    /// 获取统计信息
282    pub fn get_stats(&self) -> &ChatMemoryStats {
283        &self.store.stats
284    }
285
286    /// 导出记忆
287    pub fn export(&self) -> String {
288        serde_json::to_string_pretty(&self.store).unwrap_or_default()
289    }
290
291    /// 导入记忆
292    pub fn import(&mut self, data: &str) -> Result<(), String> {
293        let parsed: ChatMemoryStore =
294            serde_json::from_str(data).map_err(|e| format!("Invalid format: {}", e))?;
295
296        // 合并摘要
297        for summary in parsed.summaries {
298            if !self.store.summaries.iter().any(|s| s.id == summary.id) {
299                self.store.summaries.push(summary);
300            }
301        }
302
303        // 合并核心记忆
304        for memory in parsed.core_memories {
305            if !self.store.core_memories.contains(&memory) {
306                self.add_core_memory(memory);
307            }
308        }
309
310        self.update_stats();
311        self.save();
312        Ok(())
313    }
314
315    /// 清空所有记忆
316    pub fn clear(&mut self) {
317        self.store = Self::create_empty_store(&self.store.project_path);
318        self.save();
319    }
320
321    // === 私有方法 ===
322
323    fn create_empty_store(project_path: &str) -> ChatMemoryStore {
324        let current_time = now();
325        ChatMemoryStore {
326            version: CHAT_MEMORY_VERSION.to_string(),
327            project_path: project_path.to_string(),
328            summaries: Vec::new(),
329            core_memories: Vec::new(),
330            last_updated: current_time.clone(),
331            stats: ChatMemoryStats {
332                total_conversations: 0,
333                total_messages: 0,
334                oldest_conversation: current_time.clone(),
335                newest_conversation: current_time,
336            },
337        }
338    }
339
340    fn update_stats(&mut self) {
341        let summaries = &self.store.summaries;
342
343        self.store.stats.total_conversations = summaries.len();
344        self.store.stats.total_messages = summaries.iter().map(|s| s.message_count as usize).sum();
345
346        if !summaries.is_empty() {
347            let mut sorted: Vec<_> = summaries.iter().collect();
348            sorted.sort_by(|a, b| a.start_time.cmp(&b.start_time));
349
350            self.store.stats.oldest_conversation = sorted.first().unwrap().start_time.clone();
351            self.store.stats.newest_conversation = sorted.last().unwrap().end_time.clone();
352        }
353
354        self.store.last_updated = now();
355    }
356
357    fn load(&mut self) {
358        // 加载全局数据
359        if let Some(global_store) = self.load_from_dir(&self.global_dir) {
360            self.store.summaries = global_store.summaries;
361            self.store.core_memories = global_store.core_memories;
362        }
363
364        // 加载项目数据并合并
365        if let Some(ref project_dir) = self.project_dir {
366            if let Some(project_store) = self.load_from_dir(project_dir) {
367                for summary in project_store.summaries {
368                    if !self.store.summaries.iter().any(|s| s.id == summary.id) {
369                        self.store.summaries.push(summary);
370                    }
371                }
372                for memory in project_store.core_memories {
373                    if !self.store.core_memories.contains(&memory) {
374                        self.store.core_memories.push(memory);
375                    }
376                }
377            }
378        }
379
380        self.update_stats();
381    }
382
383    fn load_from_dir(&self, dir: &Path) -> Option<ChatMemoryStore> {
384        let summaries_path = dir.join(SUMMARIES_FILE);
385        if !summaries_path.exists() {
386            return None;
387        }
388
389        let content = fs::read_to_string(&summaries_path).ok()?;
390        serde_json::from_str(&content).ok()
391    }
392
393    fn save(&self) {
394        self.save_to_dir(&self.global_dir);
395        if let Some(ref project_dir) = self.project_dir {
396            self.save_to_dir(project_dir);
397        }
398    }
399
400    fn save_to_dir(&self, dir: &Path) {
401        if let Err(e) = fs::create_dir_all(dir) {
402            eprintln!("Failed to create directory {:?}: {}", dir, e);
403            return;
404        }
405
406        let summaries_path = dir.join(SUMMARIES_FILE);
407        let core_path = dir.join(CORE_FILE);
408
409        if let Ok(content) = serde_json::to_string_pretty(&self.store) {
410            let _ = fs::write(&summaries_path, content);
411        }
412
413        let core_data = serde_json::json!({
414            "version": CHAT_MEMORY_VERSION,
415            "memories": &self.store.core_memories,
416            "last_updated": &self.store.last_updated,
417        });
418
419        if let Ok(content) = serde_json::to_string_pretty(&core_data) {
420            let _ = fs::write(&core_path, content);
421        }
422    }
423}
424
425impl Default for ChatMemory {
426    fn default() -> Self {
427        Self::new(None, None)
428    }
429}