1use std::borrow::Cow;
33use std::sync::Arc;
34use std::time::Duration;
35
36use rmcp::ServiceExt;
37use rmcp::model::RawContent;
38use tokio::sync::RwLock;
39
40use crate::completion::ToolDefinition;
41use crate::tool::ToolDyn;
42use crate::tool::ToolError;
43use crate::tool::server::{ToolServerError, ToolServerHandle};
44use crate::wasm_compat::WasmBoxedFuture;
45
46pub const DEFAULT_MCP_TOOL_TIMEOUT: Duration = Duration::from_secs(300);
55
56#[derive(Clone)]
61pub struct McpTool {
62 definition: rmcp::model::Tool,
63 client: rmcp::service::ServerSink,
64 timeout: Option<Duration>,
72}
73
74impl McpTool {
75 pub fn from_mcp_server(
82 definition: rmcp::model::Tool,
83 client: rmcp::service::ServerSink,
84 ) -> Self {
85 Self {
86 definition,
87 client,
88 timeout: Some(DEFAULT_MCP_TOOL_TIMEOUT),
89 }
90 }
91
92 pub fn with_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
101 self.timeout = timeout.into();
102 self
103 }
104
105 pub fn timeout(&self) -> Option<Duration> {
107 self.timeout
108 }
109}
110
111impl From<&rmcp::model::Tool> for ToolDefinition {
112 fn from(val: &rmcp::model::Tool) -> Self {
113 Self {
114 name: val.name.to_string(),
115 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
116 parameters: val.schema_as_json_value(),
117 }
118 }
119}
120
121impl From<rmcp::model::Tool> for ToolDefinition {
122 fn from(val: rmcp::model::Tool) -> Self {
123 Self {
124 name: val.name.to_string(),
125 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
126 parameters: val.schema_as_json_value(),
127 }
128 }
129}
130
131#[derive(Debug, thiserror::Error)]
132#[error("MCP tool error: {0}")]
133pub struct McpToolError(String);
134
135impl From<McpToolError> for ToolError {
136 fn from(e: McpToolError) -> Self {
137 ToolError::ToolCallError(Box::new(e))
138 }
139}
140
141impl ToolDyn for McpTool {
142 fn name(&self) -> String {
143 self.definition.name.to_string()
144 }
145
146 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
147 Box::pin(async move {
148 ToolDefinition {
149 name: self.definition.name.to_string(),
150 description: self
151 .definition
152 .description
153 .clone()
154 .unwrap_or(Cow::from(""))
155 .to_string(),
156 parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
157 }
158 })
159 }
160
161 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
162 let name = self.definition.name.clone();
163 let arguments: Option<rmcp::model::JsonObject> =
164 serde_json::from_str(&args).unwrap_or_default();
165
166 Box::pin(async move {
167 let request = arguments
168 .map(|arguments| {
169 rmcp::model::CallToolRequestParams::new(name.clone()).with_arguments(arguments)
170 })
171 .unwrap_or_else(|| rmcp::model::CallToolRequestParams::new(name));
172
173 let call = self.client.call_tool(request);
174 let call_result = match self.timeout {
177 Some(timeout) => {
178 crate::wasm_compat::timeout(timeout, call)
179 .await
180 .map_err(|_| {
181 McpToolError(format!(
182 "MCP tool '{}' timed out after {timeout:?}",
183 self.definition.name
184 ))
185 })?
186 }
187 None => call.await,
188 };
189 let result =
190 call_result.map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
191
192 if let Some(true) = result.is_error {
193 let error_msg = result
194 .content
195 .into_iter()
196 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
197 .map(|x| x.map(|x| x.clone().text))
198 .collect::<Option<Vec<String>>>();
199
200 let error_message = error_msg.map(|x| x.join("\n"));
201 if let Some(error_message) = error_message {
202 return Err(McpToolError(error_message).into());
203 } else {
204 return Err(McpToolError("No message returned".to_string()).into());
205 }
206 };
207
208 let mut content = String::new();
209
210 for item in result.content {
211 let chunk = match item.raw {
212 rmcp::model::RawContent::Text(raw) => raw.text,
213 rmcp::model::RawContent::Image(raw) => {
214 format!("data:{};base64,{}", raw.mime_type, raw.data)
215 }
216 rmcp::model::RawContent::Resource(raw) => match raw.resource {
217 rmcp::model::ResourceContents::TextResourceContents {
218 uri,
219 mime_type,
220 text,
221 ..
222 } => {
223 format!(
224 "{mime_type}{uri}:{text}",
225 mime_type =
226 mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
227 )
228 }
229 rmcp::model::ResourceContents::BlobResourceContents {
230 uri,
231 mime_type,
232 blob,
233 ..
234 } => format!(
235 "{mime_type}{uri}:{blob}",
236 mime_type = mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
237 ),
238 },
239 RawContent::Audio(_) => {
240 return Err(McpToolError(
241 "MCP tool returned audio content, which Rig does not support yet"
242 .to_string(),
243 )
244 .into());
245 }
246 thing => {
247 return Err(McpToolError(format!(
248 "MCP tool returned unsupported content: {thing:?}"
249 ))
250 .into());
251 }
252 };
253
254 content.push_str(&chunk);
255 }
256
257 Ok(content)
258 })
259 }
260}
261
262#[derive(Debug, thiserror::Error)]
264pub enum McpClientError {
265 #[error("MCP connection error: {0}")]
267 ConnectionError(String),
268
269 #[error("Failed to fetch MCP tool list: {0}")]
271 ToolFetchError(#[from] rmcp::ServiceError),
272
273 #[error("Tool server error: {0}")]
275 ToolServerError(#[from] ToolServerError),
276}
277
278pub struct McpClientHandler {
302 client_info: rmcp::model::ClientInfo,
303 tool_server_handle: ToolServerHandle,
304 timeout: Option<Duration>,
307 managed_tool_names: Arc<RwLock<Vec<String>>>,
310}
311
312impl McpClientHandler {
313 pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
319 Self {
320 client_info,
321 tool_server_handle,
322 timeout: Some(DEFAULT_MCP_TOOL_TIMEOUT),
323 managed_tool_names: Arc::new(RwLock::new(Vec::new())),
324 }
325 }
326
327 pub fn with_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
332 self.timeout = timeout.into();
333 self
334 }
335
336 fn build_tool(&self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> McpTool {
338 McpTool::from_mcp_server(tool, client).with_timeout(self.timeout)
339 }
340
341 pub async fn connect<T, E, A>(
354 self,
355 transport: T,
356 ) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
357 where
358 T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
359 E: std::error::Error + Send + Sync + 'static,
360 {
361 let service = ServiceExt::serve(self, transport)
362 .await
363 .map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
364
365 let tools = service.peer().list_all_tools().await?;
366
367 {
368 let handler = service.service();
369 let mut managed = handler.managed_tool_names.write().await;
370
371 for tool in tools {
372 let tool_name = tool.name.to_string();
373 let mcp_tool = handler.build_tool(tool, service.peer().clone());
374 handler.tool_server_handle.add_tool(mcp_tool).await?;
375 managed.push(tool_name);
376 }
377 }
378
379 Ok(service)
380 }
381}
382
383impl rmcp::handler::client::ClientHandler for McpClientHandler {
384 fn get_info(&self) -> rmcp::model::ClientInfo {
385 self.client_info.clone()
386 }
387
388 async fn on_tool_list_changed(
389 &self,
390 context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
391 ) {
392 let tools = match context.peer.list_all_tools().await {
393 Ok(tools) => tools,
394 Err(e) => {
395 tracing::error!("Failed to re-fetch MCP tool list: {e}");
396 return;
397 }
398 };
399
400 let mut managed = self.managed_tool_names.write().await;
401
402 for name in managed.drain(..) {
403 if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
404 tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
405 }
406 }
407
408 for tool in tools {
409 let tool_name = tool.name.to_string();
410 let mcp_tool = self.build_tool(tool, context.peer.clone());
411 match self.tool_server_handle.add_tool(mcp_tool).await {
412 Ok(()) => {
413 managed.push(tool_name);
414 }
415 Err(e) => {
416 tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
417 }
418 }
419 }
420
421 tracing::info!(
422 tool_count = managed.len(),
423 "MCP tool list refreshed successfully"
424 );
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use std::sync::Arc;
431 use std::time::Duration;
432
433 use rmcp::handler::client::ClientHandler;
434 use rmcp::model::*;
435 use rmcp::service::RequestContext;
436 use rmcp::{RoleServer, ServerHandler, ServiceExt};
437 use tokio::sync::RwLock;
438
439 use super::McpClientHandler;
440 use crate::tool::server::ToolServer;
441
442 #[derive(Clone)]
444 struct DynamicToolServer {
445 tools: Arc<RwLock<Vec<Tool>>>,
446 }
447
448 impl DynamicToolServer {
449 fn new(tools: Vec<Tool>) -> Self {
450 Self {
451 tools: Arc::new(RwLock::new(tools)),
452 }
453 }
454
455 async fn set_tools(&self, tools: Vec<Tool>) {
456 *self.tools.write().await = tools;
457 }
458 }
459
460 impl ServerHandler for DynamicToolServer {
461 fn get_info(&self) -> ServerInfo {
462 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
463 .with_protocol_version(ProtocolVersion::LATEST)
464 .with_server_info(Implementation::new("test-dynamic-server", "0.1.0"))
465 }
466
467 async fn list_tools(
468 &self,
469 _request: Option<PaginatedRequestParams>,
470 _context: RequestContext<RoleServer>,
471 ) -> Result<ListToolsResult, ErrorData> {
472 let tools = self.tools.read().await.clone();
473 Ok(ListToolsResult::with_all_items(tools))
474 }
475
476 async fn call_tool(
477 &self,
478 request: CallToolRequestParams,
479 _context: RequestContext<RoleServer>,
480 ) -> Result<CallToolResult, ErrorData> {
481 Ok(CallToolResult::success(vec![Content::text(format!(
482 "called {}",
483 request.name
484 ))]))
485 }
486 }
487
488 fn make_tool(name: &str, description: &str) -> Tool {
489 Tool::new(
490 name.to_string(),
491 description.to_string(),
492 Arc::new(serde_json::Map::new()),
493 )
494 }
495
496 #[derive(Clone)]
508 struct HangingToolServer;
509
510 impl ServerHandler for HangingToolServer {
511 fn get_info(&self) -> ServerInfo {
512 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
513 .with_protocol_version(ProtocolVersion::LATEST)
514 .with_server_info(Implementation::new("hanging-server", "0.1.0"))
515 }
516
517 async fn list_tools(
518 &self,
519 _request: Option<PaginatedRequestParams>,
520 _context: RequestContext<RoleServer>,
521 ) -> Result<ListToolsResult, ErrorData> {
522 Ok(ListToolsResult::with_all_items(vec![make_tool(
523 "hang_forever",
524 "A tool whose handler never returns",
525 )]))
526 }
527
528 async fn call_tool(
529 &self,
530 _request: CallToolRequestParams,
531 _context: RequestContext<RoleServer>,
532 ) -> Result<CallToolResult, ErrorData> {
533 std::future::pending::<Result<CallToolResult, ErrorData>>().await
537 }
538 }
539
540 #[tokio::test]
541 async fn test_mcp_client_handler_initial_tool_registration() {
542 let initial_tools = vec![
543 make_tool("tool_a", "First tool"),
544 make_tool("tool_b", "Second tool"),
545 ];
546
547 let server = DynamicToolServer::new(initial_tools);
548 let tool_server_handle = ToolServer::new().run();
549
550 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
551 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
552
553 let server_clone = server.clone();
554 tokio::spawn(async move {
555 let _service = server_clone
556 .serve((server_from_client, server_to_client))
557 .await
558 .expect("server failed to start");
559 _service.waiting().await.expect("server error");
560 });
561
562 let client_info = ClientInfo::default();
563 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
564
565 let _mcp_service = handler
566 .connect((client_from_server, client_to_server))
567 .await
568 .expect("connect failed");
569
570 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
571 assert_eq!(defs.len(), 2);
572
573 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
574 assert!(names.contains(&"tool_a"));
575 assert!(names.contains(&"tool_b"));
576 }
577
578 #[tokio::test]
579 async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
580 let initial_tools = vec![make_tool("alpha", "Alpha tool")];
581
582 let server = DynamicToolServer::new(initial_tools);
583 let tool_server_handle = ToolServer::new().run();
584
585 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
586 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
587
588 let server_clone = server.clone();
589 let server_service_handle = tokio::spawn(async move {
590 server_clone
591 .serve((server_from_client, server_to_client))
592 .await
593 .expect("server failed to start")
594 });
595
596 let client_info = ClientInfo::default();
597 let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
598
599 let _mcp_service = handler
600 .connect((client_from_server, client_to_server))
601 .await
602 .expect("connect failed");
603
604 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
606 assert_eq!(defs.len(), 1);
607 assert_eq!(defs[0].name, "alpha");
608
609 server
611 .set_tools(vec![
612 make_tool("beta", "Beta tool"),
613 make_tool("gamma", "Gamma tool"),
614 ])
615 .await;
616
617 let server_service = server_service_handle.await.unwrap();
619 server_service
620 .peer()
621 .notify_tool_list_changed()
622 .await
623 .expect("failed to send notification");
624
625 tokio::time::sleep(Duration::from_millis(200)).await;
628
629 let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
630 assert_eq!(defs.len(), 2);
631
632 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
633 assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
634 assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
635 assert!(
637 !names.contains(&"alpha"),
638 "expected 'alpha' to be removed, found {names:?}"
639 );
640 }
641
642 #[tokio::test]
643 async fn test_mcp_client_handler_get_info_delegates() {
644 let client_info = ClientInfo::new(
645 ClientCapabilities::default(),
646 Implementation::new("test-client", "1.0.0"),
647 );
648
649 let tool_server_handle = ToolServer::new().run();
650 let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
651
652 let returned = handler.get_info();
653 assert_eq!(returned.client_info.name, "test-client");
654 assert_eq!(returned.client_info.version, "1.0.0");
655 }
656
657 #[tokio::test]
672 async fn mcp_tool_call_without_timeout_is_unbounded() {
673 use super::McpTool;
674 use crate::tool::ToolDyn;
675
676 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
677 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
678
679 let server_task = tokio::spawn(async move {
680 let running = HangingToolServer
681 .serve((server_from_client, server_to_client))
682 .await
683 .expect("server failed to start");
684 running.waiting().await.expect("server error");
685 });
686
687 let client = ClientInfo::default()
690 .serve((client_from_server, client_to_server))
691 .await
692 .expect("client connect failed");
693
694 let tools = client
695 .peer()
696 .list_all_tools()
697 .await
698 .expect("list_tools failed");
699 assert_eq!(tools.len(), 1, "expected exactly one advertised tool");
700
701 let mcp_tool = McpTool::from_mcp_server(tools[0].clone(), client.peer().clone());
703 assert_eq!(mcp_tool.timeout(), Some(super::DEFAULT_MCP_TOOL_TIMEOUT));
704 let mcp_tool = mcp_tool.with_timeout(None);
706 assert_eq!(mcp_tool.timeout(), None);
707
708 let timed =
709 tokio::time::timeout(Duration::from_millis(150), mcp_tool.call("{}".to_string())).await;
710
711 assert!(
712 timed.is_err(),
713 "with the timeout disabled, McpTool::call must stay unbounded; got {:?}",
714 timed.ok(),
715 );
716
717 server_task.abort();
718 }
719
720 #[tokio::test]
728 async fn mcp_tool_call_with_timeout_errors_instead_of_hanging() {
729 use super::McpTool;
730 use crate::tool::ToolDyn;
731
732 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
733 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
734
735 let server_task = tokio::spawn(async move {
736 let running = HangingToolServer
737 .serve((server_from_client, server_to_client))
738 .await
739 .expect("server failed to start");
740 running.waiting().await.expect("server error");
741 });
742
743 let client = ClientInfo::default()
744 .serve((client_from_server, client_to_server))
745 .await
746 .expect("client connect failed");
747
748 let tools = client
749 .peer()
750 .list_all_tools()
751 .await
752 .expect("list_tools failed");
753
754 let mcp_tool = McpTool::from_mcp_server(tools[0].clone(), client.peer().clone())
756 .with_timeout(Duration::from_millis(200));
757
758 let timed =
759 tokio::time::timeout(Duration::from_secs(5), mcp_tool.call("{}".to_string())).await;
760
761 let result = timed.expect(
762 "regression: McpTool::call hung past the safety timeout; the per-call \
763 timeout did not fire (issue #1914 fix is broken)",
764 );
765 let err =
766 result.expect_err("call should resolve to an error when the server never responds");
767 assert!(
769 err.to_string().contains("timed out"),
770 "expected a timeout error, got: {err}"
771 );
772
773 server_task.abort();
774 }
775
776 #[tokio::test]
782 async fn mcp_tool_call_returns_promptly_for_responsive_server() {
783 use super::McpTool;
784 use crate::tool::ToolDyn;
785
786 let server = DynamicToolServer::new(vec![make_tool("ping", "responds immediately")]);
787
788 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
789 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
790
791 let server_clone = server.clone();
792 let server_task = tokio::spawn(async move {
793 let running = server_clone
794 .serve((server_from_client, server_to_client))
795 .await
796 .expect("server failed to start");
797 running.waiting().await.expect("server error");
798 });
799
800 let client = ClientInfo::default()
801 .serve((client_from_server, client_to_server))
802 .await
803 .expect("client connect failed");
804
805 let tools = client
806 .peer()
807 .list_all_tools()
808 .await
809 .expect("list_tools failed");
810 let mcp_tool = McpTool::from_mcp_server(tools[0].clone(), client.peer().clone())
811 .with_timeout(Duration::from_secs(2));
812
813 let timed =
814 tokio::time::timeout(Duration::from_secs(5), mcp_tool.call("{}".to_string())).await;
815
816 let result = timed
817 .expect("responsive tool should resolve within the safety window")
818 .expect("tool call should succeed");
819 assert!(result.contains("ping"), "unexpected tool output: {result}");
820
821 server_task.abort();
822 }
823
824 #[tokio::test]
828 async fn mcp_client_handler_with_timeout_bounds_registered_tools() {
829 let tool_server_handle = ToolServer::new().run();
830
831 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
832 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
833
834 let server_task = tokio::spawn(async move {
835 let running = HangingToolServer
836 .serve((server_from_client, server_to_client))
837 .await
838 .expect("server failed to start");
839 running.waiting().await.expect("server error");
840 });
841
842 let handler = McpClientHandler::new(ClientInfo::default(), tool_server_handle.clone())
843 .with_timeout(Duration::from_millis(200));
844 let _mcp_service = handler
845 .connect((client_from_server, client_to_server))
846 .await
847 .expect("connect failed");
848
849 let timed = tokio::time::timeout(
851 Duration::from_secs(5),
852 tool_server_handle.call_tool("hang_forever", "{}"),
853 )
854 .await;
855
856 let result = timed.expect("handler-registered tool hung past the safety timeout");
857 let err = result.expect_err("call should time out when the server never responds");
858 assert!(
859 err.to_string().contains("timed out"),
860 "expected a timeout error, got: {err}"
861 );
862
863 server_task.abort();
864 }
865
866 #[tokio::test]
869 async fn tool_server_rmcp_tool_with_timeout_bounds_calls() {
870 let (client_to_server, server_from_client) = tokio::io::duplex(8192);
871 let (server_to_client, client_from_server) = tokio::io::duplex(8192);
872
873 let server_task = tokio::spawn(async move {
874 let running = HangingToolServer
875 .serve((server_from_client, server_to_client))
876 .await
877 .expect("server failed to start");
878 running.waiting().await.expect("server error");
879 });
880
881 let client = ClientInfo::default()
882 .serve((client_from_server, client_to_server))
883 .await
884 .expect("client connect failed");
885
886 let handle = ToolServer::new()
889 .rmcp_tool_with_timeout(
890 make_tool("hang_forever", "never returns"),
891 client.peer().clone(),
892 Duration::from_millis(200),
893 )
894 .run();
895
896 let timed = tokio::time::timeout(
897 Duration::from_secs(5),
898 handle.call_tool("hang_forever", "{}"),
899 )
900 .await;
901
902 let result = timed.expect("ToolServer-registered tool hung past the safety timeout");
903 let err = result.expect_err("call should time out when the server never responds");
904 assert!(
905 err.to_string().contains("timed out"),
906 "expected a timeout error, got: {err}"
907 );
908
909 server_task.abort();
910 }
911}