1use std::borrow::Cow;
33use std::sync::Arc;
34
35use rmcp::ServiceExt;
36use rmcp::model::RawContent;
37use tokio::sync::RwLock;
38
39use crate::completion::ToolDefinition;
40use crate::tool::ToolDyn;
41use crate::tool::ToolError;
42use crate::tool::server::{ToolServerError, ToolServerHandle};
43use crate::wasm_compat::WasmBoxedFuture;
44
45#[derive(Clone)]
50pub struct McpTool {
51 definition: rmcp::model::Tool,
52 client: rmcp::service::ServerSink,
53}
54
55impl McpTool {
56 pub fn from_mcp_server(
58 definition: rmcp::model::Tool,
59 client: rmcp::service::ServerSink,
60 ) -> Self {
61 Self { definition, client }
62 }
63}
64
65impl From<&rmcp::model::Tool> for ToolDefinition {
66 fn from(val: &rmcp::model::Tool) -> Self {
67 Self {
68 name: val.name.to_string(),
69 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
70 parameters: val.schema_as_json_value(),
71 }
72 }
73}
74
75impl From<rmcp::model::Tool> for ToolDefinition {
76 fn from(val: rmcp::model::Tool) -> Self {
77 Self {
78 name: val.name.to_string(),
79 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
80 parameters: val.schema_as_json_value(),
81 }
82 }
83}
84
85#[derive(Debug, thiserror::Error)]
86#[error("MCP tool error: {0}")]
87pub struct McpToolError(String);
88
89impl From<McpToolError> for ToolError {
90 fn from(e: McpToolError) -> Self {
91 ToolError::ToolCallError(Box::new(e))
92 }
93}
94
95impl ToolDyn for McpTool {
96 fn name(&self) -> String {
97 self.definition.name.to_string()
98 }
99
100 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
101 Box::pin(async move {
102 ToolDefinition {
103 name: self.definition.name.to_string(),
104 description: self
105 .definition
106 .description
107 .clone()
108 .unwrap_or(Cow::from(""))
109 .to_string(),
110 parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
111 }
112 })
113 }
114
115 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
116 let name = self.definition.name.clone();
117 let arguments: Option<rmcp::model::JsonObject> =
118 serde_json::from_str(&args).unwrap_or_default();
119
120 Box::pin(async move {
121 let request = arguments
122 .map(|arguments| {
123 rmcp::model::CallToolRequestParams::new(name.clone()).with_arguments(arguments)
124 })
125 .unwrap_or_else(|| rmcp::model::CallToolRequestParams::new(name));
126
127 let result = self
128 .client
129 .call_tool(request)
130 .await
131 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
132
133 if let Some(true) = result.is_error {
134 let error_msg = result
135 .content
136 .into_iter()
137 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
138 .map(|x| x.map(|x| x.clone().text))
139 .collect::<Option<Vec<String>>>();
140
141 let error_message = error_msg.map(|x| x.join("\n"));
142 if let Some(error_message) = error_message {
143 return Err(McpToolError(error_message).into());
144 } else {
145 return Err(McpToolError("No message returned".to_string()).into());
146 }
147 };
148
149 let mut content = String::new();
150
151 for item in result.content {
152 let chunk = match item.raw {
153 rmcp::model::RawContent::Text(raw) => raw.text,
154 rmcp::model::RawContent::Image(raw) => {
155 format!("data:{};base64,{}", raw.mime_type, raw.data)
156 }
157 rmcp::model::RawContent::Resource(raw) => match raw.resource {
158 rmcp::model::ResourceContents::TextResourceContents {
159 uri,
160 mime_type,
161 text,
162 ..
163 } => {
164 format!(
165 "{mime_type}{uri}:{text}",
166 mime_type =
167 mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
168 )
169 }
170 rmcp::model::ResourceContents::BlobResourceContents {
171 uri,
172 mime_type,
173 blob,
174 ..
175 } => format!(
176 "{mime_type}{uri}:{blob}",
177 mime_type = mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
178 ),
179 },
180 RawContent::Audio(_) => {
181 return Err(McpToolError(
182 "MCP tool returned audio content, which Rig does not support yet"
183 .to_string(),
184 )
185 .into());
186 }
187 thing => {
188 return Err(McpToolError(format!(
189 "MCP tool returned unsupported content: {thing:?}"
190 ))
191 .into());
192 }
193 };
194
195 content.push_str(&chunk);
196 }
197
198 Ok(content)
199 })
200 }
201}
202
203#[derive(Debug, thiserror::Error)]
205pub enum McpClientError {
206 #[error("MCP connection error: {0}")]
208 ConnectionError(String),
209
210 #[error("Failed to fetch MCP tool list: {0}")]
212 ToolFetchError(#[from] rmcp::ServiceError),
213
214 #[error("Tool server error: {0}")]
216 ToolServerError(#[from] ToolServerError),
217}
218
219pub struct McpClientHandler {
243 client_info: rmcp::model::ClientInfo,
244 tool_server_handle: ToolServerHandle,
245 managed_tool_names: Arc<RwLock<Vec<String>>>,
248}
249
250impl McpClientHandler {
251 pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
256 Self {
257 client_info,
258 tool_server_handle,
259 managed_tool_names: Arc::new(RwLock::new(Vec::new())),
260 }
261 }
262
263 pub async fn connect<T, E, A>(
276 self,
277 transport: T,
278 ) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
279 where
280 T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
281 E: std::error::Error + Send + Sync + 'static,
282 {
283 let service = ServiceExt::serve(self, transport)
284 .await
285 .map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
286
287 let tools = service.peer().list_all_tools().await?;
288
289 {
290 let handler = service.service();
291 let mut managed = handler.managed_tool_names.write().await;
292
293 for tool in tools {
294 let tool_name = tool.name.to_string();
295 let mcp_tool = McpTool::from_mcp_server(tool, service.peer().clone());
296 handler.tool_server_handle.add_tool(mcp_tool).await?;
297 managed.push(tool_name);
298 }
299 }
300
301 Ok(service)
302 }
303}
304
305impl rmcp::handler::client::ClientHandler for McpClientHandler {
306 fn get_info(&self) -> rmcp::model::ClientInfo {
307 self.client_info.clone()
308 }
309
310 async fn on_tool_list_changed(
311 &self,
312 context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
313 ) {
314 let tools = match context.peer.list_all_tools().await {
315 Ok(tools) => tools,
316 Err(e) => {
317 tracing::error!("Failed to re-fetch MCP tool list: {e}");
318 return;
319 }
320 };
321
322 let mut managed = self.managed_tool_names.write().await;
323
324 for name in managed.drain(..) {
325 if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
326 tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
327 }
328 }
329
330 for tool in tools {
331 let tool_name = tool.name.to_string();
332 let mcp_tool = McpTool::from_mcp_server(tool, context.peer.clone());
333 match self.tool_server_handle.add_tool(mcp_tool).await {
334 Ok(()) => {
335 managed.push(tool_name);
336 }
337 Err(e) => {
338 tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
339 }
340 }
341 }
342
343 tracing::info!(
344 tool_count = managed.len(),
345 "MCP tool list refreshed successfully"
346 );
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use std::sync::Arc;
353 use std::time::Duration;
354
355 use rmcp::handler::client::ClientHandler;
356 use rmcp::model::*;
357 use rmcp::service::RequestContext;
358 use rmcp::{RoleServer, ServerHandler, ServiceExt};
359 use tokio::sync::RwLock;
360
361 use super::McpClientHandler;
362 use crate::tool::server::ToolServer;
363
364 #[derive(Clone)]
366 struct DynamicToolServer {
367 tools: Arc<RwLock<Vec<Tool>>>,
368 }
369
370 impl DynamicToolServer {
371 fn new(tools: Vec<Tool>) -> Self {
372 Self {
373 tools: Arc::new(RwLock::new(tools)),
374 }
375 }
376
377 async fn set_tools(&self, tools: Vec<Tool>) {
378 *self.tools.write().await = tools;
379 }
380 }
381
382 impl ServerHandler for DynamicToolServer {
383 fn get_info(&self) -> ServerInfo {
384 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
385 .with_protocol_version(ProtocolVersion::LATEST)
386 .with_server_info(Implementation::new("test-dynamic-server", "0.1.0"))
387 }
388
389 async fn list_tools(
390 &self,
391 _request: Option<PaginatedRequestParams>,
392 _context: RequestContext<RoleServer>,
393 ) -> Result<ListToolsResult, ErrorData> {
394 let tools = self.tools.read().await.clone();
395 Ok(ListToolsResult::with_all_items(tools))
396 }
397
398 async fn call_tool(
399 &self,
400 request: CallToolRequestParams,
401 _context: RequestContext<RoleServer>,
402 ) -> Result<CallToolResult, ErrorData> {
403 Ok(CallToolResult::success(vec![Content::text(format!(
404 "called {}",
405 request.name
406 ))]))
407 }
408 }
409
410 fn make_tool(name: &str, description: &str) -> Tool {
411 Tool::new(
412 name.to_string(),
413 description.to_string(),
414 Arc::new(serde_json::Map::new()),
415 )
416 }
417
418 #[tokio::test]
419 async fn test_mcp_client_handler_initial_tool_registration() {
420 let initial_tools = vec![
421 make_tool("tool_a", "First tool"),
422 make_tool("tool_b", "Second tool"),
423 ];
424
425 let server = DynamicToolServer::new(initial_tools);
426 let tool_server_handle = ToolServer::new().run();
427
428 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
429 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
430
431 let server_clone = server.clone();
432 tokio::spawn(async move {
433 let _service = server_clone
434 .serve((server_from_client, server_to_client))
435 .await
436 .expect("server failed to start");
437 _service.waiting().await.expect("server error");
438 });
439
440 let client_info = ClientInfo::default();
441 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
442
443 let _mcp_service = handler
444 .connect((client_from_server, client_to_server))
445 .await
446 .expect("connect failed");
447
448 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
449 assert_eq!(defs.len(), 2);
450
451 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
452 assert!(names.contains(&"tool_a"));
453 assert!(names.contains(&"tool_b"));
454 }
455
456 #[tokio::test]
457 async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
458 let initial_tools = vec![make_tool("alpha", "Alpha tool")];
459
460 let server = DynamicToolServer::new(initial_tools);
461 let tool_server_handle = ToolServer::new().run();
462
463 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
464 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
465
466 let server_clone = server.clone();
467 let server_service_handle = tokio::spawn(async move {
468 server_clone
469 .serve((server_from_client, server_to_client))
470 .await
471 .expect("server failed to start")
472 });
473
474 let client_info = ClientInfo::default();
475 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
476
477 let _mcp_service = handler
478 .connect((client_from_server, client_to_server))
479 .await
480 .expect("connect failed");
481
482 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
484 assert_eq!(defs.len(), 1);
485 assert_eq!(defs[0].name, "alpha");
486
487 server
489 .set_tools(vec![
490 make_tool("beta", "Beta tool"),
491 make_tool("gamma", "Gamma tool"),
492 ])
493 .await;
494
495 let server_service = server_service_handle.await.unwrap();
497 server_service
498 .peer()
499 .notify_tool_list_changed()
500 .await
501 .expect("failed to send notification");
502
503 tokio::time::sleep(Duration::from_millis(200)).await;
506
507 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
508 assert_eq!(defs.len(), 2);
509
510 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
511 assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
512 assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
513 assert!(
515 !names.contains(&"alpha"),
516 "expected 'alpha' to be removed, found {names:?}"
517 );
518 }
519
520 #[tokio::test]
521 async fn test_mcp_client_handler_get_info_delegates() {
522 let client_info = ClientInfo::new(
523 ClientCapabilities::default(),
524 Implementation::new("test-client", "1.0.0"),
525 );
526
527 let tool_server_handle = ToolServer::new().run();
528 let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
529
530 let returned = handler.get_info();
531 assert_eq!(returned.client_info.name, "test-client");
532 assert_eq!(returned.client_info.version, "1.0.0");
533 }
534}