1use rmcp::{
13 RoleClient,
14 model::{CallToolRequestParams, CallToolResult, Tool as McpTool},
15 service::RunningService,
16};
17use std::sync::Arc;
18use tokio::sync::Mutex;
19use tracing::{debug, info, warn};
20
21pub fn should_refresh_connection(error: &str) -> bool {
23 let error_lower = error.to_lowercase();
24
25 if error_lower.contains("connection closed") || error_lower.contains("connectionclosed") {
27 return true;
28 }
29
30 if error_lower.contains("eof")
32 || error_lower.contains("closed pipe")
33 || error_lower.contains("broken pipe")
34 {
35 return true;
36 }
37
38 if error_lower.contains("session not found") || error_lower.contains("session missing") {
40 return true;
41 }
42
43 if error_lower.contains("transport error") || error_lower.contains("connection reset") {
45 return true;
46 }
47
48 false
49}
50
51#[derive(Debug, Clone)]
53pub struct RetryResult<T> {
54 pub value: T,
56 pub reconnected: bool,
58}
59
60impl<T> RetryResult<T> {
61 pub fn ok(value: T) -> Self {
63 Self { value, reconnected: false }
64 }
65
66 pub fn reconnected(value: T) -> Self {
68 Self { value, reconnected: true }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct RefreshConfig {
75 pub max_attempts: u32,
77 pub retry_delay_ms: u64,
79 pub log_reconnections: bool,
81}
82
83impl Default for RefreshConfig {
84 fn default() -> Self {
85 Self { max_attempts: 3, retry_delay_ms: 1000, log_reconnections: true }
86 }
87}
88
89impl RefreshConfig {
90 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
92 self.max_attempts = attempts;
93 self
94 }
95
96 pub fn with_retry_delay_ms(mut self, delay_ms: u64) -> Self {
98 self.retry_delay_ms = delay_ms;
99 self
100 }
101
102 pub fn without_logging(mut self) -> Self {
104 self.log_reconnections = false;
105 self
106 }
107}
108
109#[async_trait::async_trait]
113pub trait ConnectionFactory<S>: Send + Sync
114where
115 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
116{
117 async fn create_connection(&self) -> Result<RunningService<RoleClient, S>, String>;
119}
120
121pub struct ConnectionRefresher<S, F>
152where
153 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
154 F: ConnectionFactory<S>,
155{
156 client: Arc<Mutex<Option<RunningService<RoleClient, S>>>>,
158 factory: Arc<F>,
160 config: RefreshConfig,
162}
163
164impl<S, F> ConnectionRefresher<S, F>
165where
166 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
167 F: ConnectionFactory<S>,
168{
169 pub fn new(client: RunningService<RoleClient, S>, factory: Arc<F>) -> Self {
176 Self {
177 client: Arc::new(Mutex::new(Some(client))),
178 factory,
179 config: RefreshConfig::default(),
180 }
181 }
182
183 pub fn lazy(factory: Arc<F>) -> Self {
187 Self { client: Arc::new(Mutex::new(None)), factory, config: RefreshConfig::default() }
188 }
189
190 pub fn with_config(mut self, config: RefreshConfig) -> Self {
192 self.config = config;
193 self
194 }
195
196 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
198 self.config.max_attempts = attempts;
199 self
200 }
201
202 async fn ensure_connected(&self) -> Result<(), String> {
204 let mut guard = self.client.lock().await;
205
206 if guard.is_none() {
207 if self.config.log_reconnections {
208 info!("MCP client not connected, creating connection");
209 }
210 let new_client = self.factory.create_connection().await?;
211 *guard = Some(new_client);
212 }
213
214 Ok(())
215 }
216
217 async fn refresh_connection(&self) -> Result<(), String> {
219 let mut guard = self.client.lock().await;
220
221 if let Some(old_client) = guard.take() {
223 if self.config.log_reconnections {
224 debug!("Closing old MCP connection");
225 }
226 let token = old_client.cancellation_token();
227 token.cancel();
228 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
230 }
231
232 if self.config.log_reconnections {
233 info!("Refreshing MCP connection");
234 }
235 let new_client = self.factory.create_connection().await?;
236 *guard = Some(new_client);
237
238 Ok(())
239 }
240
241 pub async fn list_tools(&self) -> Result<RetryResult<Vec<McpTool>>, String> {
246 self.ensure_connected().await?;
248
249 {
251 let guard = self.client.lock().await;
252 if let Some(ref client) = *guard {
253 match client.list_all_tools().await {
254 Ok(tools) => return Ok(RetryResult::ok(tools)),
255 Err(e) => {
256 let error_str = e.to_string();
257 if !should_refresh_connection(&error_str) {
258 return Err(error_str);
259 }
260 if self.config.log_reconnections {
261 warn!(error = %error_str, "list_tools failed, will retry with reconnection");
262 }
263 }
264 }
265 }
266 }
267
268 for attempt in 1..=self.config.max_attempts {
270 if self.config.log_reconnections {
271 info!(
272 attempt = attempt,
273 max = self.config.max_attempts,
274 "Reconnection attempt for list_tools"
275 );
276 }
277
278 if self.config.retry_delay_ms > 0 {
280 tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
281 .await;
282 }
283
284 if let Err(e) = self.refresh_connection().await {
286 if self.config.log_reconnections {
287 warn!(error = %e, attempt = attempt, "Refresh failed");
288 }
289 continue;
290 }
291
292 let guard = self.client.lock().await;
294 if let Some(ref client) = *guard {
295 match client.list_all_tools().await {
296 Ok(tools) => {
297 if self.config.log_reconnections {
298 debug!(
299 attempt = attempt,
300 tool_count = tools.len(),
301 "list_tools succeeded after reconnection"
302 );
303 }
304 return Ok(RetryResult::reconnected(tools));
305 }
306 Err(e) => {
307 if self.config.log_reconnections {
308 warn!(error = %e, attempt = attempt, "list_tools failed after reconnection");
309 }
310 }
311 }
312 }
313 }
314
315 let guard = self.client.lock().await;
317 if let Some(ref client) = *guard {
318 client.list_all_tools().await.map(RetryResult::ok).map_err(|e| e.to_string())
319 } else {
320 Err("No MCP client available".to_string())
321 }
322 }
323
324 pub async fn call_tool(
326 &self,
327 params: CallToolRequestParams,
328 ) -> Result<RetryResult<CallToolResult>, String> {
329 self.ensure_connected().await?;
331
332 {
334 let guard = self.client.lock().await;
335 if let Some(ref client) = *guard {
336 match client.call_tool(params.clone()).await {
337 Ok(result) => return Ok(RetryResult::ok(result)),
338 Err(e) => {
339 let error_str = e.to_string();
340 if !should_refresh_connection(&error_str) {
341 return Err(error_str);
342 }
343 if self.config.log_reconnections {
344 warn!(error = %error_str, tool = %params.name, "call_tool failed, will retry with reconnection");
345 }
346 }
347 }
348 }
349 }
350
351 for attempt in 1..=self.config.max_attempts {
353 if self.config.log_reconnections {
354 info!(attempt = attempt, max = self.config.max_attempts, tool = %params.name, "Reconnection attempt for call_tool");
355 }
356
357 if self.config.retry_delay_ms > 0 {
359 tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
360 .await;
361 }
362
363 if let Err(e) = self.refresh_connection().await {
365 if self.config.log_reconnections {
366 warn!(error = %e, attempt = attempt, "Refresh failed");
367 }
368 continue;
369 }
370
371 let guard = self.client.lock().await;
373 if let Some(ref client) = *guard {
374 match client.call_tool(params.clone()).await {
375 Ok(result) => {
376 if self.config.log_reconnections {
377 debug!(attempt = attempt, tool = %params.name, "call_tool succeeded after reconnection");
378 }
379 return Ok(RetryResult::reconnected(result));
380 }
381 Err(e) => {
382 if self.config.log_reconnections {
383 warn!(error = %e, attempt = attempt, "call_tool failed after reconnection");
384 }
385 }
386 }
387 }
388 }
389
390 let guard = self.client.lock().await;
392 if let Some(ref client) = *guard {
393 client.call_tool(params).await.map(RetryResult::ok).map_err(|e| e.to_string())
394 } else {
395 Err("No MCP client available".to_string())
396 }
397 }
398
399 pub async fn cancellation_token(
401 &self,
402 ) -> Option<rmcp::service::RunningServiceCancellationToken> {
403 let guard = self.client.lock().await;
404 guard.as_ref().map(|c| c.cancellation_token())
405 }
406
407 pub async fn is_connected(&self) -> bool {
409 let guard = self.client.lock().await;
410 guard.is_some()
411 }
412
413 pub async fn reconnect(&self) -> Result<(), String> {
415 self.refresh_connection().await
416 }
417
418 pub async fn close(&self) {
420 let mut guard = self.client.lock().await;
421 if let Some(client) = guard.take() {
422 let token = client.cancellation_token();
423 token.cancel();
424 }
425 }
426}
427
428pub struct SimpleClient<S>
433where
434 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
435{
436 client: Arc<Mutex<RunningService<RoleClient, S>>>,
437}
438
439impl<S> SimpleClient<S>
440where
441 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
442{
443 pub fn new(client: RunningService<RoleClient, S>) -> Self {
445 Self { client: Arc::new(Mutex::new(client)) }
446 }
447
448 pub async fn list_tools(&self) -> Result<Vec<McpTool>, String> {
450 let client = self.client.lock().await;
451 client.list_all_tools().await.map_err(|e| e.to_string())
452 }
453
454 pub async fn call_tool(&self, params: CallToolRequestParams) -> Result<CallToolResult, String> {
456 let client = self.client.lock().await;
457 client.call_tool(params).await.map_err(|e| e.to_string())
458 }
459
460 pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
462 let client = self.client.lock().await;
463 client.cancellation_token()
464 }
465
466 pub fn inner(&self) -> &Arc<Mutex<RunningService<RoleClient, S>>> {
468 &self.client
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_should_refresh_connection() {
478 assert!(should_refresh_connection("connection closed"));
479 assert!(should_refresh_connection("ConnectionClosed"));
480 assert!(should_refresh_connection("EOF"));
481 assert!(should_refresh_connection("eof error"));
482 assert!(should_refresh_connection("broken pipe"));
483 assert!(should_refresh_connection("session not found"));
484 assert!(should_refresh_connection("transport error"));
485 assert!(should_refresh_connection("connection reset"));
486
487 assert!(!should_refresh_connection("invalid argument"));
489 assert!(!should_refresh_connection("permission denied"));
490 assert!(!should_refresh_connection("tool not found"));
491 }
492
493 #[test]
494 fn test_refresh_config_default() {
495 let config = RefreshConfig::default();
496 assert_eq!(config.max_attempts, 3);
497 assert_eq!(config.retry_delay_ms, 1000);
498 assert!(config.log_reconnections);
499 }
500
501 #[test]
502 fn test_refresh_config_builder() {
503 let config = RefreshConfig::default()
504 .with_max_attempts(5)
505 .with_retry_delay_ms(500)
506 .without_logging();
507
508 assert_eq!(config.max_attempts, 5);
509 assert_eq!(config.retry_delay_ms, 500);
510 assert!(!config.log_reconnections);
511 }
512
513 #[test]
514 fn test_retry_result() {
515 let ok_result = RetryResult::ok(42);
516 assert_eq!(ok_result.value, 42);
517 assert!(!ok_result.reconnected);
518
519 let reconnected_result = RetryResult::reconnected(42);
520 assert_eq!(reconnected_result.value, 42);
521 assert!(reconnected_result.reconnected);
522 }
523}