Skip to main content

mofa_foundation/agent/components/
memory.rs

1//! 记忆组件
2//!
3//! 定义 Agent 的记忆/状态持久化能力
4
5use async_trait::async_trait;
6use chrono::Utc;
7use mofa_kernel::agent::error::AgentError;
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub use mofa_kernel::agent::components::memory::{
14    Memory, MemoryItem, MemoryStats, MemoryValue, Message, MessageRole,
15};
16
17// Use kernel's AgentResult type
18pub use mofa_kernel::agent::AgentResult;
19
20// ============================================================================
21// 内存实现
22// ============================================================================
23
24/// 简单内存存储
25pub struct InMemoryStorage {
26    data: HashMap<String, MemoryItem>,
27    history: HashMap<String, Vec<Message>>,
28}
29
30impl InMemoryStorage {
31    /// 创建新的内存存储
32    pub fn new() -> Self {
33        Self {
34            data: HashMap::new(),
35            history: HashMap::new(),
36        }
37    }
38}
39
40impl Default for InMemoryStorage {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[async_trait]
47impl Memory for InMemoryStorage {
48    async fn store(&mut self, key: &str, value: MemoryValue) -> AgentResult<()> {
49        let item = MemoryItem::new(key, value);
50        self.data.insert(key.to_string(), item);
51        Ok(())
52    }
53
54    async fn retrieve(&self, key: &str) -> AgentResult<Option<MemoryValue>> {
55        Ok(self.data.get(key).map(|item| item.value.clone()))
56    }
57
58    async fn remove(&mut self, key: &str) -> AgentResult<bool> {
59        Ok(self.data.remove(key).is_some())
60    }
61
62    async fn search(&self, query: &str, limit: usize) -> AgentResult<Vec<MemoryItem>> {
63        // 简单的关键词匹配搜索
64        let query_lower = query.to_lowercase();
65        let mut results: Vec<MemoryItem> = self
66            .data
67            .values()
68            .filter(|item| {
69                if let Some(text) = item.value.as_text() {
70                    text.to_lowercase().contains(&query_lower)
71                } else {
72                    false
73                }
74            })
75            .cloned()
76            .collect();
77
78        results.sort_by(|a, b| {
79            b.score
80                .partial_cmp(&a.score)
81                .unwrap_or(std::cmp::Ordering::Equal)
82        });
83        results.truncate(limit);
84        Ok(results)
85    }
86
87    async fn clear(&mut self) -> AgentResult<()> {
88        self.data.clear();
89        Ok(())
90    }
91
92    async fn get_history(&self, session_id: &str) -> AgentResult<Vec<Message>> {
93        Ok(self.history.get(session_id).cloned().unwrap_or_default())
94    }
95
96    async fn add_to_history(&mut self, session_id: &str, message: Message) -> AgentResult<()> {
97        self.history
98            .entry(session_id.to_string())
99            .or_default()
100            .push(message);
101        Ok(())
102    }
103
104    async fn clear_history(&mut self, session_id: &str) -> AgentResult<()> {
105        self.history.remove(session_id);
106        Ok(())
107    }
108
109    async fn stats(&self) -> AgentResult<MemoryStats> {
110        let total_messages: usize = self.history.values().map(|v| v.len()).sum();
111        Ok(MemoryStats {
112            total_items: self.data.len(),
113            total_sessions: self.history.len(),
114            total_messages,
115            memory_bytes: 0, // 简化,不计算实际内存
116        })
117    }
118
119    fn memory_type(&self) -> &str {
120        "in-memory"
121    }
122}
123
124// ============================================================================
125// 基于文件的持久化存储实现
126// ============================================================================
127
128/// 基于文件的持久化存储
129///
130/// 文件结构:
131/// ```text
132/// memory/
133/// ├── data.json              # KV 存储 (MemoryValue items)
134/// ├── MEMORY.md              # 长期记忆
135/// ├── sessions/              # 会话历史
136/// │   ├── <session_id>.json  # 单个会话的消息历史
137/// ├── 2024-01-15.md          # 每日笔记 (YYYY-MM-DD.md)
138/// └── ...
139/// ```
140///
141/// # 特性
142///
143/// - 持久化到磁盘
144/// - 线程安全 (Arc<RwLock<T>>)
145/// - 原子文件写入 (临时文件 + rename)
146/// - 懒加载 (启动时从文件加载到内存)
147pub struct FileBasedStorage {
148    /// 基础目录
149    base_dir: PathBuf,
150    /// memory 目录
151    memory_dir: PathBuf,
152    /// sessions 目录
153    sessions_dir: PathBuf,
154    /// data.json 文件路径
155    data_file: PathBuf,
156    /// MEMORY.md 文件路径 (长期记忆)
157    long_term_file: PathBuf,
158    /// 内存数据 (key -> MemoryItem)
159    data: Arc<RwLock<HashMap<String, MemoryItem>>>,
160    /// 会话历史 (session_id -> Vec<Message>)
161    sessions: Arc<RwLock<HashMap<String, Vec<Message>>>>,
162}
163
164impl FileBasedStorage {
165    /// 创建新的基于文件的存储
166    ///
167    /// # 参数
168    ///
169    /// - `base_dir`: 基础目录,将在其下创建 memory/ 和 memory/sessions/
170    ///
171    /// # 示例
172    ///
173    /// ```rust,ignore
174    /// use mofa_foundation::agent::components::memory::FileBasedStorage;
175    ///
176    /// let storage = FileBasedStorage::new("/tmp/workspace").await?;
177    /// ```
178    pub async fn new(base_dir: impl AsRef<Path>) -> AgentResult<Self> {
179        let base_dir = base_dir.as_ref().to_path_buf();
180        let memory_dir = base_dir.join("memory");
181        let sessions_dir = memory_dir.join("sessions");
182        let data_file = memory_dir.join("data.json");
183        let long_term_file = memory_dir.join("MEMORY.md");
184
185        // 创建目录
186        tokio::fs::create_dir_all(&sessions_dir)
187            .await
188            .map_err(|e| {
189                AgentError::IoError(format!("Failed to create sessions directory: {}", e))
190            })?;
191
192        // 加载现有数据
193        let data = Self::load_data(&data_file).await?;
194        let sessions = Self::load_sessions(&sessions_dir).await?;
195
196        Ok(Self {
197            base_dir,
198            memory_dir,
199            sessions_dir,
200            data_file,
201            long_term_file,
202            data: Arc::new(RwLock::new(data)),
203            sessions: Arc::new(RwLock::new(sessions)),
204        })
205    }
206
207    /// 从 data.json 加载数据
208    async fn load_data(data_file: &Path) -> AgentResult<HashMap<String, MemoryItem>> {
209        if !data_file.exists() {
210            return Ok(HashMap::new());
211        }
212
213        let content = tokio::fs::read_to_string(data_file)
214            .await
215            .map_err(|e| AgentError::IoError(format!("Failed to read data.json: {}", e)))?;
216
217        if content.trim().is_empty() {
218            return Ok(HashMap::new());
219        }
220
221        serde_json::from_str(&content).map_err(|e| {
222            AgentError::SerializationError(format!("Failed to parse data.json: {}", e))
223        })
224    }
225
226    /// 从 sessions 目录加载所有会话
227    async fn load_sessions(sessions_dir: &Path) -> AgentResult<HashMap<String, Vec<Message>>> {
228        if !sessions_dir.exists() {
229            return Ok(HashMap::new());
230        }
231
232        let mut sessions = HashMap::new();
233        let mut entries = tokio::fs::read_dir(sessions_dir).await.map_err(|e| {
234            AgentError::IoError(format!("Failed to read sessions directory: {}", e))
235        })?;
236
237        while let Some(entry) = entries
238            .next_entry()
239            .await
240            .map_err(|e| AgentError::IoError(format!("Failed to read session entry: {}", e)))?
241        {
242            let path = entry.path();
243
244            // 只处理 .json 文件
245            if path.extension().and_then(|s: &std::ffi::OsStr| s.to_str()) != Some("json") {
246                continue;
247            }
248
249            // 从文件名获取 session_id (例如: "session-123.json" -> "session-123")
250            let session_id = path
251                .file_stem()
252                .and_then(|s: &std::ffi::OsStr| s.to_str())
253                .ok_or_else(|| {
254                    AgentError::IoError(format!("Invalid session file name: {:?}", path))
255                })?;
256
257            // 读取会话数据
258            let content = tokio::fs::read_to_string(&path).await.map_err(|e| {
259                AgentError::IoError(format!("Failed to read session file {:?}: {}", path, e))
260            })?;
261
262            let messages: Vec<Message> = serde_json::from_str(&content).map_err(|e| {
263                AgentError::SerializationError(format!(
264                    "Failed to parse session file {:?}: {}",
265                    path, e
266                ))
267            })?;
268
269            sessions.insert(session_id.to_string(), messages);
270        }
271
272        Ok(sessions)
273    }
274
275    /// 持久化数据到 data.json
276    ///
277    /// 使用原子写入: 写入临时文件然后 rename
278    async fn persist_data(&self) -> AgentResult<()> {
279        let data = self.data.read().await;
280        let json = serde_json::to_string_pretty(&*data).map_err(|e| {
281            AgentError::SerializationError(format!("Failed to serialize data: {}", e))
282        })?;
283        drop(data);
284
285        // 原子写入: 临时文件 + rename
286        let temp_file = self.data_file.with_extension("json.tmp");
287        tokio::fs::write(&temp_file, json)
288            .await
289            .map_err(|e| AgentError::IoError(format!("Failed to write temp data file: {}", e)))?;
290
291        tokio::fs::rename(&temp_file, &self.data_file)
292            .await
293            .map_err(|e| AgentError::IoError(format!("Failed to rename data file: {}", e)))?;
294
295        Ok(())
296    }
297
298    /// 持久化单个会话到文件
299    async fn persist_session(&self, session_id: &str) -> AgentResult<()> {
300        let sessions = self.sessions.read().await;
301        let messages = sessions.get(session_id);
302
303        let session_file = self.sessions_dir.join(format!("{}.json", session_id));
304
305        if let Some(messages) = messages {
306            // 写入会话数据
307            let json = serde_json::to_string_pretty(messages).map_err(|e| {
308                AgentError::SerializationError(format!("Failed to serialize session: {}", e))
309            })?;
310            drop(sessions);
311
312            // 原子写入
313            let temp_file = session_file.with_extension("json.tmp");
314            tokio::fs::write(&temp_file, json).await.map_err(|e| {
315                AgentError::IoError(format!("Failed to write temp session file: {}", e))
316            })?;
317
318            tokio::fs::rename(&temp_file, &session_file)
319                .await
320                .map_err(|e| {
321                    AgentError::IoError(format!("Failed to rename session file: {}", e))
322                })?;
323        } else {
324            // 会话不存在,删除文件
325            drop(sessions);
326            if session_file.exists() {
327                tokio::fs::remove_file(&session_file).await.map_err(|e| {
328                    AgentError::IoError(format!("Failed to remove session file: {}", e))
329                })?;
330            }
331        }
332
333        Ok(())
334    }
335
336    /// 获取今日日期字符串 (YYYY-MM-DD)
337    fn today_key() -> String {
338        Utc::now().format("%Y-%m-%d").to_string()
339    }
340
341    /// 获取今日文件路径 (YYYY-MM-DD.md)
342    fn today_file(&self) -> PathBuf {
343        self.memory_dir.join(format!("{}.md", Self::today_key()))
344    }
345
346    /// 读取今日笔记内容
347    pub async fn read_today_file(&self) -> AgentResult<String> {
348        let today_file = self.today_file();
349        if today_file.exists() {
350            tokio::fs::read_to_string(&today_file)
351                .await
352                .map_err(|e| AgentError::IoError(format!("Failed to read today file: {}", e)))
353        } else {
354            Ok(String::new())
355        }
356    }
357
358    /// 追加内容到今日笔记
359    pub async fn append_today_file(&self, content: &str) -> AgentResult<()> {
360        let today_file = self.today_file();
361        let final_content = if today_file.exists() {
362            let existing = tokio::fs::read_to_string(&today_file)
363                .await
364                .map_err(|e| AgentError::IoError(format!("Failed to read today file: {}", e)))?;
365            format!("{}\n{}", existing, content)
366        } else {
367            // 新文件,添加日期头部
368            let today = Self::today_key();
369            format!("# {}\n\n{}", today, content)
370        };
371
372        tokio::fs::write(&today_file, final_content)
373            .await
374            .map_err(|e| AgentError::IoError(format!("Failed to write today file: {}", e)))?;
375
376        Ok(())
377    }
378
379    /// 读取长期记忆 (MEMORY.md)
380    pub async fn read_long_term_file(&self) -> AgentResult<String> {
381        if self.long_term_file.exists() {
382            tokio::fs::read_to_string(&self.long_term_file)
383                .await
384                .map_err(|e| AgentError::IoError(format!("Failed to read long-term file: {}", e)))
385        } else {
386            Ok(String::new())
387        }
388    }
389
390    /// 写入长期记忆 (MEMORY.md)
391    pub async fn write_long_term_file(&self, content: &str) -> AgentResult<()> {
392        // 确保目录存在
393        tokio::fs::create_dir_all(&self.memory_dir)
394            .await
395            .map_err(|e| {
396                AgentError::IoError(format!("Failed to create memory directory: {}", e))
397            })?;
398
399        tokio::fs::write(&self.long_term_file, content)
400            .await
401            .map_err(|e| AgentError::IoError(format!("Failed to write long-term file: {}", e)))?;
402
403        Ok(())
404    }
405
406    /// 获取最近 N 天的记忆
407    pub async fn get_recent_memories_files(&self, days: u32) -> AgentResult<String> {
408        let mut memories = Vec::new();
409
410        for i in 0..days {
411            let date = Utc::now() - chrono::Duration::days(i as i64);
412            let date_str = date.format("%Y-%m-%d").to_string();
413            let file_path = self.memory_dir.join(format!("{}.md", date_str));
414
415            if file_path.exists() {
416                let content = tokio::fs::read_to_string(&file_path).await.map_err(|e| {
417                    AgentError::IoError(format!(
418                        "Failed to read memory file {:?}: {}",
419                        file_path, e
420                    ))
421                })?;
422                memories.push(content);
423            }
424        }
425
426        Ok(memories.join("\n\n---\n\n"))
427    }
428
429    /// 列出所有记忆文件 (按日期排序,最新的在前)
430    async fn list_memory_files(&self) -> AgentResult<Vec<PathBuf>> {
431        if !self.memory_dir.exists() {
432            return Ok(Vec::new());
433        }
434
435        let mut entries = tokio::fs::read_dir(&self.memory_dir)
436            .await
437            .map_err(|e| AgentError::IoError(format!("Failed to read memory directory: {}", e)))?;
438        let mut files = Vec::new();
439
440        while let Some(entry) = entries
441            .next_entry()
442            .await
443            .map_err(|e| AgentError::IoError(format!("Failed to read entry: {}", e)))?
444        {
445            let path = entry.path();
446            if let Some(name) = path.file_name().and_then(|n: &std::ffi::OsStr| n.to_str()) {
447                // 检查是否匹配 YYYY-MM-DD.md 模式
448                if Self::is_date_file(name) {
449                    files.push(path);
450                }
451            }
452        }
453
454        // 按文件名倒序排序 (最新的在前)
455        files.sort_by(|a: &PathBuf, b: &PathBuf| b.cmp(a));
456        Ok(files)
457    }
458
459    /// 检查文件名是否匹配日期格式 (YYYY-MM-DD.md)
460    fn is_date_file(name: &str) -> bool {
461        if name.len() != 13 {
462            // "2024-01-15.md" = 13 bytes
463            return false;
464        }
465        let bytes = name.as_bytes();
466        bytes[4] == b'-' && bytes[7] == b'-' && name.ends_with(".md")
467    }
468
469    /// 获取记忆上下文
470    pub async fn get_memory_context(&self) -> AgentResult<String> {
471        let mut parts = Vec::new();
472
473        // 长期记忆
474        let long_term = self.read_long_term_file().await?;
475        if !long_term.is_empty() {
476            parts.push(format!("## Long-term Memory\n{}", long_term));
477        }
478
479        // 今日笔记
480        let today = self.read_today_file().await?;
481        if !today.is_empty() {
482            parts.push(format!("## Today's Notes\n{}", today));
483        }
484
485        Ok(parts.join("\n\n"))
486    }
487
488    /// 读取今日笔记
489    pub async fn read_today(&self) -> AgentResult<String> {
490        self.read_today_file().await
491    }
492
493    /// 追加今日笔记
494    pub async fn append_today(&self, content: &str) -> AgentResult<()> {
495        self.append_today_file(content).await
496    }
497
498    /// 读取长期记忆
499    pub async fn read_long_term(&self) -> AgentResult<String> {
500        self.read_long_term_file().await
501    }
502
503    /// 写入长期记忆
504    pub async fn write_long_term(&self, content: &str) -> AgentResult<()> {
505        self.write_long_term_file(content).await
506    }
507
508    /// 获取最近记忆
509    pub async fn get_recent_memories(&self, days: u32) -> AgentResult<String> {
510        self.get_recent_memories_files(days).await
511    }
512}
513
514#[async_trait]
515impl Memory for FileBasedStorage {
516    async fn store(&mut self, key: &str, value: MemoryValue) -> AgentResult<()> {
517        let item = MemoryItem::new(key, value);
518        {
519            let mut data = self.data.write().await;
520            data.insert(key.to_string(), item);
521        }
522        self.persist_data().await?;
523        Ok(())
524    }
525
526    async fn retrieve(&self, key: &str) -> AgentResult<Option<MemoryValue>> {
527        let data = self.data.read().await;
528        Ok(data.get(key).map(|item| item.value.clone()))
529    }
530
531    async fn remove(&mut self, key: &str) -> AgentResult<bool> {
532        let removed = {
533            let mut data = self.data.write().await;
534            data.remove(key).is_some()
535        };
536        if removed {
537            self.persist_data().await?;
538        }
539        Ok(removed)
540    }
541
542    async fn search(&self, query: &str, limit: usize) -> AgentResult<Vec<MemoryItem>> {
543        // 在内存数据中搜索
544        let query_lower = query.to_lowercase();
545        let mut results: Vec<MemoryItem> = {
546            let data = self.data.read().await;
547            data.values()
548                .filter(|item| {
549                    if let Some(text) = item.value.as_text() {
550                        text.to_lowercase().contains(&query_lower)
551                    } else {
552                        false
553                    }
554                })
555                .cloned()
556                .collect()
557        };
558
559        // 同时在 markdown 文件中搜索
560        let memory_files = self.list_memory_files().await?;
561        for file_path in memory_files {
562            if let Ok(content) = tokio::fs::read_to_string(&file_path).await
563                && content.to_lowercase().contains(&query_lower)
564            {
565                let file_name = file_path
566                    .file_name()
567                    .and_then(|n| n.to_str())
568                    .unwrap_or("unknown")
569                    .to_string();
570
571                results.push(MemoryItem::new(file_name, MemoryValue::text(content)));
572            }
573        }
574
575        results.sort_by(|a, b| {
576            b.score
577                .partial_cmp(&a.score)
578                .unwrap_or(std::cmp::Ordering::Equal)
579        });
580        results.truncate(limit);
581        Ok(results)
582    }
583
584    async fn clear(&mut self) -> AgentResult<()> {
585        // 清空内存
586        {
587            let mut data = self.data.write().await;
588            data.clear();
589        }
590        {
591            let mut sessions = self.sessions.write().await;
592            sessions.clear();
593        }
594
595        // 删除所有文件
596        if self.data_file.exists() {
597            tokio::fs::remove_file(&self.data_file)
598                .await
599                .map_err(|e| AgentError::IoError(format!("Failed to remove data file: {}", e)))?;
600        }
601
602        if self.sessions_dir.exists() {
603            tokio::fs::remove_dir_all(&self.sessions_dir)
604                .await
605                .map_err(|e| {
606                    AgentError::IoError(format!("Failed to remove sessions directory: {}", e))
607                })?;
608            tokio::fs::create_dir_all(&self.sessions_dir)
609                .await
610                .map_err(|e| {
611                    AgentError::IoError(format!("Failed to recreate sessions directory: {}", e))
612                })?;
613        }
614
615        Ok(())
616    }
617
618    async fn get_history(&self, session_id: &str) -> AgentResult<Vec<Message>> {
619        let sessions = self.sessions.read().await;
620        Ok(sessions.get(session_id).cloned().unwrap_or_default())
621    }
622
623    async fn add_to_history(&mut self, session_id: &str, message: Message) -> AgentResult<()> {
624        {
625            let mut sessions = self.sessions.write().await;
626            sessions
627                .entry(session_id.to_string())
628                .or_default()
629                .push(message);
630        }
631        self.persist_session(session_id).await?;
632        Ok(())
633    }
634
635    async fn clear_history(&mut self, session_id: &str) -> AgentResult<()> {
636        {
637            let mut sessions = self.sessions.write().await;
638            sessions.remove(session_id);
639        }
640        self.persist_session(session_id).await?;
641        Ok(())
642    }
643
644    async fn stats(&self) -> AgentResult<MemoryStats> {
645        let data = self.data.read().await;
646        let sessions = self.sessions.read().await;
647
648        let total_messages: usize = sessions.values().map(|v| v.len()).sum();
649        let memory_bytes = data.len() * std::mem::size_of::<MemoryItem>();
650
651        Ok(MemoryStats {
652            total_items: data.len(),
653            total_sessions: sessions.len(),
654            total_messages,
655            memory_bytes,
656        })
657    }
658
659    fn memory_type(&self) -> &str {
660        "file-based"
661    }
662}