mcpkit_axum/
session.rs

1//! Session management for MCP HTTP connections.
2
3use dashmap::DashMap;
4use mcpkit_core::capability::ClientCapabilities;
5use std::collections::VecDeque;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::sync::broadcast;
11
12/// A single MCP session.
13#[derive(Debug, Clone)]
14pub struct Session {
15    /// Unique session identifier.
16    pub id: String,
17    /// When the session was created.
18    pub created_at: Instant,
19    /// When the session was last active.
20    pub last_active: Instant,
21    /// Whether the session has been initialized.
22    pub initialized: bool,
23    /// Client capabilities from initialization.
24    pub client_capabilities: Option<ClientCapabilities>,
25}
26
27impl Session {
28    /// Create a new session.
29    #[must_use]
30    pub fn new(id: String) -> Self {
31        let now = Instant::now();
32        Self {
33            id,
34            created_at: now,
35            last_active: now,
36            initialized: false,
37            client_capabilities: None,
38        }
39    }
40
41    /// Check if the session has expired.
42    #[must_use]
43    pub fn is_expired(&self, timeout: Duration) -> bool {
44        self.last_active.elapsed() >= timeout
45    }
46
47    /// Mark the session as active.
48    pub fn touch(&mut self) {
49        self.last_active = Instant::now();
50    }
51
52    /// Mark the session as initialized.
53    pub fn mark_initialized(&mut self, capabilities: Option<ClientCapabilities>) {
54        self.initialized = true;
55        self.client_capabilities = capabilities;
56    }
57}
58
59/// A stored SSE event for replay support.
60#[derive(Debug, Clone)]
61pub struct StoredEvent {
62    /// The event ID (globally unique within the session stream).
63    pub id: String,
64    /// The event type (e.g., "message", "connected").
65    pub event_type: String,
66    /// The event data.
67    pub data: String,
68    /// When the event was stored.
69    pub stored_at: Instant,
70}
71
72impl StoredEvent {
73    /// Create a new stored event.
74    #[must_use]
75    pub fn new(id: String, event_type: impl Into<String>, data: impl Into<String>) -> Self {
76        Self {
77            id,
78            event_type: event_type.into(),
79            data: data.into(),
80            stored_at: Instant::now(),
81        }
82    }
83}
84
85/// Configuration for event store retention.
86#[derive(Debug, Clone)]
87pub struct EventStoreConfig {
88    /// Maximum number of events to retain per stream.
89    pub max_events: usize,
90    /// Maximum age of events to retain.
91    pub max_age: Duration,
92}
93
94impl Default for EventStoreConfig {
95    fn default() -> Self {
96        Self {
97            max_events: 1000,
98            max_age: Duration::from_secs(300), // 5 minutes
99        }
100    }
101}
102
103impl EventStoreConfig {
104    /// Create a new event store configuration.
105    #[must_use]
106    pub const fn new(max_events: usize, max_age: Duration) -> Self {
107        Self {
108            max_events,
109            max_age,
110        }
111    }
112
113    /// Set the maximum number of events to retain.
114    #[must_use]
115    pub const fn with_max_events(mut self, max_events: usize) -> Self {
116        self.max_events = max_events;
117        self
118    }
119
120    /// Set the maximum age of events to retain.
121    #[must_use]
122    pub const fn with_max_age(mut self, max_age: Duration) -> Self {
123        self.max_age = max_age;
124        self
125    }
126}
127
128/// Event store for SSE message resumability.
129///
130/// Per the MCP Streamable HTTP specification, servers MAY store events
131/// with IDs to support client reconnection with `Last-Event-ID`.
132///
133/// # Example
134///
135/// ```rust
136/// use mcpkit_axum::{EventStore, EventStoreConfig};
137/// use std::time::Duration;
138///
139/// let config = EventStoreConfig::new(500, Duration::from_secs(120));
140/// let store = EventStore::new(config);
141///
142/// // Store an event
143/// store.store("evt-001", "message", r#"{"jsonrpc":"2.0",...}"#);
144///
145/// // Get events after a specific ID for replay (async)
146/// // let events = store.get_events_after("evt-000").await;
147/// ```
148#[derive(Debug)]
149pub struct EventStore {
150    events: RwLock<VecDeque<StoredEvent>>,
151    config: EventStoreConfig,
152    next_id: AtomicU64,
153}
154
155impl EventStore {
156    /// Create a new event store with the given configuration.
157    #[must_use]
158    pub fn new(config: EventStoreConfig) -> Self {
159        Self {
160            events: RwLock::new(VecDeque::with_capacity(config.max_events)),
161            config,
162            next_id: AtomicU64::new(1),
163        }
164    }
165
166    /// Create a new event store with default configuration.
167    #[must_use]
168    pub fn with_defaults() -> Self {
169        Self::new(EventStoreConfig::default())
170    }
171
172    /// Generate the next event ID.
173    #[must_use]
174    pub fn next_event_id(&self) -> String {
175        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
176        format!("evt-{id}")
177    }
178
179    /// Store an event with automatic ID generation.
180    ///
181    /// Returns the generated event ID.
182    pub fn store_auto_id(&self, event_type: impl Into<String>, data: impl Into<String>) -> String {
183        let id = self.next_event_id();
184        self.store(id.clone(), event_type, data);
185        id
186    }
187
188    /// Store an event with a specific ID.
189    pub fn store(
190        &self,
191        id: impl Into<String>,
192        event_type: impl Into<String>,
193        data: impl Into<String>,
194    ) {
195        let event = StoredEvent::new(id.into(), event_type, data);
196
197        // Use blocking write since we can't use async in this sync method
198        // In production, consider using parking_lot::RwLock for better sync performance
199        let mut events = futures::executor::block_on(self.events.write());
200
201        // Add the new event
202        events.push_back(event);
203
204        // Enforce max_events limit
205        while events.len() > self.config.max_events {
206            events.pop_front();
207        }
208
209        // Enforce max_age limit
210        let now = Instant::now();
211        while let Some(front) = events.front() {
212            if now.duration_since(front.stored_at) > self.config.max_age {
213                events.pop_front();
214            } else {
215                break;
216            }
217        }
218    }
219
220    /// Store an event asynchronously.
221    pub async fn store_async(
222        &self,
223        id: impl Into<String>,
224        event_type: impl Into<String>,
225        data: impl Into<String>,
226    ) {
227        let event = StoredEvent::new(id.into(), event_type, data);
228        let mut events = self.events.write().await;
229
230        events.push_back(event);
231
232        // Enforce limits
233        while events.len() > self.config.max_events {
234            events.pop_front();
235        }
236
237        let now = Instant::now();
238        while let Some(front) = events.front() {
239            if now.duration_since(front.stored_at) > self.config.max_age {
240                events.pop_front();
241            } else {
242                break;
243            }
244        }
245    }
246
247    /// Get all events after the specified event ID.
248    ///
249    /// Used for replaying events when a client reconnects with `Last-Event-ID`.
250    /// Returns events in chronological order.
251    pub async fn get_events_after(&self, last_event_id: &str) -> Vec<StoredEvent> {
252        let events = self.events.read().await;
253
254        // Find the index of the last event ID
255        // Start from the next event after last_event_id, or 0 if not found
256        let start_idx = events
257            .iter()
258            .position(|e| e.id == last_event_id)
259            .map_or(0, |i| i + 1);
260
261        events.iter().skip(start_idx).cloned().collect()
262    }
263
264    /// Get all stored events.
265    pub async fn get_all_events(&self) -> Vec<StoredEvent> {
266        let events = self.events.read().await;
267        events.iter().cloned().collect()
268    }
269
270    /// Get the number of stored events.
271    pub async fn len(&self) -> usize {
272        self.events.read().await.len()
273    }
274
275    /// Check if the store is empty.
276    pub async fn is_empty(&self) -> bool {
277        self.events.read().await.is_empty()
278    }
279
280    /// Clear all stored events.
281    pub async fn clear(&self) {
282        self.events.write().await.clear();
283    }
284
285    /// Clean up expired events.
286    pub async fn cleanup_expired(&self) {
287        let mut events = self.events.write().await;
288        let now = Instant::now();
289        while let Some(front) = events.front() {
290            if now.duration_since(front.stored_at) > self.config.max_age {
291                events.pop_front();
292            } else {
293                break;
294            }
295        }
296    }
297}
298
299/// Session manager for SSE connections.
300///
301/// Manages broadcast channels for pushing messages to SSE clients,
302/// with optional event storage for message resumability.
303#[derive(Debug)]
304pub struct SessionManager {
305    sessions: DashMap<String, broadcast::Sender<String>>,
306    /// Event stores for each session (for SSE resumability).
307    event_stores: DashMap<String, Arc<EventStore>>,
308    /// Configuration for event stores.
309    event_store_config: EventStoreConfig,
310}
311
312impl Default for SessionManager {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318impl SessionManager {
319    /// Create a new session manager.
320    #[must_use]
321    pub fn new() -> Self {
322        Self {
323            sessions: DashMap::new(),
324            event_stores: DashMap::new(),
325            event_store_config: EventStoreConfig::default(),
326        }
327    }
328
329    /// Create a new session manager with custom event store configuration.
330    #[must_use]
331    pub fn with_event_store_config(config: EventStoreConfig) -> Self {
332        Self {
333            sessions: DashMap::new(),
334            event_stores: DashMap::new(),
335            event_store_config: config,
336        }
337    }
338
339    /// Create a new session and return its ID and receiver.
340    #[must_use]
341    pub fn create_session(&self) -> (String, broadcast::Receiver<String>) {
342        let id = uuid::Uuid::new_v4().to_string();
343        let (tx, rx) = broadcast::channel(100);
344        self.sessions.insert(id.clone(), tx);
345
346        // Create an event store for this session
347        let event_store = Arc::new(EventStore::new(self.event_store_config.clone()));
348        self.event_stores.insert(id.clone(), event_store);
349
350        (id, rx)
351    }
352
353    /// Get a receiver for an existing session.
354    #[must_use]
355    pub fn get_receiver(&self, id: &str) -> Option<broadcast::Receiver<String>> {
356        self.sessions.get(id).map(|tx| tx.subscribe())
357    }
358
359    /// Get the event store for a session.
360    #[must_use]
361    pub fn get_event_store(&self, id: &str) -> Option<Arc<EventStore>> {
362        self.event_stores.get(id).map(|store| Arc::clone(&store))
363    }
364
365    /// Send a message to a specific session.
366    ///
367    /// Returns `true` if the message was sent, `false` if the session doesn't exist.
368    #[must_use]
369    pub fn send_to_session(&self, id: &str, message: String) -> bool {
370        if let Some(tx) = self.sessions.get(id) {
371            // Ignore send errors (no receivers)
372            let _ = tx.send(message);
373            true
374        } else {
375            false
376        }
377    }
378
379    /// Send a message to a specific session and store it for replay.
380    ///
381    /// This method stores the event in the event store before sending,
382    /// enabling message resumability for clients that reconnect.
383    ///
384    /// Returns the event ID if the message was sent and stored, `None` if the session doesn't exist.
385    #[must_use]
386    pub fn send_to_session_with_storage(
387        &self,
388        session_id: &str,
389        event_type: impl Into<String>,
390        message: String,
391    ) -> Option<String> {
392        if let Some(tx) = self.sessions.get(session_id) {
393            // Store the event first
394            let event_id = if let Some(store) = self.event_stores.get(session_id) {
395                store.store_auto_id(event_type, message.clone())
396            } else {
397                // Create a store if it doesn't exist (shouldn't happen normally)
398                let store = Arc::new(EventStore::new(self.event_store_config.clone()));
399                let event_id = store.store_auto_id(event_type, message.clone());
400                self.event_stores.insert(session_id.to_string(), store);
401                event_id
402            };
403
404            // Send the message
405            let _ = tx.send(message);
406            Some(event_id)
407        } else {
408            None
409        }
410    }
411
412    /// Broadcast a message to all sessions.
413    pub fn broadcast(&self, message: String) {
414        for entry in &self.sessions {
415            let _ = entry.value().send(message.clone());
416        }
417    }
418
419    /// Broadcast a message to all sessions with storage.
420    ///
421    /// Stores the event in each session's event store for resumability.
422    pub fn broadcast_with_storage(&self, event_type: impl Into<String> + Clone, message: String) {
423        for entry in &self.sessions {
424            let session_id = entry.key();
425
426            // Store in event store
427            if let Some(store) = self.event_stores.get(session_id) {
428                store.store_auto_id(event_type.clone(), message.clone());
429            }
430
431            // Send
432            let _ = entry.value().send(message.clone());
433        }
434    }
435
436    /// Remove a session.
437    pub fn remove_session(&self, id: &str) {
438        self.sessions.remove(id);
439        self.event_stores.remove(id);
440    }
441
442    /// Get the number of active sessions.
443    #[must_use]
444    pub fn session_count(&self) -> usize {
445        self.sessions.len()
446    }
447
448    /// Clean up expired events across all sessions.
449    pub async fn cleanup_expired_events(&self) {
450        for entry in &self.event_stores {
451            entry.value().cleanup_expired().await;
452        }
453    }
454
455    /// Get events after the specified event ID for replay.
456    ///
457    /// Used when a client reconnects with `Last-Event-ID`.
458    pub async fn get_events_for_replay(
459        &self,
460        session_id: &str,
461        last_event_id: &str,
462    ) -> Option<Vec<StoredEvent>> {
463        if let Some(store) = self.event_stores.get(session_id) {
464            Some(store.get_events_after(last_event_id).await)
465        } else {
466            None
467        }
468    }
469}
470
471/// Thread-safe session store with automatic cleanup.
472///
473/// Stores session metadata for HTTP request handling.
474#[derive(Debug)]
475pub struct SessionStore {
476    sessions: DashMap<String, Session>,
477    timeout: Duration,
478}
479
480impl SessionStore {
481    /// Create a new session store with the given timeout.
482    #[must_use]
483    pub fn new(timeout: Duration) -> Self {
484        Self {
485            sessions: DashMap::new(),
486            timeout,
487        }
488    }
489
490    /// Create a new session store with a default 1-hour timeout.
491    #[must_use]
492    pub fn with_default_timeout() -> Self {
493        Self::new(Duration::from_secs(3600))
494    }
495
496    /// Create a new session and return its ID.
497    #[must_use]
498    pub fn create(&self) -> String {
499        let id = uuid::Uuid::new_v4().to_string();
500        self.sessions.insert(id.clone(), Session::new(id.clone()));
501        id
502    }
503
504    /// Get a session by ID.
505    #[must_use]
506    pub fn get(&self, id: &str) -> Option<Session> {
507        self.sessions.get(id).map(|r| r.clone())
508    }
509
510    /// Touch a session to update its last active time.
511    pub fn touch(&self, id: &str) {
512        if let Some(mut session) = self.sessions.get_mut(id) {
513            session.touch();
514        }
515    }
516
517    /// Update a session.
518    pub fn update<F>(&self, id: &str, f: F)
519    where
520        F: FnOnce(&mut Session),
521    {
522        if let Some(mut session) = self.sessions.get_mut(id) {
523            f(&mut session);
524        }
525    }
526
527    /// Remove expired sessions.
528    pub fn cleanup_expired(&self) {
529        let timeout = self.timeout;
530        self.sessions.retain(|_, s| !s.is_expired(timeout));
531    }
532
533    /// Remove a session.
534    #[must_use]
535    pub fn remove(&self, id: &str) -> Option<Session> {
536        self.sessions.remove(id).map(|(_, s)| s)
537    }
538
539    /// Get the number of active sessions.
540    #[must_use]
541    pub fn session_count(&self) -> usize {
542        self.sessions.len()
543    }
544
545    /// Start a background task to periodically clean up expired sessions.
546    pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
547        let store = Arc::clone(self);
548        tokio::spawn(async move {
549            loop {
550                tokio::time::sleep(interval).await;
551                store.cleanup_expired();
552            }
553        });
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_session_creation() {
563        let session = Session::new("test-123".to_string());
564        assert_eq!(session.id, "test-123");
565        assert!(!session.initialized);
566        assert!(session.client_capabilities.is_none());
567    }
568
569    #[test]
570    fn test_session_expiry() -> Result<(), Box<dyn std::error::Error>> {
571        let mut session = Session::new("test".to_string());
572        assert!(!session.is_expired(Duration::from_secs(60)));
573
574        // Simulate old session by setting last_active in the past
575        session.last_active = Instant::now()
576            .checked_sub(Duration::from_secs(120))
577            .ok_or("Failed to subtract duration")?;
578        assert!(session.is_expired(Duration::from_secs(60)));
579        Ok(())
580    }
581
582    #[test]
583    fn test_session_store() {
584        let store = SessionStore::new(Duration::from_secs(60));
585        let id = store.create();
586
587        assert!(store.get(&id).is_some());
588        store.touch(&id);
589
590        let _ = store.remove(&id);
591        assert!(store.get(&id).is_none());
592    }
593
594    #[tokio::test]
595    async fn test_session_manager() -> Result<(), Box<dyn std::error::Error>> {
596        let manager = SessionManager::new();
597        let (id, mut rx) = manager.create_session();
598
599        // Send a message
600        assert!(manager.send_to_session(&id, "test message".to_string()));
601
602        // Receive the message
603        let msg = rx.recv().await?;
604        assert_eq!(msg, "test message");
605
606        // Remove session
607        manager.remove_session(&id);
608        assert!(!manager.send_to_session(&id, "another".to_string()));
609        Ok(())
610    }
611
612    #[tokio::test]
613    async fn test_event_store_creation() {
614        let store = EventStore::with_defaults();
615        assert!(store.is_empty().await);
616        assert_eq!(store.len().await, 0);
617    }
618
619    #[tokio::test]
620    async fn test_event_store_store_and_retrieve() {
621        let store = EventStore::with_defaults();
622
623        store.store_async("evt-1", "message", "data1").await;
624        store.store_async("evt-2", "message", "data2").await;
625        store.store_async("evt-3", "message", "data3").await;
626
627        assert_eq!(store.len().await, 3);
628
629        let all_events = store.get_all_events().await;
630        assert_eq!(all_events.len(), 3);
631        assert_eq!(all_events[0].id, "evt-1");
632        assert_eq!(all_events[1].id, "evt-2");
633        assert_eq!(all_events[2].id, "evt-3");
634    }
635
636    #[tokio::test]
637    async fn test_event_store_get_events_after() {
638        let store = EventStore::with_defaults();
639
640        store.store_async("evt-1", "message", "data1").await;
641        store.store_async("evt-2", "message", "data2").await;
642        store.store_async("evt-3", "message", "data3").await;
643
644        // Get events after evt-1
645        let events = store.get_events_after("evt-1").await;
646        assert_eq!(events.len(), 2);
647        assert_eq!(events[0].id, "evt-2");
648        assert_eq!(events[1].id, "evt-3");
649
650        // Get events after evt-2
651        let events = store.get_events_after("evt-2").await;
652        assert_eq!(events.len(), 1);
653        assert_eq!(events[0].id, "evt-3");
654
655        // Get events after evt-3 (should be empty)
656        let events = store.get_events_after("evt-3").await;
657        assert_eq!(events.len(), 0);
658
659        // Get events after unknown ID (should return all)
660        let events = store.get_events_after("unknown").await;
661        assert_eq!(events.len(), 3);
662    }
663
664    #[tokio::test]
665    async fn test_event_store_auto_id() {
666        let store = EventStore::with_defaults();
667
668        let id1 = store.store_auto_id("message", "data1");
669        let id2 = store.store_auto_id("message", "data2");
670
671        assert!(id1.starts_with("evt-"));
672        assert!(id2.starts_with("evt-"));
673        assert_ne!(id1, id2);
674
675        assert_eq!(store.len().await, 2);
676    }
677
678    #[tokio::test]
679    async fn test_event_store_max_events_limit() {
680        let config = EventStoreConfig::new(3, Duration::from_secs(300));
681        let store = EventStore::new(config);
682
683        store.store_async("evt-1", "message", "data1").await;
684        store.store_async("evt-2", "message", "data2").await;
685        store.store_async("evt-3", "message", "data3").await;
686        store.store_async("evt-4", "message", "data4").await;
687
688        // Should only have 3 events (oldest removed)
689        assert_eq!(store.len().await, 3);
690
691        let events = store.get_all_events().await;
692        assert_eq!(events[0].id, "evt-2"); // evt-1 was evicted
693        assert_eq!(events[1].id, "evt-3");
694        assert_eq!(events[2].id, "evt-4");
695    }
696
697    #[tokio::test]
698    async fn test_event_store_clear() {
699        let store = EventStore::with_defaults();
700
701        store.store_async("evt-1", "message", "data1").await;
702        store.store_async("evt-2", "message", "data2").await;
703
704        assert_eq!(store.len().await, 2);
705
706        store.clear().await;
707
708        assert!(store.is_empty().await);
709        assert_eq!(store.len().await, 0);
710    }
711
712    #[tokio::test]
713    async fn test_session_manager_with_event_store() -> Result<(), Box<dyn std::error::Error>> {
714        let manager = SessionManager::new();
715        let (id, _rx) = manager.create_session();
716
717        // Event store should be created automatically
718        let store = manager.get_event_store(&id);
719        assert!(store.is_some());
720
721        let store = store.ok_or("Event store not found")?;
722        assert!(store.is_empty().await);
723        Ok(())
724    }
725
726    #[tokio::test]
727    async fn test_session_manager_send_with_storage() -> Result<(), Box<dyn std::error::Error>> {
728        let manager = SessionManager::new();
729        let (id, mut rx) = manager.create_session();
730
731        // Send with storage
732        let event_id =
733            manager.send_to_session_with_storage(&id, "message", "test data".to_string());
734        assert!(event_id.is_some());
735
736        // Verify message was received
737        let msg = rx.recv().await?;
738        assert_eq!(msg, "test data");
739
740        // Verify event was stored
741        let store = manager
742            .get_event_store(&id)
743            .ok_or("Event store not found")?;
744        assert_eq!(store.len().await, 1);
745
746        let events = store.get_all_events().await;
747        assert_eq!(events[0].data, "test data");
748        assert_eq!(events[0].event_type, "message");
749        Ok(())
750    }
751
752    #[tokio::test]
753    async fn test_session_manager_replay() -> Result<(), Box<dyn std::error::Error>> {
754        let manager = SessionManager::new();
755        let (id, _rx) = manager.create_session();
756
757        // Send multiple messages with storage
758        let _ = manager.send_to_session_with_storage(&id, "message", "msg1".to_string());
759        let evt2 = manager.send_to_session_with_storage(&id, "message", "msg2".to_string());
760        let _ = manager.send_to_session_with_storage(&id, "message", "msg3".to_string());
761
762        // Simulate reconnection - get events after evt2
763        let events = manager
764            .get_events_for_replay(&id, &evt2.ok_or("Failed to get event ID")?)
765            .await
766            .ok_or("Failed to get events for replay")?;
767
768        // Should only get msg3
769        assert_eq!(events.len(), 1);
770        assert_eq!(events[0].data, "msg3");
771        Ok(())
772    }
773
774    #[test]
775    fn test_event_store_config() {
776        let config = EventStoreConfig::default();
777        assert_eq!(config.max_events, 1000);
778        assert_eq!(config.max_age, Duration::from_secs(300));
779
780        let config = EventStoreConfig::new(500, Duration::from_secs(120))
781            .with_max_events(600)
782            .with_max_age(Duration::from_secs(180));
783
784        assert_eq!(config.max_events, 600);
785        assert_eq!(config.max_age, Duration::from_secs(180));
786    }
787
788    #[test]
789    fn test_stored_event() {
790        let event = StoredEvent::new("evt-123".to_string(), "message", "test data");
791        assert_eq!(event.id, "evt-123");
792        assert_eq!(event.event_type, "message");
793        assert_eq!(event.data, "test data");
794    }
795}