1use super::task::{McpTaskConfig, TaskError, TaskStatus};
10use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
11use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
12use async_trait::async_trait;
13use rmcp::{
14 RoleClient,
15 model::{
16 CallToolRequestParams, ErrorCode, RawContent, ReadResourceRequestParams, Resource,
17 ResourceContents, ResourceTemplate,
18 },
19 service::RunningService,
20};
21use serde_json::{Value, json};
22use std::ops::Deref;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::Mutex;
26use tracing::{debug, warn};
27
28type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
30
31pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
33
34fn sanitize_schema(value: &mut Value) {
38 if let Value::Object(map) = value {
39 map.remove("$schema");
40 map.remove("definitions");
41 map.remove("$ref");
42 map.remove("additionalProperties");
43
44 for (_, v) in map.iter_mut() {
45 sanitize_schema(v);
46 }
47 } else if let Value::Array(arr) = value {
48 for v in arr.iter_mut() {
49 sanitize_schema(v);
50 }
51 }
52}
53
54fn should_retry_mcp_operation(
55 error: &str,
56 attempt: u32,
57 refresh_config: &RefreshConfig,
58 has_connection_factory: bool,
59) -> bool {
60 has_connection_factory
61 && attempt < refresh_config.max_attempts
62 && should_refresh_connection(error)
63}
64
65fn is_method_not_found(err: &rmcp::ServiceError) -> bool {
68 matches!(
69 err,
70 rmcp::ServiceError::McpError(e) if e.code == ErrorCode::METHOD_NOT_FOUND
71 )
72}
73
74pub struct McpToolset<S = ()>
106where
107 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
108{
109 client: Arc<Mutex<RunningService<RoleClient, S>>>,
111 tool_filter: Option<ToolFilter>,
113 name: String,
115 task_config: McpTaskConfig,
117 connection_factory: Option<DynConnectionFactory<S>>,
119 refresh_config: RefreshConfig,
121}
122
123impl<S> McpToolset<S>
124where
125 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
126{
127 pub fn new(client: RunningService<RoleClient, S>) -> Self {
145 Self {
146 client: Arc::new(Mutex::new(client)),
147 tool_filter: None,
148 name: "mcp_toolset".to_string(),
149 task_config: McpTaskConfig::default(),
150 connection_factory: None,
151 refresh_config: RefreshConfig::default(),
152 }
153 }
154
155 pub fn with_client_handler(client: RunningService<RoleClient, S>) -> Self {
170 Self::new(client)
171 }
172
173 pub fn with_name(mut self, name: impl Into<String>) -> Self {
175 self.name = name.into();
176 self
177 }
178
179 pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
193 self.task_config = config;
194 self
195 }
196
197 pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
199 where
200 F: ConnectionFactory<S> + 'static,
201 {
202 self.connection_factory = Some(factory);
203 self
204 }
205
206 pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
208 self.refresh_config = config;
209 self
210 }
211
212 pub fn with_filter<F>(mut self, filter: F) -> Self
226 where
227 F: Fn(&str) -> bool + Send + Sync + 'static,
228 {
229 self.tool_filter = Some(Arc::new(filter));
230 self
231 }
232
233 pub fn with_tools(self, tool_names: &[&str]) -> Self {
242 let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
243 self.with_filter(move |name| names.iter().any(|n| n == name))
244 }
245
246 pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
263 let client = self.client.lock().await;
264 client.cancellation_token()
265 }
266
267 pub async fn is_closed(&self) -> bool {
282 let client = self.client.lock().await;
283 client.is_closed()
284 }
285
286 async fn try_refresh_connection(&self) -> Result<bool> {
287 let Some(factory) = self.connection_factory.clone() else {
288 return Ok(false);
289 };
290
291 let new_client = factory
292 .create_connection()
293 .await
294 .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
295
296 let mut client = self.client.lock().await;
297 let old_token = client.cancellation_token();
298 old_token.cancel();
299 *client = new_client;
300 Ok(true)
301 }
302
303 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
314 let client = self.client.lock().await;
315 match client.list_all_resources().await {
316 Ok(resources) => Ok(resources),
317 Err(e) => {
318 if is_method_not_found(&e) {
319 Ok(vec![])
320 } else {
321 Err(AdkError::tool(format!("Failed to list MCP resources: {e}")))
322 }
323 }
324 }
325 }
326
327 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>> {
338 let client = self.client.lock().await;
339 match client.list_all_resource_templates().await {
340 Ok(templates) => Ok(templates),
341 Err(e) => {
342 if is_method_not_found(&e) {
343 Ok(vec![])
344 } else {
345 Err(AdkError::tool(format!("Failed to list MCP resource templates: {e}")))
346 }
347 }
348 }
349 }
350
351 pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>> {
362 let client = self.client.lock().await;
363 let params = ReadResourceRequestParams::new(uri.to_string());
364 match client.read_resource(params).await {
365 Ok(result) => Ok(result.contents),
366 Err(e) => {
367 if is_method_not_found(&e) {
368 Err(AdkError::tool(format!("resource not found: {uri}")))
369 } else {
370 Err(AdkError::tool(format!("Failed to read MCP resource '{uri}': {e}")))
371 }
372 }
373 }
374 }
375}
376
377#[async_trait]
378impl<S> Toolset for McpToolset<S>
379where
380 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
381{
382 fn name(&self) -> &str {
383 &self.name
384 }
385
386 async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
387 let mut attempt = 0u32;
388 let has_connection_factory = self.connection_factory.is_some();
389 let mcp_tools = loop {
390 let list_result = {
391 let client = self.client.lock().await;
392 client.list_all_tools().await.map_err(|e| e.to_string())
393 };
394
395 match list_result {
396 Ok(tools) => break tools,
397 Err(error) => {
398 if !should_retry_mcp_operation(
399 &error,
400 attempt,
401 &self.refresh_config,
402 has_connection_factory,
403 ) {
404 return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
405 }
406
407 let retry_attempt = attempt + 1;
408 if self.refresh_config.log_reconnections {
409 warn!(
410 attempt = retry_attempt,
411 max_attempts = self.refresh_config.max_attempts,
412 error = %error,
413 "MCP list_all_tools failed; reconnecting and retrying"
414 );
415 }
416
417 if self.refresh_config.retry_delay_ms > 0 {
418 tokio::time::sleep(tokio::time::Duration::from_millis(
419 self.refresh_config.retry_delay_ms,
420 ))
421 .await;
422 }
423
424 if !self.try_refresh_connection().await? {
425 return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
426 }
427 attempt += 1;
428 }
429 }
430 };
431
432 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
434
435 for mcp_tool in mcp_tools {
436 let tool_name = mcp_tool.name.to_string();
437
438 if let Some(ref filter) = self.tool_filter {
440 if !filter(&tool_name) {
441 continue;
442 }
443 }
444
445 let adk_tool = McpTool {
446 name: tool_name,
447 description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
448 input_schema: {
449 let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
450 sanitize_schema(&mut schema);
451 Some(schema)
452 },
453 output_schema: mcp_tool.output_schema.map(|s| {
454 let mut schema = Value::Object(s.as_ref().clone());
455 sanitize_schema(&mut schema);
456 schema
457 }),
458 client: self.client.clone(),
459 connection_factory: self.connection_factory.clone(),
460 refresh_config: self.refresh_config.clone(),
461 is_long_running: self.task_config.enable_tasks
466 && mcp_tool.annotations.as_ref().is_some_and(|a| {
467 a.read_only_hint != Some(true) && a.open_world_hint != Some(false)
468 }),
469 task_config: self.task_config.clone(),
470 };
471
472 tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
473 }
474
475 Ok(tools)
476 }
477}
478
479impl McpToolset<super::elicitation::AdkClientHandler> {
480 pub async fn with_elicitation_handler<T, E, A>(
521 transport: T,
522 handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
523 ) -> Result<Self>
524 where
525 T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
526 E: std::error::Error + Send + Sync + 'static,
527 {
528 use rmcp::ServiceExt;
529 let adk_handler = super::elicitation::AdkClientHandler::new(handler);
530 let client = adk_handler
531 .serve(transport)
532 .await
533 .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
534 Ok(Self::new(client))
535 }
536
537 #[cfg(feature = "mcp-sampling")]
569 pub async fn with_sampling_handler<T, E, A>(
570 transport: T,
571 elicitation_handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
572 sampling_handler: std::sync::Arc<dyn crate::sampling::SamplingHandler>,
573 ) -> Result<Self>
574 where
575 T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
576 E: std::error::Error + Send + Sync + 'static,
577 {
578 use rmcp::ServiceExt;
579 let adk_handler = super::elicitation::AdkClientHandler::new(elicitation_handler)
580 .with_sampling_handler(sampling_handler);
581 let client = adk_handler
582 .serve(transport)
583 .await
584 .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
585 Ok(Self::new(client))
586 }
587}
588
589struct McpTool<S>
593where
594 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
595{
596 name: String,
597 description: String,
598 input_schema: Option<Value>,
599 output_schema: Option<Value>,
600 client: Arc<Mutex<RunningService<RoleClient, S>>>,
601 connection_factory: Option<DynConnectionFactory<S>>,
602 refresh_config: RefreshConfig,
603 is_long_running: bool,
605 task_config: McpTaskConfig,
607}
608
609impl<S> McpTool<S>
610where
611 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
612{
613 async fn try_refresh_connection(&self) -> Result<bool> {
614 let Some(factory) = self.connection_factory.clone() else {
615 return Ok(false);
616 };
617
618 let new_client = factory
619 .create_connection()
620 .await
621 .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
622
623 let mut client = self.client.lock().await;
624 let old_token = client.cancellation_token();
625 old_token.cancel();
626 *client = new_client;
627 Ok(true)
628 }
629
630 async fn call_tool_with_retry(
631 &self,
632 params: CallToolRequestParams,
633 ) -> Result<rmcp::model::CallToolResult> {
634 let has_connection_factory = self.connection_factory.is_some();
635 let mut attempt = 0u32;
636
637 loop {
638 let call_result = {
639 let client = self.client.lock().await;
640 client.call_tool(params.clone()).await.map_err(|e| e.to_string())
641 };
642
643 match call_result {
644 Ok(result) => return Ok(result),
645 Err(error) => {
646 if !should_retry_mcp_operation(
647 &error,
648 attempt,
649 &self.refresh_config,
650 has_connection_factory,
651 ) {
652 return Err(AdkError::tool(format!(
653 "Failed to call MCP tool '{}': {error}",
654 self.name
655 )));
656 }
657
658 let retry_attempt = attempt + 1;
659 if self.refresh_config.log_reconnections {
660 warn!(
661 tool = %self.name,
662 attempt = retry_attempt,
663 max_attempts = self.refresh_config.max_attempts,
664 error = %error,
665 "MCP call_tool failed; reconnecting and retrying"
666 );
667 }
668
669 if self.refresh_config.retry_delay_ms > 0 {
670 tokio::time::sleep(tokio::time::Duration::from_millis(
671 self.refresh_config.retry_delay_ms,
672 ))
673 .await;
674 }
675
676 if !self.try_refresh_connection().await? {
677 return Err(AdkError::tool(format!(
678 "Failed to call MCP tool '{}': {error}",
679 self.name
680 )));
681 }
682 attempt += 1;
683 }
684 }
685 }
686 }
687
688 async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
690 let start = Instant::now();
691 let mut attempts = 0u32;
692
693 loop {
694 if let Some(timeout_ms) = self.task_config.timeout_ms {
696 let elapsed = start.elapsed().as_millis() as u64;
697 if elapsed >= timeout_ms {
698 return Err(TaskError::Timeout {
699 task_id: task_id.to_string(),
700 elapsed_ms: elapsed,
701 });
702 }
703 }
704
705 if let Some(max_attempts) = self.task_config.max_poll_attempts {
707 if attempts >= max_attempts {
708 return Err(TaskError::MaxAttemptsExceeded {
709 task_id: task_id.to_string(),
710 attempts,
711 });
712 }
713 }
714
715 tokio::time::sleep(self.task_config.poll_duration()).await;
717 attempts += 1;
718
719 debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
720
721 let poll_result = self
724 .call_tool_with_retry(CallToolRequestParams::new("tasks/get").with_arguments(
725 serde_json::Map::from_iter([(
726 "task_id".to_string(),
727 Value::String(task_id.to_string()),
728 )]),
729 ))
730 .await
731 .map_err(|e| TaskError::PollFailed(e.to_string()))?;
732
733 let status = self.parse_task_status(&poll_result)?;
735
736 match status {
737 TaskStatus::Completed => {
738 debug!(task_id = task_id, "Task completed successfully");
739 return self.extract_task_result(&poll_result);
741 }
742 TaskStatus::Failed => {
743 let error_msg = self.extract_error_message(&poll_result);
744 return Err(TaskError::TaskFailed {
745 task_id: task_id.to_string(),
746 error: error_msg,
747 });
748 }
749 TaskStatus::Cancelled => {
750 return Err(TaskError::Cancelled(task_id.to_string()));
751 }
752 TaskStatus::Pending | TaskStatus::Running => {
753 debug!(
755 task_id = task_id,
756 status = ?status,
757 "Task still in progress"
758 );
759 }
760 }
761 }
762 }
763
764 fn parse_task_status(
766 &self,
767 result: &rmcp::model::CallToolResult,
768 ) -> std::result::Result<TaskStatus, TaskError> {
769 if let Some(ref structured) = result.structured_content {
771 if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
772 return match status_str {
773 "pending" => Ok(TaskStatus::Pending),
774 "running" => Ok(TaskStatus::Running),
775 "completed" => Ok(TaskStatus::Completed),
776 "failed" => Ok(TaskStatus::Failed),
777 "cancelled" => Ok(TaskStatus::Cancelled),
778 _ => {
779 warn!(status = status_str, "Unknown task status");
780 Ok(TaskStatus::Running) }
782 };
783 }
784 }
785
786 for content in &result.content {
788 if let Some(text_content) = content.deref().as_text() {
789 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
791 if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
792 return match status_str {
793 "pending" => Ok(TaskStatus::Pending),
794 "running" => Ok(TaskStatus::Running),
795 "completed" => Ok(TaskStatus::Completed),
796 "failed" => Ok(TaskStatus::Failed),
797 "cancelled" => Ok(TaskStatus::Cancelled),
798 _ => Ok(TaskStatus::Running),
799 };
800 }
801 }
802 }
803 }
804
805 Ok(TaskStatus::Running)
807 }
808
809 fn extract_task_result(
811 &self,
812 result: &rmcp::model::CallToolResult,
813 ) -> std::result::Result<Value, TaskError> {
814 if let Some(ref structured) = result.structured_content {
816 if let Some(output) = structured.get("result") {
817 return Ok(json!({ "output": output }));
818 }
819 return Ok(json!({ "output": structured }));
820 }
821
822 let mut text_parts: Vec<String> = Vec::new();
824 for content in &result.content {
825 if let Some(text_content) = content.deref().as_text() {
826 text_parts.push(text_content.text.clone());
827 }
828 }
829
830 if text_parts.is_empty() {
831 Ok(json!({ "output": null }))
832 } else {
833 Ok(json!({ "output": text_parts.join("\n") }))
834 }
835 }
836
837 fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
839 if let Some(ref structured) = result.structured_content {
841 if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
842 return error.to_string();
843 }
844 }
845
846 for content in &result.content {
848 if let Some(text_content) = content.deref().as_text() {
849 return text_content.text.clone();
850 }
851 }
852
853 "Unknown error".to_string()
854 }
855
856 fn extract_task_id(
858 &self,
859 result: &rmcp::model::CallToolResult,
860 ) -> std::result::Result<String, TaskError> {
861 if let Some(ref structured) = result.structured_content {
863 if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
864 return Ok(task_id.to_string());
865 }
866 }
867
868 for content in &result.content {
870 if let Some(text_content) = content.deref().as_text() {
871 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
872 if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
873 return Ok(task_id.to_string());
874 }
875 }
876 }
877 }
878
879 Err(TaskError::CreateFailed("No task_id in response".to_string()))
880 }
881}
882
883#[async_trait]
884impl<S> Tool for McpTool<S>
885where
886 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
887{
888 fn name(&self) -> &str {
889 &self.name
890 }
891
892 fn description(&self) -> &str {
893 &self.description
894 }
895
896 fn is_long_running(&self) -> bool {
897 self.is_long_running
898 }
899
900 fn parameters_schema(&self) -> Option<Value> {
901 self.input_schema.clone()
902 }
903
904 fn response_schema(&self) -> Option<Value> {
905 self.output_schema.clone()
906 }
907
908 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
909 let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
911
912 if use_task_mode {
913 debug!(tool = self.name, "Executing tool in task mode (long-running)");
914
915 let task_params = self.task_config.to_task_params();
917 let task_map = task_params.as_object().cloned();
918
919 let create_result = self
920 .call_tool_with_retry({
921 let mut params = CallToolRequestParams::new(self.name.clone());
922 if !(args.is_null() || args == json!({})) {
923 match args {
924 Value::Object(map) => {
925 params = params.with_arguments(map);
926 }
927 _ => {
928 return Err(AdkError::tool("Tool arguments must be an object"));
929 }
930 }
931 }
932 if let Some(task_map) = task_map {
933 params = params.with_task(task_map);
934 }
935 params
936 })
937 .await?;
938
939 let task_id = self
941 .extract_task_id(&create_result)
942 .map_err(|e| AdkError::tool(format!("Failed to get task ID: {e}")))?;
943
944 debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
945
946 let result = self
948 .poll_task(&task_id)
949 .await
950 .map_err(|e| AdkError::tool(format!("Task execution failed: {e}")))?;
951
952 return Ok(result);
953 }
954
955 let result = self
957 .call_tool_with_retry({
958 let mut params = CallToolRequestParams::new(self.name.clone());
959 if !(args.is_null() || args == json!({})) {
960 match args {
961 Value::Object(map) => {
962 params = params.with_arguments(map);
963 }
964 _ => {
965 return Err(AdkError::tool("Tool arguments must be an object"));
966 }
967 }
968 }
969 params
970 })
971 .await?;
972
973 if result.is_error.unwrap_or(false) {
975 let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
976
977 for content in &result.content {
979 if let Some(text_content) = content.deref().as_text() {
981 error_msg.push_str(": ");
982 error_msg.push_str(&text_content.text);
983 break;
984 }
985 }
986
987 return Err(AdkError::tool(error_msg));
988 }
989
990 if let Some(structured) = result.structured_content {
992 return Ok(json!({ "output": structured }));
993 }
994
995 let mut text_parts: Vec<String> = Vec::new();
997
998 for content in &result.content {
999 let raw: &RawContent = content.deref();
1001 match raw {
1002 RawContent::Text(text_content) => {
1003 text_parts.push(text_content.text.clone());
1004 }
1005 RawContent::Image(image_content) => {
1006 text_parts.push(format!(
1008 "[Image: {} bytes, mime: {}]",
1009 image_content.data.len(),
1010 image_content.mime_type
1011 ));
1012 }
1013 RawContent::Resource(resource_content) => {
1014 let uri = match &resource_content.resource {
1015 ResourceContents::TextResourceContents { uri, .. } => uri,
1016 ResourceContents::BlobResourceContents { uri, .. } => uri,
1017 };
1018 text_parts.push(format!("[Resource: {}]", uri));
1019 }
1020 RawContent::Audio(_) => {
1021 text_parts.push("[Audio content]".to_string());
1022 }
1023 RawContent::ResourceLink(link) => {
1024 text_parts.push(format!("[ResourceLink: {}]", link.uri));
1025 }
1026 }
1027 }
1028
1029 if text_parts.is_empty() {
1030 return Err(AdkError::tool(format!("MCP tool '{}' returned no content", self.name)));
1031 }
1032
1033 Ok(json!({ "output": text_parts.join("\n") }))
1034 }
1035}
1036
1037#[cfg(test)]
1043mod tests {
1044 use super::*;
1045
1046 #[test]
1054 fn mcp_tool_is_send_and_sync() {
1055 fn require_send_sync<T: Send + Sync>() {}
1056
1057 require_send_sync::<McpTool<()>>();
1064 require_send_sync::<McpToolset<()>>();
1065 }
1066
1067 #[test]
1068 fn test_should_retry_mcp_operation_reconnectable_errors() {
1069 let config = RefreshConfig::default().with_max_attempts(3);
1070 assert!(should_retry_mcp_operation("EOF", 0, &config, true));
1071 assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
1072 }
1073
1074 #[test]
1075 fn test_should_retry_mcp_operation_stops_at_max_attempts() {
1076 let config = RefreshConfig::default().with_max_attempts(2);
1077 assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
1078 }
1079
1080 #[test]
1081 fn test_should_retry_mcp_operation_requires_factory() {
1082 let config = RefreshConfig::default().with_max_attempts(3);
1083 assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
1084 }
1085
1086 #[test]
1087 fn test_should_retry_mcp_operation_non_reconnectable_error() {
1088 let config = RefreshConfig::default().with_max_attempts(3);
1089 assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
1090 }
1091}