mcp_protocol_sdk/client/
session.rs

1//! Client session management
2//!
3//! This module provides session management for MCP clients, including connection
4//! state tracking, notification handling, and automatic reconnection capabilities.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
10use tokio::time::{sleep, timeout};
11
12use crate::client::mcp_client::{ClientConfig, McpClient};
13use crate::core::error::{McpError, McpResult};
14use crate::protocol::{messages::*, types::*};
15use crate::transport::traits::Transport;
16
17/// Session state
18#[derive(Debug, Clone, PartialEq)]
19pub enum SessionState {
20    /// Session is disconnected
21    Disconnected,
22    /// Session is connecting
23    Connecting,
24    /// Session is connected and active
25    Connected,
26    /// Session is reconnecting after a failure
27    Reconnecting,
28    /// Session has failed and cannot reconnect
29    Failed(String),
30}
31
32/// Notification handler trait
33pub trait NotificationHandler: Send + Sync {
34    /// Handle a notification from the server
35    fn handle_notification(&self, notification: JsonRpcNotification);
36}
37
38/// Session configuration
39#[derive(Debug, Clone)]
40pub struct SessionConfig {
41    /// Whether to enable automatic reconnection
42    pub auto_reconnect: bool,
43    /// Maximum number of reconnection attempts
44    pub max_reconnect_attempts: u32,
45    /// Initial reconnection delay in milliseconds
46    pub reconnect_delay_ms: u64,
47    /// Maximum reconnection delay in milliseconds
48    pub max_reconnect_delay_ms: u64,
49    /// Reconnection backoff multiplier
50    pub reconnect_backoff: f64,
51    /// Connection timeout in milliseconds
52    pub connection_timeout_ms: u64,
53    /// Heartbeat interval in milliseconds (0 to disable)
54    pub heartbeat_interval_ms: u64,
55    /// Heartbeat timeout in milliseconds
56    pub heartbeat_timeout_ms: u64,
57}
58
59impl Default for SessionConfig {
60    fn default() -> Self {
61        Self {
62            auto_reconnect: true,
63            max_reconnect_attempts: 5,
64            reconnect_delay_ms: 1000,
65            max_reconnect_delay_ms: 30000,
66            reconnect_backoff: 2.0,
67            connection_timeout_ms: 10000,
68            heartbeat_interval_ms: 30000,
69            heartbeat_timeout_ms: 5000,
70        }
71    }
72}
73
74/// Client session that manages connection lifecycle and notifications
75pub struct ClientSession {
76    /// The underlying MCP client
77    client: Arc<Mutex<McpClient>>,
78    /// Session configuration
79    config: SessionConfig,
80    /// Current session state
81    state: Arc<RwLock<SessionState>>,
82    /// State change broadcaster
83    state_tx: watch::Sender<SessionState>,
84    /// State change receiver
85    state_rx: watch::Receiver<SessionState>,
86    /// Notification handlers
87    notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
88    /// Connection timestamp
89    connected_at: Arc<RwLock<Option<Instant>>>,
90    /// Reconnection attempts counter
91    reconnect_attempts: Arc<Mutex<u32>>,
92    /// Shutdown signal
93    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
94}
95
96impl ClientSession {
97    /// Create a new client session
98    pub fn new(client: McpClient) -> Self {
99        let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
100
101        Self {
102            client: Arc::new(Mutex::new(client)),
103            config: SessionConfig::default(),
104            state: Arc::new(RwLock::new(SessionState::Disconnected)),
105            state_tx,
106            state_rx,
107            notification_handlers: Arc::new(RwLock::new(Vec::new())),
108            connected_at: Arc::new(RwLock::new(None)),
109            reconnect_attempts: Arc::new(Mutex::new(0)),
110            shutdown_tx: Arc::new(Mutex::new(None)),
111        }
112    }
113
114    /// Create a new client session with custom configuration
115    pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
116        let mut session = Self::new(client);
117        session.config = config;
118        session
119    }
120
121    /// Get the current session state
122    pub async fn state(&self) -> SessionState {
123        let state = self.state.read().await;
124        state.clone()
125    }
126
127    /// Subscribe to state changes
128    pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
129        self.state_rx.clone()
130    }
131
132    /// Check if the session is connected
133    pub async fn is_connected(&self) -> bool {
134        let state = self.state.read().await;
135        matches!(*state, SessionState::Connected)
136    }
137
138    /// Get connection uptime
139    pub async fn uptime(&self) -> Option<Duration> {
140        let connected_at = self.connected_at.read().await;
141        connected_at.map(|time| time.elapsed())
142    }
143
144    /// Add a notification handler
145    pub async fn add_notification_handler<H>(&self, handler: H)
146    where
147        H: NotificationHandler + 'static,
148    {
149        let mut handlers = self.notification_handlers.write().await;
150        handlers.push(Box::new(handler));
151    }
152
153    /// Connect to the server with the provided transport
154    pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
155    where
156        T: Transport + 'static,
157    {
158        self.transition_state(SessionState::Connecting).await?;
159
160        let connect_future = async {
161            let mut client = self.client.lock().await;
162            client.connect(transport).await
163        };
164
165        let result = timeout(
166            Duration::from_millis(self.config.connection_timeout_ms),
167            connect_future,
168        )
169        .await;
170
171        match result {
172            Ok(Ok(init_result)) => {
173                self.transition_state(SessionState::Connected).await?;
174
175                // Record connection time
176                {
177                    let mut connected_at = self.connected_at.write().await;
178                    *connected_at = Some(Instant::now());
179                }
180
181                // Reset reconnection attempts
182                {
183                    let mut attempts = self.reconnect_attempts.lock().await;
184                    *attempts = 0;
185                }
186
187                // Start background tasks
188                self.start_background_tasks().await?;
189
190                Ok(init_result)
191            }
192            Ok(Err(error)) => {
193                self.transition_state(SessionState::Failed(error.to_string()))
194                    .await?;
195                Err(error)
196            }
197            Err(_) => {
198                let error = McpError::Connection("Connection timeout".to_string());
199                self.transition_state(SessionState::Failed(error.to_string()))
200                    .await?;
201                Err(error)
202            }
203        }
204    }
205
206    /// Disconnect from the server
207    pub async fn disconnect(&self) -> McpResult<()> {
208        // Stop background tasks
209        self.stop_background_tasks().await;
210
211        // Disconnect the client
212        {
213            let client = self.client.lock().await;
214            client.disconnect().await?;
215        }
216
217        // Update state
218        self.transition_state(SessionState::Disconnected).await?;
219
220        // Clear connection time
221        {
222            let mut connected_at = self.connected_at.write().await;
223            *connected_at = None;
224        }
225
226        Ok(())
227    }
228
229    /// Reconnect to the server
230    pub async fn reconnect<T>(
231        &self,
232        transport_factory: impl Fn() -> T,
233    ) -> McpResult<InitializeResult>
234    where
235        T: Transport + 'static,
236    {
237        if !self.config.auto_reconnect {
238            return Err(McpError::Connection(
239                "Auto-reconnect is disabled".to_string(),
240            ));
241        }
242
243        let mut attempts = self.reconnect_attempts.lock().await;
244        if *attempts >= self.config.max_reconnect_attempts {
245            let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
246            self.transition_state(SessionState::Failed(error.to_string()))
247                .await?;
248            return Err(error);
249        }
250
251        *attempts += 1;
252        let current_attempts = *attempts;
253        drop(attempts);
254
255        self.transition_state(SessionState::Reconnecting).await?;
256
257        // Calculate reconnection delay with exponential backoff
258        let delay = std::cmp::min(
259            (self.config.reconnect_delay_ms as f64
260                * self
261                    .config
262                    .reconnect_backoff
263                    .powi(current_attempts as i32 - 1)) as u64,
264            self.config.max_reconnect_delay_ms,
265        );
266
267        sleep(Duration::from_millis(delay)).await;
268
269        // Attempt to reconnect
270        self.connect(transport_factory()).await
271    }
272
273    /// Get the underlying client (for direct operations)
274    pub fn client(&self) -> Arc<Mutex<McpClient>> {
275        self.client.clone()
276    }
277
278    /// Get session configuration
279    pub fn config(&self) -> &SessionConfig {
280        &self.config
281    }
282
283    // ========================================================================
284    // Background Tasks
285    // ========================================================================
286
287    /// Start background tasks (notification handling, heartbeat)
288    async fn start_background_tasks(&self) -> McpResult<()> {
289        let (shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
290            broadcast::channel(16);
291        {
292            let mut shutdown_guard = self.shutdown_tx.lock().await;
293            *shutdown_guard = Some(mpsc::channel(1).0); // Store a dummy for interface compatibility
294        }
295
296        // Start notification handler task
297        {
298            let client = self.client.clone();
299            let handlers = self.notification_handlers.clone();
300            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
301
302            tokio::spawn(async move {
303                loop {
304                    tokio::select! {
305                        _ = shutdown_rx_clone.recv() => break,
306                        notification_result = async {
307                            let client_guard = client.lock().await;
308                            client_guard.receive_notification().await
309                        } => {
310                            match notification_result {
311                                Ok(Some(notification)) => {
312                                    let handlers_guard = handlers.read().await;
313                                    for handler in handlers_guard.iter() {
314                                        handler.handle_notification(notification.clone());
315                                    }
316                                }
317                                Ok(None) => {
318                                    // No notification available, continue
319                                }
320                                Err(_) => {
321                                    // Error receiving notification, might be disconnected
322                                    break;
323                                }
324                            }
325                        }
326                    }
327                }
328            });
329        }
330
331        // Start heartbeat task if enabled
332        if self.config.heartbeat_interval_ms > 0 {
333            let client = self.client.clone();
334            let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
335            let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
336            let state = self.state.clone();
337            let state_tx = self.state_tx.clone();
338            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
339
340            tokio::spawn(async move {
341                let mut interval = tokio::time::interval(heartbeat_interval);
342
343                loop {
344                    tokio::select! {
345                        _ = shutdown_rx_clone.recv() => break,
346                        _ = interval.tick() => {
347                            // Check if we're still connected
348                            {
349                                let current_state = state.read().await;
350                                if !matches!(*current_state, SessionState::Connected) {
351                                    break;
352                                }
353                            }
354
355                            // Send ping
356                            let ping_result = timeout(heartbeat_timeout, async {
357                                let client_guard = client.lock().await;
358                                client_guard.ping().await
359                            }).await;
360
361                            if ping_result.is_err() {
362                                // Heartbeat failed, mark as disconnected
363                                let _ = state_tx.send(SessionState::Disconnected);
364                                break;
365                            }
366                        }
367                    }
368                }
369            });
370        }
371
372        Ok(())
373    }
374
375    /// Stop background tasks
376    async fn stop_background_tasks(&self) {
377        let shutdown_tx = {
378            let mut shutdown_guard = self.shutdown_tx.lock().await;
379            shutdown_guard.take()
380        };
381
382        if let Some(tx) = shutdown_tx {
383            let _ = tx.send(()).await; // Ignore error if receiver is dropped
384        }
385    }
386
387    /// Transition to a new state
388    async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
389        {
390            let mut state = self.state.write().await;
391            *state = new_state.clone();
392        }
393
394        // Broadcast the state change
395        if let Err(_) = self.state_tx.send(new_state) {
396            // Receiver may have been dropped, which is okay
397        }
398
399        Ok(())
400    }
401}
402
403/// Default notification handler that logs notifications
404pub struct LoggingNotificationHandler;
405
406impl NotificationHandler for LoggingNotificationHandler {
407    fn handle_notification(&self, notification: JsonRpcNotification) {
408        tracing::info!(
409            "Received notification: {} {:?}",
410            notification.method,
411            notification.params
412        );
413    }
414}
415
416/// Resource update notification handler
417pub struct ResourceUpdateHandler {
418    callback: Box<dyn Fn(String) + Send + Sync>,
419}
420
421impl ResourceUpdateHandler {
422    /// Create a new resource update handler
423    pub fn new<F>(callback: F) -> Self
424    where
425        F: Fn(String) + Send + Sync + 'static,
426    {
427        Self {
428            callback: Box::new(callback),
429        }
430    }
431}
432
433impl NotificationHandler for ResourceUpdateHandler {
434    fn handle_notification(&self, notification: JsonRpcNotification) {
435        if notification.method == methods::RESOURCES_UPDATED {
436            if let Some(params) = notification.params {
437                if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
438                    (self.callback)(update_params.uri);
439                }
440            }
441        }
442    }
443}
444
445/// Tool list changed notification handler
446pub struct ToolListChangedHandler {
447    callback: Box<dyn Fn() + Send + Sync>,
448}
449
450impl ToolListChangedHandler {
451    /// Create a new tool list changed handler
452    pub fn new<F>(callback: F) -> Self
453    where
454        F: Fn() + Send + Sync + 'static,
455    {
456        Self {
457            callback: Box::new(callback),
458        }
459    }
460}
461
462impl NotificationHandler for ToolListChangedHandler {
463    fn handle_notification(&self, notification: JsonRpcNotification) {
464        if notification.method == methods::TOOLS_LIST_CHANGED {
465            (self.callback)();
466        }
467    }
468}
469
470/// Progress notification handler
471pub struct ProgressHandler {
472    callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
473}
474
475impl ProgressHandler {
476    /// Create a new progress handler
477    pub fn new<F>(callback: F) -> Self
478    where
479        F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
480    {
481        Self {
482            callback: Box::new(callback),
483        }
484    }
485}
486
487impl NotificationHandler for ProgressHandler {
488    fn handle_notification(&self, notification: JsonRpcNotification) {
489        if notification.method == methods::PROGRESS {
490            if let Some(params) = notification.params {
491                if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
492                    (self.callback)(
493                        progress_params.progress_token,
494                        progress_params.progress,
495                        progress_params.total,
496                    );
497                }
498            }
499        }
500    }
501}
502
503/// Session statistics
504#[derive(Debug, Clone)]
505pub struct SessionStats {
506    /// Current session state
507    pub state: SessionState,
508    /// Connection uptime
509    pub uptime: Option<Duration>,
510    /// Number of reconnection attempts
511    pub reconnect_attempts: u32,
512    /// Connection timestamp
513    pub connected_at: Option<Instant>,
514}
515
516impl ClientSession {
517    /// Get session statistics
518    pub async fn stats(&self) -> SessionStats {
519        let state = self.state().await;
520        let uptime = self.uptime().await;
521        let reconnect_attempts = {
522            let attempts = self.reconnect_attempts.lock().await;
523            *attempts
524        };
525        let connected_at = {
526            let connected_at = self.connected_at.read().await;
527            *connected_at
528        };
529
530        SessionStats {
531            state,
532            uptime,
533            reconnect_attempts,
534            connected_at,
535        }
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use crate::client::mcp_client::McpClient;
543    use async_trait::async_trait;
544
545    // Mock transport for testing
546    struct MockTransport;
547
548    #[async_trait]
549    impl Transport for MockTransport {
550        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
551            // Return a successful initialize response
552            let init_result = InitializeResult::new(
553                ServerInfo {
554                    name: "test-server".to_string(),
555                    version: "1.0.0".to_string(),
556                },
557                ServerCapabilities::default(),
558                MCP_PROTOCOL_VERSION.to_string(),
559            );
560            JsonRpcResponse::success(serde_json::Value::from(1), init_result)
561                .map_err(|e| McpError::Serialization(e))
562        }
563
564        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
565            Ok(())
566        }
567
568        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
569            Ok(None)
570        }
571
572        async fn close(&mut self) -> McpResult<()> {
573            Ok(())
574        }
575    }
576
577    #[tokio::test]
578    async fn test_session_creation() {
579        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
580        let session = ClientSession::new(client);
581
582        assert_eq!(session.state().await, SessionState::Disconnected);
583        assert!(!session.is_connected().await);
584        assert!(session.uptime().await.is_none());
585    }
586
587    #[tokio::test]
588    async fn test_session_connection() {
589        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
590        let session = ClientSession::new(client);
591
592        let transport = MockTransport;
593        let result = session.connect(transport).await;
594
595        assert!(result.is_ok());
596        assert_eq!(session.state().await, SessionState::Connected);
597        assert!(session.is_connected().await);
598        assert!(session.uptime().await.is_some());
599    }
600
601    #[tokio::test]
602    async fn test_session_disconnect() {
603        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
604        let session = ClientSession::new(client);
605
606        // Connect first
607        let transport = MockTransport;
608        session.connect(transport).await.unwrap();
609        assert!(session.is_connected().await);
610
611        // Then disconnect
612        session.disconnect().await.unwrap();
613        assert_eq!(session.state().await, SessionState::Disconnected);
614        assert!(!session.is_connected().await);
615        assert!(session.uptime().await.is_none());
616    }
617
618    #[tokio::test]
619    async fn test_notification_handlers() {
620        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
621        let session = ClientSession::new(client);
622
623        // Add a logging notification handler
624        session
625            .add_notification_handler(LoggingNotificationHandler)
626            .await;
627
628        // Add a resource update handler
629        session
630            .add_notification_handler(ResourceUpdateHandler::new(|uri| {
631                println!("Resource updated: {}", uri);
632            }))
633            .await;
634
635        // Add a tool list changed handler
636        session
637            .add_notification_handler(ToolListChangedHandler::new(|| {
638                println!("Tool list changed");
639            }))
640            .await;
641
642        // Add a progress handler
643        session
644            .add_notification_handler(ProgressHandler::new(|token, progress, total| {
645                println!("Progress {}: {} / {:?}", token, progress, total);
646            }))
647            .await;
648
649        let handlers = session.notification_handlers.read().await;
650        assert_eq!(handlers.len(), 4);
651    }
652
653    #[tokio::test]
654    async fn test_session_stats() {
655        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
656        let session = ClientSession::new(client);
657
658        let stats = session.stats().await;
659        assert_eq!(stats.state, SessionState::Disconnected);
660        assert!(stats.uptime.is_none());
661        assert_eq!(stats.reconnect_attempts, 0);
662        assert!(stats.connected_at.is_none());
663    }
664
665    #[tokio::test]
666    async fn test_session_config() {
667        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
668        let config = SessionConfig {
669            auto_reconnect: false,
670            max_reconnect_attempts: 10,
671            reconnect_delay_ms: 2000,
672            ..Default::default()
673        };
674        let session = ClientSession::with_config(client, config.clone());
675
676        assert_eq!(session.config().auto_reconnect, false);
677        assert_eq!(session.config().max_reconnect_attempts, 10);
678        assert_eq!(session.config().reconnect_delay_ms, 2000);
679    }
680
681    #[tokio::test]
682    async fn test_state_subscription() {
683        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
684        let session = ClientSession::new(client);
685
686        let mut state_rx = session.subscribe_state_changes();
687
688        // Initial state
689        assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
690
691        // Change state
692        session
693            .transition_state(SessionState::Connecting)
694            .await
695            .unwrap();
696
697        // Wait for change
698        state_rx.changed().await.unwrap();
699        assert_eq!(*state_rx.borrow(), SessionState::Connecting);
700    }
701}