Skip to main content

adk_tool/mcp/
reconnect.rs

1// MCP Connection Refresher
2//
3// Provides automatic reconnection for MCP connections when they fail.
4// Based on adk-go's connectionRefresher pattern.
5//
6// Handles:
7// - Connection closed errors
8// - EOF errors
9// - Session not found errors
10// - Automatic retry with reconnection
11
12use rmcp::{
13    RoleClient,
14    model::{CallToolRequestParams, CallToolResult, Tool as McpTool},
15    service::RunningService,
16};
17use std::sync::Arc;
18use tokio::sync::Mutex;
19use tracing::{debug, info, warn};
20
21/// Errors that should trigger a connection refresh
22pub fn should_refresh_connection(error: &str) -> bool {
23    let error_lower = error.to_lowercase();
24
25    // Connection closed
26    if error_lower.contains("connection closed") || error_lower.contains("connectionclosed") {
27        return true;
28    }
29
30    // EOF / pipe closed
31    if error_lower.contains("eof")
32        || error_lower.contains("closed pipe")
33        || error_lower.contains("broken pipe")
34    {
35        return true;
36    }
37
38    // Session not found (server restarted)
39    if error_lower.contains("session not found") || error_lower.contains("session missing") {
40        return true;
41    }
42
43    // Transport errors
44    if error_lower.contains("transport error") || error_lower.contains("connection reset") {
45        return true;
46    }
47
48    false
49}
50
51/// Result of an operation with retry information
52#[derive(Debug, Clone)]
53pub struct RetryResult<T> {
54    /// The result value
55    pub value: T,
56    /// Whether a reconnection occurred
57    pub reconnected: bool,
58}
59
60impl<T> RetryResult<T> {
61    /// Create a new result without reconnection
62    pub fn ok(value: T) -> Self {
63        Self { value, reconnected: false }
64    }
65
66    /// Create a new result after reconnection
67    pub fn reconnected(value: T) -> Self {
68        Self { value, reconnected: true }
69    }
70}
71
72/// Configuration for connection refresh behavior
73#[derive(Debug, Clone)]
74pub struct RefreshConfig {
75    /// Maximum number of reconnection attempts
76    pub max_attempts: u32,
77    /// Delay between reconnection attempts in milliseconds
78    pub retry_delay_ms: u64,
79    /// Whether to log reconnection attempts
80    pub log_reconnections: bool,
81}
82
83impl Default for RefreshConfig {
84    fn default() -> Self {
85        Self { max_attempts: 3, retry_delay_ms: 1000, log_reconnections: true }
86    }
87}
88
89impl RefreshConfig {
90    /// Create a new config with custom max attempts
91    pub fn with_max_attempts(mut self, attempts: u32) -> Self {
92        self.max_attempts = attempts;
93        self
94    }
95
96    /// Create a new config with custom retry delay
97    pub fn with_retry_delay_ms(mut self, delay_ms: u64) -> Self {
98        self.retry_delay_ms = delay_ms;
99        self
100    }
101
102    /// Disable logging
103    pub fn without_logging(mut self) -> Self {
104        self.log_reconnections = false;
105        self
106    }
107}
108
109/// Factory trait for creating new MCP connections.
110///
111/// Implement this trait to provide reconnection capability to `ConnectionRefresher`.
112#[async_trait::async_trait]
113pub trait ConnectionFactory<S>: Send + Sync
114where
115    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
116{
117    /// Create a new connection to the MCP server.
118    async fn create_connection(&self) -> Result<RunningService<RoleClient, S>, String>;
119}
120
121/// Connection refresher that wraps an MCP client and handles automatic reconnection.
122///
123/// This is similar to adk-go's `connectionRefresher` struct. It transparently
124/// retries operations after reconnecting when the underlying session fails.
125///
126/// # Type Parameters
127///
128/// * `S` - The service type for the MCP client
129/// * `F` - The factory type for creating new connections
130///
131/// # Example
132///
133/// ```rust,ignore
134/// use adk_tool::mcp::{ConnectionRefresher, ConnectionFactory};
135///
136/// struct MyFactory { /* ... */ }
137///
138/// #[async_trait::async_trait]
139/// impl ConnectionFactory<MyService> for MyFactory {
140///     async fn create_connection(&self) -> Result<RunningService<RoleClient, MyService>, String> {
141///         // Create and return a new connection
142///     }
143/// }
144///
145/// let refresher = ConnectionRefresher::new(initial_client, Arc::new(factory));
146///
147/// // Operations automatically retry on connection failure
148/// let tools = refresher.list_tools().await?;
149/// let result = refresher.call_tool(params).await?;
150/// ```
151pub struct ConnectionRefresher<S, F>
152where
153    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
154    F: ConnectionFactory<S>,
155{
156    /// The current MCP client session
157    client: Arc<Mutex<Option<RunningService<RoleClient, S>>>>,
158    /// Factory for creating new connections
159    factory: Arc<F>,
160    /// Configuration for refresh behavior
161    config: RefreshConfig,
162}
163
164impl<S, F> ConnectionRefresher<S, F>
165where
166    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
167    F: ConnectionFactory<S>,
168{
169    /// Create a new connection refresher with an initial client and factory.
170    ///
171    /// # Arguments
172    ///
173    /// * `client` - The initial MCP client connection
174    /// * `factory` - Factory for creating new connections when needed
175    pub fn new(client: RunningService<RoleClient, S>, factory: Arc<F>) -> Self {
176        Self {
177            client: Arc::new(Mutex::new(Some(client))),
178            factory,
179            config: RefreshConfig::default(),
180        }
181    }
182
183    /// Create a new connection refresher without an initial connection.
184    ///
185    /// The first operation will trigger a connection.
186    pub fn lazy(factory: Arc<F>) -> Self {
187        Self { client: Arc::new(Mutex::new(None)), factory, config: RefreshConfig::default() }
188    }
189
190    /// Set the refresh configuration.
191    pub fn with_config(mut self, config: RefreshConfig) -> Self {
192        self.config = config;
193        self
194    }
195
196    /// Set the maximum number of reconnection attempts.
197    pub fn with_max_attempts(mut self, attempts: u32) -> Self {
198        self.config.max_attempts = attempts;
199        self
200    }
201
202    /// Ensure we have a valid connection, creating one if needed.
203    async fn ensure_connected(&self) -> Result<(), String> {
204        let mut guard = self.client.lock().await;
205
206        if guard.is_none() {
207            if self.config.log_reconnections {
208                info!("MCP client not connected, creating connection");
209            }
210            let new_client = self.factory.create_connection().await?;
211            *guard = Some(new_client);
212        }
213
214        Ok(())
215    }
216
217    /// Refresh the connection by creating a new client.
218    async fn refresh_connection(&self) -> Result<(), String> {
219        let mut guard = self.client.lock().await;
220
221        // Close existing connection if any
222        if let Some(old_client) = guard.take() {
223            if self.config.log_reconnections {
224                debug!("Closing old MCP connection");
225            }
226            let token = old_client.cancellation_token();
227            token.cancel();
228            // Give it a moment to clean up
229            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
230        }
231
232        if self.config.log_reconnections {
233            info!("Refreshing MCP connection");
234        }
235        let new_client = self.factory.create_connection().await?;
236        *guard = Some(new_client);
237
238        Ok(())
239    }
240
241    /// List all tools from the MCP server with automatic reconnection.
242    ///
243    /// Handles pagination internally and restarts from scratch if
244    /// reconnection occurs (per MCP spec, cursors don't persist across sessions).
245    pub async fn list_tools(&self) -> Result<RetryResult<Vec<McpTool>>, String> {
246        // Ensure we have a connection
247        self.ensure_connected().await?;
248
249        // First attempt
250        {
251            let guard = self.client.lock().await;
252            if let Some(ref client) = *guard {
253                match client.list_all_tools().await {
254                    Ok(tools) => return Ok(RetryResult::ok(tools)),
255                    Err(e) => {
256                        let error_str = e.to_string();
257                        if !should_refresh_connection(&error_str) {
258                            return Err(error_str);
259                        }
260                        if self.config.log_reconnections {
261                            warn!(error = %error_str, "list_tools failed, will retry with reconnection");
262                        }
263                    }
264                }
265            }
266        }
267
268        // Retry with reconnection
269        for attempt in 1..=self.config.max_attempts {
270            if self.config.log_reconnections {
271                info!(
272                    attempt = attempt,
273                    max = self.config.max_attempts,
274                    "Reconnection attempt for list_tools"
275                );
276            }
277
278            // Wait before retry
279            if self.config.retry_delay_ms > 0 {
280                tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
281                    .await;
282            }
283
284            // Try to refresh
285            if let Err(e) = self.refresh_connection().await {
286                if self.config.log_reconnections {
287                    warn!(error = %e, attempt = attempt, "Refresh failed");
288                }
289                continue;
290            }
291
292            // Retry operation
293            let guard = self.client.lock().await;
294            if let Some(ref client) = *guard {
295                match client.list_all_tools().await {
296                    Ok(tools) => {
297                        if self.config.log_reconnections {
298                            debug!(
299                                attempt = attempt,
300                                tool_count = tools.len(),
301                                "list_tools succeeded after reconnection"
302                            );
303                        }
304                        return Ok(RetryResult::reconnected(tools));
305                    }
306                    Err(e) => {
307                        if self.config.log_reconnections {
308                            warn!(error = %e, attempt = attempt, "list_tools failed after reconnection");
309                        }
310                    }
311                }
312            }
313        }
314
315        // Final attempt
316        let guard = self.client.lock().await;
317        if let Some(ref client) = *guard {
318            client.list_all_tools().await.map(RetryResult::ok).map_err(|e| e.to_string())
319        } else {
320            Err("No MCP client available".to_string())
321        }
322    }
323
324    /// Call a tool on the MCP server with automatic reconnection.
325    pub async fn call_tool(
326        &self,
327        params: CallToolRequestParams,
328    ) -> Result<RetryResult<CallToolResult>, String> {
329        // Ensure we have a connection
330        self.ensure_connected().await?;
331
332        // First attempt
333        {
334            let guard = self.client.lock().await;
335            if let Some(ref client) = *guard {
336                match client.call_tool(params.clone()).await {
337                    Ok(result) => return Ok(RetryResult::ok(result)),
338                    Err(e) => {
339                        let error_str = e.to_string();
340                        if !should_refresh_connection(&error_str) {
341                            return Err(error_str);
342                        }
343                        if self.config.log_reconnections {
344                            warn!(error = %error_str, tool = %params.name, "call_tool failed, will retry with reconnection");
345                        }
346                    }
347                }
348            }
349        }
350
351        // Retry with reconnection
352        for attempt in 1..=self.config.max_attempts {
353            if self.config.log_reconnections {
354                info!(attempt = attempt, max = self.config.max_attempts, tool = %params.name, "Reconnection attempt for call_tool");
355            }
356
357            // Wait before retry
358            if self.config.retry_delay_ms > 0 {
359                tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
360                    .await;
361            }
362
363            // Try to refresh
364            if let Err(e) = self.refresh_connection().await {
365                if self.config.log_reconnections {
366                    warn!(error = %e, attempt = attempt, "Refresh failed");
367                }
368                continue;
369            }
370
371            // Retry operation
372            let guard = self.client.lock().await;
373            if let Some(ref client) = *guard {
374                match client.call_tool(params.clone()).await {
375                    Ok(result) => {
376                        if self.config.log_reconnections {
377                            debug!(attempt = attempt, tool = %params.name, "call_tool succeeded after reconnection");
378                        }
379                        return Ok(RetryResult::reconnected(result));
380                    }
381                    Err(e) => {
382                        if self.config.log_reconnections {
383                            warn!(error = %e, attempt = attempt, "call_tool failed after reconnection");
384                        }
385                    }
386                }
387            }
388        }
389
390        // Final attempt
391        let guard = self.client.lock().await;
392        if let Some(ref client) = *guard {
393            client.call_tool(params).await.map(RetryResult::ok).map_err(|e| e.to_string())
394        } else {
395            Err("No MCP client available".to_string())
396        }
397    }
398
399    /// Get the cancellation token for the current connection.
400    pub async fn cancellation_token(
401        &self,
402    ) -> Option<rmcp::service::RunningServiceCancellationToken> {
403        let guard = self.client.lock().await;
404        guard.as_ref().map(|c| c.cancellation_token())
405    }
406
407    /// Check if currently connected.
408    pub async fn is_connected(&self) -> bool {
409        let guard = self.client.lock().await;
410        guard.is_some()
411    }
412
413    /// Force a reconnection.
414    pub async fn reconnect(&self) -> Result<(), String> {
415        self.refresh_connection().await
416    }
417
418    /// Close the connection.
419    pub async fn close(&self) {
420        let mut guard = self.client.lock().await;
421        if let Some(client) = guard.take() {
422            let token = client.cancellation_token();
423            token.cancel();
424        }
425    }
426}
427
428/// Simple wrapper for MCP clients that don't support reconnection.
429///
430/// Use this for stdio-based MCP servers where reconnection isn't possible
431/// without restarting the server process.
432pub struct SimpleClient<S>
433where
434    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
435{
436    client: Arc<Mutex<RunningService<RoleClient, S>>>,
437}
438
439impl<S> SimpleClient<S>
440where
441    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
442{
443    /// Create a new simple client wrapper.
444    pub fn new(client: RunningService<RoleClient, S>) -> Self {
445        Self { client: Arc::new(Mutex::new(client)) }
446    }
447
448    /// List all tools from the MCP server.
449    pub async fn list_tools(&self) -> Result<Vec<McpTool>, String> {
450        let client = self.client.lock().await;
451        client.list_all_tools().await.map_err(|e| e.to_string())
452    }
453
454    /// Call a tool on the MCP server.
455    pub async fn call_tool(&self, params: CallToolRequestParams) -> Result<CallToolResult, String> {
456        let client = self.client.lock().await;
457        client.call_tool(params).await.map_err(|e| e.to_string())
458    }
459
460    /// Get the cancellation token.
461    pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
462        let client = self.client.lock().await;
463        client.cancellation_token()
464    }
465
466    /// Get access to the underlying client mutex.
467    pub fn inner(&self) -> &Arc<Mutex<RunningService<RoleClient, S>>> {
468        &self.client
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_should_refresh_connection() {
478        assert!(should_refresh_connection("connection closed"));
479        assert!(should_refresh_connection("ConnectionClosed"));
480        assert!(should_refresh_connection("EOF"));
481        assert!(should_refresh_connection("eof error"));
482        assert!(should_refresh_connection("broken pipe"));
483        assert!(should_refresh_connection("session not found"));
484        assert!(should_refresh_connection("transport error"));
485        assert!(should_refresh_connection("connection reset"));
486
487        // Should not refresh for other errors
488        assert!(!should_refresh_connection("invalid argument"));
489        assert!(!should_refresh_connection("permission denied"));
490        assert!(!should_refresh_connection("tool not found"));
491    }
492
493    #[test]
494    fn test_refresh_config_default() {
495        let config = RefreshConfig::default();
496        assert_eq!(config.max_attempts, 3);
497        assert_eq!(config.retry_delay_ms, 1000);
498        assert!(config.log_reconnections);
499    }
500
501    #[test]
502    fn test_refresh_config_builder() {
503        let config = RefreshConfig::default()
504            .with_max_attempts(5)
505            .with_retry_delay_ms(500)
506            .without_logging();
507
508        assert_eq!(config.max_attempts, 5);
509        assert_eq!(config.retry_delay_ms, 500);
510        assert!(!config.log_reconnections);
511    }
512
513    #[test]
514    fn test_retry_result() {
515        let ok_result = RetryResult::ok(42);
516        assert_eq!(ok_result.value, 42);
517        assert!(!ok_result.reconnected);
518
519        let reconnected_result = RetryResult::reconnected(42);
520        assert_eq!(reconnected_result.value, 42);
521        assert!(reconnected_result.reconnected);
522    }
523}