1use std::borrow::Cow;
33use std::sync::Arc;
34
35use rmcp::ServiceExt;
36use rmcp::model::RawContent;
37use tokio::sync::RwLock;
38
39use crate::completion::ToolDefinition;
40use crate::tool::ToolDyn;
41use crate::tool::ToolError;
42use crate::tool::server::{ToolServerError, ToolServerHandle};
43use crate::wasm_compat::WasmBoxedFuture;
44
45#[derive(Clone)]
50pub struct McpTool {
51 definition: rmcp::model::Tool,
52 client: rmcp::service::ServerSink,
53}
54
55impl McpTool {
56 pub fn from_mcp_server(
58 definition: rmcp::model::Tool,
59 client: rmcp::service::ServerSink,
60 ) -> Self {
61 Self { definition, client }
62 }
63}
64
65impl From<&rmcp::model::Tool> for ToolDefinition {
66 fn from(val: &rmcp::model::Tool) -> Self {
67 Self {
68 name: val.name.to_string(),
69 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
70 parameters: val.schema_as_json_value(),
71 }
72 }
73}
74
75impl From<rmcp::model::Tool> for ToolDefinition {
76 fn from(val: rmcp::model::Tool) -> Self {
77 Self {
78 name: val.name.to_string(),
79 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
80 parameters: val.schema_as_json_value(),
81 }
82 }
83}
84
85#[derive(Debug, thiserror::Error)]
86#[error("MCP tool error: {0}")]
87pub struct McpToolError(String);
88
89impl From<McpToolError> for ToolError {
90 fn from(e: McpToolError) -> Self {
91 ToolError::ToolCallError(Box::new(e))
92 }
93}
94
95impl ToolDyn for McpTool {
96 fn name(&self) -> String {
97 self.definition.name.to_string()
98 }
99
100 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
101 Box::pin(async move {
102 ToolDefinition {
103 name: self.definition.name.to_string(),
104 description: self
105 .definition
106 .description
107 .clone()
108 .unwrap_or(Cow::from(""))
109 .to_string(),
110 parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
111 }
112 })
113 }
114
115 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
116 let name = self.definition.name.clone();
117 let arguments = serde_json::from_str(&args).unwrap_or_default();
118
119 Box::pin(async move {
120 let result = self
121 .client
122 .call_tool(rmcp::model::CallToolRequestParams {
123 name,
124 arguments,
125 meta: None,
126 task: None,
127 })
128 .await
129 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
130
131 if let Some(true) = result.is_error {
132 let error_msg = result
133 .content
134 .into_iter()
135 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
136 .map(|x| x.map(|x| x.clone().text))
137 .collect::<Option<Vec<String>>>();
138
139 let error_message = error_msg.map(|x| x.join("\n"));
140 if let Some(error_message) = error_message {
141 return Err(McpToolError(error_message).into());
142 } else {
143 return Err(McpToolError("No message returned".to_string()).into());
144 }
145 };
146
147 Ok(result
148 .content
149 .into_iter()
150 .map(|c| match c.raw {
151 rmcp::model::RawContent::Text(raw) => raw.text,
152 rmcp::model::RawContent::Image(raw) => {
153 format!("data:{};base64,{}", raw.mime_type, raw.data)
154 }
155 rmcp::model::RawContent::Resource(raw) => match raw.resource {
156 rmcp::model::ResourceContents::TextResourceContents {
157 uri,
158 mime_type,
159 text,
160 ..
161 } => {
162 format!(
163 "{mime_type}{uri}:{text}",
164 mime_type = mime_type
165 .map(|m| format!("data:{m};"))
166 .unwrap_or_default(),
167 )
168 }
169 rmcp::model::ResourceContents::BlobResourceContents {
170 uri,
171 mime_type,
172 blob,
173 ..
174 } => format!(
175 "{mime_type}{uri}:{blob}",
176 mime_type = mime_type
177 .map(|m| format!("data:{m};"))
178 .unwrap_or_default(),
179 ),
180 },
181 RawContent::Audio(_) => {
182 panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
183 }
184 thing => {
185 panic!("Unsupported type found: {thing:?}")
186 }
187 })
188 .collect::<String>())
189 })
190 }
191}
192
193#[derive(Debug, thiserror::Error)]
195pub enum McpClientError {
196 #[error("MCP connection error: {0}")]
198 ConnectionError(String),
199
200 #[error("Failed to fetch MCP tool list: {0}")]
202 ToolFetchError(#[from] rmcp::ServiceError),
203
204 #[error("Tool server error: {0}")]
206 ToolServerError(#[from] ToolServerError),
207}
208
209pub struct McpClientHandler {
233 client_info: rmcp::model::ClientInfo,
234 tool_server_handle: ToolServerHandle,
235 managed_tool_names: Arc<RwLock<Vec<String>>>,
238}
239
240impl McpClientHandler {
241 pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
246 Self {
247 client_info,
248 tool_server_handle,
249 managed_tool_names: Arc::new(RwLock::new(Vec::new())),
250 }
251 }
252
253 pub async fn connect<T, E, A>(
266 self,
267 transport: T,
268 ) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
269 where
270 T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
271 E: std::error::Error + Send + Sync + 'static,
272 {
273 let service = ServiceExt::serve(self, transport)
274 .await
275 .map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
276
277 let tools = service.peer().list_all_tools().await?;
278
279 {
280 let handler = service.service();
281 let mut managed = handler.managed_tool_names.write().await;
282
283 for tool in tools {
284 let tool_name = tool.name.to_string();
285 let mcp_tool = McpTool::from_mcp_server(tool, service.peer().clone());
286 handler.tool_server_handle.add_tool(mcp_tool).await?;
287 managed.push(tool_name);
288 }
289 }
290
291 Ok(service)
292 }
293}
294
295impl rmcp::handler::client::ClientHandler for McpClientHandler {
296 fn get_info(&self) -> rmcp::model::ClientInfo {
297 self.client_info.clone()
298 }
299
300 async fn on_tool_list_changed(
301 &self,
302 context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
303 ) {
304 let tools = match context.peer.list_all_tools().await {
305 Ok(tools) => tools,
306 Err(e) => {
307 tracing::error!("Failed to re-fetch MCP tool list: {e}");
308 return;
309 }
310 };
311
312 let mut managed = self.managed_tool_names.write().await;
313
314 for name in managed.drain(..) {
315 if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
316 tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
317 }
318 }
319
320 for tool in tools {
321 let tool_name = tool.name.to_string();
322 let mcp_tool = McpTool::from_mcp_server(tool, context.peer.clone());
323 match self.tool_server_handle.add_tool(mcp_tool).await {
324 Ok(()) => {
325 managed.push(tool_name);
326 }
327 Err(e) => {
328 tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
329 }
330 }
331 }
332
333 tracing::info!(
334 tool_count = managed.len(),
335 "MCP tool list refreshed successfully"
336 );
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use std::sync::Arc;
343 use std::time::Duration;
344
345 use rmcp::handler::client::ClientHandler;
346 use rmcp::model::*;
347 use rmcp::service::RequestContext;
348 use rmcp::{RoleServer, ServerHandler, ServiceExt};
349 use tokio::sync::RwLock;
350
351 use super::McpClientHandler;
352 use crate::tool::server::ToolServer;
353
354 #[derive(Clone)]
356 struct DynamicToolServer {
357 tools: Arc<RwLock<Vec<Tool>>>,
358 }
359
360 impl DynamicToolServer {
361 fn new(tools: Vec<Tool>) -> Self {
362 Self {
363 tools: Arc::new(RwLock::new(tools)),
364 }
365 }
366
367 async fn set_tools(&self, tools: Vec<Tool>) {
368 *self.tools.write().await = tools;
369 }
370 }
371
372 impl ServerHandler for DynamicToolServer {
373 fn get_info(&self) -> ServerInfo {
374 ServerInfo {
375 protocol_version: ProtocolVersion::V_2024_11_05,
376 capabilities: ServerCapabilities::builder().enable_tools().build(),
377 server_info: Implementation {
378 name: "test-dynamic-server".to_string(),
379 version: "0.1.0".to_string(),
380 ..Default::default()
381 },
382 instructions: None,
383 }
384 }
385
386 async fn list_tools(
387 &self,
388 _request: Option<PaginatedRequestParams>,
389 _context: RequestContext<RoleServer>,
390 ) -> Result<ListToolsResult, ErrorData> {
391 let tools = self.tools.read().await.clone();
392 Ok(ListToolsResult {
393 tools,
394 next_cursor: None,
395 meta: None,
396 })
397 }
398
399 async fn call_tool(
400 &self,
401 request: CallToolRequestParams,
402 _context: RequestContext<RoleServer>,
403 ) -> Result<CallToolResult, ErrorData> {
404 Ok(CallToolResult::success(vec![Content::text(format!(
405 "called {}",
406 request.name
407 ))]))
408 }
409 }
410
411 fn make_tool(name: &str, description: &str) -> Tool {
412 Tool::new(
413 name.to_string(),
414 description.to_string(),
415 Arc::new(serde_json::Map::new()),
416 )
417 }
418
419 #[tokio::test]
420 async fn test_mcp_client_handler_initial_tool_registration() {
421 let initial_tools = vec![
422 make_tool("tool_a", "First tool"),
423 make_tool("tool_b", "Second tool"),
424 ];
425
426 let server = DynamicToolServer::new(initial_tools);
427 let tool_server_handle = ToolServer::new().run();
428
429 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
430 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
431
432 let server_clone = server.clone();
433 tokio::spawn(async move {
434 let _service = server_clone
435 .serve((server_from_client, server_to_client))
436 .await
437 .expect("server failed to start");
438 _service.waiting().await.expect("server error");
439 });
440
441 let client_info = ClientInfo::default();
442 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
443
444 let _mcp_service = handler
445 .connect((client_from_server, client_to_server))
446 .await
447 .expect("connect failed");
448
449 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
450 assert_eq!(defs.len(), 2);
451
452 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
453 assert!(names.contains(&"tool_a"));
454 assert!(names.contains(&"tool_b"));
455 }
456
457 #[tokio::test]
458 async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
459 let initial_tools = vec![make_tool("alpha", "Alpha tool")];
460
461 let server = DynamicToolServer::new(initial_tools);
462 let tool_server_handle = ToolServer::new().run();
463
464 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
465 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
466
467 let server_clone = server.clone();
468 let server_service_handle = tokio::spawn(async move {
469 server_clone
470 .serve((server_from_client, server_to_client))
471 .await
472 .expect("server failed to start")
473 });
474
475 let client_info = ClientInfo::default();
476 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
477
478 let _mcp_service = handler
479 .connect((client_from_server, client_to_server))
480 .await
481 .expect("connect failed");
482
483 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
485 assert_eq!(defs.len(), 1);
486 assert_eq!(defs[0].name, "alpha");
487
488 server
490 .set_tools(vec![
491 make_tool("beta", "Beta tool"),
492 make_tool("gamma", "Gamma tool"),
493 ])
494 .await;
495
496 let server_service = server_service_handle.await.unwrap();
498 server_service
499 .peer()
500 .notify_tool_list_changed()
501 .await
502 .expect("failed to send notification");
503
504 tokio::time::sleep(Duration::from_millis(200)).await;
507
508 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
509 assert_eq!(defs.len(), 2);
510
511 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
512 assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
513 assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
514 assert!(
516 !names.contains(&"alpha"),
517 "expected 'alpha' to be removed, found {names:?}"
518 );
519 }
520
521 #[tokio::test]
522 async fn test_mcp_client_handler_get_info_delegates() {
523 let client_info = ClientInfo {
524 protocol_version: Default::default(),
525 capabilities: ClientCapabilities::default(),
526 client_info: Implementation {
527 name: "test-client".to_string(),
528 version: "1.0.0".to_string(),
529 ..Default::default()
530 },
531 meta: None,
532 };
533
534 let tool_server_handle = ToolServer::new().run();
535 let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
536
537 let returned = handler.get_info();
538 assert_eq!(returned.client_info.name, "test-client");
539 assert_eq!(returned.client_info.version, "1.0.0");
540 }
541}