use rmcp::{
RoleClient,
model::{CallToolRequestParams, CallToolResult, Tool as McpTool},
service::RunningService,
};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
pub fn should_refresh_connection(error: &str) -> bool {
let error_lower = error.to_lowercase();
if error_lower.contains("connection closed") || error_lower.contains("connectionclosed") {
return true;
}
if error_lower.contains("eof")
|| error_lower.contains("closed pipe")
|| error_lower.contains("broken pipe")
{
return true;
}
if error_lower.contains("session not found") || error_lower.contains("session missing") {
return true;
}
if error_lower.contains("transport error") || error_lower.contains("connection reset") {
return true;
}
false
}
#[derive(Debug, Clone)]
pub struct RetryResult<T> {
pub value: T,
pub reconnected: bool,
}
impl<T> RetryResult<T> {
pub fn ok(value: T) -> Self {
Self { value, reconnected: false }
}
pub fn reconnected(value: T) -> Self {
Self { value, reconnected: true }
}
}
#[derive(Debug, Clone)]
pub struct RefreshConfig {
pub max_attempts: u32,
pub retry_delay_ms: u64,
pub log_reconnections: bool,
}
impl Default for RefreshConfig {
fn default() -> Self {
Self { max_attempts: 3, retry_delay_ms: 1000, log_reconnections: true }
}
}
impl RefreshConfig {
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
self.max_attempts = attempts;
self
}
pub fn with_retry_delay_ms(mut self, delay_ms: u64) -> Self {
self.retry_delay_ms = delay_ms;
self
}
pub fn without_logging(mut self) -> Self {
self.log_reconnections = false;
self
}
}
#[async_trait::async_trait]
pub trait ConnectionFactory<S>: Send + Sync
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
async fn create_connection(&self) -> Result<RunningService<RoleClient, S>, String>;
}
pub struct ConnectionRefresher<S, F>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
F: ConnectionFactory<S>,
{
client: Arc<Mutex<Option<RunningService<RoleClient, S>>>>,
factory: Arc<F>,
config: RefreshConfig,
}
impl<S, F> ConnectionRefresher<S, F>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
F: ConnectionFactory<S>,
{
pub fn new(client: RunningService<RoleClient, S>, factory: Arc<F>) -> Self {
Self {
client: Arc::new(Mutex::new(Some(client))),
factory,
config: RefreshConfig::default(),
}
}
pub fn lazy(factory: Arc<F>) -> Self {
Self { client: Arc::new(Mutex::new(None)), factory, config: RefreshConfig::default() }
}
pub fn with_config(mut self, config: RefreshConfig) -> Self {
self.config = config;
self
}
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
self.config.max_attempts = attempts;
self
}
async fn ensure_connected(&self) -> Result<(), String> {
let mut guard = self.client.lock().await;
if guard.is_none() {
if self.config.log_reconnections {
info!("MCP client not connected, creating connection");
}
let new_client = self.factory.create_connection().await?;
*guard = Some(new_client);
}
Ok(())
}
async fn refresh_connection(&self) -> Result<(), String> {
let mut guard = self.client.lock().await;
if let Some(old_client) = guard.take() {
if self.config.log_reconnections {
debug!("Closing old MCP connection");
}
let token = old_client.cancellation_token();
token.cancel();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
if self.config.log_reconnections {
info!("Refreshing MCP connection");
}
let new_client = self.factory.create_connection().await?;
*guard = Some(new_client);
Ok(())
}
pub async fn list_tools(&self) -> Result<RetryResult<Vec<McpTool>>, String> {
self.ensure_connected().await?;
{
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
match client.list_all_tools().await {
Ok(tools) => return Ok(RetryResult::ok(tools)),
Err(e) => {
let error_str = e.to_string();
if !should_refresh_connection(&error_str) {
return Err(error_str);
}
if self.config.log_reconnections {
warn!(error = %error_str, "list_tools failed, will retry with reconnection");
}
}
}
}
}
for attempt in 1..=self.config.max_attempts {
if self.config.log_reconnections {
info!(
attempt = attempt,
max = self.config.max_attempts,
"Reconnection attempt for list_tools"
);
}
if self.config.retry_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
.await;
}
if let Err(e) = self.refresh_connection().await {
if self.config.log_reconnections {
warn!(error = %e, attempt = attempt, "Refresh failed");
}
continue;
}
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
match client.list_all_tools().await {
Ok(tools) => {
if self.config.log_reconnections {
debug!(
attempt = attempt,
tool_count = tools.len(),
"list_tools succeeded after reconnection"
);
}
return Ok(RetryResult::reconnected(tools));
}
Err(e) => {
if self.config.log_reconnections {
warn!(error = %e, attempt = attempt, "list_tools failed after reconnection");
}
}
}
}
}
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
client.list_all_tools().await.map(RetryResult::ok).map_err(|e| e.to_string())
} else {
Err("No MCP client available".to_string())
}
}
pub async fn call_tool(
&self,
params: CallToolRequestParams,
) -> Result<RetryResult<CallToolResult>, String> {
self.ensure_connected().await?;
{
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
match client.call_tool(params.clone()).await {
Ok(result) => return Ok(RetryResult::ok(result)),
Err(e) => {
let error_str = e.to_string();
if !should_refresh_connection(&error_str) {
return Err(error_str);
}
if self.config.log_reconnections {
warn!(error = %error_str, tool = %params.name, "call_tool failed, will retry with reconnection");
}
}
}
}
}
for attempt in 1..=self.config.max_attempts {
if self.config.log_reconnections {
info!(attempt = attempt, max = self.config.max_attempts, tool = %params.name, "Reconnection attempt for call_tool");
}
if self.config.retry_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(self.config.retry_delay_ms))
.await;
}
if let Err(e) = self.refresh_connection().await {
if self.config.log_reconnections {
warn!(error = %e, attempt = attempt, "Refresh failed");
}
continue;
}
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
match client.call_tool(params.clone()).await {
Ok(result) => {
if self.config.log_reconnections {
debug!(attempt = attempt, tool = %params.name, "call_tool succeeded after reconnection");
}
return Ok(RetryResult::reconnected(result));
}
Err(e) => {
if self.config.log_reconnections {
warn!(error = %e, attempt = attempt, "call_tool failed after reconnection");
}
}
}
}
}
let guard = self.client.lock().await;
if let Some(ref client) = *guard {
client.call_tool(params).await.map(RetryResult::ok).map_err(|e| e.to_string())
} else {
Err("No MCP client available".to_string())
}
}
pub async fn cancellation_token(
&self,
) -> Option<rmcp::service::RunningServiceCancellationToken> {
let guard = self.client.lock().await;
guard.as_ref().map(|c| c.cancellation_token())
}
pub async fn is_connected(&self) -> bool {
let guard = self.client.lock().await;
guard.is_some()
}
pub async fn reconnect(&self) -> Result<(), String> {
self.refresh_connection().await
}
pub async fn close(&self) {
let mut guard = self.client.lock().await;
if let Some(client) = guard.take() {
let token = client.cancellation_token();
token.cancel();
}
}
}
pub struct SimpleClient<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
client: Arc<Mutex<RunningService<RoleClient, S>>>,
}
impl<S> SimpleClient<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
pub fn new(client: RunningService<RoleClient, S>) -> Self {
Self { client: Arc::new(Mutex::new(client)) }
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>, String> {
let client = self.client.lock().await;
client.list_all_tools().await.map_err(|e| e.to_string())
}
pub async fn call_tool(&self, params: CallToolRequestParams) -> Result<CallToolResult, String> {
let client = self.client.lock().await;
client.call_tool(params).await.map_err(|e| e.to_string())
}
pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
let client = self.client.lock().await;
client.cancellation_token()
}
pub fn inner(&self) -> &Arc<Mutex<RunningService<RoleClient, S>>> {
&self.client
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_refresh_connection() {
assert!(should_refresh_connection("connection closed"));
assert!(should_refresh_connection("ConnectionClosed"));
assert!(should_refresh_connection("EOF"));
assert!(should_refresh_connection("eof error"));
assert!(should_refresh_connection("broken pipe"));
assert!(should_refresh_connection("session not found"));
assert!(should_refresh_connection("transport error"));
assert!(should_refresh_connection("connection reset"));
assert!(!should_refresh_connection("invalid argument"));
assert!(!should_refresh_connection("permission denied"));
assert!(!should_refresh_connection("tool not found"));
}
#[test]
fn test_refresh_config_default() {
let config = RefreshConfig::default();
assert_eq!(config.max_attempts, 3);
assert_eq!(config.retry_delay_ms, 1000);
assert!(config.log_reconnections);
}
#[test]
fn test_refresh_config_builder() {
let config = RefreshConfig::default()
.with_max_attempts(5)
.with_retry_delay_ms(500)
.without_logging();
assert_eq!(config.max_attempts, 5);
assert_eq!(config.retry_delay_ms, 500);
assert!(!config.log_reconnections);
}
#[test]
fn test_retry_result() {
let ok_result = RetryResult::ok(42);
assert_eq!(ok_result.value, 42);
assert!(!ok_result.reconnected);
let reconnected_result = RetryResult::reconnected(42);
assert_eq!(reconnected_result.value, 42);
assert!(reconnected_result.reconnected);
}
}