mcp_session_memory/
memory_store.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5use crate::error::MemoryError;
6use crate::store::{MemoryStore, SessionStore};
7use crate::types::*;
8
9pub struct InMemorySessionStore {
10 sessions: Arc<RwLock<HashMap<String, Session>>>,
11 events: Arc<RwLock<Vec<SessionEvent>>>,
12 snapshots: Arc<RwLock<HashMap<String, ReplaySnapshot>>>,
13}
14
15impl InMemorySessionStore {
16 pub fn new() -> Self {
17 Self {
18 sessions: Arc::new(RwLock::new(HashMap::new())),
19 events: Arc::new(RwLock::new(Vec::new())),
20 snapshots: Arc::new(RwLock::new(HashMap::new())),
21 }
22 }
23}
24
25#[async_trait::async_trait]
26impl SessionStore for InMemorySessionStore {
27 async fn get_session(&self, session_id: &str) -> Result<Session, MemoryError> {
28 self.sessions.read().await.get(session_id).cloned()
29 .ok_or_else(|| MemoryError::NotFound(session_id.into()))
30 }
31
32 async fn put_session(&self, session: Session) -> Result<(), MemoryError> {
33 self.sessions.write().await.insert(session.session_id.clone(), session);
34 Ok(())
35 }
36
37 async fn update_session(&self, session: Session) -> Result<(), MemoryError> {
38 let mut sessions = self.sessions.write().await;
39 if !sessions.contains_key(&session.session_id) {
40 return Err(MemoryError::NotFound(session.session_id));
41 }
42 sessions.insert(session.session_id.clone(), session);
43 Ok(())
44 }
45
46 async fn list_events(&self, session_id: &str, event_types: Option<&[EventType]>, limit: usize) -> Result<Vec<SessionEvent>, MemoryError> {
47 let events = self.events.read().await;
48 let filtered: Vec<_> = events.iter()
49 .filter(|e| e.session_id == session_id)
50 .filter(|e| event_types.is_none_or(|types| types.contains(&e.event_type)))
51 .rev().take(limit).cloned().collect();
52 Ok(filtered)
53 }
54
55 async fn put_event(&self, event: SessionEvent) -> Result<(), MemoryError> {
56 self.events.write().await.push(event);
57 Ok(())
58 }
59
60 async fn put_snapshot(&self, snapshot: ReplaySnapshot) -> Result<(), MemoryError> {
61 self.snapshots.write().await.insert(snapshot.snapshot_id.clone(), snapshot);
62 Ok(())
63 }
64
65 async fn get_snapshot(&self, snapshot_id: &str) -> Result<ReplaySnapshot, MemoryError> {
66 self.snapshots.read().await.get(snapshot_id).cloned()
67 .ok_or_else(|| MemoryError::NotFound(snapshot_id.into()))
68 }
69}
70
71pub struct InMemoryMemoryStore {
72 entries: Arc<RwLock<HashMap<String, MemoryEntry>>>,
73}
74
75impl InMemoryMemoryStore {
76 pub fn new() -> Self {
77 Self { entries: Arc::new(RwLock::new(HashMap::new())) }
78 }
79}
80
81#[async_trait::async_trait]
82impl MemoryStore for InMemoryMemoryStore {
83 async fn retrieve(&self, subject_type: &str, subject_id: &str, memory_types: Option<&[MemoryType]>, query: Option<&str>, limit: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
84 let entries = self.entries.read().await;
85 let results: Vec<_> = entries.values()
86 .filter(|e| e.status == MemoryStatus::Active)
87 .filter(|e| e.subject_type == subject_type && e.subject_id == subject_id)
88 .filter(|e| memory_types.is_none_or(|types| types.contains(&e.memory_type)))
89 .filter(|e| query.is_none_or(|q| e.content.to_lowercase().contains(&q.to_lowercase())))
90 .take(limit).cloned().collect();
91 Ok(results)
92 }
93
94 async fn get_memory(&self, memory_id: &str) -> Result<MemoryEntry, MemoryError> {
95 self.entries.read().await.get(memory_id).cloned()
96 .ok_or_else(|| MemoryError::NotFound(memory_id.into()))
97 }
98
99 async fn put_memory(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
100 self.entries.write().await.insert(entry.memory_id.clone(), entry);
101 Ok(())
102 }
103
104 async fn update_memory(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
105 let mut entries = self.entries.write().await;
106 if !entries.contains_key(&entry.memory_id) {
107 return Err(MemoryError::NotFound(entry.memory_id));
108 }
109 entries.insert(entry.memory_id.clone(), entry);
110 Ok(())
111 }
112
113 async fn delete_memory(&self, memory_id: &str) -> Result<(), MemoryError> {
114 let mut entries = self.entries.write().await;
115 if let Some(entry) = entries.get_mut(memory_id) {
116 entry.status = MemoryStatus::Deleted;
117 entry.updated_at = chrono::Utc::now();
118 Ok(())
119 } else {
120 Err(MemoryError::NotFound(memory_id.into()))
121 }
122 }
123
124 async fn list_by_session(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError> {
125 let entries = self.entries.read().await;
126 Ok(entries.values()
127 .filter(|e| e.source_session_id.as_deref() == Some(session_id) && e.status == MemoryStatus::Active)
128 .cloned().collect())
129 }
130}