mcpkit_warp/
session.rs

1//! Session management for MCP Warp integration.
2
3use dashmap::DashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::broadcast;
7use uuid::Uuid;
8
9/// Session manager for tracking MCP client sessions.
10#[derive(Clone)]
11pub struct SessionStore {
12    sessions: Arc<DashMap<String, SessionState>>,
13    sse_channels: Arc<DashMap<String, broadcast::Sender<String>>>,
14}
15
16struct SessionState {
17    last_seen: Instant,
18}
19
20impl Default for SessionStore {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl SessionStore {
27    /// Create a new session store.
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            sessions: Arc::new(DashMap::new()),
32            sse_channels: Arc::new(DashMap::new()),
33        }
34    }
35
36    /// Create a new session and return its ID.
37    #[must_use]
38    pub fn create(&self) -> String {
39        let id = Uuid::new_v4().to_string();
40        let now = Instant::now();
41        self.sessions
42            .insert(id.clone(), SessionState { last_seen: now });
43        id
44    }
45
46    /// Update the last seen time for a session.
47    pub fn touch(&self, id: &str) {
48        if let Some(mut session) = self.sessions.get_mut(id) {
49            session.last_seen = Instant::now();
50        }
51    }
52
53    /// Check if a session exists.
54    #[must_use]
55    pub fn exists(&self, id: &str) -> bool {
56        self.sessions.contains_key(id)
57    }
58
59    /// Get or create an SSE channel for a session.
60    #[must_use]
61    pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
62        let id = self.create();
63        let (tx, rx) = broadcast::channel(100);
64        self.sse_channels.insert(id.clone(), tx);
65        (id, rx)
66    }
67
68    /// Get a receiver for an existing SSE session.
69    #[must_use]
70    pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
71        self.sse_channels.get(id).map(|tx| tx.subscribe())
72    }
73
74    /// Remove sessions older than the given duration.
75    pub fn cleanup(&self, max_age: Duration) {
76        let now = Instant::now();
77        self.sessions
78            .retain(|_, session| now.duration_since(session.last_seen) < max_age);
79    }
80}
81
82/// Session manager trait for managing MCP sessions.
83pub trait SessionManager {
84    /// Create a new session.
85    fn create_session(&self) -> String;
86
87    /// Touch a session to update its last seen time.
88    fn touch_session(&self, id: &str);
89
90    /// Check if a session exists.
91    fn session_exists(&self, id: &str) -> bool;
92}
93
94impl SessionManager for SessionStore {
95    fn create_session(&self) -> String {
96        self.create()
97    }
98
99    fn touch_session(&self, id: &str) {
100        self.touch(id);
101    }
102
103    fn session_exists(&self, id: &str) -> bool {
104        self.exists(id)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_session_store_creation() {
114        let store = SessionStore::new();
115        let id = store.create();
116
117        assert!(!id.is_empty());
118        assert!(store.exists(&id));
119    }
120
121    #[test]
122    fn test_session_store_default() {
123        let store = SessionStore::default();
124        let id = store.create();
125
126        assert!(store.exists(&id));
127    }
128
129    #[test]
130    fn test_session_store_touch() {
131        let store = SessionStore::new();
132        let id = store.create();
133
134        // Touch should not panic
135        store.touch(&id);
136        assert!(store.exists(&id));
137
138        // Touching non-existent session should be no-op
139        store.touch("non-existent");
140    }
141
142    #[test]
143    fn test_session_store_exists() {
144        let store = SessionStore::new();
145        let id = store.create();
146
147        assert!(store.exists(&id));
148        assert!(!store.exists("non-existent"));
149    }
150
151    #[test]
152    fn test_session_store_cleanup() {
153        let store = SessionStore::new();
154        let id = store.create();
155
156        // Session should exist before cleanup with long max_age
157        assert!(store.exists(&id));
158
159        // Cleanup with 0 duration should remove all sessions
160        store.cleanup(Duration::from_secs(0));
161        assert!(!store.exists(&id));
162    }
163
164    #[tokio::test]
165    async fn test_session_store_sse_channel() {
166        let store = SessionStore::new();
167        let (id, mut rx) = store.create_session();
168
169        // Get the sender and send
170        let tx = store.sse_channels.get(&id).unwrap();
171        tx.send("test message".to_string()).unwrap();
172        drop(tx);
173
174        // Receive the message
175        let msg = rx.recv().await.unwrap();
176        assert_eq!(msg, "test message");
177    }
178
179    #[test]
180    fn test_session_store_get_receiver() {
181        let store = SessionStore::new();
182        let (id, _rx) = store.create_session();
183
184        // Should be able to get another receiver
185        let rx2 = store.get_receiver(&id);
186        assert!(rx2.is_some());
187
188        // Non-existent session should return None
189        let rx3 = store.get_receiver("non-existent");
190        assert!(rx3.is_none());
191    }
192
193    #[test]
194    fn test_session_manager_trait() {
195        let store = SessionStore::new();
196
197        // Test via trait
198        let id = SessionManager::create_session(&store);
199        assert!(SessionManager::session_exists(&store, &id));
200
201        SessionManager::touch_session(&store, &id);
202        assert!(SessionManager::session_exists(&store, &id));
203
204        assert!(!SessionManager::session_exists(&store, "non-existent"));
205    }
206
207    #[test]
208    fn test_multiple_sessions() {
209        let store = SessionStore::new();
210
211        let id1 = store.create();
212        let id2 = store.create();
213        let id3 = store.create();
214
215        assert!(store.exists(&id1));
216        assert!(store.exists(&id2));
217        assert!(store.exists(&id3));
218
219        // All IDs should be unique
220        assert_ne!(id1, id2);
221        assert_ne!(id2, id3);
222        assert_ne!(id1, id3);
223    }
224}