mcpkit_actix/
session.rs

1//! Session management for MCP HTTP connections.
2
3use dashmap::DashMap;
4use mcpkit_core::capability::ClientCapabilities;
5use std::collections::VecDeque;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::sync::broadcast;
11
12/// Default session timeout duration (1 hour).
13pub const DEFAULT_SESSION_TIMEOUT: Duration = Duration::from_secs(3600);
14
15/// Default SSE broadcast channel capacity.
16pub const DEFAULT_SSE_CAPACITY: usize = 100;
17
18/// A single MCP session.
19#[derive(Debug, Clone)]
20pub struct Session {
21    /// Unique session identifier.
22    pub id: String,
23    /// When the session was created.
24    pub created_at: Instant,
25    /// When the session was last active.
26    pub last_active: Instant,
27    /// Whether the session has been initialized.
28    pub initialized: bool,
29    /// Client capabilities from initialization.
30    pub client_capabilities: Option<ClientCapabilities>,
31}
32
33impl Session {
34    /// Create a new session.
35    #[must_use]
36    pub fn new(id: String) -> Self {
37        let now = Instant::now();
38        Self {
39            id,
40            created_at: now,
41            last_active: now,
42            initialized: false,
43            client_capabilities: None,
44        }
45    }
46
47    /// Check if the session has expired.
48    #[must_use]
49    pub fn is_expired(&self, timeout: Duration) -> bool {
50        self.last_active.elapsed() >= timeout
51    }
52
53    /// Mark the session as active.
54    pub fn touch(&mut self) {
55        self.last_active = Instant::now();
56    }
57
58    /// Mark the session as initialized.
59    pub fn mark_initialized(&mut self, capabilities: Option<ClientCapabilities>) {
60        self.initialized = true;
61        self.client_capabilities = capabilities;
62    }
63}
64
65/// A stored SSE event for replay support.
66#[derive(Debug, Clone)]
67pub struct StoredEvent {
68    /// The event ID (globally unique within the session stream).
69    pub id: String,
70    /// The event type (e.g., "message", "connected").
71    pub event_type: String,
72    /// The event data.
73    pub data: String,
74    /// When the event was stored.
75    pub stored_at: Instant,
76}
77
78impl StoredEvent {
79    /// Create a new stored event.
80    #[must_use]
81    pub fn new(id: String, event_type: impl Into<String>, data: impl Into<String>) -> Self {
82        Self {
83            id,
84            event_type: event_type.into(),
85            data: data.into(),
86            stored_at: Instant::now(),
87        }
88    }
89}
90
91/// Configuration for event store retention.
92#[derive(Debug, Clone)]
93pub struct EventStoreConfig {
94    /// Maximum number of events to retain per stream.
95    pub max_events: usize,
96    /// Maximum age of events to retain.
97    pub max_age: Duration,
98}
99
100impl Default for EventStoreConfig {
101    fn default() -> Self {
102        Self {
103            max_events: 1000,
104            max_age: Duration::from_secs(300), // 5 minutes
105        }
106    }
107}
108
109impl EventStoreConfig {
110    /// Create a new event store configuration.
111    #[must_use]
112    pub const fn new(max_events: usize, max_age: Duration) -> Self {
113        Self {
114            max_events,
115            max_age,
116        }
117    }
118
119    /// Set the maximum number of events to retain.
120    #[must_use]
121    pub const fn with_max_events(mut self, max_events: usize) -> Self {
122        self.max_events = max_events;
123        self
124    }
125
126    /// Set the maximum age of events to retain.
127    #[must_use]
128    pub const fn with_max_age(mut self, max_age: Duration) -> Self {
129        self.max_age = max_age;
130        self
131    }
132}
133
134/// Event store for SSE message resumability.
135///
136/// Per the MCP Streamable HTTP specification, servers MAY store events
137/// with IDs to support client reconnection with `Last-Event-ID`.
138#[derive(Debug)]
139pub struct EventStore {
140    events: RwLock<VecDeque<StoredEvent>>,
141    config: EventStoreConfig,
142    next_id: AtomicU64,
143}
144
145impl EventStore {
146    /// Create a new event store with the given configuration.
147    #[must_use]
148    pub fn new(config: EventStoreConfig) -> Self {
149        Self {
150            events: RwLock::new(VecDeque::with_capacity(config.max_events)),
151            config,
152            next_id: AtomicU64::new(1),
153        }
154    }
155
156    /// Create a new event store with default configuration.
157    #[must_use]
158    pub fn with_defaults() -> Self {
159        Self::new(EventStoreConfig::default())
160    }
161
162    /// Generate the next event ID.
163    #[must_use]
164    pub fn next_event_id(&self) -> String {
165        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
166        format!("evt-{id}")
167    }
168
169    /// Store an event with automatic ID generation.
170    ///
171    /// Returns the generated event ID.
172    pub fn store_auto_id(&self, event_type: impl Into<String>, data: impl Into<String>) -> String {
173        let id = self.next_event_id();
174        self.store(id.clone(), event_type, data);
175        id
176    }
177
178    /// Store an event with a specific ID.
179    pub fn store(
180        &self,
181        id: impl Into<String>,
182        event_type: impl Into<String>,
183        data: impl Into<String>,
184    ) {
185        let event = StoredEvent::new(id.into(), event_type, data);
186
187        // Use blocking write since we can't use async in this sync method
188        let mut events = futures::executor::block_on(self.events.write());
189
190        events.push_back(event);
191
192        // Enforce max_events limit
193        while events.len() > self.config.max_events {
194            events.pop_front();
195        }
196
197        // Enforce max_age limit
198        let now = Instant::now();
199        while let Some(front) = events.front() {
200            if now.duration_since(front.stored_at) > self.config.max_age {
201                events.pop_front();
202            } else {
203                break;
204            }
205        }
206    }
207
208    /// Store an event asynchronously.
209    pub async fn store_async(
210        &self,
211        id: impl Into<String>,
212        event_type: impl Into<String>,
213        data: impl Into<String>,
214    ) {
215        let event = StoredEvent::new(id.into(), event_type, data);
216        let mut events = self.events.write().await;
217
218        events.push_back(event);
219
220        // Enforce limits
221        while events.len() > self.config.max_events {
222            events.pop_front();
223        }
224
225        let now = Instant::now();
226        while let Some(front) = events.front() {
227            if now.duration_since(front.stored_at) > self.config.max_age {
228                events.pop_front();
229            } else {
230                break;
231            }
232        }
233    }
234
235    /// Get all events after the specified event ID.
236    ///
237    /// Used for replaying events when a client reconnects with `Last-Event-ID`.
238    pub async fn get_events_after(&self, last_event_id: &str) -> Vec<StoredEvent> {
239        let events = self.events.read().await;
240
241        let start_idx = events
242            .iter()
243            .position(|e| e.id == last_event_id)
244            .map_or(0, |i| i + 1);
245
246        events.iter().skip(start_idx).cloned().collect()
247    }
248
249    /// Get all stored events.
250    pub async fn get_all_events(&self) -> Vec<StoredEvent> {
251        let events = self.events.read().await;
252        events.iter().cloned().collect()
253    }
254
255    /// Get the number of stored events.
256    pub async fn len(&self) -> usize {
257        self.events.read().await.len()
258    }
259
260    /// Check if the store is empty.
261    pub async fn is_empty(&self) -> bool {
262        self.events.read().await.is_empty()
263    }
264
265    /// Clear all stored events.
266    pub async fn clear(&self) {
267        self.events.write().await.clear();
268    }
269
270    /// Clean up expired events.
271    pub async fn cleanup_expired(&self) {
272        let mut events = self.events.write().await;
273        let now = Instant::now();
274        while let Some(front) = events.front() {
275            if now.duration_since(front.stored_at) > self.config.max_age {
276                events.pop_front();
277            } else {
278                break;
279            }
280        }
281    }
282}
283
284/// Session manager for SSE connections.
285///
286/// Manages broadcast channels for pushing messages to SSE clients,
287/// with optional event storage for message resumability.
288#[derive(Debug)]
289pub struct SessionManager {
290    sessions: DashMap<String, broadcast::Sender<String>>,
291    /// Event stores for each session (for SSE resumability).
292    event_stores: DashMap<String, Arc<EventStore>>,
293    /// Configuration for event stores.
294    event_store_config: EventStoreConfig,
295    capacity: usize,
296}
297
298impl Default for SessionManager {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304impl SessionManager {
305    /// Create a new session manager.
306    #[must_use]
307    pub fn new() -> Self {
308        Self::with_capacity(DEFAULT_SSE_CAPACITY)
309    }
310
311    /// Create a new session manager with custom channel capacity.
312    #[must_use]
313    pub fn with_capacity(capacity: usize) -> Self {
314        Self {
315            sessions: DashMap::new(),
316            event_stores: DashMap::new(),
317            event_store_config: EventStoreConfig::default(),
318            capacity,
319        }
320    }
321
322    /// Create a new session manager with custom event store configuration.
323    #[must_use]
324    pub fn with_event_store_config(config: EventStoreConfig) -> Self {
325        Self {
326            sessions: DashMap::new(),
327            event_stores: DashMap::new(),
328            event_store_config: config,
329            capacity: DEFAULT_SSE_CAPACITY,
330        }
331    }
332
333    /// Create a new session and return its ID and receiver.
334    #[must_use]
335    pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
336        let id = uuid::Uuid::new_v4().to_string();
337        let (tx, rx) = broadcast::channel(self.capacity);
338        self.sessions.insert(id.clone(), tx);
339
340        // Create an event store for this session
341        let event_store = Arc::new(EventStore::new(self.event_store_config.clone()));
342        self.event_stores.insert(id.clone(), event_store);
343
344        (id, rx)
345    }
346
347    /// Get a receiver for an existing session.
348    #[must_use]
349    pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
350        self.sessions.get(id).map(|tx| tx.subscribe())
351    }
352
353    /// Get the event store for a session.
354    #[must_use]
355    pub fn get_event_store(&self, id: &str) -> Option<Arc<EventStore>> {
356        self.event_stores.get(id).map(|store| Arc::clone(&store))
357    }
358
359    /// Send a message to a specific session.
360    ///
361    /// Returns `true` if the message was sent, `false` if the session doesn't exist.
362    #[must_use]
363    pub fn send_to_session(&self, id: &str, message: String) -> bool {
364        if let Some(tx) = self.sessions.get(id) {
365            let _ = tx.send(message);
366            true
367        } else {
368            false
369        }
370    }
371
372    /// Send a message to a specific session and store it for replay.
373    ///
374    /// Returns the event ID if the message was sent and stored, `None` if the session doesn't exist.
375    #[must_use]
376    pub fn send_to_session_with_storage(
377        &self,
378        session_id: &str,
379        event_type: impl Into<String>,
380        message: String,
381    ) -> Option<String> {
382        if let Some(tx) = self.sessions.get(session_id) {
383            let event_id = if let Some(store) = self.event_stores.get(session_id) {
384                store.store_auto_id(event_type, message.clone())
385            } else {
386                let store = Arc::new(EventStore::new(self.event_store_config.clone()));
387                let event_id = store.store_auto_id(event_type, message.clone());
388                self.event_stores.insert(session_id.to_string(), store);
389                event_id
390            };
391
392            let _ = tx.send(message);
393            Some(event_id)
394        } else {
395            None
396        }
397    }
398
399    /// Broadcast a message to all sessions.
400    pub fn broadcast(&self, message: String) {
401        for entry in &self.sessions {
402            let _ = entry.value().send(message.clone());
403        }
404    }
405
406    /// Broadcast a message to all sessions with storage.
407    pub fn broadcast_with_storage(&self, event_type: impl Into<String> + Clone, message: String) {
408        for entry in &self.sessions {
409            let session_id = entry.key();
410
411            if let Some(store) = self.event_stores.get(session_id) {
412                store.store_auto_id(event_type.clone(), message.clone());
413            }
414
415            let _ = entry.value().send(message.clone());
416        }
417    }
418
419    /// Remove a session.
420    pub fn remove_session(&self, id: &str) {
421        self.sessions.remove(id);
422        self.event_stores.remove(id);
423    }
424
425    /// Get the number of active sessions.
426    #[must_use]
427    pub fn session_count(&self) -> usize {
428        self.sessions.len()
429    }
430
431    /// Clean up expired events across all sessions.
432    pub async fn cleanup_expired_events(&self) {
433        for entry in &self.event_stores {
434            entry.value().cleanup_expired().await;
435        }
436    }
437
438    /// Get events after the specified event ID for replay.
439    pub async fn get_events_for_replay(
440        &self,
441        session_id: &str,
442        last_event_id: &str,
443    ) -> Option<Vec<StoredEvent>> {
444        if let Some(store) = self.event_stores.get(session_id) {
445            Some(store.get_events_after(last_event_id).await)
446        } else {
447            None
448        }
449    }
450}
451
452/// Thread-safe session store with automatic cleanup.
453///
454/// Stores session metadata for HTTP request handling.
455#[derive(Debug)]
456pub struct SessionStore {
457    sessions: DashMap<String, Session>,
458    timeout: Duration,
459}
460
461impl SessionStore {
462    /// Create a new session store with the given timeout.
463    #[must_use]
464    pub fn new(timeout: Duration) -> Self {
465        Self {
466            sessions: DashMap::new(),
467            timeout,
468        }
469    }
470
471    /// Create a new session store with a default 1-hour timeout.
472    #[must_use]
473    pub fn with_default_timeout() -> Self {
474        Self::new(DEFAULT_SESSION_TIMEOUT)
475    }
476
477    /// Create a new session and return its ID.
478    #[must_use]
479    pub fn create(&self) -> String {
480        let id = uuid::Uuid::new_v4().to_string();
481        self.sessions.insert(id.clone(), Session::new(id.clone()));
482        id
483    }
484
485    /// Get a session by ID.
486    #[must_use]
487    pub fn get(&self, id: &str) -> Option<Session> {
488        self.sessions.get(id).map(|r| r.clone())
489    }
490
491    /// Touch a session to update its last active time.
492    pub fn touch(&self, id: &str) {
493        if let Some(mut session) = self.sessions.get_mut(id) {
494            session.touch();
495        }
496    }
497
498    /// Update a session.
499    pub fn update<F>(&self, id: &str, f: F)
500    where
501        F: FnOnce(&mut Session),
502    {
503        if let Some(mut session) = self.sessions.get_mut(id) {
504            f(&mut session);
505        }
506    }
507
508    /// Remove expired sessions.
509    pub fn cleanup_expired(&self) {
510        let timeout = self.timeout;
511        self.sessions.retain(|_, s| !s.is_expired(timeout));
512    }
513
514    /// Remove a session.
515    #[must_use]
516    pub fn remove(&self, id: &str) -> Option<Session> {
517        self.sessions.remove(id).map(|(_, s)| s)
518    }
519
520    /// Get the number of active sessions.
521    #[must_use]
522    pub fn session_count(&self) -> usize {
523        self.sessions.len()
524    }
525
526    /// Start a background task to periodically clean up expired sessions.
527    pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
528        let store = Arc::clone(self);
529        tokio::spawn(async move {
530            loop {
531                tokio::time::sleep(interval).await;
532                store.cleanup_expired();
533            }
534        });
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_session_creation() {
544        let session = Session::new("test-123".to_string());
545        assert_eq!(session.id, "test-123");
546        assert!(!session.initialized);
547        assert!(session.client_capabilities.is_none());
548    }
549
550    #[test]
551    fn test_session_expiry() -> Result<(), Box<dyn std::error::Error>> {
552        let mut session = Session::new("test".to_string());
553        assert!(!session.is_expired(Duration::from_secs(60)));
554
555        session.last_active = Instant::now()
556            .checked_sub(Duration::from_secs(120))
557            .ok_or("Failed to subtract duration")?;
558        assert!(session.is_expired(Duration::from_secs(60)));
559        Ok(())
560    }
561
562    #[test]
563    fn test_session_store() {
564        let store = SessionStore::new(Duration::from_secs(60));
565        let id = store.create();
566
567        assert!(store.get(&id).is_some());
568        store.touch(&id);
569
570        let _ = store.remove(&id);
571        assert!(store.get(&id).is_none());
572    }
573
574    #[tokio::test]
575    async fn test_session_manager() {
576        let manager = SessionManager::new();
577        let (id, mut rx) = manager.create_session();
578
579        assert!(manager.send_to_session(&id, "test message".to_string()));
580
581        let msg = rx.recv().await.expect("Should receive message");
582        assert_eq!(msg, "test message");
583
584        manager.remove_session(&id);
585        assert!(!manager.send_to_session(&id, "another".to_string()));
586    }
587
588    #[tokio::test]
589    async fn test_event_store_creation() {
590        let store = EventStore::with_defaults();
591        assert!(store.is_empty().await);
592        assert_eq!(store.len().await, 0);
593    }
594
595    #[tokio::test]
596    async fn test_event_store_store_and_retrieve() {
597        let store = EventStore::with_defaults();
598
599        store.store_async("evt-1", "message", "data1").await;
600        store.store_async("evt-2", "message", "data2").await;
601        store.store_async("evt-3", "message", "data3").await;
602
603        assert_eq!(store.len().await, 3);
604
605        let all_events = store.get_all_events().await;
606        assert_eq!(all_events.len(), 3);
607        assert_eq!(all_events[0].id, "evt-1");
608        assert_eq!(all_events[1].id, "evt-2");
609        assert_eq!(all_events[2].id, "evt-3");
610    }
611
612    #[tokio::test]
613    async fn test_event_store_get_events_after() {
614        let store = EventStore::with_defaults();
615
616        store.store_async("evt-1", "message", "data1").await;
617        store.store_async("evt-2", "message", "data2").await;
618        store.store_async("evt-3", "message", "data3").await;
619
620        let events = store.get_events_after("evt-1").await;
621        assert_eq!(events.len(), 2);
622        assert_eq!(events[0].id, "evt-2");
623        assert_eq!(events[1].id, "evt-3");
624
625        let events = store.get_events_after("evt-3").await;
626        assert_eq!(events.len(), 0);
627
628        let events = store.get_events_after("unknown").await;
629        assert_eq!(events.len(), 3);
630    }
631
632    #[tokio::test]
633    async fn test_session_manager_with_event_store() -> Result<(), Box<dyn std::error::Error>> {
634        let manager = SessionManager::new();
635        let (id, _rx) = manager.create_session();
636
637        let store = manager.get_event_store(&id);
638        assert!(store.is_some());
639
640        let store = store.ok_or("Event store not found")?;
641        assert!(store.is_empty().await);
642        Ok(())
643    }
644
645    #[tokio::test]
646    async fn test_session_manager_send_with_storage() -> Result<(), Box<dyn std::error::Error>> {
647        let manager = SessionManager::new();
648        let (id, mut rx) = manager.create_session();
649
650        let event_id =
651            manager.send_to_session_with_storage(&id, "message", "test data".to_string());
652        assert!(event_id.is_some());
653
654        let msg = rx.recv().await?;
655        assert_eq!(msg, "test data");
656
657        let store = manager
658            .get_event_store(&id)
659            .ok_or("Event store not found")?;
660        assert_eq!(store.len().await, 1);
661
662        let events = store.get_all_events().await;
663        assert_eq!(events[0].data, "test data");
664        assert_eq!(events[0].event_type, "message");
665        Ok(())
666    }
667
668    #[tokio::test]
669    async fn test_session_manager_replay() -> Result<(), Box<dyn std::error::Error>> {
670        let manager = SessionManager::new();
671        let (id, _rx) = manager.create_session();
672
673        let _ = manager.send_to_session_with_storage(&id, "message", "msg1".to_string());
674        let evt2 = manager.send_to_session_with_storage(&id, "message", "msg2".to_string());
675        let _ = manager.send_to_session_with_storage(&id, "message", "msg3".to_string());
676
677        let events = manager
678            .get_events_for_replay(&id, &evt2.ok_or("Failed to get event ID")?)
679            .await
680            .ok_or("Failed to get events for replay")?;
681
682        assert_eq!(events.len(), 1);
683        assert_eq!(events[0].data, "msg3");
684        Ok(())
685    }
686
687    #[test]
688    fn test_event_store_config() {
689        let config = EventStoreConfig::default();
690        assert_eq!(config.max_events, 1000);
691        assert_eq!(config.max_age, Duration::from_secs(300));
692
693        let config = EventStoreConfig::new(500, Duration::from_secs(120))
694            .with_max_events(600)
695            .with_max_age(Duration::from_secs(180));
696
697        assert_eq!(config.max_events, 600);
698        assert_eq!(config.max_age, Duration::from_secs(180));
699    }
700
701    #[test]
702    fn test_stored_event() {
703        let event = StoredEvent::new("evt-123".to_string(), "message", "test data");
704        assert_eq!(event.id, "evt-123");
705        assert_eq!(event.event_type, "message");
706        assert_eq!(event.data, "test data");
707    }
708}