1use super::task::{McpTaskConfig, TaskError, TaskStatus};
10use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
11use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
12use async_trait::async_trait;
13use rmcp::{
14 RoleClient,
15 model::{CallToolRequestParams, RawContent, ResourceContents},
16 service::RunningService,
17};
18use serde_json::{Value, json};
19use std::ops::Deref;
20use std::sync::Arc;
21use std::time::Instant;
22use tokio::sync::Mutex;
23use tracing::{debug, warn};
24
25type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
27
28pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
30
31fn sanitize_schema(value: &mut Value) {
35 if let Value::Object(map) = value {
36 map.remove("$schema");
37 map.remove("definitions");
38 map.remove("$ref");
39 map.remove("additionalProperties");
40
41 for (_, v) in map.iter_mut() {
42 sanitize_schema(v);
43 }
44 } else if let Value::Array(arr) = value {
45 for v in arr.iter_mut() {
46 sanitize_schema(v);
47 }
48 }
49}
50
51fn should_retry_mcp_operation(
52 error: &str,
53 attempt: u32,
54 refresh_config: &RefreshConfig,
55 has_connection_factory: bool,
56) -> bool {
57 has_connection_factory
58 && attempt < refresh_config.max_attempts
59 && should_refresh_connection(error)
60}
61
62pub struct McpToolset<S = ()>
94where
95 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
96{
97 client: Arc<Mutex<RunningService<RoleClient, S>>>,
99 tool_filter: Option<ToolFilter>,
101 name: String,
103 task_config: McpTaskConfig,
105 connection_factory: Option<DynConnectionFactory<S>>,
107 refresh_config: RefreshConfig,
109}
110
111impl<S> McpToolset<S>
112where
113 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
114{
115 pub fn new(client: RunningService<RoleClient, S>) -> Self {
133 Self {
134 client: Arc::new(Mutex::new(client)),
135 tool_filter: None,
136 name: "mcp_toolset".to_string(),
137 task_config: McpTaskConfig::default(),
138 connection_factory: None,
139 refresh_config: RefreshConfig::default(),
140 }
141 }
142
143 pub fn with_name(mut self, name: impl Into<String>) -> Self {
145 self.name = name.into();
146 self
147 }
148
149 pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
163 self.task_config = config;
164 self
165 }
166
167 pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
169 where
170 F: ConnectionFactory<S> + 'static,
171 {
172 self.connection_factory = Some(factory);
173 self
174 }
175
176 pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
178 self.refresh_config = config;
179 self
180 }
181
182 pub fn with_filter<F>(mut self, filter: F) -> Self
196 where
197 F: Fn(&str) -> bool + Send + Sync + 'static,
198 {
199 self.tool_filter = Some(Arc::new(filter));
200 self
201 }
202
203 pub fn with_tools(self, tool_names: &[&str]) -> Self {
212 let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
213 self.with_filter(move |name| names.iter().any(|n| n == name))
214 }
215
216 pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
233 let client = self.client.lock().await;
234 client.cancellation_token()
235 }
236
237 async fn try_refresh_connection(&self) -> Result<bool> {
238 let Some(factory) = self.connection_factory.clone() else {
239 return Ok(false);
240 };
241
242 let new_client = factory
243 .create_connection()
244 .await
245 .map_err(|e| AdkError::Tool(format!("Failed to refresh MCP connection: {}", e)))?;
246
247 let mut client = self.client.lock().await;
248 let old_token = client.cancellation_token();
249 old_token.cancel();
250 *client = new_client;
251 Ok(true)
252 }
253}
254
255#[async_trait]
256impl<S> Toolset for McpToolset<S>
257where
258 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
259{
260 fn name(&self) -> &str {
261 &self.name
262 }
263
264 async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
265 let mut attempt = 0u32;
266 let has_connection_factory = self.connection_factory.is_some();
267 let mcp_tools = loop {
268 let list_result = {
269 let client = self.client.lock().await;
270 client.list_all_tools().await.map_err(|e| e.to_string())
271 };
272
273 match list_result {
274 Ok(tools) => break tools,
275 Err(error) => {
276 if !should_retry_mcp_operation(
277 &error,
278 attempt,
279 &self.refresh_config,
280 has_connection_factory,
281 ) {
282 return Err(AdkError::Tool(format!("Failed to list MCP tools: {}", error)));
283 }
284
285 let retry_attempt = attempt + 1;
286 if self.refresh_config.log_reconnections {
287 warn!(
288 attempt = retry_attempt,
289 max_attempts = self.refresh_config.max_attempts,
290 error = %error,
291 "MCP list_all_tools failed; reconnecting and retrying"
292 );
293 }
294
295 if self.refresh_config.retry_delay_ms > 0 {
296 tokio::time::sleep(tokio::time::Duration::from_millis(
297 self.refresh_config.retry_delay_ms,
298 ))
299 .await;
300 }
301
302 if !self.try_refresh_connection().await? {
303 return Err(AdkError::Tool(format!("Failed to list MCP tools: {}", error)));
304 }
305 attempt += 1;
306 }
307 }
308 };
309
310 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
312
313 for mcp_tool in mcp_tools {
314 let tool_name = mcp_tool.name.to_string();
315
316 if let Some(ref filter) = self.tool_filter {
318 if !filter(&tool_name) {
319 continue;
320 }
321 }
322
323 let adk_tool = McpTool {
324 name: tool_name,
325 description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
326 input_schema: {
327 let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
328 sanitize_schema(&mut schema);
329 Some(schema)
330 },
331 output_schema: mcp_tool.output_schema.map(|s| {
332 let mut schema = Value::Object(s.as_ref().clone());
333 sanitize_schema(&mut schema);
334 schema
335 }),
336 client: self.client.clone(),
337 connection_factory: self.connection_factory.clone(),
338 refresh_config: self.refresh_config.clone(),
339 is_long_running: false, task_config: self.task_config.clone(),
341 };
342
343 tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
344 }
345
346 Ok(tools)
347 }
348}
349
350struct McpTool<S>
354where
355 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
356{
357 name: String,
358 description: String,
359 input_schema: Option<Value>,
360 output_schema: Option<Value>,
361 client: Arc<Mutex<RunningService<RoleClient, S>>>,
362 connection_factory: Option<DynConnectionFactory<S>>,
363 refresh_config: RefreshConfig,
364 is_long_running: bool,
366 task_config: McpTaskConfig,
368}
369
370impl<S> McpTool<S>
371where
372 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
373{
374 async fn try_refresh_connection(&self) -> Result<bool> {
375 let Some(factory) = self.connection_factory.clone() else {
376 return Ok(false);
377 };
378
379 let new_client = factory
380 .create_connection()
381 .await
382 .map_err(|e| AdkError::Tool(format!("Failed to refresh MCP connection: {}", e)))?;
383
384 let mut client = self.client.lock().await;
385 let old_token = client.cancellation_token();
386 old_token.cancel();
387 *client = new_client;
388 Ok(true)
389 }
390
391 async fn call_tool_with_retry(
392 &self,
393 params: CallToolRequestParams,
394 ) -> Result<rmcp::model::CallToolResult> {
395 let has_connection_factory = self.connection_factory.is_some();
396 let mut attempt = 0u32;
397
398 loop {
399 let call_result = {
400 let client = self.client.lock().await;
401 client.call_tool(params.clone()).await.map_err(|e| e.to_string())
402 };
403
404 match call_result {
405 Ok(result) => return Ok(result),
406 Err(error) => {
407 if !should_retry_mcp_operation(
408 &error,
409 attempt,
410 &self.refresh_config,
411 has_connection_factory,
412 ) {
413 return Err(AdkError::Tool(format!(
414 "Failed to call MCP tool '{}': {}",
415 self.name, error
416 )));
417 }
418
419 let retry_attempt = attempt + 1;
420 if self.refresh_config.log_reconnections {
421 warn!(
422 tool = %self.name,
423 attempt = retry_attempt,
424 max_attempts = self.refresh_config.max_attempts,
425 error = %error,
426 "MCP call_tool failed; reconnecting and retrying"
427 );
428 }
429
430 if self.refresh_config.retry_delay_ms > 0 {
431 tokio::time::sleep(tokio::time::Duration::from_millis(
432 self.refresh_config.retry_delay_ms,
433 ))
434 .await;
435 }
436
437 if !self.try_refresh_connection().await? {
438 return Err(AdkError::Tool(format!(
439 "Failed to call MCP tool '{}': {}",
440 self.name, error
441 )));
442 }
443 attempt += 1;
444 }
445 }
446 }
447 }
448
449 async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
451 let start = Instant::now();
452 let mut attempts = 0u32;
453
454 loop {
455 if let Some(timeout_ms) = self.task_config.timeout_ms {
457 let elapsed = start.elapsed().as_millis() as u64;
458 if elapsed >= timeout_ms {
459 return Err(TaskError::Timeout {
460 task_id: task_id.to_string(),
461 elapsed_ms: elapsed,
462 });
463 }
464 }
465
466 if let Some(max_attempts) = self.task_config.max_poll_attempts {
468 if attempts >= max_attempts {
469 return Err(TaskError::MaxAttemptsExceeded {
470 task_id: task_id.to_string(),
471 attempts,
472 });
473 }
474 }
475
476 tokio::time::sleep(self.task_config.poll_duration()).await;
478 attempts += 1;
479
480 debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
481
482 let poll_result = self
485 .call_tool_with_retry(CallToolRequestParams {
486 name: "tasks/get".into(),
487 arguments: Some(serde_json::Map::from_iter([(
488 "task_id".to_string(),
489 Value::String(task_id.to_string()),
490 )])),
491 task: None,
492 meta: None,
493 })
494 .await
495 .map_err(|e| TaskError::PollFailed(e.to_string()))?;
496
497 let status = self.parse_task_status(&poll_result)?;
499
500 match status {
501 TaskStatus::Completed => {
502 debug!(task_id = task_id, "Task completed successfully");
503 return self.extract_task_result(&poll_result);
505 }
506 TaskStatus::Failed => {
507 let error_msg = self.extract_error_message(&poll_result);
508 return Err(TaskError::TaskFailed {
509 task_id: task_id.to_string(),
510 error: error_msg,
511 });
512 }
513 TaskStatus::Cancelled => {
514 return Err(TaskError::Cancelled(task_id.to_string()));
515 }
516 TaskStatus::Pending | TaskStatus::Running => {
517 debug!(
519 task_id = task_id,
520 status = ?status,
521 "Task still in progress"
522 );
523 }
524 }
525 }
526 }
527
528 fn parse_task_status(
530 &self,
531 result: &rmcp::model::CallToolResult,
532 ) -> std::result::Result<TaskStatus, TaskError> {
533 if let Some(ref structured) = result.structured_content {
535 if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
536 return match status_str {
537 "pending" => Ok(TaskStatus::Pending),
538 "running" => Ok(TaskStatus::Running),
539 "completed" => Ok(TaskStatus::Completed),
540 "failed" => Ok(TaskStatus::Failed),
541 "cancelled" => Ok(TaskStatus::Cancelled),
542 _ => {
543 warn!(status = status_str, "Unknown task status");
544 Ok(TaskStatus::Running) }
546 };
547 }
548 }
549
550 for content in &result.content {
552 if let Some(text_content) = content.deref().as_text() {
553 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
555 if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
556 return match status_str {
557 "pending" => Ok(TaskStatus::Pending),
558 "running" => Ok(TaskStatus::Running),
559 "completed" => Ok(TaskStatus::Completed),
560 "failed" => Ok(TaskStatus::Failed),
561 "cancelled" => Ok(TaskStatus::Cancelled),
562 _ => Ok(TaskStatus::Running),
563 };
564 }
565 }
566 }
567 }
568
569 Ok(TaskStatus::Running)
571 }
572
573 fn extract_task_result(
575 &self,
576 result: &rmcp::model::CallToolResult,
577 ) -> std::result::Result<Value, TaskError> {
578 if let Some(ref structured) = result.structured_content {
580 if let Some(output) = structured.get("result") {
581 return Ok(json!({ "output": output }));
582 }
583 return Ok(json!({ "output": structured }));
584 }
585
586 let mut text_parts: Vec<String> = Vec::new();
588 for content in &result.content {
589 if let Some(text_content) = content.deref().as_text() {
590 text_parts.push(text_content.text.clone());
591 }
592 }
593
594 if text_parts.is_empty() {
595 Ok(json!({ "output": null }))
596 } else {
597 Ok(json!({ "output": text_parts.join("\n") }))
598 }
599 }
600
601 fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
603 if let Some(ref structured) = result.structured_content {
605 if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
606 return error.to_string();
607 }
608 }
609
610 for content in &result.content {
612 if let Some(text_content) = content.deref().as_text() {
613 return text_content.text.clone();
614 }
615 }
616
617 "Unknown error".to_string()
618 }
619
620 fn extract_task_id(
622 &self,
623 result: &rmcp::model::CallToolResult,
624 ) -> std::result::Result<String, TaskError> {
625 if let Some(ref structured) = result.structured_content {
627 if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
628 return Ok(task_id.to_string());
629 }
630 }
631
632 for content in &result.content {
634 if let Some(text_content) = content.deref().as_text() {
635 if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
636 if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
637 return Ok(task_id.to_string());
638 }
639 }
640 }
641 }
642
643 Err(TaskError::CreateFailed("No task_id in response".to_string()))
644 }
645}
646
647#[async_trait]
648impl<S> Tool for McpTool<S>
649where
650 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
651{
652 fn name(&self) -> &str {
653 &self.name
654 }
655
656 fn description(&self) -> &str {
657 &self.description
658 }
659
660 fn is_long_running(&self) -> bool {
661 self.is_long_running
662 }
663
664 fn parameters_schema(&self) -> Option<Value> {
665 self.input_schema.clone()
666 }
667
668 fn response_schema(&self) -> Option<Value> {
669 self.output_schema.clone()
670 }
671
672 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
673 let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
675
676 if use_task_mode {
677 debug!(tool = self.name, "Executing tool in task mode (long-running)");
678
679 let task_params = self.task_config.to_task_params();
681 let task_map = task_params.as_object().cloned();
682
683 let create_result = self
684 .call_tool_with_retry(CallToolRequestParams {
685 name: self.name.clone().into(),
686 arguments: if args.is_null() || args == json!({}) {
687 None
688 } else {
689 match args {
690 Value::Object(map) => Some(map),
691 _ => {
692 return Err(AdkError::Tool(
693 "Tool arguments must be an object".to_string(),
694 ));
695 }
696 }
697 },
698 task: task_map,
699 meta: None,
700 })
701 .await?;
702
703 let task_id = self
705 .extract_task_id(&create_result)
706 .map_err(|e| AdkError::Tool(format!("Failed to get task ID: {}", e)))?;
707
708 debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
709
710 let result = self
712 .poll_task(&task_id)
713 .await
714 .map_err(|e| AdkError::Tool(format!("Task execution failed: {}", e)))?;
715
716 return Ok(result);
717 }
718
719 let result = self
721 .call_tool_with_retry(CallToolRequestParams {
722 name: self.name.clone().into(),
723 arguments: if args.is_null() || args == json!({}) {
724 None
725 } else {
726 match args {
728 Value::Object(map) => Some(map),
729 _ => {
730 return Err(AdkError::Tool(
731 "Tool arguments must be an object".to_string(),
732 ));
733 }
734 }
735 },
736 task: None,
737 meta: None,
738 })
739 .await?;
740
741 if result.is_error.unwrap_or(false) {
743 let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
744
745 for content in &result.content {
747 if let Some(text_content) = content.deref().as_text() {
749 error_msg.push_str(": ");
750 error_msg.push_str(&text_content.text);
751 break;
752 }
753 }
754
755 return Err(AdkError::Tool(error_msg));
756 }
757
758 if let Some(structured) = result.structured_content {
760 return Ok(json!({ "output": structured }));
761 }
762
763 let mut text_parts: Vec<String> = Vec::new();
765
766 for content in &result.content {
767 let raw: &RawContent = content.deref();
769 match raw {
770 RawContent::Text(text_content) => {
771 text_parts.push(text_content.text.clone());
772 }
773 RawContent::Image(image_content) => {
774 text_parts.push(format!(
776 "[Image: {} bytes, mime: {}]",
777 image_content.data.len(),
778 image_content.mime_type
779 ));
780 }
781 RawContent::Resource(resource_content) => {
782 let uri = match &resource_content.resource {
783 ResourceContents::TextResourceContents { uri, .. } => uri,
784 ResourceContents::BlobResourceContents { uri, .. } => uri,
785 };
786 text_parts.push(format!("[Resource: {}]", uri));
787 }
788 RawContent::Audio(_) => {
789 text_parts.push("[Audio content]".to_string());
790 }
791 RawContent::ResourceLink(link) => {
792 text_parts.push(format!("[ResourceLink: {}]", link.uri));
793 }
794 }
795 }
796
797 if text_parts.is_empty() {
798 return Err(AdkError::Tool(format!("MCP tool '{}' returned no content", self.name)));
799 }
800
801 Ok(json!({ "output": text_parts.join("\n") }))
802 }
803}
804
805unsafe impl<S> Send for McpTool<S> where
807 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
808{
809}
810unsafe impl<S> Sync for McpTool<S> where
811 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
812{
813}
814
815#[cfg(test)]
816mod tests {
817 use super::*;
818
819 #[test]
820 fn test_should_retry_mcp_operation_reconnectable_errors() {
821 let config = RefreshConfig::default().with_max_attempts(3);
822 assert!(should_retry_mcp_operation("EOF", 0, &config, true));
823 assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
824 }
825
826 #[test]
827 fn test_should_retry_mcp_operation_stops_at_max_attempts() {
828 let config = RefreshConfig::default().with_max_attempts(2);
829 assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
830 }
831
832 #[test]
833 fn test_should_retry_mcp_operation_requires_factory() {
834 let config = RefreshConfig::default().with_max_attempts(3);
835 assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
836 }
837
838 #[test]
839 fn test_should_retry_mcp_operation_non_reconnectable_error() {
840 let config = RefreshConfig::default().with_max_attempts(3);
841 assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
842 }
843}