1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
10use tokio::time::{sleep, timeout};
11
12use crate::client::mcp_client::{ClientConfig, McpClient};
13use crate::core::error::{McpError, McpResult};
14use crate::protocol::{messages::*, types::*};
15use crate::transport::traits::Transport;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum SessionState {
20 Disconnected,
22 Connecting,
24 Connected,
26 Reconnecting,
28 Failed(String),
30}
31
32pub trait NotificationHandler: Send + Sync {
34 fn handle_notification(&self, notification: JsonRpcNotification);
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionConfig {
41 pub auto_reconnect: bool,
43 pub max_reconnect_attempts: u32,
45 pub reconnect_delay_ms: u64,
47 pub max_reconnect_delay_ms: u64,
49 pub reconnect_backoff: f64,
51 pub connection_timeout_ms: u64,
53 pub heartbeat_interval_ms: u64,
55 pub heartbeat_timeout_ms: u64,
57}
58
59impl Default for SessionConfig {
60 fn default() -> Self {
61 Self {
62 auto_reconnect: true,
63 max_reconnect_attempts: 5,
64 reconnect_delay_ms: 1000,
65 max_reconnect_delay_ms: 30000,
66 reconnect_backoff: 2.0,
67 connection_timeout_ms: 10000,
68 heartbeat_interval_ms: 30000,
69 heartbeat_timeout_ms: 5000,
70 }
71 }
72}
73
74pub struct ClientSession {
76 client: Arc<Mutex<McpClient>>,
78 config: SessionConfig,
80 state: Arc<RwLock<SessionState>>,
82 state_tx: watch::Sender<SessionState>,
84 state_rx: watch::Receiver<SessionState>,
86 notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
88 connected_at: Arc<RwLock<Option<Instant>>>,
90 reconnect_attempts: Arc<Mutex<u32>>,
92 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
94}
95
96impl ClientSession {
97 pub fn new(client: McpClient) -> Self {
99 let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
100
101 Self {
102 client: Arc::new(Mutex::new(client)),
103 config: SessionConfig::default(),
104 state: Arc::new(RwLock::new(SessionState::Disconnected)),
105 state_tx,
106 state_rx,
107 notification_handlers: Arc::new(RwLock::new(Vec::new())),
108 connected_at: Arc::new(RwLock::new(None)),
109 reconnect_attempts: Arc::new(Mutex::new(0)),
110 shutdown_tx: Arc::new(Mutex::new(None)),
111 }
112 }
113
114 pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
116 let mut session = Self::new(client);
117 session.config = config;
118 session
119 }
120
121 pub async fn state(&self) -> SessionState {
123 let state = self.state.read().await;
124 state.clone()
125 }
126
127 pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
129 self.state_rx.clone()
130 }
131
132 pub async fn is_connected(&self) -> bool {
134 let state = self.state.read().await;
135 matches!(*state, SessionState::Connected)
136 }
137
138 pub async fn uptime(&self) -> Option<Duration> {
140 let connected_at = self.connected_at.read().await;
141 connected_at.map(|time| time.elapsed())
142 }
143
144 pub async fn add_notification_handler<H>(&self, handler: H)
146 where
147 H: NotificationHandler + 'static,
148 {
149 let mut handlers = self.notification_handlers.write().await;
150 handlers.push(Box::new(handler));
151 }
152
153 pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
155 where
156 T: Transport + 'static,
157 {
158 self.transition_state(SessionState::Connecting).await?;
159
160 let connect_future = async {
161 let mut client = self.client.lock().await;
162 client.connect(transport).await
163 };
164
165 let result = timeout(
166 Duration::from_millis(self.config.connection_timeout_ms),
167 connect_future,
168 )
169 .await;
170
171 match result {
172 Ok(Ok(init_result)) => {
173 self.transition_state(SessionState::Connected).await?;
174
175 {
177 let mut connected_at = self.connected_at.write().await;
178 *connected_at = Some(Instant::now());
179 }
180
181 {
183 let mut attempts = self.reconnect_attempts.lock().await;
184 *attempts = 0;
185 }
186
187 self.start_background_tasks().await?;
189
190 Ok(init_result)
191 }
192 Ok(Err(error)) => {
193 self.transition_state(SessionState::Failed(error.to_string()))
194 .await?;
195 Err(error)
196 }
197 Err(_) => {
198 let error = McpError::Connection("Connection timeout".to_string());
199 self.transition_state(SessionState::Failed(error.to_string()))
200 .await?;
201 Err(error)
202 }
203 }
204 }
205
206 pub async fn disconnect(&self) -> McpResult<()> {
208 self.stop_background_tasks().await;
210
211 {
213 let client = self.client.lock().await;
214 client.disconnect().await?;
215 }
216
217 self.transition_state(SessionState::Disconnected).await?;
219
220 {
222 let mut connected_at = self.connected_at.write().await;
223 *connected_at = None;
224 }
225
226 Ok(())
227 }
228
229 pub async fn reconnect<T>(
231 &self,
232 transport_factory: impl Fn() -> T,
233 ) -> McpResult<InitializeResult>
234 where
235 T: Transport + 'static,
236 {
237 if !self.config.auto_reconnect {
238 return Err(McpError::Connection(
239 "Auto-reconnect is disabled".to_string(),
240 ));
241 }
242
243 let mut attempts = self.reconnect_attempts.lock().await;
244 if *attempts >= self.config.max_reconnect_attempts {
245 let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
246 self.transition_state(SessionState::Failed(error.to_string()))
247 .await?;
248 return Err(error);
249 }
250
251 *attempts += 1;
252 let current_attempts = *attempts;
253 drop(attempts);
254
255 self.transition_state(SessionState::Reconnecting).await?;
256
257 let delay = std::cmp::min(
259 (self.config.reconnect_delay_ms as f64
260 * self
261 .config
262 .reconnect_backoff
263 .powi(current_attempts as i32 - 1)) as u64,
264 self.config.max_reconnect_delay_ms,
265 );
266
267 sleep(Duration::from_millis(delay)).await;
268
269 self.connect(transport_factory()).await
271 }
272
273 pub fn client(&self) -> Arc<Mutex<McpClient>> {
275 self.client.clone()
276 }
277
278 pub fn config(&self) -> &SessionConfig {
280 &self.config
281 }
282
283 async fn start_background_tasks(&self) -> McpResult<()> {
289 let (shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
290 broadcast::channel(16);
291 {
292 let mut shutdown_guard = self.shutdown_tx.lock().await;
293 *shutdown_guard = Some(mpsc::channel(1).0); }
295
296 {
298 let client = self.client.clone();
299 let handlers = self.notification_handlers.clone();
300 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
301
302 tokio::spawn(async move {
303 loop {
304 tokio::select! {
305 _ = shutdown_rx_clone.recv() => break,
306 notification_result = async {
307 let client_guard = client.lock().await;
308 client_guard.receive_notification().await
309 } => {
310 match notification_result {
311 Ok(Some(notification)) => {
312 let handlers_guard = handlers.read().await;
313 for handler in handlers_guard.iter() {
314 handler.handle_notification(notification.clone());
315 }
316 }
317 Ok(None) => {
318 }
320 Err(_) => {
321 break;
323 }
324 }
325 }
326 }
327 }
328 });
329 }
330
331 if self.config.heartbeat_interval_ms > 0 {
333 let client = self.client.clone();
334 let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
335 let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
336 let state = self.state.clone();
337 let state_tx = self.state_tx.clone();
338 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
339
340 tokio::spawn(async move {
341 let mut interval = tokio::time::interval(heartbeat_interval);
342
343 loop {
344 tokio::select! {
345 _ = shutdown_rx_clone.recv() => break,
346 _ = interval.tick() => {
347 {
349 let current_state = state.read().await;
350 if !matches!(*current_state, SessionState::Connected) {
351 break;
352 }
353 }
354
355 let ping_result = timeout(heartbeat_timeout, async {
357 let client_guard = client.lock().await;
358 client_guard.ping().await
359 }).await;
360
361 if ping_result.is_err() {
362 let _ = state_tx.send(SessionState::Disconnected);
364 break;
365 }
366 }
367 }
368 }
369 });
370 }
371
372 Ok(())
373 }
374
375 async fn stop_background_tasks(&self) {
377 let shutdown_tx = {
378 let mut shutdown_guard = self.shutdown_tx.lock().await;
379 shutdown_guard.take()
380 };
381
382 if let Some(tx) = shutdown_tx {
383 let _ = tx.send(()).await; }
385 }
386
387 async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
389 {
390 let mut state = self.state.write().await;
391 *state = new_state.clone();
392 }
393
394 if let Err(_) = self.state_tx.send(new_state) {
396 }
398
399 Ok(())
400 }
401}
402
403pub struct LoggingNotificationHandler;
405
406impl NotificationHandler for LoggingNotificationHandler {
407 fn handle_notification(&self, notification: JsonRpcNotification) {
408 tracing::info!(
409 "Received notification: {} {:?}",
410 notification.method,
411 notification.params
412 );
413 }
414}
415
416pub struct ResourceUpdateHandler {
418 callback: Box<dyn Fn(String) + Send + Sync>,
419}
420
421impl ResourceUpdateHandler {
422 pub fn new<F>(callback: F) -> Self
424 where
425 F: Fn(String) + Send + Sync + 'static,
426 {
427 Self {
428 callback: Box::new(callback),
429 }
430 }
431}
432
433impl NotificationHandler for ResourceUpdateHandler {
434 fn handle_notification(&self, notification: JsonRpcNotification) {
435 if notification.method == methods::RESOURCES_UPDATED {
436 if let Some(params) = notification.params {
437 if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
438 (self.callback)(update_params.uri);
439 }
440 }
441 }
442 }
443}
444
445pub struct ToolListChangedHandler {
447 callback: Box<dyn Fn() + Send + Sync>,
448}
449
450impl ToolListChangedHandler {
451 pub fn new<F>(callback: F) -> Self
453 where
454 F: Fn() + Send + Sync + 'static,
455 {
456 Self {
457 callback: Box::new(callback),
458 }
459 }
460}
461
462impl NotificationHandler for ToolListChangedHandler {
463 fn handle_notification(&self, notification: JsonRpcNotification) {
464 if notification.method == methods::TOOLS_LIST_CHANGED {
465 (self.callback)();
466 }
467 }
468}
469
470pub struct ProgressHandler {
472 callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
473}
474
475impl ProgressHandler {
476 pub fn new<F>(callback: F) -> Self
478 where
479 F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
480 {
481 Self {
482 callback: Box::new(callback),
483 }
484 }
485}
486
487impl NotificationHandler for ProgressHandler {
488 fn handle_notification(&self, notification: JsonRpcNotification) {
489 if notification.method == methods::PROGRESS {
490 if let Some(params) = notification.params {
491 if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
492 (self.callback)(
493 progress_params.progress_token,
494 progress_params.progress,
495 progress_params.total,
496 );
497 }
498 }
499 }
500 }
501}
502
503#[derive(Debug, Clone)]
505pub struct SessionStats {
506 pub state: SessionState,
508 pub uptime: Option<Duration>,
510 pub reconnect_attempts: u32,
512 pub connected_at: Option<Instant>,
514}
515
516impl ClientSession {
517 pub async fn stats(&self) -> SessionStats {
519 let state = self.state().await;
520 let uptime = self.uptime().await;
521 let reconnect_attempts = {
522 let attempts = self.reconnect_attempts.lock().await;
523 *attempts
524 };
525 let connected_at = {
526 let connected_at = self.connected_at.read().await;
527 *connected_at
528 };
529
530 SessionStats {
531 state,
532 uptime,
533 reconnect_attempts,
534 connected_at,
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use crate::client::mcp_client::McpClient;
543 use async_trait::async_trait;
544
545 struct MockTransport;
547
548 #[async_trait]
549 impl Transport for MockTransport {
550 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
551 let init_result = InitializeResult::new(
553 ServerInfo {
554 name: "test-server".to_string(),
555 version: "1.0.0".to_string(),
556 },
557 ServerCapabilities::default(),
558 MCP_PROTOCOL_VERSION.to_string(),
559 );
560 JsonRpcResponse::success(serde_json::Value::from(1), init_result)
561 .map_err(|e| McpError::Serialization(e))
562 }
563
564 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
565 Ok(())
566 }
567
568 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
569 Ok(None)
570 }
571
572 async fn close(&mut self) -> McpResult<()> {
573 Ok(())
574 }
575 }
576
577 #[tokio::test]
578 async fn test_session_creation() {
579 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
580 let session = ClientSession::new(client);
581
582 assert_eq!(session.state().await, SessionState::Disconnected);
583 assert!(!session.is_connected().await);
584 assert!(session.uptime().await.is_none());
585 }
586
587 #[tokio::test]
588 async fn test_session_connection() {
589 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
590 let session = ClientSession::new(client);
591
592 let transport = MockTransport;
593 let result = session.connect(transport).await;
594
595 assert!(result.is_ok());
596 assert_eq!(session.state().await, SessionState::Connected);
597 assert!(session.is_connected().await);
598 assert!(session.uptime().await.is_some());
599 }
600
601 #[tokio::test]
602 async fn test_session_disconnect() {
603 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
604 let session = ClientSession::new(client);
605
606 let transport = MockTransport;
608 session.connect(transport).await.unwrap();
609 assert!(session.is_connected().await);
610
611 session.disconnect().await.unwrap();
613 assert_eq!(session.state().await, SessionState::Disconnected);
614 assert!(!session.is_connected().await);
615 assert!(session.uptime().await.is_none());
616 }
617
618 #[tokio::test]
619 async fn test_notification_handlers() {
620 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
621 let session = ClientSession::new(client);
622
623 session
625 .add_notification_handler(LoggingNotificationHandler)
626 .await;
627
628 session
630 .add_notification_handler(ResourceUpdateHandler::new(|uri| {
631 println!("Resource updated: {}", uri);
632 }))
633 .await;
634
635 session
637 .add_notification_handler(ToolListChangedHandler::new(|| {
638 println!("Tool list changed");
639 }))
640 .await;
641
642 session
644 .add_notification_handler(ProgressHandler::new(|token, progress, total| {
645 println!("Progress {}: {} / {:?}", token, progress, total);
646 }))
647 .await;
648
649 let handlers = session.notification_handlers.read().await;
650 assert_eq!(handlers.len(), 4);
651 }
652
653 #[tokio::test]
654 async fn test_session_stats() {
655 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
656 let session = ClientSession::new(client);
657
658 let stats = session.stats().await;
659 assert_eq!(stats.state, SessionState::Disconnected);
660 assert!(stats.uptime.is_none());
661 assert_eq!(stats.reconnect_attempts, 0);
662 assert!(stats.connected_at.is_none());
663 }
664
665 #[tokio::test]
666 async fn test_session_config() {
667 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
668 let config = SessionConfig {
669 auto_reconnect: false,
670 max_reconnect_attempts: 10,
671 reconnect_delay_ms: 2000,
672 ..Default::default()
673 };
674 let session = ClientSession::with_config(client, config.clone());
675
676 assert_eq!(session.config().auto_reconnect, false);
677 assert_eq!(session.config().max_reconnect_attempts, 10);
678 assert_eq!(session.config().reconnect_delay_ms, 2000);
679 }
680
681 #[tokio::test]
682 async fn test_state_subscription() {
683 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
684 let session = ClientSession::new(client);
685
686 let mut state_rx = session.subscribe_state_changes();
687
688 assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
690
691 session
693 .transition_state(SessionState::Connecting)
694 .await
695 .unwrap();
696
697 state_rx.changed().await.unwrap();
699 assert_eq!(*state_rx.borrow(), SessionState::Connecting);
700 }
701}