1use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{Mutex, RwLock, broadcast, mpsc, watch};
9use tokio::time::{sleep, timeout};
10
11use crate::client::mcp_client::McpClient;
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, methods, types::*};
14use crate::transport::traits::Transport;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum SessionState {
19 Idle,
21 Disconnected,
23 Connecting,
25 Connected,
27 Reconnecting,
29 Failed(String),
31}
32
33pub trait NotificationHandler: Send + Sync {
35 fn handle_notification(&self, notification: JsonRpcNotification);
37}
38
39#[derive(Debug, Clone)]
41pub struct SessionConfig {
42 pub auto_reconnect: bool,
44 pub max_reconnect_attempts: u32,
46 pub reconnect_delay_ms: u64,
48 pub max_reconnect_delay_ms: u64,
50 pub reconnect_backoff: f64,
52 pub connection_timeout_ms: u64,
54 pub heartbeat_interval_ms: u64,
56 pub heartbeat_timeout_ms: u64,
58
59 pub session_timeout: Duration,
62 pub request_timeout: Duration,
64 pub max_concurrent_requests: u32,
66 pub enable_compression: bool,
68 pub buffer_size: usize,
70}
71
72impl Default for SessionConfig {
73 fn default() -> Self {
74 Self {
75 auto_reconnect: true,
76 max_reconnect_attempts: 5,
77 reconnect_delay_ms: 1000,
78 max_reconnect_delay_ms: 30000,
79 reconnect_backoff: 2.0,
80 connection_timeout_ms: 10000,
81 heartbeat_interval_ms: 30000,
82 heartbeat_timeout_ms: 5000,
83 session_timeout: Duration::from_secs(300),
84 request_timeout: Duration::from_secs(30),
85 max_concurrent_requests: 10,
86 enable_compression: false,
87 buffer_size: 8192,
88 }
89 }
90}
91
92pub struct ClientSession {
94 client: Arc<Mutex<McpClient>>,
96 config: SessionConfig,
98 state: Arc<RwLock<SessionState>>,
100 state_tx: watch::Sender<SessionState>,
102 state_rx: watch::Receiver<SessionState>,
104 notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
106 connected_at: Arc<RwLock<Option<Instant>>>,
108 reconnect_attempts: Arc<Mutex<u32>>,
110 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
112}
113
114impl ClientSession {
115 pub fn new(client: McpClient) -> Self {
117 let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
118
119 Self {
120 client: Arc::new(Mutex::new(client)),
121 config: SessionConfig::default(),
122 state: Arc::new(RwLock::new(SessionState::Disconnected)),
123 state_tx,
124 state_rx,
125 notification_handlers: Arc::new(RwLock::new(Vec::new())),
126 connected_at: Arc::new(RwLock::new(None)),
127 reconnect_attempts: Arc::new(Mutex::new(0)),
128 shutdown_tx: Arc::new(Mutex::new(None)),
129 }
130 }
131
132 pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
134 let mut session = Self::new(client);
135 session.config = config;
136 session
137 }
138
139 pub async fn state(&self) -> SessionState {
141 let state = self.state.read().await;
142 state.clone()
143 }
144
145 pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
147 self.state_rx.clone()
148 }
149
150 pub async fn is_connected(&self) -> bool {
152 let state = self.state.read().await;
153 matches!(*state, SessionState::Connected)
154 }
155
156 pub async fn uptime(&self) -> Option<Duration> {
158 let connected_at = self.connected_at.read().await;
159 connected_at.map(|time| time.elapsed())
160 }
161
162 pub async fn add_notification_handler<H>(&self, handler: H)
164 where
165 H: NotificationHandler + 'static,
166 {
167 let mut handlers = self.notification_handlers.write().await;
168 handlers.push(Box::new(handler));
169 }
170
171 pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
173 where
174 T: Transport + 'static,
175 {
176 self.transition_state(SessionState::Connecting).await?;
177
178 let connect_future = async {
179 let mut client = self.client.lock().await;
180 client.connect(transport).await
181 };
182
183 let result = timeout(
184 Duration::from_millis(self.config.connection_timeout_ms),
185 connect_future,
186 )
187 .await;
188
189 match result {
190 Ok(Ok(init_result)) => {
191 self.transition_state(SessionState::Connected).await?;
192
193 {
195 let mut connected_at = self.connected_at.write().await;
196 *connected_at = Some(Instant::now());
197 }
198
199 {
201 let mut attempts = self.reconnect_attempts.lock().await;
202 *attempts = 0;
203 }
204
205 self.start_background_tasks().await?;
207
208 Ok(init_result)
209 }
210 Ok(Err(error)) => {
211 self.transition_state(SessionState::Failed(error.to_string()))
212 .await?;
213 Err(error)
214 }
215 Err(_) => {
216 let error = McpError::Connection("Connection timeout".to_string());
217 self.transition_state(SessionState::Failed(error.to_string()))
218 .await?;
219 Err(error)
220 }
221 }
222 }
223
224 pub async fn disconnect(&self) -> McpResult<()> {
226 self.stop_background_tasks().await;
228
229 {
231 let client = self.client.lock().await;
232 client.disconnect().await?;
233 }
234
235 self.transition_state(SessionState::Disconnected).await?;
237
238 {
240 let mut connected_at = self.connected_at.write().await;
241 *connected_at = None;
242 }
243
244 Ok(())
245 }
246
247 pub async fn reconnect<T>(
249 &self,
250 transport_factory: impl Fn() -> T,
251 ) -> McpResult<InitializeResult>
252 where
253 T: Transport + 'static,
254 {
255 if !self.config.auto_reconnect {
256 return Err(McpError::Connection(
257 "Auto-reconnect is disabled".to_string(),
258 ));
259 }
260
261 let mut attempts = self.reconnect_attempts.lock().await;
262 if *attempts >= self.config.max_reconnect_attempts {
263 let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
264 self.transition_state(SessionState::Failed(error.to_string()))
265 .await?;
266 return Err(error);
267 }
268
269 *attempts += 1;
270 let current_attempts = *attempts;
271 drop(attempts);
272
273 self.transition_state(SessionState::Reconnecting).await?;
274
275 let delay = std::cmp::min(
277 (self.config.reconnect_delay_ms as f64
278 * self
279 .config
280 .reconnect_backoff
281 .powi(current_attempts as i32 - 1)) as u64,
282 self.config.max_reconnect_delay_ms,
283 );
284
285 sleep(Duration::from_millis(delay)).await;
286
287 self.connect(transport_factory()).await
289 }
290
291 pub fn client(&self) -> Arc<Mutex<McpClient>> {
293 self.client.clone()
294 }
295
296 pub fn config(&self) -> &SessionConfig {
298 &self.config
299 }
300
301 async fn start_background_tasks(&self) -> McpResult<()> {
307 let (_shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
308 broadcast::channel(16);
309 {
310 let mut shutdown_guard = self.shutdown_tx.lock().await;
311 *shutdown_guard = Some(mpsc::channel(1).0); }
313
314 {
316 let client = self.client.clone();
317 let handlers = self.notification_handlers.clone();
318 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
319
320 tokio::spawn(async move {
321 loop {
322 tokio::select! {
323 _ = shutdown_rx_clone.recv() => break,
324 notification_result = async {
325 let client_guard = client.lock().await;
326 client_guard.receive_notification().await
327 } => {
328 match notification_result {
329 Ok(Some(notification)) => {
330 let handlers_guard = handlers.read().await;
331 for handler in handlers_guard.iter() {
332 handler.handle_notification(notification.clone());
333 }
334 }
335 Ok(None) => {
336 }
338 Err(_) => {
339 break;
341 }
342 }
343 }
344 }
345 }
346 });
347 }
348
349 if self.config.heartbeat_interval_ms > 0 {
351 let client = self.client.clone();
352 let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
353 let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
354 let state = self.state.clone();
355 let state_tx = self.state_tx.clone();
356 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
357
358 tokio::spawn(async move {
359 let mut interval = tokio::time::interval(heartbeat_interval);
360
361 loop {
362 tokio::select! {
363 _ = shutdown_rx_clone.recv() => break,
364 _ = interval.tick() => {
365 {
367 let current_state = state.read().await;
368 if !matches!(*current_state, SessionState::Connected) {
369 break;
370 }
371 }
372
373 let ping_result = timeout(heartbeat_timeout, async {
375 let client_guard = client.lock().await;
376 client_guard.ping().await
377 }).await;
378
379 if ping_result.is_err() {
380 let _ = state_tx.send(SessionState::Disconnected);
382 break;
383 }
384 }
385 }
386 }
387 });
388 }
389
390 Ok(())
391 }
392
393 async fn stop_background_tasks(&self) {
395 let shutdown_tx = {
396 let mut shutdown_guard = self.shutdown_tx.lock().await;
397 shutdown_guard.take()
398 };
399
400 if let Some(tx) = shutdown_tx {
401 let _ = tx.send(()).await; }
403 }
404
405 async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
407 {
408 let mut state = self.state.write().await;
409 *state = new_state.clone();
410 }
411
412 if self.state_tx.send(new_state).is_err() {
414 }
416
417 Ok(())
418 }
419}
420
421pub struct LoggingNotificationHandler;
423
424impl NotificationHandler for LoggingNotificationHandler {
425 fn handle_notification(&self, notification: JsonRpcNotification) {
426 tracing::info!(
427 "Received notification: {} {:?}",
428 notification.method,
429 notification.params
430 );
431 }
432}
433
434pub struct ResourceUpdateHandler {
436 callback: Box<dyn Fn(String) + Send + Sync>,
437}
438
439impl ResourceUpdateHandler {
440 pub fn new<F>(callback: F) -> Self
442 where
443 F: Fn(String) + Send + Sync + 'static,
444 {
445 Self {
446 callback: Box::new(callback),
447 }
448 }
449}
450
451impl NotificationHandler for ResourceUpdateHandler {
452 fn handle_notification(&self, notification: JsonRpcNotification) {
453 if notification.method == methods::RESOURCES_UPDATED {
454 if let Some(params) = notification.params {
455 if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
456 (self.callback)(update_params.uri);
457 }
458 }
459 }
460 }
461}
462
463pub struct ToolListChangedHandler {
465 callback: Box<dyn Fn() + Send + Sync>,
466}
467
468impl ToolListChangedHandler {
469 pub fn new<F>(callback: F) -> Self
471 where
472 F: Fn() + Send + Sync + 'static,
473 {
474 Self {
475 callback: Box::new(callback),
476 }
477 }
478}
479
480impl NotificationHandler for ToolListChangedHandler {
481 fn handle_notification(&self, notification: JsonRpcNotification) {
482 if notification.method == methods::TOOLS_LIST_CHANGED {
483 (self.callback)();
484 }
485 }
486}
487
488pub struct ProgressHandler {
490 callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
491}
492
493impl ProgressHandler {
494 pub fn new<F>(callback: F) -> Self
496 where
497 F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
498 {
499 Self {
500 callback: Box::new(callback),
501 }
502 }
503}
504
505impl NotificationHandler for ProgressHandler {
506 fn handle_notification(&self, notification: JsonRpcNotification) {
507 if notification.method == methods::PROGRESS {
508 if let Some(params) = notification.params {
509 if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
510 (self.callback)(
511 progress_params.progress_token.to_string(),
512 progress_params.progress,
513 progress_params.total.map(|t| t as u32),
514 );
515 }
516 }
517 }
518 }
519}
520
521#[derive(Debug, Clone)]
523pub struct SessionStats {
524 pub state: SessionState,
526 pub uptime: Option<Duration>,
528 pub reconnect_attempts: u32,
530 pub connected_at: Option<Instant>,
532}
533
534impl ClientSession {
535 pub async fn stats(&self) -> SessionStats {
537 let state = self.state().await;
538 let uptime = self.uptime().await;
539 let reconnect_attempts = {
540 let attempts = self.reconnect_attempts.lock().await;
541 *attempts
542 };
543 let connected_at = {
544 let connected_at = self.connected_at.read().await;
545 *connected_at
546 };
547
548 SessionStats {
549 state,
550 uptime,
551 reconnect_attempts,
552 connected_at,
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::client::mcp_client::McpClient;
561 use async_trait::async_trait;
562
563 struct MockTransport;
565
566 #[async_trait]
567 impl Transport for MockTransport {
568 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
569 let init_result = InitializeResult::new(
571 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
572 ServerCapabilities::default(),
573 ServerInfo {
574 name: "test-server".to_string(),
575 version: "1.0.0".to_string(),
576 title: Some("Test Server".to_string()),
577 },
578 );
579 JsonRpcResponse::success(serde_json::Value::from(1), init_result)
580 .map_err(|e| McpError::Serialization(e.to_string()))
581 }
582
583 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
584 Ok(())
585 }
586
587 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
588 Ok(None)
589 }
590
591 async fn close(&mut self) -> McpResult<()> {
592 Ok(())
593 }
594 }
595
596 #[tokio::test]
597 async fn test_session_creation() {
598 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
599 let session = ClientSession::new(client);
600
601 assert_eq!(session.state().await, SessionState::Disconnected);
602 assert!(!session.is_connected().await);
603 assert!(session.uptime().await.is_none());
604 }
605
606 #[tokio::test]
607 async fn test_session_connection() {
608 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
609 let session = ClientSession::new(client);
610
611 let transport = MockTransport;
612 let result = session.connect(transport).await;
613
614 assert!(result.is_ok());
615 assert_eq!(session.state().await, SessionState::Connected);
616 assert!(session.is_connected().await);
617 assert!(session.uptime().await.is_some());
618 }
619
620 #[tokio::test]
621 async fn test_session_disconnect() {
622 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
623 let session = ClientSession::new(client);
624
625 let transport = MockTransport;
627 session.connect(transport).await.unwrap();
628 assert!(session.is_connected().await);
629
630 session.disconnect().await.unwrap();
632 assert_eq!(session.state().await, SessionState::Disconnected);
633 assert!(!session.is_connected().await);
634 assert!(session.uptime().await.is_none());
635 }
636
637 #[tokio::test]
638 async fn test_notification_handlers() {
639 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
640 let session = ClientSession::new(client);
641
642 session
644 .add_notification_handler(LoggingNotificationHandler)
645 .await;
646
647 session
649 .add_notification_handler(ResourceUpdateHandler::new(|uri| {
650 println!("Resource updated: {uri}");
651 }))
652 .await;
653
654 session
656 .add_notification_handler(ToolListChangedHandler::new(|| {
657 println!("Tool list changed");
658 }))
659 .await;
660
661 session
663 .add_notification_handler(ProgressHandler::new(|token, progress, total| {
664 println!("Progress {token}: {progress} / {total:?}");
665 }))
666 .await;
667
668 let handlers = session.notification_handlers.read().await;
669 assert_eq!(handlers.len(), 4);
670 }
671
672 #[tokio::test]
673 async fn test_session_stats() {
674 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
675 let session = ClientSession::new(client);
676
677 let stats = session.stats().await;
678 assert_eq!(stats.state, SessionState::Disconnected);
679 assert!(stats.uptime.is_none());
680 assert_eq!(stats.reconnect_attempts, 0);
681 assert!(stats.connected_at.is_none());
682 }
683
684 #[tokio::test]
685 async fn test_session_config() {
686 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
687 let config = SessionConfig {
688 auto_reconnect: false,
689 max_reconnect_attempts: 10,
690 reconnect_delay_ms: 2000,
691 ..Default::default()
692 };
693 let session = ClientSession::with_config(client, config.clone());
694
695 assert!(!session.config().auto_reconnect);
696 assert_eq!(session.config().max_reconnect_attempts, 10);
697 assert_eq!(session.config().reconnect_delay_ms, 2000);
698 }
699
700 #[tokio::test]
701 async fn test_state_subscription() {
702 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
703 let session = ClientSession::new(client);
704
705 let mut state_rx = session.subscribe_state_changes();
706
707 assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
709
710 session
712 .transition_state(SessionState::Connecting)
713 .await
714 .unwrap();
715
716 state_rx.changed().await.unwrap();
718 assert_eq!(*state_rx.borrow(), SessionState::Connecting);
719 }
720}