Skip to main content

mcp_session_memory/
memory_store.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5use crate::error::MemoryError;
6use crate::store::{MemoryStore, SessionStore};
7use crate::types::*;
8
9pub struct InMemorySessionStore {
10    sessions: Arc<RwLock<HashMap<String, Session>>>,
11    events: Arc<RwLock<Vec<SessionEvent>>>,
12    snapshots: Arc<RwLock<HashMap<String, ReplaySnapshot>>>,
13}
14
15impl InMemorySessionStore {
16    pub fn new() -> Self {
17        Self {
18            sessions: Arc::new(RwLock::new(HashMap::new())),
19            events: Arc::new(RwLock::new(Vec::new())),
20            snapshots: Arc::new(RwLock::new(HashMap::new())),
21        }
22    }
23}
24
25#[async_trait::async_trait]
26impl SessionStore for InMemorySessionStore {
27    async fn get_session(&self, session_id: &str) -> Result<Session, MemoryError> {
28        self.sessions.read().await.get(session_id).cloned()
29            .ok_or_else(|| MemoryError::NotFound(session_id.into()))
30    }
31
32    async fn put_session(&self, session: Session) -> Result<(), MemoryError> {
33        self.sessions.write().await.insert(session.session_id.clone(), session);
34        Ok(())
35    }
36
37    async fn update_session(&self, session: Session) -> Result<(), MemoryError> {
38        let mut sessions = self.sessions.write().await;
39        if !sessions.contains_key(&session.session_id) {
40            return Err(MemoryError::NotFound(session.session_id));
41        }
42        sessions.insert(session.session_id.clone(), session);
43        Ok(())
44    }
45
46    async fn list_events(&self, session_id: &str, event_types: Option<&[EventType]>, limit: usize) -> Result<Vec<SessionEvent>, MemoryError> {
47        let events = self.events.read().await;
48        let filtered: Vec<_> = events.iter()
49            .filter(|e| e.session_id == session_id)
50            .filter(|e| event_types.is_none_or(|types| types.contains(&e.event_type)))
51            .rev().take(limit).cloned().collect();
52        Ok(filtered)
53    }
54
55    async fn put_event(&self, event: SessionEvent) -> Result<(), MemoryError> {
56        self.events.write().await.push(event);
57        Ok(())
58    }
59
60    async fn put_snapshot(&self, snapshot: ReplaySnapshot) -> Result<(), MemoryError> {
61        self.snapshots.write().await.insert(snapshot.snapshot_id.clone(), snapshot);
62        Ok(())
63    }
64
65    async fn get_snapshot(&self, snapshot_id: &str) -> Result<ReplaySnapshot, MemoryError> {
66        self.snapshots.read().await.get(snapshot_id).cloned()
67            .ok_or_else(|| MemoryError::NotFound(snapshot_id.into()))
68    }
69}
70
71pub struct InMemoryMemoryStore {
72    entries: Arc<RwLock<HashMap<String, MemoryEntry>>>,
73}
74
75impl InMemoryMemoryStore {
76    pub fn new() -> Self {
77        Self { entries: Arc::new(RwLock::new(HashMap::new())) }
78    }
79}
80
81#[async_trait::async_trait]
82impl MemoryStore for InMemoryMemoryStore {
83    async fn retrieve(&self, subject_type: &str, subject_id: &str, memory_types: Option<&[MemoryType]>, query: Option<&str>, limit: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
84        let entries = self.entries.read().await;
85        let results: Vec<_> = entries.values()
86            .filter(|e| e.status == MemoryStatus::Active)
87            .filter(|e| e.subject_type == subject_type && e.subject_id == subject_id)
88            .filter(|e| memory_types.is_none_or(|types| types.contains(&e.memory_type)))
89            .filter(|e| query.is_none_or(|q| e.content.to_lowercase().contains(&q.to_lowercase())))
90            .take(limit).cloned().collect();
91        Ok(results)
92    }
93
94    async fn get_memory(&self, memory_id: &str) -> Result<MemoryEntry, MemoryError> {
95        self.entries.read().await.get(memory_id).cloned()
96            .ok_or_else(|| MemoryError::NotFound(memory_id.into()))
97    }
98
99    async fn put_memory(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
100        self.entries.write().await.insert(entry.memory_id.clone(), entry);
101        Ok(())
102    }
103
104    async fn update_memory(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
105        let mut entries = self.entries.write().await;
106        if !entries.contains_key(&entry.memory_id) {
107            return Err(MemoryError::NotFound(entry.memory_id));
108        }
109        entries.insert(entry.memory_id.clone(), entry);
110        Ok(())
111    }
112
113    async fn delete_memory(&self, memory_id: &str) -> Result<(), MemoryError> {
114        let mut entries = self.entries.write().await;
115        if let Some(entry) = entries.get_mut(memory_id) {
116            entry.status = MemoryStatus::Deleted;
117            entry.updated_at = chrono::Utc::now();
118            Ok(())
119        } else {
120            Err(MemoryError::NotFound(memory_id.into()))
121        }
122    }
123
124    async fn list_by_session(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError> {
125        let entries = self.entries.read().await;
126        Ok(entries.values()
127            .filter(|e| e.source_session_id.as_deref() == Some(session_id) && e.status == MemoryStatus::Active)
128            .cloned().collect())
129    }
130}