Skip to main content

aster/mcp/
connection_manager.rs

1//! MCP Connection Manager
2//!
3//! This module implements the connection manager for MCP servers.
4//! It manages multiple connections, handles reconnection, heartbeat monitoring,
5//! and provides a unified interface for sending requests to MCP servers.
6//!
7//! # Features
8//!
9//! - Multi-transport support (stdio, HTTP, SSE, WebSocket)
10//! - Automatic reconnection with exponential backoff
11//! - Heartbeat monitoring for connection health
12//! - Request/response matching by ID
13//! - Connection pooling and lifecycle management
14
15use async_trait::async_trait;
16use chrono::Utc;
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use uuid::Uuid;
23
24use crate::mcp::error::{McpError, McpResult};
25use crate::mcp::transport::{
26    BoxedTransport, McpRequest, McpResponse, TransportConfig, TransportFactory, TransportState,
27};
28use crate::mcp::types::{
29    ConnectionOptions, ConnectionStatus, McpConnection, McpServerInfo, TransportType,
30};
31
32/// Connection event for monitoring connection state changes
33#[derive(Debug, Clone)]
34pub enum ConnectionEvent {
35    /// Connection is being established
36    Establishing(McpConnection),
37    /// Connection established successfully
38    Established(McpConnection),
39    /// Connection closed
40    Closed(McpConnection),
41    /// Connection error occurred
42    Error(McpConnection, String),
43    /// Reconnection attempt started
44    Reconnecting(McpConnection),
45    /// Heartbeat failed
46    HeartbeatFailed(String, String),
47}
48
49/// Internal connection state
50struct ConnectionState {
51    /// Connection info
52    info: McpConnection,
53    /// Transport instance
54    transport: BoxedTransport,
55    /// Server info used to create this connection (for reconnection)
56    #[allow(dead_code)]
57    server_info: McpServerInfo,
58    /// Reconnection attempt count (for exponential backoff)
59    #[allow(dead_code)]
60    reconnect_attempts: u32,
61    /// Last heartbeat time
62    last_heartbeat: Option<chrono::DateTime<Utc>>,
63    /// Heartbeat task handle
64    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
65}
66
67/// Pending request info for tracking and cancellation
68#[derive(Debug, Clone)]
69pub struct PendingRequestInfo {
70    /// Request ID
71    pub request_id: String,
72    /// Connection ID
73    pub connection_id: String,
74    /// Method name
75    pub method: String,
76    /// Start time
77    pub start_time: chrono::DateTime<Utc>,
78}
79
80/// Connection manager trait
81///
82/// Defines the interface for managing MCP server connections.
83#[async_trait]
84pub trait ConnectionManager: Send + Sync {
85    /// Connect to an MCP server
86    async fn connect(&self, server: McpServerInfo) -> McpResult<McpConnection>;
87
88    /// Disconnect from a server
89    async fn disconnect(&self, connection_id: &str) -> McpResult<()>;
90
91    /// Disconnect all connections
92    async fn disconnect_all(&self) -> McpResult<()>;
93
94    /// Send a request to a server
95    async fn send(&self, connection_id: &str, request: McpRequest) -> McpResult<McpResponse>;
96
97    /// Send a request with timeout
98    async fn send_with_timeout(
99        &self,
100        connection_id: &str,
101        request: McpRequest,
102        timeout: Duration,
103    ) -> McpResult<McpResponse>;
104
105    /// Send a request with retry
106    async fn send_with_retry(
107        &self,
108        connection_id: &str,
109        request: McpRequest,
110    ) -> McpResult<McpResponse>;
111
112    /// Cancel a pending request by sending a cancellation notification
113    async fn cancel_request(&self, connection_id: &str, request_id: &str) -> McpResult<()>;
114
115    /// Get a connection by ID
116    fn get_connection(&self, id: &str) -> Option<McpConnection>;
117
118    /// Get a connection by server name
119    fn get_connection_by_server(&self, server_name: &str) -> Option<McpConnection>;
120
121    /// Get all connections
122    fn get_all_connections(&self) -> Vec<McpConnection>;
123
124    /// Subscribe to connection events
125    fn subscribe(&self) -> mpsc::Receiver<ConnectionEvent>;
126}
127
128/// Default implementation of the connection manager
129pub struct McpConnectionManager {
130    /// Active connections
131    connections: Arc<RwLock<HashMap<String, ConnectionState>>>,
132    /// Server name to connection ID mapping
133    server_to_connection: Arc<RwLock<HashMap<String, String>>>,
134    /// Default connection options
135    pub default_options: ConnectionOptions,
136    /// Event channel sender
137    event_tx: Arc<Mutex<Option<mpsc::Sender<ConnectionEvent>>>>,
138    /// Request ID counter
139    request_counter: AtomicU64,
140    /// Enable heartbeat monitoring
141    enable_heartbeat: bool,
142    /// Enable auto-reconnect
143    enable_auto_reconnect: bool,
144}
145
146impl McpConnectionManager {
147    /// Create a new connection manager with default options
148    pub fn new() -> Self {
149        Self::with_options(ConnectionOptions::default())
150    }
151
152    /// Create a new connection manager with custom options
153    pub fn with_options(options: ConnectionOptions) -> Self {
154        Self {
155            connections: Arc::new(RwLock::new(HashMap::new())),
156            server_to_connection: Arc::new(RwLock::new(HashMap::new())),
157            default_options: options,
158            event_tx: Arc::new(Mutex::new(None)),
159            request_counter: AtomicU64::new(1),
160            enable_heartbeat: true,
161            enable_auto_reconnect: true,
162        }
163    }
164
165    /// Enable or disable heartbeat monitoring
166    pub fn set_heartbeat_enabled(&mut self, enabled: bool) {
167        self.enable_heartbeat = enabled;
168    }
169
170    /// Enable or disable auto-reconnect
171    pub fn set_auto_reconnect_enabled(&mut self, enabled: bool) {
172        self.enable_auto_reconnect = enabled;
173    }
174
175    /// Generate a unique connection ID
176    pub fn generate_connection_id() -> String {
177        Uuid::new_v4().to_string()
178    }
179
180    /// Generate a unique request ID
181    pub fn next_request_id(&self) -> String {
182        let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
183        format!("mcp-req-{}", id)
184    }
185
186    /// Emit a connection event
187    async fn emit_event(&self, event: ConnectionEvent) {
188        if let Some(tx) = self.event_tx.lock().await.as_ref() {
189            let _ = tx.send(event).await;
190        }
191    }
192
193    /// Create transport config from server info
194    pub fn create_transport_config(server: &McpServerInfo) -> McpResult<TransportConfig> {
195        match server.transport_type {
196            TransportType::Stdio => {
197                let command = server
198                    .command
199                    .clone()
200                    .ok_or_else(|| McpError::config("Stdio transport requires a command"))?;
201                Ok(TransportConfig::Stdio {
202                    command,
203                    args: server.args.clone().unwrap_or_default(),
204                    env: server.env.clone().unwrap_or_default(),
205                    cwd: None,
206                })
207            }
208            TransportType::Http => {
209                let url = server
210                    .url
211                    .clone()
212                    .ok_or_else(|| McpError::config("HTTP transport requires a URL"))?;
213                Ok(TransportConfig::Http {
214                    url,
215                    headers: server.headers.clone().unwrap_or_default(),
216                })
217            }
218            TransportType::Sse => {
219                let url = server
220                    .url
221                    .clone()
222                    .ok_or_else(|| McpError::config("SSE transport requires a URL"))?;
223                Ok(TransportConfig::Sse {
224                    url,
225                    headers: server.headers.clone().unwrap_or_default(),
226                })
227            }
228            TransportType::WebSocket => {
229                let url = server
230                    .url
231                    .clone()
232                    .ok_or_else(|| McpError::config("WebSocket transport requires a URL"))?;
233                Ok(TransportConfig::WebSocket {
234                    url,
235                    headers: server.headers.clone().unwrap_or_default(),
236                })
237            }
238        }
239    }
240
241    /// Perform MCP protocol handshake
242    async fn perform_handshake(
243        transport: &mut BoxedTransport,
244        connection: &mut McpConnection,
245    ) -> McpResult<()> {
246        // Send initialize request
247        let init_request = McpRequest::with_params(
248            serde_json::json!("init-1"),
249            "initialize",
250            serde_json::json!({
251                "protocolVersion": "2024-11-05",
252                "capabilities": {
253                    "roots": { "listChanged": true },
254                    "sampling": {}
255                },
256                "clientInfo": {
257                    "name": "aster",
258                    "version": env!("CARGO_PKG_VERSION")
259                }
260            }),
261        );
262
263        let response = transport.send_request(init_request).await?;
264
265        // Parse server capabilities from response
266        if let Some(result) = response.result {
267            if let Some(protocol_version) = result.get("protocolVersion").and_then(|v| v.as_str()) {
268                connection.protocol_version = Some(protocol_version.to_string());
269            }
270
271            // Parse capabilities if available
272            if let Some(capabilities) = result.get("capabilities") {
273                if let Ok(caps) = serde_json::from_value(capabilities.clone()) {
274                    connection.capabilities = Some(caps);
275                }
276            }
277        }
278
279        // Send initialized notification
280        let initialized_notification =
281            crate::mcp::transport::McpNotification::new("notifications/initialized");
282        transport
283            .send(crate::mcp::transport::McpMessage::Notification(
284                initialized_notification,
285            ))
286            .await?;
287
288        Ok(())
289    }
290
291    /// Start heartbeat monitoring for a connection
292    fn start_heartbeat(&self, connection_id: String, interval: Duration) {
293        let connections = self.connections.clone();
294        let event_tx = self.event_tx.clone();
295        let enable_auto_reconnect = self.enable_auto_reconnect;
296
297        tokio::spawn(async move {
298            let mut interval_timer = tokio::time::interval(interval);
299
300            loop {
301                interval_timer.tick().await;
302
303                let mut conns = connections.write().await;
304                if let Some(state) = conns.get_mut(&connection_id) {
305                    // Check if transport is still connected
306                    if state.transport.state() != TransportState::Connected {
307                        // Emit heartbeat failed event
308                        if let Some(tx) = event_tx.lock().await.as_ref() {
309                            let _ = tx
310                                .send(ConnectionEvent::HeartbeatFailed(
311                                    connection_id.clone(),
312                                    "Transport disconnected".to_string(),
313                                ))
314                                .await;
315                        }
316
317                        // Attempt reconnection if enabled
318                        if enable_auto_reconnect {
319                            state.info.status = ConnectionStatus::Reconnecting;
320                            // Reconnection will be handled by the reconnect logic
321                        }
322                        break;
323                    }
324
325                    // Send ping request to check connection health
326                    let ping_request = McpRequest::new(
327                        serde_json::json!(format!("ping-{}", Uuid::new_v4())),
328                        "ping",
329                    );
330
331                    match state.transport.send_request(ping_request).await {
332                        Ok(_) => {
333                            state.last_heartbeat = Some(Utc::now());
334                            state.info.last_activity = Utc::now();
335                        }
336                        Err(e) => {
337                            // Emit heartbeat failed event
338                            if let Some(tx) = event_tx.lock().await.as_ref() {
339                                let _ = tx
340                                    .send(ConnectionEvent::HeartbeatFailed(
341                                        connection_id.clone(),
342                                        e.to_string(),
343                                    ))
344                                    .await;
345                            }
346
347                            if enable_auto_reconnect {
348                                state.info.status = ConnectionStatus::Reconnecting;
349                            }
350                            break;
351                        }
352                    }
353                } else {
354                    // Connection no longer exists
355                    break;
356                }
357            }
358        });
359    }
360
361    /// Calculate reconnection delay with exponential backoff
362    pub fn calculate_reconnect_delay(&self, attempt: u32) -> Duration {
363        let base = self.default_options.reconnect_delay_base.as_millis() as u64;
364        let max = self.default_options.reconnect_delay_max.as_millis() as u64;
365
366        // Exponential backoff: base * 2^attempt
367        let delay_ms = base.saturating_mul(1u64 << attempt.min(10));
368        Duration::from_millis(delay_ms.min(max))
369    }
370
371    /// Attempt to reconnect a disconnected connection
372    ///
373    /// This method implements automatic reconnection with exponential backoff.
374    /// It will retry up to `max_retries` times before giving up.
375    pub async fn reconnect(&self, connection_id: &str) -> McpResult<McpConnection> {
376        let (server_info, max_retries) = {
377            let conns = self.connections.read().await;
378            if let Some(state) = conns.get(connection_id) {
379                (state.server_info.clone(), self.default_options.max_retries)
380            } else {
381                return Err(McpError::connection(format!(
382                    "Connection not found: {}",
383                    connection_id
384                )));
385            }
386        };
387
388        // Update status to reconnecting
389        {
390            let mut conns = self.connections.write().await;
391            if let Some(state) = conns.get_mut(connection_id) {
392                state.info.status = ConnectionStatus::Reconnecting;
393                self.emit_event(ConnectionEvent::Reconnecting(state.info.clone()))
394                    .await;
395            }
396        }
397
398        let mut last_error = None;
399
400        for attempt in 0..=max_retries {
401            if attempt > 0 {
402                let delay = self.calculate_reconnect_delay(attempt - 1);
403                tokio::time::sleep(delay).await;
404            }
405
406            // Try to reconnect
407            match self.try_reconnect(connection_id, &server_info).await {
408                Ok(connection) => {
409                    // Reset reconnect attempts on success
410                    {
411                        let mut conns = self.connections.write().await;
412                        if let Some(state) = conns.get_mut(connection_id) {
413                            state.reconnect_attempts = 0;
414                        }
415                    }
416                    return Ok(connection);
417                }
418                Err(e) => {
419                    last_error = Some(e);
420                    // Update reconnect attempts
421                    {
422                        let mut conns = self.connections.write().await;
423                        if let Some(state) = conns.get_mut(connection_id) {
424                            state.reconnect_attempts = attempt + 1;
425                        }
426                    }
427                }
428            }
429        }
430
431        // All retries failed
432        {
433            let mut conns = self.connections.write().await;
434            if let Some(state) = conns.get_mut(connection_id) {
435                state.info.status = ConnectionStatus::Error;
436                self.emit_event(ConnectionEvent::Error(
437                    state.info.clone(),
438                    last_error
439                        .as_ref()
440                        .map(|e| e.to_string())
441                        .unwrap_or_else(|| "Unknown error".to_string()),
442                ))
443                .await;
444            }
445        }
446
447        Err(last_error.unwrap_or_else(|| McpError::connection("Reconnection failed after retries")))
448    }
449
450    /// Internal method to attempt a single reconnection
451    async fn try_reconnect(
452        &self,
453        connection_id: &str,
454        server_info: &McpServerInfo,
455    ) -> McpResult<McpConnection> {
456        // Create new transport
457        let transport_config = Self::create_transport_config(server_info)?;
458        let mut transport =
459            TransportFactory::create(transport_config, server_info.options.clone())?;
460
461        // Connect transport
462        transport.connect().await?;
463
464        // Create new connection info
465        let mut connection = McpConnection::new(
466            connection_id.to_string(),
467            server_info.name.clone(),
468            server_info.transport_type,
469        );
470
471        // Perform handshake
472        Self::perform_handshake(&mut transport, &mut connection).await?;
473
474        // Update connection status
475        connection.status = ConnectionStatus::Connected;
476        connection.touch();
477
478        // Update stored connection
479        {
480            let mut conns = self.connections.write().await;
481            if let Some(state) = conns.get_mut(connection_id) {
482                state.info = connection.clone();
483                state.transport = transport;
484                state.last_heartbeat = Some(Utc::now());
485            }
486        }
487
488        // Emit established event
489        self.emit_event(ConnectionEvent::Established(connection.clone()))
490            .await;
491
492        Ok(connection)
493    }
494}
495
496impl Default for McpConnectionManager {
497    fn default() -> Self {
498        Self::new()
499    }
500}
501
502#[async_trait]
503impl ConnectionManager for McpConnectionManager {
504    async fn connect(&self, server: McpServerInfo) -> McpResult<McpConnection> {
505        // Check if already connected to this server
506        {
507            let server_map = self.server_to_connection.read().await;
508            if let Some(conn_id) = server_map.get(&server.name) {
509                let conns = self.connections.read().await;
510                if let Some(state) = conns.get(conn_id) {
511                    if state.info.status == ConnectionStatus::Connected {
512                        return Ok(state.info.clone());
513                    }
514                }
515            }
516        }
517
518        // Create connection ID and info
519        let connection_id = Self::generate_connection_id();
520        let mut connection = McpConnection::new(
521            connection_id.clone(),
522            server.name.clone(),
523            server.transport_type,
524        );
525
526        // Emit establishing event
527        self.emit_event(ConnectionEvent::Establishing(connection.clone()))
528            .await;
529
530        // Create transport config
531        let transport_config = Self::create_transport_config(&server)?;
532
533        // Create and connect transport
534        let options = server.options.clone();
535        let mut transport = TransportFactory::create(transport_config, options.clone())?;
536
537        transport.connect().await?;
538
539        // Perform MCP handshake
540        Self::perform_handshake(&mut transport, &mut connection).await?;
541
542        // Update connection status
543        connection.status = ConnectionStatus::Connected;
544        connection.touch();
545
546        // Store connection
547        {
548            let mut conns = self.connections.write().await;
549            conns.insert(
550                connection_id.clone(),
551                ConnectionState {
552                    info: connection.clone(),
553                    transport,
554                    server_info: server.clone(),
555                    reconnect_attempts: 0,
556                    last_heartbeat: Some(Utc::now()),
557                    heartbeat_handle: None,
558                },
559            );
560        }
561
562        // Update server mapping
563        {
564            let mut server_map = self.server_to_connection.write().await;
565            server_map.insert(server.name.clone(), connection_id.clone());
566        }
567
568        // Start heartbeat if enabled
569        if self.enable_heartbeat {
570            self.start_heartbeat(connection_id, options.heartbeat_interval);
571        }
572
573        // Emit established event
574        self.emit_event(ConnectionEvent::Established(connection.clone()))
575            .await;
576
577        Ok(connection)
578    }
579
580    async fn disconnect(&self, connection_id: &str) -> McpResult<()> {
581        let mut conns = self.connections.write().await;
582
583        if let Some(mut state) = conns.remove(connection_id) {
584            // Cancel heartbeat task
585            if let Some(handle) = state.heartbeat_handle.take() {
586                handle.abort();
587            }
588
589            // Disconnect transport
590            state.transport.disconnect().await?;
591
592            // Update status
593            state.info.status = ConnectionStatus::Disconnected;
594
595            // Remove from server mapping
596            {
597                let mut server_map = self.server_to_connection.write().await;
598                server_map.remove(&state.info.server_name);
599            }
600
601            // Emit closed event
602            self.emit_event(ConnectionEvent::Closed(state.info)).await;
603
604            Ok(())
605        } else {
606            Err(McpError::connection(format!(
607                "Connection not found: {}",
608                connection_id
609            )))
610        }
611    }
612
613    async fn disconnect_all(&self) -> McpResult<()> {
614        let connection_ids: Vec<String> = {
615            let conns = self.connections.read().await;
616            conns.keys().cloned().collect()
617        };
618
619        for id in connection_ids {
620            if let Err(e) = self.disconnect(&id).await {
621                tracing::warn!("Failed to disconnect {}: {}", id, e);
622            }
623        }
624
625        Ok(())
626    }
627
628    async fn send(&self, connection_id: &str, request: McpRequest) -> McpResult<McpResponse> {
629        let mut conns = self.connections.write().await;
630
631        if let Some(state) = conns.get_mut(connection_id) {
632            if state.info.status != ConnectionStatus::Connected {
633                return Err(McpError::connection("Connection is not active"));
634            }
635
636            let response = state.transport.send_request(request).await?;
637            state.info.touch();
638
639            Ok(response)
640        } else {
641            Err(McpError::connection(format!(
642                "Connection not found: {}",
643                connection_id
644            )))
645        }
646    }
647
648    async fn send_with_timeout(
649        &self,
650        connection_id: &str,
651        request: McpRequest,
652        timeout: Duration,
653    ) -> McpResult<McpResponse> {
654        let mut conns = self.connections.write().await;
655
656        if let Some(state) = conns.get_mut(connection_id) {
657            if state.info.status != ConnectionStatus::Connected {
658                return Err(McpError::connection("Connection is not active"));
659            }
660
661            let response = state
662                .transport
663                .send_request_with_timeout(request, timeout)
664                .await?;
665            state.info.touch();
666
667            Ok(response)
668        } else {
669            Err(McpError::connection(format!(
670                "Connection not found: {}",
671                connection_id
672            )))
673        }
674    }
675
676    async fn send_with_retry(
677        &self,
678        connection_id: &str,
679        request: McpRequest,
680    ) -> McpResult<McpResponse> {
681        let max_retries = self.default_options.max_retries;
682        let mut last_error = None;
683
684        for attempt in 0..=max_retries {
685            match self.send(connection_id, request.clone()).await {
686                Ok(response) => return Ok(response),
687                Err(e) => {
688                    last_error = Some(e);
689                    if attempt < max_retries {
690                        let delay = self.calculate_reconnect_delay(attempt);
691                        tokio::time::sleep(delay).await;
692                    }
693                }
694            }
695        }
696
697        Err(last_error.unwrap_or_else(|| McpError::connection("Request failed after retries")))
698    }
699
700    async fn cancel_request(&self, connection_id: &str, request_id: &str) -> McpResult<()> {
701        let mut conns = self.connections.write().await;
702
703        if let Some(state) = conns.get_mut(connection_id) {
704            if state.info.status != ConnectionStatus::Connected {
705                return Err(McpError::connection("Connection is not active"));
706            }
707
708            // Send cancellation notification per MCP protocol
709            let cancel_notification = crate::mcp::transport::McpNotification::with_params(
710                "notifications/cancelled",
711                serde_json::json!({
712                    "requestId": request_id,
713                    "reason": "Cancelled by client"
714                }),
715            );
716
717            state
718                .transport
719                .send(crate::mcp::transport::McpMessage::Notification(
720                    cancel_notification,
721                ))
722                .await?;
723
724            Ok(())
725        } else {
726            Err(McpError::connection(format!(
727                "Connection not found: {}",
728                connection_id
729            )))
730        }
731    }
732
733    fn get_connection(&self, id: &str) -> Option<McpConnection> {
734        // Use try_read to avoid blocking
735        self.connections
736            .try_read()
737            .ok()
738            .and_then(|conns| conns.get(id).map(|s| s.info.clone()))
739    }
740
741    fn get_connection_by_server(&self, server_name: &str) -> Option<McpConnection> {
742        let server_map = self.server_to_connection.try_read().ok()?;
743        let conn_id = server_map.get(server_name)?;
744        self.get_connection(conn_id)
745    }
746
747    fn get_all_connections(&self) -> Vec<McpConnection> {
748        self.connections
749            .try_read()
750            .map(|conns| conns.values().map(|s| s.info.clone()).collect())
751            .unwrap_or_default()
752    }
753
754    fn subscribe(&self) -> mpsc::Receiver<ConnectionEvent> {
755        let (tx, rx) = mpsc::channel(100);
756        let event_tx = self.event_tx.clone();
757        tokio::spawn(async move {
758            *event_tx.lock().await = Some(tx);
759        });
760        rx
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767
768    #[test]
769    fn test_connection_manager_new() {
770        let manager = McpConnectionManager::new();
771        assert!(manager.get_all_connections().is_empty());
772    }
773
774    #[test]
775    fn test_connection_manager_with_options() {
776        let options = ConnectionOptions {
777            timeout: Duration::from_secs(60),
778            max_retries: 5,
779            ..Default::default()
780        };
781        let manager = McpConnectionManager::with_options(options);
782        assert_eq!(manager.default_options.timeout, Duration::from_secs(60));
783        assert_eq!(manager.default_options.max_retries, 5);
784    }
785
786    #[test]
787    fn test_generate_connection_id() {
788        let id1 = McpConnectionManager::generate_connection_id();
789        let id2 = McpConnectionManager::generate_connection_id();
790        assert_ne!(id1, id2);
791        // Should be valid UUID format
792        assert!(Uuid::parse_str(&id1).is_ok());
793    }
794
795    #[test]
796    fn test_next_request_id() {
797        let manager = McpConnectionManager::new();
798        let id1 = manager.next_request_id();
799        let id2 = manager.next_request_id();
800        assert_ne!(id1, id2);
801        assert!(id1.starts_with("mcp-req-"));
802    }
803
804    #[test]
805    fn test_calculate_reconnect_delay() {
806        let manager = McpConnectionManager::new();
807
808        let delay0 = manager.calculate_reconnect_delay(0);
809        let delay1 = manager.calculate_reconnect_delay(1);
810        let delay2 = manager.calculate_reconnect_delay(2);
811
812        // Each delay should be roughly double the previous
813        assert!(delay1 > delay0);
814        assert!(delay2 > delay1);
815
816        // Should not exceed max
817        let delay_max = manager.calculate_reconnect_delay(100);
818        assert!(delay_max <= manager.default_options.reconnect_delay_max);
819    }
820
821    #[test]
822    fn test_create_transport_config_stdio() {
823        let server = McpServerInfo {
824            name: "test".to_string(),
825            transport_type: TransportType::Stdio,
826            command: Some("node".to_string()),
827            args: Some(vec!["server.js".to_string()]),
828            env: None,
829            url: None,
830            headers: None,
831            options: ConnectionOptions::default(),
832        };
833
834        let config = McpConnectionManager::create_transport_config(&server);
835        assert!(config.is_ok());
836        assert_eq!(config.unwrap().transport_type(), TransportType::Stdio);
837    }
838
839    #[test]
840    fn test_create_transport_config_http() {
841        let server = McpServerInfo {
842            name: "test".to_string(),
843            transport_type: TransportType::Http,
844            command: None,
845            args: None,
846            env: None,
847            url: Some("http://localhost:8080".to_string()),
848            headers: None,
849            options: ConnectionOptions::default(),
850        };
851
852        let config = McpConnectionManager::create_transport_config(&server);
853        assert!(config.is_ok());
854        assert_eq!(config.unwrap().transport_type(), TransportType::Http);
855    }
856
857    #[test]
858    fn test_create_transport_config_missing_command() {
859        let server = McpServerInfo {
860            name: "test".to_string(),
861            transport_type: TransportType::Stdio,
862            command: None, // Missing required command
863            args: None,
864            env: None,
865            url: None,
866            headers: None,
867            options: ConnectionOptions::default(),
868        };
869
870        let config = McpConnectionManager::create_transport_config(&server);
871        assert!(config.is_err());
872    }
873
874    #[test]
875    fn test_create_transport_config_missing_url() {
876        let server = McpServerInfo {
877            name: "test".to_string(),
878            transport_type: TransportType::Http,
879            command: None,
880            args: None,
881            env: None,
882            url: None, // Missing required URL
883            headers: None,
884            options: ConnectionOptions::default(),
885        };
886
887        let config = McpConnectionManager::create_transport_config(&server);
888        assert!(config.is_err());
889    }
890
891    #[tokio::test]
892    async fn test_get_connection_not_found() {
893        let manager = McpConnectionManager::new();
894        let conn = manager.get_connection("nonexistent");
895        assert!(conn.is_none());
896    }
897
898    #[tokio::test]
899    async fn test_get_connection_by_server_not_found() {
900        let manager = McpConnectionManager::new();
901        let conn = manager.get_connection_by_server("nonexistent");
902        assert!(conn.is_none());
903    }
904
905    #[tokio::test]
906    async fn test_disconnect_not_found() {
907        let manager = McpConnectionManager::new();
908        let result = manager.disconnect("nonexistent").await;
909        assert!(result.is_err());
910    }
911
912    #[tokio::test]
913    async fn test_send_not_found() {
914        let manager = McpConnectionManager::new();
915        let request = McpRequest::new(serde_json::json!(1), "test");
916        let result = manager.send("nonexistent", request).await;
917        assert!(result.is_err());
918    }
919}