1use async_trait::async_trait;
16use chrono::Utc;
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use uuid::Uuid;
23
24use crate::mcp::error::{McpError, McpResult};
25use crate::mcp::transport::{
26 BoxedTransport, McpRequest, McpResponse, TransportConfig, TransportFactory, TransportState,
27};
28use crate::mcp::types::{
29 ConnectionOptions, ConnectionStatus, McpConnection, McpServerInfo, TransportType,
30};
31
32#[derive(Debug, Clone)]
34pub enum ConnectionEvent {
35 Establishing(McpConnection),
37 Established(McpConnection),
39 Closed(McpConnection),
41 Error(McpConnection, String),
43 Reconnecting(McpConnection),
45 HeartbeatFailed(String, String),
47}
48
49struct ConnectionState {
51 info: McpConnection,
53 transport: BoxedTransport,
55 #[allow(dead_code)]
57 server_info: McpServerInfo,
58 #[allow(dead_code)]
60 reconnect_attempts: u32,
61 last_heartbeat: Option<chrono::DateTime<Utc>>,
63 heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
65}
66
67#[derive(Debug, Clone)]
69pub struct PendingRequestInfo {
70 pub request_id: String,
72 pub connection_id: String,
74 pub method: String,
76 pub start_time: chrono::DateTime<Utc>,
78}
79
80#[async_trait]
84pub trait ConnectionManager: Send + Sync {
85 async fn connect(&self, server: McpServerInfo) -> McpResult<McpConnection>;
87
88 async fn disconnect(&self, connection_id: &str) -> McpResult<()>;
90
91 async fn disconnect_all(&self) -> McpResult<()>;
93
94 async fn send(&self, connection_id: &str, request: McpRequest) -> McpResult<McpResponse>;
96
97 async fn send_with_timeout(
99 &self,
100 connection_id: &str,
101 request: McpRequest,
102 timeout: Duration,
103 ) -> McpResult<McpResponse>;
104
105 async fn send_with_retry(
107 &self,
108 connection_id: &str,
109 request: McpRequest,
110 ) -> McpResult<McpResponse>;
111
112 async fn cancel_request(&self, connection_id: &str, request_id: &str) -> McpResult<()>;
114
115 fn get_connection(&self, id: &str) -> Option<McpConnection>;
117
118 fn get_connection_by_server(&self, server_name: &str) -> Option<McpConnection>;
120
121 fn get_all_connections(&self) -> Vec<McpConnection>;
123
124 fn subscribe(&self) -> mpsc::Receiver<ConnectionEvent>;
126}
127
128pub struct McpConnectionManager {
130 connections: Arc<RwLock<HashMap<String, ConnectionState>>>,
132 server_to_connection: Arc<RwLock<HashMap<String, String>>>,
134 pub default_options: ConnectionOptions,
136 event_tx: Arc<Mutex<Option<mpsc::Sender<ConnectionEvent>>>>,
138 request_counter: AtomicU64,
140 enable_heartbeat: bool,
142 enable_auto_reconnect: bool,
144}
145
146impl McpConnectionManager {
147 pub fn new() -> Self {
149 Self::with_options(ConnectionOptions::default())
150 }
151
152 pub fn with_options(options: ConnectionOptions) -> Self {
154 Self {
155 connections: Arc::new(RwLock::new(HashMap::new())),
156 server_to_connection: Arc::new(RwLock::new(HashMap::new())),
157 default_options: options,
158 event_tx: Arc::new(Mutex::new(None)),
159 request_counter: AtomicU64::new(1),
160 enable_heartbeat: true,
161 enable_auto_reconnect: true,
162 }
163 }
164
165 pub fn set_heartbeat_enabled(&mut self, enabled: bool) {
167 self.enable_heartbeat = enabled;
168 }
169
170 pub fn set_auto_reconnect_enabled(&mut self, enabled: bool) {
172 self.enable_auto_reconnect = enabled;
173 }
174
175 pub fn generate_connection_id() -> String {
177 Uuid::new_v4().to_string()
178 }
179
180 pub fn next_request_id(&self) -> String {
182 let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
183 format!("mcp-req-{}", id)
184 }
185
186 async fn emit_event(&self, event: ConnectionEvent) {
188 if let Some(tx) = self.event_tx.lock().await.as_ref() {
189 let _ = tx.send(event).await;
190 }
191 }
192
193 pub fn create_transport_config(server: &McpServerInfo) -> McpResult<TransportConfig> {
195 match server.transport_type {
196 TransportType::Stdio => {
197 let command = server
198 .command
199 .clone()
200 .ok_or_else(|| McpError::config("Stdio transport requires a command"))?;
201 Ok(TransportConfig::Stdio {
202 command,
203 args: server.args.clone().unwrap_or_default(),
204 env: server.env.clone().unwrap_or_default(),
205 cwd: None,
206 })
207 }
208 TransportType::Http => {
209 let url = server
210 .url
211 .clone()
212 .ok_or_else(|| McpError::config("HTTP transport requires a URL"))?;
213 Ok(TransportConfig::Http {
214 url,
215 headers: server.headers.clone().unwrap_or_default(),
216 })
217 }
218 TransportType::Sse => {
219 let url = server
220 .url
221 .clone()
222 .ok_or_else(|| McpError::config("SSE transport requires a URL"))?;
223 Ok(TransportConfig::Sse {
224 url,
225 headers: server.headers.clone().unwrap_or_default(),
226 })
227 }
228 TransportType::WebSocket => {
229 let url = server
230 .url
231 .clone()
232 .ok_or_else(|| McpError::config("WebSocket transport requires a URL"))?;
233 Ok(TransportConfig::WebSocket {
234 url,
235 headers: server.headers.clone().unwrap_or_default(),
236 })
237 }
238 }
239 }
240
241 async fn perform_handshake(
243 transport: &mut BoxedTransport,
244 connection: &mut McpConnection,
245 ) -> McpResult<()> {
246 let init_request = McpRequest::with_params(
248 serde_json::json!("init-1"),
249 "initialize",
250 serde_json::json!({
251 "protocolVersion": "2024-11-05",
252 "capabilities": {
253 "roots": { "listChanged": true },
254 "sampling": {}
255 },
256 "clientInfo": {
257 "name": "aster",
258 "version": env!("CARGO_PKG_VERSION")
259 }
260 }),
261 );
262
263 let response = transport.send_request(init_request).await?;
264
265 if let Some(result) = response.result {
267 if let Some(protocol_version) = result.get("protocolVersion").and_then(|v| v.as_str()) {
268 connection.protocol_version = Some(protocol_version.to_string());
269 }
270
271 if let Some(capabilities) = result.get("capabilities") {
273 if let Ok(caps) = serde_json::from_value(capabilities.clone()) {
274 connection.capabilities = Some(caps);
275 }
276 }
277 }
278
279 let initialized_notification =
281 crate::mcp::transport::McpNotification::new("notifications/initialized");
282 transport
283 .send(crate::mcp::transport::McpMessage::Notification(
284 initialized_notification,
285 ))
286 .await?;
287
288 Ok(())
289 }
290
291 fn start_heartbeat(&self, connection_id: String, interval: Duration) {
293 let connections = self.connections.clone();
294 let event_tx = self.event_tx.clone();
295 let enable_auto_reconnect = self.enable_auto_reconnect;
296
297 tokio::spawn(async move {
298 let mut interval_timer = tokio::time::interval(interval);
299
300 loop {
301 interval_timer.tick().await;
302
303 let mut conns = connections.write().await;
304 if let Some(state) = conns.get_mut(&connection_id) {
305 if state.transport.state() != TransportState::Connected {
307 if let Some(tx) = event_tx.lock().await.as_ref() {
309 let _ = tx
310 .send(ConnectionEvent::HeartbeatFailed(
311 connection_id.clone(),
312 "Transport disconnected".to_string(),
313 ))
314 .await;
315 }
316
317 if enable_auto_reconnect {
319 state.info.status = ConnectionStatus::Reconnecting;
320 }
322 break;
323 }
324
325 let ping_request = McpRequest::new(
327 serde_json::json!(format!("ping-{}", Uuid::new_v4())),
328 "ping",
329 );
330
331 match state.transport.send_request(ping_request).await {
332 Ok(_) => {
333 state.last_heartbeat = Some(Utc::now());
334 state.info.last_activity = Utc::now();
335 }
336 Err(e) => {
337 if let Some(tx) = event_tx.lock().await.as_ref() {
339 let _ = tx
340 .send(ConnectionEvent::HeartbeatFailed(
341 connection_id.clone(),
342 e.to_string(),
343 ))
344 .await;
345 }
346
347 if enable_auto_reconnect {
348 state.info.status = ConnectionStatus::Reconnecting;
349 }
350 break;
351 }
352 }
353 } else {
354 break;
356 }
357 }
358 });
359 }
360
361 pub fn calculate_reconnect_delay(&self, attempt: u32) -> Duration {
363 let base = self.default_options.reconnect_delay_base.as_millis() as u64;
364 let max = self.default_options.reconnect_delay_max.as_millis() as u64;
365
366 let delay_ms = base.saturating_mul(1u64 << attempt.min(10));
368 Duration::from_millis(delay_ms.min(max))
369 }
370
371 pub async fn reconnect(&self, connection_id: &str) -> McpResult<McpConnection> {
376 let (server_info, max_retries) = {
377 let conns = self.connections.read().await;
378 if let Some(state) = conns.get(connection_id) {
379 (state.server_info.clone(), self.default_options.max_retries)
380 } else {
381 return Err(McpError::connection(format!(
382 "Connection not found: {}",
383 connection_id
384 )));
385 }
386 };
387
388 {
390 let mut conns = self.connections.write().await;
391 if let Some(state) = conns.get_mut(connection_id) {
392 state.info.status = ConnectionStatus::Reconnecting;
393 self.emit_event(ConnectionEvent::Reconnecting(state.info.clone()))
394 .await;
395 }
396 }
397
398 let mut last_error = None;
399
400 for attempt in 0..=max_retries {
401 if attempt > 0 {
402 let delay = self.calculate_reconnect_delay(attempt - 1);
403 tokio::time::sleep(delay).await;
404 }
405
406 match self.try_reconnect(connection_id, &server_info).await {
408 Ok(connection) => {
409 {
411 let mut conns = self.connections.write().await;
412 if let Some(state) = conns.get_mut(connection_id) {
413 state.reconnect_attempts = 0;
414 }
415 }
416 return Ok(connection);
417 }
418 Err(e) => {
419 last_error = Some(e);
420 {
422 let mut conns = self.connections.write().await;
423 if let Some(state) = conns.get_mut(connection_id) {
424 state.reconnect_attempts = attempt + 1;
425 }
426 }
427 }
428 }
429 }
430
431 {
433 let mut conns = self.connections.write().await;
434 if let Some(state) = conns.get_mut(connection_id) {
435 state.info.status = ConnectionStatus::Error;
436 self.emit_event(ConnectionEvent::Error(
437 state.info.clone(),
438 last_error
439 .as_ref()
440 .map(|e| e.to_string())
441 .unwrap_or_else(|| "Unknown error".to_string()),
442 ))
443 .await;
444 }
445 }
446
447 Err(last_error.unwrap_or_else(|| McpError::connection("Reconnection failed after retries")))
448 }
449
450 async fn try_reconnect(
452 &self,
453 connection_id: &str,
454 server_info: &McpServerInfo,
455 ) -> McpResult<McpConnection> {
456 let transport_config = Self::create_transport_config(server_info)?;
458 let mut transport =
459 TransportFactory::create(transport_config, server_info.options.clone())?;
460
461 transport.connect().await?;
463
464 let mut connection = McpConnection::new(
466 connection_id.to_string(),
467 server_info.name.clone(),
468 server_info.transport_type,
469 );
470
471 Self::perform_handshake(&mut transport, &mut connection).await?;
473
474 connection.status = ConnectionStatus::Connected;
476 connection.touch();
477
478 {
480 let mut conns = self.connections.write().await;
481 if let Some(state) = conns.get_mut(connection_id) {
482 state.info = connection.clone();
483 state.transport = transport;
484 state.last_heartbeat = Some(Utc::now());
485 }
486 }
487
488 self.emit_event(ConnectionEvent::Established(connection.clone()))
490 .await;
491
492 Ok(connection)
493 }
494}
495
496impl Default for McpConnectionManager {
497 fn default() -> Self {
498 Self::new()
499 }
500}
501
502#[async_trait]
503impl ConnectionManager for McpConnectionManager {
504 async fn connect(&self, server: McpServerInfo) -> McpResult<McpConnection> {
505 {
507 let server_map = self.server_to_connection.read().await;
508 if let Some(conn_id) = server_map.get(&server.name) {
509 let conns = self.connections.read().await;
510 if let Some(state) = conns.get(conn_id) {
511 if state.info.status == ConnectionStatus::Connected {
512 return Ok(state.info.clone());
513 }
514 }
515 }
516 }
517
518 let connection_id = Self::generate_connection_id();
520 let mut connection = McpConnection::new(
521 connection_id.clone(),
522 server.name.clone(),
523 server.transport_type,
524 );
525
526 self.emit_event(ConnectionEvent::Establishing(connection.clone()))
528 .await;
529
530 let transport_config = Self::create_transport_config(&server)?;
532
533 let options = server.options.clone();
535 let mut transport = TransportFactory::create(transport_config, options.clone())?;
536
537 transport.connect().await?;
538
539 Self::perform_handshake(&mut transport, &mut connection).await?;
541
542 connection.status = ConnectionStatus::Connected;
544 connection.touch();
545
546 {
548 let mut conns = self.connections.write().await;
549 conns.insert(
550 connection_id.clone(),
551 ConnectionState {
552 info: connection.clone(),
553 transport,
554 server_info: server.clone(),
555 reconnect_attempts: 0,
556 last_heartbeat: Some(Utc::now()),
557 heartbeat_handle: None,
558 },
559 );
560 }
561
562 {
564 let mut server_map = self.server_to_connection.write().await;
565 server_map.insert(server.name.clone(), connection_id.clone());
566 }
567
568 if self.enable_heartbeat {
570 self.start_heartbeat(connection_id, options.heartbeat_interval);
571 }
572
573 self.emit_event(ConnectionEvent::Established(connection.clone()))
575 .await;
576
577 Ok(connection)
578 }
579
580 async fn disconnect(&self, connection_id: &str) -> McpResult<()> {
581 let mut conns = self.connections.write().await;
582
583 if let Some(mut state) = conns.remove(connection_id) {
584 if let Some(handle) = state.heartbeat_handle.take() {
586 handle.abort();
587 }
588
589 state.transport.disconnect().await?;
591
592 state.info.status = ConnectionStatus::Disconnected;
594
595 {
597 let mut server_map = self.server_to_connection.write().await;
598 server_map.remove(&state.info.server_name);
599 }
600
601 self.emit_event(ConnectionEvent::Closed(state.info)).await;
603
604 Ok(())
605 } else {
606 Err(McpError::connection(format!(
607 "Connection not found: {}",
608 connection_id
609 )))
610 }
611 }
612
613 async fn disconnect_all(&self) -> McpResult<()> {
614 let connection_ids: Vec<String> = {
615 let conns = self.connections.read().await;
616 conns.keys().cloned().collect()
617 };
618
619 for id in connection_ids {
620 if let Err(e) = self.disconnect(&id).await {
621 tracing::warn!("Failed to disconnect {}: {}", id, e);
622 }
623 }
624
625 Ok(())
626 }
627
628 async fn send(&self, connection_id: &str, request: McpRequest) -> McpResult<McpResponse> {
629 let mut conns = self.connections.write().await;
630
631 if let Some(state) = conns.get_mut(connection_id) {
632 if state.info.status != ConnectionStatus::Connected {
633 return Err(McpError::connection("Connection is not active"));
634 }
635
636 let response = state.transport.send_request(request).await?;
637 state.info.touch();
638
639 Ok(response)
640 } else {
641 Err(McpError::connection(format!(
642 "Connection not found: {}",
643 connection_id
644 )))
645 }
646 }
647
648 async fn send_with_timeout(
649 &self,
650 connection_id: &str,
651 request: McpRequest,
652 timeout: Duration,
653 ) -> McpResult<McpResponse> {
654 let mut conns = self.connections.write().await;
655
656 if let Some(state) = conns.get_mut(connection_id) {
657 if state.info.status != ConnectionStatus::Connected {
658 return Err(McpError::connection("Connection is not active"));
659 }
660
661 let response = state
662 .transport
663 .send_request_with_timeout(request, timeout)
664 .await?;
665 state.info.touch();
666
667 Ok(response)
668 } else {
669 Err(McpError::connection(format!(
670 "Connection not found: {}",
671 connection_id
672 )))
673 }
674 }
675
676 async fn send_with_retry(
677 &self,
678 connection_id: &str,
679 request: McpRequest,
680 ) -> McpResult<McpResponse> {
681 let max_retries = self.default_options.max_retries;
682 let mut last_error = None;
683
684 for attempt in 0..=max_retries {
685 match self.send(connection_id, request.clone()).await {
686 Ok(response) => return Ok(response),
687 Err(e) => {
688 last_error = Some(e);
689 if attempt < max_retries {
690 let delay = self.calculate_reconnect_delay(attempt);
691 tokio::time::sleep(delay).await;
692 }
693 }
694 }
695 }
696
697 Err(last_error.unwrap_or_else(|| McpError::connection("Request failed after retries")))
698 }
699
700 async fn cancel_request(&self, connection_id: &str, request_id: &str) -> McpResult<()> {
701 let mut conns = self.connections.write().await;
702
703 if let Some(state) = conns.get_mut(connection_id) {
704 if state.info.status != ConnectionStatus::Connected {
705 return Err(McpError::connection("Connection is not active"));
706 }
707
708 let cancel_notification = crate::mcp::transport::McpNotification::with_params(
710 "notifications/cancelled",
711 serde_json::json!({
712 "requestId": request_id,
713 "reason": "Cancelled by client"
714 }),
715 );
716
717 state
718 .transport
719 .send(crate::mcp::transport::McpMessage::Notification(
720 cancel_notification,
721 ))
722 .await?;
723
724 Ok(())
725 } else {
726 Err(McpError::connection(format!(
727 "Connection not found: {}",
728 connection_id
729 )))
730 }
731 }
732
733 fn get_connection(&self, id: &str) -> Option<McpConnection> {
734 self.connections
736 .try_read()
737 .ok()
738 .and_then(|conns| conns.get(id).map(|s| s.info.clone()))
739 }
740
741 fn get_connection_by_server(&self, server_name: &str) -> Option<McpConnection> {
742 let server_map = self.server_to_connection.try_read().ok()?;
743 let conn_id = server_map.get(server_name)?;
744 self.get_connection(conn_id)
745 }
746
747 fn get_all_connections(&self) -> Vec<McpConnection> {
748 self.connections
749 .try_read()
750 .map(|conns| conns.values().map(|s| s.info.clone()).collect())
751 .unwrap_or_default()
752 }
753
754 fn subscribe(&self) -> mpsc::Receiver<ConnectionEvent> {
755 let (tx, rx) = mpsc::channel(100);
756 let event_tx = self.event_tx.clone();
757 tokio::spawn(async move {
758 *event_tx.lock().await = Some(tx);
759 });
760 rx
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn test_connection_manager_new() {
770 let manager = McpConnectionManager::new();
771 assert!(manager.get_all_connections().is_empty());
772 }
773
774 #[test]
775 fn test_connection_manager_with_options() {
776 let options = ConnectionOptions {
777 timeout: Duration::from_secs(60),
778 max_retries: 5,
779 ..Default::default()
780 };
781 let manager = McpConnectionManager::with_options(options);
782 assert_eq!(manager.default_options.timeout, Duration::from_secs(60));
783 assert_eq!(manager.default_options.max_retries, 5);
784 }
785
786 #[test]
787 fn test_generate_connection_id() {
788 let id1 = McpConnectionManager::generate_connection_id();
789 let id2 = McpConnectionManager::generate_connection_id();
790 assert_ne!(id1, id2);
791 assert!(Uuid::parse_str(&id1).is_ok());
793 }
794
795 #[test]
796 fn test_next_request_id() {
797 let manager = McpConnectionManager::new();
798 let id1 = manager.next_request_id();
799 let id2 = manager.next_request_id();
800 assert_ne!(id1, id2);
801 assert!(id1.starts_with("mcp-req-"));
802 }
803
804 #[test]
805 fn test_calculate_reconnect_delay() {
806 let manager = McpConnectionManager::new();
807
808 let delay0 = manager.calculate_reconnect_delay(0);
809 let delay1 = manager.calculate_reconnect_delay(1);
810 let delay2 = manager.calculate_reconnect_delay(2);
811
812 assert!(delay1 > delay0);
814 assert!(delay2 > delay1);
815
816 let delay_max = manager.calculate_reconnect_delay(100);
818 assert!(delay_max <= manager.default_options.reconnect_delay_max);
819 }
820
821 #[test]
822 fn test_create_transport_config_stdio() {
823 let server = McpServerInfo {
824 name: "test".to_string(),
825 transport_type: TransportType::Stdio,
826 command: Some("node".to_string()),
827 args: Some(vec!["server.js".to_string()]),
828 env: None,
829 url: None,
830 headers: None,
831 options: ConnectionOptions::default(),
832 };
833
834 let config = McpConnectionManager::create_transport_config(&server);
835 assert!(config.is_ok());
836 assert_eq!(config.unwrap().transport_type(), TransportType::Stdio);
837 }
838
839 #[test]
840 fn test_create_transport_config_http() {
841 let server = McpServerInfo {
842 name: "test".to_string(),
843 transport_type: TransportType::Http,
844 command: None,
845 args: None,
846 env: None,
847 url: Some("http://localhost:8080".to_string()),
848 headers: None,
849 options: ConnectionOptions::default(),
850 };
851
852 let config = McpConnectionManager::create_transport_config(&server);
853 assert!(config.is_ok());
854 assert_eq!(config.unwrap().transport_type(), TransportType::Http);
855 }
856
857 #[test]
858 fn test_create_transport_config_missing_command() {
859 let server = McpServerInfo {
860 name: "test".to_string(),
861 transport_type: TransportType::Stdio,
862 command: None, args: None,
864 env: None,
865 url: None,
866 headers: None,
867 options: ConnectionOptions::default(),
868 };
869
870 let config = McpConnectionManager::create_transport_config(&server);
871 assert!(config.is_err());
872 }
873
874 #[test]
875 fn test_create_transport_config_missing_url() {
876 let server = McpServerInfo {
877 name: "test".to_string(),
878 transport_type: TransportType::Http,
879 command: None,
880 args: None,
881 env: None,
882 url: None, headers: None,
884 options: ConnectionOptions::default(),
885 };
886
887 let config = McpConnectionManager::create_transport_config(&server);
888 assert!(config.is_err());
889 }
890
891 #[tokio::test]
892 async fn test_get_connection_not_found() {
893 let manager = McpConnectionManager::new();
894 let conn = manager.get_connection("nonexistent");
895 assert!(conn.is_none());
896 }
897
898 #[tokio::test]
899 async fn test_get_connection_by_server_not_found() {
900 let manager = McpConnectionManager::new();
901 let conn = manager.get_connection_by_server("nonexistent");
902 assert!(conn.is_none());
903 }
904
905 #[tokio::test]
906 async fn test_disconnect_not_found() {
907 let manager = McpConnectionManager::new();
908 let result = manager.disconnect("nonexistent").await;
909 assert!(result.is_err());
910 }
911
912 #[tokio::test]
913 async fn test_send_not_found() {
914 let manager = McpConnectionManager::new();
915 let request = McpRequest::new(serde_json::json!(1), "test");
916 let result = manager.send("nonexistent", request).await;
917 assert!(result.is_err());
918 }
919}