1use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
9use tokio::time::{sleep, timeout};
10
11use crate::client::mcp_client::McpClient;
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, types::*};
14use crate::transport::traits::Transport;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum SessionState {
19 Disconnected,
21 Connecting,
23 Connected,
25 Reconnecting,
27 Failed(String),
29}
30
31pub trait NotificationHandler: Send + Sync {
33 fn handle_notification(&self, notification: JsonRpcNotification);
35}
36
37#[derive(Debug, Clone)]
39pub struct SessionConfig {
40 pub auto_reconnect: bool,
42 pub max_reconnect_attempts: u32,
44 pub reconnect_delay_ms: u64,
46 pub max_reconnect_delay_ms: u64,
48 pub reconnect_backoff: f64,
50 pub connection_timeout_ms: u64,
52 pub heartbeat_interval_ms: u64,
54 pub heartbeat_timeout_ms: u64,
56}
57
58impl Default for SessionConfig {
59 fn default() -> Self {
60 Self {
61 auto_reconnect: true,
62 max_reconnect_attempts: 5,
63 reconnect_delay_ms: 1000,
64 max_reconnect_delay_ms: 30000,
65 reconnect_backoff: 2.0,
66 connection_timeout_ms: 10000,
67 heartbeat_interval_ms: 30000,
68 heartbeat_timeout_ms: 5000,
69 }
70 }
71}
72
73pub struct ClientSession {
75 client: Arc<Mutex<McpClient>>,
77 config: SessionConfig,
79 state: Arc<RwLock<SessionState>>,
81 state_tx: watch::Sender<SessionState>,
83 state_rx: watch::Receiver<SessionState>,
85 notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
87 connected_at: Arc<RwLock<Option<Instant>>>,
89 reconnect_attempts: Arc<Mutex<u32>>,
91 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
93}
94
95impl ClientSession {
96 pub fn new(client: McpClient) -> Self {
98 let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
99
100 Self {
101 client: Arc::new(Mutex::new(client)),
102 config: SessionConfig::default(),
103 state: Arc::new(RwLock::new(SessionState::Disconnected)),
104 state_tx,
105 state_rx,
106 notification_handlers: Arc::new(RwLock::new(Vec::new())),
107 connected_at: Arc::new(RwLock::new(None)),
108 reconnect_attempts: Arc::new(Mutex::new(0)),
109 shutdown_tx: Arc::new(Mutex::new(None)),
110 }
111 }
112
113 pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
115 let mut session = Self::new(client);
116 session.config = config;
117 session
118 }
119
120 pub async fn state(&self) -> SessionState {
122 let state = self.state.read().await;
123 state.clone()
124 }
125
126 pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
128 self.state_rx.clone()
129 }
130
131 pub async fn is_connected(&self) -> bool {
133 let state = self.state.read().await;
134 matches!(*state, SessionState::Connected)
135 }
136
137 pub async fn uptime(&self) -> Option<Duration> {
139 let connected_at = self.connected_at.read().await;
140 connected_at.map(|time| time.elapsed())
141 }
142
143 pub async fn add_notification_handler<H>(&self, handler: H)
145 where
146 H: NotificationHandler + 'static,
147 {
148 let mut handlers = self.notification_handlers.write().await;
149 handlers.push(Box::new(handler));
150 }
151
152 pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
154 where
155 T: Transport + 'static,
156 {
157 self.transition_state(SessionState::Connecting).await?;
158
159 let connect_future = async {
160 let mut client = self.client.lock().await;
161 client.connect(transport).await
162 };
163
164 let result = timeout(
165 Duration::from_millis(self.config.connection_timeout_ms),
166 connect_future,
167 )
168 .await;
169
170 match result {
171 Ok(Ok(init_result)) => {
172 self.transition_state(SessionState::Connected).await?;
173
174 {
176 let mut connected_at = self.connected_at.write().await;
177 *connected_at = Some(Instant::now());
178 }
179
180 {
182 let mut attempts = self.reconnect_attempts.lock().await;
183 *attempts = 0;
184 }
185
186 self.start_background_tasks().await?;
188
189 Ok(init_result)
190 }
191 Ok(Err(error)) => {
192 self.transition_state(SessionState::Failed(error.to_string()))
193 .await?;
194 Err(error)
195 }
196 Err(_) => {
197 let error = McpError::Connection("Connection timeout".to_string());
198 self.transition_state(SessionState::Failed(error.to_string()))
199 .await?;
200 Err(error)
201 }
202 }
203 }
204
205 pub async fn disconnect(&self) -> McpResult<()> {
207 self.stop_background_tasks().await;
209
210 {
212 let client = self.client.lock().await;
213 client.disconnect().await?;
214 }
215
216 self.transition_state(SessionState::Disconnected).await?;
218
219 {
221 let mut connected_at = self.connected_at.write().await;
222 *connected_at = None;
223 }
224
225 Ok(())
226 }
227
228 pub async fn reconnect<T>(
230 &self,
231 transport_factory: impl Fn() -> T,
232 ) -> McpResult<InitializeResult>
233 where
234 T: Transport + 'static,
235 {
236 if !self.config.auto_reconnect {
237 return Err(McpError::Connection(
238 "Auto-reconnect is disabled".to_string(),
239 ));
240 }
241
242 let mut attempts = self.reconnect_attempts.lock().await;
243 if *attempts >= self.config.max_reconnect_attempts {
244 let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
245 self.transition_state(SessionState::Failed(error.to_string()))
246 .await?;
247 return Err(error);
248 }
249
250 *attempts += 1;
251 let current_attempts = *attempts;
252 drop(attempts);
253
254 self.transition_state(SessionState::Reconnecting).await?;
255
256 let delay = std::cmp::min(
258 (self.config.reconnect_delay_ms as f64
259 * self
260 .config
261 .reconnect_backoff
262 .powi(current_attempts as i32 - 1)) as u64,
263 self.config.max_reconnect_delay_ms,
264 );
265
266 sleep(Duration::from_millis(delay)).await;
267
268 self.connect(transport_factory()).await
270 }
271
272 pub fn client(&self) -> Arc<Mutex<McpClient>> {
274 self.client.clone()
275 }
276
277 pub fn config(&self) -> &SessionConfig {
279 &self.config
280 }
281
282 async fn start_background_tasks(&self) -> McpResult<()> {
288 let (_shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
289 broadcast::channel(16);
290 {
291 let mut shutdown_guard = self.shutdown_tx.lock().await;
292 *shutdown_guard = Some(mpsc::channel(1).0); }
294
295 {
297 let client = self.client.clone();
298 let handlers = self.notification_handlers.clone();
299 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
300
301 tokio::spawn(async move {
302 loop {
303 tokio::select! {
304 _ = shutdown_rx_clone.recv() => break,
305 notification_result = async {
306 let client_guard = client.lock().await;
307 client_guard.receive_notification().await
308 } => {
309 match notification_result {
310 Ok(Some(notification)) => {
311 let handlers_guard = handlers.read().await;
312 for handler in handlers_guard.iter() {
313 handler.handle_notification(notification.clone());
314 }
315 }
316 Ok(None) => {
317 }
319 Err(_) => {
320 break;
322 }
323 }
324 }
325 }
326 }
327 });
328 }
329
330 if self.config.heartbeat_interval_ms > 0 {
332 let client = self.client.clone();
333 let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
334 let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
335 let state = self.state.clone();
336 let state_tx = self.state_tx.clone();
337 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
338
339 tokio::spawn(async move {
340 let mut interval = tokio::time::interval(heartbeat_interval);
341
342 loop {
343 tokio::select! {
344 _ = shutdown_rx_clone.recv() => break,
345 _ = interval.tick() => {
346 {
348 let current_state = state.read().await;
349 if !matches!(*current_state, SessionState::Connected) {
350 break;
351 }
352 }
353
354 let ping_result = timeout(heartbeat_timeout, async {
356 let client_guard = client.lock().await;
357 client_guard.ping().await
358 }).await;
359
360 if ping_result.is_err() {
361 let _ = state_tx.send(SessionState::Disconnected);
363 break;
364 }
365 }
366 }
367 }
368 });
369 }
370
371 Ok(())
372 }
373
374 async fn stop_background_tasks(&self) {
376 let shutdown_tx = {
377 let mut shutdown_guard = self.shutdown_tx.lock().await;
378 shutdown_guard.take()
379 };
380
381 if let Some(tx) = shutdown_tx {
382 let _ = tx.send(()).await; }
384 }
385
386 async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
388 {
389 let mut state = self.state.write().await;
390 *state = new_state.clone();
391 }
392
393 if self.state_tx.send(new_state).is_err() {
395 }
397
398 Ok(())
399 }
400}
401
402pub struct LoggingNotificationHandler;
404
405impl NotificationHandler for LoggingNotificationHandler {
406 fn handle_notification(&self, notification: JsonRpcNotification) {
407 tracing::info!(
408 "Received notification: {} {:?}",
409 notification.method,
410 notification.params
411 );
412 }
413}
414
415pub struct ResourceUpdateHandler {
417 callback: Box<dyn Fn(String) + Send + Sync>,
418}
419
420impl ResourceUpdateHandler {
421 pub fn new<F>(callback: F) -> Self
423 where
424 F: Fn(String) + Send + Sync + 'static,
425 {
426 Self {
427 callback: Box::new(callback),
428 }
429 }
430}
431
432impl NotificationHandler for ResourceUpdateHandler {
433 fn handle_notification(&self, notification: JsonRpcNotification) {
434 if notification.method == methods::RESOURCES_UPDATED {
435 if let Some(params) = notification.params {
436 if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
437 (self.callback)(update_params.uri);
438 }
439 }
440 }
441 }
442}
443
444pub struct ToolListChangedHandler {
446 callback: Box<dyn Fn() + Send + Sync>,
447}
448
449impl ToolListChangedHandler {
450 pub fn new<F>(callback: F) -> Self
452 where
453 F: Fn() + Send + Sync + 'static,
454 {
455 Self {
456 callback: Box::new(callback),
457 }
458 }
459}
460
461impl NotificationHandler for ToolListChangedHandler {
462 fn handle_notification(&self, notification: JsonRpcNotification) {
463 if notification.method == methods::TOOLS_LIST_CHANGED {
464 (self.callback)();
465 }
466 }
467}
468
469pub struct ProgressHandler {
471 callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
472}
473
474impl ProgressHandler {
475 pub fn new<F>(callback: F) -> Self
477 where
478 F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
479 {
480 Self {
481 callback: Box::new(callback),
482 }
483 }
484}
485
486impl NotificationHandler for ProgressHandler {
487 fn handle_notification(&self, notification: JsonRpcNotification) {
488 if notification.method == methods::PROGRESS {
489 if let Some(params) = notification.params {
490 if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
491 (self.callback)(
492 progress_params.progress_token,
493 progress_params.progress,
494 progress_params.total,
495 );
496 }
497 }
498 }
499 }
500}
501
502#[derive(Debug, Clone)]
504pub struct SessionStats {
505 pub state: SessionState,
507 pub uptime: Option<Duration>,
509 pub reconnect_attempts: u32,
511 pub connected_at: Option<Instant>,
513}
514
515impl ClientSession {
516 pub async fn stats(&self) -> SessionStats {
518 let state = self.state().await;
519 let uptime = self.uptime().await;
520 let reconnect_attempts = {
521 let attempts = self.reconnect_attempts.lock().await;
522 *attempts
523 };
524 let connected_at = {
525 let connected_at = self.connected_at.read().await;
526 *connected_at
527 };
528
529 SessionStats {
530 state,
531 uptime,
532 reconnect_attempts,
533 connected_at,
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::client::mcp_client::McpClient;
542 use async_trait::async_trait;
543
544 struct MockTransport;
546
547 #[async_trait]
548 impl Transport for MockTransport {
549 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
550 let init_result = InitializeResult::new(
552 ServerInfo {
553 name: "test-server".to_string(),
554 version: "1.0.0".to_string(),
555 },
556 ServerCapabilities::default(),
557 MCP_PROTOCOL_VERSION.to_string(),
558 );
559 JsonRpcResponse::success(serde_json::Value::from(1), init_result)
560 .map_err(McpError::Serialization)
561 }
562
563 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
564 Ok(())
565 }
566
567 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
568 Ok(None)
569 }
570
571 async fn close(&mut self) -> McpResult<()> {
572 Ok(())
573 }
574 }
575
576 #[tokio::test]
577 async fn test_session_creation() {
578 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
579 let session = ClientSession::new(client);
580
581 assert_eq!(session.state().await, SessionState::Disconnected);
582 assert!(!session.is_connected().await);
583 assert!(session.uptime().await.is_none());
584 }
585
586 #[tokio::test]
587 async fn test_session_connection() {
588 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
589 let session = ClientSession::new(client);
590
591 let transport = MockTransport;
592 let result = session.connect(transport).await;
593
594 assert!(result.is_ok());
595 assert_eq!(session.state().await, SessionState::Connected);
596 assert!(session.is_connected().await);
597 assert!(session.uptime().await.is_some());
598 }
599
600 #[tokio::test]
601 async fn test_session_disconnect() {
602 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
603 let session = ClientSession::new(client);
604
605 let transport = MockTransport;
607 session.connect(transport).await.unwrap();
608 assert!(session.is_connected().await);
609
610 session.disconnect().await.unwrap();
612 assert_eq!(session.state().await, SessionState::Disconnected);
613 assert!(!session.is_connected().await);
614 assert!(session.uptime().await.is_none());
615 }
616
617 #[tokio::test]
618 async fn test_notification_handlers() {
619 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
620 let session = ClientSession::new(client);
621
622 session
624 .add_notification_handler(LoggingNotificationHandler)
625 .await;
626
627 session
629 .add_notification_handler(ResourceUpdateHandler::new(|uri| {
630 println!("Resource updated: {}", uri);
631 }))
632 .await;
633
634 session
636 .add_notification_handler(ToolListChangedHandler::new(|| {
637 println!("Tool list changed");
638 }))
639 .await;
640
641 session
643 .add_notification_handler(ProgressHandler::new(|token, progress, total| {
644 println!("Progress {}: {} / {:?}", token, progress, total);
645 }))
646 .await;
647
648 let handlers = session.notification_handlers.read().await;
649 assert_eq!(handlers.len(), 4);
650 }
651
652 #[tokio::test]
653 async fn test_session_stats() {
654 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
655 let session = ClientSession::new(client);
656
657 let stats = session.stats().await;
658 assert_eq!(stats.state, SessionState::Disconnected);
659 assert!(stats.uptime.is_none());
660 assert_eq!(stats.reconnect_attempts, 0);
661 assert!(stats.connected_at.is_none());
662 }
663
664 #[tokio::test]
665 async fn test_session_config() {
666 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
667 let config = SessionConfig {
668 auto_reconnect: false,
669 max_reconnect_attempts: 10,
670 reconnect_delay_ms: 2000,
671 ..Default::default()
672 };
673 let session = ClientSession::with_config(client, config.clone());
674
675 assert!(!session.config().auto_reconnect);
676 assert_eq!(session.config().max_reconnect_attempts, 10);
677 assert_eq!(session.config().reconnect_delay_ms, 2000);
678 }
679
680 #[tokio::test]
681 async fn test_state_subscription() {
682 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
683 let session = ClientSession::new(client);
684
685 let mut state_rx = session.subscribe_state_changes();
686
687 assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
689
690 session
692 .transition_state(SessionState::Connecting)
693 .await
694 .unwrap();
695
696 state_rx.changed().await.unwrap();
698 assert_eq!(*state_rx.borrow(), SessionState::Connecting);
699 }
700}