mcp_protocol_sdk/client/
session.rs

1//! Client session management
2//!
3//! This module provides session management for MCP clients, including connection
4//! state tracking, notification handling, and automatic reconnection capabilities.
5
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
9use tokio::time::{sleep, timeout};
10
11use crate::client::mcp_client::McpClient;
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, types::*};
14use crate::transport::traits::Transport;
15
16/// Session state
17#[derive(Debug, Clone, PartialEq)]
18pub enum SessionState {
19    /// Session is disconnected
20    Disconnected,
21    /// Session is connecting
22    Connecting,
23    /// Session is connected and active
24    Connected,
25    /// Session is reconnecting after a failure
26    Reconnecting,
27    /// Session has failed and cannot reconnect
28    Failed(String),
29}
30
31/// Notification handler trait
32pub trait NotificationHandler: Send + Sync {
33    /// Handle a notification from the server
34    fn handle_notification(&self, notification: JsonRpcNotification);
35}
36
37/// Session configuration
38#[derive(Debug, Clone)]
39pub struct SessionConfig {
40    /// Whether to enable automatic reconnection
41    pub auto_reconnect: bool,
42    /// Maximum number of reconnection attempts
43    pub max_reconnect_attempts: u32,
44    /// Initial reconnection delay in milliseconds
45    pub reconnect_delay_ms: u64,
46    /// Maximum reconnection delay in milliseconds
47    pub max_reconnect_delay_ms: u64,
48    /// Reconnection backoff multiplier
49    pub reconnect_backoff: f64,
50    /// Connection timeout in milliseconds
51    pub connection_timeout_ms: u64,
52    /// Heartbeat interval in milliseconds (0 to disable)
53    pub heartbeat_interval_ms: u64,
54    /// Heartbeat timeout in milliseconds
55    pub heartbeat_timeout_ms: u64,
56}
57
58impl Default for SessionConfig {
59    fn default() -> Self {
60        Self {
61            auto_reconnect: true,
62            max_reconnect_attempts: 5,
63            reconnect_delay_ms: 1000,
64            max_reconnect_delay_ms: 30000,
65            reconnect_backoff: 2.0,
66            connection_timeout_ms: 10000,
67            heartbeat_interval_ms: 30000,
68            heartbeat_timeout_ms: 5000,
69        }
70    }
71}
72
73/// Client session that manages connection lifecycle and notifications
74pub struct ClientSession {
75    /// The underlying MCP client
76    client: Arc<Mutex<McpClient>>,
77    /// Session configuration
78    config: SessionConfig,
79    /// Current session state
80    state: Arc<RwLock<SessionState>>,
81    /// State change broadcaster
82    state_tx: watch::Sender<SessionState>,
83    /// State change receiver
84    state_rx: watch::Receiver<SessionState>,
85    /// Notification handlers
86    notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
87    /// Connection timestamp
88    connected_at: Arc<RwLock<Option<Instant>>>,
89    /// Reconnection attempts counter
90    reconnect_attempts: Arc<Mutex<u32>>,
91    /// Shutdown signal
92    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
93}
94
95impl ClientSession {
96    /// Create a new client session
97    pub fn new(client: McpClient) -> Self {
98        let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
99
100        Self {
101            client: Arc::new(Mutex::new(client)),
102            config: SessionConfig::default(),
103            state: Arc::new(RwLock::new(SessionState::Disconnected)),
104            state_tx,
105            state_rx,
106            notification_handlers: Arc::new(RwLock::new(Vec::new())),
107            connected_at: Arc::new(RwLock::new(None)),
108            reconnect_attempts: Arc::new(Mutex::new(0)),
109            shutdown_tx: Arc::new(Mutex::new(None)),
110        }
111    }
112
113    /// Create a new client session with custom configuration
114    pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
115        let mut session = Self::new(client);
116        session.config = config;
117        session
118    }
119
120    /// Get the current session state
121    pub async fn state(&self) -> SessionState {
122        let state = self.state.read().await;
123        state.clone()
124    }
125
126    /// Subscribe to state changes
127    pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
128        self.state_rx.clone()
129    }
130
131    /// Check if the session is connected
132    pub async fn is_connected(&self) -> bool {
133        let state = self.state.read().await;
134        matches!(*state, SessionState::Connected)
135    }
136
137    /// Get connection uptime
138    pub async fn uptime(&self) -> Option<Duration> {
139        let connected_at = self.connected_at.read().await;
140        connected_at.map(|time| time.elapsed())
141    }
142
143    /// Add a notification handler
144    pub async fn add_notification_handler<H>(&self, handler: H)
145    where
146        H: NotificationHandler + 'static,
147    {
148        let mut handlers = self.notification_handlers.write().await;
149        handlers.push(Box::new(handler));
150    }
151
152    /// Connect to the server with the provided transport
153    pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
154    where
155        T: Transport + 'static,
156    {
157        self.transition_state(SessionState::Connecting).await?;
158
159        let connect_future = async {
160            let mut client = self.client.lock().await;
161            client.connect(transport).await
162        };
163
164        let result = timeout(
165            Duration::from_millis(self.config.connection_timeout_ms),
166            connect_future,
167        )
168        .await;
169
170        match result {
171            Ok(Ok(init_result)) => {
172                self.transition_state(SessionState::Connected).await?;
173
174                // Record connection time
175                {
176                    let mut connected_at = self.connected_at.write().await;
177                    *connected_at = Some(Instant::now());
178                }
179
180                // Reset reconnection attempts
181                {
182                    let mut attempts = self.reconnect_attempts.lock().await;
183                    *attempts = 0;
184                }
185
186                // Start background tasks
187                self.start_background_tasks().await?;
188
189                Ok(init_result)
190            }
191            Ok(Err(error)) => {
192                self.transition_state(SessionState::Failed(error.to_string()))
193                    .await?;
194                Err(error)
195            }
196            Err(_) => {
197                let error = McpError::Connection("Connection timeout".to_string());
198                self.transition_state(SessionState::Failed(error.to_string()))
199                    .await?;
200                Err(error)
201            }
202        }
203    }
204
205    /// Disconnect from the server
206    pub async fn disconnect(&self) -> McpResult<()> {
207        // Stop background tasks
208        self.stop_background_tasks().await;
209
210        // Disconnect the client
211        {
212            let client = self.client.lock().await;
213            client.disconnect().await?;
214        }
215
216        // Update state
217        self.transition_state(SessionState::Disconnected).await?;
218
219        // Clear connection time
220        {
221            let mut connected_at = self.connected_at.write().await;
222            *connected_at = None;
223        }
224
225        Ok(())
226    }
227
228    /// Reconnect to the server
229    pub async fn reconnect<T>(
230        &self,
231        transport_factory: impl Fn() -> T,
232    ) -> McpResult<InitializeResult>
233    where
234        T: Transport + 'static,
235    {
236        if !self.config.auto_reconnect {
237            return Err(McpError::Connection(
238                "Auto-reconnect is disabled".to_string(),
239            ));
240        }
241
242        let mut attempts = self.reconnect_attempts.lock().await;
243        if *attempts >= self.config.max_reconnect_attempts {
244            let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
245            self.transition_state(SessionState::Failed(error.to_string()))
246                .await?;
247            return Err(error);
248        }
249
250        *attempts += 1;
251        let current_attempts = *attempts;
252        drop(attempts);
253
254        self.transition_state(SessionState::Reconnecting).await?;
255
256        // Calculate reconnection delay with exponential backoff
257        let delay = std::cmp::min(
258            (self.config.reconnect_delay_ms as f64
259                * self
260                    .config
261                    .reconnect_backoff
262                    .powi(current_attempts as i32 - 1)) as u64,
263            self.config.max_reconnect_delay_ms,
264        );
265
266        sleep(Duration::from_millis(delay)).await;
267
268        // Attempt to reconnect
269        self.connect(transport_factory()).await
270    }
271
272    /// Get the underlying client (for direct operations)
273    pub fn client(&self) -> Arc<Mutex<McpClient>> {
274        self.client.clone()
275    }
276
277    /// Get session configuration
278    pub fn config(&self) -> &SessionConfig {
279        &self.config
280    }
281
282    // ========================================================================
283    // Background Tasks
284    // ========================================================================
285
286    /// Start background tasks (notification handling, heartbeat)
287    async fn start_background_tasks(&self) -> McpResult<()> {
288        let (_shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
289            broadcast::channel(16);
290        {
291            let mut shutdown_guard = self.shutdown_tx.lock().await;
292            *shutdown_guard = Some(mpsc::channel(1).0); // Store a dummy for interface compatibility
293        }
294
295        // Start notification handler task
296        {
297            let client = self.client.clone();
298            let handlers = self.notification_handlers.clone();
299            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
300
301            tokio::spawn(async move {
302                loop {
303                    tokio::select! {
304                        _ = shutdown_rx_clone.recv() => break,
305                        notification_result = async {
306                            let client_guard = client.lock().await;
307                            client_guard.receive_notification().await
308                        } => {
309                            match notification_result {
310                                Ok(Some(notification)) => {
311                                    let handlers_guard = handlers.read().await;
312                                    for handler in handlers_guard.iter() {
313                                        handler.handle_notification(notification.clone());
314                                    }
315                                }
316                                Ok(None) => {
317                                    // No notification available, continue
318                                }
319                                Err(_) => {
320                                    // Error receiving notification, might be disconnected
321                                    break;
322                                }
323                            }
324                        }
325                    }
326                }
327            });
328        }
329
330        // Start heartbeat task if enabled
331        if self.config.heartbeat_interval_ms > 0 {
332            let client = self.client.clone();
333            let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
334            let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
335            let state = self.state.clone();
336            let state_tx = self.state_tx.clone();
337            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
338
339            tokio::spawn(async move {
340                let mut interval = tokio::time::interval(heartbeat_interval);
341
342                loop {
343                    tokio::select! {
344                        _ = shutdown_rx_clone.recv() => break,
345                        _ = interval.tick() => {
346                            // Check if we're still connected
347                            {
348                                let current_state = state.read().await;
349                                if !matches!(*current_state, SessionState::Connected) {
350                                    break;
351                                }
352                            }
353
354                            // Send ping
355                            let ping_result = timeout(heartbeat_timeout, async {
356                                let client_guard = client.lock().await;
357                                client_guard.ping().await
358                            }).await;
359
360                            if ping_result.is_err() {
361                                // Heartbeat failed, mark as disconnected
362                                let _ = state_tx.send(SessionState::Disconnected);
363                                break;
364                            }
365                        }
366                    }
367                }
368            });
369        }
370
371        Ok(())
372    }
373
374    /// Stop background tasks
375    async fn stop_background_tasks(&self) {
376        let shutdown_tx = {
377            let mut shutdown_guard = self.shutdown_tx.lock().await;
378            shutdown_guard.take()
379        };
380
381        if let Some(tx) = shutdown_tx {
382            let _ = tx.send(()).await; // Ignore error if receiver is dropped
383        }
384    }
385
386    /// Transition to a new state
387    async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
388        {
389            let mut state = self.state.write().await;
390            *state = new_state.clone();
391        }
392
393        // Broadcast the state change
394        if self.state_tx.send(new_state).is_err() {
395            // Receiver may have been dropped, which is okay
396        }
397
398        Ok(())
399    }
400}
401
402/// Default notification handler that logs notifications
403pub struct LoggingNotificationHandler;
404
405impl NotificationHandler for LoggingNotificationHandler {
406    fn handle_notification(&self, notification: JsonRpcNotification) {
407        tracing::info!(
408            "Received notification: {} {:?}",
409            notification.method,
410            notification.params
411        );
412    }
413}
414
415/// Resource update notification handler
416pub struct ResourceUpdateHandler {
417    callback: Box<dyn Fn(String) + Send + Sync>,
418}
419
420impl ResourceUpdateHandler {
421    /// Create a new resource update handler
422    pub fn new<F>(callback: F) -> Self
423    where
424        F: Fn(String) + Send + Sync + 'static,
425    {
426        Self {
427            callback: Box::new(callback),
428        }
429    }
430}
431
432impl NotificationHandler for ResourceUpdateHandler {
433    fn handle_notification(&self, notification: JsonRpcNotification) {
434        if notification.method == methods::RESOURCES_UPDATED {
435            if let Some(params) = notification.params {
436                if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
437                    (self.callback)(update_params.uri);
438                }
439            }
440        }
441    }
442}
443
444/// Tool list changed notification handler
445pub struct ToolListChangedHandler {
446    callback: Box<dyn Fn() + Send + Sync>,
447}
448
449impl ToolListChangedHandler {
450    /// Create a new tool list changed handler
451    pub fn new<F>(callback: F) -> Self
452    where
453        F: Fn() + Send + Sync + 'static,
454    {
455        Self {
456            callback: Box::new(callback),
457        }
458    }
459}
460
461impl NotificationHandler for ToolListChangedHandler {
462    fn handle_notification(&self, notification: JsonRpcNotification) {
463        if notification.method == methods::TOOLS_LIST_CHANGED {
464            (self.callback)();
465        }
466    }
467}
468
469/// Progress notification handler
470pub struct ProgressHandler {
471    callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
472}
473
474impl ProgressHandler {
475    /// Create a new progress handler
476    pub fn new<F>(callback: F) -> Self
477    where
478        F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
479    {
480        Self {
481            callback: Box::new(callback),
482        }
483    }
484}
485
486impl NotificationHandler for ProgressHandler {
487    fn handle_notification(&self, notification: JsonRpcNotification) {
488        if notification.method == methods::PROGRESS {
489            if let Some(params) = notification.params {
490                if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
491                    (self.callback)(
492                        progress_params.progress_token,
493                        progress_params.progress,
494                        progress_params.total,
495                    );
496                }
497            }
498        }
499    }
500}
501
502/// Session statistics
503#[derive(Debug, Clone)]
504pub struct SessionStats {
505    /// Current session state
506    pub state: SessionState,
507    /// Connection uptime
508    pub uptime: Option<Duration>,
509    /// Number of reconnection attempts
510    pub reconnect_attempts: u32,
511    /// Connection timestamp
512    pub connected_at: Option<Instant>,
513}
514
515impl ClientSession {
516    /// Get session statistics
517    pub async fn stats(&self) -> SessionStats {
518        let state = self.state().await;
519        let uptime = self.uptime().await;
520        let reconnect_attempts = {
521            let attempts = self.reconnect_attempts.lock().await;
522            *attempts
523        };
524        let connected_at = {
525            let connected_at = self.connected_at.read().await;
526            *connected_at
527        };
528
529        SessionStats {
530            state,
531            uptime,
532            reconnect_attempts,
533            connected_at,
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::client::mcp_client::McpClient;
542    use async_trait::async_trait;
543
544    // Mock transport for testing
545    struct MockTransport;
546
547    #[async_trait]
548    impl Transport for MockTransport {
549        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
550            // Return a successful initialize response
551            let init_result = InitializeResult::new(
552                ServerInfo {
553                    name: "test-server".to_string(),
554                    version: "1.0.0".to_string(),
555                },
556                ServerCapabilities::default(),
557                MCP_PROTOCOL_VERSION.to_string(),
558            );
559            JsonRpcResponse::success(serde_json::Value::from(1), init_result)
560                .map_err(McpError::Serialization)
561        }
562
563        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
564            Ok(())
565        }
566
567        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
568            Ok(None)
569        }
570
571        async fn close(&mut self) -> McpResult<()> {
572            Ok(())
573        }
574    }
575
576    #[tokio::test]
577    async fn test_session_creation() {
578        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
579        let session = ClientSession::new(client);
580
581        assert_eq!(session.state().await, SessionState::Disconnected);
582        assert!(!session.is_connected().await);
583        assert!(session.uptime().await.is_none());
584    }
585
586    #[tokio::test]
587    async fn test_session_connection() {
588        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
589        let session = ClientSession::new(client);
590
591        let transport = MockTransport;
592        let result = session.connect(transport).await;
593
594        assert!(result.is_ok());
595        assert_eq!(session.state().await, SessionState::Connected);
596        assert!(session.is_connected().await);
597        assert!(session.uptime().await.is_some());
598    }
599
600    #[tokio::test]
601    async fn test_session_disconnect() {
602        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
603        let session = ClientSession::new(client);
604
605        // Connect first
606        let transport = MockTransport;
607        session.connect(transport).await.unwrap();
608        assert!(session.is_connected().await);
609
610        // Then disconnect
611        session.disconnect().await.unwrap();
612        assert_eq!(session.state().await, SessionState::Disconnected);
613        assert!(!session.is_connected().await);
614        assert!(session.uptime().await.is_none());
615    }
616
617    #[tokio::test]
618    async fn test_notification_handlers() {
619        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
620        let session = ClientSession::new(client);
621
622        // Add a logging notification handler
623        session
624            .add_notification_handler(LoggingNotificationHandler)
625            .await;
626
627        // Add a resource update handler
628        session
629            .add_notification_handler(ResourceUpdateHandler::new(|uri| {
630                println!("Resource updated: {}", uri);
631            }))
632            .await;
633
634        // Add a tool list changed handler
635        session
636            .add_notification_handler(ToolListChangedHandler::new(|| {
637                println!("Tool list changed");
638            }))
639            .await;
640
641        // Add a progress handler
642        session
643            .add_notification_handler(ProgressHandler::new(|token, progress, total| {
644                println!("Progress {}: {} / {:?}", token, progress, total);
645            }))
646            .await;
647
648        let handlers = session.notification_handlers.read().await;
649        assert_eq!(handlers.len(), 4);
650    }
651
652    #[tokio::test]
653    async fn test_session_stats() {
654        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
655        let session = ClientSession::new(client);
656
657        let stats = session.stats().await;
658        assert_eq!(stats.state, SessionState::Disconnected);
659        assert!(stats.uptime.is_none());
660        assert_eq!(stats.reconnect_attempts, 0);
661        assert!(stats.connected_at.is_none());
662    }
663
664    #[tokio::test]
665    async fn test_session_config() {
666        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
667        let config = SessionConfig {
668            auto_reconnect: false,
669            max_reconnect_attempts: 10,
670            reconnect_delay_ms: 2000,
671            ..Default::default()
672        };
673        let session = ClientSession::with_config(client, config.clone());
674
675        assert!(!session.config().auto_reconnect);
676        assert_eq!(session.config().max_reconnect_attempts, 10);
677        assert_eq!(session.config().reconnect_delay_ms, 2000);
678    }
679
680    #[tokio::test]
681    async fn test_state_subscription() {
682        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
683        let session = ClientSession::new(client);
684
685        let mut state_rx = session.subscribe_state_changes();
686
687        // Initial state
688        assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
689
690        // Change state
691        session
692            .transition_state(SessionState::Connecting)
693            .await
694            .unwrap();
695
696        // Wait for change
697        state_rx.changed().await.unwrap();
698        assert_eq!(*state_rx.borrow(), SessionState::Connecting);
699    }
700}