Skip to main content

adk_tool/mcp/
toolset.rs

1// MCP (Model Context Protocol) Toolset Integration
2//
3// Based on Go implementation: adk-go/tool/mcptoolset/
4// Uses official Rust SDK: https://github.com/modelcontextprotocol/rust-sdk
5//
6// The McpToolset connects to an MCP server, discovers available tools,
7// and exposes them as ADK-compatible tools for use with LlmAgent.
8
9use super::task::{McpTaskConfig, TaskError, TaskStatus};
10use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
11use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
12use async_trait::async_trait;
13use rmcp::{
14    RoleClient,
15    model::{
16        CallToolRequestParams, ErrorCode, RawContent, ReadResourceRequestParams, Resource,
17        ResourceContents, ResourceTemplate,
18    },
19    service::RunningService,
20};
21use serde_json::{Value, json};
22use std::ops::Deref;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::Mutex;
26use tracing::{debug, warn};
27
28/// Shared factory object used to recreate MCP connections for refresh/retry.
29type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
30
31/// Type alias for tool filter predicate
32pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
33
34/// Sanitize JSON schema for LLM compatibility.
35/// Removes fields like `$schema`, `additionalProperties`, `definitions`, `$ref`
36/// that some LLM APIs (like Gemini) don't accept.
37fn sanitize_schema(value: &mut Value) {
38    if let Value::Object(map) = value {
39        map.remove("$schema");
40        map.remove("definitions");
41        map.remove("$ref");
42        map.remove("additionalProperties");
43
44        for (_, v) in map.iter_mut() {
45            sanitize_schema(v);
46        }
47    } else if let Value::Array(arr) = value {
48        for v in arr.iter_mut() {
49            sanitize_schema(v);
50        }
51    }
52}
53
54fn should_retry_mcp_operation(
55    error: &str,
56    attempt: u32,
57    refresh_config: &RefreshConfig,
58    has_connection_factory: bool,
59) -> bool {
60    has_connection_factory
61        && attempt < refresh_config.max_attempts
62        && should_refresh_connection(error)
63}
64
65/// Returns `true` when the `ServiceError` wraps an MCP `MethodNotFound` (-32601)
66/// JSON-RPC error, indicating the server does not implement the requested method.
67fn is_method_not_found(err: &rmcp::ServiceError) -> bool {
68    matches!(
69        err,
70        rmcp::ServiceError::McpError(e) if e.code == ErrorCode::METHOD_NOT_FOUND
71    )
72}
73
74/// MCP Toolset - connects to an MCP server and exposes its tools as ADK tools.
75///
76/// This toolset implements the ADK `Toolset` trait and bridges the gap between
77/// MCP servers and ADK agents. It:
78/// 1. Connects to an MCP server via the provided transport
79/// 2. Discovers available tools from the server
80/// 3. Converts MCP tools to ADK-compatible `Tool` implementations
81/// 4. Proxies tool execution calls to the MCP server
82///
83/// # Example
84///
85/// ```rust,ignore
86/// use adk_tool::McpToolset;
87/// use rmcp::{ServiceExt, transport::TokioChildProcess};
88/// use tokio::process::Command;
89///
90/// // Create MCP client connection to a local server
91/// let client = ().serve(TokioChildProcess::new(
92///     Command::new("npx")
93///         .arg("-y")
94///         .arg("@modelcontextprotocol/server-everything")
95/// )?).await?;
96///
97/// // Create toolset from the client
98/// let toolset = McpToolset::new(client);
99///
100/// // Add to agent
101/// let agent = LlmAgentBuilder::new("assistant")
102///     .toolset(Arc::new(toolset))
103///     .build()?;
104/// ```
105pub struct McpToolset<S = ()>
106where
107    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
108{
109    /// The running MCP client service
110    client: Arc<Mutex<RunningService<RoleClient, S>>>,
111    /// Optional filter to select which tools to expose
112    tool_filter: Option<ToolFilter>,
113    /// Name of this toolset
114    name: String,
115    /// Task configuration for long-running operations
116    task_config: McpTaskConfig,
117    /// Optional connection factory used for reconnection on transport failures.
118    connection_factory: Option<DynConnectionFactory<S>>,
119    /// Reconnection/retry configuration.
120    refresh_config: RefreshConfig,
121}
122
123impl<S> McpToolset<S>
124where
125    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
126{
127    /// Create a new MCP toolset from a running MCP client service.
128    ///
129    /// The client should already be connected and initialized.
130    /// Use `rmcp::ServiceExt::serve()` to create the client.
131    ///
132    /// # Example
133    ///
134    /// ```rust,ignore
135    /// use rmcp::{ServiceExt, transport::TokioChildProcess};
136    /// use tokio::process::Command;
137    ///
138    /// let client = ().serve(TokioChildProcess::new(
139    ///     Command::new("my-mcp-server")
140    /// )?).await?;
141    ///
142    /// let toolset = McpToolset::new(client);
143    /// ```
144    pub fn new(client: RunningService<RoleClient, S>) -> Self {
145        Self {
146            client: Arc::new(Mutex::new(client)),
147            tool_filter: None,
148            name: "mcp_toolset".to_string(),
149            task_config: McpTaskConfig::default(),
150            connection_factory: None,
151            refresh_config: RefreshConfig::default(),
152        }
153    }
154
155    /// Create a McpToolset from a RunningService with a custom ClientHandler.
156    ///
157    /// This is functionally identical to `new()` but makes the intent explicit
158    /// when using a custom `ClientHandler` type.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// use rmcp::ServiceExt;
164    /// use adk_tool::McpToolset;
165    ///
166    /// let client = my_custom_handler.serve(transport).await?;
167    /// let toolset = McpToolset::with_client_handler(client);
168    /// ```
169    pub fn with_client_handler(client: RunningService<RoleClient, S>) -> Self {
170        Self::new(client)
171    }
172
173    /// Set a custom name for this toolset.
174    pub fn with_name(mut self, name: impl Into<String>) -> Self {
175        self.name = name.into();
176        self
177    }
178
179    /// Enable task support for long-running operations.
180    ///
181    /// When enabled, tools marked as `is_long_running()` will use MCP's
182    /// async task lifecycle (SEP-1686) instead of blocking calls.
183    ///
184    /// # Example
185    ///
186    /// ```rust,ignore
187    /// let toolset = McpToolset::new(client)
188    ///     .with_task_support(McpTaskConfig::enabled()
189    ///         .poll_interval(Duration::from_secs(2))
190    ///         .timeout(Duration::from_secs(300)));
191    /// ```
192    pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
193        self.task_config = config;
194        self
195    }
196
197    /// Provide a connection factory to enable automatic MCP reconnection.
198    pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
199    where
200        F: ConnectionFactory<S> + 'static,
201    {
202        self.connection_factory = Some(factory);
203        self
204    }
205
206    /// Configure MCP reconnect/retry behavior.
207    pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
208        self.refresh_config = config;
209        self
210    }
211
212    /// Add a filter to select which tools to expose.
213    ///
214    /// The filter function receives a tool name and returns true if the tool
215    /// should be included.
216    ///
217    /// # Example
218    ///
219    /// ```rust,ignore
220    /// let toolset = McpToolset::new(client)
221    ///     .with_filter(|name| {
222    ///         matches!(name, "read_file" | "list_directory" | "search_files")
223    ///     });
224    /// ```
225    pub fn with_filter<F>(mut self, filter: F) -> Self
226    where
227        F: Fn(&str) -> bool + Send + Sync + 'static,
228    {
229        self.tool_filter = Some(Arc::new(filter));
230        self
231    }
232
233    /// Add a filter that only includes tools with the specified names.
234    ///
235    /// # Example
236    ///
237    /// ```rust,ignore
238    /// let toolset = McpToolset::new(client)
239    ///     .with_tools(&["read_file", "write_file"]);
240    /// ```
241    pub fn with_tools(self, tool_names: &[&str]) -> Self {
242        let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
243        self.with_filter(move |name| names.iter().any(|n| n == name))
244    }
245
246    /// Get a cancellation token that can be used to shutdown the MCP server.
247    ///
248    /// Call `cancel()` on the returned token to cleanly shutdown the MCP server.
249    /// This should be called before exiting to avoid EPIPE errors.
250    ///
251    /// # Example
252    ///
253    /// ```rust,ignore
254    /// let toolset = McpToolset::new(client);
255    /// let cancel_token = toolset.cancellation_token().await;
256    ///
257    /// // ... use the toolset ...
258    ///
259    /// // Before exiting:
260    /// cancel_token.cancel();
261    /// ```
262    pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
263        let client = self.client.lock().await;
264        client.cancellation_token()
265    }
266
267    async fn try_refresh_connection(&self) -> Result<bool> {
268        let Some(factory) = self.connection_factory.clone() else {
269            return Ok(false);
270        };
271
272        let new_client = factory
273            .create_connection()
274            .await
275            .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
276
277        let mut client = self.client.lock().await;
278        let old_token = client.cancellation_token();
279        old_token.cancel();
280        *client = new_client;
281        Ok(true)
282    }
283
284    /// List static resources from the connected MCP server.
285    ///
286    /// Returns the list of resources advertised by the server via the
287    /// `resources/list` protocol method. Returns an empty `Vec` when the
288    /// server does not support resources (i.e. responds with
289    /// `MethodNotFound`).
290    ///
291    /// # Errors
292    ///
293    /// Returns `AdkError::Tool` on transport or unexpected server errors.
294    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
295        let client = self.client.lock().await;
296        match client.list_all_resources().await {
297            Ok(resources) => Ok(resources),
298            Err(e) => {
299                if is_method_not_found(&e) {
300                    Ok(vec![])
301                } else {
302                    Err(AdkError::tool(format!("Failed to list MCP resources: {e}")))
303                }
304            }
305        }
306    }
307
308    /// List URI template resources from the connected MCP server.
309    ///
310    /// Returns the list of resource templates advertised by the server via
311    /// the `resourceTemplates/list` protocol method. Returns an empty `Vec`
312    /// when the server does not support resource templates (i.e. responds
313    /// with `MethodNotFound`).
314    ///
315    /// # Errors
316    ///
317    /// Returns `AdkError::Tool` on transport or unexpected server errors.
318    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>> {
319        let client = self.client.lock().await;
320        match client.list_all_resource_templates().await {
321            Ok(templates) => Ok(templates),
322            Err(e) => {
323                if is_method_not_found(&e) {
324                    Ok(vec![])
325                } else {
326                    Err(AdkError::tool(format!("Failed to list MCP resource templates: {e}")))
327                }
328            }
329        }
330    }
331
332    /// Read a resource by URI from the connected MCP server.
333    ///
334    /// Delegates to the `resources/read` protocol method. Returns the
335    /// resource contents on success.
336    ///
337    /// # Errors
338    ///
339    /// Returns `AdkError::Tool("resource not found: {uri}")` when the URI
340    /// does not match any resource on the server. Returns a generic
341    /// `AdkError::Tool` on transport or other server errors.
342    pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>> {
343        let client = self.client.lock().await;
344        let params = ReadResourceRequestParams::new(uri.to_string());
345        match client.read_resource(params).await {
346            Ok(result) => Ok(result.contents),
347            Err(e) => {
348                if is_method_not_found(&e) {
349                    Err(AdkError::tool(format!("resource not found: {uri}")))
350                } else {
351                    Err(AdkError::tool(format!("Failed to read MCP resource '{uri}': {e}")))
352                }
353            }
354        }
355    }
356}
357
358#[async_trait]
359impl<S> Toolset for McpToolset<S>
360where
361    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
362{
363    fn name(&self) -> &str {
364        &self.name
365    }
366
367    async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
368        let mut attempt = 0u32;
369        let has_connection_factory = self.connection_factory.is_some();
370        let mcp_tools = loop {
371            let list_result = {
372                let client = self.client.lock().await;
373                client.list_all_tools().await.map_err(|e| e.to_string())
374            };
375
376            match list_result {
377                Ok(tools) => break tools,
378                Err(error) => {
379                    if !should_retry_mcp_operation(
380                        &error,
381                        attempt,
382                        &self.refresh_config,
383                        has_connection_factory,
384                    ) {
385                        return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
386                    }
387
388                    let retry_attempt = attempt + 1;
389                    if self.refresh_config.log_reconnections {
390                        warn!(
391                            attempt = retry_attempt,
392                            max_attempts = self.refresh_config.max_attempts,
393                            error = %error,
394                            "MCP list_all_tools failed; reconnecting and retrying"
395                        );
396                    }
397
398                    if self.refresh_config.retry_delay_ms > 0 {
399                        tokio::time::sleep(tokio::time::Duration::from_millis(
400                            self.refresh_config.retry_delay_ms,
401                        ))
402                        .await;
403                    }
404
405                    if !self.try_refresh_connection().await? {
406                        return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
407                    }
408                    attempt += 1;
409                }
410            }
411        };
412
413        // Convert MCP tools to ADK tools
414        let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
415
416        for mcp_tool in mcp_tools {
417            let tool_name = mcp_tool.name.to_string();
418
419            // Apply filter if present
420            if let Some(ref filter) = self.tool_filter {
421                if !filter(&tool_name) {
422                    continue;
423                }
424            }
425
426            let adk_tool = McpTool {
427                name: tool_name,
428                description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
429                input_schema: {
430                    let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
431                    sanitize_schema(&mut schema);
432                    Some(schema)
433                },
434                output_schema: mcp_tool.output_schema.map(|s| {
435                    let mut schema = Value::Object(s.as_ref().clone());
436                    sanitize_schema(&mut schema);
437                    schema
438                }),
439                client: self.client.clone(),
440                connection_factory: self.connection_factory.clone(),
441                refresh_config: self.refresh_config.clone(),
442                // MCP ToolAnnotations (read_only_hint, destructive_hint, etc.)
443                // do not include a "long_running" hint. When task support is
444                // enabled on this toolset, treat non-read-only open-world tools
445                // as potentially long-running so the task lifecycle activates.
446                is_long_running: self.task_config.enable_tasks
447                    && mcp_tool.annotations.as_ref().is_some_and(|a| {
448                        a.read_only_hint != Some(true) && a.open_world_hint != Some(false)
449                    }),
450                task_config: self.task_config.clone(),
451            };
452
453            tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
454        }
455
456        Ok(tools)
457    }
458}
459
460impl McpToolset<super::elicitation::AdkClientHandler> {
461    /// Create a McpToolset with elicitation support from a transport.
462    ///
463    /// This creates the MCP client using `AdkClientHandler`, which advertises
464    /// elicitation capabilities and delegates requests to the provided handler.
465    ///
466    /// # Example
467    ///
468    /// ```rust,ignore
469    /// use adk_tool::{McpToolset, ElicitationHandler, AutoDeclineElicitationHandler};
470    /// use rmcp::transport::TokioChildProcess;
471    /// use tokio::process::Command;
472    /// use std::sync::Arc;
473    ///
474    /// let transport = TokioChildProcess::new(Command::new("my-mcp-server"))?;
475    /// let handler = Arc::new(AutoDeclineElicitationHandler);
476    /// let toolset = McpToolset::with_elicitation_handler(transport, handler).await?;
477    /// ```
478    ///
479    /// # ConnectionFactory with Elicitation
480    ///
481    /// To preserve elicitation across reconnections, clone the `Arc<dyn ElicitationHandler>`
482    /// into your `ConnectionFactory` implementation:
483    ///
484    /// ```rust,ignore
485    /// use adk_tool::{McpToolset, ElicitationHandler};
486    /// use adk_tool::mcp::ConnectionFactory;
487    /// use rmcp::{ServiceExt, service::{RoleClient, RunningService}};
488    /// use rmcp::transport::TokioChildProcess;
489    /// use tokio::process::Command;
490    /// use std::sync::Arc;
491    ///
492    /// struct MyReconnectFactory {
493    ///     handler: Arc<dyn ElicitationHandler>,
494    ///     server_command: String,
495    /// }
496    ///
497    /// // The factory creates a fresh AdkClientHandler on each reconnection,
498    /// // so the new connection advertises elicitation capabilities.
499    /// // The ConnectionFactory trait itself is unchanged.
500    /// ```
501    pub async fn with_elicitation_handler<T, E, A>(
502        transport: T,
503        handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
504    ) -> Result<Self>
505    where
506        T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
507        E: std::error::Error + Send + Sync + 'static,
508    {
509        use rmcp::ServiceExt;
510        let adk_handler = super::elicitation::AdkClientHandler::new(handler);
511        let client = adk_handler
512            .serve(transport)
513            .await
514            .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
515        Ok(Self::new(client))
516    }
517}
518
519/// Individual MCP tool wrapper that implements the ADK `Tool` trait.
520///
521/// This struct wraps an MCP tool and proxies execution calls to the MCP server.
522struct McpTool<S>
523where
524    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
525{
526    name: String,
527    description: String,
528    input_schema: Option<Value>,
529    output_schema: Option<Value>,
530    client: Arc<Mutex<RunningService<RoleClient, S>>>,
531    connection_factory: Option<DynConnectionFactory<S>>,
532    refresh_config: RefreshConfig,
533    /// Whether this tool is long-running (from MCP tool metadata)
534    is_long_running: bool,
535    /// Task configuration
536    task_config: McpTaskConfig,
537}
538
539impl<S> McpTool<S>
540where
541    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
542{
543    async fn try_refresh_connection(&self) -> Result<bool> {
544        let Some(factory) = self.connection_factory.clone() else {
545            return Ok(false);
546        };
547
548        let new_client = factory
549            .create_connection()
550            .await
551            .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
552
553        let mut client = self.client.lock().await;
554        let old_token = client.cancellation_token();
555        old_token.cancel();
556        *client = new_client;
557        Ok(true)
558    }
559
560    async fn call_tool_with_retry(
561        &self,
562        params: CallToolRequestParams,
563    ) -> Result<rmcp::model::CallToolResult> {
564        let has_connection_factory = self.connection_factory.is_some();
565        let mut attempt = 0u32;
566
567        loop {
568            let call_result = {
569                let client = self.client.lock().await;
570                client.call_tool(params.clone()).await.map_err(|e| e.to_string())
571            };
572
573            match call_result {
574                Ok(result) => return Ok(result),
575                Err(error) => {
576                    if !should_retry_mcp_operation(
577                        &error,
578                        attempt,
579                        &self.refresh_config,
580                        has_connection_factory,
581                    ) {
582                        return Err(AdkError::tool(format!(
583                            "Failed to call MCP tool '{}': {error}",
584                            self.name
585                        )));
586                    }
587
588                    let retry_attempt = attempt + 1;
589                    if self.refresh_config.log_reconnections {
590                        warn!(
591                            tool = %self.name,
592                            attempt = retry_attempt,
593                            max_attempts = self.refresh_config.max_attempts,
594                            error = %error,
595                            "MCP call_tool failed; reconnecting and retrying"
596                        );
597                    }
598
599                    if self.refresh_config.retry_delay_ms > 0 {
600                        tokio::time::sleep(tokio::time::Duration::from_millis(
601                            self.refresh_config.retry_delay_ms,
602                        ))
603                        .await;
604                    }
605
606                    if !self.try_refresh_connection().await? {
607                        return Err(AdkError::tool(format!(
608                            "Failed to call MCP tool '{}': {error}",
609                            self.name
610                        )));
611                    }
612                    attempt += 1;
613                }
614            }
615        }
616    }
617
618    /// Poll a task until completion or timeout
619    async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
620        let start = Instant::now();
621        let mut attempts = 0u32;
622
623        loop {
624            // Check timeout
625            if let Some(timeout_ms) = self.task_config.timeout_ms {
626                let elapsed = start.elapsed().as_millis() as u64;
627                if elapsed >= timeout_ms {
628                    return Err(TaskError::Timeout {
629                        task_id: task_id.to_string(),
630                        elapsed_ms: elapsed,
631                    });
632                }
633            }
634
635            // Check max attempts
636            if let Some(max_attempts) = self.task_config.max_poll_attempts {
637                if attempts >= max_attempts {
638                    return Err(TaskError::MaxAttemptsExceeded {
639                        task_id: task_id.to_string(),
640                        attempts,
641                    });
642                }
643            }
644
645            // Wait before polling
646            tokio::time::sleep(self.task_config.poll_duration()).await;
647            attempts += 1;
648
649            debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
650
651            // Poll task status using tasks/get
652            // Note: This requires the MCP server to support SEP-1686 task lifecycle
653            let poll_result = self
654                .call_tool_with_retry(CallToolRequestParams::new("tasks/get").with_arguments(
655                    serde_json::Map::from_iter([(
656                        "task_id".to_string(),
657                        Value::String(task_id.to_string()),
658                    )]),
659                ))
660                .await
661                .map_err(|e| TaskError::PollFailed(e.to_string()))?;
662
663            // Parse task status from response
664            let status = self.parse_task_status(&poll_result)?;
665
666            match status {
667                TaskStatus::Completed => {
668                    debug!(task_id = task_id, "Task completed successfully");
669                    // Extract result from the poll response
670                    return self.extract_task_result(&poll_result);
671                }
672                TaskStatus::Failed => {
673                    let error_msg = self.extract_error_message(&poll_result);
674                    return Err(TaskError::TaskFailed {
675                        task_id: task_id.to_string(),
676                        error: error_msg,
677                    });
678                }
679                TaskStatus::Cancelled => {
680                    return Err(TaskError::Cancelled(task_id.to_string()));
681                }
682                TaskStatus::Pending | TaskStatus::Running => {
683                    // Continue polling
684                    debug!(
685                        task_id = task_id,
686                        status = ?status,
687                        "Task still in progress"
688                    );
689                }
690            }
691        }
692    }
693
694    /// Parse task status from poll response
695    fn parse_task_status(
696        &self,
697        result: &rmcp::model::CallToolResult,
698    ) -> std::result::Result<TaskStatus, TaskError> {
699        // Try to extract status from structured content first
700        if let Some(ref structured) = result.structured_content {
701            if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
702                return match status_str {
703                    "pending" => Ok(TaskStatus::Pending),
704                    "running" => Ok(TaskStatus::Running),
705                    "completed" => Ok(TaskStatus::Completed),
706                    "failed" => Ok(TaskStatus::Failed),
707                    "cancelled" => Ok(TaskStatus::Cancelled),
708                    _ => {
709                        warn!(status = status_str, "Unknown task status");
710                        Ok(TaskStatus::Running) // Assume still running
711                    }
712                };
713            }
714        }
715
716        // Try to extract from text content
717        for content in &result.content {
718            if let Some(text_content) = content.deref().as_text() {
719                // Try to parse as JSON
720                if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
721                    if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
722                        return match status_str {
723                            "pending" => Ok(TaskStatus::Pending),
724                            "running" => Ok(TaskStatus::Running),
725                            "completed" => Ok(TaskStatus::Completed),
726                            "failed" => Ok(TaskStatus::Failed),
727                            "cancelled" => Ok(TaskStatus::Cancelled),
728                            _ => Ok(TaskStatus::Running),
729                        };
730                    }
731                }
732            }
733        }
734
735        // Default to running if we can't determine status
736        Ok(TaskStatus::Running)
737    }
738
739    /// Extract result from completed task
740    fn extract_task_result(
741        &self,
742        result: &rmcp::model::CallToolResult,
743    ) -> std::result::Result<Value, TaskError> {
744        // Try structured content first
745        if let Some(ref structured) = result.structured_content {
746            if let Some(output) = structured.get("result") {
747                return Ok(json!({ "output": output }));
748            }
749            return Ok(json!({ "output": structured }));
750        }
751
752        // Fall back to text content
753        let mut text_parts: Vec<String> = Vec::new();
754        for content in &result.content {
755            if let Some(text_content) = content.deref().as_text() {
756                text_parts.push(text_content.text.clone());
757            }
758        }
759
760        if text_parts.is_empty() {
761            Ok(json!({ "output": null }))
762        } else {
763            Ok(json!({ "output": text_parts.join("\n") }))
764        }
765    }
766
767    /// Extract error message from failed task
768    fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
769        // Try structured content
770        if let Some(ref structured) = result.structured_content {
771            if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
772                return error.to_string();
773            }
774        }
775
776        // Try text content
777        for content in &result.content {
778            if let Some(text_content) = content.deref().as_text() {
779                return text_content.text.clone();
780            }
781        }
782
783        "Unknown error".to_string()
784    }
785
786    /// Extract task ID from create task response
787    fn extract_task_id(
788        &self,
789        result: &rmcp::model::CallToolResult,
790    ) -> std::result::Result<String, TaskError> {
791        // Try structured content
792        if let Some(ref structured) = result.structured_content {
793            if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
794                return Ok(task_id.to_string());
795            }
796        }
797
798        // Try text content (might be JSON)
799        for content in &result.content {
800            if let Some(text_content) = content.deref().as_text() {
801                if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
802                    if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
803                        return Ok(task_id.to_string());
804                    }
805                }
806            }
807        }
808
809        Err(TaskError::CreateFailed("No task_id in response".to_string()))
810    }
811}
812
813#[async_trait]
814impl<S> Tool for McpTool<S>
815where
816    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
817{
818    fn name(&self) -> &str {
819        &self.name
820    }
821
822    fn description(&self) -> &str {
823        &self.description
824    }
825
826    fn is_long_running(&self) -> bool {
827        self.is_long_running
828    }
829
830    fn parameters_schema(&self) -> Option<Value> {
831        self.input_schema.clone()
832    }
833
834    fn response_schema(&self) -> Option<Value> {
835        self.output_schema.clone()
836    }
837
838    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
839        // Determine if we should use task mode
840        let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
841
842        if use_task_mode {
843            debug!(tool = self.name, "Executing tool in task mode (long-running)");
844
845            // Create task request with task parameters
846            let task_params = self.task_config.to_task_params();
847            let task_map = task_params.as_object().cloned();
848
849            let create_result = self
850                .call_tool_with_retry({
851                    let mut params = CallToolRequestParams::new(self.name.clone());
852                    if !(args.is_null() || args == json!({})) {
853                        match args {
854                            Value::Object(map) => {
855                                params = params.with_arguments(map);
856                            }
857                            _ => {
858                                return Err(AdkError::tool("Tool arguments must be an object"));
859                            }
860                        }
861                    }
862                    if let Some(task_map) = task_map {
863                        params = params.with_task(task_map);
864                    }
865                    params
866                })
867                .await?;
868
869            // Extract task ID
870            let task_id = self
871                .extract_task_id(&create_result)
872                .map_err(|e| AdkError::tool(format!("Failed to get task ID: {e}")))?;
873
874            debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
875
876            // Poll for completion
877            let result = self
878                .poll_task(&task_id)
879                .await
880                .map_err(|e| AdkError::tool(format!("Task execution failed: {e}")))?;
881
882            return Ok(result);
883        }
884
885        // Standard synchronous execution
886        let result = self
887            .call_tool_with_retry({
888                let mut params = CallToolRequestParams::new(self.name.clone());
889                if !(args.is_null() || args == json!({})) {
890                    match args {
891                        Value::Object(map) => {
892                            params = params.with_arguments(map);
893                        }
894                        _ => {
895                            return Err(AdkError::tool("Tool arguments must be an object"));
896                        }
897                    }
898                }
899                params
900            })
901            .await?;
902
903        // Check for error response
904        if result.is_error.unwrap_or(false) {
905            let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
906
907            // Extract error details from content
908            for content in &result.content {
909                // Use Deref to access the inner RawContent
910                if let Some(text_content) = content.deref().as_text() {
911                    error_msg.push_str(": ");
912                    error_msg.push_str(&text_content.text);
913                    break;
914                }
915            }
916
917            return Err(AdkError::tool(error_msg));
918        }
919
920        // Return structured content if available
921        if let Some(structured) = result.structured_content {
922            return Ok(json!({ "output": structured }));
923        }
924
925        // Otherwise, collect text content
926        let mut text_parts: Vec<String> = Vec::new();
927
928        for content in &result.content {
929            // Access the inner RawContent via Deref
930            let raw: &RawContent = content.deref();
931            match raw {
932                RawContent::Text(text_content) => {
933                    text_parts.push(text_content.text.clone());
934                }
935                RawContent::Image(image_content) => {
936                    // Return image data as base64
937                    text_parts.push(format!(
938                        "[Image: {} bytes, mime: {}]",
939                        image_content.data.len(),
940                        image_content.mime_type
941                    ));
942                }
943                RawContent::Resource(resource_content) => {
944                    let uri = match &resource_content.resource {
945                        ResourceContents::TextResourceContents { uri, .. } => uri,
946                        ResourceContents::BlobResourceContents { uri, .. } => uri,
947                    };
948                    text_parts.push(format!("[Resource: {}]", uri));
949                }
950                RawContent::Audio(_) => {
951                    text_parts.push("[Audio content]".to_string());
952                }
953                RawContent::ResourceLink(link) => {
954                    text_parts.push(format!("[ResourceLink: {}]", link.uri));
955                }
956            }
957        }
958
959        if text_parts.is_empty() {
960            return Err(AdkError::tool(format!("MCP tool '{}' returned no content", self.name)));
961        }
962
963        Ok(json!({ "output": text_parts.join("\n") }))
964    }
965}
966
967// McpTool<S> is Send + Sync when S: Send + Sync because all fields are
968// composed of Send + Sync primitives (String, Arc<Mutex<_>>, Arc<dyn Send + Sync>, etc.).
969// The compiler enforces this through the Tool trait bound (Tool: Send + Sync).
970// No unsafe impl needed — the previous unsafe impl was removed as unnecessary.
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975
976    /// Proves that `McpTool<S>` is `Send + Sync` for any service `S: Send + Sync`
977    /// without requiring `unsafe impl`. The compiler rejects this test at build
978    /// time if any field breaks the auto-trait derivation.
979    ///
980    /// This replaced a previous `unsafe impl Send/Sync for McpTool<S>` that was
981    /// unnecessary — all fields (String, Arc<Mutex<_>>, Arc<dyn Send+Sync>, bool)
982    /// are naturally Send + Sync.
983    #[test]
984    fn mcp_tool_is_send_and_sync() {
985        fn require_send_sync<T: Send + Sync>() {}
986
987        // The compiler proves Send + Sync for McpTool<S> and McpToolset<S> by
988        // type-checking these function bodies. If any field were !Send or !Sync,
989        // this would be a compile error — no unsafe needed.
990        //
991        // () satisfies Service<RoleClient> via the ClientHandler blanket impl
992        // in rmcp, so this is a valid concrete instantiation.
993        require_send_sync::<McpTool<()>>();
994        require_send_sync::<McpToolset<()>>();
995    }
996
997    #[test]
998    fn test_should_retry_mcp_operation_reconnectable_errors() {
999        let config = RefreshConfig::default().with_max_attempts(3);
1000        assert!(should_retry_mcp_operation("EOF", 0, &config, true));
1001        assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
1002    }
1003
1004    #[test]
1005    fn test_should_retry_mcp_operation_stops_at_max_attempts() {
1006        let config = RefreshConfig::default().with_max_attempts(2);
1007        assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
1008    }
1009
1010    #[test]
1011    fn test_should_retry_mcp_operation_requires_factory() {
1012        let config = RefreshConfig::default().with_max_attempts(3);
1013        assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
1014    }
1015
1016    #[test]
1017    fn test_should_retry_mcp_operation_non_reconnectable_error() {
1018        let config = RefreshConfig::default().with_max_attempts(3);
1019        assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
1020    }
1021}