1use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock};
9use tokio::time::{sleep, timeout};
10
11use crate::client::mcp_client::McpClient;
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, methods, types::*};
14use crate::transport::traits::Transport;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum SessionState {
19 Idle,
21 Disconnected,
23 Connecting,
25 Connected,
27 Reconnecting,
29 Failed(String),
31}
32
33pub trait NotificationHandler: Send + Sync {
35 fn handle_notification(&self, notification: JsonRpcNotification);
37}
38
39#[derive(Debug, Clone)]
41pub struct SessionConfig {
42 pub auto_reconnect: bool,
44 pub max_reconnect_attempts: u32,
46 pub reconnect_delay_ms: u64,
48 pub max_reconnect_delay_ms: u64,
50 pub reconnect_backoff: f64,
52 pub connection_timeout_ms: u64,
54 pub heartbeat_interval_ms: u64,
56 pub heartbeat_timeout_ms: u64,
58
59 pub session_timeout: Duration,
62 pub request_timeout: Duration,
64 pub max_concurrent_requests: u32,
66 pub enable_compression: bool,
68 pub buffer_size: usize,
70}
71
72impl Default for SessionConfig {
73 fn default() -> Self {
74 Self {
75 auto_reconnect: true,
76 max_reconnect_attempts: 5,
77 reconnect_delay_ms: 1000,
78 max_reconnect_delay_ms: 30000,
79 reconnect_backoff: 2.0,
80 connection_timeout_ms: 10000,
81 heartbeat_interval_ms: 30000,
82 heartbeat_timeout_ms: 5000,
83 session_timeout: Duration::from_secs(300),
84 request_timeout: Duration::from_secs(30),
85 max_concurrent_requests: 10,
86 enable_compression: false,
87 buffer_size: 8192,
88 }
89 }
90}
91
92pub struct ClientSession {
94 client: Arc<Mutex<McpClient>>,
96 config: SessionConfig,
98 state: Arc<RwLock<SessionState>>,
100 state_tx: watch::Sender<SessionState>,
102 state_rx: watch::Receiver<SessionState>,
104 notification_handlers: Arc<RwLock<Vec<Box<dyn NotificationHandler>>>>,
106 connected_at: Arc<RwLock<Option<Instant>>>,
108 reconnect_attempts: Arc<Mutex<u32>>,
110 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
112}
113
114impl ClientSession {
115 pub fn new(client: McpClient) -> Self {
117 let (state_tx, state_rx) = watch::channel(SessionState::Disconnected);
118
119 Self {
120 client: Arc::new(Mutex::new(client)),
121 config: SessionConfig::default(),
122 state: Arc::new(RwLock::new(SessionState::Disconnected)),
123 state_tx,
124 state_rx,
125 notification_handlers: Arc::new(RwLock::new(Vec::new())),
126 connected_at: Arc::new(RwLock::new(None)),
127 reconnect_attempts: Arc::new(Mutex::new(0)),
128 shutdown_tx: Arc::new(Mutex::new(None)),
129 }
130 }
131
132 pub fn with_config(client: McpClient, config: SessionConfig) -> Self {
134 let mut session = Self::new(client);
135 session.config = config;
136 session
137 }
138
139 pub async fn state(&self) -> SessionState {
141 let state = self.state.read().await;
142 state.clone()
143 }
144
145 pub fn subscribe_state_changes(&self) -> watch::Receiver<SessionState> {
147 self.state_rx.clone()
148 }
149
150 pub async fn is_connected(&self) -> bool {
152 let state = self.state.read().await;
153 matches!(*state, SessionState::Connected)
154 }
155
156 pub async fn uptime(&self) -> Option<Duration> {
158 let connected_at = self.connected_at.read().await;
159 connected_at.map(|time| time.elapsed())
160 }
161
162 pub async fn add_notification_handler<H>(&self, handler: H)
164 where
165 H: NotificationHandler + 'static,
166 {
167 let mut handlers = self.notification_handlers.write().await;
168 handlers.push(Box::new(handler));
169 }
170
171 pub async fn connect<T>(&self, transport: T) -> McpResult<InitializeResult>
173 where
174 T: Transport + 'static,
175 {
176 self.transition_state(SessionState::Connecting).await?;
177
178 let connect_future = async {
179 let mut client = self.client.lock().await;
180 client.connect(transport).await
181 };
182
183 let result = timeout(
184 Duration::from_millis(self.config.connection_timeout_ms),
185 connect_future,
186 )
187 .await;
188
189 match result {
190 Ok(Ok(init_result)) => {
191 self.transition_state(SessionState::Connected).await?;
192
193 {
195 let mut connected_at = self.connected_at.write().await;
196 *connected_at = Some(Instant::now());
197 }
198
199 {
201 let mut attempts = self.reconnect_attempts.lock().await;
202 *attempts = 0;
203 }
204
205 self.start_background_tasks().await?;
207
208 Ok(init_result)
209 }
210 Ok(Err(error)) => {
211 self.transition_state(SessionState::Failed(error.to_string()))
212 .await?;
213 Err(error)
214 }
215 Err(_) => {
216 let error = McpError::Connection("Connection timeout".to_string());
217 self.transition_state(SessionState::Failed(error.to_string()))
218 .await?;
219 Err(error)
220 }
221 }
222 }
223
224 pub async fn disconnect(&self) -> McpResult<()> {
226 self.stop_background_tasks().await;
228
229 {
231 let client = self.client.lock().await;
232 client.disconnect().await?;
233 }
234
235 self.transition_state(SessionState::Disconnected).await?;
237
238 {
240 let mut connected_at = self.connected_at.write().await;
241 *connected_at = None;
242 }
243
244 Ok(())
245 }
246
247 pub async fn reconnect<T>(
249 &self,
250 transport_factory: impl Fn() -> T,
251 ) -> McpResult<InitializeResult>
252 where
253 T: Transport + 'static,
254 {
255 if !self.config.auto_reconnect {
256 return Err(McpError::Connection(
257 "Auto-reconnect is disabled".to_string(),
258 ));
259 }
260
261 let mut attempts = self.reconnect_attempts.lock().await;
262 if *attempts >= self.config.max_reconnect_attempts {
263 let error = McpError::Connection("Max reconnection attempts exceeded".to_string());
264 self.transition_state(SessionState::Failed(error.to_string()))
265 .await?;
266 return Err(error);
267 }
268
269 *attempts += 1;
270 let current_attempts = *attempts;
271 drop(attempts);
272
273 self.transition_state(SessionState::Reconnecting).await?;
274
275 let delay = std::cmp::min(
277 (self.config.reconnect_delay_ms as f64
278 * self
279 .config
280 .reconnect_backoff
281 .powi(current_attempts as i32 - 1)) as u64,
282 self.config.max_reconnect_delay_ms,
283 );
284
285 sleep(Duration::from_millis(delay)).await;
286
287 self.connect(transport_factory()).await
289 }
290
291 pub fn client(&self) -> Arc<Mutex<McpClient>> {
293 self.client.clone()
294 }
295
296 pub fn config(&self) -> &SessionConfig {
298 &self.config
299 }
300
301 async fn start_background_tasks(&self) -> McpResult<()> {
307 let (_shutdown_tx, shutdown_rx): (broadcast::Sender<()>, broadcast::Receiver<()>) =
308 broadcast::channel(16);
309 {
310 let mut shutdown_guard = self.shutdown_tx.lock().await;
311 *shutdown_guard = Some(mpsc::channel(1).0); }
313
314 {
316 let client = self.client.clone();
317 let handlers = self.notification_handlers.clone();
318 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
319
320 tokio::spawn(async move {
321 loop {
322 tokio::select! {
323 _ = shutdown_rx_clone.recv() => break,
324 notification_result = async {
325 let client_guard = client.lock().await;
326 client_guard.receive_notification().await
327 } => {
328 match notification_result {
329 Ok(Some(notification)) => {
330 let handlers_guard = handlers.read().await;
331 for handler in handlers_guard.iter() {
332 handler.handle_notification(notification.clone());
333 }
334 }
335 Ok(None) => {
336 }
338 Err(_) => {
339 break;
341 }
342 }
343 }
344 }
345 }
346 });
347 }
348
349 if self.config.heartbeat_interval_ms > 0 {
351 let client = self.client.clone();
352 let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
353 let heartbeat_timeout = Duration::from_millis(self.config.heartbeat_timeout_ms);
354 let state = self.state.clone();
355 let state_tx = self.state_tx.clone();
356 let mut shutdown_rx_clone = shutdown_rx.resubscribe();
357
358 tokio::spawn(async move {
359 let mut interval = tokio::time::interval(heartbeat_interval);
360
361 loop {
362 tokio::select! {
363 _ = shutdown_rx_clone.recv() => break,
364 _ = interval.tick() => {
365 {
367 let current_state = state.read().await;
368 if !matches!(*current_state, SessionState::Connected) {
369 break;
370 }
371 }
372
373 let ping_result = timeout(heartbeat_timeout, async {
375 let client_guard = client.lock().await;
376 client_guard.ping().await
377 }).await;
378
379 if ping_result.is_err() {
380 let _ = state_tx.send(SessionState::Disconnected);
382 break;
383 }
384 }
385 }
386 }
387 });
388 }
389
390 Ok(())
391 }
392
393 async fn stop_background_tasks(&self) {
395 let shutdown_tx = {
396 let mut shutdown_guard = self.shutdown_tx.lock().await;
397 shutdown_guard.take()
398 };
399
400 if let Some(tx) = shutdown_tx {
401 let _ = tx.send(()).await; }
403 }
404
405 async fn transition_state(&self, new_state: SessionState) -> McpResult<()> {
407 {
408 let mut state = self.state.write().await;
409 *state = new_state.clone();
410 }
411
412 if self.state_tx.send(new_state).is_err() {
414 }
416
417 Ok(())
418 }
419}
420
421pub struct LoggingNotificationHandler;
423
424impl NotificationHandler for LoggingNotificationHandler {
425 fn handle_notification(&self, notification: JsonRpcNotification) {
426 tracing::info!(
427 "Received notification: {} {:?}",
428 notification.method,
429 notification.params
430 );
431 }
432}
433
434pub struct ResourceUpdateHandler {
436 callback: Box<dyn Fn(String) + Send + Sync>,
437}
438
439impl ResourceUpdateHandler {
440 pub fn new<F>(callback: F) -> Self
442 where
443 F: Fn(String) + Send + Sync + 'static,
444 {
445 Self {
446 callback: Box::new(callback),
447 }
448 }
449}
450
451impl NotificationHandler for ResourceUpdateHandler {
452 fn handle_notification(&self, notification: JsonRpcNotification) {
453 if notification.method == methods::RESOURCES_UPDATED {
454 if let Some(params) = notification.params {
455 if let Ok(update_params) = serde_json::from_value::<ResourceUpdatedParams>(params) {
456 (self.callback)(update_params.uri);
457 }
458 }
459 }
460 }
461}
462
463pub struct ToolListChangedHandler {
465 callback: Box<dyn Fn() + Send + Sync>,
466}
467
468impl ToolListChangedHandler {
469 pub fn new<F>(callback: F) -> Self
471 where
472 F: Fn() + Send + Sync + 'static,
473 {
474 Self {
475 callback: Box::new(callback),
476 }
477 }
478}
479
480impl NotificationHandler for ToolListChangedHandler {
481 fn handle_notification(&self, notification: JsonRpcNotification) {
482 if notification.method == methods::TOOLS_LIST_CHANGED {
483 (self.callback)();
484 }
485 }
486}
487
488pub struct ProgressHandler {
490 callback: Box<dyn Fn(String, f32, Option<u32>) + Send + Sync>,
491}
492
493impl ProgressHandler {
494 pub fn new<F>(callback: F) -> Self
496 where
497 F: Fn(String, f32, Option<u32>) + Send + Sync + 'static,
498 {
499 Self {
500 callback: Box::new(callback),
501 }
502 }
503}
504
505impl NotificationHandler for ProgressHandler {
506 fn handle_notification(&self, notification: JsonRpcNotification) {
507 if notification.method == methods::PROGRESS {
508 if let Some(params) = notification.params {
509 if let Ok(progress_params) = serde_json::from_value::<ProgressParams>(params) {
510 (self.callback)(
511 progress_params.progress_token.to_string(),
512 progress_params.progress,
513 progress_params.total.map(|t| t as u32),
514 );
515 }
516 }
517 }
518 }
519}
520
521#[derive(Debug, Clone)]
523pub struct SessionStats {
524 pub state: SessionState,
526 pub uptime: Option<Duration>,
528 pub reconnect_attempts: u32,
530 pub connected_at: Option<Instant>,
532}
533
534impl ClientSession {
535 pub async fn stats(&self) -> SessionStats {
537 let state = self.state().await;
538 let uptime = self.uptime().await;
539 let reconnect_attempts = {
540 let attempts = self.reconnect_attempts.lock().await;
541 *attempts
542 };
543 let connected_at = {
544 let connected_at = self.connected_at.read().await;
545 *connected_at
546 };
547
548 SessionStats {
549 state,
550 uptime,
551 reconnect_attempts,
552 connected_at,
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::client::mcp_client::McpClient;
561 use async_trait::async_trait;
562
563 struct MockTransport;
565
566 #[async_trait]
567 impl Transport for MockTransport {
568 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
569 let init_result = InitializeResult::new(
571 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
572 ServerCapabilities::default(),
573 ServerInfo {
574 name: "test-server".to_string(),
575 version: "1.0.0".to_string(),
576 },
577 );
578 JsonRpcResponse::success(serde_json::Value::from(1), init_result)
579 .map_err(|e| McpError::Serialization(e.to_string()))
580 }
581
582 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
583 Ok(())
584 }
585
586 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
587 Ok(None)
588 }
589
590 async fn close(&mut self) -> McpResult<()> {
591 Ok(())
592 }
593 }
594
595 #[tokio::test]
596 async fn test_session_creation() {
597 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
598 let session = ClientSession::new(client);
599
600 assert_eq!(session.state().await, SessionState::Disconnected);
601 assert!(!session.is_connected().await);
602 assert!(session.uptime().await.is_none());
603 }
604
605 #[tokio::test]
606 async fn test_session_connection() {
607 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
608 let session = ClientSession::new(client);
609
610 let transport = MockTransport;
611 let result = session.connect(transport).await;
612
613 assert!(result.is_ok());
614 assert_eq!(session.state().await, SessionState::Connected);
615 assert!(session.is_connected().await);
616 assert!(session.uptime().await.is_some());
617 }
618
619 #[tokio::test]
620 async fn test_session_disconnect() {
621 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
622 let session = ClientSession::new(client);
623
624 let transport = MockTransport;
626 session.connect(transport).await.unwrap();
627 assert!(session.is_connected().await);
628
629 session.disconnect().await.unwrap();
631 assert_eq!(session.state().await, SessionState::Disconnected);
632 assert!(!session.is_connected().await);
633 assert!(session.uptime().await.is_none());
634 }
635
636 #[tokio::test]
637 async fn test_notification_handlers() {
638 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
639 let session = ClientSession::new(client);
640
641 session
643 .add_notification_handler(LoggingNotificationHandler)
644 .await;
645
646 session
648 .add_notification_handler(ResourceUpdateHandler::new(|uri| {
649 println!("Resource updated: {}", uri);
650 }))
651 .await;
652
653 session
655 .add_notification_handler(ToolListChangedHandler::new(|| {
656 println!("Tool list changed");
657 }))
658 .await;
659
660 session
662 .add_notification_handler(ProgressHandler::new(|token, progress, total| {
663 println!("Progress {}: {} / {:?}", token, progress, total);
664 }))
665 .await;
666
667 let handlers = session.notification_handlers.read().await;
668 assert_eq!(handlers.len(), 4);
669 }
670
671 #[tokio::test]
672 async fn test_session_stats() {
673 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
674 let session = ClientSession::new(client);
675
676 let stats = session.stats().await;
677 assert_eq!(stats.state, SessionState::Disconnected);
678 assert!(stats.uptime.is_none());
679 assert_eq!(stats.reconnect_attempts, 0);
680 assert!(stats.connected_at.is_none());
681 }
682
683 #[tokio::test]
684 async fn test_session_config() {
685 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
686 let config = SessionConfig {
687 auto_reconnect: false,
688 max_reconnect_attempts: 10,
689 reconnect_delay_ms: 2000,
690 ..Default::default()
691 };
692 let session = ClientSession::with_config(client, config.clone());
693
694 assert!(!session.config().auto_reconnect);
695 assert_eq!(session.config().max_reconnect_attempts, 10);
696 assert_eq!(session.config().reconnect_delay_ms, 2000);
697 }
698
699 #[tokio::test]
700 async fn test_state_subscription() {
701 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
702 let session = ClientSession::new(client);
703
704 let mut state_rx = session.subscribe_state_changes();
705
706 assert_eq!(*state_rx.borrow(), SessionState::Disconnected);
708
709 session
711 .transition_state(SessionState::Connecting)
712 .await
713 .unwrap();
714
715 state_rx.changed().await.unwrap();
717 assert_eq!(*state_rx.borrow(), SessionState::Connecting);
718 }
719}