mcp_protocol_sdk/client/
session.rs

1//! Client session management
2//!
3//! This module provides session management for MCP clients, including connection
4//! state tracking, notification handling, and automatic reconnection capabilities.
5
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
9use tokio::time::{sleep, timeout};
10
11use crate::client::mcp_client::McpClient;
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, methods, types::*};
14use crate::transport::traits::Transport;
15
16/// Session state
17#[derive(Debug, Clone, PartialEq)]
18pub enum SessionState {
19    /// Session is idle (ready but not actively processing)
20    Idle,
21    /// Session is disconnected
22    Disconnected,
23    /// Session is connecting
24    Connecting,
25    /// Session is connected and active
26    Connected,
27    /// Session is reconnecting after a failure
28    Reconnecting,
29    /// Session has failed and cannot reconnect
30    Failed(String),
31}
32
33/// Notification handler trait
34pub trait NotificationHandler: Send + Sync {
35    /// Handle a notification from the server
36    fn handle_notification(&self, notification: JsonRpcNotification);
37}
38
39/// Session configuration
40#[derive(Debug, Clone)]
41pub struct SessionConfig {
42    /// Whether to enable automatic reconnection
43    pub auto_reconnect: bool,
44    /// Maximum number of reconnection attempts
45    pub max_reconnect_attempts: u32,
46    /// Initial reconnection delay in milliseconds
47    pub reconnect_delay_ms: u64,
48    /// Maximum reconnection delay in milliseconds
49    pub max_reconnect_delay_ms: u64,
50    /// Reconnection backoff multiplier
51    pub reconnect_backoff: f64,
52    /// Connection timeout in milliseconds
53    pub connection_timeout_ms: u64,
54    /// Heartbeat interval in milliseconds (0 to disable)
55    pub heartbeat_interval_ms: u64,
56    /// Heartbeat timeout in milliseconds
57    pub heartbeat_timeout_ms: u64,
58
59    // Additional fields for test compatibility
60    /// Session timeout duration
61    pub session_timeout: Duration,
62    /// Request timeout duration
63    pub request_timeout: Duration,
64    /// Maximum concurrent requests
65    pub max_concurrent_requests: u32,
66    /// Enable compression
67    pub enable_compression: bool,
68    /// Buffer size for operations
69    pub buffer_size: usize,
70}
71
72impl Default for SessionConfig {
73    fn default() -> Self {
74        Self {
75            auto_reconnect: true,
76            max_reconnect_attempts: 5,
77            reconnect_delay_ms: 1000,
78            max_reconnect_delay_ms: 30000,
79            reconnect_backoff: 2.0,
80            connection_timeout_ms: 10000,
81            heartbeat_interval_ms: 30000,
82            heartbeat_timeout_ms: 5000,
83            session_timeout: Duration::from_secs(300),
84            request_timeout: Duration::from_secs(30),
85            max_concurrent_requests: 10,
86            enable_compression: false,
87            buffer_size: 8192,
88        }
89    }
90}
91
92/// Client session that manages connection lifecycle and notifications
93pub struct ClientSession {
94    /// The underlying MCP client
95    client: Arc<Mutex<McpClient>>,
96    /// Session configuration
97    config: SessionConfig,
98    /// Current session state
99    state: Arc<RwLock<SessionState>>,
100    /// State change broadcaster
101    state_tx: watch::Sender<SessionState>,
102    /// State change receiver
103    state_rx: watch::Receiver<SessionState>,
104    /// Notification handlers
105    notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
106    /// Connection timestamp
107    connected_at: Arc<RwLock<Option<Instant>>>,
108    /// Reconnection attempts counter
109    reconnect_attempts: Arc<Mutex<u32>>,
110    /// Shutdown signal
111    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
112}
113
114impl ClientSession {
115    /// Create a new client session
116    pub fn new(client: McpClient) -> Self {
117        let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
118
119        Self {
120            client: Arc::new(Mutex::new(client)),
121            config: SessionConfig::default(),
122            state: Arc::new(RwLock::new(SessionState::Disconnected)),
123            state_tx,
124            state_rx,
125            notification_handlers: Arc::new(RwLock::new(Vec::new())),
126            connected_at: Arc::new(RwLock::new(None)),
127            reconnect_attempts: Arc::new(Mutex::new(0)),
128            shutdown_tx: Arc::new(Mutex::new(None)),
129        }
130    }
131
132    /// Create a new client session with custom configuration
133    pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
134        let mut session = Self::new(client);
135        session.config = config;
136        session
137    }
138
139    /// Get the current session state
140    pub async fn state(&self) -> SessionState {
141        let state = self.state.read().await;
142        state.clone()
143    }
144
145    /// Subscribe to state changes
146    pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
147        self.state_rx.clone()
148    }
149
150    /// Check if the session is connected
151    pub async fn is_connected(&self) -> bool {
152        let state = self.state.read().await;
153        matches!(*state, SessionState::Connected)
154    }
155
156    /// Get connection uptime
157    pub async fn uptime(&self) -> Option<Duration> {
158        let connected_at = self.connected_at.read().await;
159        connected_at.map(|time| time.elapsed())
160    }
161
162    /// Add a notification handler
163    pub async fn add_notification_handler<H>(&self, handler: H)
164    where
165        H: NotificationHandler + 'static,
166    {
167        let mut handlers = self.notification_handlers.write().await;
168        handlers.push(Box::new(handler));
169    }
170
171    /// Connect to the server with the provided transport
172    pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
173    where
174        T: Transport + 'static,
175    {
176        self.transition_state(SessionState::Connecting).await?;
177
178        let connect_future = async {
179            let mut client = self.client.lock().await;
180            client.connect(transport).await
181        };
182
183        let result = timeout(
184            Duration::from_millis(self.config.connection_timeout_ms),
185            connect_future,
186        )
187        .await;
188
189        match result {
190            Ok(Ok(init_result)) => {
191                self.transition_state(SessionState::Connected).await?;
192
193                // Record connection time
194                {
195                    let mut connected_at = self.connected_at.write().await;
196                    *connected_at = Some(Instant::now());
197                }
198
199                // Reset reconnection attempts
200                {
201                    let mut attempts = self.reconnect_attempts.lock().await;
202                    *attempts = 0;
203                }
204
205                // Start background tasks
206                self.start_background_tasks().await?;
207
208                Ok(init_result)
209            }
210            Ok(Err(error)) => {
211                self.transition_state(SessionState::Failed(error.to_string()))
212                    .await?;
213                Err(error)
214            }
215            Err(_) => {
216                let error = McpError::Connection("Connection timeout".to_string());
217                self.transition_state(SessionState::Failed(error.to_string()))
218                    .await?;
219                Err(error)
220            }
221        }
222    }
223
224    /// Disconnect from the server
225    pub async fn disconnect(&self) -> McpResult<()> {
226        // Stop background tasks
227        self.stop_background_tasks().await;
228
229        // Disconnect the client
230        {
231            let client = self.client.lock().await;
232            client.disconnect().await?;
233        }
234
235        // Update state
236        self.transition_state(SessionState::Disconnected).await?;
237
238        // Clear connection time
239        {
240            let mut connected_at = self.connected_at.write().await;
241            *connected_at = None;
242        }
243
244        Ok(())
245    }
246
247    /// Reconnect to the server
248    pub async fn reconnect<T>(
249        &self,
250        transport_factory: impl Fn() -> T,
251    ) -> McpResult<InitializeResult>
252    where
253        T: Transport + 'static,
254    {
255        if !self.config.auto_reconnect {
256            return Err(McpError::Connection(
257                "Auto-reconnect is disabled".to_string(),
258            ));
259        }
260
261        let mut attempts = self.reconnect_attempts.lock().await;
262        if *attempts >= self.config.max_reconnect_attempts {
263            let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
264            self.transition_state(SessionState::Failed(error.to_string()))
265                .await?;
266            return Err(error);
267        }
268
269        *attempts += 1;
270        let current_attempts = *attempts;
271        drop(attempts);
272
273        self.transition_state(SessionState::Reconnecting).await?;
274
275        // Calculate reconnection delay with exponential backoff
276        let delay = std::cmp::min(
277            (self.config.reconnect_delay_ms as f64
278                * self
279                    .config
280                    .reconnect_backoff
281                    .powi(current_attempts as i32 - 1)) as u64,
282            self.config.max_reconnect_delay_ms,
283        );
284
285        sleep(Duration::from_millis(delay)).await;
286
287        // Attempt to reconnect
288        self.connect(transport_factory()).await
289    }
290
291    /// Get the underlying client (for direct operations)
292    pub fn client(&self) -> Arc<Mutex<McpClient>> {
293        self.client.clone()
294    }
295
296    /// Get session configuration
297    pub fn config(&self) -> &SessionConfig {
298        &self.config
299    }
300
301    // ========================================================================
302    // Background Tasks
303    // ========================================================================
304
305    /// Start background tasks (notification handling, heartbeat)
306    async fn start_background_tasks(&self) -> McpResult<()> {
307        let (_shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
308            broadcast::channel(16);
309        {
310            let mut shutdown_guard = self.shutdown_tx.lock().await;
311            *shutdown_guard = Some(mpsc::channel(1).0); // Store a dummy for interface compatibility
312        }
313
314        // Start notification handler task
315        {
316            let client = self.client.clone();
317            let handlers = self.notification_handlers.clone();
318            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
319
320            tokio::spawn(async move {
321                loop {
322                    tokio::select! {
323                        _ = shutdown_rx_clone.recv() => break,
324                        notification_result = async {
325                            let client_guard = client.lock().await;
326                            client_guard.receive_notification().await
327                        } => {
328                            match notification_result {
329                                Ok(Some(notification)) => {
330                                    let handlers_guard = handlers.read().await;
331                                    for handler in handlers_guard.iter() {
332                                        handler.handle_notification(notification.clone());
333                                    }
334                                }
335                                Ok(None) => {
336                                    // No notification available, continue
337                                }
338                                Err(_) => {
339                                    // Error receiving notification, might be disconnected
340                                    break;
341                                }
342                            }
343                        }
344                    }
345                }
346            });
347        }
348
349        // Start heartbeat task if enabled
350        if self.config.heartbeat_interval_ms > 0 {
351            let client = self.client.clone();
352            let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
353            let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
354            let state = self.state.clone();
355            let state_tx = self.state_tx.clone();
356            let mut shutdown_rx_clone = shutdown_rx.resubscribe();
357
358            tokio::spawn(async move {
359                let mut interval = tokio::time::interval(heartbeat_interval);
360
361                loop {
362                    tokio::select! {
363                        _ = shutdown_rx_clone.recv() => break,
364                        _ = interval.tick() => {
365                            // Check if we're still connected
366                            {
367                                let current_state = state.read().await;
368                                if !matches!(*current_state, SessionState::Connected) {
369                                    break;
370                                }
371                            }
372
373                            // Send ping
374                            let ping_result = timeout(heartbeat_timeout, async {
375                                let client_guard = client.lock().await;
376                                client_guard.ping().await
377                            }).await;
378
379                            if ping_result.is_err() {
380                                // Heartbeat failed, mark as disconnected
381                                let _ = state_tx.send(SessionState::Disconnected);
382                                break;
383                            }
384                        }
385                    }
386                }
387            });
388        }
389
390        Ok(())
391    }
392
393    /// Stop background tasks
394    async fn stop_background_tasks(&self) {
395        let shutdown_tx = {
396            let mut shutdown_guard = self.shutdown_tx.lock().await;
397            shutdown_guard.take()
398        };
399
400        if let Some(tx) = shutdown_tx {
401            let _ = tx.send(()).await; // Ignore error if receiver is dropped
402        }
403    }
404
405    /// Transition to a new state
406    async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
407        {
408            let mut state = self.state.write().await;
409            *state = new_state.clone();
410        }
411
412        // Broadcast the state change
413        if self.state_tx.send(new_state).is_err() {
414            // Receiver may have been dropped, which is okay
415        }
416
417        Ok(())
418    }
419}
420
421/// Default notification handler that logs notifications
422pub struct LoggingNotificationHandler;
423
424impl NotificationHandler for LoggingNotificationHandler {
425    fn handle_notification(&self, notification: JsonRpcNotification) {
426        tracing::info!(
427            "Received notification: {} {:?}",
428            notification.method,
429            notification.params
430        );
431    }
432}
433
434/// Resource update notification handler
435pub struct ResourceUpdateHandler {
436    callback: Box<dyn Fn(String) + Send + Sync>,
437}
438
439impl ResourceUpdateHandler {
440    /// Create a new resource update handler
441    pub fn new<F>(callback: F) -> Self
442    where
443        F: Fn(String) + Send + Sync + 'static,
444    {
445        Self {
446            callback: Box::new(callback),
447        }
448    }
449}
450
451impl NotificationHandler for ResourceUpdateHandler {
452    fn handle_notification(&self, notification: JsonRpcNotification) {
453        if notification.method == methods::RESOURCES_UPDATED {
454            if let Some(params) = notification.params {
455                if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
456                    (self.callback)(update_params.uri);
457                }
458            }
459        }
460    }
461}
462
463/// Tool list changed notification handler
464pub struct ToolListChangedHandler {
465    callback: Box<dyn Fn() + Send + Sync>,
466}
467
468impl ToolListChangedHandler {
469    /// Create a new tool list changed handler
470    pub fn new<F>(callback: F) -> Self
471    where
472        F: Fn() + Send + Sync + 'static,
473    {
474        Self {
475            callback: Box::new(callback),
476        }
477    }
478}
479
480impl NotificationHandler for ToolListChangedHandler {
481    fn handle_notification(&self, notification: JsonRpcNotification) {
482        if notification.method == methods::TOOLS_LIST_CHANGED {
483            (self.callback)();
484        }
485    }
486}
487
488/// Progress notification handler
489pub struct ProgressHandler {
490    callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
491}
492
493impl ProgressHandler {
494    /// Create a new progress handler
495    pub fn new<F>(callback: F) -> Self
496    where
497        F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
498    {
499        Self {
500            callback: Box::new(callback),
501        }
502    }
503}
504
505impl NotificationHandler for ProgressHandler {
506    fn handle_notification(&self, notification: JsonRpcNotification) {
507        if notification.method == methods::PROGRESS {
508            if let Some(params) = notification.params {
509                if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
510                    (self.callback)(
511                        progress_params.progress_token.to_string(),
512                        progress_params.progress,
513                        progress_params.total.map(|t| t as u32),
514                    );
515                }
516            }
517        }
518    }
519}
520
521/// Session statistics
522#[derive(Debug, Clone)]
523pub struct SessionStats {
524    /// Current session state
525    pub state: SessionState,
526    /// Connection uptime
527    pub uptime: Option<Duration>,
528    /// Number of reconnection attempts
529    pub reconnect_attempts: u32,
530    /// Connection timestamp
531    pub connected_at: Option<Instant>,
532}
533
534impl ClientSession {
535    /// Get session statistics
536    pub async fn stats(&self) -> SessionStats {
537        let state = self.state().await;
538        let uptime = self.uptime().await;
539        let reconnect_attempts = {
540            let attempts = self.reconnect_attempts.lock().await;
541            *attempts
542        };
543        let connected_at = {
544            let connected_at = self.connected_at.read().await;
545            *connected_at
546        };
547
548        SessionStats {
549            state,
550            uptime,
551            reconnect_attempts,
552            connected_at,
553        }
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use crate::client::mcp_client::McpClient;
561    use async_trait::async_trait;
562
563    // Mock transport for testing
564    struct MockTransport;
565
566    #[async_trait]
567    impl Transport for MockTransport {
568        async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
569            // Return a successful initialize response
570            let init_result = InitializeResult::new(
571                crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
572                ServerCapabilities::default(),
573                ServerInfo {
574                    name: "test-server".to_string(),
575                    version: "1.0.0".to_string(),
576                },
577            );
578            JsonRpcResponse::success(serde_json::Value::from(1), init_result)
579                .map_err(|e| McpError::Serialization(e.to_string()))
580        }
581
582        async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
583            Ok(())
584        }
585
586        async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
587            Ok(None)
588        }
589
590        async fn close(&mut self) -> McpResult<()> {
591            Ok(())
592        }
593    }
594
595    #[tokio::test]
596    async fn test_session_creation() {
597        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
598        let session = ClientSession::new(client);
599
600        assert_eq!(session.state().await, SessionState::Disconnected);
601        assert!(!session.is_connected().await);
602        assert!(session.uptime().await.is_none());
603    }
604
605    #[tokio::test]
606    async fn test_session_connection() {
607        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
608        let session = ClientSession::new(client);
609
610        let transport = MockTransport;
611        let result = session.connect(transport).await;
612
613        assert!(result.is_ok());
614        assert_eq!(session.state().await, SessionState::Connected);
615        assert!(session.is_connected().await);
616        assert!(session.uptime().await.is_some());
617    }
618
619    #[tokio::test]
620    async fn test_session_disconnect() {
621        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
622        let session = ClientSession::new(client);
623
624        // Connect first
625        let transport = MockTransport;
626        session.connect(transport).await.unwrap();
627        assert!(session.is_connected().await);
628
629        // Then disconnect
630        session.disconnect().await.unwrap();
631        assert_eq!(session.state().await, SessionState::Disconnected);
632        assert!(!session.is_connected().await);
633        assert!(session.uptime().await.is_none());
634    }
635
636    #[tokio::test]
637    async fn test_notification_handlers() {
638        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
639        let session = ClientSession::new(client);
640
641        // Add a logging notification handler
642        session
643            .add_notification_handler(LoggingNotificationHandler)
644            .await;
645
646        // Add a resource update handler
647        session
648            .add_notification_handler(ResourceUpdateHandler::new(|uri| {
649                println!("Resource updated: {}", uri);
650            }))
651            .await;
652
653        // Add a tool list changed handler
654        session
655            .add_notification_handler(ToolListChangedHandler::new(|| {
656                println!("Tool list changed");
657            }))
658            .await;
659
660        // Add a progress handler
661        session
662            .add_notification_handler(ProgressHandler::new(|token, progress, total| {
663                println!("Progress {}: {} / {:?}", token, progress, total);
664            }))
665            .await;
666
667        let handlers = session.notification_handlers.read().await;
668        assert_eq!(handlers.len(), 4);
669    }
670
671    #[tokio::test]
672    async fn test_session_stats() {
673        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
674        let session = ClientSession::new(client);
675
676        let stats = session.stats().await;
677        assert_eq!(stats.state, SessionState::Disconnected);
678        assert!(stats.uptime.is_none());
679        assert_eq!(stats.reconnect_attempts, 0);
680        assert!(stats.connected_at.is_none());
681    }
682
683    #[tokio::test]
684    async fn test_session_config() {
685        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
686        let config = SessionConfig {
687            auto_reconnect: false,
688            max_reconnect_attempts: 10,
689            reconnect_delay_ms: 2000,
690            ..Default::default()
691        };
692        let session = ClientSession::with_config(client, config.clone());
693
694        assert!(!session.config().auto_reconnect);
695        assert_eq!(session.config().max_reconnect_attempts, 10);
696        assert_eq!(session.config().reconnect_delay_ms, 2000);
697    }
698
699    #[tokio::test]
700    async fn test_state_subscription() {
701        let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
702        let session = ClientSession::new(client);
703
704        let mut state_rx = session.subscribe_state_changes();
705
706        // Initial state
707        assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
708
709        // Change state
710        session
711            .transition_state(SessionState::Connecting)
712            .await
713            .unwrap();
714
715        // Wait for change
716        state_rx.changed().await.unwrap();
717        assert_eq!(*state_rx.borrow(), SessionState::Connecting);
718    }
719}