1use std::collections::VecDeque;
6use std::sync::Arc;
7
8use dashmap::DashMap;
9
10use crate::agent::message::Message;
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Weak;
14use async_trait::async_trait;
15
16use crate::agent::scheduler::Scheduler;
17
18#[async_trait]
20pub trait Memory: Send + Sync {
21 async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()>;
23
24 async fn store_batch(&self, user_id: &str, agent_id: Option<&str>, messages: Vec<Message>) -> crate::error::Result<()> {
26 for msg in messages {
27 self.store(user_id, agent_id, msg).await?;
28 }
29 Ok(())
30 }
31
32 async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message>;
34
35 async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
37 let _ = (user_id, agent_id, query, limit);
38 Ok(Vec::new())
39 }
40
41 async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
43 let _ = (user_id, agent_id, title, content, collection);
44 Ok(())
45 }
46
47 async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()>;
49
50 async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>>;
52
53 async fn update_summary(&self, collection: &str, path: &str, summary: &str) -> crate::error::Result<()> {
55 let _ = (collection, path, summary);
56 Ok(())
57 }
58
59 fn link_scheduler(&self, _scheduler: Weak<Scheduler>) {}
61
62 async fn fetch_document(&self, collection: &str, path: &str) -> crate::error::Result<Option<crate::knowledge::rag::Document>> {
64 let _ = (collection, path);
65 Ok(None)
66 }
67
68 async fn store_session(&self, _session: crate::agent::session::AgentSession) -> crate::error::Result<()> {
70 Ok(())
71 }
72
73 async fn retrieve_session(&self, _session_id: &str) -> crate::error::Result<Option<crate::agent::session::AgentSession>> {
75 Ok(None)
76 }
77}
78
79pub struct ShortTermMemory {
83 max_messages: usize,
85 max_users: usize,
87 store: DashMap<String, VecDeque<Message>>,
89 last_access: DashMap<String, std::time::Instant>,
91 path: PathBuf,
93}
94
95impl ShortTermMemory {
96 pub async fn new(max_messages: usize, max_users: usize, path: impl Into<PathBuf>) -> Self {
98 let path = path.into();
99 let store = DashMap::new();
100 let last_access = DashMap::new();
101
102 let mem = Self {
103 max_messages,
104 max_users,
105 store,
106 last_access,
107 path,
108 };
109
110 if let Err(e) = mem.load().await {
112 tracing::warn!("Failed to load short-term memory from {:?}: {}", mem.path, e);
113 }
114
115 mem
116 }
117
118 pub async fn default_capacity() -> Self {
120 Self::new(100, 1000, "data/short_term_memory.json").await
121 }
122
123 async fn load(&self) -> crate::error::Result<()> {
125 if !self.path.exists() {
126 return Ok(());
127 }
128
129 let content = tokio::fs::read_to_string(&self.path).await
130 .map_err(|e| crate::error::Error::Internal(format!("Failed to read memory file: {}", e)))?;
131
132 if content.trim().is_empty() {
133 return Ok(());
134 }
135
136 let data: HashMap<String, VecDeque<Message>> = serde_json::from_str(&content)
137 .map_err(|e| crate::error::Error::Internal(format!("Failed to parse memory file: {}", e)))?;
138
139 self.store.clear();
140 for (k, v) in data {
141 self.store.insert(k.clone(), v);
142 self.last_access.insert(k, std::time::Instant::now());
143 }
144
145 tracing::info!("Loaded short-term memory for {} users", self.store.len());
146 Ok(())
147 }
148
149 async fn save(&self) -> crate::error::Result<()> {
151 if let Some(parent) = self.path.parent() {
152 tokio::fs::create_dir_all(parent).await.ok();
153 }
154
155 let data: HashMap<_, _> = self.store.iter().map(|r| (r.key().clone(), r.value().clone())).collect();
157
158 let json = serde_json::to_string_pretty(&data)
159 .map_err(|e| crate::error::Error::Internal(format!("Failed to serialize memory: {}", e)))?;
160
161 let tmp_path = self.path.with_extension("tmp");
163 tokio::fs::write(&tmp_path, json).await
164 .map_err(|e| crate::error::Error::Internal(format!("Failed to write temporary memory file: {}", e)))?;
165
166 tokio::fs::rename(tmp_path, &self.path).await
167 .map_err(|e| crate::error::Error::Internal(format!("Failed to rename memory file: {}", e)))?;
168
169 Ok(())
170 }
171
172 pub fn message_count(&self, user_id: &str, agent_id: Option<&str>) -> usize {
174 let key = self.key(user_id, agent_id);
175 self.store.get(&key).map(|v| v.len()).unwrap_or(0)
176 }
177
178 fn key(&self, user_id: &str, agent_id: Option<&str>) -> String {
180 if let Some(agent) = agent_id {
181 format!("{}:{}", user_id, agent)
182 } else {
183 user_id.to_string()
184 }
185 }
186
187 pub fn prune_inactive(&self, duration: std::time::Duration) {
189 let now = std::time::Instant::now();
190 self.last_access.retain(|key, last_time| {
192 let keep = now.duration_since(*last_time) < duration;
193 if !keep {
194 self.store.remove(key);
195 }
196 keep
197 });
198 }
199
200 fn enforce_user_capacity(&self) {
202 if self.store.len() < self.max_users {
203 return;
204 }
205
206 let mut oldest_key = None;
207 let mut oldest_time = std::time::Instant::now();
208
209 for r in self.last_access.iter() {
210 if *r.value() < oldest_time {
211 oldest_time = *r.value();
212 oldest_key = Some(r.key().clone());
213 }
214 }
215
216 if let Some(key) = oldest_key {
217 self.store.remove(&key);
218 self.last_access.remove(&key);
219 }
220 }
221
222 pub async fn pop_oldest(&self, user_id: &str, agent_id: Option<&str>, count: usize) -> Vec<Message> {
224 let key = self.key(user_id, agent_id);
225 let mut popped = Vec::new();
226
227 if let Some(mut entry) = self.store.get_mut(&key) {
228 for _ in 0..count {
229 if let Some(msg) = entry.pop_front() {
230 popped.push(msg);
231 } else {
232 break;
233 }
234 }
235 }
236
237 if !popped.is_empty() {
238 let _ = self.save().await;
240 }
241
242 popped
243 }
244}
245
246#[async_trait]
247impl Memory for ShortTermMemory {
248 async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
249 let key = self.key(user_id, agent_id);
250
251 if !self.store.contains_key(&key) {
253 self.enforce_user_capacity();
254 }
255
256 {
257 let mut entry = self.store.entry(key.clone()).or_default();
258
259 if entry.len() >= self.max_messages {
263 entry.pop_front();
264 }
265 entry.push_back(message);
266 } self.last_access.insert(key, std::time::Instant::now());
270
271 if let Err(e) = self.save().await {
274 tracing::error!("Failed to persist short-term memory: {}", e);
275 }
276
277 Ok(())
278 }
279
280
281
282 async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
283 let key = self.key(user_id, agent_id);
284 self.store
285 .get(&key)
286 .map(|v| {
287 self.last_access.insert(key, std::time::Instant::now());
289
290 let skip = v.len().saturating_sub(limit);
291 v.iter().skip(skip).cloned().collect()
292 })
293 .unwrap_or_default()
294 }
295
296 async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
297 let text = format!("[{}] {}: {}", collection, title, content);
298 self.store(user_id, agent_id, Message::assistant(text)).await
299 }
300
301 async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
302 let key = self.key(user_id, agent_id);
303 self.store.remove(&key);
304 self.last_access.remove(&key);
305
306 self.save().await
307 }
308
309 async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
310 let key = self.key(user_id, agent_id);
311 let msg = {
312 let mut entry = self.store.entry(key.clone()).or_default();
313 entry.pop_back()
314 };
315
316 if msg.is_some() {
317 self.save().await?;
318 }
319
320 Ok(msg)
321 }
322
323 async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
324 let query_lower = query.to_lowercase();
325 let messages = self.retrieve(user_id, agent_id, 1000).await; let mut results = Vec::new();
328 for (i, msg) in messages.iter().enumerate() {
329 let content = msg.text();
330 if content.to_lowercase().contains(&query_lower) {
331 results.push(crate::knowledge::rag::Document {
332 id: format!("stm_{}_{}", self.key(user_id, agent_id), i),
333 title: format!("Recent conversation ({})", msg.role.as_str()),
334 content: content.to_string(),
335 summary: None,
336 collection: None,
337 path: None,
338 metadata: HashMap::new(),
339 score: 0.9, });
341 }
342 if results.len() >= limit {
343 break;
344 }
345 }
346
347 Ok(results)
348 }
349}
350
351pub struct InMemoryMemory {
353 store: DashMap<String, VecDeque<Message>>,
354}
355
356impl InMemoryMemory {
357 pub fn new() -> Self {
359 Self { store: DashMap::new() }
360 }
361
362 fn key(&self, user_id: &str, agent_id: Option<&str>) -> String {
363 if let Some(agent) = agent_id {
364 format!("{}:{}", user_id, agent)
365 } else {
366 user_id.to_string()
367 }
368 }
369}
370
371#[async_trait]
372impl Memory for InMemoryMemory {
373 async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
374 let key = self.key(user_id, agent_id);
375 self.store.entry(key).or_default().push_back(message);
376 Ok(())
377 }
378
379 async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
380 let key = self.key(user_id, agent_id);
381 self.store.get(&key).map(|v| {
382 let skip = v.len().saturating_sub(limit);
383 v.iter().skip(skip).cloned().collect()
384 }).unwrap_or_default()
385 }
386
387 async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
388 let key = self.key(user_id, agent_id);
389 self.store.remove(&key);
390 Ok(())
391 }
392
393 async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
394 let key = self.key(user_id, agent_id);
395 Ok(self.store.get_mut(&key).and_then(|mut v| v.pop_back()))
396 }
397}
398
399pub struct MemoryManager {
401 pub hot_tier: Arc<dyn Memory>,
403 pub cold_tier: Arc<dyn Memory>,
405}
406
407impl MemoryManager {
408 pub fn new(hot_tier: Arc<dyn Memory>, cold_tier: Arc<dyn Memory>) -> Self {
410 Self { hot_tier, cold_tier }
411 }
412
413 pub async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
416 self.hot_tier.store(user_id, agent_id, message).await?;
418
419 Ok(())
426 }
427
428 pub async fn retrieve_unified(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
431 let mut messages = self.hot_tier.retrieve(user_id, agent_id, limit).await;
432
433 if messages.len() < limit {
434 let needed = limit - messages.len();
435 let cold_messages = self.cold_tier.retrieve(user_id, agent_id, needed).await;
436
437 let mut combined = cold_messages;
438 combined.extend(messages);
439 messages = combined;
440 }
441
442 messages
443 }
444
445 pub async fn search_unified(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
447 let hot_results = self.hot_tier.search(user_id, agent_id, query, limit).await?;
448 let cold_results = self.cold_tier.search(user_id, agent_id, query, limit).await?;
449
450 let mut combined = hot_results;
451 for cold_res in cold_results {
452 if !combined.iter().any(|r| r.content == cold_res.content) {
453 combined.push(cold_res);
454 }
455 }
456
457 combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
458 combined.truncate(limit);
459
460 Ok(combined)
461 }
462
463 pub async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
465 let hot_msg = self.hot_tier.undo(user_id, agent_id).await?;
466 let _ = self.cold_tier.undo(user_id, agent_id).await?;
467 Ok(hot_msg)
468 }
469}
470
471#[async_trait]
472impl Memory for MemoryManager {
473 async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
474 self.store(user_id, agent_id, message).await
475 }
476
477 async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
478 self.retrieve_unified(user_id, agent_id, limit).await
479 }
480
481 async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
482 self.search_unified(user_id, agent_id, query, limit).await
483 }
484
485 async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
486 self.cold_tier.store_knowledge(user_id, agent_id, title, content, collection).await
488 }
489
490 async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
491 self.hot_tier.clear(user_id, agent_id).await?;
492 self.cold_tier.clear(user_id, agent_id).await?;
493 Ok(())
494 }
495
496 async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
497 self.undo(user_id, agent_id).await
498 }
499
500 async fn store_session(&self, session: crate::agent::session::AgentSession) -> crate::error::Result<()> {
501 self.cold_tier.store_session(session).await
502 }
503
504 async fn retrieve_session(&self, session_id: &str) -> crate::error::Result<Option<crate::agent::session::AgentSession>> {
505 self.cold_tier.retrieve_session(session_id).await
506 }
507
508 async fn fetch_document(&self, collection: &str, path: &str) -> crate::error::Result<Option<crate::knowledge::rag::Document>> {
509 self.cold_tier.fetch_document(collection, path).await
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[tokio::test]
519 async fn test_short_term_memory() {
520 let memory = ShortTermMemory::new(3, 10, "test_stm.json").await;
521
522 memory.store("user1", None, Message::user("Hello")).await.unwrap();
523 memory.store("user1", None, Message::assistant("Hi there")).await.unwrap();
524 memory.store("user1", None, Message::user("How are you?")).await.unwrap();
525 memory.store("user1", None, Message::assistant("I'm good!")).await.unwrap();
527
528 let messages = memory.retrieve("user1", None, 10).await;
529 assert_eq!(messages.len(), 3);
530 assert_eq!(messages[0].text(), "Hi there");
531
532 let _ = std::fs::remove_file("test_stm.json");
533 }
534}