use std::borrow::Cow;
use std::sync::Arc;
use rmcp::ServiceExt;
use rmcp::model::RawContent;
use tokio::sync::RwLock;
use crate::completion::ToolDefinition;
use crate::tool::ToolDyn;
use crate::tool::ToolError;
use crate::tool::server::{ToolServerError, ToolServerHandle};
use crate::wasm_compat::WasmBoxedFuture;
#[derive(Clone)]
pub struct McpTool {
definition: rmcp::model::Tool,
client: rmcp::service::ServerSink,
}
impl McpTool {
pub fn from_mcp_server(
definition: rmcp::model::Tool,
client: rmcp::service::ServerSink,
) -> Self {
Self { definition, client }
}
}
impl From<&rmcp::model::Tool> for ToolDefinition {
fn from(val: &rmcp::model::Tool) -> Self {
Self {
name: val.name.to_string(),
description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
parameters: val.schema_as_json_value(),
}
}
}
impl From<rmcp::model::Tool> for ToolDefinition {
fn from(val: rmcp::model::Tool) -> Self {
Self {
name: val.name.to_string(),
description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
parameters: val.schema_as_json_value(),
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("MCP tool error: {0}")]
pub struct McpToolError(String);
impl From<McpToolError> for ToolError {
fn from(e: McpToolError) -> Self {
ToolError::ToolCallError(Box::new(e))
}
}
impl ToolDyn for McpTool {
fn name(&self) -> String {
self.definition.name.to_string()
}
fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
Box::pin(async move {
ToolDefinition {
name: self.definition.name.to_string(),
description: self
.definition
.description
.clone()
.unwrap_or(Cow::from(""))
.to_string(),
parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
}
})
}
fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
let name = self.definition.name.clone();
let arguments: Option<rmcp::model::JsonObject> =
serde_json::from_str(&args).unwrap_or_default();
Box::pin(async move {
let request = arguments
.map(|arguments| {
rmcp::model::CallToolRequestParams::new(name.clone()).with_arguments(arguments)
})
.unwrap_or_else(|| rmcp::model::CallToolRequestParams::new(name));
let result = self
.client
.call_tool(request)
.await
.map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
if let Some(true) = result.is_error {
let error_msg = result
.content
.into_iter()
.map(|x| x.raw.as_text().map(|y| y.to_owned()))
.map(|x| x.map(|x| x.clone().text))
.collect::<Option<Vec<String>>>();
let error_message = error_msg.map(|x| x.join("\n"));
if let Some(error_message) = error_message {
return Err(McpToolError(error_message).into());
} else {
return Err(McpToolError("No message returned".to_string()).into());
}
};
let mut content = String::new();
for item in result.content {
let chunk = match item.raw {
rmcp::model::RawContent::Text(raw) => raw.text,
rmcp::model::RawContent::Image(raw) => {
format!("data:{};base64,{}", raw.mime_type, raw.data)
}
rmcp::model::RawContent::Resource(raw) => match raw.resource {
rmcp::model::ResourceContents::TextResourceContents {
uri,
mime_type,
text,
..
} => {
format!(
"{mime_type}{uri}:{text}",
mime_type =
mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
)
}
rmcp::model::ResourceContents::BlobResourceContents {
uri,
mime_type,
blob,
..
} => format!(
"{mime_type}{uri}:{blob}",
mime_type = mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
),
},
RawContent::Audio(_) => {
return Err(McpToolError(
"MCP tool returned audio content, which Rig does not support yet"
.to_string(),
)
.into());
}
thing => {
return Err(McpToolError(format!(
"MCP tool returned unsupported content: {thing:?}"
))
.into());
}
};
content.push_str(&chunk);
}
Ok(content)
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum McpClientError {
#[error("MCP connection error: {0}")]
ConnectionError(String),
#[error("Failed to fetch MCP tool list: {0}")]
ToolFetchError(#[from] rmcp::ServiceError),
#[error("Tool server error: {0}")]
ToolServerError(#[from] ToolServerError),
}
pub struct McpClientHandler {
client_info: rmcp::model::ClientInfo,
tool_server_handle: ToolServerHandle,
managed_tool_names: Arc<RwLock<Vec<String>>>,
}
impl McpClientHandler {
pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
Self {
client_info,
tool_server_handle,
managed_tool_names: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn connect<T, E, A>(
self,
transport: T,
) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
where
T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
let service = ServiceExt::serve(self, transport)
.await
.map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
let tools = service.peer().list_all_tools().await?;
{
let handler = service.service();
let mut managed = handler.managed_tool_names.write().await;
for tool in tools {
let tool_name = tool.name.to_string();
let mcp_tool = McpTool::from_mcp_server(tool, service.peer().clone());
handler.tool_server_handle.add_tool(mcp_tool).await?;
managed.push(tool_name);
}
}
Ok(service)
}
}
impl rmcp::handler::client::ClientHandler for McpClientHandler {
fn get_info(&self) -> rmcp::model::ClientInfo {
self.client_info.clone()
}
async fn on_tool_list_changed(
&self,
context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
) {
let tools = match context.peer.list_all_tools().await {
Ok(tools) => tools,
Err(e) => {
tracing::error!("Failed to re-fetch MCP tool list: {e}");
return;
}
};
let mut managed = self.managed_tool_names.write().await;
for name in managed.drain(..) {
if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
}
}
for tool in tools {
let tool_name = tool.name.to_string();
let mcp_tool = McpTool::from_mcp_server(tool, context.peer.clone());
match self.tool_server_handle.add_tool(mcp_tool).await {
Ok(()) => {
managed.push(tool_name);
}
Err(e) => {
tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
}
}
}
tracing::info!(
tool_count = managed.len(),
"MCP tool list refreshed successfully"
);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use rmcp::handler::client::ClientHandler;
use rmcp::model::*;
use rmcp::service::RequestContext;
use rmcp::{RoleServer, ServerHandler, ServiceExt};
use tokio::sync::RwLock;
use super::McpClientHandler;
use crate::tool::server::ToolServer;
#[derive(Clone)]
struct DynamicToolServer {
tools: Arc<RwLock<Vec<Tool>>>,
}
impl DynamicToolServer {
fn new(tools: Vec<Tool>) -> Self {
Self {
tools: Arc::new(RwLock::new(tools)),
}
}
async fn set_tools(&self, tools: Vec<Tool>) {
*self.tools.write().await = tools;
}
}
impl ServerHandler for DynamicToolServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
.with_protocol_version(ProtocolVersion::LATEST)
.with_server_info(Implementation::new("test-dynamic-server", "0.1.0"))
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListToolsResult, ErrorData> {
let tools = self.tools.read().await.clone();
Ok(ListToolsResult::with_all_items(tools))
}
async fn call_tool(
&self,
request: CallToolRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
Ok(CallToolResult::success(vec![Content::text(format!(
"called {}",
request.name
))]))
}
}
fn make_tool(name: &str, description: &str) -> Tool {
Tool::new(
name.to_string(),
description.to_string(),
Arc::new(serde_json::Map::new()),
)
}
#[tokio::test]
async fn test_mcp_client_handler_initial_tool_registration() {
let initial_tools = vec![
make_tool("tool_a", "First tool"),
make_tool("tool_b", "Second tool"),
];
let server = DynamicToolServer::new(initial_tools);
let tool_server_handle = ToolServer::new().run();
let (client_to_server, server_from_client) = tokio::io::duplex(8192);
let (server_to_client, client_from_server) = tokio::io::duplex(8192);
let server_clone = server.clone();
tokio::spawn(async move {
let _service = server_clone
.serve((server_from_client, server_to_client))
.await
.expect("server failed to start");
_service.waiting().await.expect("server error");
});
let client_info = ClientInfo::default();
let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
let _mcp_service = handler
.connect((client_from_server, client_to_server))
.await
.expect("connect failed");
let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
assert_eq!(defs.len(), 2);
let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert!(names.contains(&"tool_a"));
assert!(names.contains(&"tool_b"));
}
#[tokio::test]
async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
let initial_tools = vec![make_tool("alpha", "Alpha tool")];
let server = DynamicToolServer::new(initial_tools);
let tool_server_handle = ToolServer::new().run();
let (client_to_server, server_from_client) = tokio::io::duplex(8192);
let (server_to_client, client_from_server) = tokio::io::duplex(8192);
let server_clone = server.clone();
let server_service_handle = tokio::spawn(async move {
server_clone
.serve((server_from_client, server_to_client))
.await
.expect("server failed to start")
});
let client_info = ClientInfo::default();
let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
let _mcp_service = handler
.connect((client_from_server, client_to_server))
.await
.expect("connect failed");
let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "alpha");
server
.set_tools(vec![
make_tool("beta", "Beta tool"),
make_tool("gamma", "Gamma tool"),
])
.await;
let server_service = server_service_handle.await.unwrap();
server_service
.peer()
.notify_tool_list_changed()
.await
.expect("failed to send notification");
tokio::time::sleep(Duration::from_millis(200)).await;
let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
assert_eq!(defs.len(), 2);
let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
assert!(
!names.contains(&"alpha"),
"expected 'alpha' to be removed, found {names:?}"
);
}
#[tokio::test]
async fn test_mcp_client_handler_get_info_delegates() {
let client_info = ClientInfo::new(
ClientCapabilities::default(),
Implementation::new("test-client", "1.0.0"),
);
let tool_server_handle = ToolServer::new().run();
let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
let returned = handler.get_info();
assert_eq!(returned.client_info.name, "test-client");
assert_eq!(returned.client_info.version, "1.0.0");
}
}