1use super::task::{McpTaskConfig, TaskError, TaskStatus};
10use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
11use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
12use async_trait::async_trait;
13use rmcp::{
14 RoleClient,
15 model::{
16 CallToolRequestParams, ErrorCode, RawContent, ReadResourceRequestParams, Resource,
17 ResourceContents, ResourceTemplate,
18 },
19 service::RunningService,
20};
21use serde_json::{Value, json};
22use std::ops::Deref;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::Mutex;
26use tracing::{debug, warn};
27
28type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
30
31pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
33
34fn sanitize_schema(value: &mut Value) {
38 if let Value::Object(map) = value {
39 map.remove("$schema");
40 map.remove("definitions");
41 map.remove("$ref");
42 map.remove("additionalProperties");
43
44 for (_, v) in map.iter_mut() {
45 sanitize_schema(v);
46 }
47 } else if let Value::Array(arr) = value {
48 for v in arr.iter_mut() {
49 sanitize_schema(v);
50 }
51 }
52}
53
54fn should_retry_mcp_operation(
55 error: &str,
56 attempt: u32,
57 refresh_config: &RefreshConfig,
58 has_connection_factory: bool,
59) -> bool {
60 has_connection_factory
61 && attempt < refresh_config.max_attempts
62 && should_refresh_connection(error)
63}
64
65fn is_method_not_found(err: &rmcp::ServiceError) -> bool {
68 matches!(
69 err,
70 rmcp::ServiceError::McpError(e) if e.code == ErrorCode::METHOD_NOT_FOUND
71 )
72}
73
74pub struct McpToolset<S = ()>
106where
107 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
108{
109 client: Arc<Mutex<RunningService<RoleClient, S>>>,
111 tool_filter: Option<ToolFilter>,
113 name: String,
115 task_config: McpTaskConfig,
117 connection_factory: Option<DynConnectionFactory<S>>,
119 refresh_config: RefreshConfig,
121}
122
123impl<S> McpToolset<S>
124where
125 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
126{
127 pub fn new(client: RunningService<RoleClient, S>) -> Self {
145 Self {
146 client: Arc::new(Mutex::new(client)),
147 tool_filter: None,
148 name: "mcp_toolset".to_string(),
149 task_config: McpTaskConfig::default(),
150 connection_factory: None,
151 refresh_config: RefreshConfig::default(),
152 }
153 }
154
155 pub fn with_client_handler(client: RunningService<RoleClient, S>) -> Self {
170 Self::new(client)
171 }
172
173 pub fn with_name(mut self, name: impl Into<String>) -> Self {
175 self.name = name.into();
176 self
177 }
178
179 pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
193 self.task_config = config;
194 self
195 }
196
197 pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
199 where
200 F: ConnectionFactory<S> + 'static,
201 {
202 self.connection_factory = Some(factory);
203 self
204 }
205
206 pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
208 self.refresh_config = config;
209 self
210 }
211
212 pub fn with_filter<F>(mut self, filter: F) -> Self
226 where
227 F: Fn(&str) -> bool + Send + Sync + 'static,
228 {
229 self.tool_filter = Some(Arc::new(filter));
230 self
231 }
232
233 pub fn with_tools(self, tool_names: &[&str]) -> Self {
242 let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
243 self.with_filter(move |name| names.iter().any(|n| n == name))
244 }
245
246 pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
263 let client = self.client.lock().await;
264 client.cancellation_token()
265 }
266
267 async fn try_refresh_connection(&self) -> Result<bool> {
268 let Some(factory) = self.connection_factory.clone() else {
269 return Ok(false);
270 };
271
272 let new_client = factory
273 .create_connection()
274 .await
275 .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
276
277 let mut client = self.client.lock().await;
278 let old_token = client.cancellation_token();
279 old_token.cancel();
280 *client = new_client;
281 Ok(true)
282 }
283
284 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
295 let client = self.client.lock().await;
296 match client.list_all_resources().await {
297 Ok(resources) => Ok(resources),
298 Err(e) => {
299 if is_method_not_found(&e) {
300 Ok(vec![])
301 } else {
302 Err(AdkError::tool(format!("Failed to list MCP resources: {e}")))
303 }
304 }
305 }
306 }
307
308 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>> {
319 let client = self.client.lock().await;
320 match client.list_all_resource_templates().await {
321 Ok(templates) => Ok(templates),
322 Err(e) => {
323 if is_method_not_found(&e) {
324 Ok(vec![])
325 } else {
326 Err(AdkError::tool(format!("Failed to list MCP resource templates: {e}")))
327 }
328 }
329 }
330 }
331
332 pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>> {
343 let client = self.client.lock().await;
344 let params = ReadResourceRequestParams::new(uri.to_string());
345 match client.read_resource(params).await {
346 Ok(result) => Ok(result.contents),
347 Err(e) => {
348 if is_method_not_found(&e) {
349 Err(AdkError::tool(format!("resource not found: {uri}")))
350 } else {
351 Err(AdkError::tool(format!("Failed to read MCP resource '{uri}': {e}")))
352 }
353 }
354 }
355 }
356}
357
358#[async_trait]
359impl<S> Toolset for McpToolset<S>
360where
361 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
362{
363 fn name(&self) -> &str {
364 &self.name
365 }
366
367 async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
368 let mut attempt = 0u32;
369 let has_connection_factory = self.connection_factory.is_some();
370 let mcp_tools = loop {
371 let list_result = {
372 let client = self.client.lock().await;
373 client.list_all_tools().await.map_err(|e| e.to_string())
374 };
375
376 match list_result {
377 Ok(tools) => break tools,
378 Err(error) => {
379 if !should_retry_mcp_operation(
380 &error,
381 attempt,
382 &self.refresh_config,
383 has_connection_factory,
384 ) {
385 return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
386 }
387
388 let retry_attempt = attempt + 1;
389 if self.refresh_config.log_reconnections {
390 warn!(
391 attempt = retry_attempt,
392 max_attempts = self.refresh_config.max_attempts,
393 error = %error,
394 "MCP list_all_tools failed; reconnecting and retrying"
395 );
396 }
397
398 if self.refresh_config.retry_delay_ms > 0 {
399 tokio::time::sleep(tokio::time::Duration::from_millis(
400 self.refresh_config.retry_delay_ms,
401 ))
402 .await;
403 }
404
405 if !self.try_refresh_connection().await? {
406 return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
407 }
408 attempt += 1;
409 }
410 }
411 };
412
413 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
415
416 for mcp_tool in mcp_tools {
417 let tool_name = mcp_tool.name.to_string();
418
419 if let Some(ref filter) = self.tool_filter {
421 if !filter(&tool_name) {
422 continue;
423 }
424 }
425
426 let adk_tool = McpTool {
427 name: tool_name,
428 description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
429 input_schema: {
430 let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
431 sanitize_schema(&mut schema);
432 Some(schema)
433 },
434 output_schema: mcp_tool.output_schema.map(|s| {
435 let mut schema = Value::Object(s.as_ref().clone());
436 sanitize_schema(&mut schema);
437 schema
438 }),
439 client: self.client.clone(),
440 connection_factory: self.connection_factory.clone(),
441 refresh_config: self.refresh_config.clone(),
442 is_long_running: self.task_config.enable_tasks
447 && mcp_tool.annotations.as_ref().is_some_and(|a| {
448 a.read_only_hint != Some(true) && a.open_world_hint != Some(false)
449 }),
450 task_config: self.task_config.clone(),
451 };
452
453 tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
454 }
455
456 Ok(tools)
457 }
458}
459
460impl McpToolset<super::elicitation::AdkClientHandler> {
461 pub async fn with_elicitation_handler<T, E, A>(
502 transport: T,
503 handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
504 ) -> Result<Self>
505 where
506 T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
507 E: std::error::Error + Send + Sync + 'static,
508 {
509 use rmcp::ServiceExt;
510 let adk_handler = super::elicitation::AdkClientHandler::new(handler);
511 let client = adk_handler
512 .serve(transport)
513 .await
514 .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
515 Ok(Self::new(client))
516 }
517}
518
519struct McpTool<S>
523where
524 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
525{
526 name: String,
527 description: String,
528 input_schema: Option<Value>,
529 output_schema: Option<Value>,
530 client: Arc<Mutex<RunningService<RoleClient, S>>>,
531 connection_factory: Option<DynConnectionFactory<S>>,
532 refresh_config: RefreshConfig,
533 is_long_running: bool,
535 task_config: McpTaskConfig,
537}
538
539impl<S> McpTool<S>
540where
541 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
542{
543 async fn try_refresh_connection(&self) -> Result<bool> {
544 let Some(factory) = self.connection_factory.clone() else {
545 return Ok(false);
546 };
547
548 let new_client = factory
549 .create_connection()
550 .await
551 .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
552
553 let mut client = self.client.lock().await;
554 let old_token = client.cancellation_token();
555 old_token.cancel();
556 *client = new_client;
557 Ok(true)
558 }
559
560 async fn call_tool_with_retry(
561 &self,
562 params: CallToolRequestParams,
563 ) -> Result<rmcp::model::CallToolResult> {
564 let has_connection_factory = self.connection_factory.is_some();
565 let mut attempt = 0u32;
566
567 loop {
568 let call_result = {
569 let client = self.client.lock().await;
570 client.call_tool(params.clone()).await.map_err(|e| e.to_string())
571 };
572
573 match call_result {
574 Ok(result) => return Ok(result),
575 Err(error) => {
576 if !should_retry_mcp_operation(
577 &error,
578 attempt,
579 &self.refresh_config,
580 has_connection_factory,
581 ) {
582 return Err(AdkError::tool(format!(
583 "Failed to call MCP tool '{}': {error}",
584 self.name
585 )));
586 }
587
588 let retry_attempt = attempt + 1;
589 if self.refresh_config.log_reconnections {
590 warn!(
591 tool = %self.name,
592 attempt = retry_attempt,
593 max_attempts = self.refresh_config.max_attempts,
594 error = %error,
595 "MCP call_tool failed; reconnecting and retrying"
596 );
597 }
598
599 if self.refresh_config.retry_delay_ms > 0 {
600 tokio::time::sleep(tokio::time::Duration::from_millis(
601 self.refresh_config.retry_delay_ms,
602 ))
603 .await;
604 }
605
606 if !self.try_refresh_connection().await? {
607 return Err(AdkError::tool(format!(
608 "Failed to call MCP tool '{}': {error}",
609 self.name
610 )));
611 }
612 attempt += 1;
613 }
614 }
615 }
616 }
617
618 async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
620 let start = Instant::now();
621 let mut attempts = 0u32;
622
623 loop {
624 if let Some(timeout_ms) = self.task_config.timeout_ms {
626 let elapsed = start.elapsed().as_millis() as u64;
627 if elapsed >= timeout_ms {
628 return Err(TaskError::Timeout {
629 task_id: task_id.to_string(),
630 elapsed_ms: elapsed,
631 });
632 }
633 }
634
635 if let Some(max_attempts) = self.task_config.max_poll_attempts {
637 if attempts >= max_attempts {
638 return Err(TaskError::MaxAttemptsExceeded {
639 task_id: task_id.to_string(),
640 attempts,
641 });
642 }
643 }
644
645 tokio::time::sleep(self.task_config.poll_duration()).await;
647 attempts += 1;
648
649 debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
650
651 let poll_result = self
654 .call_tool_with_retry(CallToolRequestParams::new("tasks/get").with_arguments(
655 serde_json::Map::from_iter([(
656 "task_id".to_string(),
657 Value::String(task_id.to_string()),
658 )]),
659 ))
660 .await
661 .map_err(|e| TaskError::PollFailed(e.to_string()))?;
662
663 let status = self.parse_task_status(&poll_result)?;
665
666 match status {
667 TaskStatus::Completed => {
668 debug!(task_id = task_id, "Task completed successfully");
669 return self.extract_task_result(&poll_result);
671 }
672 TaskStatus::Failed => {
673 let error_msg = self.extract_error_message(&poll_result);
674 return Err(TaskError::TaskFailed {
675 task_id: task_id.to_string(),
676 error: error_msg,
677 });
678 }
679 TaskStatus::Cancelled => {
680 return Err(TaskError::Cancelled(task_id.to_string()));
681 }
682 TaskStatus::Pending | TaskStatus::Running => {
683 debug!(
685 task_id = task_id,
686 status = ?status,
687 "Task still in progress"
688 );
689 }
690 }
691 }
692 }
693
694 fn parse_task_status(
696 &self,
697 result: &rmcp::model::CallToolResult,
698 ) -> std::result::Result<TaskStatus, TaskError> {
699 if let Some(ref structured) = result.structured_content {
701 if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
702 return match status_str {
703 "pending" => Ok(TaskStatus::Pending),
704 "running" => Ok(TaskStatus::Running),
705 "completed" => Ok(TaskStatus::Completed),
706 "failed" => Ok(TaskStatus::Failed),
707 "cancelled" => Ok(TaskStatus::Cancelled),
708 _ => {
709 warn!(status = status_str, "Unknown task status");
710 Ok(TaskStatus::Running) }
712 };
713 }
714 }
715
716 for content in &result.content {
718 if let Some(text_content) = content.deref().as_text() {
719 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
721 if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
722 return match status_str {
723 "pending" => Ok(TaskStatus::Pending),
724 "running" => Ok(TaskStatus::Running),
725 "completed" => Ok(TaskStatus::Completed),
726 "failed" => Ok(TaskStatus::Failed),
727 "cancelled" => Ok(TaskStatus::Cancelled),
728 _ => Ok(TaskStatus::Running),
729 };
730 }
731 }
732 }
733 }
734
735 Ok(TaskStatus::Running)
737 }
738
739 fn extract_task_result(
741 &self,
742 result: &rmcp::model::CallToolResult,
743 ) -> std::result::Result<Value, TaskError> {
744 if let Some(ref structured) = result.structured_content {
746 if let Some(output) = structured.get("result") {
747 return Ok(json!({ "output": output }));
748 }
749 return Ok(json!({ "output": structured }));
750 }
751
752 let mut text_parts: Vec<String> = Vec::new();
754 for content in &result.content {
755 if let Some(text_content) = content.deref().as_text() {
756 text_parts.push(text_content.text.clone());
757 }
758 }
759
760 if text_parts.is_empty() {
761 Ok(json!({ "output": null }))
762 } else {
763 Ok(json!({ "output": text_parts.join("\n") }))
764 }
765 }
766
767 fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
769 if let Some(ref structured) = result.structured_content {
771 if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
772 return error.to_string();
773 }
774 }
775
776 for content in &result.content {
778 if let Some(text_content) = content.deref().as_text() {
779 return text_content.text.clone();
780 }
781 }
782
783 "Unknown error".to_string()
784 }
785
786 fn extract_task_id(
788 &self,
789 result: &rmcp::model::CallToolResult,
790 ) -> std::result::Result<String, TaskError> {
791 if let Some(ref structured) = result.structured_content {
793 if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
794 return Ok(task_id.to_string());
795 }
796 }
797
798 for content in &result.content {
800 if let Some(text_content) = content.deref().as_text() {
801 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
802 if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
803 return Ok(task_id.to_string());
804 }
805 }
806 }
807 }
808
809 Err(TaskError::CreateFailed("No task_id in response".to_string()))
810 }
811}
812
813#[async_trait]
814impl<S> Tool for McpTool<S>
815where
816 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
817{
818 fn name(&self) -> &str {
819 &self.name
820 }
821
822 fn description(&self) -> &str {
823 &self.description
824 }
825
826 fn is_long_running(&self) -> bool {
827 self.is_long_running
828 }
829
830 fn parameters_schema(&self) -> Option<Value> {
831 self.input_schema.clone()
832 }
833
834 fn response_schema(&self) -> Option<Value> {
835 self.output_schema.clone()
836 }
837
838 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
839 let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
841
842 if use_task_mode {
843 debug!(tool = self.name, "Executing tool in task mode (long-running)");
844
845 let task_params = self.task_config.to_task_params();
847 let task_map = task_params.as_object().cloned();
848
849 let create_result = self
850 .call_tool_with_retry({
851 let mut params = CallToolRequestParams::new(self.name.clone());
852 if !(args.is_null() || args == json!({})) {
853 match args {
854 Value::Object(map) => {
855 params = params.with_arguments(map);
856 }
857 _ => {
858 return Err(AdkError::tool("Tool arguments must be an object"));
859 }
860 }
861 }
862 if let Some(task_map) = task_map {
863 params = params.with_task(task_map);
864 }
865 params
866 })
867 .await?;
868
869 let task_id = self
871 .extract_task_id(&create_result)
872 .map_err(|e| AdkError::tool(format!("Failed to get task ID: {e}")))?;
873
874 debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
875
876 let result = self
878 .poll_task(&task_id)
879 .await
880 .map_err(|e| AdkError::tool(format!("Task execution failed: {e}")))?;
881
882 return Ok(result);
883 }
884
885 let result = self
887 .call_tool_with_retry({
888 let mut params = CallToolRequestParams::new(self.name.clone());
889 if !(args.is_null() || args == json!({})) {
890 match args {
891 Value::Object(map) => {
892 params = params.with_arguments(map);
893 }
894 _ => {
895 return Err(AdkError::tool("Tool arguments must be an object"));
896 }
897 }
898 }
899 params
900 })
901 .await?;
902
903 if result.is_error.unwrap_or(false) {
905 let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
906
907 for content in &result.content {
909 if let Some(text_content) = content.deref().as_text() {
911 error_msg.push_str(": ");
912 error_msg.push_str(&text_content.text);
913 break;
914 }
915 }
916
917 return Err(AdkError::tool(error_msg));
918 }
919
920 if let Some(structured) = result.structured_content {
922 return Ok(json!({ "output": structured }));
923 }
924
925 let mut text_parts: Vec<String> = Vec::new();
927
928 for content in &result.content {
929 let raw: &RawContent = content.deref();
931 match raw {
932 RawContent::Text(text_content) => {
933 text_parts.push(text_content.text.clone());
934 }
935 RawContent::Image(image_content) => {
936 text_parts.push(format!(
938 "[Image: {} bytes, mime: {}]",
939 image_content.data.len(),
940 image_content.mime_type
941 ));
942 }
943 RawContent::Resource(resource_content) => {
944 let uri = match &resource_content.resource {
945 ResourceContents::TextResourceContents { uri, .. } => uri,
946 ResourceContents::BlobResourceContents { uri, .. } => uri,
947 };
948 text_parts.push(format!("[Resource: {}]", uri));
949 }
950 RawContent::Audio(_) => {
951 text_parts.push("[Audio content]".to_string());
952 }
953 RawContent::ResourceLink(link) => {
954 text_parts.push(format!("[ResourceLink: {}]", link.uri));
955 }
956 }
957 }
958
959 if text_parts.is_empty() {
960 return Err(AdkError::tool(format!("MCP tool '{}' returned no content", self.name)));
961 }
962
963 Ok(json!({ "output": text_parts.join("\n") }))
964 }
965}
966
967#[cfg(test)]
973mod tests {
974 use super::*;
975
976 #[test]
984 fn mcp_tool_is_send_and_sync() {
985 fn require_send_sync<T: Send + Sync>() {}
986
987 require_send_sync::<McpTool<()>>();
994 require_send_sync::<McpToolset<()>>();
995 }
996
997 #[test]
998 fn test_should_retry_mcp_operation_reconnectable_errors() {
999 let config = RefreshConfig::default().with_max_attempts(3);
1000 assert!(should_retry_mcp_operation("EOF", 0, &config, true));
1001 assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
1002 }
1003
1004 #[test]
1005 fn test_should_retry_mcp_operation_stops_at_max_attempts() {
1006 let config = RefreshConfig::default().with_max_attempts(2);
1007 assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
1008 }
1009
1010 #[test]
1011 fn test_should_retry_mcp_operation_requires_factory() {
1012 let config = RefreshConfig::default().with_max_attempts(3);
1013 assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
1014 }
1015
1016 #[test]
1017 fn test_should_retry_mcp_operation_non_reconnectable_error() {
1018 let config = RefreshConfig::default().with_max_attempts(3);
1019 assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
1020 }
1021}