mcpkit_rocket/
session.rs

1//! Session management for MCP Rocket integration.
2
3use dashmap::DashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// Session manager for tracking MCP client sessions.
10#[derive(Clone)]
11pub struct SessionStore {
12    sessions: Arc<DashMap<String, SessionState>>,
13    sse_channels: Arc<DashMap<String, broadcast::Sender<String>>>,
14}
15
16struct SessionState {
17    #[allow(dead_code)] // Reserved for future session metadata/debugging
18    created_at: Instant,
19    last_seen: Instant,
20}
21
22impl Default for SessionStore {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl SessionStore {
29    /// Create a new session store.
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            sessions: Arc::new(DashMap::new()),
34            sse_channels: Arc::new(DashMap::new()),
35        }
36    }
37
38    /// Create a new session and return its ID.
39    #[must_use]
40    pub fn create(&self) -> String {
41        let id = Uuid::new_v4().to_string();
42        let now = Instant::now();
43        self.sessions.insert(
44            id.clone(),
45            SessionState {
46                created_at: now,
47                last_seen: now,
48            },
49        );
50        id
51    }
52
53    /// Update the last seen time for a session.
54    pub fn touch(&self, id: &str) {
55        if let Some(mut session) = self.sessions.get_mut(id) {
56            session.last_seen = Instant::now();
57        }
58    }
59
60    /// Check if a session exists.
61    #[must_use]
62    pub fn exists(&self, id: &str) -> bool {
63        self.sessions.contains_key(id)
64    }
65
66    /// Get or create an SSE channel for a session.
67    #[must_use]
68    pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
69        let id = self.create();
70        let (tx, rx) = broadcast::channel(100);
71        self.sse_channels.insert(id.clone(), tx);
72        (id, rx)
73    }
74
75    /// Get a receiver for an existing SSE session.
76    #[must_use]
77    pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
78        self.sse_channels.get(id).map(|tx| tx.subscribe())
79    }
80
81    /// Send a message to an SSE session.
82    pub fn send(
83        &self,
84        id: &str,
85        msg: String,
86    ) -> Result<usize, broadcast::error::SendError<String>> {
87        self.sse_channels
88            .get(id)
89            .ok_or_else(|| broadcast::error::SendError(msg.clone()))
90            .and_then(|tx| tx.send(msg))
91    }
92
93    /// Remove sessions older than the given duration.
94    pub fn cleanup(&self, max_age: Duration) {
95        let now = Instant::now();
96        self.sessions
97            .retain(|_, session| now.duration_since(session.last_seen) < max_age);
98    }
99}
100
101/// Session manager trait for managing MCP sessions.
102pub trait SessionManager {
103    /// Create a new session.
104    fn create_session(&self) -> String;
105
106    /// Touch a session to update its last seen time.
107    fn touch_session(&self, id: &str);
108
109    /// Check if a session exists.
110    fn session_exists(&self, id: &str) -> bool;
111}
112
113impl SessionManager for SessionStore {
114    fn create_session(&self) -> String {
115        self.create()
116    }
117
118    fn touch_session(&self, id: &str) {
119        self.touch(id);
120    }
121
122    fn session_exists(&self, id: &str) -> bool {
123        self.exists(id)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_session_store_creation() {
133        let store = SessionStore::new();
134        let id = store.create();
135
136        assert!(!id.is_empty());
137        assert!(store.exists(&id));
138    }
139
140    #[test]
141    fn test_session_store_default() {
142        let store = SessionStore::default();
143        let id = store.create();
144
145        assert!(store.exists(&id));
146    }
147
148    #[test]
149    fn test_session_store_touch() {
150        let store = SessionStore::new();
151        let id = store.create();
152
153        // Touch should not panic
154        store.touch(&id);
155        assert!(store.exists(&id));
156
157        // Touching non-existent session should be no-op
158        store.touch("non-existent");
159    }
160
161    #[test]
162    fn test_session_store_exists() {
163        let store = SessionStore::new();
164        let id = store.create();
165
166        assert!(store.exists(&id));
167        assert!(!store.exists("non-existent"));
168    }
169
170    #[test]
171    fn test_session_store_cleanup() {
172        let store = SessionStore::new();
173        let id = store.create();
174
175        // Session should exist before cleanup with long max_age
176        assert!(store.exists(&id));
177
178        // Cleanup with 0 duration should remove all sessions
179        store.cleanup(Duration::from_secs(0));
180        assert!(!store.exists(&id));
181    }
182
183    #[tokio::test]
184    async fn test_session_store_sse_channel() {
185        let store = SessionStore::new();
186        let (id, mut rx) = store.create_session();
187
188        // Session should exist
189        assert!(store.exists(&id));
190
191        // Send a message
192        let result = store.send(&id, "test message".to_string());
193        assert!(result.is_ok());
194
195        // Receive the message
196        let msg = rx.recv().await.unwrap();
197        assert_eq!(msg, "test message");
198    }
199
200    #[test]
201    fn test_session_store_get_receiver() {
202        let store = SessionStore::new();
203        let (id, _rx) = store.create_session();
204
205        // Should be able to get another receiver
206        let rx2 = store.get_receiver(&id);
207        assert!(rx2.is_some());
208
209        // Non-existent session should return None
210        let rx3 = store.get_receiver("non-existent");
211        assert!(rx3.is_none());
212    }
213
214    #[test]
215    fn test_session_store_send_to_nonexistent() {
216        let store = SessionStore::new();
217
218        // Sending to non-existent session should fail
219        let result = store.send("non-existent", "test".to_string());
220        assert!(result.is_err());
221    }
222
223    #[test]
224    fn test_session_manager_trait() {
225        let store = SessionStore::new();
226
227        // Test via trait
228        let id = SessionManager::create_session(&store);
229        assert!(SessionManager::session_exists(&store, &id));
230
231        SessionManager::touch_session(&store, &id);
232        assert!(SessionManager::session_exists(&store, &id));
233
234        assert!(!SessionManager::session_exists(&store, "non-existent"));
235    }
236
237    #[test]
238    fn test_multiple_sessions() {
239        let store = SessionStore::new();
240
241        let id1 = store.create();
242        let id2 = store.create();
243        let id3 = store.create();
244
245        assert!(store.exists(&id1));
246        assert!(store.exists(&id2));
247        assert!(store.exists(&id3));
248
249        // All IDs should be unique
250        assert_ne!(id1, id2);
251        assert_ne!(id2, id3);
252        assert_ne!(id1, id3);
253    }
254}