Skip to main content

mofa_kernel/agent/components/
memory.rs

1//! 记忆组件
2//!
3//! 定义 Agent 的记忆/状态持久化能力
4
5use crate::agent::error::AgentResult;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// 记忆组件 Trait
11///
12/// 负责 Agent 的记忆存储和检索
13///
14/// # 示例
15///
16/// ```rust,ignore
17/// use mofa_kernel::agent::components::memory::{Memory, MemoryValue, MemoryItem};
18///
19/// struct InMemoryStorage {
20///     data: HashMap<String, MemoryValue>,
21/// }
22///
23/// #[async_trait]
24/// impl Memory for InMemoryStorage {
25///     async fn store(&mut self, key: &str, value: MemoryValue) -> AgentResult<()> {
26///         self.data.insert(key.to_string(), value);
27///         Ok(())
28///     }
29///
30///     async fn retrieve(&self, key: &str) -> AgentResult<Option<MemoryValue>> {
31///         Ok(self.data.get(key).cloned())
32///     }
33///
34///     // ... other methods
35/// }
36/// ```
37#[async_trait]
38pub trait Memory: Send + Sync {
39    /// 存储记忆项
40    async fn store(&mut self, key: &str, value: MemoryValue) -> AgentResult<()>;
41
42    /// 检索记忆项
43    async fn retrieve(&self, key: &str) -> AgentResult<Option<MemoryValue>>;
44
45    /// 删除记忆项
46    async fn remove(&mut self, key: &str) -> AgentResult<bool>;
47
48    /// 检查是否存在
49    async fn contains(&self, key: &str) -> AgentResult<bool> {
50        Ok(self.retrieve(key).await?.is_some())
51    }
52
53    /// 语义搜索
54    async fn search(&self, query: &str, limit: usize) -> AgentResult<Vec<MemoryItem>>;
55
56    /// 清空所有记忆
57    async fn clear(&mut self) -> AgentResult<()>;
58
59    /// 获取对话历史
60    async fn get_history(&self, session_id: &str) -> AgentResult<Vec<Message>>;
61
62    /// 添加对话消息
63    async fn add_to_history(&mut self, session_id: &str, message: Message) -> AgentResult<()>;
64
65    /// 清空对话历史
66    async fn clear_history(&mut self, session_id: &str) -> AgentResult<()>;
67
68    /// 获取记忆统计
69    async fn stats(&self) -> AgentResult<MemoryStats> {
70        Ok(MemoryStats::default())
71    }
72
73    /// 记忆类型名称
74    fn memory_type(&self) -> &str {
75        "memory"
76    }
77}
78
79/// 记忆值类型
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub enum MemoryValue {
82    /// 文本
83    Text(String),
84    /// 嵌入向量
85    Embedding(Vec<f32>),
86    /// 结构化数据
87    Structured(serde_json::Value),
88    /// 二进制数据
89    Binary(Vec<u8>),
90    /// 带嵌入的文本
91    TextWithEmbedding { text: String, embedding: Vec<f32> },
92}
93
94impl MemoryValue {
95    /// 创建文本值
96    pub fn text(s: impl Into<String>) -> Self {
97        Self::Text(s.into())
98    }
99
100    /// 创建嵌入向量值
101    pub fn embedding(e: Vec<f32>) -> Self {
102        Self::Embedding(e)
103    }
104
105    /// 创建结构化值
106    pub fn structured(v: serde_json::Value) -> Self {
107        Self::Structured(v)
108    }
109
110    /// 创建带嵌入的文本
111    pub fn text_with_embedding(text: impl Into<String>, embedding: Vec<f32>) -> Self {
112        Self::TextWithEmbedding {
113            text: text.into(),
114            embedding,
115        }
116    }
117
118    /// 获取文本内容
119    pub fn as_text(&self) -> Option<&str> {
120        match self {
121            Self::Text(s) => Some(s),
122            Self::TextWithEmbedding { text, .. } => Some(text),
123            _ => None,
124        }
125    }
126
127    /// 获取嵌入向量
128    pub fn as_embedding(&self) -> Option<&[f32]> {
129        match self {
130            Self::Embedding(e) => Some(e),
131            Self::TextWithEmbedding { embedding, .. } => Some(embedding),
132            _ => None,
133        }
134    }
135
136    /// 获取结构化数据
137    pub fn as_structured(&self) -> Option<&serde_json::Value> {
138        match self {
139            Self::Structured(v) => Some(v),
140            _ => None,
141        }
142    }
143}
144
145impl From<String> for MemoryValue {
146    fn from(s: String) -> Self {
147        Self::Text(s)
148    }
149}
150
151impl From<&str> for MemoryValue {
152    fn from(s: &str) -> Self {
153        Self::Text(s.to_string())
154    }
155}
156
157impl From<serde_json::Value> for MemoryValue {
158    fn from(v: serde_json::Value) -> Self {
159        Self::Structured(v)
160    }
161}
162
163/// 记忆项 (搜索结果)
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct MemoryItem {
166    /// 记忆键
167    pub key: String,
168    /// 记忆值
169    pub value: MemoryValue,
170    /// 相似度分数 (0.0 - 1.0)
171    pub score: f32,
172    /// 元数据
173    pub metadata: HashMap<String, String>,
174    /// 创建时间
175    pub created_at: u64,
176    /// 最后访问时间
177    pub last_accessed: u64,
178}
179
180impl MemoryItem {
181    /// 创建新的记忆项
182    pub fn new(key: impl Into<String>, value: MemoryValue) -> Self {
183        let now = std::time::SystemTime::now()
184            .duration_since(std::time::UNIX_EPOCH)
185            .unwrap_or_default()
186            .as_millis() as u64;
187
188        Self {
189            key: key.into(),
190            value,
191            score: 1.0,
192            metadata: HashMap::new(),
193            created_at: now,
194            last_accessed: now,
195        }
196    }
197
198    /// 设置分数
199    pub fn with_score(mut self, score: f32) -> Self {
200        self.score = score.clamp(0.0, 1.0);
201        self
202    }
203
204    /// 添加元数据
205    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
206        self.metadata.insert(key.into(), value.into());
207        self
208    }
209}
210
211/// 对话消息
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct Message {
214    /// 消息角色
215    pub role: MessageRole,
216    /// 消息内容
217    pub content: String,
218    /// 时间戳
219    pub timestamp: u64,
220    /// 元数据
221    pub metadata: HashMap<String, serde_json::Value>,
222}
223
224impl Message {
225    /// 创建新消息
226    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
227        let now = std::time::SystemTime::now()
228            .duration_since(std::time::UNIX_EPOCH)
229            .unwrap_or_default()
230            .as_millis() as u64;
231
232        Self {
233            role,
234            content: content.into(),
235            timestamp: now,
236            metadata: HashMap::new(),
237        }
238    }
239
240    /// 创建系统消息
241    pub fn system(content: impl Into<String>) -> Self {
242        Self::new(MessageRole::System, content)
243    }
244
245    /// 创建用户消息
246    pub fn user(content: impl Into<String>) -> Self {
247        Self::new(MessageRole::User, content)
248    }
249
250    /// 创建助手消息
251    pub fn assistant(content: impl Into<String>) -> Self {
252        Self::new(MessageRole::Assistant, content)
253    }
254
255    /// 创建工具消息
256    pub fn tool(tool_name: impl Into<String>, content: impl Into<String>) -> Self {
257        let mut msg = Self::new(MessageRole::Tool, content);
258        msg.metadata.insert(
259            "tool_name".to_string(),
260            serde_json::Value::String(tool_name.into()),
261        );
262        msg
263    }
264
265    /// 添加元数据
266    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
267        self.metadata.insert(key.into(), value);
268        self
269    }
270}
271
272/// 消息角色
273#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
274pub enum MessageRole {
275    /// 系统消息
276    System,
277    /// 用户消息
278    User,
279    /// 助手消息
280    Assistant,
281    /// 工具消息
282    Tool,
283}
284
285impl std::fmt::Display for MessageRole {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        match self {
288            Self::System => write!(f, "system"),
289            Self::User => write!(f, "user"),
290            Self::Assistant => write!(f, "assistant"),
291            Self::Tool => write!(f, "tool"),
292        }
293    }
294}
295
296/// 记忆统计
297#[derive(Debug, Clone, Default, Serialize, Deserialize)]
298pub struct MemoryStats {
299    /// 总记忆项数
300    pub total_items: usize,
301    /// 总对话会话数
302    pub total_sessions: usize,
303    /// 总消息数
304    pub total_messages: usize,
305    /// 内存使用 (字节)
306    pub memory_bytes: usize,
307}